Skip to content

Commit a7096e5

Browse files
committed
Add linalg tests, modify squeeze for array api compliance
1 parent 9057c20 commit a7096e5

7 files changed

Lines changed: 155 additions & 89 deletions

File tree

src/blosc2/lazyexpr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2900,7 +2900,7 @@ def __getitem__(self, item):
29002900
# Squeeze single-element dimensions when indexing with integers
29012901
# See e.g. examples/ndarray/animated_plot.py
29022902
if isinstance(item, int) or (hasattr(item, "__iter__") and any(isinstance(i, int) for i in item)):
2903-
result = result.squeeze()
2903+
result = result.squeeze(axis=tuple(i for i in range(result.ndim) if result.shape[i] == 1))
29042904
return result
29052905

29062906
def slice(self, item):

src/blosc2/linalg.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
8080
raise ValueError("Arguments can't be scalars.")
8181

8282
# Makes a SimpleProxy if inputs are not blosc2 arrays
83-
x1, x2 = blosc2.as_simpleproxy(x1), blosc2.as_simpleproxy(x2)
83+
x1, x2 = blosc2.as_simpleproxy(x1, x2)
8484

8585
# Validate matrix multiplication compatibility
8686
if x1.shape[builtins.max(-1, -len(x2.shape))] != x2.shape[builtins.max(-2, -len(x2.shape))]:
@@ -184,7 +184,7 @@ def tensordot(
184184
# TODO: add fast path for when don't need to change chunkshapes
185185

186186
# Makes a SimpleProxy if inputs are not blosc2 arrays
187-
x1, x2 = blosc2.as_simpleproxy(x1), blosc2.as_simpleproxy(x2)
187+
x1, x2 = blosc2.as_simpleproxy(x1, x2)
188188

189189
if isinstance(axes, tuple):
190190
a_axes, b_axes = axes
@@ -326,7 +326,7 @@ def vecdot(x1: blosc2.NDArray, x2: blosc2.NDArray, axis: int = -1, **kwargs) ->
326326
return npvecdot(x1, x2, axis=axis)
327327

328328
# Makes a SimpleProxy if inputs are not blosc2 arrays
329-
x1, x2 = blosc2.as_simpleproxy(x1), blosc2.as_simpleproxy(x2)
329+
x1, x2 = blosc2.as_simpleproxy(x1, x2)
330330

331331
N = builtins.min(x1.ndim, x2.ndim)
332332
if axis < -N or axis > -1:
@@ -489,9 +489,15 @@ def permute_dims(
489489

490490
chunks = arr.chunks
491491
shape = arr.shape
492-
493-
for info in arr.iterchunks_info():
494-
coords = info.coords
492+
# handle SimpleProxy which doesn't have iterchunks_info
493+
if hasattr(arr, "iterchunks_info"):
494+
my_it = arr.iterchunks_info()
495+
_get_el = lambda x: x.coords # noqa: E731
496+
else:
497+
my_it = get_intersecting_chunks((), shape, chunks)
498+
_get_el = lambda x: x.raw # noqa: E731
499+
for info in my_it:
500+
coords = _get_el(info)
495501
start_stop = [
496502
(coord * chunk, builtins.min(chunk * (coord + 1), dim))
497503
for coord, chunk, dim in zip(coords, chunks, shape, strict=False)
@@ -636,7 +642,7 @@ def outer(x1: blosc2.blosc2.NDArray, x2: blosc2.blosc2.NDArray, **kwargs: Any) -
636642
out: blosc2.NDArray
637643
A two-dimensional array containing the outer product and whose shape is (N, M).
638644
"""
639-
x1, x2 = blosc2.as_simpleproxy(x1), blosc2.as_simpleproxy(x2)
645+
x1, x2 = blosc2.as_simpleproxy(x1, x2)
640646
if (x1.ndim != 1) or (x2.ndim != 1):
641647
raise ValueError("outer only valid for 1D inputs.")
642648
return tensordot(x1, x2, ((), ()), **kwargs) # for testing purposes

src/blosc2/ndarray.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4642,7 +4642,7 @@ def slice(self, key: int | slice | Sequence[slice], **kwargs: Any) -> NDArray:
46424642

46434643
return ndslice
46444644

4645-
def squeeze(self, axis=None) -> NDArray:
4645+
def squeeze(self, axis: int | Sequence[int]) -> NDArray:
46464646
"""Remove single-dimensional entries from the shape of the array.
46474647
46484648
This method modifies the array in-place. If mask is None removes any dimensions with size 1.
@@ -4666,18 +4666,15 @@ def squeeze(self, axis=None) -> NDArray:
46664666
>>> a.shape
46674667
(23, 11)
46684668
"""
4669-
if axis is None:
4670-
super().squeeze()
4671-
else:
4672-
axis = [axis] if isinstance(axis, int) else axis
4673-
mask = [False for i in range(self.ndim)]
4674-
for a in axis:
4675-
if a < 0:
4676-
a += self.ndim # Adjust axis to be within the array's dimensions
4677-
if mask[a]:
4678-
raise ValueError("Axis values must be unique.")
4679-
mask[a] = True
4680-
super().squeeze(mask=mask)
4669+
axis = [axis] if isinstance(axis, int) else axis
4670+
mask = [False for i in range(self.ndim)]
4671+
for a in axis:
4672+
if a < 0:
4673+
a += self.ndim # Adjust axis to be within the array's dimensions
4674+
if mask[a]:
4675+
raise ValueError("Axis values must be unique.")
4676+
mask[a] = True
4677+
super().squeeze(mask=mask)
46814678
return self
46824679

46834680
def indices(self, order: str | list[str] | None = None, **kwargs: Any) -> NDArray:
@@ -4711,17 +4708,23 @@ def __matmul__(self, other):
47114708
return blosc2.linalg.matmul(self, other)
47124709

47134710

4714-
def squeeze(x: NDArray, axis: int | None = None) -> NDArray:
4711+
def squeeze(x: Array, axis: int | Sequence[int]) -> NDArray:
47154712
"""
47164713
Remove single-dimensional entries from the shape of the array.
47174714
4718-
This method modifies the array in-place. If mask is None removes any dimensions with size 1.
4719-
If axis is provided, it should be an int or tuple of ints and the corresponding
4720-
dimensions (of size 1) will be removed.
4715+
This method modifies the array in-place.
4716+
4717+
Parameters
4718+
----------
4719+
x: Array
4720+
input array.
4721+
axis: int | Sequence[int]
4722+
Axis (or axes) to squeeze.
47214723
47224724
Returns
47234725
-------
4724-
out: NDArray
4726+
out: Array
4727+
An output array having the same data type and elements as x.
47254728
47264729
Examples
47274730
--------

src/blosc2/proxy.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -674,25 +674,28 @@ def __getitem__(self, item: slice | list[slice]) -> np.ndarray:
674674
return np.asarray(out) # avoids copy for PyTorch at least
675675

676676

677-
def as_simpleproxy(x: blosc2.Array) -> SimpleProxy | blosc2.Operand:
677+
def as_simpleproxy(*arrs: Sequence[blosc2.Array]) -> tuple[SimpleProxy | blosc2.Operand]:
678678
"""
679679
Convert an Array object which fulfills Array protocol into SimpleProxy. If x is already a
680680
blosc2.Operand simply returns object.
681681
682682
Parameters
683683
----------
684-
x: blosc2.Array
685-
Object fulfilling Array protocol.
684+
arrs: Sequence[blosc2.Array]
685+
Objects fulfilling Array protocol.
686686
687687
Returns
688688
-------
689-
out: blosc2.SimpleProxy | blosc2.Operand
690-
Object with minimal interface for blosc2 LazyExpr computations.
689+
out: tuple[blosc2.SimpleProxy | blosc2.Operand]
690+
Objects with minimal interface for blosc2 LazyExpr computations.
691691
"""
692-
if isinstance(x, blosc2.Operand):
693-
return x
694-
else:
695-
return SimpleProxy(x)
692+
out = ()
693+
for x in arrs:
694+
if isinstance(x, blosc2.Operand):
695+
out += (x,)
696+
else:
697+
out += (SimpleProxy(x),)
698+
return out[0] if len(out) == 1 else out
696699

697700

698701
def jit(func=None, *, out=None, disable=False, **kwargs): # noqa: C901

tests/ndarray/test_lazyexpr.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212
import pytest
13+
import torch
1314

1415
import blosc2
1516
from blosc2.lazyexpr import ne_evaluate
@@ -1697,13 +1698,13 @@ def test_lazylinalg():
16971698
np.testing.assert_array_almost_equal(out[()], npres)
16981699

16991700
# --- squeeze ---
1700-
out = blosc2.lazyexpr("squeeze(D)")
1701-
npres = np.squeeze(npD)
1701+
out = blosc2.lazyexpr("squeeze(D, axis=-1)")
1702+
npres = np.squeeze(npD, -1)
17021703
assert out.shape == npres.shape
17031704
np.testing.assert_array_almost_equal(out[()], npres)
17041705

1705-
out = blosc2.lazyexpr("D.squeeze()")
1706-
npres = np.squeeze(npD)
1706+
out = blosc2.lazyexpr("D.squeeze(axis=-1)")
1707+
npres = np.squeeze(npD, -1)
17071708
assert out.shape == npres.shape
17081709
np.testing.assert_array_almost_equal(out[()], npres)
17091710

@@ -1772,3 +1773,52 @@ def test_lazyexpr_2args():
17721773
newexpr = blosc2.hypot(lexpr, 3)
17731774
assert newexpr.expression == "hypot((sin(o0)), 3)"
17741775
assert newexpr.operands["o0"] is a
1776+
1777+
1778+
@pytest.mark.parametrize(
1779+
"xp",
1780+
[torch, np],
1781+
)
1782+
@pytest.mark.parametrize(
1783+
"dtype",
1784+
["bool", "int32", "int64", "float32", "float64", "complex128"],
1785+
)
1786+
def test_simpleproxy(xp, dtype):
1787+
dtype_ = getattr(xp, dtype) if hasattr(xp, dtype) else np.dtype(dtype)
1788+
if dtype == "bool":
1789+
blosc_matrix = blosc2.asarray([True, False, False], dtype=np.dtype(dtype), chunks=(2,))
1790+
foreign_matrix = xp.zeros((3,), dtype=dtype_)
1791+
# Create a lazy expression object
1792+
lexpr = blosc2.lazyexpr(
1793+
"(b & a) | (~b)", operands={"a": blosc_matrix, "b": foreign_matrix}
1794+
) # this does not
1795+
# Compare with numpy computation result
1796+
npb = np.asarray(foreign_matrix)
1797+
npa = blosc_matrix[()]
1798+
res = (npb & npa) | np.logical_not(npb)
1799+
else:
1800+
N = 10
1801+
shape_a = (N, N, N)
1802+
blosc_matrix = blosc2.full(shape=shape_a, fill_value=3, dtype=np.dtype(dtype), chunks=(N // 3,) * 3)
1803+
foreign_matrix = xp.ones(shape_a, dtype=dtype_)
1804+
if dtype == "complex128":
1805+
foreign_matrix += 0.5j
1806+
blosc_matrix = blosc2.full(
1807+
shape=shape_a, fill_value=3 + 2j, dtype=np.dtype(dtype), chunks=(N // 3,) * 3
1808+
)
1809+
1810+
# Create a lazy expression object
1811+
lexpr = blosc2.lazyexpr(
1812+
"b + sin(a) + sum(b) - tensordot(a, b, axes=1)",
1813+
operands={"a": blosc_matrix, "b": foreign_matrix},
1814+
) # this does not
1815+
# Compare with numpy computation result
1816+
npb = np.asarray(foreign_matrix)
1817+
npa = blosc_matrix[()]
1818+
res = npb + np.sin(npa) + np.sum(npb) - np.tensordot(npa, npb, axes=1)
1819+
1820+
# Test object metadata and result
1821+
assert isinstance(lexpr, blosc2.LazyExpr)
1822+
assert lexpr.dtype == res.dtype
1823+
assert lexpr.shape == res.shape
1824+
np.testing.assert_array_equal(lexpr[()], res)

tests/ndarray/test_linalg.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import inspect
12
from itertools import permutations
23

34
import numpy as np
45
import pytest
6+
import torch
57

68
import blosc2
9+
from blosc2.lazyexpr import linalg_funcs
710
from blosc2.ndarray import npvecdot
811

912

@@ -817,3 +820,54 @@ def test_diagonal(shape, chunkshape, offset):
817820

818821
# Assert equality
819822
np.testing.assert_array_equal(result_np, expected)
823+
824+
825+
@pytest.mark.parametrize(
826+
"xp",
827+
[torch, np],
828+
)
829+
@pytest.mark.parametrize(
830+
"dtype",
831+
["bool", "int32", "int64", "float32", "float64", "complex128"],
832+
)
833+
def test_linalgproxy(xp, dtype):
834+
dtype_ = getattr(xp, dtype) if hasattr(xp, dtype) else np.dtype(dtype)
835+
for name in linalg_funcs:
836+
if name == "transpose":
837+
continue # deprecated
838+
func = getattr(blosc2, name)
839+
N = 10
840+
shape_a = (N,)
841+
chunks = (N // 3,)
842+
if name != "outer":
843+
shape_a *= 3
844+
chunks *= 3
845+
blosc_matrix = blosc2.full(shape=shape_a, fill_value=3, dtype=np.dtype(dtype), chunks=chunks)
846+
foreign_matrix = xp.ones(shape_a, dtype=dtype_)
847+
if dtype == "complex128":
848+
foreign_matrix += 0.5j
849+
blosc_matrix = blosc2.full(
850+
shape=shape_a, fill_value=3 + 2j, dtype=np.dtype(dtype), chunks=chunks
851+
)
852+
853+
# Check this works
854+
argspec = inspect.getfullargspec(func)
855+
num_args = len(argspec.args)
856+
npfunc = blosc2.linalg.nptranspose if name == "permute_dims" else getattr(np, name)
857+
if num_args > 2 or name in ("outer", "matmul"):
858+
try:
859+
lexpr = func(blosc_matrix, foreign_matrix)
860+
except NotImplementedError:
861+
continue
862+
foreign_matrix = np.asarray(foreign_matrix)
863+
res = npfunc(blosc_matrix[()], foreign_matrix)
864+
else:
865+
try:
866+
lexpr = func(foreign_matrix)
867+
except NotImplementedError:
868+
continue
869+
except TypeError:
870+
continue
871+
foreign_matrix = np.asarray(foreign_matrix)
872+
res = npfunc(foreign_matrix, 0) if name == "expand_dims" else npfunc(foreign_matrix)
873+
np.testing.assert_array_equal(res, lexpr[()])

tests/ndarray/test_proxy_expr.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import numpy as np
1111
import pytest
12-
import torch
1312

1413
import blosc2
1514
from blosc2.lazyexpr import ne_evaluate
@@ -88,52 +87,3 @@ def test_expr_proxy_operands(chunks_blocks, c2sub_context):
8887
blosc2.remove_urlpath(urlpath)
8988
for path in cleanup_paths:
9089
blosc2.remove_urlpath(path)
91-
92-
93-
@pytest.mark.parametrize(
94-
"xp",
95-
[torch, np],
96-
)
97-
@pytest.mark.parametrize(
98-
"dtype",
99-
["bool", "int32", "int64", "float32", "float64", "complex128"],
100-
)
101-
def test_simpleproxy(xp, dtype):
102-
dtype_ = getattr(xp, dtype) if hasattr(xp, dtype) else np.dtype(dtype)
103-
if dtype == "bool":
104-
blosc_matrix = blosc2.asarray([True, False, False], dtype=np.dtype(dtype), chunks=(2,))
105-
foreign_matrix = xp.zeros((3,), dtype=dtype_)
106-
# Create a lazy expression object
107-
lexpr = blosc2.lazyexpr(
108-
"(b & a) | (~b)", operands={"a": blosc_matrix, "b": foreign_matrix}
109-
) # this does not
110-
# Compare with numpy computation result
111-
npb = np.asarray(foreign_matrix)
112-
npa = blosc_matrix[()]
113-
res = (npb & npa) | np.logical_not(npb)
114-
else:
115-
N = 10
116-
shape_a = (N, N, N)
117-
blosc_matrix = blosc2.full(shape=shape_a, fill_value=3, dtype=np.dtype(dtype), chunks=(N // 3,) * 3)
118-
foreign_matrix = xp.ones(shape_a, dtype=dtype_)
119-
if dtype == "complex128":
120-
foreign_matrix += 0.5j
121-
blosc_matrix = blosc2.full(
122-
shape=shape_a, fill_value=3 + 2j, dtype=np.dtype(dtype), chunks=(N // 3,) * 3
123-
)
124-
125-
# Create a lazy expression object
126-
lexpr = blosc2.lazyexpr(
127-
"b + sin(a) + sum(b) - tensordot(a, b, axes=1)",
128-
operands={"a": blosc_matrix, "b": foreign_matrix},
129-
) # this does not
130-
# Compare with numpy computation result
131-
npb = np.asarray(foreign_matrix)
132-
npa = blosc_matrix[()]
133-
res = npb + np.sin(npa) + np.sum(npb) - np.tensordot(npa, npb, axes=1)
134-
135-
# Test object metadata and result
136-
assert isinstance(lexpr, blosc2.LazyExpr)
137-
assert lexpr.dtype == res.dtype
138-
assert lexpr.shape == res.shape
139-
np.testing.assert_array_equal(lexpr[()], res)

0 commit comments

Comments
 (0)