99
1010import math
1111from dataclasses import dataclass , field , fields , asdict
12- from typing import List
13-
12+ from typing import Dict , List
1413
1514# ---------------------------------------------------------------------------
1615# Dataclass configuration groups
@@ -23,21 +22,51 @@ class ModelConfig:
2322 model_type : str = "fmpose3d"
2423
2524
25+ # Per-model-type defaults for fields marked with INFER_FROM_MODEL_TYPE.
26+ # Also consumed by PipelineConfig.for_model_type to set cross-config
27+ # values (dataset, sample_steps, etc.).
28+ _FMPOSE3D_DEFAULTS : Dict [str , Dict ] = {
29+ "fmpose3d" : {
30+ "n_joints" : 17 ,
31+ "out_joints" : 17 ,
32+ "dataset" : "h36m" ,
33+ "sample_steps" : 3 ,
34+ },
35+ "fmpose3d_animals" : {
36+ "n_joints" : 26 ,
37+ "out_joints" : 26 ,
38+ "dataset" : "animal3d" ,
39+ "sample_steps" : 3 ,
40+ },
41+ }
42+
43+ # Sentinel object for defaults that are inferred from the model type.
44+ INFER_FROM_MODEL_TYPE = object ()
45+
2646@dataclass
2747class FMPose3DConfig (ModelConfig ):
28- model : str = ""
2948 model_type : str = "fmpose3d"
30- layers : int = 3
49+ model : str = ""
50+ layers : int = 5
3151 channel : int = 512
3252 d_hid : int = 1024
3353 token_dim : int = 256
34- n_joints : int = 17
35- out_joints : int = 17
54+ n_joints : int = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
55+ out_joints : int = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
3656 in_channels : int = 2
3757 out_channels : int = 3
3858 frames : int = 1
39- """Optional: load model class from a specific file path."""
4059
60+ def __post_init__ (self ):
61+ defaults = _FMPOSE3D_DEFAULTS .get (self .model_type )
62+ if defaults is None :
63+ supported = ", " .join (sorted (_FMPOSE3D_DEFAULTS ))
64+ raise ValueError (
65+ f"Unknown model_type { self .model_type !r} ; supported: { supported } "
66+ )
67+ for f in fields (self ):
68+ if getattr (self , f .name ) is INFER_FROM_MODEL_TYPE :
69+ setattr (self , f .name , defaults [f .name ])
4170
4271@dataclass
4372class DatasetConfig :
@@ -239,8 +268,6 @@ class PipelineConfig:
239268 demo_cfg : DemoConfig = field (default_factory = DemoConfig )
240269 runtime_cfg : RuntimeConfig = field (default_factory = RuntimeConfig )
241270
242- # -- construction from argparse namespace ---------------------------------
243-
244271 @classmethod
245272 def from_namespace (cls , ns ) -> "PipelineConfig" :
246273 """Build a :class:`PipelineConfig` from an ``argparse.Namespace``
@@ -258,7 +285,7 @@ def _pick(dc_class, src: dict):
258285
259286 kwargs = {}
260287 for group_name , dc_class in _SUB_CONFIG_CLASSES .items ():
261- if group_name == "model_cfg" and raw .get ("model_type" , " fmpose3d" ) == "fmpose3d" :
288+ if group_name == "model_cfg" and raw .get ("model_type" , ' fmpose3d' ) in _FMPOSE3D_DEFAULTS :
262289 dc_class = FMPose3DConfig
263290 elif group_name == "pose2d_cfg" and raw .get ("pose2d_model" , "hrnet" ) == "hrnet" :
264291 dc_class = HRNetConfig
0 commit comments