Skip to content

Commit 3863ddf

Browse files
committed
Update config: extendable configs and changed name -> PipelineConfig
1 parent acb4feb commit 3863ddf

3 files changed

Lines changed: 67 additions & 39 deletions

File tree

fmpose3d/common/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313

1414
from .arguments import opts
1515
from .config import (
16-
FMPoseConfig,
16+
PipelineConfig,
1717
ModelConfig,
18+
FMPose3DConfig,
19+
HRNetConfig,
20+
Pose2DConfig,
1821
DatasetConfig,
1922
TrainingConfig,
2023
InferenceConfig,
@@ -40,7 +43,10 @@
4043

4144
__all__ = [
4245
"opts",
43-
"FMPoseConfig",
46+
"PipelineConfig",
47+
"FMPose3DConfig",
48+
"HRNetConfig",
49+
"Pose2DConfig",
4450
"ModelConfig",
4551
"DatasetConfig",
4652
"TrainingConfig",

fmpose3d/common/config.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
@dataclass
2121
class ModelConfig:
2222
"""Model architecture configuration."""
23+
model_type: str = "fmpose3d"
24+
2325

26+
@dataclass
27+
class FMPose3DConfig(ModelConfig):
2428
model: str = ""
2529
model_type: str = "fmpose3d"
2630
layers: int = 3
@@ -141,6 +145,20 @@ class OutputConfig:
141145
sh_file: str = ""
142146

143147

148+
@dataclass
149+
class Pose2DConfig:
150+
"""2D pose estimator configuration."""
151+
pose2d_model: str = "hrnet"
152+
153+
154+
@dataclass
155+
class HRNetConfig(Pose2DConfig):
156+
"""HRNet 2D pose detector configuration."""
157+
pose2d_model: str = "hrnet"
158+
det_dim: int = 416 # YOLO input dimension (HRNet-specific).
159+
num_persons: int = 1 # Maximum number of persons to estimate per frame.
160+
161+
144162
@dataclass
145163
class DemoConfig:
146164
"""Demo / inference configuration."""
@@ -174,53 +192,59 @@ class RuntimeConfig:
174192
"checkpoint_cfg": CheckpointConfig,
175193
"refinement_cfg": RefinementConfig,
176194
"output_cfg": OutputConfig,
195+
"pose2d_cfg": Pose2DConfig,
177196
"demo_cfg": DemoConfig,
178197
"runtime_cfg": RuntimeConfig,
179198
}
180199

181200

182201
@dataclass
183-
class FMPoseConfig:
184-
"""Top-level configuration for FMPose3D.
202+
class PipelineConfig:
203+
"""Top-level configuration for FMPose3D pipeline.
185204
186205
Groups related settings into sub-configs::
187206
188207
config.model_cfg.layers
189208
config.training_cfg.lr
190209
"""
191210

192-
model_cfg: ModelConfig = field(default_factory=ModelConfig)
211+
model_cfg: ModelConfig = field(default_factory=FMPose3DConfig)
193212
dataset_cfg: DatasetConfig = field(default_factory=DatasetConfig)
194213
training_cfg: TrainingConfig = field(default_factory=TrainingConfig)
195214
inference_cfg: InferenceConfig = field(default_factory=InferenceConfig)
196215
aggregation_cfg: AggregationConfig = field(default_factory=AggregationConfig)
197216
checkpoint_cfg: CheckpointConfig = field(default_factory=CheckpointConfig)
198217
refinement_cfg: RefinementConfig = field(default_factory=RefinementConfig)
199218
output_cfg: OutputConfig = field(default_factory=OutputConfig)
219+
pose2d_cfg: Pose2DConfig = field(default_factory=HRNetConfig)
200220
demo_cfg: DemoConfig = field(default_factory=DemoConfig)
201221
runtime_cfg: RuntimeConfig = field(default_factory=RuntimeConfig)
202222

203223
# -- construction from argparse namespace ---------------------------------
204224

205225
@classmethod
206-
def from_namespace(cls, ns) -> "FMPoseConfig":
207-
"""Build a :class:`FMPoseConfig` from an ``argparse.Namespace``
226+
def from_namespace(cls, ns) -> "PipelineConfig":
227+
"""Build a :class:`PipelineConfig` from an ``argparse.Namespace``
208228
209229
Example::
210230
211231
args = opts().parse()
212-
cfg = FMPoseConfig.from_namespace(args)
232+
cfg = PipelineConfig.from_namespace(args)
213233
"""
214234
raw = vars(ns) if hasattr(ns, "__dict__") else dict(ns)
215235

216236
def _pick(dc_class, src: dict):
217237
names = {f.name for f in fields(dc_class)}
218238
return dc_class(**{k: v for k, v in src.items() if k in names})
219239

220-
return cls(**{
221-
group_name: _pick(dc_class, raw)
222-
for group_name, dc_class in _SUB_CONFIG_CLASSES.items()
223-
})
240+
kwargs = {}
241+
for group_name, dc_class in _SUB_CONFIG_CLASSES.items():
242+
if group_name == "model_cfg" and raw.get("model_type", "fmpose3d") == "fmpose3d":
243+
dc_class = FMPose3DConfig
244+
elif group_name == "pose2d_cfg" and raw.get("pose2d_model", "hrnet") == "hrnet":
245+
dc_class = HRNetConfig
246+
kwargs[group_name] = _pick(dc_class, raw)
247+
return cls(**kwargs)
224248

225249
# -- utilities ------------------------------------------------------------
226250

tests/test_config.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import pytest
1414

1515
from fmpose3d.common.config import (
16-
FMPoseConfig,
17-
ModelConfig,
16+
PipelineConfig,
17+
FMPose3DConfig,
1818
DatasetConfig,
1919
TrainingConfig,
2020
InferenceConfig,
@@ -33,19 +33,18 @@
3333
# ---------------------------------------------------------------------------
3434

3535

36-
class TestModelConfig:
36+
class TestFMPose3DConfig:
3737
def test_defaults(self):
38-
cfg = ModelConfig()
38+
cfg = FMPose3DConfig()
3939
assert cfg.layers == 3
4040
assert cfg.channel == 512
4141
assert cfg.d_hid == 1024
4242
assert cfg.n_joints == 17
4343
assert cfg.out_joints == 17
4444
assert cfg.frames == 1
45-
assert cfg.model_path == ""
4645

4746
def test_custom_values(self):
48-
cfg = ModelConfig(layers=5, channel=256, n_joints=26)
47+
cfg = FMPose3DConfig(layers=5, channel=256, n_joints=26)
4948
assert cfg.layers == 5
5049
assert cfg.channel == 256
5150
assert cfg.n_joints == 26
@@ -167,15 +166,15 @@ def test_defaults(self):
167166

168167

169168
# ---------------------------------------------------------------------------
170-
# FMPoseConfig
169+
# PipelineConfig
171170
# ---------------------------------------------------------------------------
172171

173172

174-
class TestFMPoseConfig:
173+
class TestPipelineConfig:
175174
def test_default_construction(self):
176175
"""All sub-configs are initialised with their defaults."""
177-
cfg = FMPoseConfig()
178-
assert isinstance(cfg.model_cfg, ModelConfig)
176+
cfg = PipelineConfig()
177+
assert isinstance(cfg.model_cfg, FMPose3DConfig)
179178
assert isinstance(cfg.dataset_cfg, DatasetConfig)
180179
assert isinstance(cfg.training_cfg, TrainingConfig)
181180
assert isinstance(cfg.inference_cfg, InferenceConfig)
@@ -188,8 +187,8 @@ def test_default_construction(self):
188187

189188
def test_partial_construction(self):
190189
"""Supplying only some sub-configs leaves the rest at defaults."""
191-
cfg = FMPoseConfig(
192-
model_cfg=ModelConfig(layers=5),
190+
cfg = PipelineConfig(
191+
model_cfg=FMPose3DConfig(layers=5),
193192
training_cfg=TrainingConfig(lr=2e-4),
194193
)
195194
assert cfg.model_cfg.layers == 5
@@ -200,21 +199,21 @@ def test_partial_construction(self):
200199

201200
def test_sub_config_mutation(self):
202201
"""Mutating a sub-config field is reflected on the config."""
203-
cfg = FMPoseConfig()
202+
cfg = PipelineConfig()
204203
cfg.training_cfg.lr = 0.01
205204
assert cfg.training_cfg.lr == pytest.approx(0.01)
206205

207206
def test_sub_config_replacement(self):
208207
"""Replacing an entire sub-config works."""
209-
cfg = FMPoseConfig()
210-
cfg.model_cfg = ModelConfig(layers=10, channel=1024)
208+
cfg = PipelineConfig()
209+
cfg.model_cfg = FMPose3DConfig(layers=10, channel=1024)
211210
assert cfg.model_cfg.layers == 10
212211
assert cfg.model_cfg.channel == 1024
213212

214213
# -- to_dict --------------------------------------------------------------
215214

216215
def test_to_dict_returns_flat_dict(self):
217-
cfg = FMPoseConfig()
216+
cfg = PipelineConfig()
218217
d = cfg.to_dict()
219218
assert isinstance(d, dict)
220219
# Spot-check keys from different groups
@@ -225,8 +224,8 @@ def test_to_dict_returns_flat_dict(self):
225224
assert "gpu" in d
226225

227226
def test_to_dict_reflects_custom_values(self):
228-
cfg = FMPoseConfig(
229-
model_cfg=ModelConfig(layers=7),
227+
cfg = PipelineConfig(
228+
model_cfg=FMPose3DConfig(layers=7),
230229
aggregation_cfg=AggregationConfig(topk=5),
231230
)
232231
d = cfg.to_dict()
@@ -235,7 +234,7 @@ def test_to_dict_reflects_custom_values(self):
235234

236235
def test_to_dict_no_duplicate_keys(self):
237236
"""Every field name should be unique across all sub-configs."""
238-
cfg = FMPoseConfig()
237+
cfg = PipelineConfig()
239238
d = cfg.to_dict()
240239
all_field_names = []
241240
for dc_class in _SUB_CONFIG_CLASSES.values():
@@ -249,8 +248,9 @@ def test_to_dict_no_duplicate_keys(self):
249248

250249
def test_from_namespace_basic(self):
251250
ns = argparse.Namespace(
252-
# ModelConfig
251+
# FMPose3DConfig
253252
model="test_model",
253+
model_type="fmpose3d",
254254
layers=5,
255255
channel=256,
256256
d_hid=512,
@@ -260,7 +260,6 @@ def test_from_namespace_basic(self):
260260
in_channels=2,
261261
out_channels=3,
262262
frames=3,
263-
model_path="/tmp/model.py",
264263
# DatasetConfig
265264
dataset="rat7m",
266265
keypoints="cpn",
@@ -339,7 +338,7 @@ def test_from_namespace_basic(self):
339338
single=True,
340339
reload_3d=False,
341340
)
342-
cfg = FMPoseConfig.from_namespace(ns)
341+
cfg = PipelineConfig.from_namespace(ns)
343342

344343
# Verify a sample from each group
345344
assert cfg.model_cfg.layers == 5
@@ -363,15 +362,15 @@ def test_from_namespace_ignores_unknown_fields(self):
363362
ns = argparse.Namespace(
364363
layers=3, channel=512, unknown_field="should_be_ignored",
365364
)
366-
cfg = FMPoseConfig.from_namespace(ns)
365+
cfg = PipelineConfig.from_namespace(ns)
367366
assert cfg.model_cfg.layers == 3
368367
assert cfg.model_cfg.channel == 512
369368
assert not hasattr(cfg, "unknown_field")
370369

371370
def test_from_namespace_partial_namespace(self):
372371
"""A namespace missing some fields uses dataclass defaults for those."""
373372
ns = argparse.Namespace(layers=10, gpu="2")
374-
cfg = FMPoseConfig.from_namespace(ns)
373+
cfg = PipelineConfig.from_namespace(ns)
375374
assert cfg.model_cfg.layers == 10
376375
assert cfg.runtime_cfg.gpu == "2"
377376
# Unset fields keep defaults
@@ -385,7 +384,7 @@ def test_roundtrip_from_namespace_to_dict(self):
385384
ns = argparse.Namespace(
386385
layers=8, channel=1024, dataset="animal3d", lr=2e-4, topk=7, gpu="3",
387386
)
388-
cfg = FMPoseConfig.from_namespace(ns)
387+
cfg = PipelineConfig.from_namespace(ns)
389388
d = cfg.to_dict()
390389
assert d["layers"] == 8
391390
assert d["channel"] == 1024
@@ -396,10 +395,9 @@ def test_roundtrip_from_namespace_to_dict(self):
396395

397396
def test_to_dict_after_mutation(self):
398397
"""to_dict reflects in-place mutations on sub-configs."""
399-
cfg = FMPoseConfig()
398+
cfg = PipelineConfig()
400399
cfg.training_cfg.lr = 0.123
401400
cfg.model_cfg.layers = 99
402401
d = cfg.to_dict()
403402
assert d["lr"] == pytest.approx(0.123)
404403
assert d["layers"] == 99
405-

0 commit comments

Comments
 (0)