Skip to content

Commit 2ae5387

Browse files
committed
Restrict matmul fast path to supported 2D cases and fall back to chunked path
1 parent f84b360 commit 2ae5387

4 files changed

Lines changed: 185 additions & 84 deletions

File tree

src/blosc2/blosc2_ext.pyx

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2102,6 +2102,7 @@ cdef class SChunk:
21022102
cpdef remove_prefilter(self, func_name, _new_ctx=True):
21032103
cdef udf_udata* udf_data
21042104
cdef user_filters_udata* udata
2105+
cdef mm_udata* mm_data
21052106

21062107
if func_name is not None and func_name in blosc2.prefilter_funcs:
21072108
del blosc2.prefilter_funcs[func_name]
@@ -2123,6 +2124,13 @@ cdef class SChunk:
21232124
if me_data.eval_params != NULL:
21242125
free(me_data.eval_params)
21252126
free(me_data)
2127+
elif self.schunk.storage.cparams.prefilter == <blosc2_prefilter_fn>matmul_prefilter:
2128+
if self.schunk.storage.cparams.preparams != NULL:
2129+
mm_data = <mm_udata*>self.schunk.storage.cparams.preparams.user_data
2130+
if mm_data != NULL:
2131+
if mm_data.inputs != NULL:
2132+
free(mm_data.inputs)
2133+
free(mm_data)
21262134
elif self.schunk.storage.cparams.prefilter != NULL:
21272135
# From Python the preparams->udata with always have the field py_func
21282136
if self.schunk.storage.cparams.preparams != NULL:
@@ -2408,6 +2416,8 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
24082416
out_block_nrows = out_arr.blockshape[ndim - 2]
24092417
out_block_ncols = out_arr.blockshape[ndim - 1]
24102418

2419+
memset(params_output, 0, out_arr.blocknitems * typesize)
2420+
24112421
dctx = blosc2_create_dctx(BLOSC2_DPARAMS_DEFAULTS)
24122422

24132423
first_run = True
@@ -2464,13 +2474,13 @@ cdef int aux_matmul(mm_udata *udata, int64_t nchunk, int32_t nblock, void *param
24642474
if rc < 0:
24652475
raise ValueError("matmul: error decompressing the B chunk")
24662476
batch = 0
2467-
offsetA = 0
2468-
offsetB = 0
2469-
offset = 0
24702477
while batch < batches:
24712478
batch_ = batch
2479+
offsetA = 0
2480+
offsetB = 0
2481+
offset = 0
24722482
for i in range(ndim - 2):
2473-
coord = batch // udata.el_strides[0][i]
2483+
coord = batch_ // udata.el_strides[0][i]
24742484
batch_ = batch_ % udata.el_strides[0][i]
24752485
offsetA += coord * udata.el_strides[1][i]
24762486
offsetB += coord * udata.el_strides[2][i]

src/blosc2/linalg.py

Lines changed: 81 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,79 @@
2323
from collections.abc import Sequence
2424

2525

26-
def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArray: # noqa: C901
26+
def _matmul_chunked(
27+
x1: blosc2.Array, x2: blosc2.NDArray, result: blosc2.NDArray, n: int, m: int, k: int
28+
) -> None:
29+
p, q = result.chunks[-2:]
30+
r = x2.chunks[-1]
31+
32+
intersecting_chunks = get_intersecting_chunks((), result.shape[:-2], result.chunks[:-2])
33+
for chunk in intersecting_chunks:
34+
chunk = chunk.raw
35+
for row in range(0, n, p):
36+
row_end = builtins.min(row + p, n)
37+
for col in range(0, m, q):
38+
col_end = builtins.min(col + q, m)
39+
for aux in range(0, k, r):
40+
aux_end = builtins.min(aux + r, k)
41+
bx1 = (
42+
x1[chunk[-x1.ndim + 2 :] + (slice(row, row_end), slice(aux, aux_end))]
43+
if x1.ndim > 2
44+
else x1[row:row_end, aux:aux_end]
45+
)
46+
bx2 = (
47+
x2[chunk[-x2.ndim + 2 :] + (slice(aux, aux_end), slice(col, col_end))]
48+
if x2.ndim > 2
49+
else x2[aux:aux_end, col:col_end]
50+
)
51+
result[chunk + (slice(row, row_end), slice(col, col_end))] += np.matmul(bx1, bx2)
52+
53+
54+
def _matmul_can_use_fast_path(
55+
x1: blosc2.Array, x2: blosc2.NDArray, result: blosc2.NDArray, use_miniexpr: bool
56+
) -> bool:
57+
if not use_miniexpr:
58+
return False
59+
60+
ops = (x1, x2, result)
61+
all_ndarray = all(isinstance(value, blosc2.NDArray) and value.shape != () for value in ops)
62+
if not all_ndarray:
63+
return False
64+
65+
# The current prefilter-backed implementation is only supported for 2-D layouts.
66+
if result.ndim != 2 or x1.ndim != 2 or x2.ndim != 2:
67+
return False
68+
69+
if any(op.dtype != ops[0].dtype for op in ops):
70+
return False
71+
72+
chunks_aligned = x1.chunks[-2] % x1.blocks[-2] == 0
73+
chunks_aligned &= x2.chunks[-1] % x2.blocks[-1] == 0
74+
chunks_aligned &= x2.chunks[-2] % x1.blocks[-1] == 0
75+
if not chunks_aligned:
76+
return False
77+
78+
same_blocks = x2.blocks[-2] == x1.blocks[-1]
79+
same_blocks &= x2.blocks[-1] == result.blocks[-1]
80+
same_blocks &= result.blocks[-2] == x1.blocks[-2]
81+
if not same_blocks:
82+
return False
83+
84+
try:
85+
result_blocks = np.broadcast_shapes(x1.blocks, x2.blocks)
86+
except ValueError:
87+
return False
88+
if result_blocks[:-2] != result.blocks[:-2]:
89+
return False
90+
91+
if x1.dtype.kind not in ("i", "f"):
92+
return False
93+
if x2.dtype.kind not in ("i", "f"):
94+
return False
95+
return x1.dtype == x2.dtype
96+
97+
98+
def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArray:
2799
"""
28100
Computes the matrix product between two Blosc2 NDArrays.
29101
@@ -112,60 +184,10 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
112184
kwargs["_chunksize_reduc_factor"] = 1
113185
result = blosc2.zeros(result_shape, dtype=blosc2.result_type(x1, x2), **kwargs)
114186

115-
# multithreaded matmul
116-
# TODO: handle a) type promotion, b) padding (explicitly), c) (improved) >2D
117-
ops = (x1, x2, result)
118-
all_ndarray = all(isinstance(value, blosc2.NDArray) and value.shape != () for value in ops)
119187
global try_miniexpr
120188

121-
# Use a local copy so we don't modify the global
122-
use_miniexpr = try_miniexpr
123189
if 0 not in result.shape + x1.shape + x2.shape: # if any array is empty, return array of 0s
124-
if all_ndarray:
125-
if any(op.dtype != ops[0].dtype for op in ops): # TODO: Remove this condition
126-
use_miniexpr = False
127-
128-
# TODO: We can relax this to even just load according to result blockshape, but that's difficult.
129-
# Just force same chunk/block shapes
130-
# same_chunks = all(op.chunks == result.chunks for op in (x1, x2))
131-
# same_blocks = all(op.blocks == result.blocks for op in (x1, x2))
132-
# same_shape = all(op.shape == result.shape for op in (x1, x2))
133-
134-
# use_miniexpr &= same_blocks & same_chunks & same_shape
135-
# Two easier cases are presented below
136-
# Case 1: Might want to restrict loading across chunk boundaries, in which case would require:
137-
# x1.chunks[-2] % result.blocks[-2] == 0
138-
# x2.chunks[-1] % result.blocks[-1] == 0
139-
# x2.chunks[-2] % x1.blocks[-1] == 0
140-
# Can then load in x1 as slices of size [result.blocks[-2], x1.blocks[-1]]
141-
# and x2 in slices of [x1.blocks[-1], result.blocks[-1]]
142-
143-
# Case 2: Slightly easier to implement this maybe
144-
# Require that blocks are matmul compatible and broadcastable directly to result
145-
# (M, K) x (K, N) = (M, N)
146-
# so can load block-by-block for inputs and calculate block of output
147-
# Also need to avoid loading across chunk boundaries
148-
chunks_aligned = x1.chunks[-2] % x1.blocks[-2] == 0
149-
chunks_aligned &= x2.chunks[-1] % x2.blocks[-1] == 0
150-
chunks_aligned &= x2.chunks[-2] % x1.blocks[-1] == 0
151-
same_blocks = x2.blocks[-2] == x1.blocks[-1]
152-
same_blocks &= x2.blocks[-1] == result.blocks[-1]
153-
same_blocks &= result.blocks[-2] == x1.blocks[-2]
154-
try:
155-
result_blocks = np.broadcast_shapes(x1.blocks, x2.blocks)
156-
if not (same_blocks and chunks_aligned and result_blocks[:-2] == result.blocks[:-2]):
157-
use_miniexpr = False
158-
except ValueError:
159-
use_miniexpr = False
160-
161-
use_miniexpr &= x1.dtype.kind in ("i", "f")
162-
use_miniexpr &= x2.dtype.kind in ("i", "f")
163-
use_miniexpr &= x1.dtype == x2.dtype
164-
165-
else:
166-
use_miniexpr = False
167-
168-
if use_miniexpr:
190+
if _matmul_can_use_fast_path(x1, x2, result, try_miniexpr):
169191
prefilter_set = False
170192
try:
171193
result._set_pref_matmul({"x1": x1, "x2": x2}, fp_accuracy=blosc2.FPAccuracy.DEFAULT)
@@ -174,36 +196,16 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
174196
data = np.empty(result.schunk.chunksize, dtype=np.uint8)
175197
for nchunk_out in range(result.schunk.nchunks):
176198
result.schunk.update_data(nchunk_out, data, copy=False)
177-
except Exception as e:
178-
raise Exception from e
199+
except Exception as exc:
200+
warnings.warn(
201+
f"Fast matmul path unavailable; falling back to chunked path: {exc}", RuntimeWarning
202+
)
203+
_matmul_chunked(x1, x2, result, n, m, k)
179204
finally:
180205
if prefilter_set:
181206
result.schunk.remove_prefilter("miniexpr")
182-
else: # couldn't do multithreading
183-
print("multithreading failed :( ")
184-
p, q = result.chunks[-2:]
185-
r = x2.chunks[-1]
186-
187-
intersecting_chunks = get_intersecting_chunks((), result.shape[:-2], result.chunks[:-2])
188-
for chunk in intersecting_chunks:
189-
chunk = chunk.raw
190-
for row in range(0, n, p):
191-
row_end = builtins.min(row + p, n)
192-
for col in range(0, m, q):
193-
col_end = builtins.min(col + q, m)
194-
for aux in range(0, k, r):
195-
aux_end = builtins.min(aux + r, k)
196-
bx1 = (
197-
x1[chunk[-x1.ndim + 2 :] + (slice(row, row_end), slice(aux, aux_end))]
198-
if x1.ndim > 2
199-
else x1[row:row_end, aux:aux_end]
200-
)
201-
bx2 = (
202-
x2[chunk[-x2.ndim + 2 :] + (slice(aux, aux_end), slice(col, col_end))]
203-
if x2.ndim > 2
204-
else x2[aux:aux_end, col:col_end]
205-
)
206-
result[chunk + (slice(row, row_end), slice(col, col_end))] += np.matmul(bx1, bx2)
207+
else:
208+
_matmul_chunked(x1, x2, result, n, m, k)
207209

208210
if x1_is_vector:
209211
result = result.squeeze(axis=-2)

src/blosc2/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import builtins
1010
import inspect
1111
import math
12+
import sys
1213
import warnings
1314
from itertools import product
1415

@@ -26,6 +27,10 @@
2627
def _toggle_miniexpr(FLAG):
2728
global try_miniexpr
2829
try_miniexpr = FLAG
30+
for module_name in ("blosc2.lazyexpr", "blosc2.linalg"):
31+
module = sys.modules.get(module_name)
32+
if module is not None:
33+
module.try_miniexpr = FLAG
2934

3035

3136
# NumPy version and a convenient boolean flag

tests/ndarray/test_linalg.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
import pytest
1313

1414
import blosc2
15+
import blosc2.linalg as blosc2_linalg
16+
import blosc2.utils as utils_mod
1517
from blosc2.lazyexpr import linalg_funcs
16-
from blosc2.utils import npvecdot
18+
from blosc2.utils import _toggle_miniexpr, npvecdot
1719

1820
# Conditionally import torch for proxy tests
1921
try:
@@ -69,6 +71,88 @@ def test_matmul(ashape, achunks, ablocks, bshape, bchunks, bblocks, dtype):
6971
np.testing.assert_allclose(b2_res[()], np_res, rtol=1e-6)
7072

7173

74+
def test_toggle_miniexpr_updates_linalg_runtime_flag():
75+
old_flag = utils_mod.try_miniexpr
76+
try:
77+
_toggle_miniexpr(False)
78+
assert utils_mod.try_miniexpr is False
79+
assert blosc2_linalg.try_miniexpr is False
80+
81+
_toggle_miniexpr(True)
82+
assert utils_mod.try_miniexpr is True
83+
assert blosc2_linalg.try_miniexpr is True
84+
finally:
85+
_toggle_miniexpr(old_flag)
86+
87+
88+
def test_matmul_uses_fast_path_for_supported_2d(monkeypatch):
89+
old_flag = utils_mod.try_miniexpr
90+
calls = []
91+
original = blosc2.NDArray._set_pref_matmul
92+
93+
def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
94+
calls.append((self.shape, inputs["x1"].shape, inputs["x2"].shape))
95+
return original(self, inputs, fp_accuracy)
96+
97+
monkeypatch.setattr(blosc2.NDArray, "_set_pref_matmul", wrapped_set_pref_matmul)
98+
try:
99+
_toggle_miniexpr(True)
100+
a = blosc2.ones(shape=(400, 400), dtype=np.int64, chunks=(200, 200), blocks=(100, 100))
101+
b = blosc2.full(shape=(400, 400), fill_value=2, dtype=np.int64, chunks=(200, 200), blocks=(100, 100))
102+
103+
c = blosc2.matmul(a, b, chunks=(200, 200), blocks=(100, 100))
104+
105+
assert calls == [((400, 400), (400, 400), (400, 400))]
106+
np.testing.assert_allclose(c[:], np.matmul(a[:], b[:]), rtol=1e-6, atol=1e-6)
107+
finally:
108+
_toggle_miniexpr(old_flag)
109+
110+
111+
def test_matmul_falls_back_for_nd_inputs(monkeypatch):
112+
old_flag = utils_mod.try_miniexpr
113+
calls = []
114+
original = blosc2.NDArray._set_pref_matmul
115+
116+
def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
117+
calls.append((self.shape, inputs["x1"].shape, inputs["x2"].shape))
118+
return original(self, inputs, fp_accuracy)
119+
120+
monkeypatch.setattr(blosc2.NDArray, "_set_pref_matmul", wrapped_set_pref_matmul)
121+
try:
122+
_toggle_miniexpr(True)
123+
a = blosc2.ones(shape=(2, 40, 40), dtype=np.int64, chunks=(1, 20, 20), blocks=(1, 10, 10))
124+
b = blosc2.full(
125+
shape=(2, 40, 40), fill_value=2, dtype=np.int64, chunks=(1, 20, 20), blocks=(1, 10, 10)
126+
)
127+
128+
c = blosc2.matmul(a, b, chunks=(1, 20, 20), blocks=(1, 10, 10))
129+
130+
assert calls == []
131+
np.testing.assert_allclose(c[:], np.matmul(a[:], b[:]), rtol=1e-6, atol=1e-6)
132+
finally:
133+
_toggle_miniexpr(old_flag)
134+
135+
136+
def test_matmul_fast_path_failure_falls_back(monkeypatch):
137+
old_flag = utils_mod.try_miniexpr
138+
139+
def failing_set_pref_matmul(self, inputs, fp_accuracy):
140+
raise RuntimeError("boom")
141+
142+
monkeypatch.setattr(blosc2.NDArray, "_set_pref_matmul", failing_set_pref_matmul)
143+
try:
144+
_toggle_miniexpr(True)
145+
a = blosc2.ones(shape=(200, 200), dtype=np.int64, chunks=(100, 100), blocks=(50, 50))
146+
b = blosc2.full(shape=(200, 200), fill_value=2, dtype=np.int64, chunks=(100, 100), blocks=(50, 50))
147+
148+
with pytest.warns(RuntimeWarning, match="falling back to chunked path"):
149+
c = blosc2.matmul(a, b, chunks=(100, 100), blocks=(50, 50))
150+
151+
np.testing.assert_allclose(c[:], np.matmul(a[:], b[:]), rtol=1e-6, atol=1e-6)
152+
finally:
153+
_toggle_miniexpr(old_flag)
154+
155+
72156
@pytest.mark.parametrize(
73157
("ashape", "achunks", "ablocks"),
74158
{

0 commit comments

Comments
 (0)