Skip to content

Commit a0cc22d

Browse files
committed
Multithreaded matmul supported for ND
1 parent 5de386c commit a0cc22d

5 files changed

Lines changed: 33 additions & 21 deletions

File tree

bench/ndarray/stringops_bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import time
1313
import numpy as np
1414
import blosc2
15-
from blosc2.lazyexpr import _toggle_miniexpr
15+
from blosc2.utils import _toggle_miniexpr
1616

1717
# nparr = np.random.randint(low=0, high=128, size=(N, 10), dtype=np.uint32)
1818
# nparr = nparr.view('S40').astype('U10')

src/blosc2/blosc2_ext.pyx

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2188,8 +2188,10 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
21882188
cdef int ndim = out_arr.ndim
21892189
cdef int nchunk_ = nchunk
21902190
cdef int coord, batch, batch_, batches = 1
2191+
2192+
# batches = sum(strides[i]*elcoords[i])
21912193
for i in range(ndim - 2):
2192-
batches *= out_arr.shape[i]
2194+
batches *= out_arr.blockshape[i]
21932195

21942196
# nchunk = sum(strides[i]*chunkcoords[i])
21952197
for i in range(ndim - 2):
@@ -2203,8 +2205,8 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
22032205

22042206
i = nchunk_ // ncols # ncols * i + j
22052207
j = nchunk_ % ncols
2206-
nchunkA = chunk_startA = nchunkA + i * ncols
2207-
nchunkB = chunk_startB = nchunkB + j
2208+
chunk_startA = nchunkA + i * ncols
2209+
chunk_startB = nchunkB + j
22082210

22092211
# nblock = sum(strides[i]*blockcoords[i])
22102212
cdef int nblock_ = nblock
@@ -2219,14 +2221,14 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
22192221

22202222
block_i = nblock_ // block_ncols
22212223
block_j = nblock_ % block_ncols
2222-
block_startA = nblockA = nblockA + i * block_ncols
2223-
block_startB = nblockB = nblockB + j
2224+
block_startA = nblockA + block_i * block_ncols
2225+
block_startB = nblockB + block_j
22242226

2225-
# batches = sum(strides[i]*elcoords[i])
22262227
dctx = blosc2_create_dctx(BLOSC2_DPARAMS_DEFAULTS)
22272228

22282229
first_run = True
2229-
2230+
nchunkA = chunk_startA
2231+
nchunkB = chunk_startB
22302232
while True: # chunk loop
22312233
for i in range(2):
22322234
chunk_idx = nchunkA if i == 0 else nchunkB
@@ -2268,6 +2270,9 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
22682270
if rc < 0:
22692271
raise ValueError("matmul: error decompressing the B chunk")
22702272
batch = 0
2273+
offsetA = 0
2274+
offsetB = 0
2275+
offset = 0
22712276
while batch < batches:
22722277
batch_ = batch
22732278
for i in range(ndim - 2):
@@ -3268,9 +3273,10 @@ cdef class NDArray:
32683273
i = self.array.ndim - idx
32693274
udata.chunks_strides[0][i] = udata.chunks_strides[0][i + 1] * udata.array.extshape[i + 1] // udata.array.chunkshape[i + 1]
32703275
udata.blocks_strides[0][i] = udata.blocks_strides[0][i + 1] * udata.array.extchunkshape[i + 1] // udata.array.blockshape[i + 1]
3276+
udata.el_strides[0][i] = udata.el_strides[0][i + 1] * udata.array.blockshape[i + 1]
32713277

3272-
for j in range(2):
3273-
inp = inputs_[j]
3278+
for j in range(1, 3):
3279+
inp = inputs_[j - 1]
32743280
cstrides = bstrides = estrides = 1
32753281
for idx in range(2, self.array.ndim + 1):
32763282
i = inp.ndim - idx

src/blosc2/lazyexpr.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
process_key,
6969
reducers,
7070
safe_numpy_globals,
71+
try_miniexpr,
7172
)
7273

7374
if not blosc2.IS_WASM:
@@ -76,14 +77,6 @@
7677
global safe_blosc2_globals
7778
safe_blosc2_globals = {}
7879

79-
# Set this to False if miniexpr should not be tried out
80-
try_miniexpr = not blosc2.IS_WASM or getattr(blosc2, "_WASM_MINIEXPR_ENABLED", False)
81-
82-
83-
def _toggle_miniexpr(FLAG):
84-
global try_miniexpr
85-
try_miniexpr = FLAG
86-
8780

8881
def ne_evaluate(expression, local_dict=None, **kwargs):
8982
"""Safely evaluate expressions using numexpr when possible, falling back to numpy."""

src/blosc2/linalg.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import blosc2
1919

20-
from .utils import get_intersecting_chunks, nptranspose, npvecdot, slice_to_chunktuple
20+
from .utils import get_intersecting_chunks, nptranspose, npvecdot, slice_to_chunktuple, try_miniexpr
2121

2222
if TYPE_CHECKING:
2323
from collections.abc import Sequence
@@ -113,11 +113,14 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
113113
result = blosc2.zeros(result_shape, dtype=blosc2.result_type(x1, x2), **kwargs)
114114

115115
# multithreaded matmul
116-
# TODO: handle a) type promotion, b) padding, c) (improved) >2D
116+
# TODO: handle a) type promotion, b) padding (explicitly), c) (improved) >2D
117117
ops = (x1, x2, result)
118118
blocks = result.blocks
119119
all_ndarray = all(isinstance(value, blosc2.NDArray) and value.shape != () for value in ops)
120-
use_miniexpr = True
120+
global try_miniexpr
121+
122+
# Use a local copy so we don't modify the global
123+
use_miniexpr = try_miniexpr
121124
if all_ndarray:
122125
if any(op.dtype != ops[0].dtype for op in ops): # TODO: Remove this condition
123126
use_miniexpr = False
@@ -165,6 +168,7 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
165168
if prefilter_set:
166169
result.schunk.remove_prefilter("miniexpr")
167170
else: # couldn't do multithreading
171+
print("multithreading failed :( ")
168172
if 0 not in result.shape + x1.shape + x2.shape: # if any array is empty, return array of 0s
169173
p, q = result.chunks[-2:]
170174
r = x2.chunks[-1]

src/blosc2/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@
1919

2020
import blosc2
2121

22+
# Set this to False if miniexpr should not be tried out
23+
try_miniexpr = not blosc2.IS_WASM or getattr(blosc2, "_WASM_MINIEXPR_ENABLED", False)
24+
25+
26+
def _toggle_miniexpr(FLAG):
27+
global try_miniexpr
28+
try_miniexpr = FLAG
29+
30+
2231
# NumPy version and a convenient boolean flag
2332
NUMPY_GE_2_0 = np.__version__ >= "2.0"
2433
# handle different numpy versions

0 commit comments

Comments
 (0)