Skip to content

Commit 31302bf

Browse files
author
MaartenGr
committed
Fix not saving embedding model using .push_to_hf_hub
1 parent 05cfefc commit 31302bf

2 files changed

Lines changed: 17 additions & 3 deletions

File tree

bertopic/_bertopic.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,9 +480,15 @@ def transform(self,
480480
images=images,
481481
method="document",
482482
verbose=self.verbose)
483+
484+
# Check if an embedding model was found
485+
if embeddings is None:
486+
raise ValueError("No embedding model was found to embed the documents."
487+
"Make sure when loading in the model using BERTopic.load()"
488+
"to also specify the embedding model.")
483489

484490
# Transform without hdbscan_model and umap_model using only cosine similarity
485-
if type(self.hdbscan_model) == BaseCluster:
491+
elif type(self.hdbscan_model) == BaseCluster:
486492
sim_matrix = cosine_similarity(embeddings, np.array(self.topic_embeddings_))
487493
predictions = np.argmax(sim_matrix, axis=1) - self._outliers
488494

@@ -2938,14 +2944,18 @@ def save(self,
29382944
# Check embedding model
29392945
if save_embedding_model and hasattr(self.embedding_model, '_hf_model') and not isinstance(save_embedding_model, str):
29402946
save_embedding_model = self.embedding_model._hf_model
2947+
elif not save_embedding_model:
2948+
warnings.warn("You are saving a BERTopic model without explicitly defining an embedding model."
2949+
"If you are using a sentence-transformers model or a HuggingFace model supported"
2950+
"by sentence-transformers, please save the model by using a pointer towards that model."
2951+
"For example, `save_embedding_model=sentence-transformers/all-mpnet-base-v2`")
29412952

29422953
# Minimal
29432954
save_utils.save_hf(model=self, save_directory=save_directory, serialization=serialization)
29442955
save_utils.save_topics(model=self, path=save_directory / "topics.json")
29452956
save_utils.save_images(model=self, path=save_directory / "images")
29462957
save_utils.save_config(model=self, path=save_directory / 'config.json', embedding_model=save_embedding_model)
29472958

2948-
29492959
# Additional
29502960
if save_ctfidf:
29512961
save_utils.save_ctfidf(model=self, save_directory=save_directory, serialization=serialization)
@@ -3962,6 +3972,10 @@ def _create_model_from_files(
39623972
embedding_model = select_backend(SentenceTransformer(params['embedding_model']))
39633973
except:
39643974
embedding_model = BaseEmbedder()
3975+
warnings.warn("You are loading a BERTopic model without explicitly defining an embedding model."
3976+
"If you want to also load in an embedding model, make sure to use"
3977+
"BERTopic.load(my_model, embedding_model=my_embedding_model).")
3978+
39653979
if params.get("embedding_model") is not None:
39663980
del params['embedding_model']
39673981

bertopic/_save_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def push_to_hf_hub(
111111
create_pr: bool = False,
112112
model_card: bool = True,
113113
serialization: str = "safetensors",
114-
save_embedding_model: str = None,
114+
save_embedding_model: Union[str, bool] = True,
115115
save_ctfidf: bool = False,
116116
):
117117
""" Push your BERTopic model to a HuggingFace Hub

0 commit comments

Comments
 (0)