66
77@pytest .mark .parametrize (
88 ("ashape" , "achunks" , "ablocks" ),
9- [
9+ {
1010 ((12 , 10 ), (7 , 5 ), (3 , 3 )),
1111 ((10 ,), (9 ,), (7 ,)),
12- ] ,
12+ } ,
1313)
1414@pytest .mark .parametrize (
1515 ("bshape" , "bchunks" , "bblocks" ),
16- [
16+ {
1717 ((10 ,), (4 ,), (2 ,)),
1818 ((10 , 5 ), (3 , 4 ), (1 , 3 )),
1919 ((10 , 12 ), (2 , 4 ), (1 , 2 )),
20- ] ,
20+ } ,
2121)
2222@pytest .mark .parametrize (
2323 "dtype" ,
24- [ np .float32 , np .float64 , np .complex64 , np .complex128 ] ,
24+ { np .float32 , np .float64 , np .complex64 , np .complex128 } ,
2525)
2626def test_matmul (ashape , achunks , ablocks , bshape , bchunks , bblocks , dtype ):
2727 a = blosc2 .linspace (0 , 10 , dtype = dtype , shape = ashape , chunks = achunks , blocks = ablocks )
@@ -37,19 +37,19 @@ def test_matmul(ashape, achunks, ablocks, bshape, bchunks, bblocks, dtype):
3737
3838@pytest .mark .parametrize (
3939 ("ashape" , "achunks" , "ablocks" ),
40- [
40+ {
4141 ((12 , 11 ), (7 , 5 ), (3 , 1 )),
4242 ((0 , 0 ), (0 , 0 ), (0 , 0 )),
4343 ((10 ,), (4 ,), (2 ,)),
44- ] ,
44+ } ,
4545)
4646@pytest .mark .parametrize (
4747 ("bshape" , "bchunks" , "bblocks" ),
48- [
48+ {
4949 ((1 , 5 ), (1 , 4 ), (1 , 3 )),
5050 ((4 , 6 ), (2 , 4 ), (1 , 3 )),
5151 ((5 ,), (4 ,), (2 ,)),
52- ] ,
52+ } ,
5353)
5454def test_matmul_shapes (ashape , achunks , ablocks , bshape , bchunks , bblocks ):
5555 a = blosc2 .linspace (0 , 10 , shape = ashape , chunks = achunks , blocks = ablocks )
@@ -64,7 +64,7 @@ def test_matmul_shapes(ashape, achunks, ablocks, bshape, bchunks, bblocks):
6464
6565@pytest .mark .parametrize (
6666 "scalar" ,
67- [
67+ {
6868 5 , # int
6969 5.3 , # float
7070 1 + 2j , # complex
@@ -74,7 +74,7 @@ def test_matmul_shapes(ashape, achunks, ablocks, bshape, bchunks, bblocks):
7474 np .float64 (5.3 ), # NumPy float64
7575 np .complex64 (1 + 2j ), # NumPy complex64
7676 np .complex128 (1 + 2j ), # NumPy complex128
77- ] ,
77+ } ,
7878)
7979def test_matmul_scalars (scalar ):
8080 vector = blosc2 .asarray (np .array ([1 , 2 , 3 ]))
@@ -113,3 +113,30 @@ def test_matmul_dims(ashape, bshape):
113113
114114 with pytest .raises (ValueError ):
115115 blosc2 .matmul (b , a )
116+
117+
118+ @pytest .mark .parametrize (
119+ ("ashape" , "achunks" , "ablocks" , "adtype" ),
120+ {
121+ ((7 , 10 ), (7 , 5 ), (3 , 5 ), np .float32 ),
122+ ((10 ,), (9 ,), (7 ,), np .complex64 ),
123+ },
124+ )
125+ @pytest .mark .parametrize (
126+ ("bshape" , "bchunks" , "bblocks" , "bdtype" ),
127+ {
128+ ((10 ,), (4 ,), (2 ,), np .float64 ),
129+ ((10 , 6 ), (9 , 4 ), (2 , 3 ), np .complex128 ),
130+ ((10 , 12 ), (2 , 4 ), (1 , 2 ), np .complex128 ),
131+ },
132+ )
133+ def test_matmul_especial_cases (ashape , achunks , ablocks , adtype , bshape , bchunks , bblocks , bdtype ):
134+ a = blosc2 .linspace (0 , 10 , dtype = adtype , shape = ashape , chunks = achunks , blocks = ablocks )
135+ b = blosc2 .linspace (0 , 10 , dtype = bdtype , shape = bshape , chunks = bchunks , blocks = bblocks )
136+ blosc2_res = blosc2 .matmul (a , b )
137+
138+ na = a [:]
139+ nb = b [:]
140+ np_res = np .matmul (na , nb )
141+
142+ np .testing .assert_allclose (blosc2_res , np_res , rtol = 1e-6 )
0 commit comments