Skip to content

Commit 5570532

Browse files
committed
Update weight loading
1 parent 530e1ba commit 5570532

3 files changed

Lines changed: 47 additions & 30 deletions

File tree

fmpose3d/fmpose3d.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,9 @@ class _IngestedInput:
480480
# ---------------------------------------------------------------------------
481481

482482

483+
# FIXME @deruyter92: THIS IS TEMPORARY UNTIL WE DOWNLOAD THE WEIGHTS FROM HUGGINGFACE
484+
SKIP_WEIGHTS_VALIDATION = object() # sentinel value to indicate that the weights should not be validated
485+
483486
class FMPose3DInference:
484487
"""High-level, two-step inference API for FMPose3D.
485488
@@ -534,7 +537,7 @@ def __init__(
534537
self,
535538
model_cfg: FMPose3DConfig | None = None,
536539
inference_cfg: InferenceConfig | None = None,
537-
model_weights_path: str = "",
540+
model_weights_path: str | Path | None = SKIP_WEIGHTS_VALIDATION,
538541
device: str | torch.device | None = None,
539542
*,
540543
estimator_2d: HRNetEstimator | SuperAnimalEstimator | None = None,
@@ -544,6 +547,9 @@ def __init__(
544547
self.inference_cfg = inference_cfg or InferenceConfig()
545548
self.model_weights_path = model_weights_path
546549

550+
# Validate model weights path (download if needed)
551+
self._resolve_model_weights_path()
552+
547553
# Skeleton configuration from the model config.
548554
self._joints_left: list[int] = list(self.model_cfg.joints_left)
549555
self._joints_right: list[int] = list(self.model_cfg.joints_right)
@@ -572,7 +578,7 @@ def __init__(
572578
@classmethod
573579
def for_animals(
574580
cls,
575-
model_weights_path: str = "",
581+
model_weights_path: str = SKIP_WEIGHTS_VALIDATION,
576582
*,
577583
device: str | torch.device | None = None,
578584
inference_cfg: InferenceConfig | None = None,
@@ -915,35 +921,46 @@ def _load_weights(self) -> None:
915921
state-dict keys and pull matching entries from the checkpoint so that
916922
extra keys in the checkpoint are silently ignored.
917923
"""
918-
if not self.model_weights_path:
919-
raise ValueError(
920-
"No model weights path provided. Pass 'model_weights_path' "
921-
"to the FMPose3DInference constructor."
922-
)
923-
weights = Path(self.model_weights_path)
924-
if not weights.exists():
925-
raise ValueError(
926-
f"Model weights file not found: {weights}. "
927-
"Please provide a valid path to a .pth checkpoint file in the "
928-
"FMPose3DInference constructor."
929-
)
930924
if self._model_3d is None:
931925
raise ValueError("Model not initialised. Call setup_runtime() first.")
932-
pre_dict = torch.load(
933-
self.model_weights_path,
926+
weights = self._resolve_model_weights_path()
927+
state_dict = torch.load(
928+
weights,
934929
weights_only=True,
935930
map_location=self.device,
936931
)
937-
model_dict = self._model_3d.state_dict()
938-
for name in model_dict:
939-
if name in pre_dict:
940-
model_dict[name] = pre_dict[name]
941-
self._model_3d.load_state_dict(model_dict)
932+
self._model_3d.load_state_dict(state_dict)
942933

943934
# ------------------------------------------------------------------
944935
# Private helpers – input resolution
945936
# ------------------------------------------------------------------
946937

938+
def _resolve_model_weights_path(self) -> None:
939+
# TODO @deruyter92: THIS IS TEMPORARY UNTIL WE DOWNLOAD THE WEIGHTS FROM HUGGINGFACE
940+
if self.model_weights_path is SKIP_WEIGHTS_VALIDATION:
941+
return SKIP_WEIGHTS_VALIDATION
942+
943+
if not self.model_weights_path:
944+
self._download_model_weights()
945+
self.model_weights_path = Path(self.model_weights_path).resolve()
946+
if not self.model_weights_path.exists():
947+
raise ValueError(
948+
f"Model weights file not found: {self.model_weights_path}. "
949+
"Please provide a valid path to a .pth checkpoint file in the "
950+
"FMPose3DInference constructor. Or leave it empty to download "
951+
"the weights from huggingface."
952+
)
953+
return self.model_weights_path
954+
955+
def _download_model_weights(self) -> None:
956+
"""Download model weights from huggingface."""
957+
# TODO @deruyter92: Implement download from huggingface
958+
raise NotImplementedError(
959+
"Downloading model weights from huggingface is not implemented yet."
960+
"Please provide a valid path to a .pth checkpoint file in the "
961+
"FMPose3DInference constructor."
962+
)
963+
947964
def _ingest_input(self, source: Source) -> _IngestedInput:
948965
"""Normalise *source* into a ``(N, H, W, C)`` frames array.
949966

tests/test_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
class TestFMPose3DConfig:
3737
def test_defaults(self):
3838
cfg = FMPose3DConfig()
39-
assert cfg.layers == 3
39+
assert cfg.layers == 5
4040
assert cfg.channel == 512
4141
assert cfg.d_hid == 1024
4242
assert cfg.n_joints == 17

tests/test_fmpose3d.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -547,18 +547,18 @@ def test_corrupt_image_raises(self, api, tmp_path):
547547

548548
class TestLoadWeightsErrors:
549549
def test_empty_path_raises(self):
550-
api = FMPose3DInference(model_weights_path="", device="cpu")
551-
api._model_3d = _ZeroVelocityModel()
552-
with pytest.raises(ValueError, match="No model weights path"):
550+
with pytest.raises((ValueError, NotImplementedError)):
551+
api = FMPose3DInference(model_weights_path="", device="cpu")
552+
api._model_3d = _ZeroVelocityModel()
553553
api._load_weights()
554554

555555
def test_nonexistent_file_raises(self):
556-
api = FMPose3DInference(
557-
model_weights_path="/nonexistent/weights.pth",
558-
device="cpu",
559-
)
560-
api._model_3d = _ZeroVelocityModel()
561556
with pytest.raises(ValueError, match="Model weights file not found"):
557+
api = FMPose3DInference(
558+
model_weights_path="/nonexistent/weights.pth",
559+
device="cpu",
560+
)
561+
api._model_3d = _ZeroVelocityModel()
562562
api._load_weights()
563563

564564
def test_model_not_initialized_raises(self, tmp_path):

0 commit comments

Comments
 (0)