Skip to content

Commit bfc3c3c

Browse files
xiu-csderuyter92
authored andcommitted
Update model_type to "fmpose3d_humans" across configuration and model files for consistency in the FMPose3D framework.
1 parent fea9ea6 commit bfc3c3c

4 files changed

Lines changed: 9 additions & 9 deletions

File tree

fmpose3d/common/arguments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def init(self):
7474
self.parser.add_argument("--model_dir", type=str, default="")
7575
# Optional: load model class from a specific file path
7676
self.parser.add_argument("--model_path", type=str, default="")
77-
# Model registry name (e.g. "fmpose3d"); used instead of --model_path
78-
self.parser.add_argument("--model_type", type=str, default="fmpose3d")
77+
# Model registry name (e.g. "fmpose3d_humans"); used instead of --model_path
78+
self.parser.add_argument("--model_type", type=str, default="fmpose3d_humans")
7979
self.parser.add_argument("--model_weights_path", type=str, default="")
8080

8181
self.parser.add_argument("--post_refine_reload", action="store_true")

fmpose3d/common/config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
@dataclass
2020
class ModelConfig:
2121
"""Model architecture configuration."""
22-
model_type: str = "fmpose3d"
22+
model_type: str = "fmpose3d_humans"
2323

2424

2525
# Per-model-type defaults for fields marked with INFER_FROM_MODEL_TYPE.
2626
# Also consumed by PipelineConfig.for_model_type to set cross-config
2727
# values (dataset, sample_steps, etc.).
2828
_FMPOSE3D_DEFAULTS: Dict[str, Dict] = {
29-
"fmpose3d": {
29+
"fmpose3d_humans": {
3030
"n_joints": 17,
3131
"out_joints": 17,
3232
"dataset": "h36m",
@@ -39,7 +39,7 @@ class ModelConfig:
3939
"n_joints": 26,
4040
"out_joints": 26,
4141
"dataset": "animal3d",
42-
"sample_steps": 3,
42+
"sample_steps": 5,
4343
"joints_left": [0, 3, 5, 8, 10, 12, 14, 16, 20, 22],
4444
"joints_right": [1, 4, 6, 9, 11, 13, 15, 17, 21, 23],
4545
"root_joint": 7,
@@ -51,7 +51,7 @@ class ModelConfig:
5151

5252
@dataclass
5353
class FMPose3DConfig(ModelConfig):
54-
model_type: str = "fmpose3d"
54+
model_type: str = "fmpose3d_humans"
5555
model: str = ""
5656
layers: int = 5
5757
channel: int = 512
@@ -321,7 +321,7 @@ def _pick(dc_class, src: dict):
321321

322322
kwargs = {}
323323
for group_name, dc_class in _SUB_CONFIG_CLASSES.items():
324-
if group_name == "model_cfg" and raw.get("model_type", 'fmpose3d') in _FMPOSE3D_DEFAULTS:
324+
if group_name == "model_cfg" and raw.get("model_type", 'fmpose3d_humans') in _FMPOSE3D_DEFAULTS:
325325
dc_class = FMPose3DConfig
326326
elif group_name == "pose2d_cfg":
327327
p2d = raw.get("pose2d_model", "hrnet")

fmpose3d/models/fmpose3d/model_GAMLP.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def forward(self, x):
212212
x = self.fc2(x)
213213
return x
214214

215-
@register_model("fmpose3d")
215+
@register_model("fmpose3d_humans")
216216
class Model(BaseModel):
217217
def __init__(self, args):
218218
super().__init__(args)

tests/test_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def test_from_namespace_basic(self):
250250
ns = argparse.Namespace(
251251
# FMPose3DConfig
252252
model="test_model",
253-
model_type="fmpose3d",
253+
model_type="fmpose3d_humans",
254254
layers=5,
255255
channel=256,
256256
d_hid=512,

0 commit comments

Comments
 (0)