|
| 1 | +""" |
| 2 | +FMPose3D: monocular 3D Pose Estimation via Flow Matching |
| 3 | +
|
| 4 | +Official implementation of the paper: |
| 5 | +"FMPose3D: monocular 3D Pose Estimation via Flow Matching" |
| 6 | +by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis |
| 7 | +Licensed under Apache 2.0 |
| 8 | +""" |
| 9 | + |
| 10 | +import math |
| 11 | +from dataclasses import dataclass, field, fields, asdict |
| 12 | +from typing import List |
| 13 | + |
| 14 | + |
| 15 | +# --------------------------------------------------------------------------- |
| 16 | +# Dataclass configuration groups |
| 17 | +# --------------------------------------------------------------------------- |
| 18 | + |
| 19 | + |
| 20 | +@dataclass |
| 21 | +class ModelConfig: |
| 22 | + """Model architecture configuration.""" |
| 23 | + |
| 24 | + model: str = "" |
| 25 | + layers: int = 3 |
| 26 | + channel: int = 512 |
| 27 | + d_hid: int = 1024 |
| 28 | + token_dim: int = 256 |
| 29 | + n_joints: int = 17 |
| 30 | + out_joints: int = 17 |
| 31 | + in_channels: int = 2 |
| 32 | + out_channels: int = 3 |
| 33 | + frames: int = 1 |
| 34 | + model_path: str = "" |
| 35 | + """Optional: load model class from a specific file path.""" |
| 36 | + |
| 37 | + |
| 38 | +@dataclass |
| 39 | +class DatasetConfig: |
| 40 | + """Dataset and data loading configuration.""" |
| 41 | + |
| 42 | + dataset: str = "h36m" |
| 43 | + keypoints: str = "cpn_ft_h36m_dbb" |
| 44 | + root_path: str = "dataset/" |
| 45 | + actions: str = "*" |
| 46 | + downsample: int = 1 |
| 47 | + subset: float = 1.0 |
| 48 | + stride: int = 1 |
| 49 | + crop_uv: int = 0 |
| 50 | + out_all: int = 1 |
| 51 | + train_views: List[int] = field(default_factory=lambda: [0, 1, 2, 3]) |
| 52 | + test_views: List[int] = field(default_factory=lambda: [0, 1, 2, 3]) |
| 53 | + |
| 54 | + # Derived / set during parse based on dataset choice |
| 55 | + subjects_train: str = "S1,S5,S6,S7,S8" |
| 56 | + subjects_test: str = "S9,S11" |
| 57 | + root_joint: int = 0 |
| 58 | + joints_left: List[int] = field(default_factory=list) |
| 59 | + joints_right: List[int] = field(default_factory=list) |
| 60 | + |
| 61 | + |
| 62 | +@dataclass |
| 63 | +class TrainingConfig: |
| 64 | + """Training hyperparameters and settings.""" |
| 65 | + |
| 66 | + train: bool = False |
| 67 | + nepoch: int = 41 |
| 68 | + batch_size: int = 128 |
| 69 | + lr: float = 1e-3 |
| 70 | + lr_decay: float = 0.95 |
| 71 | + lr_decay_large: float = 0.5 |
| 72 | + large_decay_epoch: int = 5 |
| 73 | + workers: int = 8 |
| 74 | + data_augmentation: bool = True |
| 75 | + reverse_augmentation: bool = False |
| 76 | + norm: float = 0.01 |
| 77 | + |
| 78 | + |
| 79 | +@dataclass |
| 80 | +class InferenceConfig: |
| 81 | + """Evaluation and testing configuration.""" |
| 82 | + |
| 83 | + test: int = 1 |
| 84 | + test_augmentation: bool = True |
| 85 | + test_augmentation_flip_hypothesis: bool = False |
| 86 | + test_augmentation_FlowAug: bool = False |
| 87 | + sample_steps: int = 3 |
| 88 | + eval_multi_steps: bool = False |
| 89 | + eval_sample_steps: str = "1,3,5,7,9" |
| 90 | + num_hypothesis_list: str = "1" |
| 91 | + hypothesis_num: int = 1 |
| 92 | + guidance_scale: float = 1.0 |
| 93 | + |
| 94 | + |
| 95 | +@dataclass |
| 96 | +class AggregationConfig: |
| 97 | + """Hypothesis aggregation configuration.""" |
| 98 | + |
| 99 | + topk: int = 3 |
| 100 | + exp_temp: float = 0.002 |
| 101 | + mode: str = "exp" |
| 102 | + opt_steps: int = 2 |
| 103 | + |
| 104 | + |
| 105 | +@dataclass |
| 106 | +class CheckpointConfig: |
| 107 | + """Checkpoint loading and saving configuration.""" |
| 108 | + |
| 109 | + reload: bool = False |
| 110 | + model_dir: str = "" |
| 111 | + model_weights_path: str = "" |
| 112 | + checkpoint: str = "" |
| 113 | + previous_dir: str = "./pre_trained_model/pretrained" |
| 114 | + num_saved_models: int = 3 |
| 115 | + previous_best_threshold: float = math.inf |
| 116 | + previous_name: str = "" |
| 117 | + |
| 118 | + |
| 119 | +@dataclass |
| 120 | +class RefinementConfig: |
| 121 | + """Post-refinement model configuration.""" |
| 122 | + |
| 123 | + post_refine: bool = False |
| 124 | + post_refine_reload: bool = False |
| 125 | + previous_post_refine_name: str = "" |
| 126 | + lr_refine: float = 1e-5 |
| 127 | + refine: bool = False |
| 128 | + reload_refine: bool = False |
| 129 | + previous_refine_name: str = "" |
| 130 | + |
| 131 | + |
| 132 | +@dataclass |
| 133 | +class OutputConfig: |
| 134 | + """Output, logging, and file management configuration.""" |
| 135 | + |
| 136 | + create_time: str = "" |
| 137 | + filename: str = "" |
| 138 | + create_file: int = 1 |
| 139 | + debug: bool = False |
| 140 | + folder_name: str = "" |
| 141 | + sh_file: str = "" |
| 142 | + |
| 143 | + |
| 144 | +@dataclass |
| 145 | +class DemoConfig: |
| 146 | + """Demo / inference configuration.""" |
| 147 | + |
| 148 | + type: str = "image" |
| 149 | + """Input type: ``'image'`` or ``'video'``.""" |
| 150 | + path: str = "demo/images/running.png" |
| 151 | + """Path to input file or directory.""" |
| 152 | + |
| 153 | + |
| 154 | +@dataclass |
| 155 | +class RuntimeConfig: |
| 156 | + """Runtime environment configuration.""" |
| 157 | + |
| 158 | + gpu: str = "0" |
| 159 | + pad: int = 0 # derived: (frames - 1) // 2 |
| 160 | + single: bool = False |
| 161 | + reload_3d: bool = False |
| 162 | + |
| 163 | + |
| 164 | +# --------------------------------------------------------------------------- |
| 165 | +# Composite configuration |
| 166 | +# --------------------------------------------------------------------------- |
| 167 | + |
| 168 | +_SUB_CONFIG_CLASSES = { |
| 169 | + "model_cfg": ModelConfig, |
| 170 | + "dataset_cfg": DatasetConfig, |
| 171 | + "training_cfg": TrainingConfig, |
| 172 | + "inference_cfg": InferenceConfig, |
| 173 | + "aggregation_cfg": AggregationConfig, |
| 174 | + "checkpoint_cfg": CheckpointConfig, |
| 175 | + "refinement_cfg": RefinementConfig, |
| 176 | + "output_cfg": OutputConfig, |
| 177 | + "demo_cfg": DemoConfig, |
| 178 | + "runtime_cfg": RuntimeConfig, |
| 179 | +} |
| 180 | + |
| 181 | + |
| 182 | +@dataclass |
| 183 | +class FMPoseConfig: |
| 184 | + """Top-level configuration for FMPose3D. |
| 185 | +
|
| 186 | + Groups related settings into sub-configs:: |
| 187 | +
|
| 188 | + config.model_cfg.layers |
| 189 | + config.training_cfg.lr |
| 190 | + """ |
| 191 | + |
| 192 | + model_cfg: ModelConfig = field(default_factory=ModelConfig) |
| 193 | + dataset_cfg: DatasetConfig = field(default_factory=DatasetConfig) |
| 194 | + training_cfg: TrainingConfig = field(default_factory=TrainingConfig) |
| 195 | + inference_cfg: InferenceConfig = field(default_factory=InferenceConfig) |
| 196 | + aggregation_cfg: AggregationConfig = field(default_factory=AggregationConfig) |
| 197 | + checkpoint_cfg: CheckpointConfig = field(default_factory=CheckpointConfig) |
| 198 | + refinement_cfg: RefinementConfig = field(default_factory=RefinementConfig) |
| 199 | + output_cfg: OutputConfig = field(default_factory=OutputConfig) |
| 200 | + demo_cfg: DemoConfig = field(default_factory=DemoConfig) |
| 201 | + runtime_cfg: RuntimeConfig = field(default_factory=RuntimeConfig) |
| 202 | + |
| 203 | + # -- construction from argparse namespace --------------------------------- |
| 204 | + |
| 205 | + @classmethod |
| 206 | + def from_namespace(cls, ns) -> "FMPoseConfig": |
| 207 | + """Build a :class:`FMPoseConfig` from an ``argparse.Namespace`` |
| 208 | +
|
| 209 | + Example:: |
| 210 | +
|
| 211 | + args = opts().parse() |
| 212 | + cfg = FMPoseConfig.from_namespace(args) |
| 213 | + """ |
| 214 | + raw = vars(ns) if hasattr(ns, "__dict__") else dict(ns) |
| 215 | + |
| 216 | + def _pick(dc_class, src: dict): |
| 217 | + names = {f.name for f in fields(dc_class)} |
| 218 | + return dc_class(**{k: v for k, v in src.items() if k in names}) |
| 219 | + |
| 220 | + return cls(**{ |
| 221 | + group_name: _pick(dc_class, raw) |
| 222 | + for group_name, dc_class in _SUB_CONFIG_CLASSES.items() |
| 223 | + }) |
| 224 | + |
| 225 | + # -- utilities ------------------------------------------------------------ |
| 226 | + |
| 227 | + def to_dict(self) -> dict: |
| 228 | + """Return a flat dictionary of all configuration values.""" |
| 229 | + result = {} |
| 230 | + for group_name in _SUB_CONFIG_CLASSES: |
| 231 | + result.update(asdict(getattr(self, group_name))) |
| 232 | + return result |
| 233 | + |
0 commit comments