@@ -688,6 +688,13 @@ ctypedef struct me_udata:
688688 int64_t blocks_in_chunk[B2ND_MAX_DIM]
689689 me_expr* miniexpr_handle
690690
691+ ctypedef struct mm_udata:
692+ b2nd_array_t** inputs
693+ b2nd_array_t* array
694+ int64_t chunks_strides[3 ][B2ND_MAX_DIM]
695+ int64_t blocks_strides[3 ][B2ND_MAX_DIM]
696+ int64_t el_strides[3 ][B2ND_MAX_DIM]
697+
691698MAX_TYPESIZE = BLOSC2_MAXTYPESIZE
692699MAX_BUFFERSIZE = BLOSC2_MAX_BUFFERSIZE
693700MAX_BLOCKSIZE = BLOSC2_MAXBLOCKSIZE
@@ -2348,12 +2355,12 @@ cdef int matmul_block_kernel(T* A, T* B, T* C, int M, int K, int N) nogil:
23482355 C[rowC + c] += < T> (a * B[rowB + c])
23492356 return 0
23502357
2351- cdef int aux_matmul(me_udata * udata, int64_t nchunk, int32_t nblock, void * params_output, int32_t typesize, int typecode) nogil:
2358+ cdef int aux_matmul(mm_udata * udata, int64_t nchunk, int32_t nblock, void * params_output, int32_t typesize, int typecode) nogil:
23522359 # Declare all C variables at the beginning
23532360 cdef b2nd_array_t* out_arr
23542361 cdef b2nd_array_t* ndarr
23552362 cdef c_bool first_run
2356- cdef int rc, M, K, N
2363+ cdef int rc, p, q, r
23572364 cdef void ** input_buffers = < void ** > malloc(2 * sizeof(uint8_t* ))
23582365 cdef uint8_t** src = < uint8_t** > malloc(2 * sizeof(uint8_t* ))
23592366 cdef int32_t chunk_nbytes[2 ]
@@ -2362,42 +2369,56 @@ cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *param
23622369 cdef int blocknitems[2 ]
23632370 cdef int startA, startB, expected_blocknitems
23642371 cdef blosc2_context* dctx
2365- cdef int base, i, j, nchunkA, nchunkB, nblockA, nblockB, chunk_startA, chunk_startB, block_base, block_i, block_j, block_startA, block_startB, idx, chunk_idx, block_ncols, block_nrows, nblocks_per_2d
2366-
2372+ cdef int i, j, block_i, block_j, ncols, block_ncols, Bblock_ncols, Bncols
2373+ cdef int nchunkA = 0 , nchunkB = 0 , nblockA = 0 , nblockB = 0 , offsetA = 0 , offsetB = 0 , offset = 0
23672374 out_arr = udata.array
23682375 cdef int ndim = out_arr.ndim
2369- cdef int ncols = < int > udata.chunks_in_array[ndim - 1 ]
2370- cdef int nrows = < int > udata.chunks_in_array[ndim - 2 ]
2371- cdef int nchunks_per_2d = ncols * nrows
2372-
2373- block_ncols = < int > udata.blocks_in_chunk[ndim - 1 ]
2374- block_nrows = < int > udata.blocks_in_chunk[ndim - 2 ]
2375- nblocks_per_2d = block_ncols * block_nrows
2376-
2377- # nchunk = base * nchunks_per2d + i * ncols + j
2378- base = nchunk // nchunks_per_2d
2379- i = (nchunk % nchunks_per_2d) // ncols
2380- j = nchunk % ncols
2381- nchunkA = chunk_startA = nchunk - j
2382- nchunkB = chunk_startB = nchunk - i * ncols
2383-
2384- # nblock = block_base * nblocks_per_2d + block_i * block_ncols + block_j
2385- block_base = nblock // nblocks_per_2d
2386- block_i = (nblock % nblocks_per_2d) // block_ncols
2387- block_j = nblock % block_ncols
2388- block_startA = nblock - block_j
2389- block_startB = nblock - block_i * block_ncols
2376+ cdef int nchunk_ = nchunk
2377+ cdef int coord, batch, batch_, batches = 1
2378+ for i in range (ndim - 2 ):
2379+ batches *= out_arr.shape[i]
2380+
2381+ # nchunk = sum(strides[i]*chunkcoords[i])
2382+ for i in range (ndim - 2 ):
2383+ coord = nchunk_ // udata.chunks_strides[0 ][i]
2384+ nchunk_ = nchunk_ % udata.chunks_strides[0 ][i]
2385+ nchunkA += coord * udata.chunks_strides[1 ][i]
2386+ nchunkB += coord * udata.chunks_strides[2 ][i]
2387+
2388+ ncols = udata.chunks_strides[0 ][ndim - 2 ]
2389+ Bncols = udata.chunks_strides[2 ][ndim - 2 ]
2390+
2391+ i = nchunk_ // ncols # ncols * i + j
2392+ j = nchunk_ % ncols
2393+ nchunkA = chunk_startA = nchunkA + i * ncols
2394+ nchunkB = chunk_startB = nchunkB + j
2395+
2396+ # nblock = sum(strides[i]*blockcoords[i])
2397+ cdef int nblock_ = nblock
2398+ for i in range (ndim - 2 ):
2399+ coord = nblock_ // udata.blocks_strides[0 ][i]
2400+ nblock_ = nblock_ % udata.blocks_strides[0 ][i]
2401+ nblockA += coord * udata.blocks_strides[1 ][i]
2402+ nblockB += coord * udata.blocks_strides[2 ][i]
2403+
2404+ block_ncols = udata.blocks_strides[0 ][ndim - 2 ]
2405+ Bblock_ncols = udata.blocks_strides[2 ][ndim - 2 ]
2406+
2407+ block_i = nblock_ // block_ncols
2408+ block_j = nblock_ % block_ncols
2409+ block_startA = nblockA = nblockA + i * block_ncols
2410+ block_startB = nblockB = nblockB + j
2411+
2412+ # batches = sum(strides[i]*elcoords[i])
23902413 dctx = blosc2_create_dctx(BLOSC2_DPARAMS_DEFAULTS)
23912414
23922415 first_run = True
23932416
23942417 while True : # chunk loop
2395- printf(" chunks: %i , %i \n " , nchunkA, nchunkB)
2396- nblockA = block_startA
2397- nblockB = block_startB
23982418 for i in range (2 ):
23992419 chunk_idx = nchunkA if i == 0 else nchunkB
24002420 ndarr = udata.inputs[i]
2421+ ndim = ndarr.ndim
24012422 src[i] = ndarr.sc.data[chunk_idx]
24022423 rc = blosc2_cbuffer_sizes(src[i], & chunk_nbytes[i], & chunk_cbytes[i], & block_nbytes[i])
24032424 if rc < 0 :
@@ -2406,10 +2427,10 @@ cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *param
24062427 raise ValueError (" miniexpr: invalid block size" )
24072428 if first_run:
24082429 if i == 0 :
2409- K = ndarr.blockshape[ndim - 1 ]
2410- M = ndarr.blockshape[ndim - 2 ]
2430+ q = ndarr.blockshape[ndim - 1 ]
2431+ p = ndarr.blockshape[ndim - 2 ]
24112432 else : # i = 1
2412- N = ndarr.blockshape[ndim - 1 ]
2433+ r = ndarr.blockshape[ndim - 1 ]
24132434 input_buffers[i] = malloc(block_nbytes[i])
24142435 if input_buffers[i] == NULL :
24152436 raise MemoryError (" miniexpr: cannot allocate input block buffer" )
@@ -2420,8 +2441,9 @@ cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *param
24202441 raise ValueError (" miniexpr: inconsistent block element counts across inputs" )
24212442
24222443 first_run = False
2444+ nblockA = block_startA
2445+ nblockB = block_startB
24232446 while True : # block loop
2424- printf(" blocks: %i , %i \n " , nblockA, nblockB)
24252447 startA = nblockA * blocknitems[0 ]
24262448 startB = nblockB * blocknitems[1 ]
24272449 rc = blosc2_getitem_ctx(dctx, src[0 ], chunk_cbytes[0 ], startA, blocknitems[0 ],
@@ -2432,28 +2454,37 @@ cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *param
24322454 input_buffers[1 ], block_nbytes[1 ])
24332455 if rc < 0 :
24342456 raise ValueError (" matmul: error decompressing the B chunk" )
2435- if typecode == 0 :
2436- if typesize == 4 :
2437- rc = matmul_block_kernel[float ](< float * > input_buffers[0 ], < float * > input_buffers[1 ], < float * > params_output, M, K, N)
2438- else :
2439- rc = matmul_block_kernel[double ](< double * > input_buffers[0 ], < double * > input_buffers[1 ], < double * > params_output, M, K, N)
2440- elif typecode == 1 :
2441- if typesize == 4 :
2442- rc = matmul_block_kernel[int32_t](< int32_t* > input_buffers[0 ], < int32_t* > input_buffers[1 ], < int32_t* > params_output, M, K, N)
2457+ batch = 0
2458+ while batch < batches:
2459+ batch_ = batch
2460+ for i in range (ndim - 2 ):
2461+ coord = batch // udata.el_strides[0 ][i]
2462+ batch_ = batch_ % udata.el_strides[0 ][i]
2463+ offsetA += coord * udata.el_strides[1 ][i]
2464+ offsetB += coord * udata.el_strides[2 ][i]
2465+ offset += coord * udata.el_strides[0 ][i]
2466+ if typecode == 0 :
2467+ if typesize == 4 :
2468+ rc = matmul_block_kernel[float ](< float * > input_buffers[0 ] + offsetA, < float * > input_buffers[1 ] + offsetB, < float * > params_output + offset, p, q, r)
2469+ else :
2470+ rc = matmul_block_kernel[double ](< double * > input_buffers[0 ] + offsetA, < double * > input_buffers[1 ] + offsetB, < double * > params_output + offset, p, q, r)
2471+ elif typecode == 1 :
2472+ if typesize == 4 :
2473+ rc = matmul_block_kernel[int32_t](< int32_t* > input_buffers[0 ] + offsetA, < int32_t* > input_buffers[1 ] + offsetB, < int32_t* > params_output + offset, p, q, r)
2474+ else :
2475+ rc = matmul_block_kernel[int64_t](< int64_t* > input_buffers[0 ] + offsetA, < int64_t* > input_buffers[1 ] + offsetB, < int64_t* > params_output + offset, p, q, r)
24432476 else :
2444- rc = matmul_block_kernel[int64_t](< int64_t* > input_buffers[0 ], < int64_t* > input_buffers[1 ], < int64_t* > params_output, M, K, N)
2445- else :
2446- with gil:
2447- raise ValueError (" Unsupported dtype" )
2477+ with gil:
2478+ raise ValueError (" Unsupported dtype" )
2479+ batch += 1
24482480 nblockA += 1
2449- nblockB += block_ncols
2481+ nblockB += Bblock_ncols
24502482 if (nblockA % block_ncols == 0 ):
24512483 break
24522484 nchunkA += 1
2453- nchunkB += ncols
2485+ nchunkB += Bncols
24542486 if (nchunkA % ncols == 0 ):
24552487 break
2456- printf(" finished block %i for chunk %i \n " , nblock, nchunk)
24572488
24582489
24592490 blosc2_free_ctx(dctx)
@@ -2545,7 +2576,7 @@ cdef int miniexpr_prefilter(blosc2_prefilter_params *params):
25452576cdef int matmul_prefilter(blosc2_prefilter_params * params):
25462577 cdef int typecode
25472578
2548- cdef me_udata * udata = < me_udata * > params.user_data
2579+ cdef mm_udata * udata = < mm_udata * > params.user_data
25492580 cdef b2nd_array_t* out_arr = udata.array
25502581 cdef char dtype_kind = out_arr.dtype[1 ]
25512582 if dtype_kind == ' f' :
@@ -3407,6 +3438,48 @@ cdef class NDArray:
34073438
34083439 return udata
34093440
3441+ cdef mm_udata * _fill_mm_udata(self , inputs):
3442+ cdef mm_udata * udata = < mm_udata * > malloc(sizeof(mm_udata))
3443+ cdef int cstrides, bstrides, estrides
3444+ cdef b2nd_array_t* inp
3445+ cdef b2nd_array_t** inputs_ = < b2nd_array_t** > malloc(2 * sizeof(b2nd_array_t* ))
3446+ for i in range (2 ):
3447+ operand = inputs[' x1' ] if i == 0 else inputs[' x2' ]
3448+ inputs_[i] = < b2nd_array_t* >< uintptr_t> operand.c_array
3449+ inputs_[i].chunk_cache.nchunk = - 1
3450+ inputs_[i].chunk_cache.data = NULL
3451+ udata.inputs = inputs_
3452+ udata.array = self .array
3453+
3454+ # Save these in udf_udata to avoid computing them for each block
3455+ for i in range (3 ):
3456+ udata.chunks_strides[i][self .array.ndim - 1 ] = 1
3457+ udata.blocks_strides[i][self .array.ndim - 1 ] = 1
3458+ udata.el_strides[i][self .array.ndim - 1 ] = 1
3459+ for idx in range (2 , self .array.ndim + 1 ):
3460+ i = self .array.ndim - idx
3461+ udata.chunks_strides[0 ][i] = udata.chunks_strides[0 ][i + 1 ] * udata.array.extshape[i + 1 ] // udata.array.chunkshape[i + 1 ]
3462+ udata.blocks_strides[0 ][i] = udata.blocks_strides[0 ][i + 1 ] * udata.array.extchunkshape[i + 1 ] // udata.array.blockshape[i + 1 ]
3463+
3464+ for j in range (2 ):
3465+ inp = inputs_[j]
3466+ cstrides = bstrides = estrides = 1
3467+ for idx in range (2 , self .array.ndim + 1 ):
3468+ i = inp.ndim - idx
3469+ if inp.shape[i + 1 ] == 1 or i < 0 :
3470+ udata.chunks_strides[j][i] = 0
3471+ udata.blocks_strides[j][i] = 0
3472+ udata.el_strides[j][i] = 0
3473+ else :
3474+ bstrides *= inp.extchunkshape[i + 1 ] // inp.blockshape[i + 1 ]
3475+ cstrides *= inp.extshape[i + 1 ] // inp.chunkshape[i + 1 ]
3476+ estrides *= inp.blockshape[i + 1 ]
3477+ udata.chunks_strides[j][i] = cstrides
3478+ udata.blocks_strides[j][i] = bstrides
3479+ udata.el_strides[j][i] = estrides
3480+
3481+ return udata
3482+
34103483 def _set_pref_expr (self , expression , inputs , fp_accuracy , aux_reduc = None , jit = None ):
34113484 # Set prefilter for miniexpr
34123485 cdef blosc2_cparams* cparams = self .array.sc.storage.cparams
@@ -3497,7 +3570,7 @@ cdef class NDArray:
34973570 cdef blosc2_cparams* cparams = self .array.sc.storage.cparams
34983571 cparams.prefilter = < blosc2_prefilter_fn> matmul_prefilter
34993572
3500- cdef me_udata * udata = self ._fill_me_udata (inputs, fp_accuracy, aux_reduc = None )
3573+ cdef mm_udata * udata = self ._fill_mm_udata (inputs)
35013574 cdef b2nd_array_t* out_arr = udata.array
35023575 cdef blosc2_prefilter_params* preparams = < blosc2_prefilter_params * > calloc(1 , sizeof(blosc2_prefilter_params))
35033576 preparams.user_data = udata
0 commit comments