Skip to content

Commit a428a39

Browse files
authored
Merge pull request #513 from Blosc/cleanup_indexing
Refactor get/setitem for nonunit steps
2 parents f7667f3 + 7e0a0d6 commit a428a39

2 files changed

Lines changed: 59 additions & 64 deletions

File tree

src/blosc2/ndarray.py

Lines changed: 58 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3948,20 +3948,12 @@ def get_fselection_numpy(self, key: list | np.ndarray) -> np.ndarray:
39483948
# Default when there are booleans
39493949
# TODO: for boolean indexing could be optimised by avoiding
39503950
# calculating out_shape prior to loop and keeping track on-the-fly (like in LazyExpr machinery)
3951-
return self._get_set_findex_default(_slice, out_shape)
3952-
3953-
def _get_set_findex_default(self, _slice, out_shape=None, updater=None):
3954-
_get = False
3955-
if not ((out_shape is None) or (updater is None)):
3956-
raise ValueError("Cannot provide both out_shape and updater.")
3957-
# we have a getitem
3958-
if out_shape is not None:
3959-
_get = True
3960-
out = np.empty(out_shape, dtype=self.dtype)
3961-
elif updater is None:
3962-
raise ValueError("Must provide one of out_shape or updater.")
3963-
else:
3964-
out = self # default return for no intersecting chunks
3951+
out = np.empty(out_shape, dtype=self.dtype)
3952+
return self._get_set_findex_default(_slice, out)
3953+
3954+
def _get_set_findex_default(self, _slice, out=None, value=None):
3955+
_get = out is not None
3956+
out = self if out is None else out # default return for setitem with no intersecting chunks
39653957
if 0 in self.shape:
39663958
return out
39673959
chunk_size = ndindex.ChunkSize(self.chunks) # only works with nonzero chunks
@@ -3976,10 +3968,10 @@ def _get_set_findex_default(self, _slice, out_shape=None, updater=None):
39763968
chunk = np.empty(tuple(sp - st for st, sp in zip(start, stop, strict=True)), dtype=self.dtype)
39773969
super().get_slice_numpy(chunk, (start, stop))
39783970
if _get:
3979-
new_shape = sel_idx.newshape(out_shape)
3971+
new_shape = sel_idx.newshape(out.shape)
39803972
out[sel_idx.raw] = chunk[sub_idx].reshape(new_shape)
39813973
else:
3982-
chunk[sub_idx] = updater(sel_idx.raw)
3974+
chunk[sub_idx] = value if np.isscalar(value) else value[sel_idx]
39833975
out = super().set_slice((start, stop), chunk)
39843976
return out
39853977

@@ -4001,7 +3993,42 @@ def set_oselection_numpy(self, key: list | np.ndarray, arr: NDArray) -> np.ndarr
40013993
"""
40023994
return super().set_oindex_numpy(key, arr)
40033995

4004-
def __getitem__( # noqa: C901
3996+
def _get_set_nonunit_steps(self, _slice, out=None, value=None):
3997+
start, stop, step, mask = _slice
3998+
_get = out is not None
3999+
out = self if out is None else out # default return for setitem with no intersecting chunks
4000+
if 0 in self.shape:
4001+
return out
4002+
4003+
chunks = self.chunks
4004+
_slice = tuple(slice(s, st, stp) for s, st, stp in zip(start, stop, step, strict=True))
4005+
intersecting_chunks = [
4006+
slice_to_chunktuple(s, c) for s, c in zip(_slice, chunks, strict=True)
4007+
] # internally handles negative steps
4008+
for c in product(*intersecting_chunks):
4009+
sel_idx, glob_selection, sub_idx = _get_selection(c, _slice, chunks)
4010+
sel_idx = tuple(s for s, m in zip(sel_idx, mask, strict=True) if not m)
4011+
sub_idx = tuple(s if not m else s.start for s, m in zip(sub_idx, mask, strict=True))
4012+
locstart, locstop = _get_local_slice(
4013+
glob_selection,
4014+
(),
4015+
((), ()), # switches start and stop for negative steps
4016+
)
4017+
chunk = np.empty(
4018+
tuple(sp - st for st, sp in zip(locstart, locstop, strict=True)), dtype=self.dtype
4019+
)
4020+
# basically load whole chunk, except for slice part at beginning and end
4021+
super().get_slice_numpy(chunk, (locstart, locstop)) # copy relevant slice of chunk
4022+
if _get:
4023+
out[sel_idx] = chunk[sub_idx] # update relevant parts of chunk
4024+
else:
4025+
chunk[sub_idx] = (
4026+
value if np.isscalar(value) else value[sel_idx]
4027+
) # update relevant parts of chunk
4028+
out = super().set_slice((locstart, locstop), chunk) # load updated partial chunk into array
4029+
return out
4030+
4031+
def __getitem__(
40054032
self,
40064033
key: None
40074034
| int
@@ -4083,7 +4110,8 @@ def __getitem__( # noqa: C901
40834110
if key:
40844111
_slice = ndindex.ndindex(()).expand(self.shape) # just get whole array
40854112
out_shape = _slice.newshape(self.shape)
4086-
return np.expand_dims(self._get_set_findex_default(_slice, out_shape=out_shape), 0)
4113+
out = np.empty(out_shape, dtype=self.dtype)
4114+
return np.expand_dims(self._get_set_findex_default(_slice, out=out), 0)
40874115
else: # do nothing
40884116
return np.empty((0,) + self.shape, dtype=self.dtype)
40894117
elif (
@@ -4099,12 +4127,9 @@ def __getitem__( # noqa: C901
40994127
return self.get_fselection_numpy(key) # fancy index default, can be quite slow
41004128

41014129
start, stop, step, none_mask = get_ndarray_start_stop(self.ndim, key_, self.shape)
4102-
for i, s in enumerate(step): # (start, stop, -1) => stop < start
4103-
if s < 0:
4104-
temp = start[i]
4105-
start[i] = stop[i] + 1 # don't want to include stop
4106-
stop[i] = temp + 1 # want to include start
4107-
shape = np.array([sp - st for st, sp in zip(start, stop, strict=True)])
4130+
shape = np.array(
4131+
[(sp - st - np.sign(stp)) // stp + 1 for st, sp, stp in zip(start, stop, step, strict=True)]
4132+
)
41084133
if mask is not None: # there are some dummy dims from ints
41094134
# only get mask for not Nones in key to have nm_ same length as shape
41104135
nm_ = [not m for m, n in zip(mask, none_mask, strict=True) if not n]
@@ -4113,12 +4138,11 @@ def __getitem__( # noqa: C901
41134138
shape = tuple(shape[nm_])
41144139

41154140
# Create the array to store the result
4116-
arr = np.empty(shape, dtype=self.dtype)
4117-
nparr = super().get_slice_numpy(arr, (start, stop))
4118-
if step != (1,) * self.ndim: # TODO: optimise to work like __setitem__ for non-unit steps
4119-
# have to make step refer to sliced dims (which will be less if ints present)
4120-
slice_ = tuple(slice(None, None, st) for st, m in zip(step, nm_, strict=True) if m)
4121-
nparr = nparr[slice_]
4141+
nparr = np.empty(shape, dtype=self.dtype)
4142+
if step != (1,) * self.ndim:
4143+
nparr = self._get_set_nonunit_steps((start, stop, step, [not i for i in nm_]), out=nparr)
4144+
else:
4145+
nparr = super().get_slice_numpy(nparr, (start, stop))
41224146

41234147
if np.any(none_mask):
41244148
nparr = np.expand_dims(nparr, axis=[i for i, n in enumerate(none_mask) if n])
@@ -4130,7 +4154,7 @@ def __getitem__( # noqa: C901
41304154

41314155
return nparr
41324156

4133-
def __setitem__( # noqa : C901
4157+
def __setitem__(
41344158
self,
41354159
key: None | int | slice | Sequence[slice | int | np.bool_ | np.ndarray[int | np.bool_] | None],
41364160
value: object,
@@ -4174,14 +4198,6 @@ def __setitem__( # noqa : C901
41744198
if hasattr(value, "shape") and value.shape == ():
41754199
value = value.item()
41764200

4177-
def updater(sel_idx):
4178-
return value[sel_idx]
4179-
4180-
if np.isscalar(value): # overwrite updater function for simple cases (faster)
4181-
4182-
def updater(sel_idx):
4183-
return value
4184-
41854201
if builtins.any(isinstance(k, (list, np.ndarray)) for k in key_): # fancy indexing
41864202
_slice = ndindex.ndindex(key_).expand(
41874203
self.shape
@@ -4194,36 +4210,14 @@ def updater(sel_idx):
41944210
_slice = ndindex.ndindex(()).expand(self.shape) # just get whole array
41954211
else: # do nothing
41964212
return self
4197-
return self._get_set_findex_default(_slice, updater=updater)
4213+
return self._get_set_findex_default(_slice, value=value)
41984214

41994215
start, stop, step, none_mask = get_ndarray_start_stop(self.ndim, key_, self.shape)
42004216

42014217
if step != (1,) * self.ndim: # handle non-unit or negative steps
42024218
if np.any(none_mask):
42034219
raise ValueError("Cannot mix non-unit steps and None indexing for __setitem__.")
4204-
chunks = self.chunks
4205-
shape = self.shape
4206-
_slice = tuple(slice(s, st, stp) for s, st, stp in zip(start, stop, step, strict=True))
4207-
intersecting_chunks = [
4208-
slice_to_chunktuple(s, c) for s, c in zip(_slice, chunks, strict=True)
4209-
] # internally handles negative steps
4210-
out = self # for when shape has 0 (i.e. arr is empty, as then skip loop)
4211-
for c in product(*intersecting_chunks):
4212-
sel_idx, glob_selection, sub_idx = _get_selection(c, _slice, chunks)
4213-
sel_idx = tuple(s for s, m in zip(sel_idx, mask, strict=True) if not m)
4214-
sub_idx = tuple(s if not m else s.start for s, m in zip(sub_idx, mask, strict=True))
4215-
locstart, locstop = _get_local_slice(
4216-
glob_selection,
4217-
(),
4218-
((), ()), # switches start and stop for negative steps
4219-
)
4220-
chunk = np.empty(
4221-
tuple(sp - st for st, sp in zip(locstart, locstop, strict=True)), dtype=self.dtype
4222-
)
4223-
super().get_slice_numpy(chunk, (locstart, locstop)) # copy relevant slice of chunk
4224-
chunk[sub_idx] = updater(sel_idx) # update relevant parts of chunk
4225-
out = super().set_slice((locstart, locstop), chunk) # load updated partial chunk into array
4226-
return out
4220+
return self._get_set_nonunit_steps((start, stop, step, mask), value=value)
42274221

42284222
shape = [sp - st for sp, st in zip(stop, start, strict=False)]
42294223
if isinstance(value, NDArray):
@@ -6320,7 +6314,7 @@ def _get_selection(ctuple, ptuple, chunks):
63206314
out_pselection = ()
63216315
i = 0
63226316
for ps, pt in zip(pselection, ptuple, strict=True):
6323-
sign_ = pt.step // builtins.abs(pt.step)
6317+
sign_ = np.sign(pt.step)
63246318
n = (ps.start - pt.start - sign_) // pt.step
63256319
out_start = n + 1
63266320
# ps.stop always positive except for case where get full array (it is then -1 since desire 0th element)

tests/ndarray/test_getitem.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
argvalues = [
1717
([456], [258], [73], slice(0, 1), np.int32),
1818
([77, 134, 13], [31, 13, 5], [7, 8, 3], (slice(3, 7), slice(50, 100), 7), np.float64),
19+
([77, 134, 13], [31, 13, 5], [7, 8, 3], (slice(3, 56, 3), slice(100, 50, -4), 7), np.float64),
1920
([12, 13, 14, 15, 16], [5, 5, 5, 5, 5], [2, 2, 2, 2, 2], (slice(1, 3), ..., slice(3, 6)), np.float32),
2021
]
2122

0 commit comments

Comments
 (0)