@@ -2188,8 +2188,10 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
21882188 cdef int ndim = out_arr.ndim
21892189 cdef int nchunk_ = nchunk
21902190 cdef int coord, batch, batch_, batches = 1
2191+
2192+ # batches = sum(strides[i]*elcoords[i])
21912193 for i in range (ndim - 2 ):
2192- batches *= out_arr.shape [i]
2194+ batches *= out_arr.blockshape [i]
21932195
21942196 # nchunk = sum(strides[i]*chunkcoords[i])
21952197 for i in range (ndim - 2 ):
@@ -2203,8 +2205,8 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
22032205
22042206 i = nchunk_ // ncols # ncols * i + j
22052207 j = nchunk_ % ncols
2206- nchunkA = chunk_startA = nchunkA + i * ncols
2207- nchunkB = chunk_startB = nchunkB + j
2208+ chunk_startA = nchunkA + i * ncols
2209+ chunk_startB = nchunkB + j
22082210
22092211 # nblock = sum(strides[i]*blockcoords[i])
22102212 cdef int nblock_ = nblock
@@ -2219,14 +2221,14 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
22192221
22202222 block_i = nblock_ // block_ncols
22212223 block_j = nblock_ % block_ncols
2222- block_startA = nblockA = nblockA + i * block_ncols
2223- block_startB = nblockB = nblockB + j
2224+ block_startA = nblockA + block_i * block_ncols
2225+ block_startB = nblockB + block_j
22242226
2225- # batches = sum(strides[i]*elcoords[i])
22262227 dctx = blosc2_create_dctx(BLOSC2_DPARAMS_DEFAULTS)
22272228
22282229 first_run = True
2229-
2230+ nchunkA = chunk_startA
2231+ nchunkB = chunk_startB
22302232 while True : # chunk loop
22312233 for i in range (2 ):
22322234 chunk_idx = nchunkA if i == 0 else nchunkB
@@ -2268,6 +2270,9 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
22682270 if rc < 0 :
22692271 raise ValueError (" matmul: error decompressing the B chunk" )
22702272 batch = 0
2273+ offsetA = 0
2274+ offsetB = 0
2275+ offset = 0
22712276 while batch < batches:
22722277 batch_ = batch
22732278 for i in range (ndim - 2 ):
@@ -3268,9 +3273,10 @@ cdef class NDArray:
32683273 i = self .array.ndim - idx
32693274 udata.chunks_strides[0 ][i] = udata.chunks_strides[0 ][i + 1 ] * udata.array.extshape[i + 1 ] // udata.array.chunkshape[i + 1 ]
32703275 udata.blocks_strides[0 ][i] = udata.blocks_strides[0 ][i + 1 ] * udata.array.extchunkshape[i + 1 ] // udata.array.blockshape[i + 1 ]
3276+ udata.el_strides[0 ][i] = udata.el_strides[0 ][i + 1 ] * udata.array.blockshape[i + 1 ]
32713277
3272- for j in range (2 ):
3273- inp = inputs_[j]
3278+ for j in range (1 , 3 ):
3279+ inp = inputs_[j - 1 ]
32743280 cstrides = bstrides = estrides = 1
32753281 for idx in range (2 , self .array.ndim + 1 ):
32763282 i = inp.ndim - idx
0 commit comments