Skip to content
This repository was archived by the owner on Jan 6, 2026. It is now read-only.

Commit c2a1432

Browse files
committed
fix(transformers): fix imports
1 parent de8c59b commit c2a1432

11 files changed

Lines changed: 398 additions & 585 deletions

File tree

emm/indexing/__init__.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
- PandasNaiveIndexer: Simple O(n^2) indexing for small datasets
2525
- PandasSortedNeighbourhoodIndexer: Sorted neighborhood indexing
2626
27-
Optional indexers (require additional dependencies):
27+
Optional indexers:
2828
- Spark indexers (requires pyspark):
2929
- SparkCosSimIndexer
3030
- SparkCandidateSelectionEstimator
@@ -35,28 +35,33 @@
3535

3636
from __future__ import annotations
3737

38+
# Core indexers
3839
from emm.indexing.pandas_cos_sim_matcher import PandasCosSimIndexer
3940
from emm.indexing.pandas_naive_indexer import PandasNaiveIndexer
4041
from emm.indexing.pandas_sni import PandasSortedNeighbourhoodIndexer
4142

4243
__all__ = [
43-
"PandasCosSimIndexer",
44+
"PandasCosSimIndexer",
4445
"PandasNaiveIndexer",
45-
"PandasSortedNeighbourhoodIndexer",
46+
"PandasSortedNeighbourhoodIndexer"
4647
]
4748

48-
# Feature detection for sentence transformers
49-
HAS_SENTENCE_TRANSFORMER = False
49+
# Optional sentence transformer support
5050
try:
51-
import sentence_transformers
52-
HAS_SENTENCE_TRANSFORMER = True
51+
from emm.indexing.pandas_sentence_transformer import PandasSentenceTransformerIndexer
52+
__all__.append("PandasSentenceTransformerIndexer")
5353
except ImportError:
5454
pass
5555

56-
# Only import if dependencies are available
57-
if HAS_SENTENCE_TRANSFORMER:
58-
try:
59-
from emm.indexing.pandas_sentence_transformer import PandasSentenceTransformerIndexer
60-
__all__.append("PandasSentenceTransformerIndexer")
61-
except ImportError:
62-
HAS_SENTENCE_TRANSFORMER = False
56+
# Optional Spark support
57+
try:
58+
from emm.indexing.spark_cos_sim_matcher import SparkCosSimIndexer
59+
from emm.indexing.spark_candidate_selection import SparkCandidateSelectionEstimator
60+
from emm.indexing.spark_sni import SparkSortedNeighbourhoodIndexer
61+
__all__.extend([
62+
"SparkCosSimIndexer",
63+
"SparkCandidateSelectionEstimator",
64+
"SparkSortedNeighbourhoodIndexer"
65+
])
66+
except ImportError:
67+
pass

emm/indexing/base_indexer.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
from emm.base.module import Module
2323
from emm.version import __version__
24+
from typing import List, Optional, Dict, Any
25+
import torch
2426

2527

2628
class BaseIndexer(Module):
@@ -49,7 +51,7 @@ def decrease_window_by_one_step(self):
4951
class CosSimBaseIndexer(BaseIndexer):
5052
"""Base implementation of CosSimIndexer class"""
5153

52-
def __init__(self, num_candidates: int) -> None:
54+
def __init__(self, num_candidates: int = 5) -> None:
5355
super().__init__()
5456
if num_candidates <= 0:
5557
msg = "Number of candidates should be a positive integer"
@@ -71,6 +73,26 @@ def decrease_window_by_one_step(self) -> None:
7173
self.num_candidates -= 1
7274

7375

76+
class SentenceTransformerBaseIndexer(BaseIndexer):
77+
"""Base class for sentence transformer based indexers"""
78+
def __init__(
79+
self,
80+
model_name: str = "all-MiniLM-L6-v2",
81+
device: Optional[str] = None,
82+
batch_size: int = 32,
83+
model_kwargs: Optional[Dict[str, Any]] = None,
84+
encode_kwargs: Optional[Dict[str, Any]] = None,
85+
similarity_threshold: float = 0.5,
86+
):
87+
super().__init__()
88+
self.model_name = model_name
89+
self.device = device
90+
self.batch_size = batch_size
91+
self.model_kwargs = model_kwargs or {}
92+
self.encode_kwargs = encode_kwargs or {}
93+
self.similarity_threshold = similarity_threshold
94+
95+
7496
class SNBaseIndexer(BaseIndexer):
7597
"""Base implementation of SN Indexer class"""
7698

0 commit comments

Comments
 (0)