@@ -2369,12 +2369,13 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
23692369 cdef int blocknitems[2 ]
23702370 cdef int startA, startB, expected_blocknitems
23712371 cdef blosc2_context* dctx
2372- cdef int i, j, block_i, block_j, ncols, block_ncols, Bblock_ncols, Bncols
2372+ cdef int i, j, block_i, block_j, chunk_i, chunk_j, ncols, block_ncols, Bblock_ncols, Bncols, Ablock_ncols, Ancols
23732373 cdef int nchunkA = 0 , nchunkB = 0 , nblockA = 0 , nblockB = 0 , offsetA = 0 , offsetB = 0 , offset = 0
23742374 out_arr = udata.array
23752375 cdef int ndim = out_arr.ndim
23762376 cdef int nchunk_ = nchunk
23772377 cdef int coord, batch, batch_, batches = 1
2378+ cdef int out_chunk_nrows, out_chunk_ncols, out_block_nrows, out_block_ncols
23782379
23792380 # batches = sum(strides[i]*elcoords[i])
23802381 for i in range (ndim - 2 ):
@@ -2388,12 +2389,10 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
23882389 nchunkB += coord * udata.chunks_strides[2 ][i]
23892390
23902391 ncols = udata.chunks_strides[0 ][ndim - 2 ]
2392+ Ancols = udata.chunks_strides[1 ][ndim - 2 ]
23912393 Bncols = udata.chunks_strides[2 ][ndim - 2 ]
2392-
2393- i = nchunk_ // ncols # ncols * i + j
2394- j = nchunk_ % ncols
2395- chunk_startA = nchunkA + i * ncols
2396- chunk_startB = nchunkB + j
2394+ out_chunk_nrows = out_arr.chunkshape[ndim - 2 ]
2395+ out_chunk_ncols = out_arr.chunkshape[ndim - 1 ]
23972396
23982397 # nblock = sum(strides[i]*blockcoords[i])
23992398 cdef int nblock_ = nblock
@@ -2404,18 +2403,14 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
24042403 nblockB += coord * udata.blocks_strides[2 ][i]
24052404
24062405 block_ncols = udata.blocks_strides[0 ][ndim - 2 ]
2406+ Ablock_ncols = udata.blocks_strides[1 ][ndim - 2 ]
24072407 Bblock_ncols = udata.blocks_strides[2 ][ndim - 2 ]
2408-
2409- block_i = nblock_ // block_ncols
2410- block_j = nblock_ % block_ncols
2411- block_startA = nblockA + block_i * block_ncols
2412- block_startB = nblockB + block_j
2408+ out_block_nrows = out_arr.blockshape[ndim - 2 ]
2409+ out_block_ncols = out_arr.blockshape[ndim - 1 ]
24132410
24142411 dctx = blosc2_create_dctx(BLOSC2_DPARAMS_DEFAULTS)
24152412
24162413 first_run = True
2417- nchunkA = chunk_startA
2418- nchunkB = chunk_startB
24192414 while True : # chunk loop
24202415 for i in range (2 ):
24212416 chunk_idx = nchunkA if i == 0 else nchunkB
@@ -2431,16 +2426,28 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
24312426 if i == 0 :
24322427 q = ndarr.blockshape[ndim - 1 ]
24332428 p = ndarr.blockshape[ndim - 2 ]
2429+ # nchunk_ = chunks_in_row * chunk_row + chunk_col
2430+ # convert from chunk_idx to element idx chunk_i (row)
2431+ chunk_i = nchunk_ // ncols * out_chunk_nrows
2432+ chunk_startA = nchunkA + chunk_i // ndarr.chunkshape[ndim - 2 ] * Ancols
2433+ nchunkA = chunk_startA
2434+ # nblock_ = blocks_in_chunkrow * block_row + block_col
2435+ # convert from block_idx to element idx block_i (row)
2436+ block_i = nblock_ // block_ncols * out_block_nrows
2437+ block_startA = nblockA + block_i // p * Ablock_ncols
24342438 else : # i = 1
24352439 r = ndarr.blockshape[ndim - 1 ]
2440+ # convert from chunk_idx to element idx chunk_j (col)
2441+ chunk_j = nchunk_ % ncols * out_chunk_ncols
2442+ chunk_startB = nchunkB + chunk_j // ndarr.chunkshape[ndim - 1 ]
2443+ nchunkB = chunk_startB
2444+ # convert from block_idx to element idx block_j (col)
2445+ block_j = nblock_ % block_ncols * out_block_ncols
2446+ block_startB = nblockB + block_j // r
24362447 input_buffers[i] = malloc(block_nbytes[i])
24372448 if input_buffers[i] == NULL :
24382449 raise MemoryError (" miniexpr: cannot allocate input block buffer" )
24392450 blocknitems[i] = block_nbytes[i] // < int > ndarr.sc.typesize
2440- if i == 0 :
2441- expected_blocknitems = blocknitems[i]
2442- elif blocknitems[i] != expected_blocknitems:
2443- raise ValueError (" miniexpr: inconsistent block element counts across inputs" )
24442451
24452452 first_run = False
24462453 nblockA = block_startA
@@ -2484,11 +2491,11 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
24842491 batch += 1
24852492 nblockA += 1
24862493 nblockB += Bblock_ncols
2487- if (nblockA % block_ncols == 0 ):
2494+ if (nblockA % Ablock_ncols == 0 ):
24882495 break
24892496 nchunkA += 1
24902497 nchunkB += Bncols
2491- if (nchunkA % ncols == 0 ):
2498+ if (nchunkA % Ancols == 0 ):
24922499 break
24932500
24942501
@@ -3472,7 +3479,7 @@ cdef class NDArray:
34723479 cstrides = bstrides = estrides = 1
34733480 for idx in range (2 , self .array.ndim + 1 ):
34743481 i = inp.ndim - idx
3475- if inp.shape[i + 1 ] == 1 or i < 0 :
3482+ if ( inp.shape[i + 1 ] == 1 and i < inp.ndim - 3 ) or i < 0 :
34763483 udata.chunks_strides[j][i] = 0
34773484 udata.blocks_strides[j][i] = 0
34783485 udata.el_strides[j][i] = 0
0 commit comments