Skip to content

Commit d9bceb8

Browse files
committed
Add huggingface functionality
1 parent ec9d855 commit d9bceb8

4 files changed

Lines changed: 43 additions & 16 deletions

File tree

fmpose3d/common/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
import math
11+
import json
1112
from dataclasses import dataclass, field, fields, asdict
1213
from enum import Enum
1314
from typing import Dict, List
@@ -36,6 +37,16 @@ class ModelConfig:
3637
"""Model architecture configuration."""
3738
model_type: str = "fmpose3d_humans"
3839

40+
def to_json(self, filename: str | None = None, **kwargs) -> str:
41+
json_str = json.dumps(asdict(self), **kwargs)
42+
with open(filename, "w") as f:
43+
f.write(json_str)
44+
45+
@classmethod
46+
def from_json(cls, filename: str, **kwargs) -> "ModelConfig":
47+
with open(filename, "r") as f:
48+
return cls(**json.loads(f.read(), **kwargs))
49+
3950

4051
# Per-model-type defaults for fields marked with INFER_FROM_MODEL_TYPE.
4152
# Also consumed by PipelineConfig.for_model_type to set cross-config

fmpose3d/fmpose3d.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
ProgressCallback = Callable[[int, int], None]
3434

3535

36+
#: HuggingFace repository hosting the official FMPose3D checkpoints.
37+
_HF_REPO_ID: str = "deruyter92/fmpose_temp"
38+
3639
# Default camera-to-world rotation quaternion (from the demo script).
3740
_DEFAULT_CAM_ROTATION = np.array(
3841
[0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088],
@@ -560,7 +563,7 @@ def __init__(
560563
self,
561564
model_cfg: FMPose3DConfig | None = None,
562565
inference_cfg: InferenceConfig | None = None,
563-
model_weights_path: str | Path | None = SKIP_WEIGHTS_VALIDATION,
566+
model_weights_path: str | Path | None = None,
564567
device: str | torch.device | None = None,
565568
*,
566569
estimator_2d: HRNetEstimator | SuperAnimalEstimator | None = None,
@@ -601,7 +604,7 @@ def __init__(
601604
@classmethod
602605
def for_animals(
603606
cls,
604-
model_weights_path: str = SKIP_WEIGHTS_VALIDATION,
607+
model_weights_path: str | None = None,
605608
*,
606609
device: str | torch.device | None = None,
607610
inference_cfg: InferenceConfig | None = None,
@@ -958,15 +961,11 @@ def _load_weights(self) -> None:
958961
# Private helpers – input resolution
959962
# ------------------------------------------------------------------
960963

961-
def _resolve_model_weights_path(self) -> None:
962-
# TODO @deruyter92: THIS IS TEMPORARY UNTIL WE DOWNLOAD THE WEIGHTS FROM HUGGINGFACE
963-
if self.model_weights_path is SKIP_WEIGHTS_VALIDATION:
964-
return SKIP_WEIGHTS_VALIDATION
965-
966-
if not self.model_weights_path:
964+
def _resolve_model_weights_path(self) -> None:
965+
if self.model_weights_path is None:
967966
self._download_model_weights()
968967
self.model_weights_path = Path(self.model_weights_path).resolve()
969-
if not self.model_weights_path.exists():
968+
if not self.model_weights_path.is_file():
970969
raise ValueError(
971970
f"Model weights file not found: {self.model_weights_path}. "
972971
"Please provide a valid path to a .pth checkpoint file in the "
@@ -976,12 +975,28 @@ def _resolve_model_weights_path(self) -> None:
976975
return self.model_weights_path
977976

978977
def _download_model_weights(self) -> None:
979-
"""Download model weights from huggingface."""
980-
# TODO @deruyter92: Implement download from huggingface
981-
raise NotImplementedError(
982-
"Downloading model weights from huggingface is not implemented yet."
983-
"Please provide a valid path to a .pth checkpoint file in the "
984-
"FMPose3DInference constructor."
978+
"""Download model weights from HuggingFace Hub.
979+
980+
The weight file is determined by the current ``model_cfg.model_type``
981+
(e.g. ``"fmpose3d_humans"`` -> ``fmpose3d_humans.pth``). Files are
982+
cached locally by :func:`huggingface_hub.hf_hub_download` so
983+
subsequent calls are instant.
984+
985+
Sets ``self.model_weights_path`` to the local cached file path.
986+
"""
987+
try:
988+
from huggingface_hub import hf_hub_download
989+
except ImportError:
990+
raise ImportError(
991+
"huggingface_hub is required to download model weights. "
992+
"Install it with: pip install huggingface_hub. Or download "
993+
"the weights manually and set model_weights_path to the weights file."
994+
) from None
995+
996+
filename = f"{self.model_cfg.model_type.value}.pth"
997+
self.model_weights_path = hf_hub_download(
998+
repo_id=_HF_REPO_ID,
999+
filename=filename,
9851000
)
9861001

9871002
def _ingest_input(self, source: Source) -> _IngestedInput:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ dependencies = [
3838
"filterpy>=1.4.5",
3939
"pandas>=1.0.1",
4040
"deeplabcut==3.0.0rc13",
41+
"huggingface_hub>=0.20.0",
4142
]
4243

4344
[project.optional-dependencies]

tests/test_fmpose3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ def test_corrupt_image_raises(self, api, tmp_path):
547547

548548
class TestLoadWeightsErrors:
549549
def test_empty_path_raises(self):
550-
with pytest.raises((ValueError, NotImplementedError)):
550+
with pytest.raises(ValueError, match="Model weights file not found"):
551551
api = FMPose3DInference(model_weights_path="", device="cpu")
552552
api._model_3d = _ZeroVelocityModel()
553553
api._load_weights()

0 commit comments

Comments
 (0)