Skip to content

Commit 78bf78d

Browse files
author
Luke Shaw
committed
Ensure casting to SimpleProxy for non-string expressions
1 parent f344957 commit 78bf78d

6 files changed

Lines changed: 246 additions & 62 deletions

File tree

src/blosc2/lazyexpr.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2258,6 +2258,12 @@ def __init__(self, new_op): # noqa: C901
22582258
return
22592259
value1, op, value2 = new_op
22602260
dtype_ = check_dtype(op, value1, value2) # perform some checks
2261+
# Check that operands are proper Operands, LazyArray or scalars; if not, convert to NDArray objects
2262+
value1 = (
2263+
blosc2.SimpleProxy(value1)
2264+
if not (isinstance(value1, (blosc2.Operand, np.ndarray)) or np.isscalar(value1))
2265+
else value1
2266+
)
22612267
if value2 is None:
22622268
if isinstance(value1, LazyExpr):
22632269
self.expression = value1.expression if op is None else f"{op}({value1.expression})"
@@ -2266,7 +2272,12 @@ def __init__(self, new_op): # noqa: C901
22662272
self.operands = {"o0": value1}
22672273
self.expression = "o0" if op is None else f"{op}(o0)"
22682274
return
2269-
elif isinstance(value1, LazyExpr) or isinstance(value2, LazyExpr):
2275+
value2 = (
2276+
blosc2.SimpleProxy(value2)
2277+
if not (isinstance(value2, (blosc2.Operand, np.ndarray)) or np.isscalar(value2))
2278+
else value2
2279+
)
2280+
if isinstance(value1, LazyExpr) or isinstance(value2, LazyExpr):
22702281
if isinstance(value1, LazyExpr):
22712282
newexpr = value1.update_expr(new_op)
22722283
else:
@@ -2739,7 +2750,7 @@ def find_args(expr):
27392750

27402751
def _compute_expr(self, item, kwargs): # noqa : C901
27412752
# ne_evaluate will need safe_blosc2_globals for some functions (e.g. clip, logaddexp)
2742-
# that are implemenetd in python-blosc2 not in numexpr
2753+
# that are implemented in python-blosc2 not in numexpr
27432754
global safe_blosc2_globals
27442755
if len(safe_blosc2_globals) == 0:
27452756
# First eval call, fill blosc2_safe_globals for ne_evaluate
@@ -3011,7 +3022,7 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
30113022
_operands = operands | local_vars
30123023
# Check that operands are proper Operands, LazyArray or scalars; if not, convert to NDArray objects
30133024
for op, val in _operands.items():
3014-
if not (isinstance(val, (blosc2.Operand, blosc2.LazyArray, np.ndarray)) or np.isscalar(val)):
3025+
if not (isinstance(val, (blosc2.Operand, np.ndarray)) or np.isscalar(val)):
30153026
_operands[op] = blosc2.SimpleProxy(val)
30163027
# for scalars just return value (internally converts to () if necessary)
30173028
opshapes = {k: v if not hasattr(v, "shape") else v.shape for k, v in _operands.items()}

src/blosc2/linalg.py

Lines changed: 32 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010

1111
import blosc2
12-
from blosc2.ndarray import get_intersecting_chunks, npvecdot, slice_to_chunktuple
12+
from blosc2.ndarray import get_intersecting_chunks, nptranspose, npvecdot, slice_to_chunktuple
1313

1414
if TYPE_CHECKING:
1515
from collections.abc import Sequence
@@ -79,9 +79,8 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
7979
if np.isscalar(x1) or np.isscalar(x2):
8080
raise ValueError("Arguments can't be scalars.")
8181

82-
# Added this to pass array-api tests (which use internal getitem to check results)
83-
x1 = blosc2.asarray(x1)
84-
x2 = blosc2.asarray(x2)
82+
# Makes a SimpleProxy if inputs are not blosc2 arrays
83+
x1, x2 = blosc2.asarray(x1), blosc2.asarray(x2)
8584

8685
# Validate matrix multiplication compatibility
8786
if x1.shape[builtins.max(-1, -len(x2.shape))] != x2.shape[builtins.max(-2, -len(x2.shape))]:
@@ -183,9 +182,6 @@ def tensordot(
183182
"""
184183
fast_path = kwargs.pop("fast_path", None) # for testing purposes
185184
# TODO: add fast path for when don't need to change chunkshapes
186-
# Added this to pass array-api tests (which use internal getitem to check results)
187-
if isinstance(x1, np.ndarray) and isinstance(x2, np.ndarray):
188-
return np.tensordot(x1, x2, axes=axes)
189185

190186
x1, x2 = blosc2.asarray(x1), blosc2.asarray(x2)
191187

@@ -261,24 +257,8 @@ def tensordot(
261257
a_selection = tuple(next(rchunk_iter) if a else slice(None, None, 1) for a in a_keep)
262258
b_selection = tuple(next(rchunk_iter) if b else slice(None, None, 1) for b in b_keep)
263259
res_chunks = tuple(s.stop - s.start for s in res_chunk)
264-
265-
if fast_path: # just load everything
266-
bx1 = x1[a_selection]
267-
bx2 = x2[b_selection]
268-
newshape_a = (
269-
math.prod([bx1.shape[i] for i in a_keep_axes]),
270-
math.prod([bx1.shape[a] for a in a_axes]),
271-
)
272-
newshape_b = (
273-
math.prod([bx2.shape[b] for b in b_axes]),
274-
math.prod([bx2.shape[i] for i in b_keep_axes]),
275-
)
276-
at = bx1.transpose(newaxes_a).reshape(newshape_a)
277-
bt = bx2.transpose(newaxes_b).reshape(newshape_b)
278-
res = np.dot(at, bt)
279-
result[res_chunk] += res.reshape(res_chunks)
280-
else: # operands too big, have to go chunk-by-chunk
281-
for ochunk in product(*op_chunks):
260+
for ochunk in product(*op_chunks):
261+
if not fast_path: # operands too big, have to go chunk-by-chunk
282262
op_chunk = tuple(
283263
slice(rc * rcs, builtins.min((rc + 1) * rcs, x1s), 1)
284264
for rc, rcs, x1s in zip(ochunk, a_chunks_red, a_shape_red, strict=True)
@@ -293,21 +273,23 @@ def tensordot(
293273
op_chunk[next(order_iter)] if not b else bs_
294274
for bs_, b in zip(b_selection, b_keep, strict=True)
295275
)
296-
bx1 = x1[a_selection]
297-
bx2 = x2[b_selection]
298-
# adapted from numpy tensordot
299-
newshape_a = (
300-
math.prod([bx1.shape[i] for i in a_keep_axes]),
301-
math.prod([bx1.shape[a] for a in a_axes]),
302-
)
303-
newshape_b = (
304-
math.prod([bx2.shape[b] for b in b_axes]),
305-
math.prod([bx2.shape[i] for i in b_keep_axes]),
306-
)
307-
at = bx1.transpose(newaxes_a).reshape(newshape_a)
308-
bt = bx2.transpose(newaxes_b).reshape(newshape_b)
309-
res = np.dot(at, bt)
310-
result[res_chunk] += res.reshape(res_chunks)
276+
bx1 = x1[a_selection]
277+
bx2 = x2[b_selection]
278+
# adapted from numpy tensordot
279+
newshape_a = (
280+
math.prod([bx1.shape[i] for i in a_keep_axes]),
281+
math.prod([bx1.shape[a] for a in a_axes]),
282+
)
283+
newshape_b = (
284+
math.prod([bx2.shape[b] for b in b_axes]),
285+
math.prod([bx2.shape[i] for i in b_keep_axes]),
286+
)
287+
at = nptranspose(bx1, newaxes_a).reshape(newshape_a)
288+
bt = nptranspose(bx2, newaxes_b).reshape(newshape_b)
289+
res = np.dot(at, bt)
290+
result[res_chunk] += res.reshape(res_chunks)
291+
if fast_path: # already done everything
292+
break
311293
return result
312294

313295

@@ -396,19 +378,17 @@ def vecdot(x1: blosc2.NDArray, x2: blosc2.NDArray, axis: int = -1, **kwargs) ->
396378
)
397379
b_selection = tuple(next(rchunk_iter) if b else slice(None, None, 1) for b in b_keep)
398380

399-
if fast_path: # just load everything, also handles case of 0 in shapes
400-
bx1 = x1[a_selection]
401-
bx2 = x2[b_selection]
402-
result[res_chunk] += npvecdot(bx1, bx2, axis=axis) # handles conjugation of bx1
403-
else: # operands too big, have to go chunk-by-chunk
404-
for ochunk in range(0, a_shape_red, a_chunks_red):
381+
for ochunk in range(0, a_shape_red, a_chunks_red):
382+
if not fast_path: # operands too big, go chunk-by-chunk
405383
op_chunk = (slice(ochunk, builtins.min(ochunk + a_chunks_red, x1.shape[a_axes]), 1),)
406384
a_selection = a_selection[:a_axes] + op_chunk + a_selection[a_axes + 1 :]
407385
b_selection = b_selection[:b_axes] + op_chunk + b_selection[b_axes + 1 :]
408-
bx1 = x1[a_selection]
409-
bx2 = x2[b_selection]
410-
res = npvecdot(bx1, bx2, axis=axis) # handles conjugation of bx1
411-
result[res_chunk] += res
386+
bx1 = x1[a_selection]
387+
bx2 = x2[b_selection]
388+
res = npvecdot(bx1, bx2, axis=axis) # handles conjugation of bx1
389+
result[res_chunk] += res
390+
if fast_path: # already done everything
391+
break
412392
return result
413393

414394

@@ -517,7 +497,7 @@ def permute_dims(
517497
src_slice = tuple(slice(start, stop) for start, stop in start_stop)
518498
dst_slice = tuple(slice(start_stop[ax][0], start_stop[ax][1]) for ax in axes)
519499

520-
transposed = np.transpose(arr[src_slice], axes=axes)
500+
transposed = nptranspose(arr[src_slice], axes=axes)
521501
result[dst_slice] = np.ascontiguousarray(transposed)
522502

523503
return result
@@ -648,6 +628,7 @@ def outer(x1: blosc2.blosc2.NDArray, x2: blosc2.blosc2.NDArray, **kwargs: Any) -
648628
out: blosc2.NDArray
649629
A two-dimensional array containing the outer product and whose shape is (N, M).
650630
"""
631+
x1, x2 = blosc2.asarray(x1), blosc2.asarray(x2)
651632
if (x1.ndim != 1) or (x2.ndim != 1):
652633
raise ValueError("outer only valid for 1D inputs.")
653634
return tensordot(x1, x2, ((), ()), **kwargs) # for testing purposes

src/blosc2/ndarray.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@
4242
nprshift = np.bitwise_right_shift
4343
npbinvert = np.bitwise_invert
4444
npvecdot = np.vecdot
45+
nptranspose = np.permute_dims
4546
else: # not array-api compliant
4647
nplshift = np.left_shift
4748
nprshift = np.right_shift
4849
npbinvert = np.bitwise_not
50+
nptranspose = np.transpose
4951

5052
def npvecdot(a, b, axis=-1):
5153
return np.einsum("...i,...i->...", np.moveaxis(np.conj(a), axis, -1), np.moveaxis(b, axis, -1))
@@ -2969,7 +2971,8 @@ def chunkwise_clip(inputs, output, offset):
29692971
x, min, max = inputs
29702972
output[:] = np.clip(x, min, max)
29712973

2972-
return blosc2.lazyudf(chunkwise_clip, (x, min, max), dtype=x.dtype, shape=x.shape, **kwargs)
2974+
dtype = blosc2.result_type(x)
2975+
return blosc2.lazyudf(chunkwise_clip, (x, min, max), dtype=dtype, shape=x.shape, **kwargs)
29732976

29742977

29752978
def logaddexp(x1: int | float | blosc2.Array, x2: int | float | blosc2.Array, **kwargs: Any) -> NDArray:
@@ -3001,7 +3004,7 @@ def chunkwise_logaddexp(inputs, output, offset):
30013004
x1, x2 = inputs
30023005
output[:] = np.logaddexp(x1, x2)
30033006

3004-
dtype = blosc2.result_type(x1.dtype, x2.dtype)
3007+
dtype = blosc2.result_type(x1, x2)
30053008
if dtype == blosc2.bool_:
30063009
raise TypeError("logaddexp doesn't accept boolean arguments.")
30073010

@@ -5653,7 +5656,8 @@ def asarray(array: Sequence | blosc2.Array, copy: bool | None = None, **kwargs:
56535656
raise ValueError("Only unsafe casting is supported at the moment.")
56545657
if not hasattr(array, "shape"):
56555658
array = np.asarray(array) # defaults if dtype=None
5656-
dtype = kwargs.pop("dtype", array.dtype) # check if dtype provided
5659+
dtype_ = blosc2.proxy._convert_dtype(array.dtype)
5660+
dtype = kwargs.pop("dtype", dtype_) # check if dtype provided
56575661
kwargs = _check_ndarray_kwargs(**kwargs)
56585662
chunks = kwargs.pop("chunks", None)
56595663
blocks = kwargs.pop("blocks", None)
@@ -5664,14 +5668,13 @@ def asarray(array: Sequence | blosc2.Array, copy: bool | None = None, **kwargs:
56645668
# Let's avoid this
56655669
if blocks is None and hasattr(array, "blocks") and isinstance(array.blocks, (tuple, list)):
56665670
blocks = array.blocks
5667-
chunks, blocks = compute_chunks_blocks(array.shape, chunks, blocks, array.dtype, **kwargs)
56685671

5669-
copy = True if copy is None and not isinstance(array, NDArray) else copy
56705672
if copy:
5673+
chunks, blocks = compute_chunks_blocks(array.shape, chunks, blocks, dtype_, **kwargs)
56715674
# Fast path for small arrays. This is not too expensive in terms of memory consumption.
56725675
shape = array.shape
56735676
small_size = 2**24 # 16 MB
5674-
array_nbytes = math.prod(shape) * array.dtype.itemsize
5677+
array_nbytes = math.prod(shape) * dtype_.itemsize
56755678
if array_nbytes < small_size:
56765679
if not isinstance(array, np.ndarray) and hasattr(array, "chunks"):
56775680
# A getitem operation should be enough to get a numpy array
@@ -5682,7 +5685,7 @@ def asarray(array: Sequence | blosc2.Array, copy: bool | None = None, **kwargs:
56825685
return blosc2_ext.asarray(array, chunks, blocks, **kwargs)
56835686

56845687
# Create the empty array
5685-
ndarr = empty(shape, array.dtype, chunks=chunks, blocks=blocks, **kwargs)
5688+
ndarr = empty(shape, dtype_, chunks=chunks, blocks=blocks, **kwargs)
56865689
behaved = are_partitions_behaved(shape, chunks, blocks)
56875690

56885691
# Get the coordinates of the chunks
@@ -5705,7 +5708,7 @@ def asarray(array: Sequence | blosc2.Array, copy: bool | None = None, **kwargs:
57055708
ndarr[slice_] = array_slice
57065709
else:
57075710
if not isinstance(array, NDArray):
5708-
raise ValueError("Must always do a copy for asarray unless NDArray provided.")
5711+
return blosc2.SimpleProxy(array, chunks, blocks)
57095712
# TODO: make a direct view possible
57105713
return array
57115714

src/blosc2/proxy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,11 @@ def dtype(self):
649649
"""The data type of the source array."""
650650
return self._dtype
651651

652+
@property
653+
def ndim(self):
654+
"""The number of dimensions of the source array."""
655+
return len(self.shape)
656+
652657
def __getitem__(self, item: slice | list[slice]) -> np.ndarray:
653658
"""
654659
Get a slice as a numpy.ndarray (via this proxy).

0 commit comments

Comments
 (0)