Skip to content

Commit ef244df

Browse files
committed
utils: bubble up vector_distance_f32 and vector_distance_to_relevance #47
1 parent 8be276a commit ef244df

3 files changed

Lines changed: 42 additions & 4 deletions

File tree

objectbox/c.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,12 @@ def c_array_pointer(py_list: Union[List[Any], np.ndarray], c_type):
360360
return ctypes.cast(c_array(py_list, c_type), ctypes.POINTER(c_type))
361361

362362

363+
# OBX_C_API float obx_vector_distance_float32(OBXVectorDistanceType type, const float* vector1, const float* vector2, size_t dimension);
364+
obx_vector_distance_float32 = c_fn("obx_vector_distance_float32", ctypes.c_float, [OBXVectorDistanceType, ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float), ctypes.c_size_t])
365+
366+
# OBX_C_API float obx_vector_distance_to_relevance(OBXVectorDistanceType type, float distance);
367+
obx_vector_distance_to_relevance = c_fn("obx_vector_distance_to_relevance", ctypes.c_float, [OBXVectorDistanceType, ctypes.c_float])
368+
363369
# OBX_model* (void);
364370
obx_model = c_fn('obx_model', OBX_model_p, [])
365371

objectbox/query_builder.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import numpy as np
33
from typing import *
44

5+
from objectbox.c import *
56
from objectbox.model.properties import Property
67
from objectbox.objectbox import ObjectBox
78
from objectbox.query import Query
8-
from objectbox.c import *
9+
from objectbox.utils import check_float_vector
910

1011

1112
class QueryBuilder:
@@ -108,10 +109,11 @@ def between_2ints(self, prop: Union[int, str, Property], value_a: int, value_b:
108109
cond = obx_qb_between_2ints(self._c_builder, prop_id, value_a, value_b)
109110
return cond
110111

111-
def nearest_neighbors_f32(self, prop: Union[int, str, Property], query_vector: Union[np.ndarray, List[float]],
112+
def nearest_neighbors_f32(self,
113+
prop: Union[int, str, Property],
114+
query_vector: Union[np.ndarray, List[float]],
112115
element_count: int):
113-
if isinstance(query_vector, np.ndarray) and query_vector.dtype != np.float32:
114-
raise Exception(f"query_vector dtype is expected to be np.float32, got: {query_vector.dtype}")
116+
check_float_vector(query_vector, "query_vector")
115117
prop_id = self._entity.get_property_id(prop)
116118
c_query_vector = c_array(query_vector, ctypes.c_float)
117119
cond = obx_qb_nearest_neighbors_f32(self._c_builder, prop_id, c_query_vector, element_count)

objectbox/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import numpy as np
2+
3+
from objectbox.c import *
4+
from objectbox.model.properties import VectorDistanceType
5+
6+
7+
def check_float_vector(vector: Union[np.ndarray, List[float]], vector_name: str):
8+
""" Checks that the given vector is a float vector (either np.ndarray or Python's list). """
9+
if isinstance(vector, np.ndarray) and vector.dtype != np.float32:
10+
raise Exception(f"{vector_name} dtype is expected to be np.float32, got: {vector.dtype}")
11+
elif isinstance(vector, list) and len(vector) > 0 and (type(vector[0]) is not float):
12+
raise Exception(f"{vector_name} is expected to be a float list, got vector[0]: {type(vector[0])}")
13+
14+
15+
def vector_distance_f32(distance_type: VectorDistanceType,
16+
vector1: Union[np.ndarray, List[float]],
17+
vector2: Union[np.ndarray, List[float]],
18+
dimension: int) -> float:
19+
""" Utility function to calculate the distance of two vectors. """
20+
check_float_vector(vector1, "vector1")
21+
check_float_vector(vector2, "vector2")
22+
return obx_vector_distance_float32(distance_type,
23+
c_array(vector1, ctypes.c_float),
24+
c_array(vector2, ctypes.c_float),
25+
dimension)
26+
27+
28+
def vector_distance_to_relevance(distance_type: VectorDistanceType, distance: float) -> float:
29+
""" Converts the given distance to a relevance score in range [0.0, 1.0], according to its type. """
30+
return obx_vector_distance_to_relevance(distance_type, distance)

0 commit comments

Comments
 (0)