Skip to content

Commit 17f8a39

Browse files
use chunks parameter for dataarray models (#1031)
* Correctly parse chunks parameter for RasterModels when using DataArrays * Remove parsing logic as it is not necessary for spatialimage class * Dont overwrite chunks of input data if no specific chunksize is set by the user * add cases in test for chunks for raster models --------- Co-authored-by: Luca Marconato <m.lucalmer@gmail.com>
1 parent aec5181 commit 17f8a39

2 files changed

Lines changed: 63 additions & 0 deletions

File tree

src/spatialdata/models/models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,10 @@ def parse(
239239
chunks=chunks,
240240
)
241241
_parse_transformations(data, parsed_transform)
242+
else:
243+
# Chunk single scale images
244+
if chunks is not None:
245+
data = data.chunk(chunks=chunks)
242246
cls()._check_chunk_size_not_too_large(data)
243247
# recompute coordinates for (multiscale) spatial image
244248
return compute_coordinates(data)

tests/models/test_models.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,65 @@ def test_raster_schema(
195195
with pytest.raises(ValueError):
196196
model.parse(image, **kwargs)
197197

198+
@pytest.mark.parametrize(
199+
"model,chunks,expected",
200+
[
201+
(Labels2DModel, None, (10, 10)),
202+
(Labels2DModel, 5, (5, 5)),
203+
(Labels2DModel, (5, 5), (5, 5)),
204+
(Labels2DModel, {"x": 5, "y": 5}, (5, 5)),
205+
(Labels3DModel, None, (2, 10, 10)),
206+
(Labels3DModel, 5, (2, 5, 5)),
207+
(Labels3DModel, (2, 5, 5), (2, 5, 5)),
208+
(Labels3DModel, {"z": 2, "x": 5, "y": 5}, (2, 5, 5)),
209+
(Image2DModel, None, (1, 10, 10)), # Image2D Models always have a c dimension
210+
(Image2DModel, 5, (1, 5, 5)),
211+
(Image2DModel, (1, 5, 5), (1, 5, 5)),
212+
(Image2DModel, {"c": 1, "x": 5, "y": 5}, (1, 5, 5)),
213+
(Image3DModel, None, (1, 2, 10, 10)), # Image3D models have z in addition, so 4 total dimensions
214+
(Image3DModel, 5, (1, 2, 5, 5)),
215+
(Image3DModel, (1, 2, 5, 5), (1, 2, 5, 5)),
216+
(
217+
Image3DModel,
218+
{"c": 1, "z": 2, "x": 5, "y": 5},
219+
(1, 2, 5, 5),
220+
),
221+
],
222+
)
223+
def test_raster_models_parse_with_chunks_parameter(self, model, chunks, expected):
224+
image: ArrayLike = np.arange(100).reshape((10, 10))
225+
if model in [Labels3DModel, Image3DModel]:
226+
image = np.stack([image] * 2)
227+
228+
if model in [Image2DModel, Image3DModel]:
229+
image = np.expand_dims(image, axis=0)
230+
231+
# parse as numpy array
232+
# single scale
233+
x_ss = model.parse(image, chunks=chunks)
234+
assert x_ss.data.chunksize == expected
235+
# multi scale
236+
x_ms = model.parse(image, chunks=chunks, scale_factors=(2,))
237+
assert x_ms["scale0"]["image"].data.chunksize == expected
238+
239+
# parse as dask array
240+
dask_image = from_array(image)
241+
# single scale
242+
y_ss = model.parse(dask_image, chunks=chunks)
243+
assert y_ss.data.chunksize == expected
244+
# multi scale
245+
y_ms = model.parse(dask_image, chunks=chunks, scale_factors=(2,))
246+
assert y_ms["scale0"]["image"].data.chunksize == expected
247+
248+
# parse as DataArray
249+
data_array = DataArray(image, dims=model.dims.dims)
250+
# single scale
251+
z_ss = model.parse(data_array, chunks=chunks)
252+
assert z_ss.data.chunksize == expected
253+
# multi scale
254+
z_ms = model.parse(data_array, chunks=chunks, scale_factors=(2,))
255+
assert z_ms["scale0"]["image"].data.chunksize == expected
256+
198257
@pytest.mark.parametrize("model", [Labels2DModel, Labels3DModel])
199258
def test_labels_model_with_multiscales(self, model):
200259
# Passing "scale_factors" should generate multiscales with a "method" appropriate for labels

0 commit comments

Comments
 (0)