@@ -31,12 +31,18 @@ class ModelConfig:
3131 "out_joints" : 17 ,
3232 "dataset" : "h36m" ,
3333 "sample_steps" : 3 ,
34+ "joints_left" : [4 , 5 , 6 , 11 , 12 , 13 ],
35+ "joints_right" : [1 , 2 , 3 , 14 , 15 , 16 ],
36+ "root_joint" : 0 ,
3437 },
3538 "fmpose3d_animals" : {
3639 "n_joints" : 26 ,
3740 "out_joints" : 26 ,
3841 "dataset" : "animal3d" ,
3942 "sample_steps" : 3 ,
43+ "joints_left" : [0 , 3 , 5 , 8 , 10 , 12 , 14 , 16 , 20 , 22 ],
44+ "joints_right" : [1 , 4 , 6 , 9 , 11 , 13 , 15 , 17 , 21 , 23 ],
45+ "root_joint" : 7 ,
4046 },
4147}
4248
@@ -53,6 +59,9 @@ class FMPose3DConfig(ModelConfig):
5359 token_dim : int = 256
5460 n_joints : int = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
5561 out_joints : int = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
62+ joints_left : List [int ] = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
63+ joints_right : List [int ] = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
64+ root_joint : int = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
5665 in_channels : int = 2
5766 out_channels : int = 3
5867 frames : int = 1
@@ -207,6 +216,33 @@ class HRNetConfig(Pose2DConfig):
207216 hrnet_weights_path : str = ""
208217
209218
219+ @dataclass
220+ class SuperAnimalConfig (Pose2DConfig ):
221+ """DeepLabCut SuperAnimal 2D pose detector configuration.
222+
223+ Uses the DeepLabCut ``superanimal_analyze_images`` API to detect
224+ animal keypoints in the quadruped80K format, then maps them to the
225+ Animal3D 26-keypoint layout expected by the ``fmpose3d_animals``
226+ 3D lifter.
227+
228+ Attributes
229+ ----------
230+ superanimal_name : str
231+ Name of the SuperAnimal model (default ``"superanimal_quadruped"``).
232+ sa_model_name : str
233+ Backbone architecture (default ``"hrnet_w32"``).
234+ detector_name : str
235+ Object detector used for animal bounding boxes.
236+ max_individuals : int
237+ Maximum number of individuals to detect per image (default 1).
238+ """
239+ pose2d_model : str = "superanimal"
240+ superanimal_name : str = "superanimal_quadruped"
241+ sa_model_name : str = "hrnet_w32"
242+ detector_name : str = "fasterrcnn_resnet50_fpn_v2"
243+ max_individuals : int = 1
244+
245+
210246@dataclass
211247class DemoConfig :
212248 """Demo / inference configuration."""
@@ -287,8 +323,12 @@ def _pick(dc_class, src: dict):
287323 for group_name , dc_class in _SUB_CONFIG_CLASSES .items ():
288324 if group_name == "model_cfg" and raw .get ("model_type" , 'fmpose3d' ) in _FMPOSE3D_DEFAULTS :
289325 dc_class = FMPose3DConfig
290- elif group_name == "pose2d_cfg" and raw .get ("pose2d_model" , "hrnet" ) == "hrnet" :
291- dc_class = HRNetConfig
326+ elif group_name == "pose2d_cfg" :
327+ p2d = raw .get ("pose2d_model" , "hrnet" )
328+ if p2d == "superanimal" :
329+ dc_class = SuperAnimalConfig
330+ elif p2d == "hrnet" :
331+ dc_class = HRNetConfig
292332 kwargs [group_name ] = _pick (dc_class , raw )
293333 return cls (** kwargs )
294334
0 commit comments