Skip to content

Commit 7a6944f

Browse files
committed
Add default FMPose3DConfig per model_type
1 parent d05195d commit 7a6944f

1 file changed

Lines changed: 37 additions & 10 deletions

File tree

fmpose3d/common/config.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99

1010
import math
1111
from dataclasses import dataclass, field, fields, asdict
12-
from typing import List
13-
12+
from typing import Dict, List
1413

1514
# ---------------------------------------------------------------------------
1615
# Dataclass configuration groups
@@ -23,21 +22,51 @@ class ModelConfig:
2322
model_type: str = "fmpose3d"
2423

2524

25+
# Per-model-type defaults for fields marked with INFER_FROM_MODEL_TYPE.
26+
# Also consumed by PipelineConfig.for_model_type to set cross-config
27+
# values (dataset, sample_steps, etc.).
28+
_FMPOSE3D_DEFAULTS: Dict[str, Dict] = {
29+
"fmpose3d": {
30+
"n_joints": 17,
31+
"out_joints": 17,
32+
"dataset": "h36m",
33+
"sample_steps": 3,
34+
},
35+
"fmpose3d_animals": {
36+
"n_joints": 26,
37+
"out_joints": 26,
38+
"dataset": "animal3d",
39+
"sample_steps": 3,
40+
},
41+
}
42+
43+
# Sentinel object for defaults that are inferred from the model type.
44+
INFER_FROM_MODEL_TYPE = object()
45+
2646
@dataclass
2747
class FMPose3DConfig(ModelConfig):
28-
model: str = ""
2948
model_type: str = "fmpose3d"
30-
layers: int = 3
49+
model: str = ""
50+
layers: int = 5
3151
channel: int = 512
3252
d_hid: int = 1024
3353
token_dim: int = 256
34-
n_joints: int = 17
35-
out_joints: int = 17
54+
n_joints: int = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
55+
out_joints: int = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
3656
in_channels: int = 2
3757
out_channels: int = 3
3858
frames: int = 1
39-
"""Optional: load model class from a specific file path."""
4059

60+
def __post_init__(self):
61+
defaults = _FMPOSE3D_DEFAULTS.get(self.model_type)
62+
if defaults is None:
63+
supported = ", ".join(sorted(_FMPOSE3D_DEFAULTS))
64+
raise ValueError(
65+
f"Unknown model_type {self.model_type!r}; supported: {supported}"
66+
)
67+
for f in fields(self):
68+
if getattr(self, f.name) is INFER_FROM_MODEL_TYPE:
69+
setattr(self, f.name, defaults[f.name])
4170

4271
@dataclass
4372
class DatasetConfig:
@@ -239,8 +268,6 @@ class PipelineConfig:
239268
demo_cfg: DemoConfig = field(default_factory=DemoConfig)
240269
runtime_cfg: RuntimeConfig = field(default_factory=RuntimeConfig)
241270

242-
# -- construction from argparse namespace ---------------------------------
243-
244271
@classmethod
245272
def from_namespace(cls, ns) -> "PipelineConfig":
246273
"""Build a :class:`PipelineConfig` from an ``argparse.Namespace``
@@ -258,7 +285,7 @@ def _pick(dc_class, src: dict):
258285

259286
kwargs = {}
260287
for group_name, dc_class in _SUB_CONFIG_CLASSES.items():
261-
if group_name == "model_cfg" and raw.get("model_type", "fmpose3d") == "fmpose3d":
288+
if group_name == "model_cfg" and raw.get("model_type", 'fmpose3d') in _FMPOSE3D_DEFAULTS:
262289
dc_class = FMPose3DConfig
263290
elif group_name == "pose2d_cfg" and raw.get("pose2d_model", "hrnet") == "hrnet":
264291
dc_class = HRNetConfig

0 commit comments

Comments
 (0)