@@ -15,9 +15,11 @@ def _find_expected_nn(points: np.ndarray, query: np.ndarray, n: int):
1515 return np .argsort (d )[:n ]
1616
1717
18- def _test_random_points (num_points : int , num_query_points : int , seed : Optional [int ] = None ):
18+ def _test_random_points (num_points : int , num_query_points : int , seed : Optional [int ] = None , distance_type : HnswDistanceType = HnswDistanceType . EUCLIDEAN , min_score : float = 0.5 ):
1919 """ Generates random points in a 2d plane; checks the queried NN against the expected. """
2020
21+ vector_field_name = "vector_" + distance_type .name .lower ()
22+
2123 print (f"Test random points; Points: { num_points } , Query points: { num_query_points } , Seed: { seed } " )
2224
2325 k = 10
@@ -37,7 +39,7 @@ def _test_random_points(num_points: int, num_query_points: int, seed: Optional[i
3739 for i in range (points .shape [0 ]):
3840 object_ = VectorEntity ()
3941 object_ .name = f"point_{ i } "
40- object_ . vector = points [i ]
42+ setattr ( object_ , vector_field_name , points [i ])
4143 objects .append (object_ )
4244 box .put (* objects )
4345 print (f"DB seeded with { box .count ()} random points!" )
@@ -58,50 +60,62 @@ def _test_random_points(num_points: int, num_query_points: int, seed: Optional[i
5860
5961 # Run ANN with OBX
6062 qb = box .query ()
61- qb .nearest_neighbors_f32 ("vector" , query_point , k )
63+ qb .nearest_neighbors_f32 (vector_field_name , query_point , k )
6264 query = qb .build ()
6365 obx_result = [id_ for id_ , score in query .find_ids_with_scores ()] # Ignore score
6466 assert len (obx_result ) == k
6567
6668 # We would like at least half of the expected results, to be returned by the search (in any order)
6769 # Remember: it's an approximate search!
6870 search_score = len (np .intersect1d (expected_result , obx_result )) / k
69- assert search_score >= 0.5 # TODO likely could be increased
71+ assert search_score >= min_score # TODO likely could be increased
7072
7173 print (f"Done!" )
7274
7375
7476def test_random_points ():
75- _test_random_points (num_points = 100 , num_query_points = 10 , seed = 10 )
76- _test_random_points (num_points = 100 , num_query_points = 10 , seed = 11 )
77- _test_random_points (num_points = 100 , num_query_points = 10 , seed = 12 )
78- _test_random_points (num_points = 100 , num_query_points = 10 , seed = 13 )
79- _test_random_points (num_points = 100 , num_query_points = 10 , seed = 14 )
80- _test_random_points (num_points = 100 , num_query_points = 10 , seed = 15 )
81-
82-
83- def test_combined_nn_search ():
84- """ Tests NN search combined with regular query conditions, offset and limit. """
85-
77+
78+ min_score = 0.5
79+ distance_type = HnswDistanceType .EUCLIDEAN
80+ _test_random_points (num_points = 100 , num_query_points = 10 , seed = 10 , distance_type = distance_type , min_score = min_score )
81+ _test_random_points (num_points = 100 , num_query_points = 10 , seed = 11 , distance_type = distance_type , min_score = min_score )
82+ _test_random_points (num_points = 100 , num_query_points = 10 , seed = 12 , distance_type = distance_type , min_score = min_score )
83+ _test_random_points (num_points = 100 , num_query_points = 10 , seed = 13 , distance_type = distance_type , min_score = min_score )
84+ _test_random_points (num_points = 100 , num_query_points = 10 , seed = 14 , distance_type = distance_type , min_score = min_score )
85+ _test_random_points (num_points = 100 , num_query_points = 10 , seed = 15 , distance_type = distance_type , min_score = min_score )
86+
87+ # TODO: Cosine and Dot Product may result in 0 score
88+
89+ def _test_combined_nn_search (distance_type : HnswDistanceType = HnswDistanceType .EUCLIDEAN ):
90+
8691 db = create_test_objectbox ()
8792
8893 box = objectbox .Box (db , VectorEntity )
8994
90- box .put (VectorEntity (name = "Power of red" , vector = [1 , 1 ]))
91- box .put (VectorEntity (name = "Blueberry" , vector = [2 , 2 ]))
92- box .put (VectorEntity (name = "Red" , vector = [3 , 3 ]))
93- box .put (VectorEntity (name = "Blue sea" , vector = [4 , 4 ]))
94- box .put (VectorEntity (name = "Lightblue" , vector = [5 , 5 ]))
95- box .put (VectorEntity (name = "Red apple" , vector = [6 , 6 ]))
96- box .put (VectorEntity (name = "Hundred" , vector = [7 , 7 ]))
97- box .put (VectorEntity (name = "Tired" , vector = [8 , 8 ]))
98- box .put (VectorEntity (name = "Power of blue" , vector = [9 , 9 ]))
99-
95+ vector_field_name = "vector_" + distance_type .name .lower ()
96+
97+ values = [
98+ ("Power of red" , [1 , 1 ]),
99+ ("Blueberry" , [2 , 2 ]),
100+ ("Red" , [3 , 3 ]),
101+ ("Blue sea" , [4 , 4 ]),
102+ ("Lightblue" , [5 , 5 ]),
103+ ("Red apple" , [6 , 6 ]),
104+ ("Hundred" , [7 , 7 ]),
105+ ("Tired" , [8 , 8 ]),
106+ ("Power of blue" , [9 , 9 ])
107+ ]
108+ for value in values :
109+ entity = VectorEntity ()
110+ setattr (entity , "name" , value [0 ])
111+ setattr (entity , vector_field_name , value [1 ])
112+ box .put (entity )
113+
100114 assert box .count () == 9
101115
102116 # Test condition + NN search
103117 qb = box .query ()
104- qb .nearest_neighbors_f32 ("vector" , [4.1 , 4.2 ], 6 )
118+ qb .nearest_neighbors_f32 (vector_field_name , [4.1 , 4.2 ], 6 )
105119 qb .contains_string ("name" , "red" , case_sensitive = False )
106120 query = qb .build ()
107121 # 4, 5, 3, 6, 2, 7
@@ -121,7 +135,7 @@ def test_combined_nn_search():
121135
122136 # Regular condition + NN search
123137 qb = box .query ()
124- qb .nearest_neighbors_f32 ("vector" , [9.2 , 8.9 ], 7 )
138+ qb .nearest_neighbors_f32 (vector_field_name , [9.2 , 8.9 ], 7 )
125139 qb .starts_with_string ("name" , "Blue" , case_sensitive = True )
126140 query = qb .build ()
127141
@@ -131,7 +145,7 @@ def test_combined_nn_search():
131145
132146 # Regular condition + NN search
133147 qb = box .query ()
134- qb .nearest_neighbors_f32 ("vector" , [7.7 , 7.7 ], 8 )
148+ qb .nearest_neighbors_f32 (vector_field_name , [7.7 , 7.7 ], 8 )
135149 qb .contains_string ("name" , "blue" , case_sensitive = False )
136150 query = qb .build ()
137151 # 8, 7, 9, 6, 5, 4, 3, 2
@@ -157,3 +171,10 @@ def test_combined_nn_search():
157171 assert len (search_results ) == 2
158172 assert search_results [0 ] == 4
159173 assert search_results [1 ] == 5
174+
175+
176+ def test_combined_nn_search ():
177+ """ Tests NN search combined with regular query conditions, offset and limit. """
178+ distance_type = HnswDistanceType .EUCLIDEAN
179+ _test_combined_nn_search (distance_type )
180+ # TODO: Cosine, DotProduct diverges see below
0 commit comments