33
44import numpy as np
55
6+ from fastembed .common .onnx_model import OnnxOutputContext
67from fastembed .common .types import NumpyArray
78from fastembed .text .pooled_normalized_embedding import PooledNormalizedEmbedding
89from fastembed .text .onnx_embedding import OnnxTextEmbeddingWorker
@@ -44,9 +45,11 @@ class JinaEmbeddingV3(PooledNormalizedEmbedding):
4445 PASSAGE_TASK = Task .RETRIEVAL_PASSAGE
4546 QUERY_TASK = Task .RETRIEVAL_QUERY
4647
47- def __init__ (self , * args : Any , ** kwargs : Any ):
48+ def __init__ (self , * args : Any , task_id : Optional [ int ] = None , ** kwargs : Any ):
4849 super ().__init__ (* args , ** kwargs )
49- self .current_task_id : Union [Task , int ] = self .PASSAGE_TASK
50+ self .default_task_id : Union [Task , int ] = (
51+ task_id if task_id is not None else self .PASSAGE_TASK
52+ )
5053
5154 @classmethod
5255 def _get_worker_class (cls ) -> Type [OnnxTextEmbeddingWorker ]:
@@ -57,30 +60,34 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
5760 return supported_multitask_models
5861
5962 def _preprocess_onnx_input (
60- self , onnx_input : dict [str , NumpyArray ], ** kwargs : Any
63+ self ,
64+ onnx_input : dict [str , NumpyArray ],
65+ task_id : Optional [Union [int , Task ]] = None ,
66+ ** kwargs : Any ,
6167 ) -> dict [str , NumpyArray ]:
62- onnx_input ["task_id" ] = np .array (self .current_task_id , dtype = np .int64 )
68+ if task_id is None :
69+ raise ValueError (f"task_id must be provided for JinaEmbeddingV3, got <{ task_id } >" )
70+ onnx_input ["task_id" ] = np .array (task_id , dtype = np .int64 )
6371 return onnx_input
6472
6573 def embed (
6674 self ,
6775 documents : Union [str , Iterable [str ]],
6876 batch_size : int = 256 ,
6977 parallel : Optional [int ] = None ,
70- task_id : int = PASSAGE_TASK ,
78+ task_id : Optional [ int ] = None ,
7179 ** kwargs : Any ,
7280 ) -> Iterable [NumpyArray ]:
73- self .current_task_id = task_id
74- kwargs ["task_id" ] = task_id
75- yield from super ().embed (documents , batch_size , parallel , ** kwargs )
81+ task_id = (
82+ task_id if task_id is not None else self .default_task_id
83+ ) # required for multiprocessing
84+ yield from super ().embed (documents , batch_size , parallel , task_id = task_id , ** kwargs )
7685
7786 def query_embed (self , query : Union [str , Iterable [str ]], ** kwargs : Any ) -> Iterable [NumpyArray ]:
78- self .current_task_id = self .QUERY_TASK
79- yield from super ().embed (query , ** kwargs )
87+ yield from super ().embed (query , task_id = self .QUERY_TASK , ** kwargs )
8088
8189 def passage_embed (self , texts : Iterable [str ], ** kwargs : Any ) -> Iterable [NumpyArray ]:
82- self .current_task_id = self .PASSAGE_TASK
83- yield from super ().embed (texts , ** kwargs )
90+ yield from super ().embed (texts , task_id = self .PASSAGE_TASK , ** kwargs )
8491
8592
8693class JinaEmbeddingV3Worker (OnnxTextEmbeddingWorker ):
@@ -90,11 +97,15 @@ def init_embedding(
9097 cache_dir : str ,
9198 ** kwargs : Any ,
9299 ) -> JinaEmbeddingV3 :
93- model = JinaEmbeddingV3 (
100+ return JinaEmbeddingV3 (
94101 model_name = model_name ,
95102 cache_dir = cache_dir ,
96103 threads = 1 ,
97104 ** kwargs ,
98105 )
99- model .current_task_id = kwargs ["task_id" ]
100- return model
106+
107+ def process (self , items : Iterable [tuple [int , Any ]]) -> Iterable [tuple [int , OnnxOutputContext ]]:
108+ self .model : JinaEmbeddingV3 # mypy complaints `self.model` does not have `default_task_id`
109+ for idx , batch in items :
110+ onnx_output = self .model .onnx_embed (batch , task_id = self .model .default_task_id )
111+ yield idx , onnx_output
0 commit comments