44from sentence_transformers import SentenceTransformer
55from sklearn .neighbors import NearestNeighbors
66from sklearn .base import TransformerMixin
7+ import torch
78
89from emm .indexing .base_indexer import CosSimBaseIndexer
910from emm .helper .blocking_functions import _parse_blocking_func
@@ -24,6 +25,7 @@ def __init__(
2425 model_kwargs : Optional [Dict [str , Any ]] = None ,
2526 encode_kwargs : Optional [Dict [str , Any ]] = None ,
2627 similarity_threshold : float = 0.5 ,
28+ num_candidates : int = 10 ,
2729 ** kwargs ,
2830 ) -> None :
2931 """Initialize sentence transformer indexer
@@ -39,6 +41,7 @@ def __init__(
3941 model_kwargs: Additional kwargs for model initialization (e.g. {'truncate_dim': 384})
4042 encode_kwargs: Additional kwargs for encoding method
4143 similarity_threshold: Similarity threshold for filtering matches
44+ num_candidates: Number of nearest neighbors to return
4245 **kwargs: Additional indexer parameters
4346 """
4447 check_sentence_transformers_available ()
@@ -50,11 +53,11 @@ def __init__(
5053 model_kwargs = model_kwargs ,
5154 encode_kwargs = encode_kwargs
5255 )
53- CosSimBaseIndexer .__init__ (self , ** kwargs )
56+ CosSimBaseIndexer .__init__ (self , num_candidates = num_candidates )
5457 self .input_col = input_col
5558 self .blocking_func = _parse_blocking_func (kwargs .get ('blocking_func' ))
5659 self .nn = NearestNeighbors (
57- n_neighbors = kwargs . get ( ' num_candidates' , 10 ) ,
60+ n_neighbors = num_candidates ,
5861 metric = 'cosine'
5962 )
6063 self .base_embeddings = None
@@ -121,82 +124,117 @@ def transform(self, X: pd.DataFrame, multiple_indexers: bool = False) -> pd.Data
121124 with Timer ("SentenceTransformerIndexer.transform" ) as timer :
122125 logger .info (f"Transforming { len (X )} records" )
123126
127+ # Pre-allocate results list with estimated size
128+ est_size = len (X ) * self .num_candidates
124129 results = []
130+ results .reserve (est_size ) # Pre-allocate memory
125131
126132 if self .blocking_func is not None :
133+ # Process in blocks to reduce memory usage
127134 blocks = X [self .input_col ].map (self .blocking_func )
135+
128136 for block in blocks .unique ():
129137 if block not in self .base_embeddings :
130138 continue
131139
132- block_embeddings = self .encode_texts (
133- X [blocks == block ][self .input_col ].tolist (),
140+ block_mask = blocks == block
141+ block_texts = X [block_mask ][self .input_col ].tolist ()
142+
143+ # Use efficient encoding with mixed precision
144+ with torch .cuda .amp .autocast (enabled = self .use_mixed_precision ):
145+ block_embeddings = self .encode_texts (block_texts )
146+
147+ # Calculate similarities efficiently
148+ similarities = self .calculate_pairwise_cosine_similarity (
149+ block_embeddings ,
150+ self .base_embeddings [block ],
151+ batch_size = self .batch_size
134152 )
135153
136- distances , indices = self .nn .kneighbors (block_embeddings )
137- similarities = 1 - distances
154+ # Get top k efficiently using numpy
155+ top_k_indices = np .argpartition (- similarities ,
156+ self .num_candidates - 1 ,
157+ axis = 1 )[:,:self .num_candidates ]
158+ top_k_similarities = np .take_along_axis (similarities , top_k_indices , axis = 1 )
138159
160+ # Process matches efficiently
139161 block_results = self ._process_matches (
140- similarities ,
141- indices ,
142- X [blocks == block ].index ,
162+ top_k_similarities ,
163+ top_k_indices ,
164+ X [block_mask ].index ,
143165 self .base_indices [block ]
144166 )
145167 results .extend (block_results )
146168
147- # Clean up memory after each block
148- del block_embeddings
169+ # Clean up memory
170+ del block_embeddings , similarities , top_k_indices , top_k_similarities
149171 if self .device .type == 'cuda' :
150172 torch .cuda .empty_cache ()
151173 else :
174+ # Process all records at once with batching
152175 query_embeddings = self .encode_texts (
153176 X [self .input_col ].tolist (),
154177 )
155178
156- distances , indices = self .nn .kneighbors (query_embeddings )
157- similarities = 1 - distances
179+ # Use efficient batch-wise similarity calculation
180+ similarities = self .calculate_pairwise_cosine_similarity (
181+ query_embeddings ,
182+ self .base_embeddings ,
183+ batch_size = self .batch_size
184+ )
185+
186+ # Get top k efficiently
187+ top_k_indices = np .argpartition (- similarities ,
188+ self .num_candidates - 1 ,
189+ axis = 1 )[:,:self .num_candidates ]
190+ top_k_similarities = np .take_along_axis (similarities , top_k_indices , axis = 1 )
158191
159192 results = self ._process_matches (
160- similarities ,
161- indices ,
193+ top_k_similarities ,
194+ top_k_indices ,
162195 X .index ,
163196 self .base_indices
164197 )
165198
166199 # Clean up memory
167- del query_embeddings
200+ del query_embeddings , similarities , top_k_indices , top_k_similarities
168201 if self .device .type == 'cuda' :
169202 torch .cuda .empty_cache ()
170-
203+
204+ # Create DataFrame efficiently
171205 candidates = pd .DataFrame (results )
172206 if len (candidates ) == 0 :
173207 candidates = pd .DataFrame (columns = ['uid' , 'gt_uid' , 'score' , 'rank' ])
174-
175- # Sort and rank within groups like other indexers
208+
209+ # Optimize sorting and ranking
176210 candidates = candidates .sort_values (['uid' , 'score' ], ascending = [True , False ])
177- gb = candidates .groupby ('uid' )
178- candidates ['rank' ] = gb ['score' ].transform (lambda x : range (1 , len (x ) + 1 ))
211+ candidates ['rank' ] = candidates .groupby ('uid' ).cumcount () + 1
179212
180213 if multiple_indexers :
181214 candidates [self .column_prefix ()] = 1
182-
215+
216+ # Efficient column operations
183217 if self .carry_on_cols :
184- for col in self .carry_on_cols :
185- if col in X .columns :
186- candidates [col ] = X [col ]
187-
218+ candidates = candidates .merge (
219+ X [['uid' ] + self .carry_on_cols ],
220+ on = 'uid' ,
221+ how = 'left'
222+ )
223+
224+ # Rename columns efficiently
188225 candidates = candidates .rename (columns = {
189226 'score' : f'score_{ self .column_prefix ()} ' ,
190227 'rank' : f'rank_{ self .column_prefix ()} '
191228 })
192229
193- candidates = candidates [candidates ["similarity_score" ] >= self .similarity_threshold ]
230+ # Filter by threshold
231+ candidates = candidates [candidates [f"score_{ self .column_prefix ()} " ] >= self .similarity_threshold ]
194232
195233 logger .info (f"Generated { len (candidates )} candidates" )
196234 return candidates
197235
198236 finally :
199- # Ensure memory is cleaned up even if an error occurs
237+ # Ensure memory cleanup
200238 if self .device .type == 'cuda' :
201239 torch .cuda .empty_cache ()
202240
0 commit comments