Skip to content

Commit 13c8ed7

Browse files
authored
Update to pass precomputed embeddings to KeyBERTInspired (#2368)
1 parent 0ed386c commit 13c8ed7

2 files changed

Lines changed: 24 additions & 5 deletions

File tree

bertopic/_bertopic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
from bertopic.representation._mmr import mmr
6060
from bertopic.backend._utils import select_backend
6161
from bertopic.vectorizers import ClassTfidfTransformer
62-
from bertopic.representation import BaseRepresentation
62+
from bertopic.representation import BaseRepresentation, KeyBERTInspired
6363
from bertopic.dimensionality import BaseDimensionalityReduction
6464
from bertopic.cluster._utils import hdbscan_delegator, is_supported_hdbscan
6565
from bertopic._utils import (
@@ -4051,6 +4051,7 @@ def _extract_topics(
40514051
documents,
40524052
fine_tune_representation=fine_tune_representation,
40534053
calculate_aspects=fine_tune_representation,
4054+
embeddings=embeddings,
40544055
)
40554056
self._create_topic_vectors(documents=documents, embeddings=embeddings, mappings=mappings)
40564057

@@ -4311,6 +4312,7 @@ def _extract_words_per_topic(
43114312
c_tf_idf: csr_matrix = None,
43124313
fine_tune_representation: bool = True,
43134314
calculate_aspects: bool = True,
4315+
embeddings: np.ndarray = None,
43144316
) -> Mapping[str, List[Tuple[str, float]]]:
43154317
"""Based on tf_idf scores per topic, extract the top n words per topic.
43164318
@@ -4326,6 +4328,8 @@ def _extract_words_per_topic(
43264328
fine_tune_representation: If True, the topic representation will be fine-tuned using representation models.
43274329
If False, the topic representation will remain as the base c-TF-IDF representation.
43284330
calculate_aspects: Whether to calculate additional topic aspects
4331+
embeddings: Pre-trained document embeddings. These can be used
4332+
instead of an embedding model
43294333
43304334
Returns:
43314335
topics: The top words per topic
@@ -4361,6 +4365,8 @@ def _extract_words_per_topic(
43614365
elif fine_tune_representation and isinstance(self.representation_model, list):
43624366
for tuner in self.representation_model:
43634367
topics = tuner.extract_topics(self, documents, c_tf_idf, topics)
4368+
elif fine_tune_representation and isinstance(self.representation_model, KeyBERTInspired):
4369+
topics = self.representation_model.extract_topics(self, documents, c_tf_idf, topics, embeddings)
43644370
elif fine_tune_representation and isinstance(self.representation_model, BaseRepresentation):
43654371
topics = self.representation_model.extract_topics(self, documents, c_tf_idf, topics)
43664372
elif fine_tune_representation and isinstance(self.representation_model, dict):

bertopic/representation/_keybert.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def extract_topics(
7171
documents: pd.DataFrame,
7272
c_tf_idf: csr_matrix,
7373
topics: Mapping[str, List[Tuple[str, float]]],
74+
embeddings: np.ndarray = None,
7475
) -> Mapping[str, List[Tuple[str, float]]]:
7576
"""Extract topics.
7677
@@ -79,6 +80,8 @@ def extract_topics(
7980
documents: All input documents
8081
c_tf_idf: The topic c-TF-IDF representation
8182
topics: The candidate topics as calculated with c-TF-IDF
83+
embeddings: Pre-trained document embeddings. These can be used
84+
instead of an embedding model
8285
8386
Returns:
8487
updated_topics: Updated topic representations
@@ -88,13 +91,19 @@ def extract_topics(
8891
c_tf_idf, documents, topics, self.nr_samples, self.nr_repr_docs
8992
)
9093

94+
# If document embeddings are precomputed, extract the embeddings of the representative documents based on repr_doc_indices
95+
repr_embeddings = None
96+
if embeddings is not None:
97+
repr_embeddings = [embeddings[index] for index in np.concatenate(repr_doc_indices)]
98+
9199
# We extract the top n words per class
92100
topics = self._extract_candidate_words(topic_model, c_tf_idf, topics)
93101

94102
# We calculate the similarity between word and document embeddings and create
95103
# topic embeddings from the representative document embeddings
96-
sim_matrix, words = self._extract_embeddings(topic_model, topics, representative_docs, repr_doc_indices)
97-
104+
sim_matrix, words = self._extract_embeddings(
105+
topic_model, topics, representative_docs, repr_doc_indices, repr_embeddings
106+
)
98107
# Find the best matching words based on the similarity matrix for each topic
99108
updated_topics = self._extract_top_words(words, topics, sim_matrix)
100109

@@ -150,6 +159,7 @@ def _extract_embeddings(
150159
topics: Mapping[str, List[Tuple[str, float]]],
151160
representative_docs: List[str],
152161
repr_doc_indices: List[List[int]],
162+
repr_embeddings: np.ndarray = None,
153163
) -> Union[np.ndarray, List[str]]:
154164
"""Extract the representative document embeddings and create topic embeddings.
155165
Then extract word embeddings and calculate the cosine similarity between topic
@@ -162,13 +172,16 @@ def _extract_embeddings(
162172
representative_docs: A flat list of representative documents
163173
repr_doc_indices: The indices of representative documents
164174
that belong to each topic
175+
repr_embeddings: Embeddings of respective representative_docs
165176
166177
Returns:
167178
sim: The similarity matrix between word and topic embeddings
168179
vocab: The complete vocabulary of input documents
169180
"""
170-
# Calculate representative docs embeddings and create topic embeddings
171-
repr_embeddings = topic_model._extract_embeddings(representative_docs, method="document", verbose=False)
181+
# Calculate representative document embeddings if there are no precomputed embeddings.
182+
if repr_embeddings is None:
183+
repr_embeddings = topic_model._extract_embeddings(representative_docs, method="document", verbose=False)
184+
172185
topic_embeddings = [np.mean(repr_embeddings[i[0] : i[-1] + 1], axis=0) for i in repr_doc_indices]
173186

174187
# Calculate word embeddings and extract best matching with updated topic_embeddings

0 commit comments

Comments
 (0)