Skip to content

Commit 4dae4da

Browse files
lshaw8317FrancescAlted
authored andcommitted
Adding ND support
1 parent 6945bab commit 4dae4da

2 files changed

Lines changed: 150 additions & 58 deletions

File tree

src/blosc2/blosc2_ext.pyx

Lines changed: 122 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,13 @@ ctypedef struct me_udata:
688688
int64_t blocks_in_chunk[B2ND_MAX_DIM]
689689
me_expr* miniexpr_handle
690690

691+
ctypedef struct mm_udata:
692+
b2nd_array_t** inputs
693+
b2nd_array_t* array
694+
int64_t chunks_strides[3][B2ND_MAX_DIM]
695+
int64_t blocks_strides[3][B2ND_MAX_DIM]
696+
int64_t el_strides[3][B2ND_MAX_DIM]
697+
691698
MAX_TYPESIZE = BLOSC2_MAXTYPESIZE
692699
MAX_BUFFERSIZE = BLOSC2_MAX_BUFFERSIZE
693700
MAX_BLOCKSIZE = BLOSC2_MAXBLOCKSIZE
@@ -2348,12 +2355,12 @@ cdef int matmul_block_kernel(T* A, T* B, T* C, int M, int K, int N) nogil:
23482355
C[rowC + c] += <T>(a * B[rowB + c])
23492356
return 0
23502357

2351-
cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *params_output, int32_t typesize, int typecode) nogil:
2358+
cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *params_output, int32_t typesize, int typecode) nogil:
23522359
# Declare all C variables at the beginning
23532360
cdef b2nd_array_t* out_arr
23542361
cdef b2nd_array_t* ndarr
23552362
cdef c_bool first_run
2356-
cdef int rc, M, K, N
2363+
cdef int rc, p, q, r
23572364
cdef void** input_buffers = <void**> malloc(2 * sizeof(uint8_t*))
23582365
cdef uint8_t** src = <uint8_t**> malloc(2 * sizeof(uint8_t*))
23592366
cdef int32_t chunk_nbytes[2]
@@ -2362,42 +2369,56 @@ cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *param
23622369
cdef int blocknitems[2]
23632370
cdef int startA, startB, expected_blocknitems
23642371
cdef blosc2_context* dctx
2365-
cdef int base, i, j, nchunkA, nchunkB, nblockA, nblockB, chunk_startA, chunk_startB, block_base, block_i, block_j, block_startA, block_startB, idx, chunk_idx, block_ncols, block_nrows, nblocks_per_2d
2366-
2372+
cdef int i, j, block_i, block_j, ncols, block_ncols, Bblock_ncols, Bncols
2373+
cdef int nchunkA = 0, nchunkB = 0, nblockA = 0, nblockB = 0, offsetA = 0, offsetB = 0, offset = 0
23672374
out_arr = udata.array
23682375
cdef int ndim = out_arr.ndim
2369-
cdef int ncols = <int> udata.chunks_in_array[ndim - 1]
2370-
cdef int nrows = <int> udata.chunks_in_array[ndim - 2]
2371-
cdef int nchunks_per_2d = ncols * nrows
2372-
2373-
block_ncols = <int> udata.blocks_in_chunk[ndim - 1]
2374-
block_nrows = <int> udata.blocks_in_chunk[ndim - 2]
2375-
nblocks_per_2d = block_ncols * block_nrows
2376-
2377-
# nchunk = base * nchunks_per2d + i * ncols + j
2378-
base = nchunk // nchunks_per_2d
2379-
i = (nchunk % nchunks_per_2d) // ncols
2380-
j = nchunk % ncols
2381-
nchunkA = chunk_startA = nchunk - j
2382-
nchunkB = chunk_startB = nchunk - i * ncols
2383-
2384-
# nblock = block_base * nblocks_per_2d + block_i * block_ncols + block_j
2385-
block_base = nblock // nblocks_per_2d
2386-
block_i = (nblock % nblocks_per_2d) // block_ncols
2387-
block_j = nblock % block_ncols
2388-
block_startA = nblock - block_j
2389-
block_startB = nblock - block_i * block_ncols
2376+
cdef int nchunk_ = nchunk
2377+
cdef int coord, batch, batch_, batches = 1
2378+
for i in range(ndim - 2):
2379+
batches *= out_arr.shape[i]
2380+
2381+
# nchunk = sum(strides[i]*chunkcoords[i])
2382+
for i in range(ndim - 2):
2383+
coord = nchunk_ // udata.chunks_strides[0][i]
2384+
nchunk_ = nchunk_ % udata.chunks_strides[0][i]
2385+
nchunkA += coord * udata.chunks_strides[1][i]
2386+
nchunkB += coord * udata.chunks_strides[2][i]
2387+
2388+
ncols = udata.chunks_strides[0][ndim - 2]
2389+
Bncols = udata.chunks_strides[2][ndim - 2]
2390+
2391+
i = nchunk_ // ncols # ncols * i + j
2392+
j = nchunk_ % ncols
2393+
nchunkA = chunk_startA = nchunkA + i * ncols
2394+
nchunkB = chunk_startB = nchunkB + j
2395+
2396+
# nblock = sum(strides[i]*blockcoords[i])
2397+
cdef int nblock_ = nblock
2398+
for i in range(ndim - 2):
2399+
coord = nblock_ // udata.blocks_strides[0][i]
2400+
nblock_ = nblock_ % udata.blocks_strides[0][i]
2401+
nblockA += coord * udata.blocks_strides[1][i]
2402+
nblockB += coord * udata.blocks_strides[2][i]
2403+
2404+
block_ncols = udata.blocks_strides[0][ndim - 2]
2405+
Bblock_ncols = udata.blocks_strides[2][ndim - 2]
2406+
2407+
block_i = nblock_ // block_ncols
2408+
block_j = nblock_ % block_ncols
2409+
block_startA = nblockA = nblockA + i * block_ncols
2410+
block_startB = nblockB = nblockB + j
2411+
2412+
# batches = sum(strides[i]*elcoords[i])
23902413
dctx = blosc2_create_dctx(BLOSC2_DPARAMS_DEFAULTS)
23912414

23922415
first_run = True
23932416

23942417
while True: # chunk loop
2395-
printf("chunks: %i, %i\n", nchunkA, nchunkB)
2396-
nblockA = block_startA
2397-
nblockB = block_startB
23982418
for i in range(2):
23992419
chunk_idx = nchunkA if i == 0 else nchunkB
24002420
ndarr = udata.inputs[i]
2421+
ndim = ndarr.ndim
24012422
src[i] = ndarr.sc.data[chunk_idx]
24022423
rc = blosc2_cbuffer_sizes(src[i], &chunk_nbytes[i], &chunk_cbytes[i], &block_nbytes[i])
24032424
if rc < 0:
@@ -2406,10 +2427,10 @@ cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *param
24062427
raise ValueError("miniexpr: invalid block size")
24072428
if first_run:
24082429
if i == 0:
2409-
K = ndarr.blockshape[ndim - 1]
2410-
M = ndarr.blockshape[ndim - 2]
2430+
q = ndarr.blockshape[ndim - 1]
2431+
p = ndarr.blockshape[ndim - 2]
24112432
else: # i = 1
2412-
N = ndarr.blockshape[ndim - 1]
2433+
r = ndarr.blockshape[ndim - 1]
24132434
input_buffers[i] = malloc(block_nbytes[i])
24142435
if input_buffers[i] == NULL:
24152436
raise MemoryError("miniexpr: cannot allocate input block buffer")
@@ -2420,8 +2441,9 @@ cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *param
24202441
raise ValueError("miniexpr: inconsistent block element counts across inputs")
24212442

24222443
first_run = False
2444+
nblockA = block_startA
2445+
nblockB = block_startB
24232446
while True: # block loop
2424-
printf("blocks: %i, %i\n", nblockA, nblockB)
24252447
startA = nblockA * blocknitems[0]
24262448
startB = nblockB * blocknitems[1]
24272449
rc = blosc2_getitem_ctx(dctx, src[0], chunk_cbytes[0], startA, blocknitems[0],
@@ -2432,28 +2454,37 @@ cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *param
24322454
input_buffers[1], block_nbytes[1])
24332455
if rc < 0:
24342456
raise ValueError("matmul: error decompressing the B chunk")
2435-
if typecode == 0:
2436-
if typesize == 4:
2437-
rc = matmul_block_kernel[float](<float*>input_buffers[0], <float*>input_buffers[1], <float*>params_output, M, K, N)
2438-
else:
2439-
rc = matmul_block_kernel[double](<double*>input_buffers[0], <double*>input_buffers[1], <double*>params_output, M, K, N)
2440-
elif typecode == 1:
2441-
if typesize == 4:
2442-
rc = matmul_block_kernel[int32_t](<int32_t*>input_buffers[0], <int32_t*>input_buffers[1], <int32_t*>params_output, M, K, N)
2457+
batch = 0
2458+
while batch < batches:
2459+
batch_ = batch
2460+
for i in range(ndim - 2):
2461+
coord = batch // udata.el_strides[0][i]
2462+
batch_ = batch_ % udata.el_strides[0][i]
2463+
offsetA += coord * udata.el_strides[1][i]
2464+
offsetB += coord * udata.el_strides[2][i]
2465+
offset += coord * udata.el_strides[0][i]
2466+
if typecode == 0:
2467+
if typesize == 4:
2468+
rc = matmul_block_kernel[float](<float*>input_buffers[0] + offsetA, <float*>input_buffers[1] + offsetB, <float*>params_output + offset, p, q, r)
2469+
else:
2470+
rc = matmul_block_kernel[double](<double*>input_buffers[0] + offsetA, <double*>input_buffers[1] + offsetB, <double*>params_output + offset, p, q, r)
2471+
elif typecode == 1:
2472+
if typesize == 4:
2473+
rc = matmul_block_kernel[int32_t](<int32_t*>input_buffers[0] + offsetA, <int32_t*>input_buffers[1] + offsetB, <int32_t*>params_output + offset, p, q, r)
2474+
else:
2475+
rc = matmul_block_kernel[int64_t](<int64_t*>input_buffers[0] + offsetA, <int64_t*>input_buffers[1] + offsetB, <int64_t*>params_output + offset, p, q, r)
24432476
else:
2444-
rc = matmul_block_kernel[int64_t](<int64_t*>input_buffers[0], <int64_t*>input_buffers[1], <int64_t*>params_output, M, K, N)
2445-
else:
2446-
with gil:
2447-
raise ValueError("Unsupported dtype")
2477+
with gil:
2478+
raise ValueError("Unsupported dtype")
2479+
batch += 1
24482480
nblockA += 1
2449-
nblockB += block_ncols
2481+
nblockB += Bblock_ncols
24502482
if (nblockA % block_ncols == 0):
24512483
break
24522484
nchunkA += 1
2453-
nchunkB += ncols
2485+
nchunkB += Bncols
24542486
if (nchunkA % ncols == 0):
24552487
break
2456-
printf("finished block %i for chunk %i\n", nblock, nchunk)
24572488

24582489

24592490
blosc2_free_ctx(dctx)
@@ -2545,7 +2576,7 @@ cdef int miniexpr_prefilter(blosc2_prefilter_params *params):
25452576
cdef int matmul_prefilter(blosc2_prefilter_params *params):
25462577
cdef int typecode
25472578

2548-
cdef me_udata* udata = <me_udata *> params.user_data
2579+
cdef mm_udata* udata = <mm_udata *> params.user_data
25492580
cdef b2nd_array_t* out_arr = udata.array
25502581
cdef char dtype_kind = out_arr.dtype[1]
25512582
if dtype_kind == 'f':
@@ -3407,6 +3438,48 @@ cdef class NDArray:
34073438

34083439
return udata
34093440

3441+
cdef mm_udata *_fill_mm_udata(self, inputs):
3442+
cdef mm_udata *udata = <mm_udata *> malloc(sizeof(mm_udata))
3443+
cdef int cstrides, bstrides, estrides
3444+
cdef b2nd_array_t* inp
3445+
cdef b2nd_array_t** inputs_ = <b2nd_array_t**> malloc(2 * sizeof(b2nd_array_t*))
3446+
for i in range(2):
3447+
operand = inputs['x1'] if i == 0 else inputs['x2']
3448+
inputs_[i] = <b2nd_array_t*><uintptr_t>operand.c_array
3449+
inputs_[i].chunk_cache.nchunk = -1
3450+
inputs_[i].chunk_cache.data = NULL
3451+
udata.inputs = inputs_
3452+
udata.array = self.array
3453+
3454+
# Save these in udf_udata to avoid computing them for each block
3455+
for i in range(3):
3456+
udata.chunks_strides[i][self.array.ndim - 1] = 1
3457+
udata.blocks_strides[i][self.array.ndim - 1] = 1
3458+
udata.el_strides[i][self.array.ndim - 1] = 1
3459+
for idx in range(2, self.array.ndim + 1):
3460+
i = self.array.ndim - idx
3461+
udata.chunks_strides[0][i] = udata.chunks_strides[0][i + 1] * udata.array.extshape[i + 1] // udata.array.chunkshape[i + 1]
3462+
udata.blocks_strides[0][i] = udata.blocks_strides[0][i + 1] * udata.array.extchunkshape[i + 1] // udata.array.blockshape[i + 1]
3463+
3464+
for j in range(2):
3465+
inp = inputs_[j]
3466+
cstrides = bstrides = estrides = 1
3467+
for idx in range(2, self.array.ndim + 1):
3468+
i = inp.ndim - idx
3469+
if inp.shape[i + 1] == 1 or i < 0:
3470+
udata.chunks_strides[j][i] = 0
3471+
udata.blocks_strides[j][i] = 0
3472+
udata.el_strides[j][i] = 0
3473+
else:
3474+
bstrides *= inp.extchunkshape[i + 1] // inp.blockshape[i + 1]
3475+
cstrides *= inp.extshape[i + 1] // inp.chunkshape[i + 1]
3476+
estrides *= inp.blockshape[i + 1]
3477+
udata.chunks_strides[j][i] = cstrides
3478+
udata.blocks_strides[j][i] = bstrides
3479+
udata.el_strides[j][i] = estrides
3480+
3481+
return udata
3482+
34103483
def _set_pref_expr(self, expression, inputs, fp_accuracy, aux_reduc=None, jit=None):
34113484
# Set prefilter for miniexpr
34123485
cdef blosc2_cparams* cparams = self.array.sc.storage.cparams
@@ -3497,7 +3570,7 @@ cdef class NDArray:
34973570
cdef blosc2_cparams* cparams = self.array.sc.storage.cparams
34983571
cparams.prefilter = <blosc2_prefilter_fn> matmul_prefilter
34993572

3500-
cdef me_udata* udata = self._fill_me_udata(inputs, fp_accuracy, aux_reduc=None)
3573+
cdef mm_udata* udata = self._fill_mm_udata(inputs)
35013574
cdef b2nd_array_t* out_arr = udata.array
35023575
cdef blosc2_prefilter_params* preparams = <blosc2_prefilter_params *> calloc(1, sizeof(blosc2_prefilter_params))
35033576
preparams.user_data = udata

src/blosc2/linalg.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,21 +113,40 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
113113
result = blosc2.zeros(result_shape, dtype=blosc2.result_type(x1, x2), **kwargs)
114114

115115
# multithreaded matmul
116-
# TODO: handle a) type promotion, b) non-square blocks, c) and >2D
116+
# TODO: handle a) type promotion, b) padding, c) (improved) >2D
117117
ops = (x1, x2, result)
118-
shape, chunks, blocks = result.shape, result.chunks, result.blocks
118+
blocks = result.blocks
119119
all_ndarray = all(isinstance(value, blosc2.NDArray) and value.shape != () for value in ops)
120120
use_miniexpr = True
121121
if all_ndarray:
122-
# can maybe relax this to just have A.blocks[-1] == B.blocks[-2]
123-
# Require aligned NDArray operands with identical chunk/block grid, and square matrices/chunks/blocks
124-
same_shape = all(op.shape[-1] == op.shape[-2] and op.shape == shape for op in ops)
125-
same_chunks = all(op.shape[-1] == op.shape[-2] and op.chunks == chunks for op in ops)
126-
same_blocks = all(op.shape[-1] == op.shape[-2] and op.blocks == blocks for op in ops)
127-
if not (same_shape and same_chunks and same_blocks):
128-
use_miniexpr = False
129122
if any(op.dtype != ops[0].dtype for op in ops): # TODO: Remove this condition
130123
use_miniexpr = False
124+
125+
# TODO: In fact the following can be relaxed too, just need to load across block boundaries
126+
# Might want to restrict loading across chunk boundaries, in which case would require:
127+
# x1.chunks[-2] % result.blocks[-2] == 0
128+
# x2.chunks[-1] % result.blocks[-1] == 0
129+
# x2.chunks[-2] % x1.blocks[-1] == 0
130+
# Can then load in x1 as slices of size [result.blocks[-2], x1.blocks[-1]]
131+
# and x2 in slices of [x1.blocks[-1], result.blocks[-1]]
132+
133+
# Require that blocks are matmul compatible and broadcastable directly to result
134+
# (M, K) x (K, N) = (M, N)
135+
# so can load block-by-block for inputs and calculate block of output
136+
# Also need to avoid loading across chunk boundaries
137+
chunks_aligned = x1.chunks[-2] % x1.blocks[-2] == 0
138+
chunks_aligned &= x2.chunks[-1] % x2.blocks[-1] == 0
139+
chunks_aligned &= x2.chunks[-2] % x1.blocks[-1] == 0
140+
same_blocks = x2.blocks[-2] == x1.blocks[-1]
141+
same_blocks &= x2.blocks[-1] == result.blocks[-1]
142+
same_blocks &= result.blocks[-2] == x1.blocks[-2]
143+
try:
144+
result_blocks = np.broadcast_shapes(x1.blocks, x2.blocks)
145+
except ValueError:
146+
use_miniexpr = False
147+
if not (same_blocks and chunks_aligned and result_blocks[:-2] == blocks[:-2]):
148+
use_miniexpr = False
149+
131150
else:
132151
use_miniexpr = False
133152

0 commit comments

Comments
 (0)