Skip to content

Commit 7e0a0d6

Browse files
committed
Refactor get/setitem for nonunit steps
1 parent 3b4898a commit 7e0a0d6

2 files changed

Lines changed: 59 additions & 64 deletions

File tree

src/blosc2/ndarray.py

Lines changed: 58 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3945,20 +3945,12 @@ def get_fselection_numpy(self, key: list | np.ndarray) -> np.ndarray:
39453945
# Default when there are booleans
39463946
# TODO: for boolean indexing could be optimised by avoiding
39473947
# calculating out_shape prior to loop and keeping track on-the-fly (like in LazyExpr machinery)
3948-
return self._get_set_findex_default(_slice, out_shape)
3949-
3950-
def _get_set_findex_default(self, _slice, out_shape=None, updater=None):
3951-
_get = False
3952-
if not ((out_shape is None) or (updater is None)):
3953-
raise ValueError("Cannot provide both out_shape and updater.")
3954-
# we have a getitem
3955-
if out_shape is not None:
3956-
_get = True
3957-
out = np.empty(out_shape, dtype=self.dtype)
3958-
elif updater is None:
3959-
raise ValueError("Must provide one of out_shape or updater.")
3960-
else:
3961-
out = self # default return for no intersecting chunks
3948+
out = np.empty(out_shape, dtype=self.dtype)
3949+
return self._get_set_findex_default(_slice, out)
3950+
3951+
def _get_set_findex_default(self, _slice, out=None, value=None):
3952+
_get = out is not None
3953+
out = self if out is None else out # default return for setitem with no intersecting chunks
39623954
if 0 in self.shape:
39633955
return out
39643956
chunk_size = ndindex.ChunkSize(self.chunks) # only works with nonzero chunks
@@ -3973,10 +3965,10 @@ def _get_set_findex_default(self, _slice, out_shape=None, updater=None):
39733965
chunk = np.empty(tuple(sp - st for st, sp in zip(start, stop, strict=True)), dtype=self.dtype)
39743966
super().get_slice_numpy(chunk, (start, stop))
39753967
if _get:
3976-
new_shape = sel_idx.newshape(out_shape)
3968+
new_shape = sel_idx.newshape(out.shape)
39773969
out[sel_idx.raw] = chunk[sub_idx].reshape(new_shape)
39783970
else:
3979-
chunk[sub_idx] = updater(sel_idx.raw)
3971+
chunk[sub_idx] = value if np.isscalar(value) else value[sel_idx]
39803972
out = super().set_slice((start, stop), chunk)
39813973
return out
39823974

@@ -3998,7 +3990,42 @@ def set_oselection_numpy(self, key: list | np.ndarray, arr: NDArray) -> np.ndarr
39983990
"""
39993991
return super().set_oindex_numpy(key, arr)
40003992

4001-
def __getitem__( # noqa: C901
3993+
def _get_set_nonunit_steps(self, _slice, out=None, value=None):
3994+
start, stop, step, mask = _slice
3995+
_get = out is not None
3996+
out = self if out is None else out # default return for setitem with no intersecting chunks
3997+
if 0 in self.shape:
3998+
return out
3999+
4000+
chunks = self.chunks
4001+
_slice = tuple(slice(s, st, stp) for s, st, stp in zip(start, stop, step, strict=True))
4002+
intersecting_chunks = [
4003+
slice_to_chunktuple(s, c) for s, c in zip(_slice, chunks, strict=True)
4004+
] # internally handles negative steps
4005+
for c in product(*intersecting_chunks):
4006+
sel_idx, glob_selection, sub_idx = _get_selection(c, _slice, chunks)
4007+
sel_idx = tuple(s for s, m in zip(sel_idx, mask, strict=True) if not m)
4008+
sub_idx = tuple(s if not m else s.start for s, m in zip(sub_idx, mask, strict=True))
4009+
locstart, locstop = _get_local_slice(
4010+
glob_selection,
4011+
(),
4012+
((), ()), # switches start and stop for negative steps
4013+
)
4014+
chunk = np.empty(
4015+
tuple(sp - st for st, sp in zip(locstart, locstop, strict=True)), dtype=self.dtype
4016+
)
4017+
# basically load whole chunk, except for slice part at beginning and end
4018+
super().get_slice_numpy(chunk, (locstart, locstop)) # copy relevant slice of chunk
4019+
if _get:
4020+
out[sel_idx] = chunk[sub_idx] # update relevant parts of chunk
4021+
else:
4022+
chunk[sub_idx] = (
4023+
value if np.isscalar(value) else value[sel_idx]
4024+
) # update relevant parts of chunk
4025+
out = super().set_slice((locstart, locstop), chunk) # load updated partial chunk into array
4026+
return out
4027+
4028+
def __getitem__(
40024029
self,
40034030
key: None
40044031
| int
@@ -4080,7 +4107,8 @@ def __getitem__( # noqa: C901
40804107
if key:
40814108
_slice = ndindex.ndindex(()).expand(self.shape) # just get whole array
40824109
out_shape = _slice.newshape(self.shape)
4083-
return np.expand_dims(self._get_set_findex_default(_slice, out_shape=out_shape), 0)
4110+
out = np.empty(out_shape, dtype=self.dtype)
4111+
return np.expand_dims(self._get_set_findex_default(_slice, out=out), 0)
40844112
else: # do nothing
40854113
return np.empty((0,) + self.shape, dtype=self.dtype)
40864114
elif (
@@ -4096,12 +4124,9 @@ def __getitem__( # noqa: C901
40964124
return self.get_fselection_numpy(key) # fancy index default, can be quite slow
40974125

40984126
start, stop, step, none_mask = get_ndarray_start_stop(self.ndim, key_, self.shape)
4099-
for i, s in enumerate(step): # (start, stop, -1) => stop < start
4100-
if s < 0:
4101-
temp = start[i]
4102-
start[i] = stop[i] + 1 # don't want to include stop
4103-
stop[i] = temp + 1 # want to include start
4104-
shape = np.array([sp - st for st, sp in zip(start, stop, strict=True)])
4127+
shape = np.array(
4128+
[(sp - st - np.sign(stp)) // stp + 1 for st, sp, stp in zip(start, stop, step, strict=True)]
4129+
)
41054130
if mask is not None: # there are some dummy dims from ints
41064131
# only get mask for not Nones in key to have nm_ same length as shape
41074132
nm_ = [not m for m, n in zip(mask, none_mask, strict=True) if not n]
@@ -4110,12 +4135,11 @@ def __getitem__( # noqa: C901
41104135
shape = tuple(shape[nm_])
41114136

41124137
# Create the array to store the result
4113-
arr = np.empty(shape, dtype=self.dtype)
4114-
nparr = super().get_slice_numpy(arr, (start, stop))
4115-
if step != (1,) * self.ndim: # TODO: optimise to work like __setitem__ for non-unit steps
4116-
# have to make step refer to sliced dims (which will be less if ints present)
4117-
slice_ = tuple(slice(None, None, st) for st, m in zip(step, nm_, strict=True) if m)
4118-
nparr = nparr[slice_]
4138+
nparr = np.empty(shape, dtype=self.dtype)
4139+
if step != (1,) * self.ndim:
4140+
nparr = self._get_set_nonunit_steps((start, stop, step, [not i for i in nm_]), out=nparr)
4141+
else:
4142+
nparr = super().get_slice_numpy(nparr, (start, stop))
41194143

41204144
if np.any(none_mask):
41214145
nparr = np.expand_dims(nparr, axis=[i for i, n in enumerate(none_mask) if n])
@@ -4127,7 +4151,7 @@ def __getitem__( # noqa: C901
41274151

41284152
return nparr
41294153

4130-
def __setitem__( # noqa : C901
4154+
def __setitem__(
41314155
self,
41324156
key: None | int | slice | Sequence[slice | int | np.bool_ | np.ndarray[int | np.bool_] | None],
41334157
value: object,
@@ -4171,14 +4195,6 @@ def __setitem__( # noqa : C901
41714195
if hasattr(value, "shape") and value.shape == ():
41724196
value = value.item()
41734197

4174-
def updater(sel_idx):
4175-
return value[sel_idx]
4176-
4177-
if np.isscalar(value): # overwrite updater function for simple cases (faster)
4178-
4179-
def updater(sel_idx):
4180-
return value
4181-
41824198
if builtins.any(isinstance(k, (list, np.ndarray)) for k in key_): # fancy indexing
41834199
_slice = ndindex.ndindex(key_).expand(
41844200
self.shape
@@ -4191,36 +4207,14 @@ def updater(sel_idx):
41914207
_slice = ndindex.ndindex(()).expand(self.shape) # just get whole array
41924208
else: # do nothing
41934209
return self
4194-
return self._get_set_findex_default(_slice, updater=updater)
4210+
return self._get_set_findex_default(_slice, value=value)
41954211

41964212
start, stop, step, none_mask = get_ndarray_start_stop(self.ndim, key_, self.shape)
41974213

41984214
if step != (1,) * self.ndim: # handle non-unit or negative steps
41994215
if np.any(none_mask):
42004216
raise ValueError("Cannot mix non-unit steps and None indexing for __setitem__.")
4201-
chunks = self.chunks
4202-
shape = self.shape
4203-
_slice = tuple(slice(s, st, stp) for s, st, stp in zip(start, stop, step, strict=True))
4204-
intersecting_chunks = [
4205-
slice_to_chunktuple(s, c) for s, c in zip(_slice, chunks, strict=True)
4206-
] # internally handles negative steps
4207-
out = self # for when shape has 0 (i.e. arr is empty, as then skip loop)
4208-
for c in product(*intersecting_chunks):
4209-
sel_idx, glob_selection, sub_idx = _get_selection(c, _slice, chunks)
4210-
sel_idx = tuple(s for s, m in zip(sel_idx, mask, strict=True) if not m)
4211-
sub_idx = tuple(s if not m else s.start for s, m in zip(sub_idx, mask, strict=True))
4212-
locstart, locstop = _get_local_slice(
4213-
glob_selection,
4214-
(),
4215-
((), ()), # switches start and stop for negative steps
4216-
)
4217-
chunk = np.empty(
4218-
tuple(sp - st for st, sp in zip(locstart, locstop, strict=True)), dtype=self.dtype
4219-
)
4220-
super().get_slice_numpy(chunk, (locstart, locstop)) # copy relevant slice of chunk
4221-
chunk[sub_idx] = updater(sel_idx) # update relevant parts of chunk
4222-
out = super().set_slice((locstart, locstop), chunk) # load updated partial chunk into array
4223-
return out
4217+
return self._get_set_nonunit_steps((start, stop, step, mask), value=value)
42244218

42254219
shape = [sp - st for sp, st in zip(stop, start, strict=False)]
42264220
if isinstance(value, NDArray):
@@ -6313,7 +6307,7 @@ def _get_selection(ctuple, ptuple, chunks):
63136307
out_pselection = ()
63146308
i = 0
63156309
for ps, pt in zip(pselection, ptuple, strict=True):
6316-
sign_ = pt.step // builtins.abs(pt.step)
6310+
sign_ = np.sign(pt.step)
63176311
n = (ps.start - pt.start - sign_) // pt.step
63186312
out_start = n + 1
63196313
# ps.stop always positive except for case where get full array (it is then -1 since desire 0th element)

tests/ndarray/test_getitem.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
argvalues = [
1717
([456], [258], [73], slice(0, 1), np.int32),
1818
([77, 134, 13], [31, 13, 5], [7, 8, 3], (slice(3, 7), slice(50, 100), 7), np.float64),
19+
([77, 134, 13], [31, 13, 5], [7, 8, 3], (slice(3, 56, 3), slice(100, 50, -4), 7), np.float64),
1920
([12, 13, 14, 15, 16], [5, 5, 5, 5, 5], [2, 2, 2, 2, 2], (slice(1, 3), ..., slice(3, 6)), np.float32),
2021
]
2122

0 commit comments

Comments
 (0)