Skip to content

Commit 04bc7a3

Browse files
generalljoein
andauthored
MiniCOIL v1 (#513)
* add token embeddings * fix parallel worker init * implement minicoil * fix mypy * fix mypy * register minicoil * rollback suggested "fix" * add minicoil test * some pr issues (#514) * some pr issues * revert query embed refactor * test: add query embed tests * nit * Update tests/test_sparse_embeddings.py --------- Co-authored-by: Andrey Vasnetsov <andrey@vasnetsov.com> * review * fix: revert change to colbert query_embed --------- Co-authored-by: George <george.panchuk@qdrant.tech>
1 parent b785640 commit 04bc7a3

25 files changed

Lines changed: 1190 additions & 42 deletions

fastembed/common/onnx_model.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,16 @@ class OnnxModel(Generic[T]):
2828
def _get_worker_class(cls) -> Type["EmbeddingWorker[T]"]:
2929
raise NotImplementedError("Subclasses must implement this method")
3030

31-
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]:
31+
def _post_process_onnx_output(self, output: OnnxOutputContext, **kwargs: Any) -> Iterable[T]:
32+
"""Post-process the ONNX model output to convert it into a usable format.
33+
34+
Args:
35+
output (OnnxOutputContext): The raw output from the ONNX model.
36+
**kwargs: Additional keyword arguments that may be needed by specific implementations.
37+
38+
Returns:
39+
Iterable[T]: Post-processed output as an iterable of type T.
40+
"""
3241
raise NotImplementedError("Subclasses must implement this method")
3342

3443
def __init__(self) -> None:

fastembed/common/preprocessor_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def load_tokenizer(model_dir: Path) -> tuple[Tokenizer, dict[str, int]]:
3636

3737
with open(str(tokenizer_config_path)) as tokenizer_config_file:
3838
tokenizer_config = json.load(tokenizer_config_file)
39-
assert (
40-
"model_max_length" in tokenizer_config or "max_length" in tokenizer_config
41-
), "Models without model_max_length or max_length are not supported."
39+
assert "model_max_length" in tokenizer_config or "max_length" in tokenizer_config, (
40+
"Models without model_max_length or max_length are not supported."
41+
)
4242
if "model_max_length" not in tokenizer_config:
4343
max_context = tokenizer_config["max_length"]
4444
elif "max_length" not in tokenizer_config:

fastembed/common/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
OnnxProvider: TypeAlias = Union[str, tuple[str, dict[Any, Any]]]
1818
NumpyArray = Union[
19+
NDArray[np.float64],
1920
NDArray[np.float32],
2021
NDArray[np.float16],
2122
NDArray[np.int8],

fastembed/image/onnx_embedding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def _preprocess_onnx_input(
193193

194194
return onnx_input
195195

196-
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]:
196+
def _post_process_onnx_output(
197+
self, output: OnnxOutputContext, **kwargs: Any
198+
) -> Iterable[NumpyArray]:
197199
return normalize(output.model_output)
198200

199201

fastembed/image/onnx_image_model.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,16 @@ class OnnxImageModel(OnnxModel[T]):
2323
def _get_worker_class(cls) -> Type["ImageEmbeddingWorker[T]"]:
2424
raise NotImplementedError("Subclasses must implement this method")
2525

26-
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]:
26+
def _post_process_onnx_output(self, output: OnnxOutputContext, **kwargs: Any) -> Iterable[T]:
27+
"""Post-process the ONNX model output to convert it into a usable format.
28+
29+
Args:
30+
output (OnnxOutputContext): The raw output from the ONNX model.
31+
**kwargs: Additional keyword arguments that may be needed by specific implementations.
32+
33+
Returns:
34+
Iterable[T]: Post-processed output as an iterable of type T.
35+
"""
2736
raise NotImplementedError("Subclasses must implement this method")
2837

2938
def __init__(self) -> None:
@@ -104,7 +113,7 @@ def _embed_images(
104113
self.load_onnx_model()
105114

106115
for batch in iter_batch(images, batch_size):
107-
yield from self._post_process_onnx_output(self.onnx_embed(batch))
116+
yield from self._post_process_onnx_output(self.onnx_embed(batch), **kwargs)
108117
else:
109118
if parallel == 0:
110119
parallel = os.cpu_count()
@@ -125,7 +134,7 @@ def _embed_images(
125134
start_method=start_method,
126135
)
127136
for batch in pool.ordered_map(iter_batch(images, batch_size), **params):
128-
yield from self._post_process_onnx_output(batch) # type: ignore
137+
yield from self._post_process_onnx_output(batch, **kwargs) # type: ignore
129138

130139

131140
class ImageEmbeddingWorker(EmbeddingWorker[T]):

fastembed/image/transform/functional.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,26 +72,26 @@ def normalize(
7272
if not np.issubdtype(image.dtype, np.floating):
7373
image = image.astype(np.float32)
7474

75-
mean = mean if isinstance(mean, list) else [mean] * num_channels
75+
mean_list = mean if isinstance(mean, list) else [mean] * num_channels
7676

77-
if len(mean) != num_channels:
77+
if len(mean_list) != num_channels:
7878
raise ValueError(
7979
f"mean must have the same number of channels as the image, image has {num_channels} channels, got "
80-
f"{len(mean)}"
80+
f"{len(mean_list)}"
8181
)
8282

83-
mean_arr = np.array(mean, dtype=np.float32)
83+
mean_arr = np.array(mean_list, dtype=np.float32)
8484

85-
std = std if isinstance(std, list) else [std] * num_channels
86-
if len(std) != num_channels:
85+
std_list = std if isinstance(std, list) else [std] * num_channels
86+
if len(std_list) != num_channels:
8787
raise ValueError(
88-
f"std must have the same number of channels as the image, image has {num_channels} channels, got {len(std)}"
88+
f"std must have the same number of channels as the image, image has {num_channels} channels, got {len(std_list)}"
8989
)
9090

91-
std_arr = np.array(std, dtype=np.float32)
91+
std_arr = np.array(std_list, dtype=np.float32)
9292

93-
image = ((image.T - mean_arr) / std_arr).T
94-
return image
93+
image_upd = ((image.T - mean_arr) / std_arr).T
94+
return image_upd
9595

9696

9797
def resize(

fastembed/late_interaction/colbert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class Colbert(LateInteractionTextEmbeddingBase, OnnxTextModel[NumpyArray]):
4343
MASK_TOKEN = "[MASK]"
4444

4545
def _post_process_onnx_output(
46-
self, output: OnnxOutputContext, is_doc: bool = True
46+
self, output: OnnxOutputContext, is_doc: bool = True, **kwargs: Any
4747
) -> Iterable[NumpyArray]:
4848
if not is_doc:
4949
return output.model_output
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from dataclasses import asdict
2+
from typing import Union, Iterable, Optional, Any, Type
3+
4+
from fastembed.common.model_description import DenseModelDescription, ModelSource
5+
from fastembed.common.onnx_model import OnnxOutputContext
6+
from fastembed.common.types import NumpyArray
7+
from fastembed.late_interaction.late_interaction_embedding_base import (
8+
LateInteractionTextEmbeddingBase,
9+
)
10+
from fastembed.text.onnx_embedding import OnnxTextEmbedding
11+
from fastembed.text.onnx_text_model import TextEmbeddingWorker
12+
import numpy as np
13+
14+
supported_token_embeddings_models = [
15+
DenseModelDescription(
16+
model="jinaai/jina-embeddings-v2-small-en-tokens",
17+
dim=512,
18+
description="Text embeddings, Unimodal (text), English, 8192 input tokens truncation,"
19+
" Prefixes for queries/documents: not necessary, 2023 year.",
20+
license="apache-2.0",
21+
size_in_GB=0.12,
22+
sources=ModelSource(hf="xenova/jina-embeddings-v2-small-en"),
23+
model_file="onnx/model.onnx",
24+
),
25+
]
26+
27+
28+
class TokenEmbeddingsModel(OnnxTextEmbedding, LateInteractionTextEmbeddingBase):
29+
@classmethod
30+
def _list_supported_models(cls) -> list[DenseModelDescription]:
31+
"""Lists the supported models.
32+
33+
Returns:
34+
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
35+
"""
36+
return supported_token_embeddings_models
37+
38+
@classmethod
39+
def list_supported_models(cls) -> list[dict[str, Any]]:
40+
"""Lists the supported models.
41+
42+
Returns:
43+
list[dict[str, Any]]: A list of dictionaries containing the model information.
44+
"""
45+
return [asdict(model) for model in cls._list_supported_models()]
46+
47+
@classmethod
48+
def _get_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]:
49+
return TokensEmbeddingWorker
50+
51+
def _post_process_onnx_output(
52+
self, output: OnnxOutputContext, **kwargs: Any
53+
) -> Iterable[NumpyArray]:
54+
# Size: (batch_size, sequence_length, hidden_size)
55+
embeddings = output.model_output
56+
# Size: (batch_size, sequence_length)
57+
assert output.attention_mask is not None
58+
masks = output.attention_mask
59+
60+
# For each document we only select those embeddings that are not masked out
61+
for i in range(embeddings.shape[0]):
62+
yield embeddings[i, masks[i] == 1]
63+
64+
def embed(
65+
self,
66+
documents: Union[str, Iterable[str]],
67+
batch_size: int = 256,
68+
parallel: Optional[int] = None,
69+
**kwargs: Any,
70+
) -> Iterable[NumpyArray]:
71+
yield from super().embed(documents, batch_size=batch_size, parallel=parallel, **kwargs)
72+
73+
def tokenize_docs(self, documents: list[str]) -> list[NumpyArray]:
74+
if self.tokenizer is None:
75+
raise ValueError("Tokenizer not initialized")
76+
encoded = self.tokenizer.encode_batch(documents)
77+
return [np.array(e.ids, dtype=np.int32) for e in encoded]
78+
79+
80+
class TokensEmbeddingWorker(TextEmbeddingWorker[NumpyArray]):
81+
def init_embedding(
82+
self, model_name: str, cache_dir: str, **kwargs: Any
83+
) -> TokenEmbeddingsModel:
84+
return TokenEmbeddingsModel(
85+
model_name=model_name,
86+
cache_dir=cache_dir,
87+
threads=1,
88+
**kwargs,
89+
)
90+
91+
92+
if __name__ == "__main__":
93+
# Example usage
94+
print(TokenEmbeddingsModel.list_supported_models())
95+
model = TokenEmbeddingsModel(model_name="jinaai/jina-embeddings-v2-small-en-tokens")
96+
docs = ["Hello, world!", "hello", "hello hello"]
97+
98+
embeddings = model.embed(docs)
99+
for emb in embeddings:
100+
print(emb.shape)
101+
102+
print(model.tokenize_docs(docs))

fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,9 @@ def rerank_pairs(
196196
def _get_worker_class(cls) -> Type[TextRerankerWorker]:
197197
return TextCrossEncoderWorker
198198

199-
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[float]:
199+
def _post_process_onnx_output(
200+
self, output: OnnxOutputContext, **kwargs: Any
201+
) -> Iterable[float]:
200202
return (float(elem) for elem in output.model_output)
201203

202204

fastembed/rerank/cross_encoder/onnx_text_model.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,18 @@ def _rerank_pairs(
133133
for batch in pool.ordered_map(iter_batch(pairs, batch_size), **params):
134134
yield from self._post_process_onnx_output(batch) # type: ignore
135135

136-
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[float]:
136+
def _post_process_onnx_output(
137+
self, output: OnnxOutputContext, **kwargs: Any
138+
) -> Iterable[float]:
139+
"""Post-process the ONNX model output to convert it into a usable format.
140+
141+
Args:
142+
output (OnnxOutputContext): The raw output from the ONNX model.
143+
**kwargs: Additional keyword arguments that may be needed by specific implementations.
144+
145+
Returns:
146+
Iterable[float]: Post-processed output as an iterable of float values.
147+
"""
137148
raise NotImplementedError("Subclasses must implement this method")
138149

139150
def _preprocess_onnx_input(

0 commit comments

Comments
 (0)