Skip to content

Commit c4bf891

Browse files
committed
Add HRNet model api
- This is an adapter of the `gen_video_kpts` function - it can read arrays instead of image paths - can be configured with HRNetConfig
1 parent 3863ddf commit c4bf891

3 files changed

Lines changed: 297 additions & 3 deletions

File tree

fmpose3d/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
# Import 2D pose detection utilities
2222
from .lib.hrnet.gen_kpts import gen_video_kpts
23+
from .lib.hrnet.hrnet import HRNetPose2d
2324
from .lib.preprocess import h36m_coco_format, revise_kpts
2425

2526
# Make commonly used classes/functions available at package level
@@ -29,6 +30,7 @@
2930
"aggregation_select_single_best_hypothesis_by_2D_error",
3031
"aggregation_RPEA_joint_level",
3132
# 2D pose detection
33+
"HRNetPose2d",
3234
"gen_video_kpts",
3335
"h36m_coco_format",
3436
"revise_kpts",

fmpose3d/common/config.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,29 @@ class Pose2DConfig:
153153

154154
@dataclass
155155
class HRNetConfig(Pose2DConfig):
156-
"""HRNet 2D pose detector configuration."""
156+
"""HRNet 2D pose detector configuration.
157+
158+
Attributes
159+
----------
160+
det_dim : int
161+
YOLO input resolution for human detection (default 416).
162+
num_persons : int
163+
Maximum number of persons to estimate per frame (default 1).
164+
thred_score : float
165+
YOLO object-confidence threshold (default 0.30).
166+
hrnet_cfg_file : str
167+
Path to the HRNet YAML experiment config. When left empty the
168+
bundled ``w48_384x288_adam_lr1e-3.yaml`` is used.
169+
hrnet_weights_path : str
170+
Path to the HRNet ``.pth`` checkpoint. When left empty the
171+
auto-downloaded ``pose_hrnet_w48_384x288.pth`` is used.
172+
"""
157173
pose2d_model: str = "hrnet"
158-
det_dim: int = 416 # YOLO input dimension (HRNet-specific).
159-
num_persons: int = 1 # Maximum number of persons to estimate per frame.
174+
det_dim: int = 416
175+
num_persons: int = 1
176+
thred_score: float = 0.30
177+
hrnet_cfg_file: str = ""
178+
hrnet_weights_path: str = ""
160179

161180

162181
@dataclass

fmpose3d/lib/hrnet/hrnet.py

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
"""
2+
FMPose3D – clean HRNet 2D pose estimation API.
3+
4+
Provides :class:`HRNetPose2d`, a self-contained wrapper around the
5+
HRNet + YOLO detection pipeline that accepts numpy arrays directly
6+
(no file I/O, no argparse, no global yacs config leaking out).
7+
8+
Usage::
9+
10+
api = HRNetPose2d(det_dim=416, num_persons=1)
11+
api.setup() # loads YOLO + HRNet weights
12+
keypoints, scores = api.predict(frames) # (M, N, 17, 2), (M, N, 17)
13+
"""
14+
15+
from __future__ import annotations
16+
17+
import copy
18+
import os.path as osp
19+
from collections import OrderedDict
20+
from typing import Tuple
21+
22+
import numpy as np
23+
import torch
24+
import torch.backends.cudnn as cudnn
25+
26+
from fmpose3d.lib.checkpoint.download_checkpoints import (
27+
ensure_checkpoints,
28+
get_checkpoint_path,
29+
)
30+
31+
32+
class HRNetPose2d:
33+
"""Self-contained 2D pose estimator (YOLO detector + HRNet).
34+
35+
A self-contained HRNet 2D pose estimator that accepts numpy arrays directly.
36+
It serves as alternative to the gen_video_kpts function in fmpose3d/lib/hrnet/gen_kpts.py,
37+
which generates 2D keypoints from a video file.
38+
39+
Parameters
40+
----------
41+
det_dim : int
42+
YOLO input resolution (default 416).
43+
num_persons : int
44+
Maximum number of persons to track per frame (default 1).
45+
thred_score : float
46+
YOLO object-confidence threshold (default 0.30).
47+
hrnet_cfg_file : str
48+
Path to the HRNet YAML experiment config. Empty string (default)
49+
uses the bundled ``w48_384x288_adam_lr1e-3.yaml``.
50+
hrnet_weights_path : str
51+
Path to the HRNet ``.pth`` checkpoint. Empty string (default)
52+
uses the auto-downloaded ``pose_hrnet_w48_384x288.pth``.
53+
"""
54+
55+
def __init__(
56+
self,
57+
det_dim: int = 416,
58+
num_persons: int = 1,
59+
thred_score: float = 0.30,
60+
hrnet_cfg_file: str = "",
61+
hrnet_weights_path: str = "",
62+
) -> None:
63+
self.det_dim = det_dim
64+
self.num_persons = num_persons
65+
self.thred_score = thred_score
66+
self.hrnet_cfg_file = hrnet_cfg_file
67+
self.hrnet_weights_path = hrnet_weights_path
68+
69+
# Populated by setup()
70+
self._human_model = None
71+
self._pose_model = None
72+
self._people_sort = None
73+
self._hrnet_cfg = None # frozen yacs CfgNode used by PreProcess / get_final_preds
74+
75+
# ------------------------------------------------------------------
76+
# Setup
77+
# ------------------------------------------------------------------
78+
79+
@property
80+
def is_ready(self) -> bool:
81+
"""``True`` once :meth:`setup` has been called."""
82+
return self._human_model is not None
83+
84+
def setup(self) -> "HRNetPose2d":
85+
"""Load YOLO detector and HRNet pose model.
86+
87+
Can safely be called more than once (subsequent calls are no-ops).
88+
89+
Returns ``self`` so you can write ``api = HRNetPose2d().setup()``.
90+
"""
91+
if self.is_ready:
92+
return self
93+
94+
ensure_checkpoints()
95+
96+
# --- resolve paths ---------------------------------------------------
97+
hrnet_cfg_file = self.hrnet_cfg_file
98+
if not hrnet_cfg_file:
99+
hrnet_cfg_file = osp.join(
100+
osp.dirname(osp.abspath(__file__)),
101+
"experiments",
102+
"w48_384x288_adam_lr1e-3.yaml",
103+
)
104+
105+
hrnet_weights = self.hrnet_weights_path
106+
if not hrnet_weights:
107+
hrnet_weights = get_checkpoint_path("pose_hrnet_w48_384x288.pth")
108+
109+
# --- build internal yacs config (kept private) -----------------------
110+
from fmpose3d.lib.hrnet.lib.config import cfg as _global_cfg
111+
from fmpose3d.lib.hrnet.lib.config import update_config as _update_cfg
112+
from types import SimpleNamespace
113+
114+
_global_cfg.defrost()
115+
_update_cfg(
116+
_global_cfg,
117+
SimpleNamespace(cfg=hrnet_cfg_file, opts=[], modelDir=hrnet_weights),
118+
)
119+
# Snapshot the frozen cfg so we can pass it to PreProcess / get_final_preds.
120+
self._hrnet_cfg = _global_cfg
121+
122+
# cudnn tuning
123+
cudnn.benchmark = self._hrnet_cfg.CUDNN.BENCHMARK
124+
cudnn.deterministic = self._hrnet_cfg.CUDNN.DETERMINISTIC
125+
cudnn.enabled = self._hrnet_cfg.CUDNN.ENABLED
126+
127+
# --- load models -----------------------------------------------------
128+
from fmpose3d.lib.yolov3.human_detector import load_model as _yolo_load
129+
from fmpose3d.lib.sort.sort import Sort
130+
131+
self._human_model = _yolo_load(inp_dim=self.det_dim)
132+
self._pose_model = self._load_hrnet(self._hrnet_cfg)
133+
self._people_sort = Sort(min_hits=0)
134+
135+
return self
136+
137+
# ------------------------------------------------------------------
138+
# Prediction
139+
# ------------------------------------------------------------------
140+
141+
def predict(
142+
self, frames: np.ndarray
143+
) -> Tuple[np.ndarray, np.ndarray]:
144+
"""Estimate 2D keypoints for a batch of BGR frames.
145+
146+
Parameters
147+
----------
148+
frames : ndarray, shape ``(N, H, W, C)``
149+
BGR images. A single frame ``(H, W, C)`` is also accepted
150+
and will be treated as a batch of one.
151+
152+
Returns
153+
-------
154+
keypoints : ndarray, shape ``(num_persons, N, 17, 2)``
155+
COCO-format 2D keypoints in pixel coordinates.
156+
scores : ndarray, shape ``(num_persons, N, 17)``
157+
Per-joint confidence scores.
158+
"""
159+
if not self.is_ready:
160+
self.setup()
161+
162+
if frames.ndim == 3:
163+
frames = frames[np.newaxis]
164+
165+
kpts_result = []
166+
scores_result = []
167+
168+
for i in range(frames.shape[0]):
169+
kpts, sc = self._estimate_frame(frames[i])
170+
kpts_result.append(kpts)
171+
scores_result.append(sc)
172+
173+
keypoints = np.array(kpts_result) # (N, M, 17, 2)
174+
scores = np.array(scores_result) # (N, M, 17)
175+
176+
# (N, M, 17, 2) → (M, N, 17, 2)
177+
keypoints = keypoints.transpose(1, 0, 2, 3)
178+
# (N, M, 17) → (M, N, 17)
179+
scores = scores.transpose(1, 0, 2)
180+
181+
return keypoints, scores
182+
183+
# ------------------------------------------------------------------
184+
# Internal helpers
185+
# ------------------------------------------------------------------
186+
187+
@staticmethod
188+
def _load_hrnet(config):
189+
"""Instantiate HRNet and load checkpoint weights."""
190+
from fmpose3d.lib.hrnet.lib.models import pose_hrnet
191+
192+
model = pose_hrnet.get_pose_net(config, is_train=False)
193+
if torch.cuda.is_available():
194+
model = model.cuda()
195+
196+
state_dict = torch.load(config.OUTPUT_DIR, weights_only=True)
197+
new_state_dict = OrderedDict()
198+
for k, v in state_dict.items():
199+
new_state_dict[k] = v
200+
model.load_state_dict(new_state_dict)
201+
model.eval()
202+
return model
203+
204+
def _estimate_frame(
205+
self, frame: np.ndarray
206+
) -> Tuple[np.ndarray, np.ndarray]:
207+
"""Run detection + pose estimation on a single BGR frame.
208+
209+
Returns
210+
-------
211+
kpts : ndarray, shape ``(num_persons, 17, 2)``
212+
scores : ndarray, shape ``(num_persons, 17)``
213+
"""
214+
from fmpose3d.lib.yolov3.human_detector import yolo_human_det
215+
from fmpose3d.lib.hrnet.lib.utils.utilitys import PreProcess
216+
from fmpose3d.lib.hrnet.lib.utils.inference import get_final_preds
217+
218+
num_persons = self.num_persons
219+
220+
bboxs, det_scores = yolo_human_det(
221+
frame, self._human_model, reso=self.det_dim, confidence=self.thred_score,
222+
)
223+
224+
if bboxs is None or not bboxs.any():
225+
# No detection – return zeros
226+
kpts = np.zeros((num_persons, 17, 2), dtype=np.float32)
227+
scores = np.zeros((num_persons, 17), dtype=np.float32)
228+
return kpts, scores
229+
230+
# Track
231+
people_track = self._people_sort.update(bboxs)
232+
233+
if people_track.shape[0] == 1:
234+
people_track_ = people_track[-1, :-1].reshape(1, 4)
235+
elif people_track.shape[0] >= 2:
236+
people_track_ = people_track[-num_persons:, :-1].reshape(num_persons, 4)
237+
people_track_ = people_track_[::-1]
238+
else:
239+
kpts = np.zeros((num_persons, 17, 2), dtype=np.float32)
240+
scores = np.zeros((num_persons, 17), dtype=np.float32)
241+
return kpts, scores
242+
243+
track_bboxs = []
244+
for bbox in people_track_:
245+
bbox = [round(i, 2) for i in list(bbox)]
246+
track_bboxs.append(bbox)
247+
248+
with torch.no_grad():
249+
inputs, origin_img, center, scale = PreProcess(
250+
frame, track_bboxs, self._hrnet_cfg, num_persons,
251+
)
252+
inputs = inputs[:, [2, 1, 0]]
253+
254+
if torch.cuda.is_available():
255+
inputs = inputs.cuda()
256+
output = self._pose_model(inputs)
257+
258+
preds, maxvals = get_final_preds(
259+
self._hrnet_cfg,
260+
output.clone().cpu().numpy(),
261+
np.asarray(center),
262+
np.asarray(scale),
263+
)
264+
265+
kpts = np.zeros((num_persons, 17, 2), dtype=np.float32)
266+
scores = np.zeros((num_persons, 17), dtype=np.float32)
267+
for i, kpt in enumerate(preds):
268+
kpts[i] = kpt
269+
for i, score in enumerate(maxvals):
270+
scores[i] = score.squeeze()
271+
272+
return kpts, scores
273+

0 commit comments

Comments
 (0)