Skip to content

Commit 08773e0

Browse files
author
Luke Shaw
committed
Add support for scalar lazyexprs
1 parent 1984bdf commit 08773e0

2 files changed

Lines changed: 27 additions & 10 deletions

File tree

src/blosc2/lazyexpr.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2274,6 +2274,9 @@ def __init__(self, new_op): # noqa: C901
22742274
self.expression = value1.expression if op is None else f"{op}({value1.expression})"
22752275
self.operands = value1.operands
22762276
else:
2277+
if np.isscalar(value1):
2278+
value1 = ne_evaluate(f"{op}({value1})")
2279+
op = None
22772280
self.operands = {"o0": value1}
22782281
self.expression = "o0" if op is None else f"{op}(o0)"
22792282
return
@@ -2293,7 +2296,8 @@ def __init__(self, new_op): # noqa: C901
22932296
return
22942297
elif op in funcs_2args:
22952298
if np.isscalar(value1) and np.isscalar(value2):
2296-
self.expression = f"{op}({value1}, {value2})"
2299+
self.expression = "o0"
2300+
self.operands = {"o0": ne_evaluate(f"{op}({value1}, {value2})")} # eager evaluation
22972301
elif np.isscalar(value2):
22982302
self.operands = {"o0": value1}
22992303
self.expression = f"{op}(o0, {value2})"
@@ -2307,7 +2311,8 @@ def __init__(self, new_op): # noqa: C901
23072311

23082312
self._dtype = dtype_
23092313
if np.isscalar(value1) and np.isscalar(value2):
2310-
self.expression = f"({value1} {op} {value2})"
2314+
self.expression = "o0"
2315+
self.operands = {"o0": ne_evaluate(f"({value1} {op} {value2})")} # eager evaluation
23112316
elif np.isscalar(value2):
23122317
self.operands = {"o0": value1}
23132318
self.expression = f"(o0 {op} {value2})"

tests/ndarray/test_lazyexpr.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -359,12 +359,7 @@ def test_functions(function, dtype_fixture, shape_fixture):
359359
)
360360
@pytest.mark.parametrize(
361361
("value1", "value2"),
362-
[
363-
("NDArray", "scalar"),
364-
("NDArray", "NDArray"),
365-
("scalar", "NDArray"),
366-
# ("scalar", "scalar") # Not supported by LazyExpr
367-
],
362+
[("NDArray", "scalar"), ("NDArray", "NDArray"), ("scalar", "NDArray"), ("scalar", "scalar")],
368363
)
369364
def test_arctan2_pow(urlpath, shape_fixture, dtype_fixture, function, value1, value2):
370365
nelems = np.prod(shape_fixture)
@@ -406,7 +401,7 @@ def test_arctan2_pow(urlpath, shape_fixture, dtype_fixture, function, value1, va
406401
else:
407402
expr_string = f"{function}(na1, value2)"
408403
res_numexpr = ne_evaluate(expr_string)
409-
else: # ("scalar", "NDArray")
404+
elif value2 == "NDArray": # ("scalar", "NDArray")
410405
value1 = 12
411406
na2 = np.linspace(0, 10, nelems, dtype=dtype_fixture).reshape(shape_fixture)
412407
a2 = blosc2.asarray(na2, urlpath=urlpath2, mode="w")
@@ -422,9 +417,21 @@ def test_arctan2_pow(urlpath, shape_fixture, dtype_fixture, function, value1, va
422417
else:
423418
expr_string = f"{function}(value1, na2)"
424419
res_numexpr = ne_evaluate(expr_string)
420+
else: # ("scalar", "scalar")
421+
value1 = 12
422+
value2 = 3
423+
# Construct the lazy expression based on the function name
424+
expr = blosc2.LazyExpr(new_op=(value1, function, value2))
425+
res_lazyexpr = expr.compute()
426+
# Evaluate using NumExpr
427+
if function == "**":
428+
res_numexpr = ne_evaluate("value1**value2")
429+
else:
430+
expr_string = f"{function}(value1, value2)"
431+
res_numexpr = ne_evaluate(expr_string)
425432
# Compare the results
426433
tol = 1e-15 if dtype_fixture == "float64" else 1e-6
427-
np.testing.assert_allclose(res_lazyexpr[:], res_numexpr, atol=tol, rtol=tol)
434+
np.testing.assert_allclose(res_lazyexpr[()], res_numexpr, atol=tol, rtol=tol)
428435

429436
for path in [urlpath1, urlpath2, urlpath_save]:
430437
blosc2.remove_urlpath(path)
@@ -1511,6 +1518,11 @@ def test_scalar_dtypes(values):
15111518
dtype2 = (avalue1 * avalue2).dtype
15121519
assert dtype1 == dtype2, f"Expected {dtype1} but got {dtype2}"
15131520

1521+
# test scalars
1522+
value = value1 if np.isscalar(value1) else value2
1523+
assert blosc2.sin(value)[()] == np.sin(value)
1524+
assert (value + blosc2.sin(value))[()] == value + np.sin(value)
1525+
15141526

15151527
def test_to_cframe():
15161528
N = 1_000

0 commit comments

Comments
 (0)