@@ -2375,8 +2375,10 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
23752375 cdef int ndim = out_arr.ndim
23762376 cdef int nchunk_ = nchunk
23772377 cdef int coord, batch, batch_, batches = 1
2378+
2379+ # batches = sum(strides[i]*elcoords[i])
23782380 for i in range (ndim - 2 ):
2379- batches *= out_arr.shape [i]
2381+ batches *= out_arr.blockshape [i]
23802382
23812383 # nchunk = sum(strides[i]*chunkcoords[i])
23822384 for i in range (ndim - 2 ):
@@ -2390,8 +2392,8 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
23902392
23912393 i = nchunk_ // ncols # ncols * i + j
23922394 j = nchunk_ % ncols
2393- nchunkA = chunk_startA = nchunkA + i * ncols
2394- nchunkB = chunk_startB = nchunkB + j
2395+ chunk_startA = nchunkA + i * ncols
2396+ chunk_startB = nchunkB + j
23952397
23962398 # nblock = sum(strides[i]*blockcoords[i])
23972399 cdef int nblock_ = nblock
@@ -2406,14 +2408,14 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
24062408
24072409 block_i = nblock_ // block_ncols
24082410 block_j = nblock_ % block_ncols
2409- block_startA = nblockA = nblockA + i * block_ncols
2410- block_startB = nblockB = nblockB + j
2411+ block_startA = nblockA + block_i * block_ncols
2412+ block_startB = nblockB + block_j
24112413
2412- # batches = sum(strides[i]*elcoords[i])
24132414 dctx = blosc2_create_dctx(BLOSC2_DPARAMS_DEFAULTS)
24142415
24152416 first_run = True
2416-
2417+ nchunkA = chunk_startA
2418+ nchunkB = chunk_startB
24172419 while True : # chunk loop
24182420 for i in range (2 ):
24192421 chunk_idx = nchunkA if i == 0 else nchunkB
@@ -2455,6 +2457,9 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
24552457 if rc < 0 :
24562458 raise ValueError (" matmul: error decompressing the B chunk" )
24572459 batch = 0
2460+ offsetA = 0
2461+ offsetB = 0
2462+ offset = 0
24582463 while batch < batches:
24592464 batch_ = batch
24602465 for i in range (ndim - 2 ):
@@ -3460,9 +3465,10 @@ cdef class NDArray:
34603465 i = self .array.ndim - idx
34613466 udata.chunks_strides[0 ][i] = udata.chunks_strides[0 ][i + 1 ] * udata.array.extshape[i + 1 ] // udata.array.chunkshape[i + 1 ]
34623467 udata.blocks_strides[0 ][i] = udata.blocks_strides[0 ][i + 1 ] * udata.array.extchunkshape[i + 1 ] // udata.array.blockshape[i + 1 ]
3468+ udata.el_strides[0 ][i] = udata.el_strides[0 ][i + 1 ] * udata.array.blockshape[i + 1 ]
34633469
3464- for j in range (2 ):
3465- inp = inputs_[j]
3470+ for j in range (1 , 3 ):
3471+ inp = inputs_[j - 1 ]
34663472 cstrides = bstrides = estrides = 1
34673473 for idx in range (2 , self .array.ndim + 1 ):
34683474 i = inp.ndim - idx
0 commit comments