Skip to content

Commit 33d6a5f

Browse files
committed
Handle foreign inputs to logaddexp, clip
1 parent e1307bf commit 33d6a5f

7 files changed

Lines changed: 59 additions & 27 deletions

File tree

ADD_LAZYFUNCS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ Once you have written a (public API) function in Blosc2, it is important to:
55
* Add it to the list of functions in ``__all__`` in the ``__init__.py`` file
66
* If it is present in numpy, add it to the relevant dictionary (``local_ufunc_map``, ``ufunc_map`` ``ufunc_map_1param``) in ``ndarray.py``
77

8+
If your function is implemented at the Blosc2 level (and not via either the `LazyUDF` or `LazyExpr` classes), you will need to add some conversion of the inputs to SimpleProxy instances (see e.g. ``matmul`` for an example).
9+
810
Finally, you also need to deal with it correctly within ``shape_utils.py``.
911

1012
If the function does not change the shape of the output, simply add it to ``elementwise_funcs`` and you're done.

src/blosc2/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def _raise(exc):
445445
result_type,
446446
can_cast,
447447
)
448-
from .proxy import Proxy, ProxySource, ProxyNDSource, ProxyNDField, SimpleProxy, jit
448+
from .proxy import Proxy, ProxySource, ProxyNDSource, ProxyNDField, SimpleProxy, jit, as_simpleproxy
449449

450450
from .schunk import SChunk, open
451451
from . import linalg
@@ -648,6 +648,7 @@ def _raise(exc):
648648
"asarray",
649649
"asin",
650650
"asinh",
651+
"as_simpleproxy",
651652
"astype",
652653
"atan",
653654
"atan2",

src/blosc2/lazyexpr.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -434,13 +434,13 @@ def convert_inputs(inputs):
434434
return []
435435
inputs_ = []
436436
for obj in inputs:
437-
if not isinstance(obj, blosc2.Array) and not np.isscalar(obj):
437+
if not isinstance(obj, (np.ndarray, blosc2.Operand)) and not np.isscalar(obj):
438438
try:
439-
obj = np.asarray(obj)
439+
obj = blosc2.SimpleProxy(obj)
440440
except Exception:
441441
print(
442442
"Inputs not being np.ndarray, Array or Python scalar objects"
443-
" should be convertible to np.ndarray."
443+
" should be convertible to SimpleProxy."
444444
)
445445
raise
446446
inputs_.append(obj)
@@ -1077,9 +1077,9 @@ def fill_chunk_operands( # noqa: C901
10771077
if nchunk == 0:
10781078
# Initialize the iterator for reading the chunks
10791079
# Take any operand (all should have the same shape and chunks)
1080-
arr = next(iter(operands.values()))
1080+
key, arr = next(iter(operands.items()))
10811081
chunks_idx, _ = get_chunks_idx(arr.shape, arr.chunks)
1082-
info = (reduc, aligned, low_mem, chunks_idx)
1082+
info = (reduc, aligned[key], low_mem, chunks_idx)
10831083
iter_chunks = read_nchunk(list(operands.values()), info)
10841084
# Run the asynchronous file reading function from a synchronous context
10851085
chunks = next(iter_chunks)
@@ -1095,7 +1095,7 @@ def fill_chunk_operands( # noqa: C901
10951095
# The chunk is a special zero chunk, so we can treat it as a scalar
10961096
chunk_operands[key] = np.zeros((), dtype=value.dtype)
10971097
continue
1098-
if aligned:
1098+
if aligned[key]:
10991099
buff = blosc2.decompress2(chunks[i])
11001100
bsize = value.dtype.itemsize * math.prod(chunks_)
11011101
chunk_operands[key] = np.frombuffer(buff[:bsize], dtype=value.dtype).reshape(chunks_)
@@ -1115,10 +1115,6 @@ def fill_chunk_operands( # noqa: C901
11151115
chunk_operands[key] = value[()]
11161116
continue
11171117

1118-
if isinstance(value, np.ndarray | blosc2.C2Array):
1119-
chunk_operands[key] = value[slice_]
1120-
continue
1121-
11221118
if not full_chunk or not isinstance(value, blosc2.NDArray):
11231119
# The chunk is not a full one, or has padding, or is not a blosc2.NDArray,
11241120
# so we need to go the slow path
@@ -1143,7 +1139,7 @@ def fill_chunk_operands( # noqa: C901
11431139
value.get_slice_numpy(chunk_operands[key], (starts, stops))
11441140
continue
11451141

1146-
if aligned:
1142+
if aligned[key]:
11471143
# Decompress the whole chunk and store it
11481144
buff = value.schunk.decompress_chunk(nchunk)
11491145
bsize = value.dtype.itemsize * math.prod(chunks_)
@@ -1203,7 +1199,10 @@ def fast_eval( # noqa: C901
12031199
if blocks is None:
12041200
blocks = basearr.blocks
12051201
# Check whether the partitions are aligned and behaved
1206-
aligned = blosc2.are_partitions_aligned(shape, chunks, blocks)
1202+
aligned = {
1203+
k: False if not hasattr(k, "chunks") else blosc2.are_partitions_aligned(k.shape, k.chunks, k.blocks)
1204+
for k in operands
1205+
}
12071206
behaved = blosc2.are_partitions_behaved(shape, chunks, blocks)
12081207

12091208
# Check that all operands are NDArray for fast path
@@ -1227,7 +1226,7 @@ def fast_eval( # noqa: C901
12271226
offset = tuple(s.start for s in cslice) # offset for the udf
12281227
chunks_ = tuple(s.stop - s.start for s in cslice)
12291228

1230-
full_chunk = chunks_ == chunks
1229+
full_chunk = chunks_ == chunks # slice is same as chunk
12311230
fill_chunk_operands(
12321231
operands, cslice, chunks_, full_chunk, aligned, nchunk, iter_disk, chunk_operands
12331232
)
@@ -1811,7 +1810,7 @@ def reduce_slices( # noqa: C901
18111810
same_chunks = all(operand.chunks == o.chunks for o in operands.values() if hasattr(o, "chunks"))
18121811
same_blocks = all(operand.blocks == o.blocks for o in operands.values() if hasattr(o, "blocks"))
18131812
fast_path = same_shape and same_chunks and same_blocks and (0 not in operand.chunks)
1814-
aligned, iter_disk = False, False
1813+
aligned, iter_disk = dict.fromkeys(operands.keys(), False), False
18151814
if fast_path:
18161815
chunks = operand.chunks
18171816
# Check that all operands are NDArray for fast path

src/blosc2/linalg.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
8080
raise ValueError("Arguments can't be scalars.")
8181

8282
# Makes a SimpleProxy if inputs are not blosc2 arrays
83-
x1, x2 = blosc2.asarray(x1), blosc2.asarray(x2)
83+
x1, x2 = blosc2.as_simpleproxy(x1), blosc2.as_simpleproxy(x2)
8484

8585
# Validate matrix multiplication compatibility
8686
if x1.shape[builtins.max(-1, -len(x2.shape))] != x2.shape[builtins.max(-2, -len(x2.shape))]:
@@ -183,7 +183,8 @@ def tensordot(
183183
fast_path = kwargs.pop("fast_path", None) # for testing purposes
184184
# TODO: add fast path for when don't need to change chunkshapes
185185

186-
x1, x2 = blosc2.asarray(x1), blosc2.asarray(x2)
186+
# Makes a SimpleProxy if inputs are not blosc2 arrays
187+
x1, x2 = blosc2.as_simpleproxy(x1), blosc2.as_simpleproxy(x2)
187188

188189
if isinstance(axes, tuple):
189190
a_axes, b_axes = axes
@@ -324,7 +325,8 @@ def vecdot(x1: blosc2.NDArray, x2: blosc2.NDArray, axis: int = -1, **kwargs) ->
324325
if isinstance(x1, np.ndarray) and isinstance(x2, np.ndarray):
325326
return npvecdot(x1, x2, axis=axis)
326327

327-
x1, x2 = blosc2.asarray(x1), blosc2.asarray(x2)
328+
# Makes a SimpleProxy if inputs are not blosc2 arrays
329+
x1, x2 = blosc2.as_simpleproxy(x1), blosc2.as_simpleproxy(x2)
328330

329331
N = builtins.min(x1.ndim, x2.ndim)
330332
if axis < -N or axis > -1:
@@ -466,8 +468,9 @@ def permute_dims(
466468
"""
467469
if np.isscalar(arr) or arr.ndim < 2:
468470
return arr
469-
if isinstance(arr, np.ndarray): # for array-api test compliance (does getitem for comparison)
470-
return np.permute_dims(arr, axes)
471+
472+
# Makes a SimpleProxy if input is not blosc2 array
473+
arr = blosc2.as_simpleproxy(arr)
471474

472475
ndim = arr.ndim
473476

@@ -535,7 +538,8 @@ def transpose(x, **kwargs: Any) -> blosc2.NDArray:
535538
# If arguments are dimension < 2, they are returned
536539
if np.isscalar(x) or x.ndim < 2:
537540
return x
538-
541+
# Makes a SimpleProxy if input is not blosc2 array
542+
x = blosc2.as_simpleproxy(x)
539543
# Validate arguments are dimension 2
540544
if x.ndim > 2:
541545
raise ValueError("Transposing arrays with dimension greater than 2 is not supported yet.")
@@ -559,6 +563,8 @@ def matrix_transpose(arr: blosc2.Array, **kwargs: Any) -> blosc2.NDArray:
559563
``(..., N, M)``.
560564
"""
561565
axes = None
566+
# Makes a SimpleProxy if input is not blosc2 array
567+
arr = blosc2.as_simpleproxy(arr)
562568
if not np.isscalar(arr) and arr.ndim > 2:
563569
axes = list(range(arr.ndim))
564570
axes[-2], axes[-1] = axes[-1], axes[-2]
@@ -592,6 +598,8 @@ def diagonal(x: blosc2.blosc2.NDArray, offset: int = 0) -> blosc2.blosc2.NDArray
592598
593599
Reference: https://data-apis.org/array-api/latest/extensions/generated/array_api.linalg.diag.html#diag
594600
"""
601+
# Makes a SimpleProxy if input is not blosc2 array
602+
x = blosc2.as_simpleproxy(x)
595603
n_rows, n_cols = x.shape[-2:]
596604
min_idx = builtins.min(n_rows, n_cols)
597605
if offset < 0:
@@ -628,7 +636,7 @@ def outer(x1: blosc2.blosc2.NDArray, x2: blosc2.blosc2.NDArray, **kwargs: Any) -
628636
out: blosc2.NDArray
629637
A two-dimensional array containing the outer product and whose shape is (N, M).
630638
"""
631-
x1, x2 = blosc2.asarray(x1), blosc2.asarray(x2)
639+
x1, x2 = blosc2.as_simpleproxy(x1), blosc2.as_simpleproxy(x2)
632640
if (x1.ndim != 1) or (x2.ndim != 1):
633641
raise ValueError("outer only valid for 1D inputs.")
634642
return tensordot(x1, x2, ((), ()), **kwargs) # for testing purposes

src/blosc2/ndarray.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5669,6 +5669,7 @@ def asarray(array: Sequence | blosc2.Array, copy: bool | None = None, **kwargs:
56695669
if blocks is None and hasattr(array, "blocks") and isinstance(array.blocks, (tuple, list)):
56705670
blocks = array.blocks
56715671

5672+
copy = True if copy is None and not isinstance(array, NDArray) else copy
56725673
if copy:
56735674
chunks, blocks = compute_chunks_blocks(array.shape, chunks, blocks, dtype_, **kwargs)
56745675
# Fast path for small arrays. This is not too expensive in terms of memory consumption.
@@ -5708,7 +5709,7 @@ def asarray(array: Sequence | blosc2.Array, copy: bool | None = None, **kwargs:
57085709
ndarr[slice_] = array_slice
57095710
else:
57105711
if not isinstance(array, NDArray):
5711-
return blosc2.SimpleProxy(array, chunks, blocks)
5712+
raise ValueError("Must always do a copy for asarray unless NDArray provided.")
57125713
# TODO: make a direct view possible
57135714
return array
57145715

src/blosc2/proxy.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,27 @@ def __getitem__(self, item: slice | list[slice]) -> np.ndarray:
670670
return np.asarray(self._src[item]) # avoids copy for PyTorch at least
671671

672672

673+
def as_simpleproxy(x: blosc2.Array) -> SimpleProxy | blosc2.Operand:
674+
"""
675+
Convert an Array object which fulfills Array protocol into SimpleProxy. If x is already a
676+
blosc2.Operand simply returns object.
677+
678+
Parameters
679+
----------
680+
x: blosc2.Array
681+
Object fulfilling Array protocol.
682+
683+
Returns
684+
-------
685+
out: blosc2.SimpleProxy | blosc2.Operand
686+
Object with minimal interface for blosc2 LazyExpr computations.
687+
"""
688+
if isinstance(x, blosc2.Operand):
689+
return x
690+
else:
691+
return SimpleProxy(x)
692+
693+
673694
def jit(func=None, *, out=None, disable=False, **kwargs): # noqa: C901
674695
"""
675696
Prepare a function so that it can be used with the Blosc2 compute engine.

tests/ndarray/test_elementwise_funcs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _test_unary_func_proxy(np_func, blosc_func, dtype, shape, xp):
188188
a_blosc[tuple(i // 4 for i in shape)] = 1 + 1j
189189
a_blosc[tuple(i // 2 for i in shape)] = xp.nan + xp.nan * 1j
190190
if dtype == blosc2.bool_ and np_func.__name__ == "arctanh":
191-
a_blosc = xp.zeros(shape=shape, dtype=dtype_)
191+
a_blosc = xp.zeros(shape, dtype=dtype_)
192192

193193
arr = np.asarray(a_blosc)
194194
success = False
@@ -223,7 +223,7 @@ def _test_binary_func_impl(np_func, blosc_func, dtype, shape, chunkshape): # no
223223
1, stop=np.prod(shape), num=np.prod(shape), chunks=chunkshape, shape=shape, dtype=dtype
224224
)
225225
if np_func.__name__ in ("right_shift", "left_shift"):
226-
a_blosc2 = blosc2.asarray(2)
226+
a_blosc2 = blosc2.asarray(2, copy=True)
227227
else:
228228
a_blosc2 = blosc2.linspace(
229229
start=np.prod(shape) * 2,
@@ -304,8 +304,8 @@ def test_unary_funcs(np_func, blosc_func, dtype, shape, chunkshape):
304304
@pytest.mark.parametrize("dtype", STR_DTYPES)
305305
@pytest.mark.parametrize("shape", [(10,), (20, 20)])
306306
@pytest.mark.parametrize("xp", [torch])
307-
def test_unfuncs_proxy(np_func, blosc_func, dtype, shape, chunkshape, xp):
308-
_test_unary_func_proxy(np_func, blosc_func, dtype, shape, chunkshape, xp)
307+
def test_unfuncs_proxy(np_func, blosc_func, dtype, shape, xp):
308+
_test_unary_func_proxy(np_func, blosc_func, dtype, shape, xp)
309309

310310

311311
@pytest.mark.heavy

0 commit comments

Comments
 (0)