Skip to content

Commit 530e1ba

Browse files
committed
adopt new model name 'fmpose3d_humans' in fmpose3d and tests
1 parent ad7e721 commit 530e1ba

2 files changed

Lines changed: 7 additions & 7 deletions

File tree

fmpose3d/fmpose3d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ class _IngestedInput:
483483
class FMPose3DInference:
484484
"""High-level, two-step inference API for FMPose3D.
485485
486-
Supports both **human** (``model_type="fmpose3d"``, 17 H36M joints)
486+
Supports both **human** (``model_type="fmpose3d_humans"``, 17 H36M joints)
487487
and **animal** (``model_type="fmpose3d_animals"``, 26 Animal3D joints)
488488
pipelines. The skeleton layout, 2D estimator, and post-processing
489489
are chosen automatically from the model configuration.
@@ -722,7 +722,7 @@ def pose_3d(
722722
) -> Pose3DResult:
723723
"""Lift 2D keypoints to 3D using the flow-matching model.
724724
725-
**Human pipeline** (``model_type="fmpose3d"``):
725+
**Human pipeline** (``model_type="fmpose3d_humans"``):
726726
Mirrors ``demo/vis_in_the_wild.py`` -- normalise screen
727727
coordinates, flip-augmented TTA, Euler ODE sampling, zero the
728728
root joint, ``camera_to_world``.

tests/test_fmpose3d.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def forward(self, c_2d: torch.Tensor, y: torch.Tensor, t: torch.Tensor) -> torch
104104

105105

106106
def _make_ready_api(
107-
model_type: str = "fmpose3d",
107+
model_type: str = "fmpose3d_humans",
108108
test_augmentation: bool = False,
109109
) -> FMPose3DInference:
110110
"""Return an ``FMPose3DInference`` with a mock model pre-installed.
@@ -148,7 +148,7 @@ def animal_api() -> FMPose3DInference:
148148
@pytest.fixture
149149
def ready_human_api() -> FMPose3DInference:
150150
"""Human API with mock model (TTA disabled)."""
151-
return _make_ready_api("fmpose3d", test_augmentation=False)
151+
return _make_ready_api("fmpose3d_humans", test_augmentation=False)
152152

153153

154154
@pytest.fixture
@@ -393,7 +393,7 @@ def test_animal(self):
393393

394394
class TestFMPose3DInferenceInit:
395395
def test_default_human(self, human_api):
396-
assert human_api.model_cfg.model_type == "fmpose3d"
396+
assert human_api.model_cfg.model_type == "fmpose3d_humans"
397397
assert human_api._joints_left == [4, 5, 6, 11, 12, 13]
398398
assert human_api._joints_right == [1, 2, 3, 14, 15, 16]
399399
assert human_api._root_joint == 0
@@ -645,7 +645,7 @@ def test_progress_callback(self, ready_human_api):
645645

646646
def test_tta_path_produces_output(self):
647647
"""Test-time augmentation (flip) path produces correct shapes."""
648-
api = _make_ready_api("fmpose3d", test_augmentation=True)
648+
api = _make_ready_api("fmpose3d_humans", test_augmentation=True)
649649
kpts = np.random.randn(1, 17, 2).astype("float32")
650650
result = api.pose_3d(kpts, image_size=(480, 640), seed=42)
651651
assert result.poses_3d.shape == (1, 17, 3)
@@ -660,7 +660,7 @@ def test_animal_api_shapes(self):
660660

661661
def test_predict_end_to_end_with_mock_estimator(self):
662662
"""predict() chains prepare_2d → pose_3d correctly."""
663-
api = _make_ready_api("fmpose3d", test_augmentation=False)
663+
api = _make_ready_api("fmpose3d_humans", test_augmentation=False)
664664

665665
mock_kpts = np.random.randn(1, 1, 17, 2).astype("float32")
666666
mock_scores = np.ones((1, 1, 17), dtype="float32")

0 commit comments

Comments
 (0)