Skip to content

Commit d60eebf

Browse files
committed
Expand matmul path coverage and report selected backend in benchmarks
1 parent 46b0842 commit d60eebf

2 files changed

Lines changed: 310 additions & 23 deletions

File tree

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
import argparse
2+
import json
3+
import statistics
4+
import time
5+
import warnings
6+
7+
import numpy as np
8+
9+
import blosc2
10+
import blosc2.linalg as linalg
11+
12+
13+
def parse_int_tuple(value: str) -> tuple[int, ...]:
14+
return tuple(int(item.strip()) for item in value.split(",") if item.strip())
15+
16+
17+
def build_arrays(
18+
shape_a: tuple[int, ...],
19+
shape_b: tuple[int, ...],
20+
dtype: np.dtype,
21+
chunks_a: tuple[int, ...] | None,
22+
chunks_b: tuple[int, ...] | None,
23+
blocks_a: tuple[int, ...] | None,
24+
blocks_b: tuple[int, ...] | None,
25+
):
26+
a_np = np.ones(shape_a, dtype=dtype)
27+
b_np = np.full(shape_b, 2, dtype=dtype)
28+
a = blosc2.asarray(a_np, chunks=chunks_a, blocks=blocks_a)
29+
b = blosc2.asarray(b_np, chunks=chunks_b, blocks=blocks_b)
30+
return a, b, a_np, b_np
31+
32+
33+
def expected_gflops(shape_a: tuple[int, ...], shape_b: tuple[int, ...], elapsed: float) -> float | None:
34+
if elapsed <= 0 or len(shape_a) < 2 or len(shape_b) < 2:
35+
return None
36+
m = shape_a[-2]
37+
k = shape_a[-1]
38+
n = shape_b[-1]
39+
batch = int(np.prod(np.broadcast_shapes(shape_a[:-2], shape_b[:-2]))) if len(shape_a) > 2 or len(shape_b) > 2 else 1
40+
flops = 2 * batch * m * n * k
41+
return flops / elapsed / 1e9
42+
43+
44+
def set_path_mode(mode: str) -> bool:
45+
original = linalg.try_miniexpr
46+
if mode == "chunked":
47+
linalg.try_miniexpr = False
48+
elif mode == "fast":
49+
linalg.try_miniexpr = True
50+
elif mode == "auto":
51+
linalg.try_miniexpr = original
52+
else:
53+
raise ValueError(f"unknown mode: {mode}")
54+
return original
55+
56+
57+
def run_case(
58+
mode: str,
59+
repeats: int,
60+
shape_a: tuple[int, ...],
61+
shape_b: tuple[int, ...],
62+
dtype: np.dtype,
63+
chunks_a: tuple[int, ...] | None,
64+
chunks_b: tuple[int, ...] | None,
65+
blocks_a: tuple[int, ...] | None,
66+
blocks_b: tuple[int, ...] | None,
67+
chunks_out: tuple[int, ...] | None,
68+
blocks_out: tuple[int, ...] | None,
69+
):
70+
a, b, a_np, b_np = build_arrays(shape_a, shape_b, dtype, chunks_a, chunks_b, blocks_a, blocks_b)
71+
with warnings.catch_warnings():
72+
warnings.simplefilter("ignore", RuntimeWarning)
73+
expected = np.matmul(a_np, b_np)
74+
original_flag = set_path_mode(mode)
75+
original_set_pref_matmul = blosc2.NDArray._set_pref_matmul
76+
selected_paths = []
77+
times = []
78+
result = None
79+
80+
def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
81+
selected_paths.append("fast")
82+
return original_set_pref_matmul(self, inputs, fp_accuracy)
83+
84+
blosc2.NDArray._set_pref_matmul = wrapped_set_pref_matmul
85+
try:
86+
for _ in range(repeats):
87+
before = len(selected_paths)
88+
t0 = time.perf_counter()
89+
with warnings.catch_warnings():
90+
warnings.simplefilter("ignore", RuntimeWarning)
91+
result = blosc2.matmul(a, b, chunks=chunks_out, blocks=blocks_out)
92+
times.append(time.perf_counter() - t0)
93+
if len(selected_paths) == before:
94+
selected_paths.append("chunked")
95+
finally:
96+
blosc2.NDArray._set_pref_matmul = original_set_pref_matmul
97+
linalg.try_miniexpr = original_flag
98+
99+
if result is None:
100+
raise RuntimeError("matmul did not produce a result")
101+
102+
actual = result[:]
103+
np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6)
104+
105+
best = min(times)
106+
median = statistics.median(times)
107+
return {
108+
"mode": mode,
109+
"times_s": times,
110+
"best_s": best,
111+
"median_s": median,
112+
"gflops_best": expected_gflops(shape_a, shape_b, best),
113+
"gflops_median": expected_gflops(shape_a, shape_b, median),
114+
"correct": True,
115+
"selected_paths": selected_paths,
116+
"selected_path": selected_paths[0] if selected_paths and len(set(selected_paths)) == 1 else "mixed",
117+
}
118+
119+
120+
def main() -> None:
121+
parser = argparse.ArgumentParser(description="Compare chunked and fast blosc2.matmul paths.")
122+
parser.add_argument("--shape-a", default="1000,1000", help="Comma-separated shape for A.")
123+
parser.add_argument("--shape-b", default="1000,1000", help="Comma-separated shape for B.")
124+
parser.add_argument("--dtype", default="float64", choices=["float32", "float64", "int32", "int64"])
125+
parser.add_argument("--chunks-a", default="500,500", help="Comma-separated chunk shape for A.")
126+
parser.add_argument("--chunks-b", default="500,500", help="Comma-separated chunk shape for B.")
127+
parser.add_argument("--blocks-a", default="100,100", help="Comma-separated block shape for A.")
128+
parser.add_argument("--blocks-b", default="100,100", help="Comma-separated block shape for B.")
129+
parser.add_argument("--chunks-out", default="500,500", help="Comma-separated chunk shape for output.")
130+
parser.add_argument("--blocks-out", default="100,100", help="Comma-separated block shape for output.")
131+
parser.add_argument("--repeats", type=int, default=5)
132+
parser.add_argument("--modes", nargs="+", default=["chunked", "fast", "auto"], choices=["chunked", "fast", "auto"])
133+
parser.add_argument("--json", action="store_true", help="Emit full JSON instead of a compact text summary.")
134+
args = parser.parse_args()
135+
136+
shape_a = parse_int_tuple(args.shape_a)
137+
shape_b = parse_int_tuple(args.shape_b)
138+
chunks_a = parse_int_tuple(args.chunks_a) if args.chunks_a else None
139+
chunks_b = parse_int_tuple(args.chunks_b) if args.chunks_b else None
140+
blocks_a = parse_int_tuple(args.blocks_a) if args.blocks_a else None
141+
blocks_b = parse_int_tuple(args.blocks_b) if args.blocks_b else None
142+
chunks_out = parse_int_tuple(args.chunks_out) if args.chunks_out else None
143+
blocks_out = parse_int_tuple(args.blocks_out) if args.blocks_out else None
144+
dtype = np.dtype(args.dtype)
145+
146+
results = []
147+
for mode in args.modes:
148+
results.append(
149+
run_case(
150+
mode,
151+
args.repeats,
152+
shape_a,
153+
shape_b,
154+
dtype,
155+
chunks_a,
156+
chunks_b,
157+
blocks_a,
158+
blocks_b,
159+
chunks_out,
160+
blocks_out,
161+
)
162+
)
163+
164+
summary = {
165+
"shape_a": shape_a,
166+
"shape_b": shape_b,
167+
"dtype": str(dtype),
168+
"chunks_a": chunks_a,
169+
"chunks_b": chunks_b,
170+
"blocks_a": blocks_a,
171+
"blocks_b": blocks_b,
172+
"chunks_out": chunks_out,
173+
"blocks_out": blocks_out,
174+
"results": results,
175+
}
176+
177+
best_by_mode = {item["mode"]: item["best_s"] for item in results}
178+
if "chunked" in best_by_mode and "fast" in best_by_mode:
179+
summary["speedup_fast_vs_chunked"] = best_by_mode["chunked"] / best_by_mode["fast"]
180+
181+
if args.json:
182+
print(json.dumps(summary, indent=2, sort_keys=True))
183+
return
184+
185+
print(
186+
"case",
187+
json.dumps(
188+
{
189+
"shape_a": shape_a,
190+
"shape_b": shape_b,
191+
"dtype": str(dtype),
192+
"chunks_out": chunks_out,
193+
"blocks_out": blocks_out,
194+
},
195+
sort_keys=True,
196+
),
197+
)
198+
for item in results:
199+
print(
200+
"result",
201+
json.dumps(
202+
{
203+
"mode": item["mode"],
204+
"best_s": round(item["best_s"], 6),
205+
"median_s": round(item["median_s"], 6),
206+
"gflops_best": None if item["gflops_best"] is None else round(item["gflops_best"], 3),
207+
"correct": item["correct"],
208+
"selected_path": item["selected_path"],
209+
},
210+
sort_keys=True,
211+
),
212+
)
213+
if "speedup_fast_vs_chunked" in summary:
214+
print("speedup", json.dumps({"fast_vs_chunked": round(summary["speedup_fast_vs_chunked"], 3)}, sort_keys=True))
215+
216+
217+
if __name__ == "__main__":
218+
main()

tests/ndarray/test_linalg.py

Lines changed: 92 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -86,44 +86,61 @@ def test_toggle_miniexpr_updates_linalg_runtime_flag():
8686
_toggle_miniexpr(old_flag)
8787

8888

89-
def test_matmul_uses_fast_path_for_supported_2d(monkeypatch):
90-
old_flag = utils_mod.try_miniexpr
89+
def _set_pref_matmul_call_recorder(monkeypatch):
9190
calls = []
9291
original = blosc2.NDArray._set_pref_matmul
9392

9493
def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
95-
calls.append((self.shape, inputs["x1"].shape, inputs["x2"].shape))
94+
calls.append((self.shape, inputs["x1"].shape, inputs["x2"].shape, self.dtype))
9695
return original(self, inputs, fp_accuracy)
9796

9897
monkeypatch.setattr(blosc2.NDArray, "_set_pref_matmul", wrapped_set_pref_matmul)
98+
return calls
99+
100+
101+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
102+
def test_matmul_uses_fast_path_for_supported_2d(monkeypatch, dtype):
103+
old_flag = utils_mod.try_miniexpr
104+
calls = _set_pref_matmul_call_recorder(monkeypatch)
99105
try:
100106
_toggle_miniexpr(True)
101-
a = blosc2.ones(shape=(400, 400), dtype=np.float64, chunks=(200, 200), blocks=(100, 100))
102-
b = blosc2.full(
103-
shape=(400, 400), fill_value=2, dtype=np.float64, chunks=(200, 200), blocks=(100, 100)
104-
)
107+
a = blosc2.ones(shape=(400, 400), dtype=dtype, chunks=(200, 200), blocks=(100, 100))
108+
b = blosc2.full(shape=(400, 400), fill_value=2, dtype=dtype, chunks=(200, 200), blocks=(100, 100))
105109

106110
with warnings.catch_warnings():
107111
warnings.simplefilter("ignore", RuntimeWarning)
108112
c = blosc2.matmul(a, b, chunks=(200, 200), blocks=(100, 100))
109113
expected = np.matmul(a[:], b[:])
110114

111-
assert calls == [((400, 400), (400, 400), (400, 400))]
115+
assert calls == [((400, 400), (400, 400), (400, 400), np.dtype(dtype))]
112116
np.testing.assert_allclose(c[:], expected, rtol=1e-6, atol=1e-6)
113117
finally:
114118
_toggle_miniexpr(old_flag)
115119

116120

117-
def test_matmul_falls_back_for_integer_inputs(monkeypatch):
121+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
122+
def test_matmul_uses_fast_path_with_multiple_inner_blocks(monkeypatch, dtype):
118123
old_flag = utils_mod.try_miniexpr
119-
calls = []
120-
original = blosc2.NDArray._set_pref_matmul
124+
calls = _set_pref_matmul_call_recorder(monkeypatch)
125+
try:
126+
_toggle_miniexpr(True)
127+
a = blosc2.ones(shape=(256, 384), dtype=dtype, chunks=(128, 192), blocks=(64, 64))
128+
b = blosc2.full(shape=(384, 256), fill_value=2, dtype=dtype, chunks=(192, 128), blocks=(64, 64))
121129

122-
def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
123-
calls.append((self.shape, inputs["x1"].shape, inputs["x2"].shape))
124-
return original(self, inputs, fp_accuracy)
130+
with warnings.catch_warnings():
131+
warnings.simplefilter("ignore", RuntimeWarning)
132+
c = blosc2.matmul(a, b, chunks=(128, 128), blocks=(64, 64))
133+
expected = np.matmul(a[:], b[:])
125134

126-
monkeypatch.setattr(blosc2.NDArray, "_set_pref_matmul", wrapped_set_pref_matmul)
135+
assert calls == [((256, 256), (256, 384), (384, 256), np.dtype(dtype))]
136+
np.testing.assert_allclose(c[:], expected, rtol=1e-6, atol=1e-6)
137+
finally:
138+
_toggle_miniexpr(old_flag)
139+
140+
141+
def test_matmul_falls_back_for_integer_inputs(monkeypatch):
142+
old_flag = utils_mod.try_miniexpr
143+
calls = _set_pref_matmul_call_recorder(monkeypatch)
127144
try:
128145
_toggle_miniexpr(True)
129146
a = blosc2.ones(shape=(200, 200), dtype=np.int64, chunks=(100, 100), blocks=(50, 50))
@@ -139,14 +156,7 @@ def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
139156

140157
def test_matmul_falls_back_for_nd_inputs(monkeypatch):
141158
old_flag = utils_mod.try_miniexpr
142-
calls = []
143-
original = blosc2.NDArray._set_pref_matmul
144-
145-
def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
146-
calls.append((self.shape, inputs["x1"].shape, inputs["x2"].shape))
147-
return original(self, inputs, fp_accuracy)
148-
149-
monkeypatch.setattr(blosc2.NDArray, "_set_pref_matmul", wrapped_set_pref_matmul)
159+
calls = _set_pref_matmul_call_recorder(monkeypatch)
150160
try:
151161
_toggle_miniexpr(True)
152162
a = blosc2.ones(shape=(2, 40, 40), dtype=np.float64, chunks=(1, 20, 20), blocks=(1, 10, 10))
@@ -165,6 +175,65 @@ def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
165175
_toggle_miniexpr(old_flag)
166176

167177

178+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
179+
def test_matmul_falls_back_for_misaligned_blocks(monkeypatch, dtype):
180+
old_flag = utils_mod.try_miniexpr
181+
calls = _set_pref_matmul_call_recorder(monkeypatch)
182+
try:
183+
_toggle_miniexpr(True)
184+
a = blosc2.ones(shape=(400, 400), dtype=dtype, chunks=(200, 200), blocks=(120, 100))
185+
b = blosc2.full(shape=(400, 400), fill_value=2, dtype=dtype, chunks=(200, 200), blocks=(100, 100))
186+
187+
with warnings.catch_warnings():
188+
warnings.simplefilter("ignore", RuntimeWarning)
189+
c = blosc2.matmul(a, b, chunks=(200, 200), blocks=(120, 100))
190+
expected = np.matmul(a[:], b[:])
191+
192+
assert calls == []
193+
np.testing.assert_allclose(c[:], expected, rtol=1e-6, atol=1e-6)
194+
finally:
195+
_toggle_miniexpr(old_flag)
196+
197+
198+
def test_matmul_falls_back_for_dtype_mismatch(monkeypatch):
199+
old_flag = utils_mod.try_miniexpr
200+
calls = _set_pref_matmul_call_recorder(monkeypatch)
201+
try:
202+
_toggle_miniexpr(True)
203+
a = blosc2.ones(shape=(200, 200), dtype=np.float32, chunks=(100, 100), blocks=(50, 50))
204+
b = blosc2.full(shape=(200, 200), fill_value=2, dtype=np.float64, chunks=(100, 100), blocks=(50, 50))
205+
206+
with warnings.catch_warnings():
207+
warnings.simplefilter("ignore", RuntimeWarning)
208+
c = blosc2.matmul(a, b, chunks=(100, 100), blocks=(50, 50))
209+
expected = np.matmul(a[:], b[:])
210+
211+
assert calls == []
212+
np.testing.assert_allclose(c[:], expected, rtol=1e-6, atol=1e-6)
213+
finally:
214+
_toggle_miniexpr(old_flag)
215+
216+
217+
@pytest.mark.parametrize("dtype", [np.complex64, np.complex128])
218+
def test_matmul_complex_falls_back_to_chunked(monkeypatch, dtype):
219+
old_flag = utils_mod.try_miniexpr
220+
calls = _set_pref_matmul_call_recorder(monkeypatch)
221+
try:
222+
_toggle_miniexpr(True)
223+
a = blosc2.asarray(np.ones((100, 100), dtype=dtype))
224+
b = blosc2.asarray(np.full((100, 100), 2 + 0j, dtype=dtype))
225+
226+
with warnings.catch_warnings():
227+
warnings.simplefilter("ignore", RuntimeWarning)
228+
c = blosc2.matmul(a, b, chunks=(50, 50), blocks=(25, 25))
229+
expected = np.matmul(a[:], b[:])
230+
231+
assert calls == []
232+
np.testing.assert_allclose(c[:], expected, rtol=1e-6, atol=1e-6)
233+
finally:
234+
_toggle_miniexpr(old_flag)
235+
236+
168237
def test_matmul_fast_path_failure_falls_back(monkeypatch):
169238
old_flag = utils_mod.try_miniexpr
170239

0 commit comments

Comments
 (0)