11import pytest
22import numpy as np
33import blosc2
4- from blosc2 .ndarray import matmul
54
65
76@pytest .mark .parametrize (
87 ("ashape" , "achunks" , "ablocks" ),
98 [
10- ((12 , 10 ), (6 , 5 ), (3 , 3 )),
11- ((10 , ), (9 , ), (7 , )),
9+ ((12 , 10 ), (7 , 5 ), (3 , 3 )),
10+ ((10 ,), (9 ,), (7 ,)),
1211 ],
1312)
1413@pytest .mark .parametrize (
2524def test_matmul (ashape , achunks , ablocks , bshape , bchunks , bblocks , dtype ):
2625 a = blosc2 .linspace (0 , 10 , dtype = dtype , shape = ashape , chunks = achunks , blocks = ablocks )
2726 b = blosc2 .linspace (0 , 10 , dtype = dtype , shape = bshape , chunks = bchunks , blocks = bblocks )
27+ blosc2_res = blosc2 .matmul (a , b )
2828
2929 na = a [:]
3030 nb = b [:]
31- blosc2_res = matmul (a , b )
3231 np_res = np .matmul (na , nb )
32+
3333 np .testing .assert_allclose (blosc2_res , np_res , rtol = 1e-6 )
3434
3535
3636@pytest .mark .parametrize (
3737 ("ashape" , "achunks" , "ablocks" ),
3838 [
39- ((12 , 11 ), (6 , 5 ), (3 , 1 )),
39+ ((12 , 11 ), (7 , 5 ), (3 , 1 )),
4040 ((0 , 0 ), (0 , 0 ), (0 , 0 )),
4141 ((10 ,), (4 ,), (2 ,)),
4242 ],
@@ -45,13 +45,40 @@ def test_matmul(ashape, achunks, ablocks, bshape, bchunks, bblocks, dtype):
4545 ("bshape" , "bchunks" , "bblocks" ),
4646 [
4747 ((1 , 5 ), (1 , 4 ), (1 , 3 )),
48- ((4 , 12 ), (2 , 4 ), (1 , 3 )),
49- ((10 ,), (4 ,), (2 ,)),
48+ ((4 , 6 ), (2 , 4 ), (1 , 3 )),
49+ ((5 ,), (4 ,), (2 ,)),
5050 ],
5151)
52- def test_matmul_raises (ashape , achunks , ablocks , bshape , bchunks , bblocks ):
52+ def test_matmul_shapes (ashape , achunks , ablocks , bshape , bchunks , bblocks ):
5353 a = blosc2 .linspace (0 , 10 , shape = ashape , chunks = achunks , blocks = ablocks )
5454 b = blosc2 .linspace (0 , 10 , shape = bshape , chunks = bchunks , blocks = bblocks )
55- if a .shape [- 1 ] != b .shape [0 ]:
56- with pytest .raises (ValueError ):
57- matmul (a , b )
55+
56+ with pytest .raises (ValueError ):
57+ blosc2 .matmul (a , b )
58+
59+ with pytest .raises (ValueError ):
60+ blosc2 .matmul (b , a )
61+
62+
63+ @pytest .mark .parametrize ("scalar" , [
64+ 5 , # int
65+ 5.3 , # float
66+ 1 + 2j , # complex
67+ np .int32 (5 ), # NumPy int32
68+ np .int64 (5 ), # NumPy int64
69+ np .float32 (5.3 ), # NumPy float32
70+ np .float64 (5.3 ), # NumPy float64
71+ np .complex64 (1 + 2j ), # NumPy complex64
72+ np .complex128 (1 + 2j ), # NumPy complex128
73+ ])
74+ def test_matmul_scalars (scalar ):
75+ vector = blosc2 .asarray (np .array ([1 , 2 , 3 ]))
76+
77+ with pytest .raises (ValueError ):
78+ blosc2 .matmul (scalar , vector )
79+
80+ with pytest .raises (ValueError ):
81+ blosc2 .matmul (vector , scalar )
82+
83+ with pytest .raises (ValueError ):
84+ blosc2 .matmul (scalar , scalar )
0 commit comments