Skip to content

Commit 6997ec9

Browse files
lshaw8317FrancescAlted
authored andcommitted
Multithreaded matmul supported for ND
1 parent 4dae4da commit 6997ec9

5 files changed

Lines changed: 33 additions & 21 deletions

File tree

bench/ndarray/stringops_bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import time
1313
import numpy as np
1414
import blosc2
15-
from blosc2.lazyexpr import _toggle_miniexpr
15+
from blosc2.utils import _toggle_miniexpr
1616

1717
# nparr = np.random.randint(low=0, high=128, size=(N, 10), dtype=np.uint32)
1818
# nparr = nparr.view('S40').astype('U10')

src/blosc2/blosc2_ext.pyx

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2375,8 +2375,10 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
23752375
cdef int ndim = out_arr.ndim
23762376
cdef int nchunk_ = nchunk
23772377
cdef int coord, batch, batch_, batches = 1
2378+
2379+
# batches = sum(strides[i]*elcoords[i])
23782380
for i in range(ndim - 2):
2379-
batches *= out_arr.shape[i]
2381+
batches *= out_arr.blockshape[i]
23802382

23812383
# nchunk = sum(strides[i]*chunkcoords[i])
23822384
for i in range(ndim - 2):
@@ -2390,8 +2392,8 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
23902392

23912393
i = nchunk_ // ncols # ncols * i + j
23922394
j = nchunk_ % ncols
2393-
nchunkA = chunk_startA = nchunkA + i * ncols
2394-
nchunkB = chunk_startB = nchunkB + j
2395+
chunk_startA = nchunkA + i * ncols
2396+
chunk_startB = nchunkB + j
23952397

23962398
# nblock = sum(strides[i]*blockcoords[i])
23972399
cdef int nblock_ = nblock
@@ -2406,14 +2408,14 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
24062408

24072409
block_i = nblock_ // block_ncols
24082410
block_j = nblock_ % block_ncols
2409-
block_startA = nblockA = nblockA + i * block_ncols
2410-
block_startB = nblockB = nblockB + j
2411+
block_startA = nblockA + block_i * block_ncols
2412+
block_startB = nblockB + block_j
24112413

2412-
# batches = sum(strides[i]*elcoords[i])
24132414
dctx = blosc2_create_dctx(BLOSC2_DPARAMS_DEFAULTS)
24142415

24152416
first_run = True
2416-
2417+
nchunkA = chunk_startA
2418+
nchunkB = chunk_startB
24172419
while True: # chunk loop
24182420
for i in range(2):
24192421
chunk_idx = nchunkA if i == 0 else nchunkB
@@ -2455,6 +2457,9 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
24552457
if rc < 0:
24562458
raise ValueError("matmul: error decompressing the B chunk")
24572459
batch = 0
2460+
offsetA = 0
2461+
offsetB = 0
2462+
offset = 0
24582463
while batch < batches:
24592464
batch_ = batch
24602465
for i in range(ndim - 2):
@@ -3460,9 +3465,10 @@ cdef class NDArray:
34603465
i = self.array.ndim - idx
34613466
udata.chunks_strides[0][i] = udata.chunks_strides[0][i + 1] * udata.array.extshape[i + 1] // udata.array.chunkshape[i + 1]
34623467
udata.blocks_strides[0][i] = udata.blocks_strides[0][i + 1] * udata.array.extchunkshape[i + 1] // udata.array.blockshape[i + 1]
3468+
udata.el_strides[0][i] = udata.el_strides[0][i + 1] * udata.array.blockshape[i + 1]
34633469

3464-
for j in range(2):
3465-
inp = inputs_[j]
3470+
for j in range(1, 3):
3471+
inp = inputs_[j - 1]
34663472
cstrides = bstrides = estrides = 1
34673473
for idx in range(2, self.array.ndim + 1):
34683474
i = inp.ndim - idx

src/blosc2/lazyexpr.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
process_key,
6969
reducers,
7070
safe_numpy_globals,
71+
try_miniexpr,
7172
)
7273

7374
if not blosc2.IS_WASM:
@@ -76,14 +77,6 @@
7677
global safe_blosc2_globals
7778
safe_blosc2_globals = {}
7879

79-
# Set this to False if miniexpr should not be tried out
80-
try_miniexpr = not blosc2.IS_WASM or getattr(blosc2, "_WASM_MINIEXPR_ENABLED", False)
81-
82-
83-
def _toggle_miniexpr(FLAG):
84-
global try_miniexpr
85-
try_miniexpr = FLAG
86-
8780

8881
def ne_evaluate(expression, local_dict=None, **kwargs):
8982
"""Safely evaluate expressions using numexpr when possible, falling back to numpy."""

src/blosc2/linalg.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import blosc2
1919

20-
from .utils import get_intersecting_chunks, nptranspose, npvecdot, slice_to_chunktuple
20+
from .utils import get_intersecting_chunks, nptranspose, npvecdot, slice_to_chunktuple, try_miniexpr
2121

2222
if TYPE_CHECKING:
2323
from collections.abc import Sequence
@@ -113,11 +113,14 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
113113
result = blosc2.zeros(result_shape, dtype=blosc2.result_type(x1, x2), **kwargs)
114114

115115
# multithreaded matmul
116-
# TODO: handle a) type promotion, b) padding, c) (improved) >2D
116+
# TODO: handle a) type promotion, b) padding (explicitly), c) (improved) >2D
117117
ops = (x1, x2, result)
118118
blocks = result.blocks
119119
all_ndarray = all(isinstance(value, blosc2.NDArray) and value.shape != () for value in ops)
120-
use_miniexpr = True
120+
global try_miniexpr
121+
122+
# Use a local copy so we don't modify the global
123+
use_miniexpr = try_miniexpr
121124
if all_ndarray:
122125
if any(op.dtype != ops[0].dtype for op in ops): # TODO: Remove this condition
123126
use_miniexpr = False
@@ -165,6 +168,7 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
165168
if prefilter_set:
166169
result.schunk.remove_prefilter("miniexpr")
167170
else: # couldn't do multithreading
171+
print("multithreading failed :( ")
168172
if 0 not in result.shape + x1.shape + x2.shape: # if any array is empty, return array of 0s
169173
p, q = result.chunks[-2:]
170174
r = x2.chunks[-1]

src/blosc2/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@
1919

2020
import blosc2
2121

22+
# Set this to False if miniexpr should not be tried out
23+
try_miniexpr = not blosc2.IS_WASM or getattr(blosc2, "_WASM_MINIEXPR_ENABLED", False)
24+
25+
26+
def _toggle_miniexpr(FLAG):
27+
global try_miniexpr
28+
try_miniexpr = FLAG
29+
30+
2231
# NumPy version and a convenient boolean flag
2332
NUMPY_GE_2_0 = np.__version__ >= "2.0"
2433
# handle different numpy versions

0 commit comments

Comments
 (0)