@@ -71,6 +71,7 @@ def extract_topics(
7171 documents : pd .DataFrame ,
7272 c_tf_idf : csr_matrix ,
7373 topics : Mapping [str , List [Tuple [str , float ]]],
74+ embeddings : np .ndarray = None ,
7475 ) -> Mapping [str , List [Tuple [str , float ]]]:
7576 """Extract topics.
7677
@@ -79,6 +80,8 @@ def extract_topics(
7980 documents: All input documents
8081 c_tf_idf: The topic c-TF-IDF representation
8182 topics: The candidate topics as calculated with c-TF-IDF
83+ embeddings: Pre-trained document embeddings. These can be used
84+ instead of an embedding model
8285
8386 Returns:
8487 updated_topics: Updated topic representations
@@ -88,13 +91,19 @@ def extract_topics(
8891 c_tf_idf , documents , topics , self .nr_samples , self .nr_repr_docs
8992 )
9093
94+ # If document embeddings are precomputed, extract the embeddings of the representative documents based on repr_doc_indices
95+ repr_embeddings = None
96+ if embeddings is not None :
97+ repr_embeddings = [embeddings [index ] for index in np .concatenate (repr_doc_indices )]
98+
9199 # We extract the top n words per class
92100 topics = self ._extract_candidate_words (topic_model , c_tf_idf , topics )
93101
94102 # We calculate the similarity between word and document embeddings and create
95103 # topic embeddings from the representative document embeddings
96- sim_matrix , words = self ._extract_embeddings (topic_model , topics , representative_docs , repr_doc_indices )
97-
104+ sim_matrix , words = self ._extract_embeddings (
105+ topic_model , topics , representative_docs , repr_doc_indices , repr_embeddings
106+ )
98107 # Find the best matching words based on the similarity matrix for each topic
99108 updated_topics = self ._extract_top_words (words , topics , sim_matrix )
100109
@@ -150,6 +159,7 @@ def _extract_embeddings(
150159 topics : Mapping [str , List [Tuple [str , float ]]],
151160 representative_docs : List [str ],
152161 repr_doc_indices : List [List [int ]],
162+ repr_embeddings : np .ndarray = None ,
153163 ) -> Union [np .ndarray , List [str ]]:
154164 """Extract the representative document embeddings and create topic embeddings.
155165 Then extract word embeddings and calculate the cosine similarity between topic
@@ -162,13 +172,16 @@ def _extract_embeddings(
162172 representative_docs: A flat list of representative documents
163173 repr_doc_indices: The indices of representative documents
164174 that belong to each topic
175+ repr_embeddings: Embeddings of respective representative_docs
165176
166177 Returns:
167178 sim: The similarity matrix between word and topic embeddings
168179 vocab: The complete vocabulary of input documents
169180 """
170- # Calculate representative docs embeddings and create topic embeddings
171- repr_embeddings = topic_model ._extract_embeddings (representative_docs , method = "document" , verbose = False )
181+ # Calculate representative document embeddings if there are no precomputed embeddings.
182+ if repr_embeddings is None :
183+ repr_embeddings = topic_model ._extract_embeddings (representative_docs , method = "document" , verbose = False )
184+
172185 topic_embeddings = [np .mean (repr_embeddings [i [0 ] : i [- 1 ] + 1 ], axis = 0 ) for i in repr_doc_indices ]
173186
174187 # Calculate word embeddings and extract best matching with updated topic_embeddings
0 commit comments