Skip to content

Commit a8bc280

Browse files
committed
Support miniexpr for where inside reductions, e.g 'sum(where(a < b, b, a))'
1 parent 90565a8 commit a8bc280

3 files changed

Lines changed: 27 additions & 14 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ set(MINIEXPR_BUILD_BENCH OFF CACHE BOOL "Build miniexpr benchmarks" FORCE)
5858

5959
FetchContent_Declare(miniexpr
6060
GIT_REPOSITORY https://github.com/Blosc/miniexpr.git
61-
#GIT_TAG ndim # latest me_compile_nd()/me_eval_nd() APIs
62-
GIT_TAG 8c50850094e156ce568186edef667fabecbd00ff # latest commit in ndim
61+
GIT_TAG ndim # latest me_compile_nd()/me_eval_nd() APIs
6362
# In case you want to use a local copy of miniexpr for development, uncomment the line below
6463
# SOURCE_DIR "/Users/faltet/blosc/miniexpr"
6564
)

src/blosc2/lazyexpr.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,8 +1309,8 @@ def fast_eval( # noqa: C901
13091309
fp_accuracy = kwargs.pop("fp_accuracy", blosc2.FPAccuracy.DEFAULT)
13101310
res_eval = blosc2.uninit(shape, dtype, chunks=chunks, blocks=blocks, cparams=cparams, **kwargs)
13111311
try:
1312-
print("expr->miniexpr:", expression)
13131312
res_eval._set_pref_expr(expression, operands, fp_accuracy=fp_accuracy)
1313+
print("expr->miniexpr:", expression)
13141314
# Data to compress is fetched from operands, so it can be uninitialized here
13151315
data = np.empty(res_eval.schunk.chunksize, dtype=np.uint8)
13161316
# Exercise prefilter for each chunk
@@ -1993,7 +1993,7 @@ def reduce_slices( # noqa: C901
19931993
del temp
19941994

19951995
# miniexpr reduction path only supported for some cases so far
1996-
if not (where is None and fast_path and all_ndarray and reduced_shape == ()):
1996+
if not (fast_path and all_ndarray and reduced_shape == ()):
19971997
use_miniexpr = False
19981998

19991999
# Some reductions are not supported yet in miniexpr
@@ -2008,6 +2008,8 @@ def reduce_slices( # noqa: C901
20082008
)
20092009
if has_complex and any(tok in expression for tok in ("!=", "==", "<=", ">=", "<", ">")):
20102010
use_miniexpr = False
2011+
if where is not None and len(where) != 2:
2012+
use_miniexpr = False
20112013

20122014
if use_miniexpr:
20132015
# Experiments say that not splitting is best (at least on Apple Silicon M4 Pro)
@@ -2037,9 +2039,12 @@ def reduce_slices( # noqa: C901
20372039
# For other operations, zeros should be safe
20382040
aux_reduc = np.zeros(nblocks, dtype=dtype)
20392041
try:
2042+
if where is not None:
2043+
expression_miniexpr = f"{reduce_op_str}(where({expression}, _where_x, _where_y))"
2044+
else:
2045+
expression_miniexpr = f"{reduce_op_str}({expression})"
2046+
res_eval._set_pref_expr(expression_miniexpr, operands, fp_accuracy, aux_reduc)
20402047
print("expr->miniexpr:", expression, reduce_op)
2041-
expression = f"{reduce_op_str}({expression})"
2042-
res_eval._set_pref_expr(expression, operands, fp_accuracy, aux_reduc)
20432048
# Data won't even try to be compressed, so buffers can be unitialized and reused
20442049
data = np.empty(res_eval.schunk.chunksize, dtype=np.uint8)
20452050
chunk_data = np.empty(res_eval.schunk.chunksize + blosc2.MAX_OVERHEAD, dtype=np.uint8)

tests/ndarray/test_reductions.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,24 @@ def test_reduce_bool(array_fixture, reduce_op):
6565
np.testing.assert_allclose(res, nres, atol=tol, rtol=tol)
6666

6767

68-
def test_reduce_where(array_fixture):
68+
# @pytest.mark.parametrize("reduce_op", ["sum"])
69+
@pytest.mark.parametrize("reduce_op", ["sum", "prod", "min", "max", "any", "all", "argmax", "argmin"])
70+
def test_reduce_where(array_fixture, reduce_op):
6971
a1, a2, a3, a4, na1, na2, na3, na4 = array_fixture
70-
# The next works
71-
# res = blosc2.where(a1 < a2, a2, 0).sum()
72-
# nres = ne_evaluate("sum(where(na1 < na2, na2, 0))")
73-
# This does not work yet (it currently hangs)
74-
res = blosc2.where(a1 < a2, a2, a1).sum()
75-
nres = ne_evaluate("sum(where(na1 < na2, na2, na1))")
76-
print("res:", res, nres)
72+
if reduce_op == "prod":
73+
# To avoid overflow, create a1 and a2 with small values
74+
na1 = np.linspace(0, 0.1, np.prod(a1.shape), dtype=np.float32).reshape(a1.shape)
75+
a1 = blosc2.asarray(na1)
76+
na2 = np.linspace(0, 0.5, np.prod(a1.shape), dtype=np.float32).reshape(a1.shape)
77+
a2 = blosc2.asarray(na2)
78+
expr = a1 + a2 - 0.2
79+
nres = eval("na1 + na2 - .2")
80+
else:
81+
expr = blosc2.where(a1 < a2, a2, a1)
82+
nres = eval("np.where(na1 < na2, na2, na1)")
83+
res = getattr(expr, reduce_op)()
84+
nres = getattr(nres, reduce_op)()
85+
# print("res:", res, nres, type(res), type(nres))
7786
tol = 1e-15 if a1.dtype == "float64" else 1e-6
7887
np.testing.assert_allclose(res, nres, atol=tol, rtol=tol)
7988

0 commit comments

Comments
 (0)