Skip to content

Commit f9c0324

Browse files
committed
Broaden the use cases for miniexpr to all 1-dim cases
1 parent 3c639ed commit f9c0324

2 files changed

Lines changed: 20 additions & 10 deletions

File tree

src/blosc2/lazyexpr.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,9 +1301,9 @@ def fast_eval( # noqa: C901
13011301
if not (isinstance(op, blosc2.NDArray) and op.urlpath is None and out is None):
13021302
use_miniexpr = False
13031303
break
1304-
# Ensure blocks fit exactly in chunks
1304+
# Ensure blocks fit exactly in chunks for the n-dim case
13051305
blocks_fit = builtins.all(c % b == 0 for c, b in zip(op.chunks, op.blocks, strict=True))
1306-
if not blocks_fit:
1306+
if len(op.shape) != 1 and not blocks_fit:
13071307
use_miniexpr = False
13081308
break
13091309

@@ -2016,9 +2016,9 @@ def reduce_slices( # noqa: C901
20162016
if has_complex and any(tok in expression for tok in ("!=", "==", "<=", ">=", "<", ">")):
20172017
use_miniexpr = False
20182018
for op in operands.values():
2019-
# Check that chunksize is multiple of blocksize and blocks fit exactly in chunks
2019+
# Ensure blocks fit exactly in chunks for the n-dim case
20202020
blocks_fit = builtins.all(c % b == 0 for c, b in zip(op.chunks, op.blocks, strict=True))
2021-
if not blocks_fit:
2021+
if len(op.shape) != 1 and not blocks_fit:
20222022
use_miniexpr = False
20232023
break
20242024

tests/ndarray/test_lazyexpr.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,10 @@ def test_complex_evaluate(array_fixture):
226226
expr += 2
227227
nres = ne_evaluate("tan(na1) * (sin(na2) * sin(na2) + cos(na3)) + (sqrt(na4) * 2) + 2")
228228
res = expr.compute()
229-
np.testing.assert_allclose(res[:], nres)
229+
if na1.dtype == np.float32:
230+
np.testing.assert_allclose(res[:], nres, rtol=1e-5)
231+
else:
232+
np.testing.assert_allclose(res[:], nres)
230233

231234

232235
def test_complex_getitem(array_fixture):
@@ -235,7 +238,10 @@ def test_complex_getitem(array_fixture):
235238
expr += 2
236239
nres = ne_evaluate("tan(na1) * (sin(na2) * sin(na2) + cos(na3)) + (sqrt(na4) * 2) + 2")
237240
res = expr[:]
238-
np.testing.assert_allclose(res, nres)
241+
if na1.dtype == np.float32:
242+
np.testing.assert_allclose(res[:], nres, rtol=1e-5)
243+
else:
244+
np.testing.assert_allclose(res[:], nres)
239245

240246

241247
def test_complex_getitem_slice(array_fixture):
@@ -253,19 +259,23 @@ def test_func_expression(array_fixture):
253259
expr = (a1 + a2) * a3 - a4
254260
expr = blosc2.sin(expr) + blosc2.cos(expr)
255261
nres = ne_evaluate("sin((na1 + na2) * na3 - na4) + cos((na1 + na2) * na3 - na4)")
256-
res = expr.compute(storage={})
257-
np.testing.assert_allclose(res[:], nres)
262+
res = expr.compute()
263+
if na1.dtype == np.float32:
264+
np.testing.assert_allclose(res[:], nres, rtol=1e-5)
265+
else:
266+
np.testing.assert_allclose(res[:], nres)
258267

259268

260269
def test_expression_with_constants(array_fixture):
261270
a1, a2, a3, a4, na1, na2, na3, na4 = array_fixture
262271
# Test with operands with same chunks and blocks
263272
expr = a1 + 2 - a3 * 3.14
264273
nres = ne_evaluate("na1 + 2 - na3 * 3.14")
274+
res = expr.compute()
265275
if na1.dtype == np.float32:
266-
np.testing.assert_allclose(expr[:], nres, rtol=1e-6)
276+
np.testing.assert_allclose(res[:], nres, rtol=1e-5)
267277
else:
268-
np.testing.assert_allclose(expr[:], nres)
278+
np.testing.assert_allclose(res[:], nres)
269279

270280

271281
@pytest.mark.parametrize("compare_expressions", [True, False])

0 commit comments

Comments
 (0)