Skip to content

Commit 61ef8f4

Browse files
committed
Add a test for reductions with where
1 parent 423a448 commit 61ef8f4

1 file changed

Lines changed: 13 additions & 0 deletions

File tree

tests/ndarray/test_reductions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,19 @@ def test_reduce_bool(array_fixture, reduce_op):
6565
np.testing.assert_allclose(res, nres, atol=tol, rtol=tol)
6666

6767

68+
def test_reduce_where(array_fixture):
69+
a1, a2, a3, a4, na1, na2, na3, na4 = array_fixture
70+
# The next works
71+
# res = blosc2.where(a1 < a2, a2, 0).sum()
72+
# nres = ne_evaluate("sum(where(na1 < na2, na2, 0))")
73+
# This does not work yet (it currently hangs)
74+
res = blosc2.where(a1 < a2, a2, a1).sum()
75+
nres = ne_evaluate("sum(where(na1 < na2, na2, na1))")
76+
print("res:", res, nres)
77+
tol = 1e-15 if a1.dtype == "float64" else 1e-6
78+
np.testing.assert_allclose(res, nres, atol=tol, rtol=tol)
79+
80+
6881
@pytest.mark.parametrize(
6982
"reduce_op", ["sum", "prod", "mean", "std", "var", "min", "max", "any", "all", "argmax", "argmin"]
7083
)

0 commit comments

Comments
 (0)