Skip to content
This repository was archived by the owner on Jan 6, 2026. It is now read-only.

Commit 81d7b92

Browse files
committed
feat(transformers): big performance overhaul
1 parent 6ddabb1 commit 81d7b92

5 files changed

Lines changed: 439 additions & 143 deletions

File tree

emm/indexing/pandas_sentence_transformer.py

Lines changed: 66 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from sentence_transformers import SentenceTransformer
55
from sklearn.neighbors import NearestNeighbors
66
from sklearn.base import TransformerMixin
7+
import torch
78

89
from emm.indexing.base_indexer import CosSimBaseIndexer
910
from emm.helper.blocking_functions import _parse_blocking_func
@@ -24,6 +25,7 @@ def __init__(
2425
model_kwargs: Optional[Dict[str, Any]] = None,
2526
encode_kwargs: Optional[Dict[str, Any]] = None,
2627
similarity_threshold: float = 0.5,
28+
num_candidates: int = 10,
2729
**kwargs,
2830
) -> None:
2931
"""Initialize sentence transformer indexer
@@ -39,6 +41,7 @@ def __init__(
3941
model_kwargs: Additional kwargs for model initialization (e.g. {'truncate_dim': 384})
4042
encode_kwargs: Additional kwargs for encoding method
4143
similarity_threshold: Similarity threshold for filtering matches
44+
num_candidates: Number of nearest neighbors to return
4245
**kwargs: Additional indexer parameters
4346
"""
4447
check_sentence_transformers_available()
@@ -50,11 +53,11 @@ def __init__(
5053
model_kwargs=model_kwargs,
5154
encode_kwargs=encode_kwargs
5255
)
53-
CosSimBaseIndexer.__init__(self, **kwargs)
56+
CosSimBaseIndexer.__init__(self, num_candidates=num_candidates)
5457
self.input_col = input_col
5558
self.blocking_func = _parse_blocking_func(kwargs.get('blocking_func'))
5659
self.nn = NearestNeighbors(
57-
n_neighbors=kwargs.get('num_candidates', 10),
60+
n_neighbors=num_candidates,
5861
metric='cosine'
5962
)
6063
self.base_embeddings = None
@@ -121,82 +124,117 @@ def transform(self, X: pd.DataFrame, multiple_indexers: bool = False) -> pd.Data
121124
with Timer("SentenceTransformerIndexer.transform") as timer:
122125
logger.info(f"Transforming {len(X)} records")
123126

127+
# Pre-allocate results list with estimated size
128+
est_size = len(X) * self.num_candidates
124129
results = []
130+
results.reserve(est_size) # Pre-allocate memory
125131

126132
if self.blocking_func is not None:
133+
# Process in blocks to reduce memory usage
127134
blocks = X[self.input_col].map(self.blocking_func)
135+
128136
for block in blocks.unique():
129137
if block not in self.base_embeddings:
130138
continue
131139

132-
block_embeddings = self.encode_texts(
133-
X[blocks == block][self.input_col].tolist(),
140+
block_mask = blocks == block
141+
block_texts = X[block_mask][self.input_col].tolist()
142+
143+
# Use efficient encoding with mixed precision
144+
with torch.cuda.amp.autocast(enabled=self.use_mixed_precision):
145+
block_embeddings = self.encode_texts(block_texts)
146+
147+
# Calculate similarities efficiently
148+
similarities = self.calculate_pairwise_cosine_similarity(
149+
block_embeddings,
150+
self.base_embeddings[block],
151+
batch_size=self.batch_size
134152
)
135153

136-
distances, indices = self.nn.kneighbors(block_embeddings)
137-
similarities = 1 - distances
154+
# Get top k efficiently using numpy
155+
top_k_indices = np.argpartition(-similarities,
156+
self.num_candidates-1,
157+
axis=1)[:,:self.num_candidates]
158+
top_k_similarities = np.take_along_axis(similarities, top_k_indices, axis=1)
138159

160+
# Process matches efficiently
139161
block_results = self._process_matches(
140-
similarities,
141-
indices,
142-
X[blocks == block].index,
162+
top_k_similarities,
163+
top_k_indices,
164+
X[block_mask].index,
143165
self.base_indices[block]
144166
)
145167
results.extend(block_results)
146168

147-
# Clean up memory after each block
148-
del block_embeddings
169+
# Clean up memory
170+
del block_embeddings, similarities, top_k_indices, top_k_similarities
149171
if self.device.type == 'cuda':
150172
torch.cuda.empty_cache()
151173
else:
174+
# Process all records at once with batching
152175
query_embeddings = self.encode_texts(
153176
X[self.input_col].tolist(),
154177
)
155178

156-
distances, indices = self.nn.kneighbors(query_embeddings)
157-
similarities = 1 - distances
179+
# Use efficient batch-wise similarity calculation
180+
similarities = self.calculate_pairwise_cosine_similarity(
181+
query_embeddings,
182+
self.base_embeddings,
183+
batch_size=self.batch_size
184+
)
185+
186+
# Get top k efficiently
187+
top_k_indices = np.argpartition(-similarities,
188+
self.num_candidates-1,
189+
axis=1)[:,:self.num_candidates]
190+
top_k_similarities = np.take_along_axis(similarities, top_k_indices, axis=1)
158191

159192
results = self._process_matches(
160-
similarities,
161-
indices,
193+
top_k_similarities,
194+
top_k_indices,
162195
X.index,
163196
self.base_indices
164197
)
165198

166199
# Clean up memory
167-
del query_embeddings
200+
del query_embeddings, similarities, top_k_indices, top_k_similarities
168201
if self.device.type == 'cuda':
169202
torch.cuda.empty_cache()
170-
203+
204+
# Create DataFrame efficiently
171205
candidates = pd.DataFrame(results)
172206
if len(candidates) == 0:
173207
candidates = pd.DataFrame(columns=['uid', 'gt_uid', 'score', 'rank'])
174-
175-
# Sort and rank within groups like other indexers
208+
209+
# Optimize sorting and ranking
176210
candidates = candidates.sort_values(['uid', 'score'], ascending=[True, False])
177-
gb = candidates.groupby('uid')
178-
candidates['rank'] = gb['score'].transform(lambda x: range(1, len(x) + 1))
211+
candidates['rank'] = candidates.groupby('uid').cumcount() + 1
179212

180213
if multiple_indexers:
181214
candidates[self.column_prefix()] = 1
182-
215+
216+
# Efficient column operations
183217
if self.carry_on_cols:
184-
for col in self.carry_on_cols:
185-
if col in X.columns:
186-
candidates[col] = X[col]
187-
218+
candidates = candidates.merge(
219+
X[['uid'] + self.carry_on_cols],
220+
on='uid',
221+
how='left'
222+
)
223+
224+
# Rename columns efficiently
188225
candidates = candidates.rename(columns={
189226
'score': f'score_{self.column_prefix()}',
190227
'rank': f'rank_{self.column_prefix()}'
191228
})
192229

193-
candidates = candidates[candidates["similarity_score"] >= self.similarity_threshold]
230+
# Filter by threshold
231+
candidates = candidates[candidates[f"score_{self.column_prefix()}"] >= self.similarity_threshold]
194232

195233
logger.info(f"Generated {len(candidates)} candidates")
196234
return candidates
197235

198236
finally:
199-
# Ensure memory is cleaned up even if an error occurs
237+
# Ensure memory cleanup
200238
if self.device.type == 'cuda':
201239
torch.cuda.empty_cache()
202240

0 commit comments

Comments
 (0)