Skip to content

Commit 7e09737

Browse files
lshaw8317FrancescAlted
authored andcommitted
Extended allowable cases
1 parent 0b7b0b5 commit 7e09737

2 files changed

Lines changed: 44 additions & 38 deletions

File tree

src/blosc2/blosc2_ext.pyx

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2369,12 +2369,13 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
23692369
cdef int blocknitems[2]
23702370
cdef int startA, startB, expected_blocknitems
23712371
cdef blosc2_context* dctx
2372-
cdef int i, j, block_i, block_j, ncols, block_ncols, Bblock_ncols, Bncols
2372+
cdef int i, j, block_i, block_j, chunk_i, chunk_j, ncols, block_ncols, Bblock_ncols, Bncols, Ablock_ncols, Ancols
23732373
cdef int nchunkA = 0, nchunkB = 0, nblockA = 0, nblockB = 0, offsetA = 0, offsetB = 0, offset = 0
23742374
out_arr = udata.array
23752375
cdef int ndim = out_arr.ndim
23762376
cdef int nchunk_ = nchunk
23772377
cdef int coord, batch, batch_, batches = 1
2378+
cdef int out_chunk_nrows, out_chunk_ncols, out_block_nrows, out_block_ncols
23782379

23792380
# batches = sum(strides[i]*elcoords[i])
23802381
for i in range(ndim - 2):
@@ -2388,12 +2389,10 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
23882389
nchunkB += coord * udata.chunks_strides[2][i]
23892390

23902391
ncols = udata.chunks_strides[0][ndim - 2]
2392+
Ancols = udata.chunks_strides[1][ndim - 2]
23912393
Bncols = udata.chunks_strides[2][ndim - 2]
2392-
2393-
i = nchunk_ // ncols # ncols * i + j
2394-
j = nchunk_ % ncols
2395-
chunk_startA = nchunkA + i * ncols
2396-
chunk_startB = nchunkB + j
2394+
out_chunk_nrows = out_arr.chunkshape[ndim - 2]
2395+
out_chunk_ncols = out_arr.chunkshape[ndim - 1]
23972396

23982397
# nblock = sum(strides[i]*blockcoords[i])
23992398
cdef int nblock_ = nblock
@@ -2404,18 +2403,14 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
24042403
nblockB += coord * udata.blocks_strides[2][i]
24052404

24062405
block_ncols = udata.blocks_strides[0][ndim - 2]
2406+
Ablock_ncols = udata.blocks_strides[1][ndim - 2]
24072407
Bblock_ncols = udata.blocks_strides[2][ndim - 2]
2408-
2409-
block_i = nblock_ // block_ncols
2410-
block_j = nblock_ % block_ncols
2411-
block_startA = nblockA + block_i * block_ncols
2412-
block_startB = nblockB + block_j
2408+
out_block_nrows = out_arr.blockshape[ndim - 2]
2409+
out_block_ncols = out_arr.blockshape[ndim - 1]
24132410

24142411
dctx = blosc2_create_dctx(BLOSC2_DPARAMS_DEFAULTS)
24152412

24162413
first_run = True
2417-
nchunkA = chunk_startA
2418-
nchunkB = chunk_startB
24192414
while True: # chunk loop
24202415
for i in range(2):
24212416
chunk_idx = nchunkA if i == 0 else nchunkB
@@ -2431,16 +2426,28 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
24312426
if i == 0:
24322427
q = ndarr.blockshape[ndim - 1]
24332428
p = ndarr.blockshape[ndim - 2]
2429+
# nchunk_ = chunks_in_row * chunk_row + chunk_col
2430+
# convert from chunk_idx to element idx chunk_i (row)
2431+
chunk_i = nchunk_ // ncols * out_chunk_nrows
2432+
chunk_startA = nchunkA + chunk_i // ndarr.chunkshape[ndim - 2] * Ancols
2433+
nchunkA = chunk_startA
2434+
# nblock_ = blocks_in_chunkrow * block_row + block_col
2435+
# convert from block_idx to element idx block_i (row)
2436+
block_i = nblock_ // block_ncols * out_block_nrows
2437+
block_startA = nblockA + block_i // p * Ablock_ncols
24342438
else: # i = 1
24352439
r = ndarr.blockshape[ndim - 1]
2440+
# convert from chunk_idx to element idx chunk_j (col)
2441+
chunk_j = nchunk_ % ncols * out_chunk_ncols
2442+
chunk_startB = nchunkB + chunk_j // ndarr.chunkshape[ndim - 1]
2443+
nchunkB = chunk_startB
2444+
# convert from block_idx to element idx block_j (col)
2445+
block_j = nblock_ % block_ncols * out_block_ncols
2446+
block_startB = nblockB + block_j // r
24362447
input_buffers[i] = malloc(block_nbytes[i])
24372448
if input_buffers[i] == NULL:
24382449
raise MemoryError("miniexpr: cannot allocate input block buffer")
24392450
blocknitems[i] = block_nbytes[i] // <int> ndarr.sc.typesize
2440-
if i == 0:
2441-
expected_blocknitems = blocknitems[i]
2442-
elif blocknitems[i] != expected_blocknitems:
2443-
raise ValueError("miniexpr: inconsistent block element counts across inputs")
24442451

24452452
first_run = False
24462453
nblockA = block_startA
@@ -2484,11 +2491,11 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
24842491
batch += 1
24852492
nblockA += 1
24862493
nblockB += Bblock_ncols
2487-
if (nblockA % block_ncols == 0):
2494+
if (nblockA % Ablock_ncols == 0):
24882495
break
24892496
nchunkA += 1
24902497
nchunkB += Bncols
2491-
if (nchunkA % ncols == 0):
2498+
if (nchunkA % Ancols == 0):
24922499
break
24932500

24942501

@@ -3472,7 +3479,7 @@ cdef class NDArray:
34723479
cstrides = bstrides = estrides = 1
34733480
for idx in range(2, self.array.ndim + 1):
34743481
i = inp.ndim - idx
3475-
if inp.shape[i + 1] == 1 or i < 0:
3482+
if (inp.shape[i + 1] == 1 and i < inp.ndim - 3) or i < 0:
34763483
udata.chunks_strides[j][i] = 0
34773484
udata.blocks_strides[j][i] = 0
34783485
udata.el_strides[j][i] = 0

src/blosc2/linalg.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,13 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
125125
if any(op.dtype != ops[0].dtype for op in ops): # TODO: Remove this condition
126126
use_miniexpr = False
127127

128+
# TODO: We can relax this to even just load according to result blockshape, but that's difficult.
128129
# Just force same chunk/block shapes
129-
same_chunks = all(op.chunks == result.chunks for op in (x1, x2))
130-
same_blocks = all(op.blocks == result.blocks for op in (x1, x2))
131-
same_shape = all(op.shape == result.shape for op in (x1, x2))
132-
133-
use_miniexpr &= same_blocks & same_chunks & same_shape
130+
# same_chunks = all(op.chunks == result.chunks for op in (x1, x2))
131+
# same_blocks = all(op.blocks == result.blocks for op in (x1, x2))
132+
# same_shape = all(op.shape == result.shape for op in (x1, x2))
134133

135-
# TODO: We can relax this to even just load according to result blockshape, but that's difficult.
134+
# use_miniexpr &= same_blocks & same_chunks & same_shape
136135
# Two easier cases are presented below
137136
# Case 1: Might want to restrict loading across chunk boundaries, in which case would require:
138137
# x1.chunks[-2] % result.blocks[-2] == 0
@@ -146,18 +145,18 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
146145
# (M, K) x (K, N) = (M, N)
147146
# so can load block-by-block for inputs and calculate block of output
148147
# Also need to avoid loading across chunk boundaries
149-
# chunks_aligned = x1.chunks[-2] % x1.blocks[-2] == 0
150-
# chunks_aligned &= x2.chunks[-1] % x2.blocks[-1] == 0
151-
# chunks_aligned &= x2.chunks[-2] % x1.blocks[-1] == 0
152-
# same_blocks = x2.blocks[-2] == x1.blocks[-1]
153-
# same_blocks &= x2.blocks[-1] == result.blocks[-1]
154-
# same_blocks &= result.blocks[-2] == x1.blocks[-2]
155-
# try:
156-
# result_blocks = np.broadcast_shapes(x1.blocks, x2.blocks)
157-
# if not (same_blocks and chunks_aligned and result_blocks[:-2] == result.blocks[:-2]):
158-
# use_miniexpr = False
159-
# except ValueError:
160-
# use_miniexpr = False
148+
chunks_aligned = x1.chunks[-2] % x1.blocks[-2] == 0
149+
chunks_aligned &= x2.chunks[-1] % x2.blocks[-1] == 0
150+
chunks_aligned &= x2.chunks[-2] % x1.blocks[-1] == 0
151+
same_blocks = x2.blocks[-2] == x1.blocks[-1]
152+
same_blocks &= x2.blocks[-1] == result.blocks[-1]
153+
same_blocks &= result.blocks[-2] == x1.blocks[-2]
154+
try:
155+
result_blocks = np.broadcast_shapes(x1.blocks, x2.blocks)
156+
if not (same_blocks and chunks_aligned and result_blocks[:-2] == result.blocks[:-2]):
157+
use_miniexpr = False
158+
except ValueError:
159+
use_miniexpr = False
161160

162161
use_miniexpr &= x1.dtype.kind in ("i", "f")
163162
use_miniexpr &= x2.dtype.kind in ("i", "f")

0 commit comments

Comments
 (0)