@@ -674,6 +674,13 @@ ctypedef struct me_udata:
674674 int64_t blocks_in_chunk[B2ND_MAX_DIM]
675675 me_expr* miniexpr_handle
676676
677+ ctypedef struct mm_udata:
678+ b2nd_array_t** inputs
679+ b2nd_array_t* array
680+ int64_t chunks_strides[3 ][B2ND_MAX_DIM]
681+ int64_t blocks_strides[3 ][B2ND_MAX_DIM]
682+ int64_t el_strides[3 ][B2ND_MAX_DIM]
683+
677684MAX_TYPESIZE = BLOSC2_MAXTYPESIZE
678685MAX_BUFFERSIZE = BLOSC2_MAX_BUFFERSIZE
679686MAX_BLOCKSIZE = BLOSC2_MAXBLOCKSIZE
@@ -2161,12 +2168,12 @@ cdef int matmul_block_kernel(T* A, T* B, T* C, int M, int K, int N) nogil:
21612168 C[rowC + c] += < T> (a * B[rowB + c])
21622169 return 0
21632170
2164- cdef int aux_matmul(me_udata * udata, int64_t nchunk, int32_t nblock, void * params_output, int32_t typesize, int typecode) nogil:
2171+ cdef int aux_matmul(mm_udata * udata, int64_t nchunk, int32_t nblock, void * params_output, int32_t typesize, int typecode) nogil:
21652172 # Declare all C variables at the beginning
21662173 cdef b2nd_array_t* out_arr
21672174 cdef b2nd_array_t* ndarr
21682175 cdef c_bool first_run
2169- cdef int rc, M, K, N
2176+ cdef int rc, p, q, r
21702177 cdef void ** input_buffers = < void ** > malloc(2 * sizeof(uint8_t* ))
21712178 cdef uint8_t** src = < uint8_t** > malloc(2 * sizeof(uint8_t* ))
21722179 cdef int32_t chunk_nbytes[2 ]
@@ -2175,42 +2182,56 @@ cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *param
21752182 cdef int blocknitems[2 ]
21762183 cdef int startA, startB, expected_blocknitems
21772184 cdef blosc2_context* dctx
2178- cdef int base, i, j, nchunkA, nchunkB, nblockA, nblockB, chunk_startA, chunk_startB, block_base, block_i, block_j, block_startA, block_startB, idx, chunk_idx, block_ncols, block_nrows, nblocks_per_2d
2179-
2185+ cdef int i, j, block_i, block_j, ncols, block_ncols, Bblock_ncols, Bncols
2186+ cdef int nchunkA = 0 , nchunkB = 0 , nblockA = 0 , nblockB = 0 , offsetA = 0 , offsetB = 0 , offset = 0
21802187 out_arr = udata.array
21812188 cdef int ndim = out_arr.ndim
2182- cdef int ncols = < int > udata.chunks_in_array[ndim - 1 ]
2183- cdef int nrows = < int > udata.chunks_in_array[ndim - 2 ]
2184- cdef int nchunks_per_2d = ncols * nrows
2185-
2186- block_ncols = < int > udata.blocks_in_chunk[ndim - 1 ]
2187- block_nrows = < int > udata.blocks_in_chunk[ndim - 2 ]
2188- nblocks_per_2d = block_ncols * block_nrows
2189-
2190- # nchunk = base * nchunks_per2d + i * ncols + j
2191- base = nchunk // nchunks_per_2d
2192- i = (nchunk % nchunks_per_2d) // ncols
2193- j = nchunk % ncols
2194- nchunkA = chunk_startA = nchunk - j
2195- nchunkB = chunk_startB = nchunk - i * ncols
2196-
2197- # nblock = block_base * nblocks_per_2d + block_i * block_ncols + block_j
2198- block_base = nblock // nblocks_per_2d
2199- block_i = (nblock % nblocks_per_2d) // block_ncols
2200- block_j = nblock % block_ncols
2201- block_startA = nblock - block_j
2202- block_startB = nblock - block_i * block_ncols
2189+ cdef int nchunk_ = nchunk
2190+ cdef int coord, batch, batch_, batches = 1
2191+ for i in range (ndim - 2 ):
2192+ batches *= out_arr.shape[i]
2193+
2194+ # nchunk = sum(strides[i]*chunkcoords[i])
2195+ for i in range (ndim - 2 ):
2196+ coord = nchunk_ // udata.chunks_strides[0 ][i]
2197+ nchunk_ = nchunk_ % udata.chunks_strides[0 ][i]
2198+ nchunkA += coord * udata.chunks_strides[1 ][i]
2199+ nchunkB += coord * udata.chunks_strides[2 ][i]
2200+
2201+ ncols = udata.chunks_strides[0 ][ndim - 2 ]
2202+ Bncols = udata.chunks_strides[2 ][ndim - 2 ]
2203+
2204+ i = nchunk_ // ncols # ncols * i + j
2205+ j = nchunk_ % ncols
2206+ nchunkA = chunk_startA = nchunkA + i * ncols
2207+ nchunkB = chunk_startB = nchunkB + j
2208+
2209+ # nblock = sum(strides[i]*blockcoords[i])
2210+ cdef int nblock_ = nblock
2211+ for i in range (ndim - 2 ):
2212+ coord = nblock_ // udata.blocks_strides[0 ][i]
2213+ nblock_ = nblock_ % udata.blocks_strides[0 ][i]
2214+ nblockA += coord * udata.blocks_strides[1 ][i]
2215+ nblockB += coord * udata.blocks_strides[2 ][i]
2216+
2217+ block_ncols = udata.blocks_strides[0 ][ndim - 2 ]
2218+ Bblock_ncols = udata.blocks_strides[2 ][ndim - 2 ]
2219+
2220+ block_i = nblock_ // block_ncols
2221+ block_j = nblock_ % block_ncols
2222+ block_startA = nblockA = nblockA + i * block_ncols
2223+ block_startB = nblockB = nblockB + j
2224+
2225+ # batches = sum(strides[i]*elcoords[i])
22032226 dctx = blosc2_create_dctx(BLOSC2_DPARAMS_DEFAULTS)
22042227
22052228 first_run = True
22062229
22072230 while True : # chunk loop
2208- printf(" chunks: %i , %i \n " , nchunkA, nchunkB)
2209- nblockA = block_startA
2210- nblockB = block_startB
22112231 for i in range (2 ):
22122232 chunk_idx = nchunkA if i == 0 else nchunkB
22132233 ndarr = udata.inputs[i]
2234+ ndim = ndarr.ndim
22142235 src[i] = ndarr.sc.data[chunk_idx]
22152236 rc = blosc2_cbuffer_sizes(src[i], & chunk_nbytes[i], & chunk_cbytes[i], & block_nbytes[i])
22162237 if rc < 0 :
@@ -2219,10 +2240,10 @@ cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *param
22192240 raise ValueError (" miniexpr: invalid block size" )
22202241 if first_run:
22212242 if i == 0 :
2222- K = ndarr.blockshape[ndim - 1 ]
2223- M = ndarr.blockshape[ndim - 2 ]
2243+ q = ndarr.blockshape[ndim - 1 ]
2244+ p = ndarr.blockshape[ndim - 2 ]
22242245 else : # i = 1
2225- N = ndarr.blockshape[ndim - 1 ]
2246+ r = ndarr.blockshape[ndim - 1 ]
22262247 input_buffers[i] = malloc(block_nbytes[i])
22272248 if input_buffers[i] == NULL :
22282249 raise MemoryError (" miniexpr: cannot allocate input block buffer" )
@@ -2233,8 +2254,9 @@ cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *param
22332254 raise ValueError (" miniexpr: inconsistent block element counts across inputs" )
22342255
22352256 first_run = False
2257+ nblockA = block_startA
2258+ nblockB = block_startB
22362259 while True : # block loop
2237- printf(" blocks: %i , %i \n " , nblockA, nblockB)
22382260 startA = nblockA * blocknitems[0 ]
22392261 startB = nblockB * blocknitems[1 ]
22402262 rc = blosc2_getitem_ctx(dctx, src[0 ], chunk_cbytes[0 ], startA, blocknitems[0 ],
@@ -2245,28 +2267,37 @@ cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *param
22452267 input_buffers[1 ], block_nbytes[1 ])
22462268 if rc < 0 :
22472269 raise ValueError (" matmul: error decompressing the B chunk" )
2248- if typecode == 0 :
2249- if typesize == 4 :
2250- rc = matmul_block_kernel[float ](< float * > input_buffers[0 ], < float * > input_buffers[1 ], < float * > params_output, M, K, N)
2251- else :
2252- rc = matmul_block_kernel[double ](< double * > input_buffers[0 ], < double * > input_buffers[1 ], < double * > params_output, M, K, N)
2253- elif typecode == 1 :
2254- if typesize == 4 :
2255- rc = matmul_block_kernel[int32_t](< int32_t* > input_buffers[0 ], < int32_t* > input_buffers[1 ], < int32_t* > params_output, M, K, N)
2270+ batch = 0
2271+ while batch < batches:
2272+ batch_ = batch
2273+ for i in range (ndim - 2 ):
2274+ coord = batch // udata.el_strides[0 ][i]
2275+ batch_ = batch_ % udata.el_strides[0 ][i]
2276+ offsetA += coord * udata.el_strides[1 ][i]
2277+ offsetB += coord * udata.el_strides[2 ][i]
2278+ offset += coord * udata.el_strides[0 ][i]
2279+ if typecode == 0 :
2280+ if typesize == 4 :
2281+ rc = matmul_block_kernel[float ](< float * > input_buffers[0 ] + offsetA, < float * > input_buffers[1 ] + offsetB, < float * > params_output + offset, p, q, r)
2282+ else :
2283+ rc = matmul_block_kernel[double ](< double * > input_buffers[0 ] + offsetA, < double * > input_buffers[1 ] + offsetB, < double * > params_output + offset, p, q, r)
2284+ elif typecode == 1 :
2285+ if typesize == 4 :
2286+ rc = matmul_block_kernel[int32_t](< int32_t* > input_buffers[0 ] + offsetA, < int32_t* > input_buffers[1 ] + offsetB, < int32_t* > params_output + offset, p, q, r)
2287+ else :
2288+ rc = matmul_block_kernel[int64_t](< int64_t* > input_buffers[0 ] + offsetA, < int64_t* > input_buffers[1 ] + offsetB, < int64_t* > params_output + offset, p, q, r)
22562289 else :
2257- rc = matmul_block_kernel[int64_t](< int64_t* > input_buffers[0 ], < int64_t* > input_buffers[1 ], < int64_t* > params_output, M, K, N)
2258- else :
2259- with gil:
2260- raise ValueError (" Unsupported dtype" )
2290+ with gil:
2291+ raise ValueError (" Unsupported dtype" )
2292+ batch += 1
22612293 nblockA += 1
2262- nblockB += block_ncols
2294+ nblockB += Bblock_ncols
22632295 if (nblockA % block_ncols == 0 ):
22642296 break
22652297 nchunkA += 1
2266- nchunkB += ncols
2298+ nchunkB += Bncols
22672299 if (nchunkA % ncols == 0 ):
22682300 break
2269- printf(" finished block %i for chunk %i \n " , nblock, nchunk)
22702301
22712302
22722303 blosc2_free_ctx(dctx)
@@ -2358,7 +2389,7 @@ cdef int miniexpr_prefilter(blosc2_prefilter_params *params):
23582389cdef int matmul_prefilter(blosc2_prefilter_params * params):
23592390 cdef int typecode
23602391
2361- cdef me_udata * udata = < me_udata * > params.user_data
2392+ cdef mm_udata * udata = < mm_udata * > params.user_data
23622393 cdef b2nd_array_t* out_arr = udata.array
23632394 cdef char dtype_kind = out_arr.dtype[1 ]
23642395 if dtype_kind == ' f' :
@@ -3215,6 +3246,48 @@ cdef class NDArray:
32153246
32163247 return udata
32173248
3249+ cdef mm_udata * _fill_mm_udata(self , inputs):
3250+ cdef mm_udata * udata = < mm_udata * > malloc(sizeof(mm_udata))
3251+ cdef int cstrides, bstrides, estrides
3252+ cdef b2nd_array_t* inp
3253+ cdef b2nd_array_t** inputs_ = < b2nd_array_t** > malloc(2 * sizeof(b2nd_array_t* ))
3254+ for i in range (2 ):
3255+ operand = inputs[' x1' ] if i == 0 else inputs[' x2' ]
3256+ inputs_[i] = < b2nd_array_t* >< uintptr_t> operand.c_array
3257+ inputs_[i].chunk_cache.nchunk = - 1
3258+ inputs_[i].chunk_cache.data = NULL
3259+ udata.inputs = inputs_
3260+ udata.array = self .array
3261+
3262+ # Save these in udf_udata to avoid computing them for each block
3263+ for i in range (3 ):
3264+ udata.chunks_strides[i][self .array.ndim - 1 ] = 1
3265+ udata.blocks_strides[i][self .array.ndim - 1 ] = 1
3266+ udata.el_strides[i][self .array.ndim - 1 ] = 1
3267+ for idx in range (2 , self .array.ndim + 1 ):
3268+ i = self .array.ndim - idx
3269+ udata.chunks_strides[0 ][i] = udata.chunks_strides[0 ][i + 1 ] * udata.array.extshape[i + 1 ] // udata.array.chunkshape[i + 1 ]
3270+ udata.blocks_strides[0 ][i] = udata.blocks_strides[0 ][i + 1 ] * udata.array.extchunkshape[i + 1 ] // udata.array.blockshape[i + 1 ]
3271+
3272+ for j in range (2 ):
3273+ inp = inputs_[j]
3274+ cstrides = bstrides = estrides = 1
3275+ for idx in range (2 , self .array.ndim + 1 ):
3276+ i = inp.ndim - idx
3277+ if inp.shape[i + 1 ] == 1 or i < 0 :
3278+ udata.chunks_strides[j][i] = 0
3279+ udata.blocks_strides[j][i] = 0
3280+ udata.el_strides[j][i] = 0
3281+ else :
3282+ bstrides *= inp.extchunkshape[i + 1 ] // inp.blockshape[i + 1 ]
3283+ cstrides *= inp.extshape[i + 1 ] // inp.chunkshape[i + 1 ]
3284+ estrides *= inp.blockshape[i + 1 ]
3285+ udata.chunks_strides[j][i] = cstrides
3286+ udata.blocks_strides[j][i] = bstrides
3287+ udata.el_strides[j][i] = estrides
3288+
3289+ return udata
3290+
32183291 def _set_pref_expr (self , expression , inputs , fp_accuracy , aux_reduc = None , jit = None ):
32193292 # Set prefilter for miniexpr
32203293 cdef blosc2_cparams* cparams = self .array.sc.storage.cparams
@@ -3305,7 +3378,7 @@ cdef class NDArray:
33053378 cdef blosc2_cparams* cparams = self .array.sc.storage.cparams
33063379 cparams.prefilter = < blosc2_prefilter_fn> matmul_prefilter
33073380
3308- cdef me_udata * udata = self ._fill_me_udata (inputs, fp_accuracy, aux_reduc = None )
3381+ cdef mm_udata * udata = self ._fill_mm_udata (inputs)
33093382 cdef b2nd_array_t* out_arr = udata.array
33103383 cdef blosc2_prefilter_params* preparams = < blosc2_prefilter_params * > calloc(1 , sizeof(blosc2_prefilter_params))
33113384 preparams.user_data = udata
0 commit comments