Skip to content

Commit b73d9ce

Browse files
committed
Feat: add extendable model registry
- New ABC base_model as template. - Easy access to defined set of models. - Modularly extendable with new implementations.
1 parent db775ea commit b73d9ce

5 files changed

Lines changed: 143 additions & 7 deletions

File tree

fmpose3d/models/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,16 @@
1111
FMPose3D models.
1212
"""
1313

14-
from .graph_frames import Graph
15-
from .model_GAMLP import Model
14+
from .base_model import BaseModel, register_model, get_model, list_models
15+
16+
# Import model subpackages so their @register_model decorators execute.
17+
from .fmpose3d import Graph, Model
1618

1719
__all__ = [
20+
"BaseModel",
21+
"register_model",
22+
"get_model",
23+
"list_models",
1824
"Graph",
1925
"Model",
2026
]

fmpose3d/models/base_model.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""
2+
FMPose3D: monocular 3D Pose Estimation via Flow Matching
3+
4+
Official implementation of the paper:
5+
"FMPose3D: monocular 3D Pose Estimation via Flow Matching"
6+
by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis
7+
Licensed under Apache 2.0
8+
"""
9+
10+
from abc import ABC, abstractmethod
11+
import warnings
12+
13+
import torch
14+
import torch.nn as nn
15+
16+
17+
# ---------------------------------------------------------------------------
18+
# Model registry
19+
# ---------------------------------------------------------------------------
20+
21+
_MODEL_REGISTRY: dict[str, type["BaseModel"]] = {}
22+
23+
24+
def register_model(name: str):
25+
"""Class decorator that registers a model under *name*.
26+
27+
Usage::
28+
29+
@register_model("my_model")
30+
class MyModel(BaseModel):
31+
...
32+
33+
The model can then be retrieved with :func:`get_model`.
34+
"""
35+
36+
def decorator(cls: type["BaseModel"]) -> type["BaseModel"]:
37+
if name in _MODEL_REGISTRY:
38+
warnings.warn(
39+
f"Model '{name}' is already registered "
40+
f"(existing: {_MODEL_REGISTRY[name].__qualname__}, "
41+
f"new: {cls.__qualname__})"
42+
)
43+
# raise ValueError(
44+
# f"Model '{name}' is already registered "
45+
# f"(existing: {_MODEL_REGISTRY[name].__qualname__}, "
46+
# f"new: {cls.__qualname__})"
47+
# )
48+
_MODEL_REGISTRY[name] = cls
49+
return cls
50+
51+
return decorator
52+
53+
54+
def get_model(name: str) -> type["BaseModel"]:
55+
"""Return the model class registered under *name*.
56+
57+
Raises :class:`KeyError` with a helpful message when the name is unknown.
58+
"""
59+
if name not in _MODEL_REGISTRY:
60+
available = ", ".join(sorted(_MODEL_REGISTRY)) or "(none)"
61+
raise KeyError(
62+
f"Unknown model '{name}'. Available models: {available}"
63+
)
64+
return _MODEL_REGISTRY[name]
65+
66+
67+
def list_models() -> list[str]:
68+
"""Return a sorted list of all registered model names."""
69+
return sorted(_MODEL_REGISTRY)
70+
71+
72+
# ---------------------------------------------------------------------------
73+
# Base model
74+
# ---------------------------------------------------------------------------
75+
76+
77+
class BaseModel(ABC, nn.Module):
78+
"""Abstract base class for all FMPose3D lifting models.
79+
80+
Every model must accept a single *args* namespace / object in its
81+
constructor and implement :meth:`forward` with the signature below.
82+
83+
Parameters expected on *args* (at minimum):
84+
- ``channel`` – embedding dimension
85+
- ``layers`` – number of transformer / GCN blocks
86+
- ``d_hid`` – hidden MLP dimension
87+
- ``token_dim`` – token dimension
88+
- ``n_joints`` – number of body joints
89+
"""
90+
91+
@abstractmethod
92+
def __init__(self, args) -> None:
93+
super().__init__()
94+
95+
@abstractmethod
96+
def forward(
97+
self,
98+
pose_2d: torch.Tensor,
99+
y_t: torch.Tensor,
100+
t: torch.Tensor,
101+
) -> torch.Tensor:
102+
"""Predict the velocity field for flow matching.
103+
104+
Args:
105+
pose_2d: 2D keypoints, shape ``(B, F, J, 2)``.
106+
y_t: Noisy 3D hypothesis at time *t*, shape ``(B, F, J, 3)``.
107+
t: Diffusion / flow time, shape ``(B, F, 1, 1)`` with values
108+
in ``[0, 1]``.
109+
110+
Returns:
111+
Predicted velocity ``v``, shape ``(B, F, J, 3)``.
112+
"""
113+
...
114+
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""
2+
FMPose3D model subpackage.
3+
"""
4+
5+
from .graph_frames import Graph
6+
from .model_GAMLP import Model
7+
8+
__all__ = [
9+
"Graph",
10+
"Model",
11+
]
12+
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,4 +207,5 @@ def normalize_undigraph(A):
207207
if __name__=="__main__":
208208
graph = Graph('hm36_gt', 'spatial', 1)
209209
print(graph.A.shape)
210-
# print(graph)
210+
# print(graph)
211+
Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
import torch.nn as nn
1212
import math
1313
from einops import rearrange
14-
from fmpose3d.models.graph_frames import Graph
14+
from fmpose3d.models.fmpose3d.graph_frames import Graph
15+
from fmpose3d.models.base_model import BaseModel, register_model
1516
from functools import partial
1617
from einops import rearrange
1718
from timm.models.layers import DropPath
@@ -211,9 +212,10 @@ def forward(self, x):
211212
x = self.fc2(x)
212213
return x
213214

214-
class Model(nn.Module):
215+
@register_model("fmpose3d")
216+
class Model(BaseModel):
215217
def __init__(self, args):
216-
super().__init__()
218+
super().__init__(args)
217219
## GCN
218220
self.graph = Graph('hm36_gt', 'spatial', pad=1)
219221
# Register as buffer (not parameter) to follow module device automatically
@@ -264,4 +266,5 @@ class Args:
264266
y_t = torch.randn(1, 17, 17, 3, device=device)
265267
t = torch.randn(1, 1, 1, 1, device=device)
266268
v = model(x, y_t, t)
267-
print(v.shape)
269+
print(v.shape)
270+

0 commit comments

Comments
 (0)