|
| 1 | +""" |
| 2 | +FMPose3D – clean HRNet 2D pose estimation API. |
| 3 | +
|
| 4 | +Provides :class:`HRNetPose2d`, a self-contained wrapper around the |
| 5 | +HRNet + YOLO detection pipeline that accepts numpy arrays directly |
| 6 | +(no file I/O, no argparse, no global yacs config leaking out). |
| 7 | +
|
| 8 | +Usage:: |
| 9 | +
|
| 10 | + api = HRNetPose2d(det_dim=416, num_persons=1) |
| 11 | + api.setup() # loads YOLO + HRNet weights |
| 12 | + keypoints, scores = api.predict(frames) # (M, N, 17, 2), (M, N, 17) |
| 13 | +""" |
| 14 | + |
| 15 | +from __future__ import annotations |
| 16 | + |
| 17 | +import copy |
| 18 | +import os.path as osp |
| 19 | +from collections import OrderedDict |
| 20 | +from typing import Tuple |
| 21 | + |
| 22 | +import numpy as np |
| 23 | +import torch |
| 24 | +import torch.backends.cudnn as cudnn |
| 25 | + |
| 26 | +from fmpose3d.lib.checkpoint.download_checkpoints import ( |
| 27 | + ensure_checkpoints, |
| 28 | + get_checkpoint_path, |
| 29 | +) |
| 30 | + |
| 31 | + |
| 32 | +class HRNetPose2d: |
| 33 | + """Self-contained 2D pose estimator (YOLO detector + HRNet). |
| 34 | +
|
| 35 | + A self-contained HRNet 2D pose estimator that accepts numpy arrays directly. |
| 36 | + It serves as alternative to the gen_video_kpts function in fmpose3d/lib/hrnet/gen_kpts.py, |
| 37 | + which generates 2D keypoints from a video file. |
| 38 | +
|
| 39 | + Parameters |
| 40 | + ---------- |
| 41 | + det_dim : int |
| 42 | + YOLO input resolution (default 416). |
| 43 | + num_persons : int |
| 44 | + Maximum number of persons to track per frame (default 1). |
| 45 | + thred_score : float |
| 46 | + YOLO object-confidence threshold (default 0.30). |
| 47 | + hrnet_cfg_file : str |
| 48 | + Path to the HRNet YAML experiment config. Empty string (default) |
| 49 | + uses the bundled ``w48_384x288_adam_lr1e-3.yaml``. |
| 50 | + hrnet_weights_path : str |
| 51 | + Path to the HRNet ``.pth`` checkpoint. Empty string (default) |
| 52 | + uses the auto-downloaded ``pose_hrnet_w48_384x288.pth``. |
| 53 | + """ |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + det_dim: int = 416, |
| 58 | + num_persons: int = 1, |
| 59 | + thred_score: float = 0.30, |
| 60 | + hrnet_cfg_file: str = "", |
| 61 | + hrnet_weights_path: str = "", |
| 62 | + ) -> None: |
| 63 | + self.det_dim = det_dim |
| 64 | + self.num_persons = num_persons |
| 65 | + self.thred_score = thred_score |
| 66 | + self.hrnet_cfg_file = hrnet_cfg_file |
| 67 | + self.hrnet_weights_path = hrnet_weights_path |
| 68 | + |
| 69 | + # Populated by setup() |
| 70 | + self._human_model = None |
| 71 | + self._pose_model = None |
| 72 | + self._people_sort = None |
| 73 | + self._hrnet_cfg = None # frozen yacs CfgNode used by PreProcess / get_final_preds |
| 74 | + |
| 75 | + # ------------------------------------------------------------------ |
| 76 | + # Setup |
| 77 | + # ------------------------------------------------------------------ |
| 78 | + |
| 79 | + @property |
| 80 | + def is_ready(self) -> bool: |
| 81 | + """``True`` once :meth:`setup` has been called.""" |
| 82 | + return self._human_model is not None |
| 83 | + |
| 84 | + def setup(self) -> "HRNetPose2d": |
| 85 | + """Load YOLO detector and HRNet pose model. |
| 86 | +
|
| 87 | + Can safely be called more than once (subsequent calls are no-ops). |
| 88 | +
|
| 89 | + Returns ``self`` so you can write ``api = HRNetPose2d().setup()``. |
| 90 | + """ |
| 91 | + if self.is_ready: |
| 92 | + return self |
| 93 | + |
| 94 | + ensure_checkpoints() |
| 95 | + |
| 96 | + # --- resolve paths --------------------------------------------------- |
| 97 | + hrnet_cfg_file = self.hrnet_cfg_file |
| 98 | + if not hrnet_cfg_file: |
| 99 | + hrnet_cfg_file = osp.join( |
| 100 | + osp.dirname(osp.abspath(__file__)), |
| 101 | + "experiments", |
| 102 | + "w48_384x288_adam_lr1e-3.yaml", |
| 103 | + ) |
| 104 | + |
| 105 | + hrnet_weights = self.hrnet_weights_path |
| 106 | + if not hrnet_weights: |
| 107 | + hrnet_weights = get_checkpoint_path("pose_hrnet_w48_384x288.pth") |
| 108 | + |
| 109 | + # --- build internal yacs config (kept private) ----------------------- |
| 110 | + from fmpose3d.lib.hrnet.lib.config import cfg as _global_cfg |
| 111 | + from fmpose3d.lib.hrnet.lib.config import update_config as _update_cfg |
| 112 | + from types import SimpleNamespace |
| 113 | + |
| 114 | + _global_cfg.defrost() |
| 115 | + _update_cfg( |
| 116 | + _global_cfg, |
| 117 | + SimpleNamespace(cfg=hrnet_cfg_file, opts=[], modelDir=hrnet_weights), |
| 118 | + ) |
| 119 | + # Snapshot the frozen cfg so we can pass it to PreProcess / get_final_preds. |
| 120 | + self._hrnet_cfg = _global_cfg |
| 121 | + |
| 122 | + # cudnn tuning |
| 123 | + cudnn.benchmark = self._hrnet_cfg.CUDNN.BENCHMARK |
| 124 | + cudnn.deterministic = self._hrnet_cfg.CUDNN.DETERMINISTIC |
| 125 | + cudnn.enabled = self._hrnet_cfg.CUDNN.ENABLED |
| 126 | + |
| 127 | + # --- load models ----------------------------------------------------- |
| 128 | + from fmpose3d.lib.yolov3.human_detector import load_model as _yolo_load |
| 129 | + from fmpose3d.lib.sort.sort import Sort |
| 130 | + |
| 131 | + self._human_model = _yolo_load(inp_dim=self.det_dim) |
| 132 | + self._pose_model = self._load_hrnet(self._hrnet_cfg) |
| 133 | + self._people_sort = Sort(min_hits=0) |
| 134 | + |
| 135 | + return self |
| 136 | + |
| 137 | + # ------------------------------------------------------------------ |
| 138 | + # Prediction |
| 139 | + # ------------------------------------------------------------------ |
| 140 | + |
| 141 | + def predict( |
| 142 | + self, frames: np.ndarray |
| 143 | + ) -> Tuple[np.ndarray, np.ndarray]: |
| 144 | + """Estimate 2D keypoints for a batch of BGR frames. |
| 145 | +
|
| 146 | + Parameters |
| 147 | + ---------- |
| 148 | + frames : ndarray, shape ``(N, H, W, C)`` |
| 149 | + BGR images. A single frame ``(H, W, C)`` is also accepted |
| 150 | + and will be treated as a batch of one. |
| 151 | +
|
| 152 | + Returns |
| 153 | + ------- |
| 154 | + keypoints : ndarray, shape ``(num_persons, N, 17, 2)`` |
| 155 | + COCO-format 2D keypoints in pixel coordinates. |
| 156 | + scores : ndarray, shape ``(num_persons, N, 17)`` |
| 157 | + Per-joint confidence scores. |
| 158 | + """ |
| 159 | + if not self.is_ready: |
| 160 | + self.setup() |
| 161 | + |
| 162 | + if frames.ndim == 3: |
| 163 | + frames = frames[np.newaxis] |
| 164 | + |
| 165 | + kpts_result = [] |
| 166 | + scores_result = [] |
| 167 | + |
| 168 | + for i in range(frames.shape[0]): |
| 169 | + kpts, sc = self._estimate_frame(frames[i]) |
| 170 | + kpts_result.append(kpts) |
| 171 | + scores_result.append(sc) |
| 172 | + |
| 173 | + keypoints = np.array(kpts_result) # (N, M, 17, 2) |
| 174 | + scores = np.array(scores_result) # (N, M, 17) |
| 175 | + |
| 176 | + # (N, M, 17, 2) → (M, N, 17, 2) |
| 177 | + keypoints = keypoints.transpose(1, 0, 2, 3) |
| 178 | + # (N, M, 17) → (M, N, 17) |
| 179 | + scores = scores.transpose(1, 0, 2) |
| 180 | + |
| 181 | + return keypoints, scores |
| 182 | + |
| 183 | + # ------------------------------------------------------------------ |
| 184 | + # Internal helpers |
| 185 | + # ------------------------------------------------------------------ |
| 186 | + |
| 187 | + @staticmethod |
| 188 | + def _load_hrnet(config): |
| 189 | + """Instantiate HRNet and load checkpoint weights.""" |
| 190 | + from fmpose3d.lib.hrnet.lib.models import pose_hrnet |
| 191 | + |
| 192 | + model = pose_hrnet.get_pose_net(config, is_train=False) |
| 193 | + if torch.cuda.is_available(): |
| 194 | + model = model.cuda() |
| 195 | + |
| 196 | + state_dict = torch.load(config.OUTPUT_DIR, weights_only=True) |
| 197 | + new_state_dict = OrderedDict() |
| 198 | + for k, v in state_dict.items(): |
| 199 | + new_state_dict[k] = v |
| 200 | + model.load_state_dict(new_state_dict) |
| 201 | + model.eval() |
| 202 | + return model |
| 203 | + |
| 204 | + def _estimate_frame( |
| 205 | + self, frame: np.ndarray |
| 206 | + ) -> Tuple[np.ndarray, np.ndarray]: |
| 207 | + """Run detection + pose estimation on a single BGR frame. |
| 208 | +
|
| 209 | + Returns |
| 210 | + ------- |
| 211 | + kpts : ndarray, shape ``(num_persons, 17, 2)`` |
| 212 | + scores : ndarray, shape ``(num_persons, 17)`` |
| 213 | + """ |
| 214 | + from fmpose3d.lib.yolov3.human_detector import yolo_human_det |
| 215 | + from fmpose3d.lib.hrnet.lib.utils.utilitys import PreProcess |
| 216 | + from fmpose3d.lib.hrnet.lib.utils.inference import get_final_preds |
| 217 | + |
| 218 | + num_persons = self.num_persons |
| 219 | + |
| 220 | + bboxs, det_scores = yolo_human_det( |
| 221 | + frame, self._human_model, reso=self.det_dim, confidence=self.thred_score, |
| 222 | + ) |
| 223 | + |
| 224 | + if bboxs is None or not bboxs.any(): |
| 225 | + # No detection – return zeros |
| 226 | + kpts = np.zeros((num_persons, 17, 2), dtype=np.float32) |
| 227 | + scores = np.zeros((num_persons, 17), dtype=np.float32) |
| 228 | + return kpts, scores |
| 229 | + |
| 230 | + # Track |
| 231 | + people_track = self._people_sort.update(bboxs) |
| 232 | + |
| 233 | + if people_track.shape[0] == 1: |
| 234 | + people_track_ = people_track[-1, :-1].reshape(1, 4) |
| 235 | + elif people_track.shape[0] >= 2: |
| 236 | + people_track_ = people_track[-num_persons:, :-1].reshape(num_persons, 4) |
| 237 | + people_track_ = people_track_[::-1] |
| 238 | + else: |
| 239 | + kpts = np.zeros((num_persons, 17, 2), dtype=np.float32) |
| 240 | + scores = np.zeros((num_persons, 17), dtype=np.float32) |
| 241 | + return kpts, scores |
| 242 | + |
| 243 | + track_bboxs = [] |
| 244 | + for bbox in people_track_: |
| 245 | + bbox = [round(i, 2) for i in list(bbox)] |
| 246 | + track_bboxs.append(bbox) |
| 247 | + |
| 248 | + with torch.no_grad(): |
| 249 | + inputs, origin_img, center, scale = PreProcess( |
| 250 | + frame, track_bboxs, self._hrnet_cfg, num_persons, |
| 251 | + ) |
| 252 | + inputs = inputs[:, [2, 1, 0]] |
| 253 | + |
| 254 | + if torch.cuda.is_available(): |
| 255 | + inputs = inputs.cuda() |
| 256 | + output = self._pose_model(inputs) |
| 257 | + |
| 258 | + preds, maxvals = get_final_preds( |
| 259 | + self._hrnet_cfg, |
| 260 | + output.clone().cpu().numpy(), |
| 261 | + np.asarray(center), |
| 262 | + np.asarray(scale), |
| 263 | + ) |
| 264 | + |
| 265 | + kpts = np.zeros((num_persons, 17, 2), dtype=np.float32) |
| 266 | + scores = np.zeros((num_persons, 17), dtype=np.float32) |
| 267 | + for i, kpt in enumerate(preds): |
| 268 | + kpts[i] = kpt |
| 269 | + for i, score in enumerate(maxvals): |
| 270 | + scores[i] = score.squeeze() |
| 271 | + |
| 272 | + return kpts, scores |
| 273 | + |
0 commit comments