Skip to content

Commit 3c6fc80

Browse files
authored
Merge pull request #435 from Blosc/fancyIndex
Fixing reductions and slicing for lazy expressions
2 parents 2f97340 + 1d786a9 commit 3c6fc80

3 files changed

Lines changed: 131 additions & 63 deletions

File tree

src/blosc2/lazyexpr.py

Lines changed: 102 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def compute_broadcast_shape(arrays):
568568
return np.broadcast_shapes(*shapes) if shapes else None
569569

570570

571-
def check_smaller_shape(value_shape, shape, slice_shape):
571+
def check_smaller_shape(value_shape, shape, slice_shape, slice_):
572572
"""Check whether the shape of the value is smaller than the shape of the array.
573573
574574
This follows the NumPy broadcasting rules.
@@ -579,7 +579,10 @@ def check_smaller_shape(value_shape, shape, slice_shape):
579579
is_smaller_shape = any(
580580
s > (1 if i >= len(value_shape) else value_shape[i]) for i, s in enumerate(slice_shape)
581581
)
582-
return len(value_shape) < len(shape) or is_smaller_shape
582+
slice_past_bounds = any(
583+
s.stop > (1 if i >= len(value_shape) else value_shape[i]) for i, s in enumerate(slice_)
584+
)
585+
return len(value_shape) < len(shape) or is_smaller_shape or slice_past_bounds
583586

584587

585588
def _compute_smaller_slice(larger_shape, smaller_shape, larger_slice):
@@ -1396,29 +1399,35 @@ def slices_eval( # noqa: C901
13961399

13971400
dtype = kwargs.pop("dtype", None)
13981401
shape_slice = None
1399-
need_orig_slice = False
1402+
need_final_slice = False
14001403

1401-
# keep orig_slice for final getitem
1404+
# keep orig_slice
14021405
_slice = _slice.raw
14031406
orig_slice = _slice
1407+
full_slice = () # by default the full_slice is the whole array
1408+
final_slice = () # by default the final_slice is the whole array
1409+
1410+
# Compute the shape and chunks of the output array, including broadcasting
1411+
shape = compute_broadcast_shape(operands.values())
14041412
if out is None:
1405-
# Compute the shape and chunks of the output array, including broadcasting
1406-
shape = compute_broadcast_shape(operands.values())
14071413
if _slice != ():
14081414
# Check whether _slice contains an integer, or any step that are not None or 1
14091415
if any(
14101416
(isinstance(s, int)) or (isinstance(s, slice) and s.step not in (None, 1)) for s in _slice
14111417
):
1412-
need_orig_slice = True
1418+
need_final_slice = True
14131419
_slice = tuple(slice(i, i + 1) if isinstance(i, int) else i for i in _slice)
14141420
full_slice = tuple(
1415-
slice(s.start or 0, s.stop or shape[i], None) for i, s in enumerate(_slice)
1421+
slice(s.start or 0, s.stop or shape[i], 1) for i, s in enumerate(_slice)
14161422
) # get rid of non-unit steps
14171423
# shape_slice in general not equal to final shape:
1418-
# dummy dims (due to ints) or non-unit steps will be dealt with by taking orig_slice
1424+
# dummy dims (due to ints) or non-unit steps will be dealt with by taking final_slice
14191425
shape_slice = ndindex.ndindex(full_slice).newshape(shape)
1426+
final_slice = ndindex.ndindex(orig_slice).as_subindex(full_slice).raw
14201427
else:
1421-
# TODO: check that this is fine since __out__ could have shape_slice and __shape__ should refer to full array
1428+
# # out should always have shape of full array
1429+
# if shape is not None and shape != out.shape:
1430+
# raise ValueError("Provided output shape does not match the operands' shape.")
14221431
shape = out.shape
14231432

14241433
if chunks is None: # Guess chunk shape
@@ -1457,9 +1466,8 @@ def slices_eval( # noqa: C901
14571466
intersecting_chunks = chunk_size.as_subchunks(_slice, shape) # if _slice is (), returns all chunks
14581467

14591468
for nchunk, chunk_slice in enumerate(intersecting_chunks):
1460-
# Check whether current chunk_slice intersects with _slice
1469+
# get intersection of chunk and target
14611470
if _slice != ():
1462-
# get intersection of chunk and target
14631471
cslice = tuple(
14641472
slice(max(s1.start, s2.start), min(s1.stop, s2.stop))
14651473
for s1, s2 in zip(chunk_slice.raw, _slice, strict=True)
@@ -1469,6 +1477,10 @@ def slices_eval( # noqa: C901
14691477

14701478
cslice_shape = tuple(s.stop - s.start for s in cslice)
14711479
len_chunk = math.prod(cslice_shape)
1480+
# get local index of part of out that is to be updated
1481+
cslice_subidx = (
1482+
ndindex.ndindex(cslice).as_subindex(full_slice).raw
1483+
) # in the case full_slice=(), just gives cslice
14721484

14731485
# Get the starts and stops for the slice
14741486
starts = [s.start if s.start is not None else 0 for s in cslice]
@@ -1482,7 +1494,7 @@ def slices_eval( # noqa: C901
14821494
if value.shape == ():
14831495
chunk_operands[key] = value[()]
14841496
continue
1485-
if check_smaller_shape(value.shape, shape, cslice_shape):
1497+
if check_smaller_shape(value.shape, shape, cslice_shape, cslice):
14861498
# We need to fetch the part of the value that broadcasts with the operand
14871499
smaller_slice = compute_smaller_slice(shape, value.shape, cslice)
14881500
chunk_operands[key] = value[smaller_slice]
@@ -1502,6 +1514,7 @@ def slices_eval( # noqa: C901
15021514

15031515
if callable(expression):
15041516
result = np.empty(cslice_shape, dtype=out.dtype) # raises error if out is None
1517+
# cslice should be equal to cslice_subidx
15051518
# Call the udf directly and use result as the output array
15061519
offset = tuple(s.start for s in cslice)
15071520
expression(tuple(chunk_operands.values()), result, offset=offset)
@@ -1573,12 +1586,12 @@ def slices_eval( # noqa: C901
15731586
out.schunk.update_data(nchunk, result, copy=False)
15741587
else:
15751588
try:
1576-
out[cslice] = result
1589+
out[cslice_subidx] = result
15771590
except ComplexWarning:
15781591
# The result is a complex number, so we need to convert it to real.
15791592
# This is a workaround for rigidness of numpy with type casting.
15801593
result = result.real.astype(out.dtype)
1581-
out[cslice] = result
1594+
out[cslice_subidx] = result
15821595
elif len(where) == 1:
15831596
lenres = len(result)
15841597
out[lenout : lenout + lenres] = result
@@ -1588,7 +1601,7 @@ def slices_eval( # noqa: C901
15881601
else:
15891602
raise ValueError("The where condition must be a tuple with one or two elements")
15901603

1591-
if where is not None and len(where) < 2: # Don't need to take orig_slice since filled up from 0 index
1604+
if where is not None and len(where) < 2: # Don't need to take final_slice since filled up from 0 index
15921605
if _order is not None:
15931606
# argsort the result following _order
15941607
new_order = np.argsort(out[:lenout])
@@ -1600,16 +1613,16 @@ def slices_eval( # noqa: C901
16001613
else:
16011614
out.resize((lenout,))
16021615

1603-
else: # Need to take orig_slice since filled up array according to slice_ for each chunk
1604-
if orig_slice != ():
1616+
else: # Need to take final_slice since filled up array according to slice_ for each chunk
1617+
if final_slice != ():
16051618
if isinstance(out, np.ndarray):
1606-
if need_orig_slice:
1607-
out = out[orig_slice]
1619+
if need_final_slice: # only called if out was None
1620+
out = out[final_slice]
16081621
elif isinstance(out, blosc2.NDArray):
16091622
# It *seems* better to choose an automatic chunks and blocks for the output array
16101623
# out = out.slice(_slice, chunks=out.chunks, blocks=out.blocks)
1611-
if need_orig_slice:
1612-
out = out.slice(orig_slice)
1624+
if need_final_slice: # only called if out was None
1625+
out = out.slice(final_slice)
16131626
else:
16141627
raise ValueError("The output array is not a NumPy array or a NDArray")
16151628

@@ -1679,7 +1692,7 @@ def slices_eval_getitem(
16791692
if value.shape == ():
16801693
slice_operands[key] = value[()]
16811694
continue
1682-
if check_smaller_shape(value.shape, shape, slice_shape):
1695+
if check_smaller_shape(value.shape, shape, slice_shape, _slice_bcast):
16831696
# We need to fetch the part of the value that broadcasts with the operand
16841697
smaller_slice = compute_smaller_slice(shape, value.shape, _slice)
16851698
slice_operands[key] = value[smaller_slice]
@@ -1703,7 +1716,8 @@ def slices_eval_getitem(
17031716
# This is a workaround for rigidness of numpy with type casting.
17041717
return result.real.astype(dtype, copy=False)
17051718
else:
1706-
out[()] = result
1719+
# out should always have maximal shape
1720+
out[_slice] = result
17071721
return out
17081722

17091723

@@ -1721,6 +1735,19 @@ def infer_reduction_dtype(dtype, operation):
17211735
raise ValueError(f"Unsupported operation: {operation}")
17221736

17231737

1738+
def step_handler(s1start, s2start, s1stop, s2stop, s2step):
1739+
# assume s1step = 1
1740+
newstart = max(s1start, s2start)
1741+
newstop = min(s1stop, s2stop)
1742+
rem = (newstart - s2start) % s2step
1743+
if rem != 0: # only pass through here if s2step is not 1
1744+
newstart += s2step - rem
1745+
# true_stop = start + n*step + 1 -> stop = start + n * step + 1 + residual
1746+
# so n = (stop - start - 1) // step
1747+
newstop = newstart + (newstop - newstart - 1) // s2step * s2step + 1
1748+
return slice(newstart, newstop, s2step)
1749+
1750+
17241751
def reduce_slices( # noqa: C901
17251752
expression: str | Callable[[tuple, np.ndarray, tuple[int]], None],
17261753
operands: dict,
@@ -1770,14 +1797,26 @@ def reduce_slices( # noqa: C901
17701797
# Compute the shape and chunks of the output array, including broadcasting
17711798
shape = compute_broadcast_shape(operands.values())
17721799

1800+
_slice = _slice.raw
1801+
shape_slice = shape
1802+
full_slice = () # by default the full_slice is the whole array
1803+
if out is None and _slice != ():
1804+
shape_slice = ndindex.ndindex(_slice).newshape(shape)
1805+
full_slice = _slice
1806+
1807+
# after slicing, we reduce to calculate shape of output
17731808
if axis is None:
1774-
axis = tuple(range(len(shape)))
1809+
axis = tuple(range(len(shape_slice)))
17751810
elif not isinstance(axis, tuple):
17761811
axis = (axis,)
1812+
axis = tuple(a if a >= 0 else a + len(shape_slice) for a in axis)
17771813
if keepdims:
1778-
reduced_shape = tuple(1 if i in axis else s for i, s in enumerate(shape))
1814+
reduced_shape = tuple(1 if i in axis else s for i, s in enumerate(shape_slice))
17791815
else:
1780-
reduced_shape = tuple(s for i, s in enumerate(shape) if i not in axis)
1816+
reduced_shape = tuple(s for i, s in enumerate(shape_slice) if i not in axis)
1817+
1818+
if out is not None and reduced_shape != out.shape:
1819+
raise ValueError("Provided output shape does not match the reduced shape.")
17811820

17821821
if is_inside_new_expr():
17831822
# We already have the dtype and reduced_shape, so return immediately
@@ -1820,23 +1859,17 @@ def reduce_slices( # noqa: C901
18201859
chunk_operands = {}
18211860
# Check which chunks intersect with _slice
18221861
chunk_size = ndindex.ChunkSize(chunks)
1823-
_slice = _slice.raw
18241862
intersecting_chunks = chunk_size.as_subchunks(_slice, shape) # if _slice is (), returns all chunks
18251863
out_init = False
18261864

18271865
for nchunk, chunk_slice in enumerate(intersecting_chunks):
18281866
cslice = chunk_slice.raw
1829-
if keepdims:
1830-
reduced_slice = tuple(slice(None) if i in axis else sl for i, sl in enumerate(cslice))
1831-
else:
1832-
reduced_slice = tuple(sl for i, sl in enumerate(cslice) if i not in axis)
18331867
offset = tuple(s.start for s in cslice) # offset for the udf
1834-
# TODO: Is this necessary, shouldn't slice always be None for a reduction?
18351868
# Check whether current cslice intersects with _slice
18361869
if cslice != () and _slice != ():
18371870
# get intersection of chunk and target
18381871
cslice = tuple(
1839-
slice(max(s1.start, s2.start), min(s1.stop, s2.stop))
1872+
step_handler(s1.start, s2.start, s1.stop, s2.stop, s2.step)
18401873
for s1, s2 in zip(cslice, _slice, strict=True)
18411874
)
18421875
chunks_ = tuple(s.stop - s.start for s in cslice)
@@ -1859,7 +1892,7 @@ def reduce_slices( # noqa: C901
18591892
if value.shape == ():
18601893
chunk_operands[key] = value[()]
18611894
continue
1862-
if check_smaller_shape(value.shape, shape, chunks_):
1895+
if check_smaller_shape(value.shape, shape, chunks_, cslice):
18631896
# We need to fetch the part of the value that broadcasts with the operand
18641897
smaller_slice = compute_smaller_slice(operand.shape, value.shape, cslice)
18651898
chunk_operands[key] = value[smaller_slice]
@@ -1874,6 +1907,15 @@ def reduce_slices( # noqa: C901
18741907
continue
18751908
chunk_operands[key] = value[cslice]
18761909

1910+
# get local index of part of out that is to be updated
1911+
cslice_subidx = (
1912+
ndindex.ndindex(cslice).as_subindex(full_slice).raw
1913+
) # if full_slice is (), just gives cslice
1914+
if keepdims:
1915+
reduced_slice = tuple(slice(None) if i in axis else sl for i, sl in enumerate(cslice_subidx))
1916+
else:
1917+
reduced_slice = tuple(sl for i, sl in enumerate(cslice_subidx) if i not in axis)
1918+
18771919
# Evaluate and reduce the expression using chunks of operands
18781920

18791921
if callable(expression):
@@ -2622,23 +2664,16 @@ def get_num_elements(self, axis, item):
26222664
num_elements = self.sum(axis=axis, dtype=np.int64, item=item)
26232665
self._where_args = orig_where_args
26242666
return num_elements
2625-
if np.isscalar(axis):
2626-
axis = (axis,)
26272667
# Compute the number of elements in the array
26282668
shape = self.shape
2669+
if np.isscalar(axis):
2670+
axis = (axis,)
26292671
if item is not None:
26302672
# Compute the shape of the slice
2631-
if not isinstance(item, tuple):
2632-
item = (item,)
2633-
# Ensure that the limits in item slices are not None
2634-
item = tuple(slice(s.start or 0, s.stop or self.shape[i], s.step) for i, s in enumerate(item))
2635-
# Compute the intersection of the slice with the shape
2636-
item = tuple(slice(s1.start, min(s1.stop, s2)) for s1, s2 in zip(item, shape, strict=True))
2637-
if axis is None:
2638-
shape = [s.stop - s.start for s in item]
2639-
else:
2640-
shape = [s.stop - s.start for i, s in enumerate(item) if i in axis]
2641-
return math.prod(shape) if axis is None else math.prod([shape[i] for i in axis])
2673+
shape = ndindex.ndindex(item).newshape(shape)
2674+
axis = tuple(range(len(shape))) if axis is None else axis
2675+
axis = tuple(a if a >= 0 else a + len(shape) for a in axis) # handle negative indexing
2676+
return math.prod([shape[i] for i in axis])
26422677

26432678
def mean(self, axis=None, dtype=None, keepdims=False, **kwargs):
26442679
item = kwargs.pop("item", None)
@@ -2657,9 +2692,15 @@ def mean(self, axis=None, dtype=None, keepdims=False, **kwargs):
26572692

26582693
def std(self, axis=None, dtype=None, keepdims=False, ddof=0, **kwargs):
26592694
item = kwargs.pop("item", None)
2660-
mean_value = self.mean(axis=axis, dtype=dtype, keepdims=True, item=item)
2661-
expr = (self - mean_value) ** 2
2662-
out = expr.mean(axis=axis, dtype=dtype, keepdims=keepdims, item=item)
2695+
if item is None: # fast path
2696+
mean_value = self.mean(axis=axis, dtype=dtype, keepdims=True)
2697+
expr = (self - mean_value) ** 2
2698+
else:
2699+
mean_value = self.mean(axis=axis, dtype=dtype, keepdims=True, item=item)
2700+
# TODO: Not optimal because we load the whole slice in memory. Would have to write
2701+
# a bespoke std function that executed within slice_eval to avoid this probably.
2702+
expr = (self.slice(item) - mean_value) ** 2
2703+
out = expr.mean(axis=axis, dtype=dtype, keepdims=keepdims)
26632704
if ddof != 0:
26642705
num_elements = self.get_num_elements(axis, item)
26652706
out = np.sqrt(out * num_elements / (num_elements - ddof))
@@ -2675,14 +2716,18 @@ def std(self, axis=None, dtype=None, keepdims=False, ddof=0, **kwargs):
26752716

26762717
def var(self, axis=None, dtype=None, keepdims=False, ddof=0, **kwargs):
26772718
item = kwargs.pop("item", None)
2678-
mean_value = self.mean(axis=axis, dtype=dtype, keepdims=True, item=item)
2679-
expr = (self - mean_value) ** 2
2719+
if item is None: # fast path
2720+
mean_value = self.mean(axis=axis, dtype=dtype, keepdims=True)
2721+
expr = (self - mean_value) ** 2
2722+
else:
2723+
mean_value = self.mean(axis=axis, dtype=dtype, keepdims=True, item=item)
2724+
# TODO: Not optimal because we load the whole slice in memory. Would have to write
2725+
# a bespoke var function that executed within slice_eval to avoid this probably.
2726+
expr = (self.slice(item) - mean_value) ** 2
2727+
out = expr.mean(axis=axis, dtype=dtype, keepdims=keepdims)
26802728
if ddof != 0:
2681-
out = expr.mean(axis=axis, dtype=dtype, keepdims=keepdims, item=item)
26822729
num_elements = self.get_num_elements(axis, item)
26832730
out = out * num_elements / (num_elements - ddof)
2684-
else:
2685-
out = expr.mean(axis=axis, dtype=dtype, keepdims=keepdims, item=item)
26862731
out2 = kwargs.pop("out", None)
26872732
if out2 is not None:
26882733
out2[:] = out
@@ -3241,6 +3286,7 @@ def __getitem__(self, item):
32413286
# TODO: as this creates a big array, this can potentially consume a lot of memory
32423287
output = np.empty(self.shape, self.dtype)
32433288
# It is important to pass kwargs here, because chunks can be used internally
3289+
# fills numpy array with desired slice
32443290
chunked_eval(self.func, self.inputs_dict, item, _getitem=True, _output=output, **self.kwargs)
32453291
return output[item]
32463292
return self.res_getitem[item]

tests/ndarray/test_lazyexpr.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,23 +1078,27 @@ def test_eval_getitem(array_fixture):
10781078
np.testing.assert_allclose(expr[:10], nres[:10])
10791079
np.testing.assert_allclose(expr[0:10:2], nres[0:10:2])
10801080

1081+
1082+
def test_eval_getitem2():
10811083
# Small test for non-isomorphic shape
10821084
shape = (2, 10, 5)
1083-
test_arr = blosc2.linspace(0, 10, np.prod(shape), shape=shape)
1085+
test_arr = blosc2.linspace(0, 10, np.prod(shape), shape=shape, chunks=(1, 5, 1))
10841086
expr = test_arr * 30
10851087
nres = test_arr[:] * 30
10861088
np.testing.assert_allclose(expr[0], nres[0])
1087-
np.testing.assert_allclose(expr[:10], nres[:10])
1089+
np.testing.assert_allclose(expr[1:, :7], nres[1:, :7])
10881090
np.testing.assert_allclose(expr[0:10:2], nres[0:10:2])
1091+
# This works, but it is not very efficient since it relies on blosc2.ndarray.slice for non-unit steps
1092+
np.testing.assert_allclose(expr.slice((slice(None, None, None), slice(0, 10, 2)))[:], nres[:, 0:10:2])
10891093

10901094
# Small test for broadcasting
1091-
shape = (2, 10, 5)
1092-
test_arr = blosc2.linspace(0, 10, np.prod(shape), shape=shape)
1093-
expr = test_arr + test_arr.slice(slice(1, 2))
1095+
expr = test_arr + test_arr.slice(1)
10941096
nres = test_arr[:] + test_arr[1]
10951097
np.testing.assert_allclose(expr[0], nres[0])
1096-
np.testing.assert_allclose(expr[:10], nres[:10])
1097-
np.testing.assert_allclose(expr[0:10:2], nres[0:10:2])
1098+
np.testing.assert_allclose(expr[1:, :7], nres[1:, :7])
1099+
np.testing.assert_allclose(expr[:, 0:10:2], nres[:, 0:10:2])
1100+
# This works, but it is not very efficient since it relies on blosc2.ndarray.slice for non-unit steps
1101+
np.testing.assert_allclose(expr.slice((slice(None, None, None), slice(0, 10, 2)))[:], nres[:, 0:10:2])
10981102

10991103

11001104
# Test lazyexpr's slice method

0 commit comments

Comments
 (0)