3333ProgressCallback = Callable [[int , int ], None ]
3434
3535
36+ #: HuggingFace repository hosting the official FMPose3D checkpoints.
37+ _HF_REPO_ID : str = "deruyter92/fmpose_temp"
38+
3639# Default camera-to-world rotation quaternion (from the demo script).
3740_DEFAULT_CAM_ROTATION = np .array (
3841 [0.1407056450843811 , - 0.1500701755285263 , - 0.755240797996521 , 0.6223280429840088 ],
@@ -560,7 +563,7 @@ def __init__(
560563 self ,
561564 model_cfg : FMPose3DConfig | None = None ,
562565 inference_cfg : InferenceConfig | None = None ,
563- model_weights_path : str | Path | None = SKIP_WEIGHTS_VALIDATION ,
566+ model_weights_path : str | Path | None = None ,
564567 device : str | torch .device | None = None ,
565568 * ,
566569 estimator_2d : HRNetEstimator | SuperAnimalEstimator | None = None ,
@@ -601,7 +604,7 @@ def __init__(
601604 @classmethod
602605 def for_animals (
603606 cls ,
604- model_weights_path : str = SKIP_WEIGHTS_VALIDATION ,
607+ model_weights_path : str | None = None ,
605608 * ,
606609 device : str | torch .device | None = None ,
607610 inference_cfg : InferenceConfig | None = None ,
@@ -958,15 +961,11 @@ def _load_weights(self) -> None:
958961 # Private helpers – input resolution
959962 # ------------------------------------------------------------------
960963
961- def _resolve_model_weights_path (self ) -> None :
962- # TODO @deruyter92: THIS IS TEMPORARY UNTIL WE DOWNLOAD THE WEIGHTS FROM HUGGINGFACE
963- if self .model_weights_path is SKIP_WEIGHTS_VALIDATION :
964- return SKIP_WEIGHTS_VALIDATION
965-
966- if not self .model_weights_path :
964+ def _resolve_model_weights_path (self ) -> None :
965+ if self .model_weights_path is None :
967966 self ._download_model_weights ()
968967 self .model_weights_path = Path (self .model_weights_path ).resolve ()
969- if not self .model_weights_path .exists ():
968+ if not self .model_weights_path .is_file ():
970969 raise ValueError (
971970 f"Model weights file not found: { self .model_weights_path } . "
972971 "Please provide a valid path to a .pth checkpoint file in the "
@@ -976,12 +975,28 @@ def _resolve_model_weights_path(self) -> None:
976975 return self .model_weights_path
977976
978977 def _download_model_weights (self ) -> None :
979- """Download model weights from huggingface."""
980- # TODO @deruyter92: Implement download from huggingface
981- raise NotImplementedError (
982- "Downloading model weights from huggingface is not implemented yet."
983- "Please provide a valid path to a .pth checkpoint file in the "
984- "FMPose3DInference constructor."
978+ """Download model weights from HuggingFace Hub.
979+
980+ The weight file is determined by the current ``model_cfg.model_type``
981+ (e.g. ``"fmpose3d_humans"`` -> ``fmpose3d_humans.pth``). Files are
982+ cached locally by :func:`huggingface_hub.hf_hub_download` so
983+ subsequent calls are instant.
984+
985+ Sets ``self.model_weights_path`` to the local cached file path.
986+ """
987+ try :
988+ from huggingface_hub import hf_hub_download
989+ except ImportError :
990+ raise ImportError (
991+ "huggingface_hub is required to download model weights. "
992+ "Install it with: pip install huggingface_hub. Or download "
993+ "the weights manually and set model_weights_path to the weights file."
994+ ) from None
995+
996+ filename = f"{ self .model_cfg .model_type .value } .pth"
997+ self .model_weights_path = hf_hub_download (
998+ repo_id = _HF_REPO_ID ,
999+ filename = filename ,
9851000 )
9861001
9871002 def _ingest_input (self , source : Source ) -> _IngestedInput :
0 commit comments