Skip to content

Commit 3c199a2

Browse files
committed
Update FMPose3DConfig and add SuperAnimalConfig
1 parent 7a6944f commit 3c199a2

1 file changed

Lines changed: 42 additions & 2 deletions

File tree

fmpose3d/common/config.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,18 @@ class ModelConfig:
3131
"out_joints": 17,
3232
"dataset": "h36m",
3333
"sample_steps": 3,
34+
"joints_left": [4, 5, 6, 11, 12, 13],
35+
"joints_right": [1, 2, 3, 14, 15, 16],
36+
"root_joint": 0,
3437
},
3538
"fmpose3d_animals": {
3639
"n_joints": 26,
3740
"out_joints": 26,
3841
"dataset": "animal3d",
3942
"sample_steps": 3,
43+
"joints_left": [0, 3, 5, 8, 10, 12, 14, 16, 20, 22],
44+
"joints_right": [1, 4, 6, 9, 11, 13, 15, 17, 21, 23],
45+
"root_joint": 7,
4046
},
4147
}
4248

@@ -53,6 +59,9 @@ class FMPose3DConfig(ModelConfig):
5359
token_dim: int = 256
5460
n_joints: int = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
5561
out_joints: int = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
62+
joints_left: List[int] = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
63+
joints_right: List[int] = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
64+
root_joint: int = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
5665
in_channels: int = 2
5766
out_channels: int = 3
5867
frames: int = 1
@@ -207,6 +216,33 @@ class HRNetConfig(Pose2DConfig):
207216
hrnet_weights_path: str = ""
208217

209218

219+
@dataclass
220+
class SuperAnimalConfig(Pose2DConfig):
221+
"""DeepLabCut SuperAnimal 2D pose detector configuration.
222+
223+
Uses the DeepLabCut ``superanimal_analyze_images`` API to detect
224+
animal keypoints in the quadruped80K format, then maps them to the
225+
Animal3D 26-keypoint layout expected by the ``fmpose3d_animals``
226+
3D lifter.
227+
228+
Attributes
229+
----------
230+
superanimal_name : str
231+
Name of the SuperAnimal model (default ``"superanimal_quadruped"``).
232+
sa_model_name : str
233+
Backbone architecture (default ``"hrnet_w32"``).
234+
detector_name : str
235+
Object detector used for animal bounding boxes.
236+
max_individuals : int
237+
Maximum number of individuals to detect per image (default 1).
238+
"""
239+
pose2d_model: str = "superanimal"
240+
superanimal_name: str = "superanimal_quadruped"
241+
sa_model_name: str = "hrnet_w32"
242+
detector_name: str = "fasterrcnn_resnet50_fpn_v2"
243+
max_individuals: int = 1
244+
245+
210246
@dataclass
211247
class DemoConfig:
212248
"""Demo / inference configuration."""
@@ -287,8 +323,12 @@ def _pick(dc_class, src: dict):
287323
for group_name, dc_class in _SUB_CONFIG_CLASSES.items():
288324
if group_name == "model_cfg" and raw.get("model_type", 'fmpose3d') in _FMPOSE3D_DEFAULTS:
289325
dc_class = FMPose3DConfig
290-
elif group_name == "pose2d_cfg" and raw.get("pose2d_model", "hrnet") == "hrnet":
291-
dc_class = HRNetConfig
326+
elif group_name == "pose2d_cfg":
327+
p2d = raw.get("pose2d_model", "hrnet")
328+
if p2d == "superanimal":
329+
dc_class = SuperAnimalConfig
330+
elif p2d == "hrnet":
331+
dc_class = HRNetConfig
292332
kwargs[group_name] = _pick(dc_class, raw)
293333
return cls(**kwargs)
294334

0 commit comments

Comments
 (0)