Skip to content

Commit 583bf08

Browse files
author
Mark Saroufim
committed
Fix benchmark exploit via object-identity caching
The benchmark harness was vulnerable to submissions that cache results based on Python object identity (e.g., id(tensor)). Since the same data objects were reused across all timing iterations, a submission could cache on first call and return cached results on subsequent calls, showing artificial speedups of 12-36%. Changes: - Clone data before each timing iteration (outside the timed region) to give each iteration fresh object identities while not affecting measured kernel time - Use local seed variable instead of mutating test.args["seed"] to avoid shared mutable state between benchmark runs
1 parent 2998db4 commit 583bf08

1 file changed

Lines changed: 11 additions & 4 deletions

File tree

problems/nvidia/eval_better_bench_grouped_gemm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,14 @@ def _run_single_benchmark(
242242
data_list = []
243243
# generate input data once
244244

245+
local_seed = test.args.get("seed", None)
245246
for i in range(NUM_ITERATIONS_PER_BENCHMARK):
246-
if "seed" in test.args:
247-
test.args["seed"] += 42
248-
data = generate_input(**test.args)
247+
if local_seed is not None:
248+
local_seed += 42
249+
args = {**test.args, "seed": local_seed}
250+
else:
251+
args = test.args
252+
data = generate_input(**args)
249253
data_list.append(data)
250254

251255
check_copy = _clone_data(data_list)
@@ -272,12 +276,15 @@ def _run_single_benchmark(
272276
for i in range(max_repeats):
273277
torch.cuda.synchronize()
274278

279+
# Clone data before timing to prevent object-identity caching exploits
280+
iteration_data = _clone_data(data_list)
281+
275282
outputs = []
276283
clear_l2_cache()
277284
start_event = torch.cuda.Event(enable_timing=True)
278285
end_event = torch.cuda.Event(enable_timing=True)
279286
start_event.record()
280-
for data in data_list:
287+
for data in iteration_data:
281288
output = custom_kernel(data)
282289
outputs.append(output)
283290
end_event.record()

0 commit comments

Comments
 (0)