2323 from collections .abc import Sequence
2424
2525
26- def matmul (x1 : blosc2 .Array , x2 : blosc2 .NDArray , ** kwargs : Any ) -> blosc2 .NDArray : # noqa: C901
26+ def _matmul_chunked (
27+ x1 : blosc2 .Array , x2 : blosc2 .NDArray , result : blosc2 .NDArray , n : int , m : int , k : int
28+ ) -> None :
29+ p , q = result .chunks [- 2 :]
30+ r = x2 .chunks [- 1 ]
31+
32+ intersecting_chunks = get_intersecting_chunks ((), result .shape [:- 2 ], result .chunks [:- 2 ])
33+ for chunk in intersecting_chunks :
34+ chunk = chunk .raw
35+ for row in range (0 , n , p ):
36+ row_end = builtins .min (row + p , n )
37+ for col in range (0 , m , q ):
38+ col_end = builtins .min (col + q , m )
39+ for aux in range (0 , k , r ):
40+ aux_end = builtins .min (aux + r , k )
41+ bx1 = (
42+ x1 [chunk [- x1 .ndim + 2 :] + (slice (row , row_end ), slice (aux , aux_end ))]
43+ if x1 .ndim > 2
44+ else x1 [row :row_end , aux :aux_end ]
45+ )
46+ bx2 = (
47+ x2 [chunk [- x2 .ndim + 2 :] + (slice (aux , aux_end ), slice (col , col_end ))]
48+ if x2 .ndim > 2
49+ else x2 [aux :aux_end , col :col_end ]
50+ )
51+ result [chunk + (slice (row , row_end ), slice (col , col_end ))] += np .matmul (bx1 , bx2 )
52+
53+
54+ def _matmul_can_use_fast_path (
55+ x1 : blosc2 .Array , x2 : blosc2 .NDArray , result : blosc2 .NDArray , use_miniexpr : bool
56+ ) -> bool :
57+ if not use_miniexpr :
58+ return False
59+
60+ ops = (x1 , x2 , result )
61+ all_ndarray = all (isinstance (value , blosc2 .NDArray ) and value .shape != () for value in ops )
62+ if not all_ndarray :
63+ return False
64+
65+ # The current prefilter-backed implementation is only supported for 2-D layouts.
66+ if result .ndim != 2 or x1 .ndim != 2 or x2 .ndim != 2 :
67+ return False
68+
69+ if any (op .dtype != ops [0 ].dtype for op in ops ):
70+ return False
71+
72+ chunks_aligned = x1 .chunks [- 2 ] % x1 .blocks [- 2 ] == 0
73+ chunks_aligned &= x2 .chunks [- 1 ] % x2 .blocks [- 1 ] == 0
74+ chunks_aligned &= x2 .chunks [- 2 ] % x1 .blocks [- 1 ] == 0
75+ if not chunks_aligned :
76+ return False
77+
78+ same_blocks = x2 .blocks [- 2 ] == x1 .blocks [- 1 ]
79+ same_blocks &= x2 .blocks [- 1 ] == result .blocks [- 1 ]
80+ same_blocks &= result .blocks [- 2 ] == x1 .blocks [- 2 ]
81+ if not same_blocks :
82+ return False
83+
84+ try :
85+ result_blocks = np .broadcast_shapes (x1 .blocks , x2 .blocks )
86+ except ValueError :
87+ return False
88+ if result_blocks [:- 2 ] != result .blocks [:- 2 ]:
89+ return False
90+
91+ if x1 .dtype .kind not in ("i" , "f" ):
92+ return False
93+ if x2 .dtype .kind not in ("i" , "f" ):
94+ return False
95+ return x1 .dtype == x2 .dtype
96+
97+
98+ def matmul (x1 : blosc2 .Array , x2 : blosc2 .NDArray , ** kwargs : Any ) -> blosc2 .NDArray :
2799 """
28100 Computes the matrix product between two Blosc2 NDArrays.
29101
@@ -112,60 +184,10 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
112184 kwargs ["_chunksize_reduc_factor" ] = 1
113185 result = blosc2 .zeros (result_shape , dtype = blosc2 .result_type (x1 , x2 ), ** kwargs )
114186
115- # multithreaded matmul
116- # TODO: handle a) type promotion, b) padding (explicitly), c) (improved) >2D
117- ops = (x1 , x2 , result )
118- all_ndarray = all (isinstance (value , blosc2 .NDArray ) and value .shape != () for value in ops )
119187 global try_miniexpr
120188
121- # Use a local copy so we don't modify the global
122- use_miniexpr = try_miniexpr
123189 if 0 not in result .shape + x1 .shape + x2 .shape : # if any array is empty, return array of 0s
124- if all_ndarray :
125- if any (op .dtype != ops [0 ].dtype for op in ops ): # TODO: Remove this condition
126- use_miniexpr = False
127-
128- # TODO: We can relax this to even just load according to result blockshape, but that's difficult.
129- # Just force same chunk/block shapes
130- # same_chunks = all(op.chunks == result.chunks for op in (x1, x2))
131- # same_blocks = all(op.blocks == result.blocks for op in (x1, x2))
132- # same_shape = all(op.shape == result.shape for op in (x1, x2))
133-
134- # use_miniexpr &= same_blocks & same_chunks & same_shape
135- # Two easier cases are presented below
136- # Case 1: Might want to restrict loading across chunk boundaries, in which case would require:
137- # x1.chunks[-2] % result.blocks[-2] == 0
138- # x2.chunks[-1] % result.blocks[-1] == 0
139- # x2.chunks[-2] % x1.blocks[-1] == 0
140- # Can then load in x1 as slices of size [result.blocks[-2], x1.blocks[-1]]
141- # and x2 in slices of [x1.blocks[-1], result.blocks[-1]]
142-
143- # Case 2: Slightly easier to implement this maybe
144- # Require that blocks are matmul compatible and broadcastable directly to result
145- # (M, K) x (K, N) = (M, N)
146- # so can load block-by-block for inputs and calculate block of output
147- # Also need to avoid loading across chunk boundaries
148- chunks_aligned = x1 .chunks [- 2 ] % x1 .blocks [- 2 ] == 0
149- chunks_aligned &= x2 .chunks [- 1 ] % x2 .blocks [- 1 ] == 0
150- chunks_aligned &= x2 .chunks [- 2 ] % x1 .blocks [- 1 ] == 0
151- same_blocks = x2 .blocks [- 2 ] == x1 .blocks [- 1 ]
152- same_blocks &= x2 .blocks [- 1 ] == result .blocks [- 1 ]
153- same_blocks &= result .blocks [- 2 ] == x1 .blocks [- 2 ]
154- try :
155- result_blocks = np .broadcast_shapes (x1 .blocks , x2 .blocks )
156- if not (same_blocks and chunks_aligned and result_blocks [:- 2 ] == result .blocks [:- 2 ]):
157- use_miniexpr = False
158- except ValueError :
159- use_miniexpr = False
160-
161- use_miniexpr &= x1 .dtype .kind in ("i" , "f" )
162- use_miniexpr &= x2 .dtype .kind in ("i" , "f" )
163- use_miniexpr &= x1 .dtype == x2 .dtype
164-
165- else :
166- use_miniexpr = False
167-
168- if use_miniexpr :
190+ if _matmul_can_use_fast_path (x1 , x2 , result , try_miniexpr ):
169191 prefilter_set = False
170192 try :
171193 result ._set_pref_matmul ({"x1" : x1 , "x2" : x2 }, fp_accuracy = blosc2 .FPAccuracy .DEFAULT )
@@ -174,36 +196,16 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
174196 data = np .empty (result .schunk .chunksize , dtype = np .uint8 )
175197 for nchunk_out in range (result .schunk .nchunks ):
176198 result .schunk .update_data (nchunk_out , data , copy = False )
177- except Exception as e :
178- raise Exception from e
199+ except Exception as exc :
200+ warnings .warn (
201+ f"Fast matmul path unavailable; falling back to chunked path: { exc } " , RuntimeWarning
202+ )
203+ _matmul_chunked (x1 , x2 , result , n , m , k )
179204 finally :
180205 if prefilter_set :
181206 result .schunk .remove_prefilter ("miniexpr" )
182- else : # couldn't do multithreading
183- print ("multithreading failed :( " )
184- p , q = result .chunks [- 2 :]
185- r = x2 .chunks [- 1 ]
186-
187- intersecting_chunks = get_intersecting_chunks ((), result .shape [:- 2 ], result .chunks [:- 2 ])
188- for chunk in intersecting_chunks :
189- chunk = chunk .raw
190- for row in range (0 , n , p ):
191- row_end = builtins .min (row + p , n )
192- for col in range (0 , m , q ):
193- col_end = builtins .min (col + q , m )
194- for aux in range (0 , k , r ):
195- aux_end = builtins .min (aux + r , k )
196- bx1 = (
197- x1 [chunk [- x1 .ndim + 2 :] + (slice (row , row_end ), slice (aux , aux_end ))]
198- if x1 .ndim > 2
199- else x1 [row :row_end , aux :aux_end ]
200- )
201- bx2 = (
202- x2 [chunk [- x2 .ndim + 2 :] + (slice (aux , aux_end ), slice (col , col_end ))]
203- if x2 .ndim > 2
204- else x2 [aux :aux_end , col :col_end ]
205- )
206- result [chunk + (slice (row , row_end ), slice (col , col_end ))] += np .matmul (bx1 , bx2 )
207+ else :
208+ _matmul_chunked (x1 , x2 , result , n , m , k )
207209
208210 if x1_is_vector :
209211 result = result .squeeze (axis = - 2 )
0 commit comments