Skip to content

Commit 307a15f

Browse files
authored
Add pipeline_tag to model card (#1287)
1 parent 74e6484 commit 307a15f

1 file changed

Lines changed: 16 additions & 1 deletion

File tree

bertopic/_save_utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
tags:
5555
- bertopic
5656
library_name: bertopic
57+
pipeline_tag: {PIPELINE_TAG}
5758
---
5859
5960
# {MODEL_NAME}
@@ -284,6 +285,13 @@ def generate_readme(model, repo_id: str):
284285
model_card = model_card.replace("{HYPERPARAMS}", params)
285286
model_card = model_card.replace("{FRAMEWORKS}", frameworks)
286287

288+
# Fill Pipeline tag
289+
has_visual_aspect = check_has_visual_aspect(model)
290+
if not has_visual_aspect:
291+
model_card = model_card.replace("{PIPELINE_TAG}", "text-classification")
292+
else:
293+
model_card = model_card.replace("pipeline_tag: {PIPELINE_TAG} /n","") # TODO add proper tag for this instance
294+
287295
return model_card
288296

289297

@@ -363,6 +371,13 @@ def save_config(model, path: str, embedding_model):
363371

364372
return config
365373

374+
def check_has_visual_aspect(model):
375+
"""Check if model has visual aspect"""
376+
if _has_vision:
377+
for aspect, value in model.topic_aspects_.items():
378+
if isinstance(value[0], Image.Image):
379+
visual_aspects = model.topic_aspects_[aspect]
380+
return True
366381

367382
def save_images(model, path: str):
368383
""" Save topic images """
@@ -470,4 +485,4 @@ def save_safetensors(path, tensors):
470485
import safetensors
471486
safetensors.torch.save_file(tensors, path)
472487
except ImportError:
473-
raise ValueError("`pip install safetensors` to save as .safetensors")
488+
raise ValueError("`pip install safetensors` to save as .safetensors")

0 commit comments

Comments
 (0)