Skip to content

Commit 45533f2

Browse files
authored
Merge pull request #14 from AdaptiveMotorControlLab/feat/add_api
Add high-level API for the pose3d inference pipeline
2 parents 58fcb2c + 2d89595 commit 45533f2

27 files changed

Lines changed: 1998 additions & 151 deletions

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,8 @@ htmlcov/
4545
*.pkl
4646
*.h5
4747
*.ckpt
48+
49+
# Excluded directories
50+
pre_trained_models/
51+
demo/predictions/
52+
demo/images/

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
![Version](https://img.shields.io/badge/python_version-3.10-purple)
44
[![PyPI version](https://badge.fury.io/py/fmpose3d.svg?icon=si%3Apython)](https://badge.fury.io/py/fmpose3d)
5-
[![License: LApache 2.0](https://img.shields.io/badge/License-Apache2.0-blue.svg)](https://www.gnu.org/licenses/apach2.0)
5+
[![License: Apache 2.0](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://www.apache.org/licenses/LICENSE-2.0)
66

77
This is the official implementation of the approach described in the preprint:
88

9-
[**FMPose3D: monocular 3D pose estimation via flow matching**](http://arxiv.org/abs/2602.05755)
9+
[**FMPose3D: monocular 3D pose estimation via flow matching**](https://arxiv.org/abs/2602.05755)
1010
Ti Wang, Xiaohang Yu, Mackenzie Weygandt Mathis
1111

1212
<!-- <p align="center"><img src="./images/Frame 4.jpg" width="50%" alt="" /></p> -->
@@ -51,7 +51,7 @@ sh vis_in_the_wild.sh
5151
```
5252
The predictions will be saved to folder `demo/predictions`.
5353

54-
<p align="center"><img src="./images/demo.jpg" width="95%" alt="" /></p>
54+
<p align="center"><img src="./images/demo.gif" width="95%" alt="" /></p>
5555

5656
## Training and Inference
5757

@@ -79,7 +79,7 @@ The training logs, checkpoints, and related files of each training time will be
7979

8080
For training on Human3.6M:
8181
```bash
82-
sh /scripts/FMPose3D_train.sh
82+
sh ./scripts/FMPose3D_train.sh
8383
```
8484

8585
### Inference

animals/demo/vis_animals.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
"""
99

1010
# SuperAnimal Demo: https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/COLAB/COLAB_YOURDATA_SuperAnimal.ipynb
11-
import sys
1211
import os
1312
import numpy as np
1413
import glob
@@ -25,8 +24,6 @@
2524
from fmpose3d.animals.common.arguments import opts as parse_args
2625
from fmpose3d.common.camera import normalize_screen_coordinates, camera_to_world
2726

28-
sys.path.append(os.getcwd())
29-
3027
args = parse_args().parse()
3128
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
3229

@@ -334,13 +331,15 @@ def get_pose3D(path, output_dir, type='image'):
334331
print(f"args.n_joints: {args.n_joints}, args.out_joints: {args.out_joints}")
335332

336333
## Reload model
334+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
335+
337336
model = {}
338-
model['CFM'] = CFM(args).cuda()
337+
model['CFM'] = CFM(args).to(device)
339338

340339
model_dict = model['CFM'].state_dict()
341340
model_path = args.saved_model_path
342341
print(f"Loading model from: {model_path}")
343-
pre_dict = torch.load(model_path)
342+
pre_dict = torch.load(model_path, map_location=device, weights_only=True)
344343
for name, key in model_dict.items():
345344
model_dict[name] = pre_dict[name]
346345
model['CFM'].load_state_dict(model_dict)
@@ -400,7 +399,8 @@ def get_3D_pose_from_image(args, keypoints, i, img, model, output_dir):
400399
input_2D = np.expand_dims(input_2D, axis=0) # (1, J, 2)
401400

402401
# Convert to tensor format matching visualize_animal_poses.py
403-
input_2D = torch.from_numpy(input_2D.astype('float32')).cuda() # (1, J, 2)
402+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
403+
input_2D = torch.from_numpy(input_2D.astype('float32')).to(device) # (1, J, 2)
404404
input_2D = input_2D.unsqueeze(0) # (1, 1, J, 2)
405405

406406
# Euler sampler for CFM
@@ -418,7 +418,7 @@ def euler_sample(c_2d, y_local, steps, model_3d):
418418

419419
# Single inference without flip augmentation
420420
# Create 3D random noise with shape (1, 1, J, 3)
421-
y = torch.randn(input_2D.size(0), input_2D.size(1), input_2D.size(2), 3).cuda()
421+
y = torch.randn(input_2D.size(0), input_2D.size(1), input_2D.size(2), 3, device=device)
422422
output_3D = euler_sample(input_2D, y, steps=args.sample_steps, model_3d=model)
423423

424424
output_3D = output_3D[0:, args.pad].unsqueeze(1)

animals/scripts/main_animal3d.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def step(split, args, actions, dataLoader, model, optimizer=None, epoch=None, st
7575
# gt_3D shape: torch.Size([B, J, 4]) (x,y,z + homogeneous coordinate)
7676
gt_3D = gt_3D[:,:,:3] # only use x,y,z for 3D ground truth
7777

78-
# [input_2D, gt_3D, batch_cam, vis_3D] = get_varialbe(split, [input_2D, gt_3D, batch_cam, vis_3D])
78+
# [input_2D, gt_3D, batch_cam, vis_3D] = get_variable(split, [input_2D, gt_3D, batch_cam, vis_3D])
7979

8080
# unsqueeze frame dimension
8181
input_2D = input_2D.unsqueeze(1) # (B,F,J,C)
@@ -264,15 +264,17 @@ def get_parameter_number(net):
264264
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,
265265
shuffle=False, num_workers=int(args.workers), pin_memory=True)
266266

267+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
268+
267269
model = {}
268-
model['CFM'] = CFM(args).cuda()
270+
model['CFM'] = CFM(args).to(device)
269271

270272
if args.reload:
271273
model_dict = model['CFM'].state_dict()
272274
# Prefer explicit saved_model_path; otherwise fallback to previous_dir glob
273275
model_path = args.saved_model_path
274276
print(model_path)
275-
pre_dict = torch.load(model_path)
277+
pre_dict = torch.load(model_path, weights_only=True, map_location=device)
276278
for name, key in model_dict.items():
277279
model_dict[name] = pre_dict[name]
278280
model['CFM'].load_state_dict(model_dict)

demo/vis_in_the_wild.py

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
Licensed under Apache 2.0
88
"""
99

10-
import sys
1110
import cv2
1211
import os
1312
import numpy as np
@@ -16,8 +15,6 @@
1615
from tqdm import tqdm
1716
import copy
1817

19-
sys.path.append(os.getcwd())
20-
2118
# Auto-download checkpoint files if missing
2219
from fmpose3d.lib.checkpoint.download_checkpoints import ensure_checkpoints
2320
ensure_checkpoints()
@@ -28,17 +25,10 @@
2825

2926
args = parse_args().parse()
3027
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
31-
if getattr(args, 'model_path', ''):
32-
import importlib.util
33-
import pathlib
34-
model_abspath = os.path.abspath(args.model_path)
35-
module_name = pathlib.Path(model_abspath).stem
36-
spec = importlib.util.spec_from_file_location(module_name, model_abspath)
37-
module = importlib.util.module_from_spec(spec)
38-
assert spec.loader is not None
39-
spec.loader.exec_module(module)
40-
CFM = getattr(module, 'Model')
41-
28+
29+
from fmpose3d.models import get_model
30+
CFM = get_model(args.model_type)
31+
4232
from fmpose3d.common.camera import *
4333

4434
import matplotlib
@@ -50,15 +40,27 @@
5040
matplotlib.rcParams['pdf.fonttype'] = 42
5141
matplotlib.rcParams['ps.fonttype'] = 42
5242

53-
def show2Dpose(kps, img):
54-
connections = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5],
55-
[5, 6], [0, 7], [7, 8], [8, 9], [9, 10],
56-
[8, 11], [11, 12], [12, 13], [8, 14], [14, 15], [15, 16]]
43+
# Shared skeleton definition so 2D/3D segment colors match
44+
SKELETON_CONNECTIONS = [
45+
[0, 1], [1, 2], [2, 3], [0, 4], [4, 5],
46+
[5, 6], [0, 7], [7, 8], [8, 9], [9, 10],
47+
[8, 11], [11, 12], [12, 13], [8, 14], [14, 15], [15, 16]
48+
]
49+
# LR mask for skeleton segments: True -> left color, False -> right color
50+
SKELETON_LR = np.array(
51+
[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
52+
dtype=bool,
53+
)
5754

58-
LR = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=bool)
55+
def show2Dpose(kps, img):
56+
connections = SKELETON_CONNECTIONS
57+
LR = SKELETON_LR
5958

6059
lcolor = (255, 0, 0)
6160
rcolor = (0, 0, 255)
61+
# lcolor = (240, 176, 0)
62+
# rcolor = (240, 176, 0)
63+
6264
thickness = 3
6365

6466
for j,c in enumerate(connections):
@@ -67,8 +69,8 @@ def show2Dpose(kps, img):
6769
start = list(start)
6870
end = list(end)
6971
cv2.line(img, (start[0], start[1]), (end[0], end[1]), lcolor if LR[j] else rcolor, thickness)
70-
cv2.circle(img, (start[0], start[1]), thickness=-1, color=(0, 255, 0), radius=3)
71-
cv2.circle(img, (end[0], end[1]), thickness=-1, color=(0, 255, 0), radius=3)
72+
# cv2.circle(img, (start[0], start[1]), thickness=-1, color=(0, 255, 0), radius=3)
73+
# cv2.circle(img, (end[0], end[1]), thickness=-1, color=(0, 255, 0), radius=3)
7274

7375
return img
7476

@@ -77,11 +79,13 @@ def show3Dpose(vals, ax):
7779

7880
lcolor=(0,0,1)
7981
rcolor=(1,0,0)
80-
81-
I = np.array( [0, 0, 1, 4, 2, 5, 0, 7, 8, 8, 14, 15, 11, 12, 8, 9])
82-
J = np.array( [1, 4, 2, 5, 3, 6, 7, 8, 14, 11, 15, 16, 12, 13, 9, 10])
83-
84-
LR = np.array([0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0], dtype=bool)
82+
# lcolor=(0/255, 176/255, 240/255)
83+
# rcolor=(0/255, 176/255, 240/255)
84+
85+
86+
I = np.array([c[0] for c in SKELETON_CONNECTIONS])
87+
J = np.array([c[1] for c in SKELETON_CONNECTIONS])
88+
LR = SKELETON_LR
8589

8690
for i in np.arange( len(I) ):
8791
x, y, z = [np.array( [vals[I[i], j], vals[J[i], j]] ) for j in range(3)]
@@ -199,7 +203,8 @@ def get_3D_pose_from_image(args, keypoints, i, img, model, output_dir):
199203

200204
input_2D = input_2D[np.newaxis, :, :, :, :]
201205

202-
input_2D = torch.from_numpy(input_2D.astype('float32')).cuda()
206+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
207+
input_2D = torch.from_numpy(input_2D.astype('float32')).to(device)
203208

204209
N = input_2D.size(0)
205210

@@ -215,10 +220,10 @@ def euler_sample(c_2d, y_local, steps, model_3d):
215220

216221
## estimation
217222

218-
y = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3).cuda()
223+
y = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3, device=device)
219224
output_3D_non_flip = euler_sample(input_2D[:, 0], y, steps=args.sample_steps, model_3d=model)
220225

221-
y_flip = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3).cuda()
226+
y_flip = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3, device=device)
222227
output_3D_flip = euler_sample(input_2D[:, 1], y_flip, steps=args.sample_steps, model_3d=model)
223228

224229
output_3D_flip[:, :, :, 0] *= -1
@@ -266,14 +271,16 @@ def get_pose3D(path, output_dir, type='image'):
266271
# args.type = type
267272

268273
## Reload
274+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
275+
269276
model = {}
270-
model['CFM'] = CFM(args).cuda()
277+
model['CFM'] = CFM(args).to(device)
271278

272279
# if args.reload:
273280
model_dict = model['CFM'].state_dict()
274-
model_path = args.saved_model_path
281+
model_path = args.model_weights_path
275282
print(model_path)
276-
pre_dict = torch.load(model_path)
283+
pre_dict = torch.load(model_path, map_location=device, weights_only=True)
277284
for name, key in model_dict.items():
278285
model_dict[name] = pre_dict[name]
279286
model['CFM'].load_state_dict(model_dict)
@@ -336,7 +343,7 @@ def get_pose3D(path, output_dir, type='image'):
336343
## save
337344
output_dir_pose = output_dir +'pose/'
338345
os.makedirs(output_dir_pose, exist_ok=True)
339-
plt.savefig(output_dir_pose + str(('%04d'% i)) + '_pose.jpg', dpi=200, bbox_inches = 'tight')
346+
plt.savefig(output_dir_pose + str(('%04d'% i)) + '_pose.png', dpi=200, bbox_inches = 'tight')
340347

341348

342349
if __name__ == "__main__":

demo/vis_in_the_wild.sh

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
#Test
22
layers=5
3-
gpu_id=1
3+
gpu_id=0
44
sample_steps=3
55
batch_size=1
66
sh_file='vis_in_the_wild.sh'
77

8-
model_path='../pre_trained_models/fmpose_detected2d/model_GAMLP.py'
9-
saved_model_path='../pre_trained_models/fmpose_detected2d/FMpose_36_4972_best.pth'
8+
model_type='fmpose3d'
9+
model_weights_path='../pre_trained_models/fmpose3d_h36m/FMpose3D_pretrained_weights.pth'
1010

11-
# path='./images/image_00068.jpg' # single image
12-
input_images_folder='./images/' # folder containing multiple images
11+
target_path='./images/' # folder containing multiple images
12+
# target_path='./images/xx.png' # single image
13+
# target_path='./videos/xxx.mp4' # video path
1314

1415
python3 vis_in_the_wild.py \
1516
--type 'image' \
16-
--path ${input_images_folder} \
17-
--saved_model_path "${saved_model_path}" \
18-
--model_path "${model_path}" \
17+
--path ${target_path} \
18+
--model_weights_path "${model_weights_path}" \
19+
--model_type "${model_type}" \
1920
--sample_steps ${sample_steps} \
2021
--batch_size ${batch_size} \
2122
--layers ${layers} \

fmpose3d/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,49 @@
1818
aggregation_RPEA_joint_level,
1919
)
2020

21+
# Configuration dataclasses
22+
from .common.config import (
23+
FMPose3DConfig,
24+
HRNetConfig,
25+
InferenceConfig,
26+
ModelConfig,
27+
PipelineConfig,
28+
)
29+
30+
# High-level inference API
31+
from .fmpose3d import (
32+
FMPose3DInference,
33+
HRNetEstimator,
34+
Pose2DResult,
35+
Pose3DResult,
36+
Source,
37+
)
38+
2139
# Import 2D pose detection utilities
2240
from .lib.hrnet.gen_kpts import gen_video_kpts
41+
from .lib.hrnet.hrnet import HRNetPose2d
2342
from .lib.preprocess import h36m_coco_format, revise_kpts
2443

2544
# Make commonly used classes/functions available at package level
2645
__all__ = [
46+
# Inference API
47+
"FMPose3DInference",
48+
"HRNetEstimator",
49+
"Pose2DResult",
50+
"Pose3DResult",
51+
"Source",
52+
# Configuration
53+
"FMPose3DConfig",
54+
"HRNetConfig",
55+
"InferenceConfig",
56+
"ModelConfig",
57+
"PipelineConfig",
2758
# Aggregation methods
2859
"average_aggregation",
2960
"aggregation_select_single_best_hypothesis_by_2D_error",
3061
"aggregation_RPEA_joint_level",
3162
# 2D pose detection
63+
"HRNetPose2d",
3264
"gen_video_kpts",
3365
"h36m_coco_format",
3466
"revise_kpts",

fmpose3d/aggregation_methods.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,17 +166,13 @@ def aggregation_RPEA_joint_level(
166166
dist[:, :, 0] = 0.0
167167

168168
# Convert 2D losses to weights using softmax over top-k hypotheses per joint
169-
tau = float(getattr(args, "weight_softmax_tau", 1.0))
170169
H = dist.size(1)
171170
k = int(getattr(args, "topk", None))
172-
# print("k:", k)
173-
# k = int(H//2)+1
174171
k = max(1, min(k, H))
175172

176173
# top-k smallest distances along hypothesis dim
177174
topk_vals, topk_idx = torch.topk(dist, k=k, dim=1, largest=False) # (B,k,J)
178175

179-
# Weight calculation method ; weight_method = 'exp'
180176
temp = args.exp_temp
181177
max_safe_val = temp * 20
182178
topk_vals_clipped = torch.clamp(topk_vals, max=max_safe_val)

fmpose3d/animals/common/arber_dataset.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import glob
1313
import os
1414
import random
15-
import sys
1615

1716
import cv2
1817
import matplotlib.pyplot as plt
@@ -23,10 +22,8 @@
2322
from torch.utils.data import Dataset
2423
from tqdm import tqdm
2524

26-
sys.path.append(os.path.dirname(sys.path[0]))
27-
28-
from common.camera import normalize_screen_coordinates
29-
from common.lifter3d import load_camera_params, load_h5_keypoints
25+
from fmpose3d.common.camera import normalize_screen_coordinates
26+
from fmpose3d.animals.common.lifter3d import load_camera_params, load_h5_keypoints
3027

3128

3229
class ArberDataset(Dataset):

0 commit comments

Comments
 (0)