Skip to content

Commit bc6729c

Browse files
committed
[DOC, TESTS] Add doc, tests
1 parent ebea726 commit bc6729c

5 files changed

Lines changed: 112 additions & 21 deletions

File tree

doc/reference/array_operations.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ Operations with arrays
66

77
lazy_functions
88
reduction_functions
9+
linear_algebra

doc/reference/linear_algebra.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
.. _linear_algebra:
2+
3+
Linear Algebra
4+
--------------
5+
6+
The next functions can be used for computing linear algebra operations with :ref:`NDArray <NDArray>`.
7+
8+
.. currentmodule:: blosc2
9+
10+
.. autosummary::
11+
:toctree: autofiles/operations_with_arrays/
12+
:nosignatures:
13+
14+
matmul

src/blosc2/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ class Tuner(Enum):
235235
ones,
236236
full,
237237
save,
238+
matmul,
238239
)
239240

240241
from .c2array import c2context, C2Array, URLPath

src/blosc2/ndarray.py

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3645,23 +3645,70 @@ def sort(array: NDArray, order: str | list[str] | None = None, **kwargs: Any) ->
36453645

36463646
def matmul(x1: NDArray, x2: NDArray, **kwargs: Any) -> NDArray:
36473647
"""
3648-
Perform matrix multiplication between two Blosc2 NDArrays.
3648+
Computes the matrix product between two Blosc2 NDArrays.
36493649
36503650
Parameters
36513651
----------
36523652
x1: `NDArray`
3653-
First input array.
3653+
The first input array.
36543654
x2: `NDArray`
3655-
Second input array.
3655+
The second input array.
36563656
kwargs: Any, optional
36573657
Keyword arguments that are supported by the :func:`empty` constructor.
36583658
36593659
Returns
36603660
-------
36613661
out: :ref:`NDArray`
3662-
The result matrix multiplication.
3662+
The matrix product of the inputs. This is a scalar only when both x1,
3663+
x2 are 1-d vectors.
3664+
3665+
Raises
3666+
------
3667+
ValueError
3668+
If the last dimension of ``x1`` is not the same size as
3669+
the second-to-last dimension of ``x2``.
3670+
3671+
If a scalar value is passed in.
3672+
3673+
References
3674+
----------
3675+
`numpy.matmul <https://numpy.org/doc/stable/reference/generated/numpy.matmul.html>`_
3676+
3677+
Examples
3678+
--------
3679+
For 2-D arrays it is the matrix product:
3680+
3681+
>>> import numpy as np
3682+
>>> import blosc2
3683+
>>> a = np.array([[1, 2],
3684+
... [3, 4]])
3685+
>>> nd_a = blosc2.asarray(a)
3686+
>>> b = np.array([[2, 3],
3687+
... [2, 1]])
3688+
>>> nd_b = blosc2.asarray(b)
3689+
>>> blosc2.matmul(nd_a, nd_b)
3690+
array([[ 6, 5],
3691+
[14, 13]])
3692+
3693+
3694+
For 2-D mixed with 1-D, the result is the usual.
3695+
3696+
>>> a = np.array([[1, 3],
3697+
... [0, 1]])
3698+
>>> nd_a = blosc2.asarray(a)
3699+
>>> v = np.array([1, 2])
3700+
>>> nd_v = blosc2.asarray(v)
3701+
>>> blosc2.matmul(nd_a, nd_v)
3702+
array([7, 2])
3703+
>>> blosc2.matmul(nd_v, nd_a)
3704+
array([1, 5])
3705+
36633706
"""
36643707

3708+
# Validate arguments are not scalars
3709+
if np.isscalar(x1) or np.isscalar(x2):
3710+
raise ValueError("Arguments can't be scalars.")
3711+
36653712
# Promote 1D arrays to 2D if necessary
36663713
x1_is_vector = False
36673714
x2_is_vector = False
@@ -3673,13 +3720,14 @@ def matmul(x1: NDArray, x2: NDArray, **kwargs: Any) -> NDArray:
36733720
x2_is_vector = True
36743721

36753722
# Validate matrix multiplication compatibility
3676-
if x1.shape[-1] != x2.shape[0]:
3723+
if x1.shape[-1] != x2.shape[-2]:
36773724
raise ValueError("Shapes are not aligned for matrix multiplication.")
36783725

3679-
n, l = x1.shape
3680-
_, m = x2.shape
3681-
p1, q1 = x1.chunks
3682-
p2, q2 = x2.chunks
3726+
n, l = x1.shape[-2:]
3727+
m = x2.shape[-1]
3728+
3729+
p1, q1 = x1.chunks[-2:]
3730+
q2 = x2.chunks[-1]
36833731

36843732
result = np.zeros((n, m), dtype=x1.dtype)
36853733
# result = blosc2.zeros(n) # TODO: file a ticket for blosc2.zeros()
@@ -3696,7 +3744,7 @@ def matmul(x1: NDArray, x2: NDArray, **kwargs: Any) -> NDArray:
36963744
bres[:] += np.matmul(bx1, bx2)
36973745

36983746
if x1_is_vector and x2_is_vector:
3699-
result = result.reshape((n, m))
3747+
result = result[0][0]
37003748
elif x1_is_vector:
37013749
result = result.reshape((m,))
37023750
elif x2_is_vector:

tests/ndarray/test_matmul.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import pytest
22
import numpy as np
33
import blosc2
4-
from blosc2.ndarray import matmul
54

65

76
@pytest.mark.parametrize(
87
("ashape", "achunks", "ablocks"),
98
[
10-
((12, 10), (6, 5), (3, 3)),
11-
((10, ), (9, ), (7, )),
9+
((12, 10), (7, 5), (3, 3)),
10+
((10,), (9,), (7,)),
1211
],
1312
)
1413
@pytest.mark.parametrize(
@@ -25,18 +24,19 @@
2524
def test_matmul(ashape, achunks, ablocks, bshape, bchunks, bblocks, dtype):
2625
a = blosc2.linspace(0, 10, dtype=dtype, shape=ashape, chunks=achunks, blocks=ablocks)
2726
b = blosc2.linspace(0, 10, dtype=dtype, shape=bshape, chunks=bchunks, blocks=bblocks)
27+
blosc2_res = blosc2.matmul(a, b)
2828

2929
na = a[:]
3030
nb = b[:]
31-
blosc2_res = matmul(a, b)
3231
np_res = np.matmul(na, nb)
32+
3333
np.testing.assert_allclose(blosc2_res, np_res, rtol=1e-6)
3434

3535

3636
@pytest.mark.parametrize(
3737
("ashape", "achunks", "ablocks"),
3838
[
39-
((12, 11), (6, 5), (3, 1)),
39+
((12, 11), (7, 5), (3, 1)),
4040
((0, 0), (0, 0), (0, 0)),
4141
((10,), (4,), (2,)),
4242
],
@@ -45,13 +45,40 @@ def test_matmul(ashape, achunks, ablocks, bshape, bchunks, bblocks, dtype):
4545
("bshape", "bchunks", "bblocks"),
4646
[
4747
((1, 5), (1, 4), (1, 3)),
48-
((4, 12), (2, 4), (1, 3)),
49-
((10,), (4,), (2,)),
48+
((4, 6), (2, 4), (1, 3)),
49+
((5,), (4,), (2,)),
5050
],
5151
)
52-
def test_matmul_raises(ashape, achunks, ablocks, bshape, bchunks, bblocks):
52+
def test_matmul_shapes(ashape, achunks, ablocks, bshape, bchunks, bblocks):
5353
a = blosc2.linspace(0, 10, shape=ashape, chunks=achunks, blocks=ablocks)
5454
b = blosc2.linspace(0, 10, shape=bshape, chunks=bchunks, blocks=bblocks)
55-
if a.shape[-1] != b.shape[0]:
56-
with pytest.raises(ValueError):
57-
matmul(a, b)
55+
56+
with pytest.raises(ValueError):
57+
blosc2.matmul(a, b)
58+
59+
with pytest.raises(ValueError):
60+
blosc2.matmul(b, a)
61+
62+
63+
@pytest.mark.parametrize("scalar", [
64+
5, # int
65+
5.3, # float
66+
1 + 2j, # complex
67+
np.int32(5), # NumPy int32
68+
np.int64(5), # NumPy int64
69+
np.float32(5.3), # NumPy float32
70+
np.float64(5.3), # NumPy float64
71+
np.complex64(1 + 2j), # NumPy complex64
72+
np.complex128(1 + 2j), # NumPy complex128
73+
])
74+
def test_matmul_scalars(scalar):
75+
vector = blosc2.asarray(np.array([1, 2, 3]))
76+
77+
with pytest.raises(ValueError):
78+
blosc2.matmul(scalar, vector)
79+
80+
with pytest.raises(ValueError):
81+
blosc2.matmul(vector, scalar)
82+
83+
with pytest.raises(ValueError):
84+
blosc2.matmul(scalar, scalar)

0 commit comments

Comments
 (0)