Skip to content

Commit c91d42d

Browse files
authored
Update setting jina v3 tasks (#503)
* new: improve task setter in jina v3 * refactor * new: add hf_token secret * fix: cross platform env propagation
1 parent aa0c475 commit c91d42d

4 files changed

Lines changed: 49 additions & 41 deletions

File tree

.github/workflows/python-tests.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,7 @@ jobs:
4242
poetry install --no-interaction --no-ansi --without dev,docs
4343
4444
- name: Run pytest
45+
env:
46+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
4547
run: |
46-
poetry run pytest
48+
poetry run pytest

fastembed/text/multitask_embedding.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55

6+
from fastembed.common.onnx_model import OnnxOutputContext
67
from fastembed.common.types import NumpyArray
78
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
89
from fastembed.text.onnx_embedding import OnnxTextEmbeddingWorker
@@ -44,9 +45,11 @@ class JinaEmbeddingV3(PooledNormalizedEmbedding):
4445
PASSAGE_TASK = Task.RETRIEVAL_PASSAGE
4546
QUERY_TASK = Task.RETRIEVAL_QUERY
4647

47-
def __init__(self, *args: Any, **kwargs: Any):
48+
def __init__(self, *args: Any, task_id: Optional[int] = None, **kwargs: Any):
4849
super().__init__(*args, **kwargs)
49-
self.current_task_id: Union[Task, int] = self.PASSAGE_TASK
50+
self.default_task_id: Union[Task, int] = (
51+
task_id if task_id is not None else self.PASSAGE_TASK
52+
)
5053

5154
@classmethod
5255
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
@@ -57,30 +60,34 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
5760
return supported_multitask_models
5861

5962
def _preprocess_onnx_input(
60-
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
63+
self,
64+
onnx_input: dict[str, NumpyArray],
65+
task_id: Optional[Union[int, Task]] = None,
66+
**kwargs: Any,
6167
) -> dict[str, NumpyArray]:
62-
onnx_input["task_id"] = np.array(self.current_task_id, dtype=np.int64)
68+
if task_id is None:
69+
raise ValueError(f"task_id must be provided for JinaEmbeddingV3, got <{task_id}>")
70+
onnx_input["task_id"] = np.array(task_id, dtype=np.int64)
6371
return onnx_input
6472

6573
def embed(
6674
self,
6775
documents: Union[str, Iterable[str]],
6876
batch_size: int = 256,
6977
parallel: Optional[int] = None,
70-
task_id: int = PASSAGE_TASK,
78+
task_id: Optional[int] = None,
7179
**kwargs: Any,
7280
) -> Iterable[NumpyArray]:
73-
self.current_task_id = task_id
74-
kwargs["task_id"] = task_id
75-
yield from super().embed(documents, batch_size, parallel, **kwargs)
81+
task_id = (
82+
task_id if task_id is not None else self.default_task_id
83+
) # required for multiprocessing
84+
yield from super().embed(documents, batch_size, parallel, task_id=task_id, **kwargs)
7685

7786
def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]:
78-
self.current_task_id = self.QUERY_TASK
79-
yield from super().embed(query, **kwargs)
87+
yield from super().embed(query, task_id=self.QUERY_TASK, **kwargs)
8088

8189
def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
82-
self.current_task_id = self.PASSAGE_TASK
83-
yield from super().embed(texts, **kwargs)
90+
yield from super().embed(texts, task_id=self.PASSAGE_TASK, **kwargs)
8491

8592

8693
class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker):
@@ -90,11 +97,15 @@ def init_embedding(
9097
cache_dir: str,
9198
**kwargs: Any,
9299
) -> JinaEmbeddingV3:
93-
model = JinaEmbeddingV3(
100+
return JinaEmbeddingV3(
94101
model_name=model_name,
95102
cache_dir=cache_dir,
96103
threads=1,
97104
**kwargs,
98105
)
99-
model.current_task_id = kwargs["task_id"]
100-
return model
106+
107+
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, OnnxOutputContext]]:
108+
self.model: JinaEmbeddingV3 # mypy complaints `self.model` does not have `default_task_id`
109+
for idx, batch in items:
110+
onnx_output = self.model.onnx_embed(batch, task_id=self.model.default_task_id)
111+
yield idx, onnx_output

fastembed/text/onnx_text_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _embed_documents(
115115
if not hasattr(self, "model") or self.model is None:
116116
self.load_onnx_model()
117117
for batch in iter_batch(documents, batch_size):
118-
yield from self._post_process_onnx_output(self.onnx_embed(batch))
118+
yield from self._post_process_onnx_output(self.onnx_embed(batch, **kwargs))
119119
else:
120120
if parallel == 0:
121121
parallel = os.cpu_count()

tests/test_text_multitask_embeddings.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,25 @@ def test_single_embedding():
109109

110110
canonical_vector = task["vectors"]
111111
assert np.allclose(
112-
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
112+
embeddings[:, : canonical_vector.shape[1]], canonical_vector, atol=1e-4
113113
), model_desc.model
114114

115+
classification_embeddings = list(model.embed(documents=docs, task_id=Task.CLASSIFICATION))
116+
classification_embeddings = np.stack(classification_embeddings, axis=0)
117+
118+
assert classification_embeddings.shape == (len(docs), dim)
119+
120+
model = TextEmbedding(model_name=model_name, task_id=Task.CLASSIFICATION)
121+
default_embeddings = list(model.embed(documents=docs))
122+
default_embeddings = np.stack(default_embeddings, axis=0)
123+
124+
assert default_embeddings.shape == (len(docs), dim)
125+
126+
assert np.allclose(
127+
classification_embeddings,
128+
default_embeddings,
129+
atol=1e-4,
130+
), model_desc.model
115131
if is_ci:
116132
delete_model_cache(model.model._model_dir)
117133

@@ -140,7 +156,7 @@ def test_single_embedding_query():
140156

141157
canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"]
142158
assert np.allclose(
143-
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
159+
embeddings[:, : canonical_vector.shape[1]], canonical_vector, atol=1e-4
144160
), model_desc.model
145161

146162
if is_ci:
@@ -172,7 +188,7 @@ def test_single_embedding_passage():
172188

173189
canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"]
174190
assert np.allclose(
175-
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
191+
embeddings[:, : canonical_vector.shape[1]], canonical_vector, atol=1e-4
176192
), model_desc.model
177193

178194
if is_ci:
@@ -207,27 +223,6 @@ def test_parallel_processing(dim: int, model_name: str):
207223
delete_model_cache(model.model._model_dir)
208224

209225

210-
def test_task_assignment():
211-
is_ci = os.getenv("CI")
212-
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
213-
214-
if is_ci and not is_manual:
215-
pytest.skip("Skipping in CI non-manual mode")
216-
217-
for model_desc in JinaEmbeddingV3._list_supported_models():
218-
# todo: once we add more models, we should not test models >1GB size locally
219-
model_name = model_desc.model
220-
221-
model = TextEmbedding(model_name=model_name)
222-
223-
for i, task_id in enumerate(Task):
224-
_ = list(model.embed(documents=docs, batch_size=1, task_id=i))
225-
assert model.model.current_task_id == task_id
226-
227-
if is_ci:
228-
delete_model_cache(model.model._model_dir)
229-
230-
231226
@pytest.mark.parametrize("model_name", ["jinaai/jina-embeddings-v3"])
232227
def test_lazy_load(model_name: str):
233228
is_ci = os.getenv("CI")

0 commit comments

Comments
 (0)