@@ -115,61 +115,73 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
115115 # multithreaded matmul
116116 # TODO: handle a) type promotion, b) padding (explicitly), c) (improved) >2D
117117 ops = (x1 , x2 , result )
118- blocks = result .blocks
119118 all_ndarray = all (isinstance (value , blosc2 .NDArray ) and value .shape != () for value in ops )
120119 global try_miniexpr
121120
122121 # Use a local copy so we don't modify the global
123122 use_miniexpr = try_miniexpr
124- if all_ndarray :
125- if any (op .dtype != ops [0 ].dtype for op in ops ): # TODO: Remove this condition
123+ if 0 not in result .shape + x1 .shape + x2 .shape : # if any array is empty, return array of 0s
124+ if all_ndarray :
125+ if any (op .dtype != ops [0 ].dtype for op in ops ): # TODO: Remove this condition
126+ use_miniexpr = False
127+
128+ # Just force same chunk/block shapes
129+ same_chunks = all (op .chunks == result .chunks for op in (x1 , x2 ))
130+ same_blocks = all (op .blocks == result .blocks for op in (x1 , x2 ))
131+ same_shape = all (op .shape == result .shape for op in (x1 , x2 ))
132+
133+ use_miniexpr &= same_blocks & same_chunks & same_shape
134+
135+ # TODO: We can relax this to even just load according to result blockshape, but that's difficult.
136+ # Two easier cases are presented below
137+ # Case 1: Might want to restrict loading across chunk boundaries, in which case would require:
138+ # x1.chunks[-2] % result.blocks[-2] == 0
139+ # x2.chunks[-1] % result.blocks[-1] == 0
140+ # x2.chunks[-2] % x1.blocks[-1] == 0
141+ # Can then load in x1 as slices of size [result.blocks[-2], x1.blocks[-1]]
142+ # and x2 in slices of [x1.blocks[-1], result.blocks[-1]]
143+
144+ # Case 2: Slightly easier to implement this maybe
145+ # Require that blocks are matmul compatible and broadcastable directly to result
146+ # (M, K) x (K, N) = (M, N)
147+ # so can load block-by-block for inputs and calculate block of output
148+ # Also need to avoid loading across chunk boundaries
149+ # chunks_aligned = x1.chunks[-2] % x1.blocks[-2] == 0
150+ # chunks_aligned &= x2.chunks[-1] % x2.blocks[-1] == 0
151+ # chunks_aligned &= x2.chunks[-2] % x1.blocks[-1] == 0
152+ # same_blocks = x2.blocks[-2] == x1.blocks[-1]
153+ # same_blocks &= x2.blocks[-1] == result.blocks[-1]
154+ # same_blocks &= result.blocks[-2] == x1.blocks[-2]
155+ # try:
156+ # result_blocks = np.broadcast_shapes(x1.blocks, x2.blocks)
157+ # if not (same_blocks and chunks_aligned and result_blocks[:-2] == result.blocks[:-2]):
158+ # use_miniexpr = False
159+ # except ValueError:
160+ # use_miniexpr = False
161+
162+ use_miniexpr &= x1 .dtype .kind in ("i" , "f" )
163+ use_miniexpr &= x2 .dtype .kind in ("i" , "f" )
164+ use_miniexpr &= x1 .dtype == x2 .dtype
165+
166+ else :
126167 use_miniexpr = False
127168
128- # TODO: In fact the following can be relaxed too, just need to load across block boundaries
129- # Might want to restrict loading across chunk boundaries, in which case would require:
130- # x1.chunks[-2] % result.blocks[-2] == 0
131- # x2.chunks[-1] % result.blocks[-1] == 0
132- # x2.chunks[-2] % x1.blocks[-1] == 0
133- # Can then load in x1 as slices of size [result.blocks[-2], x1.blocks[-1]]
134- # and x2 in slices of [x1.blocks[-1], result.blocks[-1]]
135-
136- # Require that blocks are matmul compatible and broadcastable directly to result
137- # (M, K) x (K, N) = (M, N)
138- # so can load block-by-block for inputs and calculate block of output
139- # Also need to avoid loading across chunk boundaries
140- chunks_aligned = x1 .chunks [- 2 ] % x1 .blocks [- 2 ] == 0
141- chunks_aligned &= x2 .chunks [- 1 ] % x2 .blocks [- 1 ] == 0
142- chunks_aligned &= x2 .chunks [- 2 ] % x1 .blocks [- 1 ] == 0
143- same_blocks = x2 .blocks [- 2 ] == x1 .blocks [- 1 ]
144- same_blocks &= x2 .blocks [- 1 ] == result .blocks [- 1 ]
145- same_blocks &= result .blocks [- 2 ] == x1 .blocks [- 2 ]
146- try :
147- result_blocks = np .broadcast_shapes (x1 .blocks , x2 .blocks )
148- except ValueError :
149- use_miniexpr = False
150- if not (same_blocks and chunks_aligned and result_blocks [:- 2 ] == blocks [:- 2 ]):
151- use_miniexpr = False
152-
153- else :
154- use_miniexpr = False
155-
156- if use_miniexpr :
157- prefilter_set = False
158- try :
159- result ._set_pref_matmul ({"x1" : x1 , "x2" : x2 }, fp_accuracy = blosc2 .FPAccuracy .DEFAULT )
160- prefilter_set = True
161- # Data to compress is fetched from operands, so it can be uninitialized here
162- data = np .empty (result .schunk .chunksize , dtype = np .uint8 )
163- for nchunk_out in range (result .schunk .nchunks ):
164- result .schunk .update_data (nchunk_out , data , copy = False )
165- except Exception as e :
166- raise Exception from e
167- finally :
168- if prefilter_set :
169- result .schunk .remove_prefilter ("miniexpr" )
170- else : # couldn't do multithreading
171- print ("multithreading failed :( " )
172- if 0 not in result .shape + x1 .shape + x2 .shape : # if any array is empty, return array of 0s
169+ if use_miniexpr :
170+ prefilter_set = False
171+ try :
172+ result ._set_pref_matmul ({"x1" : x1 , "x2" : x2 }, fp_accuracy = blosc2 .FPAccuracy .DEFAULT )
173+ prefilter_set = True
174+ # Data to compress is fetched from operands, so it can be uninitialized here
175+ data = np .empty (result .schunk .chunksize , dtype = np .uint8 )
176+ for nchunk_out in range (result .schunk .nchunks ):
177+ result .schunk .update_data (nchunk_out , data , copy = False )
178+ except Exception as e :
179+ raise Exception from e
180+ finally :
181+ if prefilter_set :
182+ result .schunk .remove_prefilter ("miniexpr" )
183+ else : # couldn't do multithreading
184+ print ("multithreading failed :( " )
173185 p , q = result .chunks [- 2 :]
174186 r = x2 .chunks [- 1 ]
175187
0 commit comments