Skip to content

Commit 56c2da9

Browse files
authored
fix: handle io array shape dimensionality (#157)
* fix: handle io array shape dimensionality * refactor: simplify logic * chore: format * fix: handle else * chore: clearer comments + bigger test * refactor: use iterators instead of `ctr` for size-1 axis
1 parent b8e238d commit 56c2da9

8 files changed

Lines changed: 59 additions & 39 deletions

File tree

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ crate-type = ["cdylib", "rlib"]
1010

1111
[dependencies]
1212
pyo3 = { version = "0.27.1", features = ["abi3-py311"] }
13-
zarrs = { version = "0.23.0", features = ["async", "zlib", "pcodec", "bz2"] }
13+
zarrs = { version = "0.23.6", features = ["async", "zlib", "pcodec", "bz2"] }
1414
rayon_iter_concurrent_limit = "0.2.0"
1515
rayon = "1.10.0"
1616
# fix for https://stackoverflow.com/questions/76593417/package-openssl-was-not-found-in-the-pkg-config-search-path

python/zarrs/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ._internal import __version__
22
from .pipeline import ZarrsCodecPipeline as _ZarrsCodecPipeline
3-
from .utils import CollapsedDimensionError, DiscontiguousArrayError
3+
from .utils import DiscontiguousArrayError, UnsupportedVIndexingError
44

55

66
# Need to do this redirection so people can access the pipeline as `zarrs.ZarrsCodecPipeline` instead of `zarrs.pipeline.ZarrsCodecPipeline`
@@ -11,6 +11,6 @@ class ZarrsCodecPipeline(_ZarrsCodecPipeline):
1111
__all__ = [
1212
"ZarrsCodecPipeline",
1313
"DiscontiguousArrayError",
14-
"CollapsedDimensionError",
14+
"UnsupportedVIndexingError",
1515
"__version__",
1616
]

python/zarrs/pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626

2727
from ._internal import CodecPipelineImpl
2828
from .utils import (
29-
CollapsedDimensionError,
3029
DiscontiguousArrayError,
3130
FillValueNoneError,
31+
UnsupportedVIndexingError,
3232
make_chunk_info_for_rust_with_indices,
3333
)
3434

@@ -185,7 +185,7 @@ async def read(
185185
except (
186186
UnsupportedMetadataError,
187187
DiscontiguousArrayError,
188-
CollapsedDimensionError,
188+
UnsupportedVIndexingError,
189189
UnsupportedDataTypeError,
190190
FillValueNoneError,
191191
):
@@ -220,7 +220,7 @@ async def write(
220220
except (
221221
UnsupportedMetadataError,
222222
DiscontiguousArrayError,
223-
CollapsedDimensionError,
223+
UnsupportedVIndexingError,
224224
UnsupportedDataTypeError,
225225
FillValueNoneError,
226226
):

python/zarrs/utils.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class DiscontiguousArrayError(Exception):
3030
pass
3131

3232

33-
class CollapsedDimensionError(Exception):
33+
class UnsupportedVIndexingError(Exception):
3434
pass
3535

3636

@@ -160,7 +160,7 @@ def make_chunk_info_for_rust_with_indices(
160160
drop_axes: tuple[int, ...],
161161
shape: tuple[int, ...],
162162
) -> RustChunkInfo:
163-
shape = shape if shape else (1,) # constant array
163+
is_constant = shape == ()
164164
chunk_info_with_indices: list[ChunkItem] = []
165165
write_empty_chunks: bool = True
166166
for (
@@ -171,8 +171,12 @@ def make_chunk_info_for_rust_with_indices(
171171
_,
172172
) in batch_info:
173173
write_empty_chunks = chunk_spec.config.write_empty_chunks
174+
# Convert the selector tuples to ones that only have slices i.e., `i: int` replaced by slice(i, i+1)
174175
out_selection_as_slices = selector_tuple_to_slice_selection(out_selection)
175176
chunk_selection_as_slices = selector_tuple_to_slice_selection(chunk_selection)
177+
# Because `chunk_selection_as_slices` contains only slices, certain types of vindex-ing are not going to be able to be processed by the zarrs pipeline.
178+
# Thus we get the shapes of the input selector and the the converted-to-slices selector to check if they differ.
179+
# If they differ, then the indexing operation is not supported because it is not describe-able as slices.
176180
shape_chunk_selection_slices = get_shape_for_selector(
177181
tuple(chunk_selection_as_slices),
178182
chunk_spec.shape,
@@ -182,17 +186,44 @@ def make_chunk_info_for_rust_with_indices(
182186
shape_chunk_selection = get_shape_for_selector(
183187
chunk_selection, chunk_spec.shape, pad=True, drop_axes=drop_axes
184188
)
185-
if prod_op(shape_chunk_selection) != prod_op(shape_chunk_selection_slices):
186-
raise CollapsedDimensionError(
189+
if (chunk_size := prod_op(shape_chunk_selection)) != prod_op(
190+
shape_chunk_selection_slices
191+
):
192+
raise UnsupportedVIndexingError(
187193
f"{shape_chunk_selection} != {shape_chunk_selection_slices}"
188194
)
195+
if not is_constant and chunk_size > prod_op(shape):
196+
raise IndexError(
197+
f"the size of the chunk subset {shape_chunk_selection} and input/output subset {shape} are incompatible"
198+
)
199+
io_array_shape = list(shape)
200+
out_selection_expanded = out_selection_as_slices
201+
# We need to have io_array_shape and out_selection_expanded with dimensionalities matching that of the underlying array.
202+
# `drop_axes`` is only triggered via fancy outer-indexing because applying `chunk_selection_as_slices` to the chunk array would not drop a dimension that the out-array thinks should be dropped, thus that dimension needs to be indicated.
203+
# However, other indexing operations can silently drop a dimension on input to match the output, like `z[1, ...]`.
204+
# In other words, applying the `chunk_selection_as_slices` to a chunk array would drop a dimension, but `out_selection` already encodes this dropped dimension because zarr-python constructs the out-array missing the dimension.
205+
# So if we detect that a dimension has been dropped silently like this after converting to slices, we update to handle the dropped dimension.
206+
scs_iter = iter(shape_chunk_selection)
207+
scs_current = next(scs_iter, None)
208+
for idx_shape, shape_chunk_from_slices in enumerate(
209+
shape_chunk_selection_slices
210+
):
211+
# Detect if this dimension has been dropped on the io_array i.e., shape_chunk_selection has been exhausted so there is an extra 1-sized dimension at the end or has a mismatch with the "full" chunk shape `shape_chunk_selection_slices`.
212+
if shape_chunk_from_slices == 1 != scs_current:
213+
drop_axes += (idx_shape,)
214+
else:
215+
scs_current = next(scs_iter, None)
216+
if drop_axes:
217+
for axis in drop_axes:
218+
io_array_shape.insert(axis, 1)
219+
out_selection_expanded.insert(axis, slice(0, 1))
189220
chunk_info_with_indices.append(
190221
ChunkItem(
191222
key=byte_getter.path,
192223
chunk_subset=chunk_selection_as_slices,
193224
chunk_shape=chunk_spec.shape,
194-
subset=out_selection_as_slices,
195-
shape=shape,
225+
subset=out_selection_expanded,
226+
shape=io_array_shape,
196227
)
197228
)
198229
return RustChunkInfo(chunk_info_with_indices, write_empty_chunks)

src/chunk_item.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pub(crate) struct ChunkItem {
3232
pub subset: ArraySubset,
3333
pub shape: Vec<NonZeroU64>,
3434
pub num_elements: u64,
35+
pub array_shape: Vec<NonZeroU64>,
3536
}
3637

3738
#[gen_stub_pymethods]
@@ -65,6 +66,7 @@ impl ChunkItem {
6566
subset,
6667
shape: chunk_shape_nonzero_u64,
6768
num_elements,
69+
array_shape: shape_nonzero_u64,
6870
})
6971
}
7072
}

src/lib.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ mod utils;
3838

3939
use crate::concurrency::ChunkConcurrentLimitAndCodecOptions;
4040
use crate::store::StoreConfig;
41-
use crate::utils::{PyCodecErrExt, PyErrExt as _, PyUntypedArrayExt as _};
41+
use crate::utils::{PyCodecErrExt, PyErrExt as _};
4242

4343
// TODO: Use a OnceLock for store with get_or_try_init when stabilised?
4444
#[gen_stub_pyclass]
@@ -288,7 +288,6 @@ impl CodecPipelineImpl {
288288
) -> PyResult<()> {
289289
// Get input array
290290
let output = Self::nparray_to_unsafe_cell_slice(value)?;
291-
let output_shape: Vec<u64> = value.shape_zarr()?;
292291

293292
// Adjust the concurrency based on the codec chain and the first chunk description
294293
let Some((chunk_concurrent_limit, codec_options)) =
@@ -343,7 +342,7 @@ impl CodecPipelineImpl {
343342
.fixed_size()
344343
.ok_or("variable length data type not supported")
345344
.map_py_err::<PyTypeError>()?,
346-
&output_shape,
345+
bytemuck::must_cast_slice(&item.array_shape),
347346
item.subset.clone(),
348347
)
349348
.map_py_err::<PyRuntimeError>()?
@@ -410,7 +409,6 @@ impl CodecPipelineImpl {
410409
} else {
411410
InputValue::Constant(FillValue::new(input_slice.to_vec()))
412411
};
413-
let input_shape: Vec<u64> = value.shape_zarr()?;
414412

415413
// Adjust the concurrency based on the codec chain and the first chunk description
416414
let Some((chunk_concurrent_limit, mut codec_options)) =
@@ -424,7 +422,11 @@ impl CodecPipelineImpl {
424422
let store_chunk = |item: ChunkItem| match &input {
425423
InputValue::Array(input) => {
426424
let chunk_subset_bytes = input
427-
.extract_array_subset(&item.subset, &input_shape, &self.data_type)
425+
.extract_array_subset(
426+
&item.subset,
427+
bytemuck::must_cast_slice(&item.array_shape),
428+
&self.data_type,
429+
)
428430
.map_codec_err()?;
429431
self.store_chunk_subset_bytes(
430432
&item,

src/utils.rs

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use std::fmt::Display;
22

3-
use numpy::{PyUntypedArray, PyUntypedArrayMethods};
4-
use pyo3::{Bound, PyErr, PyResult, PyTypeInfo};
3+
use pyo3::{PyErr, PyResult, PyTypeInfo};
54
use zarrs::array::CodecError;
65

76
use crate::ChunkItem;
@@ -38,23 +37,6 @@ impl<T> PyCodecErrExt<T> for Result<T, CodecError> {
3837
}
3938
}
4039

41-
pub(crate) trait PyUntypedArrayExt {
42-
fn shape_zarr(&self) -> PyResult<Vec<u64>>;
43-
}
44-
45-
impl PyUntypedArrayExt for Bound<'_, PyUntypedArray> {
46-
fn shape_zarr(&self) -> PyResult<Vec<u64>> {
47-
Ok(if self.shape().is_empty() {
48-
vec![1] // scalar value
49-
} else {
50-
self.shape()
51-
.iter()
52-
.map(|&i| u64::try_from(i))
53-
.collect::<Result<_, _>>()?
54-
})
55-
}
56-
}
57-
5840
pub fn is_whole_chunk(item: &ChunkItem) -> bool {
5941
item.chunk_subset.start().iter().all(|&o| o == 0)
6042
&& item.chunk_subset.shape() == bytemuck::must_cast_slice::<_, u64>(&item.shape)

tests/test_sharding.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,12 @@ def test_sharding_partial_readwrite(
148148

149149
a[:] = data
150150

151-
for x in range(data.shape[0]):
152-
read_data = a[x, :, :]
153-
assert np.array_equal(data[x], read_data)
151+
for axis in range(len(data.shape)):
152+
for x in range(data.shape[0]):
153+
selector = [slice(None), slice(None), slice(None)]
154+
selector[axis] = x
155+
read_data = a[*tuple(selector)]
156+
assert np.array_equal(data[*tuple(selector)], read_data)
154157

155158

156159
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)