Skip to content

Commit f8ab0a6

Browse files
committed
Add config dataclasses (in parallel to arguments.py)
1 parent 81a22c6 commit f8ab0a6

2 files changed

Lines changed: 257 additions & 0 deletions

File tree

fmpose3d/common/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,19 @@
1212
"""
1313

1414
from .arguments import opts
15+
from .config import (
16+
FMPoseConfig,
17+
ModelConfig,
18+
DatasetConfig,
19+
TrainingConfig,
20+
InferenceConfig,
21+
AggregationConfig,
22+
CheckpointConfig,
23+
RefinementConfig,
24+
OutputConfig,
25+
DemoConfig,
26+
RuntimeConfig,
27+
)
1528
from .h36m_dataset import Human36mDataset
1629
from .load_data_hm36 import Fusion
1730
from .utils import (
@@ -27,6 +40,17 @@
2740

2841
__all__ = [
2942
"opts",
43+
"FMPoseConfig",
44+
"ModelConfig",
45+
"DatasetConfig",
46+
"TrainingConfig",
47+
"InferenceConfig",
48+
"AggregationConfig",
49+
"CheckpointConfig",
50+
"RefinementConfig",
51+
"OutputConfig",
52+
"DemoConfig",
53+
"RuntimeConfig",
3054
"Human36mDataset",
3155
"Fusion",
3256
"mpjpe_cal",

fmpose3d/common/config.py

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
"""
2+
FMPose3D: monocular 3D Pose Estimation via Flow Matching
3+
4+
Official implementation of the paper:
5+
"FMPose3D: monocular 3D Pose Estimation via Flow Matching"
6+
by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis
7+
Licensed under Apache 2.0
8+
"""
9+
10+
import math
11+
from dataclasses import dataclass, field, fields, asdict
12+
from typing import List
13+
14+
15+
# ---------------------------------------------------------------------------
16+
# Dataclass configuration groups
17+
# ---------------------------------------------------------------------------
18+
19+
20+
@dataclass
21+
class ModelConfig:
22+
"""Model architecture configuration."""
23+
24+
model: str = ""
25+
layers: int = 3
26+
channel: int = 512
27+
d_hid: int = 1024
28+
token_dim: int = 256
29+
n_joints: int = 17
30+
out_joints: int = 17
31+
in_channels: int = 2
32+
out_channels: int = 3
33+
frames: int = 1
34+
model_path: str = ""
35+
"""Optional: load model class from a specific file path."""
36+
37+
38+
@dataclass
39+
class DatasetConfig:
40+
"""Dataset and data loading configuration."""
41+
42+
dataset: str = "h36m"
43+
keypoints: str = "cpn_ft_h36m_dbb"
44+
root_path: str = "dataset/"
45+
actions: str = "*"
46+
downsample: int = 1
47+
subset: float = 1.0
48+
stride: int = 1
49+
crop_uv: int = 0
50+
out_all: int = 1
51+
train_views: List[int] = field(default_factory=lambda: [0, 1, 2, 3])
52+
test_views: List[int] = field(default_factory=lambda: [0, 1, 2, 3])
53+
54+
# Derived / set during parse based on dataset choice
55+
subjects_train: str = "S1,S5,S6,S7,S8"
56+
subjects_test: str = "S9,S11"
57+
root_joint: int = 0
58+
joints_left: List[int] = field(default_factory=list)
59+
joints_right: List[int] = field(default_factory=list)
60+
61+
62+
@dataclass
63+
class TrainingConfig:
64+
"""Training hyperparameters and settings."""
65+
66+
train: bool = False
67+
nepoch: int = 41
68+
batch_size: int = 128
69+
lr: float = 1e-3
70+
lr_decay: float = 0.95
71+
lr_decay_large: float = 0.5
72+
large_decay_epoch: int = 5
73+
workers: int = 8
74+
data_augmentation: bool = True
75+
reverse_augmentation: bool = False
76+
norm: float = 0.01
77+
78+
79+
@dataclass
80+
class InferenceConfig:
81+
"""Evaluation and testing configuration."""
82+
83+
test: int = 1
84+
test_augmentation: bool = True
85+
test_augmentation_flip_hypothesis: bool = False
86+
test_augmentation_FlowAug: bool = False
87+
sample_steps: int = 3
88+
eval_multi_steps: bool = False
89+
eval_sample_steps: str = "1,3,5,7,9"
90+
num_hypothesis_list: str = "1"
91+
hypothesis_num: int = 1
92+
guidance_scale: float = 1.0
93+
94+
95+
@dataclass
96+
class AggregationConfig:
97+
"""Hypothesis aggregation configuration."""
98+
99+
topk: int = 3
100+
exp_temp: float = 0.002
101+
mode: str = "exp"
102+
opt_steps: int = 2
103+
104+
105+
@dataclass
106+
class CheckpointConfig:
107+
"""Checkpoint loading and saving configuration."""
108+
109+
reload: bool = False
110+
model_dir: str = ""
111+
model_weights_path: str = ""
112+
checkpoint: str = ""
113+
previous_dir: str = "./pre_trained_model/pretrained"
114+
num_saved_models: int = 3
115+
previous_best_threshold: float = math.inf
116+
previous_name: str = ""
117+
118+
119+
@dataclass
120+
class RefinementConfig:
121+
"""Post-refinement model configuration."""
122+
123+
post_refine: bool = False
124+
post_refine_reload: bool = False
125+
previous_post_refine_name: str = ""
126+
lr_refine: float = 1e-5
127+
refine: bool = False
128+
reload_refine: bool = False
129+
previous_refine_name: str = ""
130+
131+
132+
@dataclass
133+
class OutputConfig:
134+
"""Output, logging, and file management configuration."""
135+
136+
create_time: str = ""
137+
filename: str = ""
138+
create_file: int = 1
139+
debug: bool = False
140+
folder_name: str = ""
141+
sh_file: str = ""
142+
143+
144+
@dataclass
145+
class DemoConfig:
146+
"""Demo / inference configuration."""
147+
148+
type: str = "image"
149+
"""Input type: ``'image'`` or ``'video'``."""
150+
path: str = "demo/images/running.png"
151+
"""Path to input file or directory."""
152+
153+
154+
@dataclass
155+
class RuntimeConfig:
156+
"""Runtime environment configuration."""
157+
158+
gpu: str = "0"
159+
pad: int = 0 # derived: (frames - 1) // 2
160+
single: bool = False
161+
reload_3d: bool = False
162+
163+
164+
# ---------------------------------------------------------------------------
165+
# Composite configuration
166+
# ---------------------------------------------------------------------------
167+
168+
_SUB_CONFIG_CLASSES = {
169+
"model_cfg": ModelConfig,
170+
"dataset_cfg": DatasetConfig,
171+
"training_cfg": TrainingConfig,
172+
"inference_cfg": InferenceConfig,
173+
"aggregation_cfg": AggregationConfig,
174+
"checkpoint_cfg": CheckpointConfig,
175+
"refinement_cfg": RefinementConfig,
176+
"output_cfg": OutputConfig,
177+
"demo_cfg": DemoConfig,
178+
"runtime_cfg": RuntimeConfig,
179+
}
180+
181+
182+
@dataclass
183+
class FMPoseConfig:
184+
"""Top-level configuration for FMPose3D.
185+
186+
Groups related settings into sub-configs::
187+
188+
config.model_cfg.layers
189+
config.training_cfg.lr
190+
"""
191+
192+
model_cfg: ModelConfig = field(default_factory=ModelConfig)
193+
dataset_cfg: DatasetConfig = field(default_factory=DatasetConfig)
194+
training_cfg: TrainingConfig = field(default_factory=TrainingConfig)
195+
inference_cfg: InferenceConfig = field(default_factory=InferenceConfig)
196+
aggregation_cfg: AggregationConfig = field(default_factory=AggregationConfig)
197+
checkpoint_cfg: CheckpointConfig = field(default_factory=CheckpointConfig)
198+
refinement_cfg: RefinementConfig = field(default_factory=RefinementConfig)
199+
output_cfg: OutputConfig = field(default_factory=OutputConfig)
200+
demo_cfg: DemoConfig = field(default_factory=DemoConfig)
201+
runtime_cfg: RuntimeConfig = field(default_factory=RuntimeConfig)
202+
203+
# -- construction from argparse namespace ---------------------------------
204+
205+
@classmethod
206+
def from_namespace(cls, ns) -> "FMPoseConfig":
207+
"""Build a :class:`FMPoseConfig` from an ``argparse.Namespace``
208+
209+
Example::
210+
211+
args = opts().parse()
212+
cfg = FMPoseConfig.from_namespace(args)
213+
"""
214+
raw = vars(ns) if hasattr(ns, "__dict__") else dict(ns)
215+
216+
def _pick(dc_class, src: dict):
217+
names = {f.name for f in fields(dc_class)}
218+
return dc_class(**{k: v for k, v in src.items() if k in names})
219+
220+
return cls(**{
221+
group_name: _pick(dc_class, raw)
222+
for group_name, dc_class in _SUB_CONFIG_CLASSES.items()
223+
})
224+
225+
# -- utilities ------------------------------------------------------------
226+
227+
def to_dict(self) -> dict:
228+
"""Return a flat dictionary of all configuration values."""
229+
result = {}
230+
for group_name in _SUB_CONFIG_CLASSES:
231+
result.update(asdict(getattr(self, group_name)))
232+
return result
233+

0 commit comments

Comments
 (0)