Skip to content

Commit d78e5ed

Browse files
ebrahimebrahimarhowe00
authored andcommitted
Fix solution analysis result data type (OpenwaterHealth#242)
1 parent e525730 commit d78e5ed

2 files changed

Lines changed: 16 additions & 4 deletions

File tree

src/openlifu/plan/solution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def analyze(self, transducer: Transducer, options: SolutionAnalysisOptions = Sol
212212
solution_analysis.p0_Pa += [np.max(p0_Pa)]
213213
solution_analysis.TIC = np.mean(TIC)
214214
solution_analysis.power_W = np.mean(power_W)
215-
solution_analysis.MI = solution_analysis.mainlobe_pnp_MPa/np.sqrt(self.pulse.frequency*1e-6)
215+
solution_analysis.MI = (solution_analysis.mainlobe_pnp_MPa/np.sqrt(self.pulse.frequency*1e-6)).item()
216216
solution_analysis.global_ispta_mWcm2 = float((ita_mWcm2*z_mask).max())
217217
return solution_analysis
218218

tests/test_solution.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from dataclasses import fields
34
from datetime import datetime, timedelta
45
from pathlib import Path
56

@@ -55,17 +56,17 @@ def example_solution() -> Solution:
5556
{
5657
'p_min': xa.DataArray(
5758
data=rng.random((1, 3, 2, 3)),
58-
dims=["focal_point_index", "x", "y", "z"],
59+
dims=["focal_point_index", "lat", "ele", "ax"],
5960
attrs={'units': "Pa"}
6061
),
6162
'p_max': xa.DataArray(
6263
data=rng.random((1, 3, 2, 3)),
63-
dims=["focal_point_index", "x", "y", "z"],
64+
dims=["focal_point_index", "lat", "ele", "ax"],
6465
attrs={'units': "Pa"}
6566
),
6667
'intensity': xa.DataArray(
6768
data=rng.random((1, 3, 2, 3)),
68-
dims=["focal_point_index", "x", "y", "z"],
69+
dims=["focal_point_index", "lat", "ele", "ax"],
6970
attrs={'units': "W/cm^2"}
7071
)
7172
},
@@ -136,6 +137,17 @@ def test_json_serialize_deserialize_solution_analysis(compact_representation: bo
136137
analysis_reconstructed = SolutionAnalysis.from_json(analysis_json)
137138
assert dataclasses_are_equal(analysis_reconstructed, analysis)
138139

140+
def test_solution_analyze_data_types(example_solution:Solution, example_transducer:Transducer):
141+
"""Test that solution analysis field are all floats or lists of floats as expected"""
142+
analysis = example_solution.analyze(example_transducer)
143+
for f in fields(analysis):
144+
value = getattr(analysis, f.name)
145+
if not isinstance(value, float):
146+
assert isinstance(value, list)
147+
if len(value) > 0:
148+
assert isinstance(value[0], float)
149+
150+
139151
def test_solution_created_date():
140152
"""Test that created date is recent when a solution is created."""
141153
tolerance = timedelta(seconds=2)

0 commit comments

Comments
 (0)