@@ -480,6 +480,9 @@ class _IngestedInput:
480480# ---------------------------------------------------------------------------
481481
482482
483+ # FIXME @deruyter92: THIS IS TEMPORARY UNTIL WE DOWNLOAD THE WEIGHTS FROM HUGGINGFACE
484+ SKIP_WEIGHTS_VALIDATION = object () # sentinel value to indicate that the weights should not be validated
485+
483486class FMPose3DInference :
484487 """High-level, two-step inference API for FMPose3D.
485488
@@ -534,7 +537,7 @@ def __init__(
534537 self ,
535538 model_cfg : FMPose3DConfig | None = None ,
536539 inference_cfg : InferenceConfig | None = None ,
537- model_weights_path : str = "" ,
540+ model_weights_path : str | Path | None = SKIP_WEIGHTS_VALIDATION ,
538541 device : str | torch .device | None = None ,
539542 * ,
540543 estimator_2d : HRNetEstimator | SuperAnimalEstimator | None = None ,
@@ -544,6 +547,9 @@ def __init__(
544547 self .inference_cfg = inference_cfg or InferenceConfig ()
545548 self .model_weights_path = model_weights_path
546549
550+ # Validate model weights path (download if needed)
551+ self ._resolve_model_weights_path ()
552+
547553 # Skeleton configuration from the model config.
548554 self ._joints_left : list [int ] = list (self .model_cfg .joints_left )
549555 self ._joints_right : list [int ] = list (self .model_cfg .joints_right )
@@ -572,7 +578,7 @@ def __init__(
572578 @classmethod
573579 def for_animals (
574580 cls ,
575- model_weights_path : str = "" ,
581+ model_weights_path : str = SKIP_WEIGHTS_VALIDATION ,
576582 * ,
577583 device : str | torch .device | None = None ,
578584 inference_cfg : InferenceConfig | None = None ,
@@ -915,35 +921,46 @@ def _load_weights(self) -> None:
915921 state-dict keys and pull matching entries from the checkpoint so that
916922 extra keys in the checkpoint are silently ignored.
917923 """
918- if not self .model_weights_path :
919- raise ValueError (
920- "No model weights path provided. Pass 'model_weights_path' "
921- "to the FMPose3DInference constructor."
922- )
923- weights = Path (self .model_weights_path )
924- if not weights .exists ():
925- raise ValueError (
926- f"Model weights file not found: { weights } . "
927- "Please provide a valid path to a .pth checkpoint file in the "
928- "FMPose3DInference constructor."
929- )
930924 if self ._model_3d is None :
931925 raise ValueError ("Model not initialised. Call setup_runtime() first." )
932- pre_dict = torch .load (
933- self .model_weights_path ,
926+ weights = self ._resolve_model_weights_path ()
927+ state_dict = torch .load (
928+ weights ,
934929 weights_only = True ,
935930 map_location = self .device ,
936931 )
937- model_dict = self ._model_3d .state_dict ()
938- for name in model_dict :
939- if name in pre_dict :
940- model_dict [name ] = pre_dict [name ]
941- self ._model_3d .load_state_dict (model_dict )
932+ self ._model_3d .load_state_dict (state_dict )
942933
943934 # ------------------------------------------------------------------
944935 # Private helpers – input resolution
945936 # ------------------------------------------------------------------
946937
938+ def _resolve_model_weights_path (self ) -> None :
939+ # TODO @deruyter92: THIS IS TEMPORARY UNTIL WE DOWNLOAD THE WEIGHTS FROM HUGGINGFACE
940+ if self .model_weights_path is SKIP_WEIGHTS_VALIDATION :
941+ return SKIP_WEIGHTS_VALIDATION
942+
943+ if not self .model_weights_path :
944+ self ._download_model_weights ()
945+ self .model_weights_path = Path (self .model_weights_path ).resolve ()
946+ if not self .model_weights_path .exists ():
947+ raise ValueError (
948+ f"Model weights file not found: { self .model_weights_path } . "
949+ "Please provide a valid path to a .pth checkpoint file in the "
950+ "FMPose3DInference constructor. Or leave it empty to download "
951+ "the weights from huggingface."
952+ )
953+ return self .model_weights_path
954+
955+ def _download_model_weights (self ) -> None :
956+ """Download model weights from huggingface."""
957+ # TODO @deruyter92: Implement download from huggingface
958+ raise NotImplementedError (
959+ "Downloading model weights from huggingface is not implemented yet."
960+ "Please provide a valid path to a .pth checkpoint file in the "
961+ "FMPose3DInference constructor."
962+ )
963+
947964 def _ingest_input (self , source : Source ) -> _IngestedInput :
948965 """Normalise *source* into a ``(N, H, W, C)`` frames array.
949966
0 commit comments