Skip to content

Commit 4bedf3d

Browse files
committed
Add fast-naive stats as well in bench
1 parent 13063c2 commit 4bedf3d

1 file changed

Lines changed: 40 additions & 5 deletions

File tree

bench/ndarray/matmul_path_compare.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def set_path_mode(mode: str) -> bool:
5555

5656

5757
def run_case(
58+
label: str,
5859
mode: str,
5960
block_backend: str,
6061
repeats: int,
@@ -113,6 +114,7 @@ def wrapped_set_pref_matmul(self, inputs, fp_accuracy):
113114
best = min(times)
114115
median = statistics.median(times)
115116
return {
117+
"label": label,
116118
"mode": mode,
117119
"times_s": times,
118120
"best_s": best,
@@ -163,6 +165,7 @@ def main() -> None:
163165
for mode in args.modes:
164166
results.append(
165167
run_case(
168+
mode,
166169
mode,
167170
args.block_backend,
168171
args.repeats,
@@ -178,6 +181,27 @@ def main() -> None:
178181
)
179182
)
180183

184+
if args.block_backend == "auto" and "fast" in args.modes:
185+
fast_naive = run_case(
186+
"fast-naive",
187+
"fast",
188+
"naive",
189+
args.repeats,
190+
shape_a,
191+
shape_b,
192+
dtype,
193+
chunks_a,
194+
chunks_b,
195+
blocks_a,
196+
blocks_b,
197+
chunks_out,
198+
blocks_out,
199+
)
200+
if fast_naive["selected_block_backend"] != next(
201+
item["selected_block_backend"] for item in results if item["mode"] == "fast"
202+
):
203+
results.append(fast_naive)
204+
181205
summary = {
182206
"shape_a": shape_a,
183207
"shape_b": shape_b,
@@ -192,9 +216,13 @@ def main() -> None:
192216
"results": results,
193217
}
194218

195-
best_by_mode = {item["mode"]: item["best_s"] for item in results}
196-
if "chunked" in best_by_mode and "fast" in best_by_mode:
197-
summary["speedup_fast_vs_chunked"] = best_by_mode["chunked"] / best_by_mode["fast"]
219+
best_by_label = {item["label"]: item["best_s"] for item in results}
220+
if "chunked" in best_by_label and "fast" in best_by_label:
221+
summary["speedup_fast_vs_chunked"] = best_by_label["chunked"] / best_by_label["fast"]
222+
if "chunked" in best_by_label and "fast-naive" in best_by_label:
223+
summary["speedup_fast_naive_vs_chunked"] = best_by_label["chunked"] / best_by_label["fast-naive"]
224+
if "fast" in best_by_label and "fast-naive" in best_by_label:
225+
summary["speedup_fast_vs_fast_naive"] = best_by_label["fast-naive"] / best_by_label["fast"]
198226

199227
if args.json:
200228
print(json.dumps(summary, indent=2, sort_keys=True))
@@ -208,10 +236,13 @@ def main() -> None:
208236
print(f" blocks A/B/out: {blocks_a} / {blocks_b} / {blocks_out}")
209237
print(f" repeats: {args.repeats}")
210238
print(f" fast block backend: {args.block_backend}")
211-
for item in results:
239+
display_order = ["chunked", "fast-naive", "fast", "auto"]
240+
ordered_results = sorted(results, key=lambda item: display_order.index(item["label"]) if item["label"] in display_order else len(display_order))
241+
242+
for item in ordered_results:
212243
gflops_best = "-" if item["gflops_best"] is None else f"{item['gflops_best']:.3f}"
213244
print(
214-
f"{item['mode']:>7}: "
245+
f"{item['label']:>10}: "
215246
f"best={item['best_s']:.6f}s "
216247
f"median={item['median_s']:.6f}s "
217248
f"gflops={gflops_best} "
@@ -221,6 +252,10 @@ def main() -> None:
221252
)
222253
if "speedup_fast_vs_chunked" in summary:
223254
print(f"Speedup fast vs chunked: {summary['speedup_fast_vs_chunked']:.3f}x")
255+
if "speedup_fast_naive_vs_chunked" in summary:
256+
print(f"Speedup fast-naive vs chunked: {summary['speedup_fast_naive_vs_chunked']:.3f}x")
257+
if "speedup_fast_vs_fast_naive" in summary:
258+
print(f"Speedup fast vs fast-naive: {summary['speedup_fast_vs_fast_naive']:.3f}x")
224259

225260

226261
if __name__ == "__main__":

0 commit comments

Comments
 (0)