Skip to content

Commit c2cd92b

Browse files
ltetrelebrahimebrahim
authored andcommitted
Set up virtual fitting algorithm framework (OpenwaterHealth#147)
Squashed from: - Initial implementation for the virtual fitting class. - precising on which coordinates the volume is - Improving docstring to give more information on target coordinate space - better typing for transform output from run method - update volume and target coordinates information in docstrings
1 parent f941ba7 commit c2cd92b

4 files changed

Lines changed: 192 additions & 0 deletions

File tree

src/openlifu/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
seg_methods,
3535
)
3636
from openlifu.sim import SimSetup
37+
from openlifu.vf import VirtualFit
3738
from openlifu.xdc import Transducer
3839

3940
from ._version import version as __version__
@@ -43,6 +44,7 @@
4344
"Transducer",
4445
"Protocol",
4546
"Solution",
47+
"VirtualFit",
4648
"Material",
4749
"SegmentationMethod",
4850
"seg_methods",

src/openlifu/vf/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .virtual_fit import VirtualFit
2+
3+
__all__ = [
4+
"VirtualFit"
5+
]

src/openlifu/vf/virtual_fit.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import logging
2+
from dataclasses import dataclass, field
3+
from typing import Optional, Tuple
4+
5+
import numpy as np
6+
import xarray as xa
7+
8+
from openlifu.db.session import ArrayTransform
9+
from openlifu.geo import Point
10+
from openlifu.plan import TargetConstraints
11+
from openlifu.xdc import Transducer
12+
13+
14+
@dataclass
15+
class VirtualFit:
16+
"""
17+
VirtualFit class.
18+
19+
Represents the virtual fitting algorithm which consists in
20+
finding the optimal transducer transform (position and orientation)
21+
given an input MRI volume in LPS coordinates and the associated target.
22+
"""
23+
pitch_range: Tuple[int, int] = (10, 40)
24+
"""The pitch range for the grid search."""
25+
26+
pitch_step: int = 3
27+
"""The pitch step for the grid search."""
28+
29+
yaw_range: Tuple[int, int] = (-5, 25)
30+
"""The yaw range for the grid search."""
31+
32+
yaw_step: int = 3
33+
"""The yaw step for the grid search."""
34+
35+
search_range_units: str = "deg"
36+
"""Search grid units."""
37+
38+
steering_limits: Tuple[TargetConstraints] = field(default_factory=list)
39+
"""Defines the accepteable range for a target in the transducer space, usually LPS."""
40+
41+
blocked_elems_threshold: float = 0.1
42+
"""How much blocked elements are acceptable."""
43+
44+
volume: xa.Dataset = field(default_factory=xa.Dataset)
45+
"""The MRI volume in LPS coordinates, on which to optimize the position."""
46+
47+
transducer: Transducer = field(default_factory=Transducer)
48+
"""Transducer that sits on the skin."""
49+
50+
def __post_init__(self):
51+
self.logger = logging.getLogger(__name__)
52+
"""The VirtualFit logger."""
53+
self.logger.info(f"Initializing VirtualFit with the following parameters: {self.__dict__}")
54+
self.logger.info("VirtualFit: Skin extraction...")
55+
# 1. extract skin surface, this is done only once at initialization
56+
# self.skin_surface = self.extract_skin_surface(volume: xa.Dataset)
57+
"""A list of vertices representing the skin surface."""
58+
59+
def extract_skin_surface(self, volume: xa.Dataset, quantile: float = 0.05):
60+
#TODO: basic thresholding + convex hull
61+
# from scipy.spatial import ConvexHull
62+
# threshold = np.quantile(volume, 0.05) #TODO: check otsu threhsolding instead
63+
# volume_thresholded = volume[volume > threshold]
64+
#
65+
# return ConvexHull(volume)
66+
pass
67+
68+
def fit_to_surface(
69+
self,
70+
sph_coords: Tuple[float, float],
71+
skin_surface: np.ndarray
72+
) -> np.ndarray:
73+
"""
74+
Fit a 3D plane plane given spherical coordinates (yaw, pitch)
75+
and a set of points coordinates LPS.
76+
"""
77+
pass
78+
79+
def get_search_grid(
80+
self,
81+
yaw_range: Tuple[int, int],
82+
yaw_step: int,
83+
pitch_range: Tuple[int, int],
84+
pitch_step: int
85+
) -> np.ndarray:
86+
"""
87+
Defines the transducer search grid in (yaw, pitch) coordinates.
88+
"""
89+
yaw_sequence = np.arange(yaw_range[0], yaw_range[-1], yaw_step)
90+
pitch_sequence = np.arange(pitch_range[0], pitch_range[-1], pitch_step)
91+
pitch_yaw_grid = np.meshgrid(pitch_sequence, yaw_sequence, indexing="ij")
92+
93+
return pitch_yaw_grid
94+
95+
def analyse_position(self, pos: np.ndarray, transducer: Transducer, target: Point):
96+
"""
97+
Analyse the transducer position relative to a specific target.
98+
"""
99+
#TODO: Compute if target is within steering limits
100+
#TODO: In the future, we should implement the ray-tracing analysis given a full segmentation
101+
102+
# pos_analysis = 1.0
103+
# target_tr_space = target2trspace(pos, target)
104+
# for target_constraint in self.steering_limits:
105+
# pos = target_tr_space.get_position(
106+
# dim=target_constraint.dim,
107+
# units=target_constraint.units
108+
# )
109+
# try:
110+
# target_constraint.check_bounds(pos)
111+
# except ValueError:
112+
# pos_analysis = 0.0
113+
#
114+
# return pos_analysis
115+
116+
pass
117+
118+
def run(
119+
self,
120+
target: Point,
121+
pitch_range: Optional[Tuple[int, int]] = None,
122+
pitch_step: Optional[int] = None,
123+
yaw_range: Optional[Tuple[int, int]] = None,
124+
yaw_step: Optional[int] = None,
125+
steering_limits: Optional[Tuple[TargetConstraints]] = None,
126+
blocked_elems_threshold: Optional[float] = None
127+
) -> ArrayTransform:
128+
"""
129+
VirtualFit main process.
130+
131+
Finds the optimal transducer transform (position and orientation)
132+
given an input MRI volume in LPS coordinates, and the associated
133+
target in same coordinates LPS.
134+
"""
135+
if pitch_range is None:
136+
pitch_range = self.pitch_range
137+
if pitch_step is None:
138+
pitch_step = self.pitch_step
139+
if yaw_range is None:
140+
yaw_range = self.yaw_range
141+
if yaw_step is None:
142+
yaw_step = self.yaw_step
143+
if steering_limits is None:
144+
steering_limits = self.steering_limits
145+
if blocked_elems_threshold is None:
146+
blocked_elems_threshold = self.blocked_elems_threshold
147+
148+
self.logger.info("Running VirtualFit main process.")
149+
self.logger.info("VirtualFit: Searching optimal position...")
150+
# 2. get search grid
151+
search_grid = self.get_search_grid(yaw_range, yaw_step, pitch_range, pitch_step)
152+
for i in range(search_grid[0].shape[0]):
153+
for j in range(search_grid[0].shape[1]):
154+
yaw, pitch = (search_grid[0][i, j], search_grid[1][i, j])
155+
self.logger.info(f"VirtualFit: Analysing {(yaw, pitch)}...")
156+
# 3. define transducer transform (plane fitting) on the surface (skin) given spherical coordinate (yaw, pitch)
157+
# self.fit_to_surface(sph_coords: Tuple[float, float], skin_surface: np.ndarray)
158+
# 4. analyse current transform
159+
# self.analyse_position(pos: np.ndarray, transducer: Transducer, target: Point)
160+
optimal_transform = np.zeros((4, 4))
161+
self.logger.info("VirtualFit: Found optimal position!")
162+
163+
return optimal_transform

tests/test_virtual_fit.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import logging
2+
from pathlib import Path
3+
4+
import pytest
5+
6+
from openlifu.db import Session
7+
from openlifu.vf import VirtualFit
8+
9+
10+
@pytest.fixture()
11+
def example_session() -> Session:
12+
return Session.from_file(Path(__file__).parent/"resources/example_db/subjects/example_subject/sessions/example_session/example_session.json")
13+
14+
def test_virtual_fit(
15+
example_session: Session
16+
):
17+
"""Test if virtual fit runs."""
18+
logging.disable(logging.CRITICAL)
19+
20+
target = example_session.targets[0]
21+
vf = VirtualFit()
22+
vf.run(target)

0 commit comments

Comments
 (0)