Skip to content

Commit 5de386c

Browse files
committed
Adding ND support
1 parent c41dca8 commit 5de386c

2 files changed

Lines changed: 150 additions & 58 deletions

File tree

src/blosc2/blosc2_ext.pyx

Lines changed: 122 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,13 @@ ctypedef struct me_udata:
674674
int64_t blocks_in_chunk[B2ND_MAX_DIM]
675675
me_expr* miniexpr_handle
676676

677+
ctypedef struct mm_udata:
678+
b2nd_array_t** inputs
679+
b2nd_array_t* array
680+
int64_t chunks_strides[3][B2ND_MAX_DIM]
681+
int64_t blocks_strides[3][B2ND_MAX_DIM]
682+
int64_t el_strides[3][B2ND_MAX_DIM]
683+
677684
MAX_TYPESIZE = BLOSC2_MAXTYPESIZE
678685
MAX_BUFFERSIZE = BLOSC2_MAX_BUFFERSIZE
679686
MAX_BLOCKSIZE = BLOSC2_MAXBLOCKSIZE
@@ -2161,12 +2168,12 @@ cdef int matmul_block_kernel(T* A, T* B, T* C, int M, int K, int N) nogil:
21612168
C[rowC + c] += <T>(a * B[rowB + c])
21622169
return 0
21632170

2164-
cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *params_output, int32_t typesize, int typecode) nogil:
2171+
cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *params_output, int32_t typesize, int typecode) nogil:
21652172
# Declare all C variables at the beginning
21662173
cdef b2nd_array_t* out_arr
21672174
cdef b2nd_array_t* ndarr
21682175
cdef c_bool first_run
2169-
cdef int rc, M, K, N
2176+
cdef int rc, p, q, r
21702177
cdef void** input_buffers = <void**> malloc(2 * sizeof(uint8_t*))
21712178
cdef uint8_t** src = <uint8_t**> malloc(2 * sizeof(uint8_t*))
21722179
cdef int32_t chunk_nbytes[2]
@@ -2175,42 +2182,56 @@ cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *param
21752182
cdef int blocknitems[2]
21762183
cdef int startA, startB, expected_blocknitems
21772184
cdef blosc2_context* dctx
2178-
cdef int base, i, j, nchunkA, nchunkB, nblockA, nblockB, chunk_startA, chunk_startB, block_base, block_i, block_j, block_startA, block_startB, idx, chunk_idx, block_ncols, block_nrows, nblocks_per_2d
2179-
2185+
cdef int i, j, block_i, block_j, ncols, block_ncols, Bblock_ncols, Bncols
2186+
cdef int nchunkA = 0, nchunkB = 0, nblockA = 0, nblockB = 0, offsetA = 0, offsetB = 0, offset = 0
21802187
out_arr = udata.array
21812188
cdef int ndim = out_arr.ndim
2182-
cdef int ncols = <int> udata.chunks_in_array[ndim - 1]
2183-
cdef int nrows = <int> udata.chunks_in_array[ndim - 2]
2184-
cdef int nchunks_per_2d = ncols * nrows
2185-
2186-
block_ncols = <int> udata.blocks_in_chunk[ndim - 1]
2187-
block_nrows = <int> udata.blocks_in_chunk[ndim - 2]
2188-
nblocks_per_2d = block_ncols * block_nrows
2189-
2190-
# nchunk = base * nchunks_per2d + i * ncols + j
2191-
base = nchunk // nchunks_per_2d
2192-
i = (nchunk % nchunks_per_2d) // ncols
2193-
j = nchunk % ncols
2194-
nchunkA = chunk_startA = nchunk - j
2195-
nchunkB = chunk_startB = nchunk - i * ncols
2196-
2197-
# nblock = block_base * nblocks_per_2d + block_i * block_ncols + block_j
2198-
block_base = nblock // nblocks_per_2d
2199-
block_i = (nblock % nblocks_per_2d) // block_ncols
2200-
block_j = nblock % block_ncols
2201-
block_startA = nblock - block_j
2202-
block_startB = nblock - block_i * block_ncols
2189+
cdef int nchunk_ = nchunk
2190+
cdef int coord, batch, batch_, batches = 1
2191+
for i in range(ndim - 2):
2192+
batches *= out_arr.shape[i]
2193+
2194+
# nchunk = sum(strides[i]*chunkcoords[i])
2195+
for i in range(ndim - 2):
2196+
coord = nchunk_ // udata.chunks_strides[0][i]
2197+
nchunk_ = nchunk_ % udata.chunks_strides[0][i]
2198+
nchunkA += coord * udata.chunks_strides[1][i]
2199+
nchunkB += coord * udata.chunks_strides[2][i]
2200+
2201+
ncols = udata.chunks_strides[0][ndim - 2]
2202+
Bncols = udata.chunks_strides[2][ndim - 2]
2203+
2204+
i = nchunk_ // ncols # ncols * i + j
2205+
j = nchunk_ % ncols
2206+
nchunkA = chunk_startA = nchunkA + i * ncols
2207+
nchunkB = chunk_startB = nchunkB + j
2208+
2209+
# nblock = sum(strides[i]*blockcoords[i])
2210+
cdef int nblock_ = nblock
2211+
for i in range(ndim - 2):
2212+
coord = nblock_ // udata.blocks_strides[0][i]
2213+
nblock_ = nblock_ % udata.blocks_strides[0][i]
2214+
nblockA += coord * udata.blocks_strides[1][i]
2215+
nblockB += coord * udata.blocks_strides[2][i]
2216+
2217+
block_ncols = udata.blocks_strides[0][ndim - 2]
2218+
Bblock_ncols = udata.blocks_strides[2][ndim - 2]
2219+
2220+
block_i = nblock_ // block_ncols
2221+
block_j = nblock_ % block_ncols
2222+
block_startA = nblockA = nblockA + i * block_ncols
2223+
block_startB = nblockB = nblockB + j
2224+
2225+
# batches = sum(strides[i]*elcoords[i])
22032226
dctx = blosc2_create_dctx(BLOSC2_DPARAMS_DEFAULTS)
22042227

22052228
first_run = True
22062229

22072230
while True: # chunk loop
2208-
printf("chunks: %i, %i\n", nchunkA, nchunkB)
2209-
nblockA = block_startA
2210-
nblockB = block_startB
22112231
for i in range(2):
22122232
chunk_idx = nchunkA if i == 0 else nchunkB
22132233
ndarr = udata.inputs[i]
2234+
ndim = ndarr.ndim
22142235
src[i] = ndarr.sc.data[chunk_idx]
22152236
rc = blosc2_cbuffer_sizes(src[i], &chunk_nbytes[i], &chunk_cbytes[i], &block_nbytes[i])
22162237
if rc < 0:
@@ -2219,10 +2240,10 @@ cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *param
22192240
raise ValueError("miniexpr: invalid block size")
22202241
if first_run:
22212242
if i == 0:
2222-
K = ndarr.blockshape[ndim - 1]
2223-
M = ndarr.blockshape[ndim - 2]
2243+
q = ndarr.blockshape[ndim - 1]
2244+
p = ndarr.blockshape[ndim - 2]
22242245
else: # i = 1
2225-
N = ndarr.blockshape[ndim - 1]
2246+
r = ndarr.blockshape[ndim - 1]
22262247
input_buffers[i] = malloc(block_nbytes[i])
22272248
if input_buffers[i] == NULL:
22282249
raise MemoryError("miniexpr: cannot allocate input block buffer")
@@ -2233,8 +2254,9 @@ cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *param
22332254
raise ValueError("miniexpr: inconsistent block element counts across inputs")
22342255

22352256
first_run = False
2257+
nblockA = block_startA
2258+
nblockB = block_startB
22362259
while True: # block loop
2237-
printf("blocks: %i, %i\n", nblockA, nblockB)
22382260
startA = nblockA * blocknitems[0]
22392261
startB = nblockB * blocknitems[1]
22402262
rc = blosc2_getitem_ctx(dctx, src[0], chunk_cbytes[0], startA, blocknitems[0],
@@ -2245,28 +2267,37 @@ cdef int aux_matmul(me_udata *udata, int64_t nchunk, int32_t nblock, void *param
22452267
input_buffers[1], block_nbytes[1])
22462268
if rc < 0:
22472269
raise ValueError("matmul: error decompressing the B chunk")
2248-
if typecode == 0:
2249-
if typesize == 4:
2250-
rc = matmul_block_kernel[float](<float*>input_buffers[0], <float*>input_buffers[1], <float*>params_output, M, K, N)
2251-
else:
2252-
rc = matmul_block_kernel[double](<double*>input_buffers[0], <double*>input_buffers[1], <double*>params_output, M, K, N)
2253-
elif typecode == 1:
2254-
if typesize == 4:
2255-
rc = matmul_block_kernel[int32_t](<int32_t*>input_buffers[0], <int32_t*>input_buffers[1], <int32_t*>params_output, M, K, N)
2270+
batch = 0
2271+
while batch < batches:
2272+
batch_ = batch
2273+
for i in range(ndim - 2):
2274+
coord = batch // udata.el_strides[0][i]
2275+
batch_ = batch_ % udata.el_strides[0][i]
2276+
offsetA += coord * udata.el_strides[1][i]
2277+
offsetB += coord * udata.el_strides[2][i]
2278+
offset += coord * udata.el_strides[0][i]
2279+
if typecode == 0:
2280+
if typesize == 4:
2281+
rc = matmul_block_kernel[float](<float*>input_buffers[0] + offsetA, <float*>input_buffers[1] + offsetB, <float*>params_output + offset, p, q, r)
2282+
else:
2283+
rc = matmul_block_kernel[double](<double*>input_buffers[0] + offsetA, <double*>input_buffers[1] + offsetB, <double*>params_output + offset, p, q, r)
2284+
elif typecode == 1:
2285+
if typesize == 4:
2286+
rc = matmul_block_kernel[int32_t](<int32_t*>input_buffers[0] + offsetA, <int32_t*>input_buffers[1] + offsetB, <int32_t*>params_output + offset, p, q, r)
2287+
else:
2288+
rc = matmul_block_kernel[int64_t](<int64_t*>input_buffers[0] + offsetA, <int64_t*>input_buffers[1] + offsetB, <int64_t*>params_output + offset, p, q, r)
22562289
else:
2257-
rc = matmul_block_kernel[int64_t](<int64_t*>input_buffers[0], <int64_t*>input_buffers[1], <int64_t*>params_output, M, K, N)
2258-
else:
2259-
with gil:
2260-
raise ValueError("Unsupported dtype")
2290+
with gil:
2291+
raise ValueError("Unsupported dtype")
2292+
batch += 1
22612293
nblockA += 1
2262-
nblockB += block_ncols
2294+
nblockB += Bblock_ncols
22632295
if (nblockA % block_ncols == 0):
22642296
break
22652297
nchunkA += 1
2266-
nchunkB += ncols
2298+
nchunkB += Bncols
22672299
if (nchunkA % ncols == 0):
22682300
break
2269-
printf("finished block %i for chunk %i\n", nblock, nchunk)
22702301

22712302

22722303
blosc2_free_ctx(dctx)
@@ -2358,7 +2389,7 @@ cdef int miniexpr_prefilter(blosc2_prefilter_params *params):
23582389
cdef int matmul_prefilter(blosc2_prefilter_params *params):
23592390
cdef int typecode
23602391

2361-
cdef me_udata* udata = <me_udata *> params.user_data
2392+
cdef mm_udata* udata = <mm_udata *> params.user_data
23622393
cdef b2nd_array_t* out_arr = udata.array
23632394
cdef char dtype_kind = out_arr.dtype[1]
23642395
if dtype_kind == 'f':
@@ -3215,6 +3246,48 @@ cdef class NDArray:
32153246

32163247
return udata
32173248

3249+
cdef mm_udata *_fill_mm_udata(self, inputs):
3250+
cdef mm_udata *udata = <mm_udata *> malloc(sizeof(mm_udata))
3251+
cdef int cstrides, bstrides, estrides
3252+
cdef b2nd_array_t* inp
3253+
cdef b2nd_array_t** inputs_ = <b2nd_array_t**> malloc(2 * sizeof(b2nd_array_t*))
3254+
for i in range(2):
3255+
operand = inputs['x1'] if i == 0 else inputs['x2']
3256+
inputs_[i] = <b2nd_array_t*><uintptr_t>operand.c_array
3257+
inputs_[i].chunk_cache.nchunk = -1
3258+
inputs_[i].chunk_cache.data = NULL
3259+
udata.inputs = inputs_
3260+
udata.array = self.array
3261+
3262+
# Save these in udf_udata to avoid computing them for each block
3263+
for i in range(3):
3264+
udata.chunks_strides[i][self.array.ndim - 1] = 1
3265+
udata.blocks_strides[i][self.array.ndim - 1] = 1
3266+
udata.el_strides[i][self.array.ndim - 1] = 1
3267+
for idx in range(2, self.array.ndim + 1):
3268+
i = self.array.ndim - idx
3269+
udata.chunks_strides[0][i] = udata.chunks_strides[0][i + 1] * udata.array.extshape[i + 1] // udata.array.chunkshape[i + 1]
3270+
udata.blocks_strides[0][i] = udata.blocks_strides[0][i + 1] * udata.array.extchunkshape[i + 1] // udata.array.blockshape[i + 1]
3271+
3272+
for j in range(2):
3273+
inp = inputs_[j]
3274+
cstrides = bstrides = estrides = 1
3275+
for idx in range(2, self.array.ndim + 1):
3276+
i = inp.ndim - idx
3277+
if inp.shape[i + 1] == 1 or i < 0:
3278+
udata.chunks_strides[j][i] = 0
3279+
udata.blocks_strides[j][i] = 0
3280+
udata.el_strides[j][i] = 0
3281+
else:
3282+
bstrides *= inp.extchunkshape[i + 1] // inp.blockshape[i + 1]
3283+
cstrides *= inp.extshape[i + 1] // inp.chunkshape[i + 1]
3284+
estrides *= inp.blockshape[i + 1]
3285+
udata.chunks_strides[j][i] = cstrides
3286+
udata.blocks_strides[j][i] = bstrides
3287+
udata.el_strides[j][i] = estrides
3288+
3289+
return udata
3290+
32183291
def _set_pref_expr(self, expression, inputs, fp_accuracy, aux_reduc=None, jit=None):
32193292
# Set prefilter for miniexpr
32203293
cdef blosc2_cparams* cparams = self.array.sc.storage.cparams
@@ -3305,7 +3378,7 @@ cdef class NDArray:
33053378
cdef blosc2_cparams* cparams = self.array.sc.storage.cparams
33063379
cparams.prefilter = <blosc2_prefilter_fn> matmul_prefilter
33073380

3308-
cdef me_udata* udata = self._fill_me_udata(inputs, fp_accuracy, aux_reduc=None)
3381+
cdef mm_udata* udata = self._fill_mm_udata(inputs)
33093382
cdef b2nd_array_t* out_arr = udata.array
33103383
cdef blosc2_prefilter_params* preparams = <blosc2_prefilter_params *> calloc(1, sizeof(blosc2_prefilter_params))
33113384
preparams.user_data = udata

src/blosc2/linalg.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,21 +113,40 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
113113
result = blosc2.zeros(result_shape, dtype=blosc2.result_type(x1, x2), **kwargs)
114114

115115
# multithreaded matmul
116-
# TODO: handle a) type promotion, b) non-square blocks, c) and >2D
116+
# TODO: handle a) type promotion, b) padding, c) (improved) >2D
117117
ops = (x1, x2, result)
118-
shape, chunks, blocks = result.shape, result.chunks, result.blocks
118+
blocks = result.blocks
119119
all_ndarray = all(isinstance(value, blosc2.NDArray) and value.shape != () for value in ops)
120120
use_miniexpr = True
121121
if all_ndarray:
122-
# can maybe relax this to just have A.blocks[-1] == B.blocks[-2]
123-
# Require aligned NDArray operands with identical chunk/block grid, and square matrices/chunks/blocks
124-
same_shape = all(op.shape[-1] == op.shape[-2] and op.shape == shape for op in ops)
125-
same_chunks = all(op.shape[-1] == op.shape[-2] and op.chunks == chunks for op in ops)
126-
same_blocks = all(op.shape[-1] == op.shape[-2] and op.blocks == blocks for op in ops)
127-
if not (same_shape and same_chunks and same_blocks):
128-
use_miniexpr = False
129122
if any(op.dtype != ops[0].dtype for op in ops): # TODO: Remove this condition
130123
use_miniexpr = False
124+
125+
# TODO: In fact the following can be relaxed too, just need to load across block boundaries
126+
# Might want to restrict loading across chunk boundaries, in which case would require:
127+
# x1.chunks[-2] % result.blocks[-2] == 0
128+
# x2.chunks[-1] % result.blocks[-1] == 0
129+
# x2.chunks[-2] % x1.blocks[-1] == 0
130+
# Can then load in x1 as slices of size [result.blocks[-2], x1.blocks[-1]]
131+
# and x2 in slices of [x1.blocks[-1], result.blocks[-1]]
132+
133+
# Require that blocks are matmul compatible and broadcastable directly to result
134+
# (M, K) x (K, N) = (M, N)
135+
# so can load block-by-block for inputs and calculate block of output
136+
# Also need to avoid loading across chunk boundaries
137+
chunks_aligned = x1.chunks[-2] % x1.blocks[-2] == 0
138+
chunks_aligned &= x2.chunks[-1] % x2.blocks[-1] == 0
139+
chunks_aligned &= x2.chunks[-2] % x1.blocks[-1] == 0
140+
same_blocks = x2.blocks[-2] == x1.blocks[-1]
141+
same_blocks &= x2.blocks[-1] == result.blocks[-1]
142+
same_blocks &= result.blocks[-2] == x1.blocks[-2]
143+
try:
144+
result_blocks = np.broadcast_shapes(x1.blocks, x2.blocks)
145+
except ValueError:
146+
use_miniexpr = False
147+
if not (same_blocks and chunks_aligned and result_blocks[:-2] == blocks[:-2]):
148+
use_miniexpr = False
149+
131150
else:
132151
use_miniexpr = False
133152

0 commit comments

Comments
 (0)