Skip to content

Commit 1354fff

Browse files
committed
[TEST] New test for special cases
1 parent dc48bdf commit 1354fff

1 file changed

Lines changed: 38 additions & 11 deletions

File tree

tests/ndarray/test_matmul.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,22 @@
66

77
@pytest.mark.parametrize(
88
("ashape", "achunks", "ablocks"),
9-
[
9+
{
1010
((12, 10), (7, 5), (3, 3)),
1111
((10,), (9,), (7,)),
12-
],
12+
},
1313
)
1414
@pytest.mark.parametrize(
1515
("bshape", "bchunks", "bblocks"),
16-
[
16+
{
1717
((10,), (4,), (2,)),
1818
((10, 5), (3, 4), (1, 3)),
1919
((10, 12), (2, 4), (1, 2)),
20-
],
20+
},
2121
)
2222
@pytest.mark.parametrize(
2323
"dtype",
24-
[np.float32, np.float64, np.complex64, np.complex128],
24+
{np.float32, np.float64, np.complex64, np.complex128},
2525
)
2626
def test_matmul(ashape, achunks, ablocks, bshape, bchunks, bblocks, dtype):
2727
a = blosc2.linspace(0, 10, dtype=dtype, shape=ashape, chunks=achunks, blocks=ablocks)
@@ -37,19 +37,19 @@ def test_matmul(ashape, achunks, ablocks, bshape, bchunks, bblocks, dtype):
3737

3838
@pytest.mark.parametrize(
3939
("ashape", "achunks", "ablocks"),
40-
[
40+
{
4141
((12, 11), (7, 5), (3, 1)),
4242
((0, 0), (0, 0), (0, 0)),
4343
((10,), (4,), (2,)),
44-
],
44+
},
4545
)
4646
@pytest.mark.parametrize(
4747
("bshape", "bchunks", "bblocks"),
48-
[
48+
{
4949
((1, 5), (1, 4), (1, 3)),
5050
((4, 6), (2, 4), (1, 3)),
5151
((5,), (4,), (2,)),
52-
],
52+
},
5353
)
5454
def test_matmul_shapes(ashape, achunks, ablocks, bshape, bchunks, bblocks):
5555
a = blosc2.linspace(0, 10, shape=ashape, chunks=achunks, blocks=ablocks)
@@ -64,7 +64,7 @@ def test_matmul_shapes(ashape, achunks, ablocks, bshape, bchunks, bblocks):
6464

6565
@pytest.mark.parametrize(
6666
"scalar",
67-
[
67+
{
6868
5, # int
6969
5.3, # float
7070
1 + 2j, # complex
@@ -74,7 +74,7 @@ def test_matmul_shapes(ashape, achunks, ablocks, bshape, bchunks, bblocks):
7474
np.float64(5.3), # NumPy float64
7575
np.complex64(1 + 2j), # NumPy complex64
7676
np.complex128(1 + 2j), # NumPy complex128
77-
],
77+
},
7878
)
7979
def test_matmul_scalars(scalar):
8080
vector = blosc2.asarray(np.array([1, 2, 3]))
@@ -113,3 +113,30 @@ def test_matmul_dims(ashape, bshape):
113113

114114
with pytest.raises(ValueError):
115115
blosc2.matmul(b, a)
116+
117+
118+
@pytest.mark.parametrize(
119+
("ashape", "achunks", "ablocks", "adtype"),
120+
{
121+
((7, 10), (7, 5), (3, 5), np.float32),
122+
((10,), (9,), (7,), np.complex64),
123+
},
124+
)
125+
@pytest.mark.parametrize(
126+
("bshape", "bchunks", "bblocks", "bdtype"),
127+
{
128+
((10,), (4,), (2,), np.float64),
129+
((10, 6), (9, 4), (2, 3), np.complex128),
130+
((10, 12), (2, 4), (1, 2), np.complex128),
131+
},
132+
)
133+
def test_matmul_especial_cases(ashape, achunks, ablocks, adtype, bshape, bchunks, bblocks, bdtype):
134+
a = blosc2.linspace(0, 10, dtype=adtype, shape=ashape, chunks=achunks, blocks=ablocks)
135+
b = blosc2.linspace(0, 10, dtype=bdtype, shape=bshape, chunks=bchunks, blocks=bblocks)
136+
blosc2_res = blosc2.matmul(a, b)
137+
138+
na = a[:]
139+
nb = b[:]
140+
np_res = np.matmul(na, nb)
141+
142+
np.testing.assert_allclose(blosc2_res, np_res, rtol=1e-6)

0 commit comments

Comments
 (0)