Skip to content

Commit dc48bdf

Browse files
committed
[FIX] Comments added
1 parent 0fea63b commit dc48bdf

2 files changed

Lines changed: 49 additions & 14 deletions

File tree

src/blosc2/ndarray.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3709,6 +3709,10 @@ def matmul(x1: NDArray, x2: NDArray, **kwargs: Any) -> NDArray:
37093709
if np.isscalar(x1) or np.isscalar(x2):
37103710
raise ValueError("Arguments can't be scalars.")
37113711

3712+
# Validate arguments are dimension 1 or 2
3713+
if x1.ndim > 2 or x2.ndim > 2:
3714+
raise ValueError("Multiplication of arrays with dimension greater than 2 is not supported yet.")
3715+
37123716
# Promote 1D arrays to 2D if necessary
37133717
x1_is_vector = False
37143718
x2_is_vector = False
@@ -3729,7 +3733,7 @@ def matmul(x1: NDArray, x2: NDArray, **kwargs: Any) -> NDArray:
37293733
p1, q1 = x1.chunks[-2:]
37303734
q2 = x2.chunks[-1]
37313735

3732-
result = blosc2.zeros((n, m), dtype=x1.dtype)
3736+
result = blosc2.zeros((n, m), dtype=np.result_type(x1, x2), **kwargs)
37333737

37343738
for row in range(0, n, p1):
37353739
row_end = (row + p1) if (row + p1) < n else n

tests/ndarray/test_matmul.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import pytest
21
import numpy as np
2+
import pytest
3+
34
import blosc2
45

56

@@ -19,7 +20,8 @@
1920
],
2021
)
2122
@pytest.mark.parametrize(
22-
"dtype", [np.float32, np.float64, np.complex64, np.complex128],
23+
"dtype",
24+
[np.float32, np.float64, np.complex64, np.complex128],
2325
)
2426
def test_matmul(ashape, achunks, ablocks, bshape, bchunks, bblocks, dtype):
2527
a = blosc2.linspace(0, 10, dtype=dtype, shape=ashape, chunks=achunks, blocks=ablocks)
@@ -60,17 +62,20 @@ def test_matmul_shapes(ashape, achunks, ablocks, bshape, bchunks, bblocks):
6062
blosc2.matmul(b, a)
6163

6264

63-
@pytest.mark.parametrize("scalar", [
64-
5, # int
65-
5.3, # float
66-
1 + 2j, # complex
67-
np.int32(5), # NumPy int32
68-
np.int64(5), # NumPy int64
69-
np.float32(5.3), # NumPy float32
70-
np.float64(5.3), # NumPy float64
71-
np.complex64(1 + 2j), # NumPy complex64
72-
np.complex128(1 + 2j), # NumPy complex128
73-
])
65+
@pytest.mark.parametrize(
66+
"scalar",
67+
[
68+
5, # int
69+
5.3, # float
70+
1 + 2j, # complex
71+
np.int32(5), # NumPy int32
72+
np.int64(5), # NumPy int64
73+
np.float32(5.3), # NumPy float32
74+
np.float64(5.3), # NumPy float64
75+
np.complex64(1 + 2j), # NumPy complex64
76+
np.complex128(1 + 2j), # NumPy complex128
77+
],
78+
)
7479
def test_matmul_scalars(scalar):
7580
vector = blosc2.asarray(np.array([1, 2, 3]))
7681

@@ -82,3 +87,29 @@ def test_matmul_scalars(scalar):
8287

8388
with pytest.raises(ValueError):
8489
blosc2.matmul(scalar, scalar)
90+
91+
92+
@pytest.mark.parametrize(
93+
"ashape",
94+
[
95+
(12, 10, 10),
96+
(7, 5, 5),
97+
(3, 3, 3),
98+
],
99+
)
100+
@pytest.mark.parametrize(
101+
"bshape",
102+
[
103+
(10, 10, 10, 11),
104+
(3, 2, 9),
105+
],
106+
)
107+
def test_matmul_dims(ashape, bshape):
108+
a = blosc2.linspace(0, 10, shape=ashape)
109+
b = blosc2.linspace(0, 1, shape=bshape)
110+
111+
with pytest.raises(ValueError):
112+
blosc2.matmul(a, b)
113+
114+
with pytest.raises(ValueError):
115+
blosc2.matmul(b, a)

0 commit comments

Comments
 (0)