2121
2222from emm .base .module import Module
2323from emm .version import __version__
24+ from typing import List , Optional , Dict , Any
25+ import torch
2426
2527
2628class BaseIndexer (Module ):
@@ -49,7 +51,7 @@ def decrease_window_by_one_step(self):
4951class CosSimBaseIndexer (BaseIndexer ):
5052 """Base implementation of CosSimIndexer class"""
5153
52- def __init__ (self , num_candidates : int ) -> None :
54+ def __init__ (self , num_candidates : int = 5 ) -> None :
5355 super ().__init__ ()
5456 if num_candidates <= 0 :
5557 msg = "Number of candidates should be a positive integer"
@@ -71,6 +73,26 @@ def decrease_window_by_one_step(self) -> None:
7173 self .num_candidates -= 1
7274
7375
76+ class SentenceTransformerBaseIndexer (BaseIndexer ):
77+ """Base class for sentence transformer based indexers"""
78+ def __init__ (
79+ self ,
80+ model_name : str = "all-MiniLM-L6-v2" ,
81+ device : Optional [str ] = None ,
82+ batch_size : int = 32 ,
83+ model_kwargs : Optional [Dict [str , Any ]] = None ,
84+ encode_kwargs : Optional [Dict [str , Any ]] = None ,
85+ similarity_threshold : float = 0.5 ,
86+ ):
87+ super ().__init__ ()
88+ self .model_name = model_name
89+ self .device = device
90+ self .batch_size = batch_size
91+ self .model_kwargs = model_kwargs or {}
92+ self .encode_kwargs = encode_kwargs or {}
93+ self .similarity_threshold = similarity_threshold
94+
95+
7496class SNBaseIndexer (BaseIndexer ):
7597 """Base implementation of SN Indexer class"""
7698
0 commit comments