Skip to content
This repository was archived by the owner on Jan 6, 2026. It is now read-only.

Commit 6ddabb1

Browse files
committed
fix(transformers): slopmop
1 parent 0de44e4 commit 6ddabb1

5 files changed

Lines changed: 43 additions & 31 deletions

File tree

emm/indexing/pandas_sentence_transformer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
batch_size: Optional[int] = None,
2424
model_kwargs: Optional[Dict[str, Any]] = None,
2525
encode_kwargs: Optional[Dict[str, Any]] = None,
26+
similarity_threshold: float = 0.5,
2627
**kwargs,
2728
) -> None:
2829
"""Initialize sentence transformer indexer
@@ -37,6 +38,7 @@ def __init__(
3738
batch_size: Batch size for encoding
3839
model_kwargs: Additional kwargs for model initialization (e.g. {'truncate_dim': 384})
3940
encode_kwargs: Additional kwargs for encoding method
41+
similarity_threshold: Similarity threshold for filtering matches
4042
**kwargs: Additional indexer parameters
4143
"""
4244
check_sentence_transformers_available()
@@ -60,6 +62,7 @@ def __init__(
6062
self.gt = None
6163
logger.info(f"Initializing SentenceTransformerIndexer with model {model_name}")
6264
self.carry_on_cols = []
65+
self.similarity_threshold = similarity_threshold
6366

6467
def fit(self, X: pd.DataFrame, y: Any = None) -> TransformerMixin:
6568
"""Compute embeddings for base names and fit nearest neighbors
@@ -187,6 +190,8 @@ def transform(self, X: pd.DataFrame, multiple_indexers: bool = False) -> pd.Data
187190
'rank': f'rank_{self.column_prefix()}'
188191
})
189192

193+
candidates = candidates[candidates["similarity_score"] >= self.similarity_threshold]
194+
190195
logger.info(f"Generated {len(candidates)} candidates")
191196
return candidates
192197

@@ -198,7 +203,7 @@ def transform(self, X: pd.DataFrame, multiple_indexers: bool = False) -> pd.Data
198203
def _process_matches(self, similarities, indices, query_indices, base_indices):
199204
"""Process matches and filter by similarity threshold"""
200205
results = []
201-
mask = similarities >= self.cos_sim_lower_bound
206+
mask = similarities >= self.similarity_threshold
202207

203208
for i, (sim_row, idx_row, mask_row) in enumerate(zip(similarities, indices, mask)):
204209
valid_indices = idx_row[mask_row]

emm/models/sentence_transformer/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class BaseSentenceTransformerComponent:
3030
def __init__(
3131
self,
3232
model_name: str = "all-MiniLM-L6-v2",
33+
similarity_threshold: float = 0.5,
3334
device: Optional[str] = None,
3435
batch_size: Optional[int] = None,
3536
model_kwargs: Optional[Dict[str, Any]] = None,
@@ -39,6 +40,7 @@ def __init__(
3940
4041
Args:
4142
model_name: Name or path of the sentence transformer model to use
43+
similarity_threshold: Similarity threshold for filtering candidates
4244
device: Device to run model on ('cpu', 'cuda', or None for auto-detection)
4345
batch_size: Batch size for encoding (None for auto-detection)
4446
model_kwargs: Additional kwargs passed to SentenceTransformer initialization
@@ -74,6 +76,7 @@ def __init__(
7476
self.encode_kwargs = encode_kwargs or {}
7577
self.model = SentenceTransformer(model_name, device=self.device, **self.model_kwargs)
7678
self.model_name = model_name
79+
self.similarity_threshold = similarity_threshold
7780

7881
def encode_texts(self, texts: List[str]) -> NDArray[np.float32]:
7982
"""Encode a list of texts into embeddings.

emm/models/sentence_transformer/tuning/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class TuningConfig:
1717
wandb_project: Optional W&B project for logging
1818
loss_type: Type of loss function ('dae', 'contrastive', 'combined')
1919
device_count: Number of devices to use
20+
similarity_threshold: Similarity threshold for filtering candidates
2021
"""
2122
model_name: str = 'all-MiniLM-L6-v2'
2223
batch_size: int = 32
@@ -27,4 +28,5 @@ class TuningConfig:
2728
output_path: Optional[Path] = None
2829
wandb_project: Optional[str] = None
2930
loss_type: str = 'dae' # 'dae', 'contrastive', or 'combined'
30-
device_count: int = 1
31+
device_count: int = 1
32+
similarity_threshold: float = 0.5

emm/parameters.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,25 @@
2727

2828
ROOT_DIRECTORY = Path(__file__).resolve().parent.parent
2929

30+
# Common sentence transformer settings for reuse
31+
SENTENCE_TRANSFORMER_BASE = {
32+
"type": "sentence_transformer",
33+
"model_name": "all-MiniLM-L6-v2", # Default lightweight model
34+
"num_candidates": 10,
35+
"similarity_threshold": 0.5, # Renamed from cos_sim_lower_bound for clarity
36+
"device": None, # Auto-detect CUDA/CPU
37+
"batch_size": None, # Auto-detect based on available memory
38+
"input_col": "preprocessed",
39+
# Support for model-specific parameters as shown in mixedbread example
40+
"model_kwargs": {
41+
"normalize_embeddings": True,
42+
# Other model-specific params like truncate_dim can be added here
43+
},
44+
"encode_kwargs": {
45+
"normalize_embeddings": True,
46+
},
47+
}
48+
3049
# default model parameters picked up in PandasEntityMatching and SparkEntityMatching
3150
MODEL_PARAMS = {
3251
# type of name preprocessor defined in name_preprocessing.py
@@ -44,17 +63,8 @@
4463
"type": "sni", # Sorted Neighbourhood Indexing,
4564
"window_length": 3,
4665
},
47-
# Sentence transformer indexer
48-
{
49-
"type": "sentence_transformer",
50-
"model_name": "all-MiniLM-L6-v2",
51-
"num_candidates": 10,
52-
"cos_sim_lower_bound": 0.5,
53-
"device": None,
54-
"batch_size": None,
55-
"model_kwargs": None,
56-
"encode_kwargs": None,
57-
},
66+
# Sentence transformer indexer with base settings
67+
SENTENCE_TRANSFORMER_BASE,
5868
],
5969
"partition_size": 5000, # Number of names in ground_truth and names_to_match per Spark partition: across-worker division. (Set to None for no automatic repartitioning)
6070
# input columns:
@@ -88,31 +98,20 @@
8898
"cosine_similarity": {
8999
"tokenizer": "words", # "words" or "characters"
90100
"ngram": 1, # number of token per n-gram
91-
"cos_sim_lower_bound": 0.0,
92-
"num_candidates": 10, # Number of candidates returned by indexer.
93-
"binary_countvectorizer": True, # use binary countVectorizer or not
94-
# the same value as is used in Spark pipeline in CountVectorizer(vocabSize) 2**25=33554432, 2**24=16777216
101+
"similarity_threshold": 0.0, # Renamed from cos_sim_lower_bound for consistency
102+
"num_candidates": 10,
103+
"binary_countvectorizer": True,
95104
"max_features": 2**25,
96-
# Python function to be used in blocking ground_truth & names_to_match (only pairs within the same block will be considered in cosine similarity)
97-
# - None # No Blocking
98-
# - blocking_functions.first() # block using first character
99105
"blocking_func": None,
100106
},
101107
"sni": {
102-
"window_length": 3, # window size for SNI
103-
"mapping_func": None, # custom mapping function applied in SNI step
108+
"window_length": 3,
109+
"mapping_func": None,
104110
},
105111
"naive": {},
106112
"sentence_transformer": {
107-
"model_name": "all-MiniLM-L6-v2", # Default lightweight model or path to fine-tuned model
108-
"num_candidates": 10, # Number of candidates returned by indexer
109-
"cos_sim_lower_bound": 0.5, # Minimum similarity threshold
110-
"batch_size": None, # Will use auto-detection
111-
"device": None, # Auto-detect device
112-
"blocking_func": None, # Optional blocking function
113-
"input_col": "preprocessed", # Input column name
114-
"model_kwargs": None, # Optional kwargs for model initialization
115-
"encode_kwargs": None, # Optional kwargs for encoding
113+
**SENTENCE_TRANSFORMER_BASE,
114+
"blocking_func": None, # Additional parameter specific to indexer
116115
},
117116
}
118117

emm/supervised_model/sentence_transformer_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
self,
4545
score_col: str = "nm_score",
4646
model_name: str = "all-MiniLM-L6-v2",
47+
similarity_threshold: float = 0.5,
4748
device: Optional[str] = None,
4849
batch_size: Optional[int] = None,
4950
model_kwargs: Optional[Dict[str, Any]] = None,
@@ -74,6 +75,7 @@ def __init__(
7475
)
7576
BaseSupervisedModel.__init__(self)
7677
self.score_col = score_col
78+
self.similarity_threshold = similarity_threshold
7779

7880
def fit(self, X: pd.DataFrame, y: pd.Series | None = None) -> SentenceTransformerLayerTransformer:
7981
"""Placeholder for fit method - not used as we use pre-trained models
@@ -187,6 +189,7 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame | None:
187189
})
188190

189191
X = self.calc_score(X)
192+
X = X[X[self.score_col] >= self.similarity_threshold]
190193
X = self.select_best_score(X, group_cols=["uid"])
191194

192195
timer.log_param("cands", len(X))

0 commit comments

Comments
 (0)