Skip to content

Commit 9371784

Browse files
committed
tests: added cosine and dot_product HNSW Index to Model #30
1 parent a0b4065 commit 9371784

4 files changed

Lines changed: 72 additions & 51 deletions

File tree

tests/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def create_test_objectbox(db_name: Optional[str] = None, clear_db: bool = True)
5252
model.entity(TestEntity, last_property_id=IdUid(27, 1027))
5353
model.entity(TestEntityDatetime, last_property_id=IdUid(4, 2004))
5454
model.entity(TestEntityFlex, last_property_id=IdUid(2, 3002))
55-
model.entity(VectorEntity, last_property_id=IdUid(3, 4003))
55+
model.entity(VectorEntity, last_property_id=IdUid(5, 4005))
5656
model.last_entity_id = IdUid(4, 4)
57-
model.last_index_id = IdUid(3, 40001)
57+
model.last_index_id = IdUid(5, 40003)
5858

5959
return objectbox.Builder().model(model).directory(db_path).build()
6060

tests/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class TestEntityFlex:
5757
class VectorEntity:
5858
id = Id(id=1, uid=4001)
5959
name = Property(str, type=PropertyType.string, id=2, uid=4002)
60-
vector = Property(np.ndarray, type=PropertyType.floatVector, id=3, uid=4003,
60+
vector_euclidean = Property(np.ndarray, type=PropertyType.floatVector, id=3, uid=4003,
6161
index=HnswIndex(
6262
id=3, uid=40001,
6363
dimensions=2, distance_type=HnswDistanceType.EUCLIDEAN)

tests/test_hnsw.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@ def _find_expected_nn(points: np.ndarray, query: np.ndarray, n: int):
1515
return np.argsort(d)[:n]
1616

1717

18-
def _test_random_points(num_points: int, num_query_points: int, seed: Optional[int] = None):
18+
def _test_random_points(num_points: int, num_query_points: int, seed: Optional[int] = None, distance_type: HnswDistanceType = HnswDistanceType.EUCLIDEAN, min_score: float = 0.5):
1919
""" Generates random points in a 2d plane; checks the queried NN against the expected. """
2020

21+
vector_field_name = "vector_"+distance_type.name.lower()
22+
2123
print(f"Test random points; Points: {num_points}, Query points: {num_query_points}, Seed: {seed}")
2224

2325
k = 10
@@ -37,7 +39,7 @@ def _test_random_points(num_points: int, num_query_points: int, seed: Optional[i
3739
for i in range(points.shape[0]):
3840
object_ = VectorEntity()
3941
object_.name = f"point_{i}"
40-
object_.vector = points[i]
42+
setattr(object_, vector_field_name, points[i])
4143
objects.append(object_)
4244
box.put(*objects)
4345
print(f"DB seeded with {box.count()} random points!")
@@ -58,50 +60,62 @@ def _test_random_points(num_points: int, num_query_points: int, seed: Optional[i
5860

5961
# Run ANN with OBX
6062
qb = box.query()
61-
qb.nearest_neighbors_f32("vector", query_point, k)
63+
qb.nearest_neighbors_f32(vector_field_name, query_point, k)
6264
query = qb.build()
6365
obx_result = [id_ for id_, score in query.find_ids_with_scores()] # Ignore score
6466
assert len(obx_result) == k
6567

6668
# We would like at least half of the expected results, to be returned by the search (in any order)
6769
# Remember: it's an approximate search!
6870
search_score = len(np.intersect1d(expected_result, obx_result)) / k
69-
assert search_score >= 0.5 # TODO likely could be increased
71+
assert search_score >= min_score # TODO likely could be increased
7072

7173
print(f"Done!")
7274

7375

7476
def test_random_points():
75-
_test_random_points(num_points=100, num_query_points=10, seed=10)
76-
_test_random_points(num_points=100, num_query_points=10, seed=11)
77-
_test_random_points(num_points=100, num_query_points=10, seed=12)
78-
_test_random_points(num_points=100, num_query_points=10, seed=13)
79-
_test_random_points(num_points=100, num_query_points=10, seed=14)
80-
_test_random_points(num_points=100, num_query_points=10, seed=15)
81-
82-
83-
def test_combined_nn_search():
84-
""" Tests NN search combined with regular query conditions, offset and limit. """
85-
77+
78+
min_score = 0.5
79+
distance_type = HnswDistanceType.EUCLIDEAN
80+
_test_random_points(num_points=100, num_query_points=10, seed=10, distance_type=distance_type, min_score=min_score)
81+
_test_random_points(num_points=100, num_query_points=10, seed=11, distance_type=distance_type, min_score=min_score)
82+
_test_random_points(num_points=100, num_query_points=10, seed=12, distance_type=distance_type, min_score=min_score)
83+
_test_random_points(num_points=100, num_query_points=10, seed=13, distance_type=distance_type, min_score=min_score)
84+
_test_random_points(num_points=100, num_query_points=10, seed=14, distance_type=distance_type, min_score=min_score)
85+
_test_random_points(num_points=100, num_query_points=10, seed=15, distance_type=distance_type, min_score=min_score)
86+
87+
# TODO: Cosine and Dot Product may result in 0 score
88+
89+
def _test_combined_nn_search(distance_type: HnswDistanceType = HnswDistanceType.EUCLIDEAN):
90+
8691
db = create_test_objectbox()
8792

8893
box = objectbox.Box(db, VectorEntity)
8994

90-
box.put(VectorEntity(name="Power of red", vector=[1, 1]))
91-
box.put(VectorEntity(name="Blueberry", vector=[2, 2]))
92-
box.put(VectorEntity(name="Red", vector=[3, 3]))
93-
box.put(VectorEntity(name="Blue sea", vector=[4, 4]))
94-
box.put(VectorEntity(name="Lightblue", vector=[5, 5]))
95-
box.put(VectorEntity(name="Red apple", vector=[6, 6]))
96-
box.put(VectorEntity(name="Hundred", vector=[7, 7]))
97-
box.put(VectorEntity(name="Tired", vector=[8, 8]))
98-
box.put(VectorEntity(name="Power of blue", vector=[9, 9]))
99-
95+
vector_field_name = "vector_"+distance_type.name.lower()
96+
97+
values = [
98+
("Power of red", [1, 1]),
99+
("Blueberry", [2, 2]),
100+
("Red", [3, 3]),
101+
("Blue sea", [4, 4]),
102+
("Lightblue", [5, 5]),
103+
("Red apple", [6, 6]),
104+
("Hundred", [7, 7]),
105+
("Tired", [8, 8]),
106+
("Power of blue", [9, 9])
107+
]
108+
for value in values:
109+
entity = VectorEntity()
110+
setattr(entity, "name", value[0])
111+
setattr(entity, vector_field_name, value[1])
112+
box.put(entity)
113+
100114
assert box.count() == 9
101115

102116
# Test condition + NN search
103117
qb = box.query()
104-
qb.nearest_neighbors_f32("vector", [4.1, 4.2], 6)
118+
qb.nearest_neighbors_f32(vector_field_name, [4.1, 4.2], 6)
105119
qb.contains_string("name", "red", case_sensitive=False)
106120
query = qb.build()
107121
# 4, 5, 3, 6, 2, 7
@@ -121,7 +135,7 @@ def test_combined_nn_search():
121135

122136
# Regular condition + NN search
123137
qb = box.query()
124-
qb.nearest_neighbors_f32("vector", [9.2, 8.9], 7)
138+
qb.nearest_neighbors_f32(vector_field_name, [9.2, 8.9], 7)
125139
qb.starts_with_string("name", "Blue", case_sensitive=True)
126140
query = qb.build()
127141

@@ -131,7 +145,7 @@ def test_combined_nn_search():
131145

132146
# Regular condition + NN search
133147
qb = box.query()
134-
qb.nearest_neighbors_f32("vector", [7.7, 7.7], 8)
148+
qb.nearest_neighbors_f32(vector_field_name, [7.7, 7.7], 8)
135149
qb.contains_string("name", "blue", case_sensitive=False)
136150
query = qb.build()
137151
# 8, 7, 9, 6, 5, 4, 3, 2
@@ -157,3 +171,10 @@ def test_combined_nn_search():
157171
assert len(search_results) == 2
158172
assert search_results[0] == 4
159173
assert search_results[1] == 5
174+
175+
176+
def test_combined_nn_search():
177+
""" Tests NN search combined with regular query conditions, offset and limit. """
178+
distance_type = HnswDistanceType.EUCLIDEAN
179+
_test_combined_nn_search(distance_type)
180+
# TODO: Cosine, DotProduct diverges see below

tests/test_query.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def test_basics():
1616
box_test_entity.put(TestEntity(str="bar", int64=456))
1717

1818
box_vector_entity = objectbox.Box(ob, VectorEntity)
19-
box_vector_entity.put(VectorEntity(name="Object 1", vector=[1, 1]))
20-
box_vector_entity.put(VectorEntity(name="Object 2", vector=[2, 2]))
21-
box_vector_entity.put(VectorEntity(name="Object 3", vector=[3, 3]))
19+
box_vector_entity.put(VectorEntity(name="Object 1", vector_euclidean=[1, 1]))
20+
box_vector_entity.put(VectorEntity(name="Object 2", vector_euclidean=[2, 2]))
21+
box_vector_entity.put(VectorEntity(name="Object 3", vector_euclidean=[3, 3]))
2222

2323
# String query
2424
str_prop: Property = TestEntity.get_property("str")
@@ -98,7 +98,7 @@ def test_basics():
9898
assert query.remove() == 1
9999

100100
# NN query
101-
vector_prop: Property = VectorEntity.get_property("vector")
101+
vector_prop: Property = VectorEntity.get_property("vector_euclidean")
102102

103103
query = box_vector_entity.query(vector_prop.nearest_neighbor([2.1, 2.1], 2)).build()
104104
assert query.count() == 2
@@ -258,11 +258,11 @@ def test_set_parameter():
258258
box_test_entity.put(TestEntity(str="Barrakuda", int64=4, int32=386, int8=60))
259259

260260
box_vector_entity = objectbox.Box(db, VectorEntity)
261-
box_vector_entity.put(VectorEntity(name="Object 1", vector=[1, 1]))
262-
box_vector_entity.put(VectorEntity(name="Object 2", vector=[2, 2]))
263-
box_vector_entity.put(VectorEntity(name="Object 3", vector=[3, 3]))
264-
box_vector_entity.put(VectorEntity(name="Object 4", vector=[4, 4]))
265-
box_vector_entity.put(VectorEntity(name="Object 5", vector=[5, 5]))
261+
box_vector_entity.put(VectorEntity(name="Object 1", vector_euclidean=[1, 1]))
262+
box_vector_entity.put(VectorEntity(name="Object 2", vector_euclidean=[2, 2]))
263+
box_vector_entity.put(VectorEntity(name="Object 3", vector_euclidean=[3, 3]))
264+
box_vector_entity.put(VectorEntity(name="Object 4", vector_euclidean=[4, 4]))
265+
box_vector_entity.put(VectorEntity(name="Object 5", vector_euclidean=[5, 5]))
266266

267267
qb = box_test_entity.query()
268268
qb.starts_with_string("str", "fo", case_sensitive=False)
@@ -280,22 +280,22 @@ def test_set_parameter():
280280
assert query.find_ids() == [3]
281281

282282
qb = box_vector_entity.query()
283-
qb.nearest_neighbors_f32("vector", [3.4, 3.4], 3)
283+
qb.nearest_neighbors_f32("vector_euclidean", [3.4, 3.4], 3)
284284
query = qb.build()
285285
assert query.find_ids() == sorted([3, 4, 2])
286286

287287
# set_parameter_vector_f32
288288
# set_parameter_int (NN count)
289-
query.set_parameter_vector_f32("vector", [4.9, 4.9])
289+
query.set_parameter_vector_f32("vector_euclidean", [4.9, 4.9])
290290
assert query.find_ids() == sorted([5, 4, 3])
291291

292-
query.set_parameter_vector_f32("vector", [0, 0])
292+
query.set_parameter_vector_f32("vector_euclidean", [0, 0])
293293
assert query.find_ids() == sorted([1, 2, 3])
294294

295-
query.set_parameter_vector_f32("vector", [2.5, 2.1])
295+
query.set_parameter_vector_f32("vector_euclidean", [2.5, 2.1])
296296
assert query.find_ids() == sorted([2, 3, 1])
297297

298-
query.set_parameter_int("vector", 2)
298+
query.set_parameter_int("vector_euclidean", 2)
299299
assert query.find_ids() == sorted([2, 3])
300300

301301

@@ -307,11 +307,11 @@ def test_set_parameter_alias():
307307
box.put(TestEntity(str="FooBar", int64=10, int32=49, int8=45))
308308

309309
box_vector = objectbox.Box(db, VectorEntity)
310-
box_vector.put(VectorEntity(name="Object 1", vector=[1, 1]))
311-
box_vector.put(VectorEntity(name="Object 2", vector=[2, 2]))
312-
box_vector.put(VectorEntity(name="Object 3", vector=[3, 3]))
313-
box_vector.put(VectorEntity(name="Object 4", vector=[4, 4]))
314-
box_vector.put(VectorEntity(name="Object 5", vector=[5, 5]))
310+
box_vector.put(VectorEntity(name="Object 1", vector_euclidean=[1, 1]))
311+
box_vector.put(VectorEntity(name="Object 2", vector_euclidean=[2, 2]))
312+
box_vector.put(VectorEntity(name="Object 3", vector_euclidean=[3, 3]))
313+
box_vector.put(VectorEntity(name="Object 4", vector_euclidean=[4, 4]))
314+
box_vector.put(VectorEntity(name="Object 5", vector_euclidean=[5, 5]))
315315

316316
str_prop: Property = TestEntity.get_property("str")
317317
int32_prop: Property = TestEntity.get_property("int32")
@@ -354,7 +354,7 @@ def test_set_parameter_alias():
354354
assert query.find()[0].str == "FooBar"
355355

356356
# Test set parameter alias on vector
357-
vector_prop: Property = VectorEntity.get_property("vector")
357+
vector_prop: Property = VectorEntity.get_property("vector_euclidean")
358358

359359
query = box_vector.query(vector_prop.nearest_neighbor([3.4, 3.4], 3).alias("nearest_neighbour_filter")).build()
360360
assert query.count() == 3

0 commit comments

Comments
 (0)