Skip to content

Commit 4c239b1

Browse files
I8dNLojoein
andauthored
Custom rerankers support (#496)
* Custom rerankers support * Test for reranker_custom_model * test fix * Model description type fix * Test fix * fix: fix naming * fix: remove redundant arg from tests * new: update readme --------- Co-authored-by: George Panchuk <george.panchuk@qdrant.tech>
1 parent 6acfb00 commit 4c239b1

4 files changed

Lines changed: 173 additions & 2 deletions

File tree

README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,23 @@ scores = list(encoder.rerank(query, documents))
190190
# [-11.48061752319336, 5.472434997558594]
191191
```
192192

193+
Text cross encoders can also be extended with models which are not in the list of supported models.
194+
195+
```python
196+
from fastembed.rerank.cross_encoder import TextCrossEncoder
197+
from fastembed.common.model_description import ModelSource
198+
199+
TextCrossEncoder.add_custom_model(
200+
model="Xenova/ms-marco-MiniLM-L-4-v2",
201+
model_file="onnx/model.onnx",
202+
sources=ModelSource(hf="Xenova/ms-marco-MiniLM-L-4-v2"),
203+
)
204+
model = TextCrossEncoder(model_name="Xenova/ms-marco-MiniLM-L-4-v2")
205+
scores = list(model.rerank_pairs(
206+
[("What is AI?", "Artificial intelligence is ..."), ("What is ML?", "Machine learning is ..."),]
207+
))
208+
```
209+
193210
## ⚡️ FastEmbed on a GPU
194211

195212
FastEmbed supports running on GPU devices.
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import Optional, Sequence, Any
2+
3+
from fastembed.common import OnnxProvider
4+
from fastembed.common.model_description import BaseModelDescription
5+
from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder
6+
7+
8+
class CustomTextCrossEncoder(OnnxTextCrossEncoder):
9+
SUPPORTED_MODELS: list[BaseModelDescription] = []
10+
11+
def __init__(
12+
self,
13+
model_name: str,
14+
cache_dir: Optional[str] = None,
15+
threads: Optional[int] = None,
16+
providers: Optional[Sequence[OnnxProvider]] = None,
17+
cuda: bool = False,
18+
device_ids: Optional[list[int]] = None,
19+
lazy_load: bool = False,
20+
device_id: Optional[int] = None,
21+
specific_model_path: Optional[str] = None,
22+
**kwargs: Any,
23+
):
24+
super().__init__(
25+
model_name=model_name,
26+
cache_dir=cache_dir,
27+
threads=threads,
28+
providers=providers,
29+
cuda=cuda,
30+
device_ids=device_ids,
31+
lazy_load=lazy_load,
32+
device_id=device_id,
33+
specific_model_path=specific_model_path,
34+
**kwargs,
35+
)
36+
37+
@classmethod
38+
def _list_supported_models(cls) -> list[BaseModelDescription]:
39+
return cls.SUPPORTED_MODELS
40+
41+
@classmethod
42+
def add_model(
43+
cls,
44+
model_description: BaseModelDescription,
45+
) -> None:
46+
cls.SUPPORTED_MODELS.append(model_description)

fastembed/rerank/cross_encoder/text_cross_encoder.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,19 @@
33

44
from fastembed.common import OnnxProvider
55
from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder
6+
from fastembed.rerank.cross_encoder.custom_text_cross_encoder import CustomTextCrossEncoder
7+
68
from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase
7-
from fastembed.common.model_description import BaseModelDescription
9+
from fastembed.common.model_description import (
10+
ModelSource,
11+
BaseModelDescription,
12+
)
813

914

1015
class TextCrossEncoder(TextCrossEncoderBase):
1116
CROSS_ENCODER_REGISTRY: list[Type[TextCrossEncoderBase]] = [
1217
OnnxTextCrossEncoder,
18+
CustomTextCrossEncoder,
1319
]
1420

1521
@classmethod
@@ -124,3 +130,34 @@ def rerank_pairs(
124130
yield from self.model.rerank_pairs(
125131
pairs, batch_size=batch_size, parallel=parallel, **kwargs
126132
)
133+
134+
@classmethod
135+
def add_custom_model(
136+
cls,
137+
model: str,
138+
sources: ModelSource,
139+
model_file: str = "onnx/model.onnx",
140+
description: str = "",
141+
license: str = "",
142+
size_in_gb: float = 0.0,
143+
additional_files: Optional[list[str]] = None,
144+
) -> None:
145+
registered_models = cls._list_supported_models()
146+
for registered_model in registered_models:
147+
if model == registered_model.model:
148+
raise ValueError(
149+
f"Model {model} is already registered in CrossEncoderModel, if you still want to add this model, "
150+
f"please use another model name"
151+
)
152+
153+
CustomTextCrossEncoder.add_model(
154+
BaseModelDescription(
155+
model=model,
156+
sources=sources,
157+
model_file=model_file,
158+
description=description,
159+
license=license,
160+
size_in_GB=size_in_gb,
161+
additional_files=additional_files or [],
162+
)
163+
)

tests/test_custom_models.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,28 @@
33
import numpy as np
44
import pytest
55

6-
from fastembed.common.model_description import PoolingType, ModelSource, DenseModelDescription
6+
from fastembed.common.model_description import (
7+
PoolingType,
8+
ModelSource,
9+
DenseModelDescription,
10+
BaseModelDescription,
11+
)
712
from fastembed.common.onnx_model import OnnxOutputContext
813
from fastembed.common.utils import normalize, mean_pooling
914
from fastembed.text.custom_text_embedding import CustomTextEmbedding, PostprocessingConfig
15+
from fastembed.rerank.cross_encoder.custom_text_cross_encoder import CustomTextCrossEncoder
16+
from fastembed.rerank.cross_encoder import TextCrossEncoder
1017
from fastembed.text.text_embedding import TextEmbedding
1118
from tests.utils import delete_model_cache
1219

1320

1421
@pytest.fixture(autouse=True)
1522
def restore_custom_models_fixture():
1623
CustomTextEmbedding.SUPPORTED_MODELS = []
24+
CustomTextCrossEncoder.SUPPORTED_MODELS = []
1725
yield
1826
CustomTextEmbedding.SUPPORTED_MODELS = []
27+
CustomTextCrossEncoder.SUPPORTED_MODELS = []
1928

2029

2130
def test_text_custom_model():
@@ -65,6 +74,43 @@ def test_text_custom_model():
6574
delete_model_cache(model.model._model_dir)
6675

6776

77+
def test_cross_encoder_custom_model():
78+
is_ci = os.getenv("CI")
79+
custom_model_name = "Xenova/ms-marco-MiniLM-L-4-v2"
80+
size_in_gb = 0.08
81+
source = ModelSource(hf=custom_model_name)
82+
canonical_vector = np.array([-5.7170815, -11.112114], dtype=np.float32)
83+
84+
TextCrossEncoder.add_custom_model(
85+
custom_model_name,
86+
model_file="onnx/model.onnx",
87+
sources=source,
88+
size_in_gb=size_in_gb,
89+
)
90+
91+
assert CustomTextCrossEncoder.SUPPORTED_MODELS[0] == BaseModelDescription(
92+
model=custom_model_name,
93+
sources=source,
94+
model_file="onnx/model.onnx",
95+
description="",
96+
license="",
97+
size_in_GB=size_in_gb,
98+
)
99+
100+
model = TextCrossEncoder(custom_model_name)
101+
pairs = [
102+
("What is AI?", "Artificial intelligence is ..."),
103+
("What is ML?", "Machine learning is ..."),
104+
]
105+
scores = list(model.rerank_pairs(pairs))
106+
107+
embeddings = np.stack(scores, axis=0)
108+
assert embeddings.shape == (2,)
109+
assert np.allclose(embeddings, canonical_vector, atol=1e-3)
110+
if is_ci:
111+
delete_model_cache(model.model._model_dir)
112+
113+
68114
def test_mock_add_custom_models():
69115
dim = 5
70116
size_in_gb = 0.1
@@ -156,3 +202,28 @@ def test_do_not_add_existing_model():
156202
dim=384,
157203
size_in_gb=0.47,
158204
)
205+
206+
207+
def test_do_not_add_existing_cross_encoder():
208+
existing_base_model = "Xenova/ms-marco-MiniLM-L-6-v2"
209+
custom_model_name = "Xenova/ms-marco-MiniLM-L-4-v2"
210+
211+
with pytest.raises(ValueError, match=f"Model {existing_base_model} is already registered"):
212+
TextCrossEncoder.add_custom_model(
213+
existing_base_model,
214+
sources=ModelSource(hf=existing_base_model),
215+
size_in_gb=0.08,
216+
)
217+
218+
TextCrossEncoder.add_custom_model(
219+
custom_model_name,
220+
sources=ModelSource(hf=existing_base_model),
221+
size_in_gb=0.08,
222+
)
223+
224+
with pytest.raises(ValueError, match=f"Model {custom_model_name} is already registered"):
225+
TextCrossEncoder.add_custom_model(
226+
custom_model_name,
227+
sources=ModelSource(hf=custom_model_name),
228+
size_in_gb=0.08,
229+
)

0 commit comments

Comments
 (0)