Skip to content

Commit 217032d

Browse files
committed
Add none indexing for lazyudf/lazyarray
1 parent 37acccb commit 217032d

4 files changed

Lines changed: 59 additions & 31 deletions

File tree

src/blosc2/lazyexpr.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -510,13 +510,15 @@ def check_smaller_shape(value_shape, shape, slice_shape, slice_):
510510
This follows the NumPy broadcasting rules.
511511
"""
512512
# slice_shape must be as long as shape
513-
if len(slice_shape) != len(shape):
514-
raise ValueError("slice_shape must be as long as shape")
513+
if len(slice_shape) != len(slice_):
514+
raise ValueError("slice_shape must be as long as slice_")
515+
no_nones_shape = tuple(sh for sh, s in zip(slice_shape, slice_, strict=True) if s is not None)
516+
no_nones_slice = tuple(s for sh, s in zip(slice_shape, slice_, strict=True) if s is not None)
515517
is_smaller_shape = any(
516-
s > (1 if i >= len(value_shape) else value_shape[i]) for i, s in enumerate(slice_shape)
518+
s > (1 if i >= len(value_shape) else value_shape[i]) for i, s in enumerate(no_nones_shape)
517519
)
518520
slice_past_bounds = any(
519-
s.stop > (1 if i >= len(value_shape) else value_shape[i]) for i, s in enumerate(slice_)
521+
s.stop > (1 if i >= len(value_shape) else value_shape[i]) for i, s in enumerate(no_nones_slice)
520522
)
521523
return len(value_shape) < len(shape) or is_smaller_shape or slice_past_bounds
522524

@@ -547,10 +549,31 @@ def compute_smaller_slice(larger_shape, smaller_shape, larger_slice):
547549
"""
548550
Returns the slice of the smaller array that corresponds to the slice of the larger array.
549551
"""
550-
diff_dims = len(larger_shape) - len(smaller_shape)
552+
j_small = len(smaller_shape) - 1
553+
j_large = len(larger_shape) - 1
554+
smaller_shape_nones = []
555+
larger_shape_nones = []
556+
for s in reversed(larger_slice):
557+
if s is None:
558+
smaller_shape_nones.append(1)
559+
larger_shape_nones.append(1)
560+
else:
561+
if j_small >= 0:
562+
smaller_shape_nones.append(smaller_shape[j_small])
563+
j_small -= 1
564+
if j_large >= 0:
565+
larger_shape_nones.append(larger_shape[j_large])
566+
j_large -= 1
567+
smaller_shape_nones.reverse()
568+
larger_shape_nones.reverse()
569+
diff_dims = len(larger_shape_nones) - len(smaller_shape_nones)
551570
return tuple(
552-
larger_slice[i] if smaller_shape[i - diff_dims] != 1 else slice(0, larger_shape[i])
553-
for i in range(diff_dims, len(larger_shape))
571+
None
572+
if larger_slice[i] is None
573+
else (
574+
larger_slice[i] if smaller_shape_nones[i - diff_dims] != 1 else slice(0, larger_shape_nones[i])
575+
)
576+
for i in range(diff_dims, len(larger_shape_nones))
554577
)
555578

556579

@@ -1694,7 +1717,6 @@ def slices_eval_getitem(
16941717
_slice_bcast = tuple(slice(i, i + 1) if isinstance(i, int) else i for i in _slice.raw)
16951718
slice_shape = ndindex.ndindex(_slice_bcast).newshape(shape) # includes dummy dimensions
16961719
_slice = _slice.raw
1697-
offset = tuple(s.start for s in _slice_bcast) # offset for the udf
16981720

16991721
# Get the slice of each operand
17001722
slice_operands = {}
@@ -1715,6 +1737,7 @@ def slices_eval_getitem(
17151737

17161738
# Evaluate the expression using slices of operands
17171739
if callable(expression):
1740+
offset = tuple(0 if s is None else s.start for s in _slice_bcast) # offset for the udf
17181741
result = np.empty(slice_shape, dtype=dtype)
17191742
expression(tuple(slice_operands.values()), result, offset=offset)
17201743
else:
@@ -2161,7 +2184,7 @@ def chunked_eval( # noqa: C901
21612184
"""
21622185
try:
21632186
# standardise slice to be ndindex.Tuple
2164-
item = () if item in (None, slice(None, None, None)) else item
2187+
item = () if item == slice(None, None, None) else item
21652188
item = item if isinstance(item, tuple) else (item,)
21662189
item = tuple(
21672190
slice(s.start, s.stop, 1 if s.step is None else s.step) if isinstance(s, slice) else s

tests/ndarray/test_getitem.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818
([77, 134, 13], [31, 13, 5], [7, 8, 3], (slice(3, 7), slice(50, 100), 7), np.float64),
1919
([77, 134, 13], [31, 13, 5], [7, 8, 3], (slice(3, 56, 3), slice(100, 50, -4), 7), np.float64),
2020
([12, 13, 14, 15, 16], [5, 5, 5, 5, 5], [2, 2, 2, 2, 2], (slice(1, 3), ..., slice(3, 6)), np.float32),
21+
(
22+
[12, 13, 14, 15, 16],
23+
[5, 5, 5, 5, 5],
24+
[2, 2, 2, 2, 2],
25+
(None, slice(1, 3), None, ..., slice(3, 6)),
26+
np.float32,
27+
),
2128
]
2229

2330

tests/ndarray/test_lazyexpr.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def test_simple_getitem(array_fixture):
9090
res = expr[sl]
9191
np.testing.assert_allclose(res, nres[sl])
9292

93+
# Test None indexing
94+
sl = (None, slice(3, 8), None)
95+
res = expr[sl]
96+
np.testing.assert_allclose(res, nres[sl])
97+
9398

9499
# Mix Proxy and NDArray operands
95100
def test_proxy_simple_getitem(array_fixture):
@@ -114,14 +119,13 @@ def test_mix_operands(array_fixture):
114119
np.testing.assert_allclose(expr[:], nres)
115120
np.testing.assert_allclose(expr.compute()[:], nres)
116121

117-
# TODO: fix this
118-
# expr = na2 + a1
119-
# nres = ne_evaluate("na2 + na1")
120-
# sl = slice(100)
121-
# res = expr[sl]
122-
# np.testing.assert_allclose(res, nres[sl])
123-
# np.testing.assert_allclose(expr[:], nres)
124-
# np.testing.assert_allclose(expr.compute()[:], nres)
122+
expr = na2 + a1
123+
nres = ne_evaluate("na2 + na1")
124+
sl = slice(100)
125+
res = expr[sl]
126+
np.testing.assert_allclose(res, nres[sl])
127+
np.testing.assert_allclose(expr[:], nres)
128+
np.testing.assert_allclose(expr.compute()[:], nres)
125129

126130
expr = a1 + na2 + a3
127131
nres = ne_evaluate("na1 + na2 + na3")
@@ -151,19 +155,13 @@ def test_mix_operands(array_fixture):
151155
np.testing.assert_allclose(expr[:], nres)
152156
np.testing.assert_allclose(expr.compute()[:], nres)
153157

154-
# TODO: support this case
155-
# expr = a1 + na2 * a3
156-
# print("--------------------------------------------------------")
157-
# print(type(expr))
158-
# print(expr.expression)
159-
# print(expr.operands)
160-
# print("--------------------------------------------------------")
161-
# nres = ne_evaluate("na1 + na2 * na3")
162-
# sl = slice(100)
163-
# res = expr[sl]
164-
# np.testing.assert_allclose(res, nres[sl])
165-
# np.testing.assert_allclose(expr[:], nres)
166-
# np.testing.assert_allclose(expr.compute()[:], nres)
158+
expr = a1 + na2 * a3
159+
nres = ne_evaluate("na1 + na2 * na3")
160+
sl = slice(100)
161+
res = expr[sl]
162+
np.testing.assert_allclose(res, nres[sl])
163+
np.testing.assert_allclose(expr[:], nres)
164+
np.testing.assert_allclose(expr.compute()[:], nres)
167165

168166

169167
# Add more test functions to test different aspects of the code

tests/ndarray/test_lazyudf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def test_params(chunked_eval):
259259
[
260260
((40, 20), (30, 10), (5, 5), (slice(0, 5), slice(5, 20)), "eval.b2nd", False),
261261
((13, 13, 10), (10, 10, 5), (5, 5, 3), (slice(0, 12), slice(3, 13), ...), "eval.b2nd", True),
262-
((13, 13), (10, 10), (5, 5), (slice(3, 8), slice(9, 12)), None, False),
262+
((13, 13), (10, 10), (5, 5), (slice(3, 8), None, slice(9, 12)), None, False),
263263
],
264264
)
265265
def test_getitem(shape, chunks, blocks, slices, urlpath, contiguous, chunked_eval):

0 commit comments

Comments
 (0)