Skip to content

Commit ef132c1

Browse files
committed
Add UnaryOp handling to infer_shape
1 parent ed4f37b commit ef132c1

5 files changed

Lines changed: 67 additions & 3 deletions

File tree

src/blosc2/lazyexpr.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
process_key,
4848
)
4949

50+
from .proxy import _convert_dtype
5051
from .shape_utils import constructors, elementwise_funcs, infer_shape, linalg_attrs, linalg_funcs, reducers
5152

5253
if not blosc2.IS_WASM:
@@ -2213,7 +2214,9 @@ def result_type(
22132214
# Follow NumPy rules for scalar-array operations
22142215
# Create small arrays with the same dtypes and let NumPy's type promotion determine the result type
22152216
arrs = [
2216-
value if (np.isscalar(value) or not hasattr(value, "dtype")) else np.array([0], dtype=value.dtype)
2217+
value
2218+
if (np.isscalar(value) or not hasattr(value, "dtype"))
2219+
else np.array([0], dtype=_convert_dtype(value.dtype))
22172220
for value in arrays_and_dtypes
22182221
]
22192222
return np.result_type(*arrs)

src/blosc2/proxy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,8 @@ def _convert_dtype(dt: str | DTypeLike):
579579
"""
580580
Attempts to convert to blosc2.dtype (i.e. numpy dtype)
581581
"""
582+
if hasattr(dt, "as_numpy_dtype"):
583+
dt = dt.as_numpy_dtype
582584
try:
583585
return np.dtype(dt)
584586
except TypeError: # likely passed e.g. a torch.float64
@@ -616,7 +618,7 @@ def __init__(self, src, chunks: tuple | None = None, blocks: tuple | None = None
616618
raise TypeError("The source must have a __getitem__ method")
617619
self._src = src
618620
self._dtype = _convert_dtype(src.dtype)
619-
self._shape = src.shape
621+
self._shape = src.shape if isinstance(src.shape, tuple) else tuple(src.shape)
620622
# Compute reasonable values for chunks and blocks
621623
cparams = blosc2.CParams(clevel=0)
622624

src/blosc2/shape_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,9 @@ def visit_BinOp(self, node):
520520
right = self.visit(node.right)
521521
return elementwise(left, right)
522522

523+
def visit_UnaryOp(self, node):
524+
return self.visit(node.operand)
525+
523526
def _eval_slice(self, node):
524527
if isinstance(node, ast.Slice):
525528
return slice(

tests/ndarray/test_lazyexpr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1195,7 +1195,7 @@ def test_fill_disk_operands(chunks, blocks, disk, fill_value):
11951195
b = blosc2.open("b.b2nd")
11961196
c = blosc2.open("c.b2nd")
11971197

1198-
expr = ((a**3 + blosc2.sin(c * 2)) < b) & (c > 0)
1198+
expr = ((a**3 + blosc2.sin(c * 2)) < b) & ~(c > 0)
11991199

12001200
out = expr.compute()
12011201
assert out.shape == (N, N)

tests/ndarray/test_proxy.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,15 @@
66
# LICENSE file in the root directory of this source tree)
77
#######################################################################
88

9+
import dask.array as da
10+
import jax.numpy as jnp
911
import numpy as np
1012
import pytest
13+
import tensorflow as tf
14+
import torch
15+
16+
# TODO: import cupy as cp
17+
import zarr
1118

1219
import blosc2
1320
from blosc2.ndarray import get_chunks_idx
@@ -172,3 +179,52 @@ def get_chunk(self, nchunk):
172179
proxy = blosc2.Proxy(source)
173180
result = proxy[...]
174181
np.testing.assert_array_equal(result, data)
182+
183+
184+
@pytest.mark.parametrize(
185+
"xp",
186+
[torch, tf, np, jnp, da, zarr],
187+
)
188+
@pytest.mark.parametrize(
189+
"dtype",
190+
["bool", "int32", "int64", "float32", "float64", "complex128"],
191+
)
192+
def test_simpleproxy(xp, dtype):
193+
dtype_ = getattr(xp, dtype) if hasattr(xp, dtype) else np.dtype(dtype)
194+
if dtype == "bool":
195+
blosc_matrix = blosc2.asarray([True, False, False], dtype=np.dtype(dtype), chunks=(2,))
196+
foreign_matrix = xp.zeros((3,), dtype=dtype_)
197+
# Create a lazy expression object
198+
lexpr = blosc2.lazyexpr(
199+
"(b & a) | (~b)", operands={"a": blosc_matrix, "b": foreign_matrix}
200+
) # this does not
201+
# Compare with numpy computation result
202+
npb = np.asarray(foreign_matrix)
203+
npa = blosc_matrix[()]
204+
res = (npb & npa) | np.logical_not(npb)
205+
else:
206+
N = 10
207+
shape_a = (N, N, N)
208+
blosc_matrix = blosc2.full(shape=shape_a, fill_value=3, dtype=np.dtype(dtype), chunks=(N // 3,) * 3)
209+
foreign_matrix = xp.ones(shape_a, dtype=dtype_)
210+
if dtype == "complex128":
211+
foreign_matrix += 0.5j
212+
blosc_matrix = blosc2.full(
213+
shape=shape_a, fill_value=3 + 2j, dtype=np.dtype(dtype), chunks=(N // 3,) * 3
214+
)
215+
216+
# Create a lazy expression object
217+
lexpr = blosc2.lazyexpr(
218+
"b + sin(a) + sum(b) - tensordot(a, b, axes=1)",
219+
operands={"a": blosc_matrix, "b": foreign_matrix},
220+
) # this does not
221+
# Compare with numpy computation result
222+
npb = np.asarray(foreign_matrix)
223+
npa = blosc_matrix[()]
224+
res = npb + np.sin(npa) + np.sum(npb) - np.tensordot(npa, npb, axes=1)
225+
226+
# Test object metadata and result
227+
assert isinstance(lexpr, blosc2.LazyExpr)
228+
assert lexpr.dtype == res.dtype
229+
assert lexpr.shape == res.shape
230+
np.testing.assert_array_equal(lexpr[()], res)

0 commit comments

Comments
 (0)