|
| 1 | +from dataclasses import asdict |
| 2 | +from typing import Union, Iterable, Optional, Any, Type |
| 3 | + |
| 4 | +from fastembed.common.model_description import DenseModelDescription, ModelSource |
| 5 | +from fastembed.common.onnx_model import OnnxOutputContext |
| 6 | +from fastembed.common.types import NumpyArray |
| 7 | +from fastembed.late_interaction.late_interaction_embedding_base import ( |
| 8 | + LateInteractionTextEmbeddingBase, |
| 9 | +) |
| 10 | +from fastembed.text.onnx_embedding import OnnxTextEmbedding |
| 11 | +from fastembed.text.onnx_text_model import TextEmbeddingWorker |
| 12 | +import numpy as np |
| 13 | + |
| 14 | +supported_token_embeddings_models = [ |
| 15 | + DenseModelDescription( |
| 16 | + model="jinaai/jina-embeddings-v2-small-en-tokens", |
| 17 | + dim=512, |
| 18 | + description="Text embeddings, Unimodal (text), English, 8192 input tokens truncation," |
| 19 | + " Prefixes for queries/documents: not necessary, 2023 year.", |
| 20 | + license="apache-2.0", |
| 21 | + size_in_GB=0.12, |
| 22 | + sources=ModelSource(hf="xenova/jina-embeddings-v2-small-en"), |
| 23 | + model_file="onnx/model.onnx", |
| 24 | + ), |
| 25 | +] |
| 26 | + |
| 27 | + |
| 28 | +class TokenEmbeddingsModel(OnnxTextEmbedding, LateInteractionTextEmbeddingBase): |
| 29 | + @classmethod |
| 30 | + def _list_supported_models(cls) -> list[DenseModelDescription]: |
| 31 | + """Lists the supported models. |
| 32 | +
|
| 33 | + Returns: |
| 34 | + list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. |
| 35 | + """ |
| 36 | + return supported_token_embeddings_models |
| 37 | + |
| 38 | + @classmethod |
| 39 | + def list_supported_models(cls) -> list[dict[str, Any]]: |
| 40 | + """Lists the supported models. |
| 41 | +
|
| 42 | + Returns: |
| 43 | + list[dict[str, Any]]: A list of dictionaries containing the model information. |
| 44 | + """ |
| 45 | + return [asdict(model) for model in cls._list_supported_models()] |
| 46 | + |
| 47 | + @classmethod |
| 48 | + def _get_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]: |
| 49 | + return TokensEmbeddingWorker |
| 50 | + |
| 51 | + def _post_process_onnx_output( |
| 52 | + self, output: OnnxOutputContext, **kwargs: Any |
| 53 | + ) -> Iterable[NumpyArray]: |
| 54 | + # Size: (batch_size, sequence_length, hidden_size) |
| 55 | + embeddings = output.model_output |
| 56 | + # Size: (batch_size, sequence_length) |
| 57 | + assert output.attention_mask is not None |
| 58 | + masks = output.attention_mask |
| 59 | + |
| 60 | + # For each document we only select those embeddings that are not masked out |
| 61 | + for i in range(embeddings.shape[0]): |
| 62 | + yield embeddings[i, masks[i] == 1] |
| 63 | + |
| 64 | + def embed( |
| 65 | + self, |
| 66 | + documents: Union[str, Iterable[str]], |
| 67 | + batch_size: int = 256, |
| 68 | + parallel: Optional[int] = None, |
| 69 | + **kwargs: Any, |
| 70 | + ) -> Iterable[NumpyArray]: |
| 71 | + yield from super().embed(documents, batch_size=batch_size, parallel=parallel, **kwargs) |
| 72 | + |
| 73 | + def tokenize_docs(self, documents: list[str]) -> list[NumpyArray]: |
| 74 | + if self.tokenizer is None: |
| 75 | + raise ValueError("Tokenizer not initialized") |
| 76 | + encoded = self.tokenizer.encode_batch(documents) |
| 77 | + return [np.array(e.ids, dtype=np.int32) for e in encoded] |
| 78 | + |
| 79 | + |
| 80 | +class TokensEmbeddingWorker(TextEmbeddingWorker[NumpyArray]): |
| 81 | + def init_embedding( |
| 82 | + self, model_name: str, cache_dir: str, **kwargs: Any |
| 83 | + ) -> TokenEmbeddingsModel: |
| 84 | + return TokenEmbeddingsModel( |
| 85 | + model_name=model_name, |
| 86 | + cache_dir=cache_dir, |
| 87 | + threads=1, |
| 88 | + **kwargs, |
| 89 | + ) |
| 90 | + |
| 91 | + |
| 92 | +if __name__ == "__main__": |
| 93 | + # Example usage |
| 94 | + print(TokenEmbeddingsModel.list_supported_models()) |
| 95 | + model = TokenEmbeddingsModel(model_name="jinaai/jina-embeddings-v2-small-en-tokens") |
| 96 | + docs = ["Hello, world!", "hello", "hello hello"] |
| 97 | + |
| 98 | + embeddings = model.embed(docs) |
| 99 | + for emb in embeddings: |
| 100 | + print(emb.shape) |
| 101 | + |
| 102 | + print(model.tokenize_docs(docs)) |
0 commit comments