1313import pytest
1414
1515from fmpose3d .common .config import (
16- FMPoseConfig ,
17- ModelConfig ,
16+ PipelineConfig ,
17+ FMPose3DConfig ,
1818 DatasetConfig ,
1919 TrainingConfig ,
2020 InferenceConfig ,
3333# ---------------------------------------------------------------------------
3434
3535
36- class TestModelConfig :
36+ class TestFMPose3DConfig :
3737 def test_defaults (self ):
38- cfg = ModelConfig ()
38+ cfg = FMPose3DConfig ()
3939 assert cfg .layers == 3
4040 assert cfg .channel == 512
4141 assert cfg .d_hid == 1024
4242 assert cfg .n_joints == 17
4343 assert cfg .out_joints == 17
4444 assert cfg .frames == 1
45- assert cfg .model_path == ""
4645
4746 def test_custom_values (self ):
48- cfg = ModelConfig (layers = 5 , channel = 256 , n_joints = 26 )
47+ cfg = FMPose3DConfig (layers = 5 , channel = 256 , n_joints = 26 )
4948 assert cfg .layers == 5
5049 assert cfg .channel == 256
5150 assert cfg .n_joints == 26
@@ -167,15 +166,15 @@ def test_defaults(self):
167166
168167
169168# ---------------------------------------------------------------------------
170- # FMPoseConfig
169+ # PipelineConfig
171170# ---------------------------------------------------------------------------
172171
173172
174- class TestFMPoseConfig :
173+ class TestPipelineConfig :
175174 def test_default_construction (self ):
176175 """All sub-configs are initialised with their defaults."""
177- cfg = FMPoseConfig ()
178- assert isinstance (cfg .model_cfg , ModelConfig )
176+ cfg = PipelineConfig ()
177+ assert isinstance (cfg .model_cfg , FMPose3DConfig )
179178 assert isinstance (cfg .dataset_cfg , DatasetConfig )
180179 assert isinstance (cfg .training_cfg , TrainingConfig )
181180 assert isinstance (cfg .inference_cfg , InferenceConfig )
@@ -188,8 +187,8 @@ def test_default_construction(self):
188187
189188 def test_partial_construction (self ):
190189 """Supplying only some sub-configs leaves the rest at defaults."""
191- cfg = FMPoseConfig (
192- model_cfg = ModelConfig (layers = 5 ),
190+ cfg = PipelineConfig (
191+ model_cfg = FMPose3DConfig (layers = 5 ),
193192 training_cfg = TrainingConfig (lr = 2e-4 ),
194193 )
195194 assert cfg .model_cfg .layers == 5
@@ -200,21 +199,21 @@ def test_partial_construction(self):
200199
201200 def test_sub_config_mutation (self ):
202201 """Mutating a sub-config field is reflected on the config."""
203- cfg = FMPoseConfig ()
202+ cfg = PipelineConfig ()
204203 cfg .training_cfg .lr = 0.01
205204 assert cfg .training_cfg .lr == pytest .approx (0.01 )
206205
207206 def test_sub_config_replacement (self ):
208207 """Replacing an entire sub-config works."""
209- cfg = FMPoseConfig ()
210- cfg .model_cfg = ModelConfig (layers = 10 , channel = 1024 )
208+ cfg = PipelineConfig ()
209+ cfg .model_cfg = FMPose3DConfig (layers = 10 , channel = 1024 )
211210 assert cfg .model_cfg .layers == 10
212211 assert cfg .model_cfg .channel == 1024
213212
214213 # -- to_dict --------------------------------------------------------------
215214
216215 def test_to_dict_returns_flat_dict (self ):
217- cfg = FMPoseConfig ()
216+ cfg = PipelineConfig ()
218217 d = cfg .to_dict ()
219218 assert isinstance (d , dict )
220219 # Spot-check keys from different groups
@@ -225,8 +224,8 @@ def test_to_dict_returns_flat_dict(self):
225224 assert "gpu" in d
226225
227226 def test_to_dict_reflects_custom_values (self ):
228- cfg = FMPoseConfig (
229- model_cfg = ModelConfig (layers = 7 ),
227+ cfg = PipelineConfig (
228+ model_cfg = FMPose3DConfig (layers = 7 ),
230229 aggregation_cfg = AggregationConfig (topk = 5 ),
231230 )
232231 d = cfg .to_dict ()
@@ -235,7 +234,7 @@ def test_to_dict_reflects_custom_values(self):
235234
236235 def test_to_dict_no_duplicate_keys (self ):
237236 """Every field name should be unique across all sub-configs."""
238- cfg = FMPoseConfig ()
237+ cfg = PipelineConfig ()
239238 d = cfg .to_dict ()
240239 all_field_names = []
241240 for dc_class in _SUB_CONFIG_CLASSES .values ():
@@ -249,8 +248,9 @@ def test_to_dict_no_duplicate_keys(self):
249248
250249 def test_from_namespace_basic (self ):
251250 ns = argparse .Namespace (
252- # ModelConfig
251+ # FMPose3DConfig
253252 model = "test_model" ,
253+ model_type = "fmpose3d" ,
254254 layers = 5 ,
255255 channel = 256 ,
256256 d_hid = 512 ,
@@ -260,7 +260,6 @@ def test_from_namespace_basic(self):
260260 in_channels = 2 ,
261261 out_channels = 3 ,
262262 frames = 3 ,
263- model_path = "/tmp/model.py" ,
264263 # DatasetConfig
265264 dataset = "rat7m" ,
266265 keypoints = "cpn" ,
@@ -339,7 +338,7 @@ def test_from_namespace_basic(self):
339338 single = True ,
340339 reload_3d = False ,
341340 )
342- cfg = FMPoseConfig .from_namespace (ns )
341+ cfg = PipelineConfig .from_namespace (ns )
343342
344343 # Verify a sample from each group
345344 assert cfg .model_cfg .layers == 5
@@ -363,15 +362,15 @@ def test_from_namespace_ignores_unknown_fields(self):
363362 ns = argparse .Namespace (
364363 layers = 3 , channel = 512 , unknown_field = "should_be_ignored" ,
365364 )
366- cfg = FMPoseConfig .from_namespace (ns )
365+ cfg = PipelineConfig .from_namespace (ns )
367366 assert cfg .model_cfg .layers == 3
368367 assert cfg .model_cfg .channel == 512
369368 assert not hasattr (cfg , "unknown_field" )
370369
371370 def test_from_namespace_partial_namespace (self ):
372371 """A namespace missing some fields uses dataclass defaults for those."""
373372 ns = argparse .Namespace (layers = 10 , gpu = "2" )
374- cfg = FMPoseConfig .from_namespace (ns )
373+ cfg = PipelineConfig .from_namespace (ns )
375374 assert cfg .model_cfg .layers == 10
376375 assert cfg .runtime_cfg .gpu == "2"
377376 # Unset fields keep defaults
@@ -385,7 +384,7 @@ def test_roundtrip_from_namespace_to_dict(self):
385384 ns = argparse .Namespace (
386385 layers = 8 , channel = 1024 , dataset = "animal3d" , lr = 2e-4 , topk = 7 , gpu = "3" ,
387386 )
388- cfg = FMPoseConfig .from_namespace (ns )
387+ cfg = PipelineConfig .from_namespace (ns )
389388 d = cfg .to_dict ()
390389 assert d ["layers" ] == 8
391390 assert d ["channel" ] == 1024
@@ -396,10 +395,9 @@ def test_roundtrip_from_namespace_to_dict(self):
396395
397396 def test_to_dict_after_mutation (self ):
398397 """to_dict reflects in-place mutations on sub-configs."""
399- cfg = FMPoseConfig ()
398+ cfg = PipelineConfig ()
400399 cfg .training_cfg .lr = 0.123
401400 cfg .model_cfg .layers = 99
402401 d = cfg .to_dict ()
403402 assert d ["lr" ] == pytest .approx (0.123 )
404403 assert d ["layers" ] == 99
405-
0 commit comments