Skip to content

Commit 261b101

Browse files
Port solution analysis functionality from matlab (OpenwaterHealth#105)
* [WIP] finishing up solution analysis * Fix default simulation parameters * update masking * fix TIC calculation * Fix ruff errors * remove stale references * Style changes to fix CI * Fix broken imports * remove spurious file * fix pulse sequences in unit tests * skip the unit tests that still need to be revamped * document and test get_effective_origin * add a few more docstrings * add a few more docstrings * remove unused function mask_focus * add more docstrings * make get_beam_bounds return tuple Makes type annotation more clear * add more docstrings * migrate the offset grid test to the new get_offset_grid unit test rescued from the graveyard! --------- Co-authored-by: Ebrahim Ebrahim <ebrahim.ebrahim@kitware.com>
1 parent d8b8255 commit 261b101

20 files changed

Lines changed: 433 additions & 457 deletions

notebooks/test_first.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
ds['p_min'].sel(ele=0).plot.imshow()
9393

9494
# %% [markdown]
95-
# We can examine the output object, which is an `xarray.DataSet` object with 3 data variables: `p_max` (Peak Positive Pressure), `p_min` (Peak Negative Pressure), and `ita` (Time Averaged Intensity). It's attributes also contain the `source` pulse (an `xarray.DataArray`), and `output`, the raw K-Wave output structure.
95+
# We can examine the output object, which is an `xarray.DataSet` object with 3 data variables: `p_max` (Peak Positive Pressure), `p_min` (Peak Negative Pressure), and `intensity` (Time Averaged Intensity). It's attributes also contain the `source` pulse (an `xarray.DataArray`), and `output`, the raw K-Wave output structure.
9696

9797
# %%
9898
ds

src/openlifu/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515
apod_methods,
1616
delay_methods,
1717
focal_patterns,
18-
get_beamwidth,
19-
mask_focus,
20-
offset_grid,
2118
)
2219
from openlifu.db import Database, User
2320

@@ -63,9 +60,6 @@
6360
"focal_patterns",
6461
"delay_methods",
6562
"apod_methods",
66-
"get_beamwidth",
67-
"mask_focus",
68-
"offset_grid",
6963
"SimSetup",
7064
"Database",
7165
"User",

src/openlifu/bf/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
from .apod_methods import ApodizationMethod
22
from .delay_methods import DelayMethod
33
from .focal_patterns import FocalPattern, SinglePoint, Wheel
4-
from .get_beamwidth import get_beamwidth
5-
from .mask_focus import mask_focus
6-
from .offset_grid import offset_grid
74
from .pulse import Pulse
85
from .sequence import Sequence
96

@@ -14,8 +11,5 @@
1411
"FocalPattern",
1512
"SinglePoint",
1613
"Pulse",
17-
"Sequence",
18-
"offset_grid",
19-
"mask_focus",
20-
"get_beamwidth"
14+
"Sequence"
2115
]

src/openlifu/bf/get_beamwidth.py

Lines changed: 0 additions & 138 deletions
This file was deleted.

src/openlifu/bf/mask_focus.py

Lines changed: 0 additions & 74 deletions
This file was deleted.

src/openlifu/bf/offset_grid.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

src/openlifu/bf/sequence.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@ class Sequence(DictMixin):
2020
pulse_train_interval: float = 1.0 # s
2121
pulse_train_count: int = 1
2222

23+
def __post_init__(self):
24+
if self.pulse_interval <= 0:
25+
raise ValueError("Pulse interval must be positive")
26+
if self.pulse_count <= 0:
27+
raise ValueError("Pulse count must be positive")
28+
if self.pulse_train_interval < 0:
29+
raise ValueError("Pulse train interval must be non-negative")
30+
elif (self.pulse_train_interval > 0) and (self.pulse_train_interval < (self.pulse_interval * self.pulse_count)):
31+
raise ValueError("Pulse train interval must be greater than or equal to the total pulse interval")
32+
if self.pulse_train_count <= 0:
33+
raise ValueError("Pulse train count must be positive")
34+
2335
def get_table(self) -> pd.DataFrame:
2436
"""
2537
Get a table of the sequence parameters

src/openlifu/geo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212
@dataclass
1313
class Point:
14+
position: np.ndarray = field(default_factory=lambda: np.array([0.0, 0.0, 0.0])) # mm
1415
id: str = "point"
1516
name: str = "Point"
1617
color: Any = (1.0, 0.0, 0.0)
1718
radius: float = 1. # mm
18-
position: np.ndarray = field(default_factory=lambda: np.array([0.0, 0.0, 0.0])) # mm
1919
dims: Tuple[str, str, str] = ("x","y","z")
2020
units: str = "mm"
2121

src/openlifu/plan/protocol.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def calc_solution(
252252
simulation_result_aggregated: xa.Dataset = xa.Dataset()
253253
scaled_solution_analysis: SolutionAnalysis = SolutionAnalysis()
254254
foci: List[Point] = self.focal_pattern.get_targets(target)
255-
simulation_cycles = np.max([np.round(self.pulse.duration * self.pulse.frequency), 20])
255+
simulation_cycles = np.min([np.round(self.pulse.duration * self.pulse.frequency), 20])
256256

257257
# updating solution sequence if pulse mismatch
258258
if (self.sequence.pulse_count % len(foci)) != 0:
@@ -273,7 +273,7 @@ def calc_solution(
273273
cycles = simulation_cycles,
274274
dt=sim_options.dt,
275275
t_end=sim_options.t_end,
276-
amplitude = 1,
276+
amplitude = self.pulse.amplitude,
277277
gpu = use_gpu
278278
)
279279
delays_to_stack.append(delays)
@@ -325,11 +325,11 @@ def calc_solution(
325325
pnp_aggregated = solution.simulation_result['p_min'].max(dim="focal_point_index")
326326
ppp_aggregated = solution.simulation_result['p_max'].max(dim="focal_point_index")
327327
# TODO: Ensure this mean is weighted by the number of times each point is focused on, once openlifu supports hitting points different numbers of times
328-
intensity_aggregated = solution.simulation_result['ita'].mean(dim="focal_point_index")
328+
intensity_aggregated = solution.simulation_result['intensity'].mean(dim="focal_point_index")
329329
simulation_result_aggregated = deepcopy(solution.simulation_result)
330330
simulation_result_aggregated = simulation_result_aggregated.drop_dims("focal_point_index")
331331
simulation_result_aggregated['p_min'] = pnp_aggregated
332332
simulation_result_aggregated['p_max'] = ppp_aggregated
333-
simulation_result_aggregated['ita'] = intensity_aggregated
333+
simulation_result_aggregated['intensity'] = intensity_aggregated
334334

335335
return solution, simulation_result_aggregated, scaled_solution_analysis

0 commit comments

Comments
 (0)