Skip to content

Commit 003b703

Browse files
TongFuKitwareebrahimebrahim
authored andcommitted
Implement virtual fitting algorithm (OpenwaterHealth#147)
Squashed from: - Fix comment for the parameter steering limits - Redefine the function to extract skin surface - Add functions for the conversion between lps and spherical coordinates - Add function to compute the pose of transducer - Add function to analyse the target position in transducer coordinate - Add structure to get optimal transform - Remove no used comments
1 parent c2cd92b commit 003b703

1 file changed

Lines changed: 233 additions & 40 deletions

File tree

src/openlifu/vf/virtual_fit.py

Lines changed: 233 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Optional, Tuple
44

55
import numpy as np
6+
import scipy.interpolate
67
import xarray as xa
78

89
from openlifu.db.session import ArrayTransform
@@ -20,7 +21,7 @@ class VirtualFit:
2021
finding the optimal transducer transform (position and orientation)
2122
given an input MRI volume in LPS coordinates and the associated target.
2223
"""
23-
pitch_range: Tuple[int, int] = (10, 40)
24+
pitch_range: Tuple[int, int] = (-10, 40)
2425
"""The pitch range for the grid search."""
2526

2627
pitch_step: int = 3
@@ -35,46 +36,198 @@ class VirtualFit:
3536
search_range_units: str = "deg"
3637
"""Search grid units."""
3738

38-
steering_limits: Tuple[TargetConstraints] = field(default_factory=list)
39-
"""Defines the accepteable range for a target in the transducer space, usually LPS."""
39+
radius_in_mm: float = 50
40+
"""Radius from transducer"""
41+
42+
steering_limits: Tuple[TargetConstraints, TargetConstraints, TargetConstraints] = (TargetConstraints(), TargetConstraints(), TargetConstraints())
43+
"""Defines the steering range limits for the transducer in the local coordinate system, usually in (lat, ele, ax)."""
4044

4145
blocked_elems_threshold: float = 0.1
4246
"""How much blocked elements are acceptable."""
4347

4448
volume: xa.Dataset = field(default_factory=xa.Dataset)
4549
"""The MRI volume in LPS coordinates, on which to optimize the position."""
4650

51+
scene_matrix: np.ndarray = field(default_factory=lambda: np.eye(3))
52+
"""The transform represents the MRI volume scene"""
53+
54+
scene_origin: Tuple[float, float, float] = (0, 0, 0)
55+
"""The origin point of the the MRI volume scene"""
56+
4757
transducer: Transducer = field(default_factory=Transducer)
4858
"""Transducer that sits on the skin."""
4959

5060
def __post_init__(self):
5161
self.logger = logging.getLogger(__name__)
52-
"""The VirtualFit logger."""
5362
self.logger.info(f"Initializing VirtualFit with the following parameters: {self.__dict__}")
5463
self.logger.info("VirtualFit: Skin extraction...")
5564
# 1. extract skin surface, this is done only once at initialization
56-
# self.skin_surface = self.extract_skin_surface(volume: xa.Dataset)
57-
"""A list of vertices representing the skin surface."""
65+
# self.skin_origin, self.skin_surface, self.skin_interpolator = self.extract_skin_surface(volume: xa.Dataset)
5866

59-
def extract_skin_surface(self, volume: xa.Dataset, quantile: float = 0.05):
60-
#TODO: basic thresholding + convex hull
61-
# from scipy.spatial import ConvexHull
67+
def extract_skin_surface(
68+
self,
69+
volume: xa.Dataset,
70+
quantile: float = 0.05,
71+
scene_origin: Optional[Tuple[float, float, float]] = None,
72+
scene_matrix: Optional[np.ndarray] = None):
73+
"""
74+
Extract skin surface from MRI volume in LPS coordinates
75+
76+
Args:
77+
volume: The MRI volume in LPS coordinates.
78+
Target is expected to be in the simulation grid coordinates (lat, ele, ax).
79+
quantile: The threshold to define the surface.
80+
scene_origin: The origin of the scene
81+
scene_matrix: The transform of the scene
82+
83+
Returns:
84+
skin_origin: The origin of the skin
85+
skin_surface: The list of points represent the skin surface in LPS coordinates
86+
skin_interpolator: An interpolatoor represents the skin surface in spherical coordinates (pitch, yaw, r)
87+
"""
88+
if scene_origin is None:
89+
scene_origin = self.scene_origin
90+
if scene_matrix is None:
91+
scene_matrix = self.scene_matrix
92+
# -> Tuple[float, float, float], np.ndarray, scipy.interpolate.LinearNDInterpolator]
93+
#TODO: Segmentation (basic thresholding)
6294
# threshold = np.quantile(volume, 0.05) #TODO: check otsu threhsolding instead
6395
# volume_thresholded = volume[volume > threshold]
64-
#
65-
# return ConvexHull(volume)
96+
97+
#TODO: option1 Intepolant + list of points
98+
# from scipy.interpolate import LinearNDInterpolator
99+
# return Tuple[float, float, float], np.ndarray, skin_interpolator(scipy.interpolate.LinearNDInterpolator)
100+
101+
#TODO: option2 ConvexHull (combine interpolant and list of points)
102+
# from scipy.spatial import ConvexHull
103+
# return Tuple[float, float, float], ConvexHull(volume)
66104
pass
67105

68-
def fit_to_surface(
106+
def pyr2lps(self, pitch: float, yaw: float, r: float, origin: Tuple[float, float, float] = (0, 0, 0)):
107+
"""
108+
Convert spherical coordinates to LPS coordinates
109+
"""
110+
pitch_rad = np.deg2rad(180 - pitch)
111+
yaw_rad = np.deg2rad(yaw)
112+
p = r * np.cos(yaw_rad) * np.cos(pitch_rad)
113+
s = r * np.cos(yaw_rad) * np.sin(pitch_rad)
114+
l = r * np.sin(yaw_rad)
115+
return l + origin[0], p + origin[1], s + origin[2]
116+
117+
def lps2pyr(self, l: float, p:float, s:float, origin: Tuple[float, float, float] = (0, 0, 0)):
118+
"""
119+
Convert LPS coordinates to spherical coordinates
120+
"""
121+
x = p - origin[1]
122+
y = s - origin[2]
123+
z = l - origin[0]
124+
th = np.arctan2(y, x)
125+
phi = np.arctan2(z, np.sqrt(x**2 + y**2))
126+
r = np.sqrt(x**2 + y**2 + z**2)
127+
pitch = (360 - np.degrees(th)) % 360 - 180
128+
yaw = np.degrees(phi)
129+
return pitch, yaw, r
130+
131+
def get_transducer_pose(
69132
self,
70133
sph_coords: Tuple[float, float],
71-
skin_surface: np.ndarray
72-
) -> np.ndarray:
134+
skin_origin: Optional[Tuple[float, float, float]] = None,
135+
skin_interpolator: Optional[scipy.interpolate.LinearNDInterpolator] = None,
136+
z_offset: float = 13.55,
137+
dzdy: float = 0.15,
138+
search_x: float = 20,
139+
search_dx: float = 1,
140+
search_y: float = 20,
141+
search_dy: float = 1) -> np.ndarray:
73142
"""
74-
Fit a 3D plane plane given spherical coordinates (yaw, pitch)
75-
and a set of points coordinates LPS.
143+
Computes the pose of the transducer positioned at a point on the segmented skin surface
144+
defined by spherical coordinates (pitch, yaw).
145+
146+
Args:
147+
sph_coords: Spherical coordinates (pitch, yaw) in degrees.
148+
pitch: Angle above the S=0 "eye" line (rotation about the "L" axis).
149+
yaw: Angle along the pitched circle towards the subject's left ear
150+
(rotation about the "S*" axis).
151+
skin_origin: The skin surface origin
152+
skin_interpolant: Function mapping spherical coordinates to radial distance.
153+
z_offset: Distance of transducer from skin surface (mm)
154+
dzdy: Slope of transducer away from skin surface. Default is 0.15 (bottom of transducer is raised 15% relative to top)
155+
search_x: Lateral (yaw) ROI extent for surface fitting (one-sided, mm).
156+
search_dx: Lateral (yaw) ROI step size for surface fitting (one-sided, mm)
157+
search_y: Elevation (pitch) ROI extent for surface fitting (one-sided, mm)
158+
search_dy: Elevation (pitch) ROI step size for surface fitting (one-sided, mm)
159+
160+
Returns:
161+
np.ndarray
162+
4x4 transformation matrix representing the transducer's pose in terms of position and orientation (lat, ele, ax).
76163
"""
77-
pass
164+
165+
# Get input arguments
166+
pitch, yaw = sph_coords
167+
# Decomment these lines when the function extract_skin_surface is implemented
168+
# if skin_origin is None:
169+
# skin_origin = self.skin_origin
170+
# if skin_interpolator is None:
171+
# skin_interpolator = self.skin_interpolator
172+
173+
# Compute skin surface origin and local coordinates
174+
r = skin_interpolator(pitch, yaw)
175+
l, p, s = self.pyr2lps(pitch, yaw, r, skin_origin)
176+
transducer_origin = np.array([l, p, s])
177+
178+
# Set up local unit vectors for ROI definition
179+
roi_uv = [None] * 3
180+
roi_uv[2] = -transducer_origin / np.linalg.norm(transducer_origin, 2)
181+
l1, p1, s1 = self.pyr2lps(pitch, yaw - 1, r, skin_origin)
182+
roi_uv[0] = np.array([l1, p1, s1]) - transducer_origin
183+
roi_uv[0] -= roi_uv[2] * np.dot(roi_uv[0], roi_uv[2])
184+
roi_uv[0] /= np.linalg.norm(roi_uv[0], 2)
185+
roi_uv[1] = np.cross(roi_uv[2], roi_uv[0])
186+
# Create matrix
187+
roi_matrix = np.eye(4)
188+
roi_matrix[:3, :3] = np.column_stack(roi_uv)
189+
roi_matrix[:3, 3] = transducer_origin
190+
roi_forward_matrix = np.linalg.pinv(roi_matrix)
191+
192+
# Search grid of transducer plane and surface fitting
193+
dx_sequence = np.arange(-search_x, search_x + search_dx, search_dx)
194+
dy_sequence = np.arange(-search_y, search_y + search_dy, search_dy)
195+
dx_grid, dy_grid = np.meshgrid(dx_sequence, dy_sequence, indexing='ij')
196+
roi_grid = np.array([l, p, s]) + np.outer(dx_grid, roi_uv[0]) + np.outer(dy_grid, roi_uv[1])
197+
# Convert search grid to pitch-yaw
198+
roi_pgrid = [self.lps2pyr(grid[0], grid[1], grid[2], skin_origin) for grid in roi_grid]
199+
# Get surface grid
200+
surf_pgrid = roi_pgrid.copy()
201+
surf_pgrid = [[*grid[:2], skin_interpolator(grid[0], grid[1]).item()] for grid in roi_pgrid]
202+
surf_lps = [self.pyr2lps(grid[0], grid[1], grid[2], skin_origin) for grid in surf_pgrid]
203+
# Get surface grid in local coords
204+
surf_lps_vec = np.hstack([surf_lps, np.ones((len(surf_lps), 1))]).T
205+
surf_xyz = roi_forward_matrix @ surf_lps_vec
206+
207+
# Fit plane
208+
plane_fit = np.linalg.lstsq(surf_xyz[:2, :].T, surf_xyz[2, :], rcond=None)[0]
209+
# Get plane-fit unit vectors and convert to LPS
210+
plane_matrix_xyz = np.column_stack([[1, 0, plane_fit[0]], [0, 1, plane_fit[1]], [0, 0, 1]])
211+
plane_matrix_xyz /= np.linalg.norm(plane_matrix_xyz, axis=0)
212+
plane_matrix_xyz[:, 2] = np.cross(plane_matrix_xyz[:, 0], plane_matrix_xyz[:, 1])
213+
plane_matrix = np.eye(4)
214+
plane_matrix[:3, :3] = roi_matrix[:3, :3] @ plane_matrix_xyz
215+
plane_matrix[:3, 3] = transducer_origin
216+
217+
# Get offset transducer unit vectors & origin
218+
transducer_origin = transducer_origin - plane_matrix[:3, 2] * z_offset + plane_matrix[:3, 1] * z_offset * dzdy
219+
transducer_uv = [None] * 3
220+
transducer_uv[0] = plane_matrix[:3, 0]
221+
transducer_uv[1] = plane_matrix[:3, 1] + dzdy * plane_matrix[:3, 2]
222+
transducer_uv[1] /= np.linalg.norm(transducer_uv[1], 2)
223+
transducer_uv[2] = np.cross(transducer_uv[0], transducer_uv[1])
224+
225+
# Create matrix
226+
transducer_pose = np.eye(4)
227+
transducer_pose[:3, :3] = np.column_stack(transducer_uv)
228+
transducer_pose[:3, 3] = transducer_origin
229+
230+
return transducer_pose
78231

79232
def get_search_grid(
80233
self,
@@ -92,28 +245,55 @@ def get_search_grid(
92245

93246
return pitch_yaw_grid
94247

95-
def analyse_position(self, pos: np.ndarray, transducer: Transducer, target: Point):
248+
def analyse_target_position(
249+
self,
250+
target: Point,
251+
transducer_pose: np.ndarray,
252+
radius_in_mm: Optional[float] = None,
253+
steering_limits: Optional[Tuple[TargetConstraints, TargetConstraints, TargetConstraints]] = None):
96254
"""
97-
Analyse the transducer position relative to a specific target.
255+
Analyzes the pose of a transducer relative to a specific target point.
256+
Determines whether or not the target is within the transducer's steering limits and
257+
computes the steering distance.
258+
259+
Args:
260+
target: The target point
261+
transducer_pose : A 4x4 transformation matrix representing the transducer's pose
262+
radius_in_mm: Radius of the transducer in millimeters.
263+
steeringLimits: Steering range limits for the transducer in the local coordinate system (lat, ele, ax)
264+
265+
Returns:
266+
in_bounds: A boolean indicating whether the target is within the steering limits.
267+
steering_dist: The Euclidean distance from the transducer's center to the target in the local coordinate system.
98268
"""
99-
#TODO: Compute if target is within steering limits
100269
#TODO: In the future, we should implement the ray-tracing analysis given a full segmentation
101270

102-
# pos_analysis = 1.0
103-
# target_tr_space = target2trspace(pos, target)
104-
# for target_constraint in self.steering_limits:
105-
# pos = target_tr_space.get_position(
106-
# dim=target_constraint.dim,
107-
# units=target_constraint.units
108-
# )
109-
# try:
110-
# target_constraint.check_bounds(pos)
111-
# except ValueError:
112-
# pos_analysis = 0.0
113-
#
114-
# return pos_analysis
271+
# Get transducer parameters
272+
if radius_in_mm is None:
273+
radius_in_mm = self.radius_in_mm
274+
if steering_limits is None:
275+
steering_limits = self.steering_limits
276+
115277

116-
pass
278+
# Transform target position into local coordinate of transducer
279+
homogeneous_target_position = np.append(target.position, 1)
280+
transducer_forward_matrix = np.linalg.pinv(transducer_pose)
281+
target_pos_local = transducer_forward_matrix @ homogeneous_target_position
282+
pos = target_pos_local[:3]
283+
pos[2] -= radius_in_mm
284+
285+
# Calculate steering distance
286+
steering_dist = np.linalg.norm(pos)
287+
288+
# Check if the target point is within the steering limits
289+
in_bounds = True
290+
for i, target_constraint in enumerate(steering_limits):
291+
try:
292+
target_constraint.check_bounds(pos[i])
293+
except ValueError:
294+
in_bounds = False
295+
296+
return in_bounds, steering_dist
117297

118298
def run(
119299
self,
@@ -122,6 +302,7 @@ def run(
122302
pitch_step: Optional[int] = None,
123303
yaw_range: Optional[Tuple[int, int]] = None,
124304
yaw_step: Optional[int] = None,
305+
radius_in_mm: Optional[float] = None,
125306
steering_limits: Optional[Tuple[TargetConstraints]] = None,
126307
blocked_elems_threshold: Optional[float] = None
127308
) -> ArrayTransform:
@@ -140,6 +321,8 @@ def run(
140321
yaw_range = self.yaw_range
141322
if yaw_step is None:
142323
yaw_step = self.yaw_step
324+
if radius_in_mm is None:
325+
radius_in_mm = self.radius_in_mm
143326
if steering_limits is None:
144327
steering_limits = self.steering_limits
145328
if blocked_elems_threshold is None:
@@ -149,15 +332,25 @@ def run(
149332
self.logger.info("VirtualFit: Searching optimal position...")
150333
# 2. get search grid
151334
search_grid = self.get_search_grid(yaw_range, yaw_step, pitch_range, pitch_step)
335+
transducer_poses = np.empty(search_grid[0].shape, dtype=object)
336+
in_bounds = np.zeros_like(search_grid[0])
337+
steering_dists = np.zeros_like(search_grid[0])
152338
for i in range(search_grid[0].shape[0]):
153339
for j in range(search_grid[0].shape[1]):
154-
yaw, pitch = (search_grid[0][i, j], search_grid[1][i, j])
155-
self.logger.info(f"VirtualFit: Analysing {(yaw, pitch)}...")
156-
# 3. define transducer transform (plane fitting) on the surface (skin) given spherical coordinate (yaw, pitch)
157-
# self.fit_to_surface(sph_coords: Tuple[float, float], skin_surface: np.ndarray)
340+
pitch, yaw = (search_grid[0][i, j], search_grid[1][i, j])
341+
self.logger.info(f"VirtualFit: Analysing {(pitch, yaw)}...")
342+
# 3. define transducer transform (plane fitting) on the surface (skin) given spherical coordinate (pitch, yaw)
343+
transducer_poses[i, j] = self.get_transducer_pose([pitch, yaw])
158344
# 4. analyse current transform
159-
# self.analyse_position(pos: np.ndarray, transducer: Transducer, target: Point)
160-
optimal_transform = np.zeros((4, 4))
345+
in_bounds[i, j], steering_dists[i, j] = self.analyse_target_position(transducer_poses[i, j], target)
346+
# 5. get optimal transform
347+
optimal_transform = None
348+
for i in range(in_bounds.shape[0]):
349+
for j in range(in_bounds.shape[1]):
350+
if in_bounds[i, j]:
351+
#TODO: Check blocked element
352+
# self.check_blocked_elements()
353+
optimal_transform = transducer_poses[i, j]
161354
self.logger.info("VirtualFit: Found optimal position!")
162355

163356
return optimal_transform

0 commit comments

Comments
 (0)