@@ -86,44 +86,61 @@ def test_toggle_miniexpr_updates_linalg_runtime_flag():
8686 _toggle_miniexpr (old_flag )
8787
8888
89- def test_matmul_uses_fast_path_for_supported_2d (monkeypatch ):
90- old_flag = utils_mod .try_miniexpr
89+ def _set_pref_matmul_call_recorder (monkeypatch ):
9190 calls = []
9291 original = blosc2 .NDArray ._set_pref_matmul
9392
9493 def wrapped_set_pref_matmul (self , inputs , fp_accuracy ):
95- calls .append ((self .shape , inputs ["x1" ].shape , inputs ["x2" ].shape ))
94+ calls .append ((self .shape , inputs ["x1" ].shape , inputs ["x2" ].shape , self . dtype ))
9695 return original (self , inputs , fp_accuracy )
9796
9897 monkeypatch .setattr (blosc2 .NDArray , "_set_pref_matmul" , wrapped_set_pref_matmul )
98+ return calls
99+
100+
101+ @pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
102+ def test_matmul_uses_fast_path_for_supported_2d (monkeypatch , dtype ):
103+ old_flag = utils_mod .try_miniexpr
104+ calls = _set_pref_matmul_call_recorder (monkeypatch )
99105 try :
100106 _toggle_miniexpr (True )
101- a = blosc2 .ones (shape = (400 , 400 ), dtype = np .float64 , chunks = (200 , 200 ), blocks = (100 , 100 ))
102- b = blosc2 .full (
103- shape = (400 , 400 ), fill_value = 2 , dtype = np .float64 , chunks = (200 , 200 ), blocks = (100 , 100 )
104- )
107+ a = blosc2 .ones (shape = (400 , 400 ), dtype = dtype , chunks = (200 , 200 ), blocks = (100 , 100 ))
108+ b = blosc2 .full (shape = (400 , 400 ), fill_value = 2 , dtype = dtype , chunks = (200 , 200 ), blocks = (100 , 100 ))
105109
106110 with warnings .catch_warnings ():
107111 warnings .simplefilter ("ignore" , RuntimeWarning )
108112 c = blosc2 .matmul (a , b , chunks = (200 , 200 ), blocks = (100 , 100 ))
109113 expected = np .matmul (a [:], b [:])
110114
111- assert calls == [((400 , 400 ), (400 , 400 ), (400 , 400 ))]
115+ assert calls == [((400 , 400 ), (400 , 400 ), (400 , 400 ), np . dtype ( dtype ) )]
112116 np .testing .assert_allclose (c [:], expected , rtol = 1e-6 , atol = 1e-6 )
113117 finally :
114118 _toggle_miniexpr (old_flag )
115119
116120
117- def test_matmul_falls_back_for_integer_inputs (monkeypatch ):
121+ @pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
122+ def test_matmul_uses_fast_path_with_multiple_inner_blocks (monkeypatch , dtype ):
118123 old_flag = utils_mod .try_miniexpr
119- calls = []
120- original = blosc2 .NDArray ._set_pref_matmul
124+ calls = _set_pref_matmul_call_recorder (monkeypatch )
125+ try :
126+ _toggle_miniexpr (True )
127+ a = blosc2 .ones (shape = (256 , 384 ), dtype = dtype , chunks = (128 , 192 ), blocks = (64 , 64 ))
128+ b = blosc2 .full (shape = (384 , 256 ), fill_value = 2 , dtype = dtype , chunks = (192 , 128 ), blocks = (64 , 64 ))
121129
122- def wrapped_set_pref_matmul (self , inputs , fp_accuracy ):
123- calls .append ((self .shape , inputs ["x1" ].shape , inputs ["x2" ].shape ))
124- return original (self , inputs , fp_accuracy )
130+ with warnings .catch_warnings ():
131+ warnings .simplefilter ("ignore" , RuntimeWarning )
132+ c = blosc2 .matmul (a , b , chunks = (128 , 128 ), blocks = (64 , 64 ))
133+ expected = np .matmul (a [:], b [:])
125134
126- monkeypatch .setattr (blosc2 .NDArray , "_set_pref_matmul" , wrapped_set_pref_matmul )
135+ assert calls == [((256 , 256 ), (256 , 384 ), (384 , 256 ), np .dtype (dtype ))]
136+ np .testing .assert_allclose (c [:], expected , rtol = 1e-6 , atol = 1e-6 )
137+ finally :
138+ _toggle_miniexpr (old_flag )
139+
140+
141+ def test_matmul_falls_back_for_integer_inputs (monkeypatch ):
142+ old_flag = utils_mod .try_miniexpr
143+ calls = _set_pref_matmul_call_recorder (monkeypatch )
127144 try :
128145 _toggle_miniexpr (True )
129146 a = blosc2 .ones (shape = (200 , 200 ), dtype = np .int64 , chunks = (100 , 100 ), blocks = (50 , 50 ))
@@ -139,14 +156,7 @@ def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
139156
140157def test_matmul_falls_back_for_nd_inputs (monkeypatch ):
141158 old_flag = utils_mod .try_miniexpr
142- calls = []
143- original = blosc2 .NDArray ._set_pref_matmul
144-
145- def wrapped_set_pref_matmul (self , inputs , fp_accuracy ):
146- calls .append ((self .shape , inputs ["x1" ].shape , inputs ["x2" ].shape ))
147- return original (self , inputs , fp_accuracy )
148-
149- monkeypatch .setattr (blosc2 .NDArray , "_set_pref_matmul" , wrapped_set_pref_matmul )
159+ calls = _set_pref_matmul_call_recorder (monkeypatch )
150160 try :
151161 _toggle_miniexpr (True )
152162 a = blosc2 .ones (shape = (2 , 40 , 40 ), dtype = np .float64 , chunks = (1 , 20 , 20 ), blocks = (1 , 10 , 10 ))
@@ -165,6 +175,65 @@ def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
165175 _toggle_miniexpr (old_flag )
166176
167177
178+ @pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
179+ def test_matmul_falls_back_for_misaligned_blocks (monkeypatch , dtype ):
180+ old_flag = utils_mod .try_miniexpr
181+ calls = _set_pref_matmul_call_recorder (monkeypatch )
182+ try :
183+ _toggle_miniexpr (True )
184+ a = blosc2 .ones (shape = (400 , 400 ), dtype = dtype , chunks = (200 , 200 ), blocks = (120 , 100 ))
185+ b = blosc2 .full (shape = (400 , 400 ), fill_value = 2 , dtype = dtype , chunks = (200 , 200 ), blocks = (100 , 100 ))
186+
187+ with warnings .catch_warnings ():
188+ warnings .simplefilter ("ignore" , RuntimeWarning )
189+ c = blosc2 .matmul (a , b , chunks = (200 , 200 ), blocks = (120 , 100 ))
190+ expected = np .matmul (a [:], b [:])
191+
192+ assert calls == []
193+ np .testing .assert_allclose (c [:], expected , rtol = 1e-6 , atol = 1e-6 )
194+ finally :
195+ _toggle_miniexpr (old_flag )
196+
197+
198+ def test_matmul_falls_back_for_dtype_mismatch (monkeypatch ):
199+ old_flag = utils_mod .try_miniexpr
200+ calls = _set_pref_matmul_call_recorder (monkeypatch )
201+ try :
202+ _toggle_miniexpr (True )
203+ a = blosc2 .ones (shape = (200 , 200 ), dtype = np .float32 , chunks = (100 , 100 ), blocks = (50 , 50 ))
204+ b = blosc2 .full (shape = (200 , 200 ), fill_value = 2 , dtype = np .float64 , chunks = (100 , 100 ), blocks = (50 , 50 ))
205+
206+ with warnings .catch_warnings ():
207+ warnings .simplefilter ("ignore" , RuntimeWarning )
208+ c = blosc2 .matmul (a , b , chunks = (100 , 100 ), blocks = (50 , 50 ))
209+ expected = np .matmul (a [:], b [:])
210+
211+ assert calls == []
212+ np .testing .assert_allclose (c [:], expected , rtol = 1e-6 , atol = 1e-6 )
213+ finally :
214+ _toggle_miniexpr (old_flag )
215+
216+
217+ @pytest .mark .parametrize ("dtype" , [np .complex64 , np .complex128 ])
218+ def test_matmul_complex_falls_back_to_chunked (monkeypatch , dtype ):
219+ old_flag = utils_mod .try_miniexpr
220+ calls = _set_pref_matmul_call_recorder (monkeypatch )
221+ try :
222+ _toggle_miniexpr (True )
223+ a = blosc2 .asarray (np .ones ((100 , 100 ), dtype = dtype ))
224+ b = blosc2 .asarray (np .full ((100 , 100 ), 2 + 0j , dtype = dtype ))
225+
226+ with warnings .catch_warnings ():
227+ warnings .simplefilter ("ignore" , RuntimeWarning )
228+ c = blosc2 .matmul (a , b , chunks = (50 , 50 ), blocks = (25 , 25 ))
229+ expected = np .matmul (a [:], b [:])
230+
231+ assert calls == []
232+ np .testing .assert_allclose (c [:], expected , rtol = 1e-6 , atol = 1e-6 )
233+ finally :
234+ _toggle_miniexpr (old_flag )
235+
236+
168237def test_matmul_fast_path_failure_falls_back (monkeypatch ):
169238 old_flag = utils_mod .try_miniexpr
170239
0 commit comments