@@ -480,9 +480,15 @@ def transform(self,
480480 images = images ,
481481 method = "document" ,
482482 verbose = self .verbose )
483+
484+ # Check if an embedding model was found
485+ if embeddings is None :
486+ raise ValueError ("No embedding model was found to embed the documents."
487+ "Make sure when loading in the model using BERTopic.load()"
488+ "to also specify the embedding model." )
483489
484490 # Transform without hdbscan_model and umap_model using only cosine similarity
485- if type (self .hdbscan_model ) == BaseCluster :
491+ elif type (self .hdbscan_model ) == BaseCluster :
486492 sim_matrix = cosine_similarity (embeddings , np .array (self .topic_embeddings_ ))
487493 predictions = np .argmax (sim_matrix , axis = 1 ) - self ._outliers
488494
@@ -2938,14 +2944,18 @@ def save(self,
29382944 # Check embedding model
29392945 if save_embedding_model and hasattr (self .embedding_model , '_hf_model' ) and not isinstance (save_embedding_model , str ):
29402946 save_embedding_model = self .embedding_model ._hf_model
2947+ elif not save_embedding_model :
2948+ warnings .warn ("You are saving a BERTopic model without explicitly defining an embedding model."
2949+ "If you are using a sentence-transformers model or a HuggingFace model supported"
2950+ "by sentence-transformers, please save the model by using a pointer towards that model."
2951+ "For example, `save_embedding_model=sentence-transformers/all-mpnet-base-v2`" )
29412952
29422953 # Minimal
29432954 save_utils .save_hf (model = self , save_directory = save_directory , serialization = serialization )
29442955 save_utils .save_topics (model = self , path = save_directory / "topics.json" )
29452956 save_utils .save_images (model = self , path = save_directory / "images" )
29462957 save_utils .save_config (model = self , path = save_directory / 'config.json' , embedding_model = save_embedding_model )
29472958
2948-
29492959 # Additional
29502960 if save_ctfidf :
29512961 save_utils .save_ctfidf (model = self , save_directory = save_directory , serialization = serialization )
@@ -3962,6 +3972,10 @@ def _create_model_from_files(
39623972 embedding_model = select_backend (SentenceTransformer (params ['embedding_model' ]))
39633973 except :
39643974 embedding_model = BaseEmbedder ()
3975+ warnings .warn ("You are loading a BERTopic model without explicitly defining an embedding model."
3976+ "If you want to also load in an embedding model, make sure to use"
3977+ "BERTopic.load(my_model, embedding_model=my_embedding_model)." )
3978+
39653979 if params .get ("embedding_model" ) is not None :
39663980 del params ['embedding_model' ]
39673981
0 commit comments