Skip to content

Commit dad0ef5

Browse files
feat: Implement advanced query filtering with Where clauses and introduce Record data types for items.
1 parent 7535e14 commit dad0ef5

26 files changed

Lines changed: 1458 additions & 1029 deletions

README.md

Lines changed: 334 additions & 493 deletions
Large diffs are not rendered by default.

src/Api.php

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ public function __construct(
3939
public readonly StreamFactoryInterface $streamFactory,
4040
public readonly string $baseUri,
4141
public readonly array $headers = [],
42-
) {}
42+
) {
43+
}
4344

4445
/**
4546
* Retrieves the current user's identity, tenant, and databases.
@@ -62,7 +63,7 @@ public function getCollectionByCrn(string $crn, string $database, string $tenant
6263
{
6364
$response = $this->sendRequest('GET', "/api/v2/collections/{$crn}");
6465

65-
return Collection::make(json_decode($response->getBody()->getContents(), true), $this, $database, $tenant);
66+
return Collection::fromArray(json_decode($response->getBody()->getContents(), true), $this, $database, $tenant);
6667
}
6768

6869
/**
@@ -134,7 +135,7 @@ public function getTenant(string $tenant): ?Tenant
134135

135136
$result = json_decode($response->getBody()->getContents(), true);
136137

137-
return Tenant::make($result);
138+
return Tenant::fromArray($result);
138139
}
139140

140141
/**
@@ -180,7 +181,7 @@ public function listDatabases(string $tenant, ?int $limit = null, ?int $offset =
180181

181182
$result = json_decode($response->getBody()->getContents(), true);
182183

183-
return array_map(fn(array $item) => Database::make($item), $result);
184+
return array_map(fn(array $item) => Database::fromArray($item), $result);
184185
}
185186

186187
/**
@@ -197,7 +198,7 @@ public function getDatabase(string $database, string $tenant): Database
197198

198199
$result = json_decode($response->getBody()->getContents(), true);
199200

200-
return Database::make($result);
201+
return Database::fromArray($result);
201202
}
202203

203204
/**
@@ -232,7 +233,7 @@ public function listCollections(string $database, string $tenant, ?int $limit =
232233

233234
$result = json_decode($response->getBody()->getContents(), true);
234235

235-
return array_map(fn(array $item) => Collection::make($item, $this, $database, $tenant), $result);
236+
return array_map(fn(array $item) => Collection::fromArray($item, $this, $database, $tenant), $result);
236237
}
237238

238239
/**
@@ -252,7 +253,7 @@ public function createCollection(string $database, string $tenant, CreateCollect
252253

253254
$result = json_decode($response->getBody()->getContents(), true);
254255

255-
return Collection::make($result, $this, $database, $tenant);
256+
return Collection::fromArray($result, $this, $database, $tenant);
256257
}
257258

258259
/**
@@ -270,7 +271,7 @@ public function getCollection(string $collectionId, string $database, string $te
270271

271272
$result = json_decode($response->getBody()->getContents(), true);
272273

273-
return Collection::make($result, $this, $database, $tenant);
274+
return Collection::fromArray($result, $this, $database, $tenant);
274275
}
275276

276277
/**
@@ -395,7 +396,7 @@ public function getCollectionItems(string $collectionId, string $database, strin
395396

396397
$result = json_decode($response->getBody()->getContents(), true);
397398

398-
return GetItemsResponse::from($result);
399+
return GetItemsResponse::fromArray($result);
399400
}
400401

401402
/**
@@ -431,7 +432,7 @@ public function queryCollectionItems(string $collectionId, string $database, str
431432

432433
$result = json_decode($response->getBody()->getContents(), true);
433434

434-
return QueryItemsResponse::from($result);
435+
return QueryItemsResponse::fromArray($result);
435436
}
436437

437438

@@ -454,11 +455,13 @@ private function handleErrorResponse(ResponseInterface $response): void
454455
if ($error !== null) {
455456

456457
// If the structure is 'error' => 'NotFoundError("Collection not found")'
457-
if (preg_match(
458-
'/^(?P<error_type>\w+)\((?P<message>.*)\)$/',
459-
$error['error'] ?? '',
460-
$matches
461-
)) {
458+
if (
459+
preg_match(
460+
'/^(?P<error_type>\w+)\((?P<message>.*)\)$/',
461+
$error['error'] ?? '',
462+
$matches
463+
)
464+
) {
462465
if (isset($matches['message'])) {
463466
$error_type = $matches['error_type'] ?? 'UnknownError';
464467
$message = $matches['message'];

src/Models/Collection.php

Lines changed: 97 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
use Codewithkyrian\ChromaDB\Requests\UpdateItemsRequest;
1515
use Codewithkyrian\ChromaDB\Responses\GetItemsResponse;
1616
use Codewithkyrian\ChromaDB\Responses\QueryItemsResponse;
17+
use Codewithkyrian\ChromaDB\Types\Includes;
18+
use Codewithkyrian\ChromaDB\Types\Record;
1719

1820
class Collection
1921
{
@@ -27,16 +29,17 @@ class Collection
2729
* @param EmbeddingFunction|null $embeddingFunction Optional embedding function. Must match the one used to create the collection.
2830
*/
2931
public function __construct(
30-
public readonly Api $api,
31-
public readonly string $name,
32-
public readonly string $id,
33-
public readonly ?array $metadata = null,
34-
public readonly ?string $database = null,
35-
public readonly ?string $tenant = null,
32+
public readonly Api $api,
33+
public readonly string $name,
34+
public readonly string $id,
35+
public readonly ?array $metadata = null,
36+
public readonly ?string $database = null,
37+
public readonly ?string $tenant = null,
3638
public ?EmbeddingFunction $embeddingFunction = null,
37-
) {}
39+
) {
40+
}
3841

39-
public static function make(array $data, Api $api, string $database, string $tenant): self
42+
public static function fromArray(array $data, Api $api, string $database, string $tenant): self
4043
{
4144
return new self(
4245
api: $api,
@@ -60,20 +63,37 @@ public function toArray(): array
6063
/**
6164
* Add items to the collection.
6265
*
63-
* @param string[] $ids The IDs of the items to add.
66+
* @param string[]|Record[] $ids The IDs of the items to add, or an array of Record objects.
6467
* @param number[][]|null $embeddings The embeddings of the items to add (optional).
6568
* @param array<string, array<string, mixed>>|null $metadatas The metadatas of the items to add (optional).
6669
* @param string[]|null $documents The documents of the items to add (optional).
6770
* @param string[]|null $images The base64 encoded images of the items to add (optional).
6871
* @return void
6972
*/
7073
public function add(
71-
array $ids,
74+
array $ids,
7275
?array $embeddings = null,
7376
?array $metadatas = null,
7477
?array $documents = null,
7578
?array $images = null
7679
): void {
80+
if (!empty($ids) && $ids[0] instanceof Record) {
81+
$records = $ids;
82+
$ids = [];
83+
$embeddings = [];
84+
$metadatas = [];
85+
$documents = [];
86+
$images = [];
87+
88+
foreach ($records as $record) {
89+
$ids[] = $record->id;
90+
$embeddings[] = $record->embedding;
91+
$metadatas[] = $record->metadata;
92+
$documents[] = $record->document;
93+
$images[] = $record->image;
94+
}
95+
}
96+
7797
$validated = $this->validate(
7898
ids: $ids,
7999
embeddings: $embeddings,
@@ -97,20 +117,37 @@ public function add(
97117
/**
98118
* Update the embeddings, documents, and/or metadatas of existing items.
99119
*
100-
* @param string[] $ids The IDs of the items to update.
120+
* @param string[]|Record[] $ids The IDs of the items to update, or an array of Record objects.
101121
* @param number[][]|null $embeddings The embeddings of the items to update (optional).
102122
* @param array<string, array<string, mixed>>|null $metadatas The metadatas of the items to update (optional).
103123
* @param string[]|null $documents The documents of the items to update (optional).
104124
* @param string[]|null $images The base64 encoded images of the items to update (optional).
105125
*
106126
*/
107127
public function update(
108-
array $ids,
128+
array $ids,
109129
?array $embeddings = null,
110130
?array $metadatas = null,
111131
?array $documents = null,
112132
?array $images = null
113133
) {
134+
if (!empty($ids) && $ids[0] instanceof Record) {
135+
$records = $ids;
136+
$ids = [];
137+
$embeddings = [];
138+
$metadatas = [];
139+
$documents = [];
140+
$images = [];
141+
142+
foreach ($records as $record) {
143+
$ids[] = $record->id;
144+
$embeddings[] = $record->embedding;
145+
$metadatas[] = $record->metadata;
146+
$documents[] = $record->document;
147+
$images[] = $record->image;
148+
}
149+
}
150+
114151
$validated = $this->validate(
115152
ids: $ids,
116153
embeddings: $embeddings,
@@ -134,20 +171,37 @@ public function update(
134171
/**
135172
* Upsert items in the collection.
136173
*
137-
* @param string[] $ids The IDs of the items to upsert.
174+
* @param string[]|Record[] $ids The IDs of the items to upsert, or an array of Record objects.
138175
* @param number[][]|null $embeddings The embeddings of the items to upsert (optional).
139176
* @param array<string, array<string, mixed>>|null $metadatas The metadatas of the items to upsert (optional).
140177
* @param string[]|null $documents The documents of the items to upsert (optional).
141178
* @param string[]|null $images The base64 encoded images of the items to upsert (optional).
142179
*
143180
*/
144181
public function upsert(
145-
array $ids,
182+
array $ids,
146183
?array $embeddings = null,
147184
?array $metadatas = null,
148185
?array $documents = null,
149186
?array $images = null
150187
): void {
188+
if (!empty($ids) && $ids[0] instanceof Record) {
189+
$records = $ids;
190+
$ids = [];
191+
$embeddings = [];
192+
$metadatas = [];
193+
$documents = [];
194+
$images = [];
195+
196+
foreach ($records as $record) {
197+
$ids[] = $record->id;
198+
$embeddings[] = $record->embedding;
199+
$metadatas[] = $record->metadata;
200+
$documents[] = $record->document;
201+
$images[] = $record->image;
202+
}
203+
}
204+
151205
$validated = $this->validate(
152206
ids: $ids,
153207
embeddings: $embeddings,
@@ -179,23 +233,25 @@ public function count(): int
179233
/**
180234
* Get items from the collection.
181235
*
182-
* @param array $ids The IDs of the items to get (optional).
183-
* @param array $where The where clause to filter items by (optional).
184-
* @param array $whereDocument The where clause to filter items by (optional).
185-
* @param int $limit The limit on the number of items to get (optional).
186-
* @param int $offset The offset on the number of items to get (optional).
187-
* @param string[] $include The list of fields to include in the response (optional).
236+
* @param array|null $ids The IDs of the items to get (optional).
237+
* @param array|null $where The where clause to filter items by (optional).
238+
* @param array|null $whereDocument The where clause to filter items by (optional).
239+
* @param int|null $limit The limit on the number of items to get (optional).
240+
* @param int|null $offset The offset on the number of items to get (optional).
241+
* @param string[]|Includes[]|null $include The list of fields to include in the response (optional).
188242
*/
189243
public function get(
190244
?array $ids = null,
191245
?array $where = null,
192246
?array $whereDocument = null,
193-
?int $limit = null,
194-
?int $offset = null,
247+
?int $limit = null,
248+
?int $offset = null,
195249
?array $include = null
196250
): GetItemsResponse {
197251
$include ??= ['embeddings', 'metadatas', 'distances'];
198252

253+
$include = array_map(fn($i) => $i instanceof Includes ? $i->value : $i, $include);
254+
199255
$request = new GetEmbeddingRequest(
200256
ids: $ids,
201257
where: $where,
@@ -212,16 +268,19 @@ public function get(
212268
* Retrieves a preview of records from the collection.
213269
*
214270
* @param int $limit The number of entries to return. Defaults to 10.
215-
* @param string[] $include The list of fields to include in the response (optional).
216-
*/
217-
public function peek(int $limit = 10, ?array $include = null): GetItemsResponse {
271+
* @param string[]|Includes[]|null $include The list of fields to include in the response (optional).
272+
*/
273+
public function peek(int $limit = 10, ?array $include = null): GetItemsResponse
274+
{
218275
$include ??= ['embeddings', 'metadatas', 'distances'];
219-
276+
277+
$include = array_map(fn($i) => $i instanceof Includes ? $i->value : $i, $include);
278+
220279
$request = new GetEmbeddingRequest(
221280
limit: $limit,
222281
include: $include,
223282
);
224-
283+
225284
return $this->api->getCollectionItems($this->id, $this->database, $this->tenant, $request);
226285
}
227286

@@ -252,21 +311,23 @@ public function delete(?array $ids = null, ?array $where = null, ?array $whereDo
252311
* @param int $nResults The number of results to return (optional).
253312
* @param ?array $where The where clause to filter items to search based on metadata values (optional).
254313
* @param ?array $whereDocument The where clause to filter to search based on document content (optional).
255-
* @param ?array $include The list of fields to include in the response (optional).
314+
* @param string[]|Includes[]|null $include The list of fields to include in the response (optional).
256315
*/
257316
public function query(
258317
?array $queryEmbeddings = null,
259318
?array $queryTexts = null,
260319
?array $queryImages = null,
261-
int $nResults = 10,
320+
int $nResults = 10,
262321
?array $where = null,
263322
?array $whereDocument = null,
264323
?array $include = null
265324
): QueryItemsResponse {
266325
$include ??= ['embeddings', 'metadatas', 'distances'];
267326

327+
$include = array_map(fn($i) => $i instanceof Includes ? $i->value : $i, $include);
328+
268329
if (
269-
!(($queryEmbeddings != null xor $queryTexts != null xor $queryImages != null))
330+
!(($queryEmbeddings != null xor $queryTexts != null xor $queryImages != null))
270331
) {
271332
throw new \InvalidArgumentException(
272333
'You must provide only one of queryEmbeddings, queryTexts, queryImages, or queryUris'
@@ -290,6 +351,8 @@ public function query(
290351
);
291352
}
292353
} else {
354+
355+
293356
$finalEmbeddings = $queryEmbeddings;
294357
}
295358

@@ -326,13 +389,13 @@ public function setEmbeddingFunction(EmbeddingFunction $embeddingFunction): void
326389
* @return array{ids: string[], embeddings: int[][], metadatas: array[], documents: string[], images: string[], uris: string[]}
327390
*/
328391
protected
329-
function validate(
330-
array $ids,
392+
function validate(
393+
array $ids,
331394
?array $embeddings,
332395
?array $metadatas,
333396
?array $documents,
334397
?array $images,
335-
bool $requireEmbeddingsOrDocuments
398+
bool $requireEmbeddingsOrDocuments
336399
): array {
337400

338401
if ($requireEmbeddingsOrDocuments) {
@@ -373,7 +436,7 @@ function validate(
373436
}
374437

375438
$ids = array_map(function ($id) {
376-
$id = (string)$id;
439+
$id = (string) $id;
377440
if ($id === '') {
378441
throw new \InvalidArgumentException('Expected IDs to be non-empty strings');
379442
}

src/Models/Database.php

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ public function __construct(
2222
* Tenant of the database.
2323
*/
2424
public readonly ?string $tenant,
25-
) {}
25+
) {
26+
}
2627

27-
public static function make(array $data): self
28+
public static function fromArray(array $data): self
2829
{
2930
return new self(
3031
id: $data['id'],

0 commit comments

Comments
 (0)