Skip to content

Commit 81d6d4e

Browse files
author
Mark Saroufim
committed
Shuffle iteration data and fix recheck bug
Additional hardening on top of the object-identity caching fix: - Shuffle data order each timing iteration to prevent call-count caching (a submission could track invocation count and predict which data item appears at each position) - Move clone before torch.cuda.synchronize() so clone GPU copies can overlap with previous iteration's tail work - Fix pre-existing recheck bug where only the last item's correctness was checked (if not good was outside the for loop) - Use shuffle_order indices to correctly pair shuffled outputs with their reference data during recheck
1 parent 583bf08 commit 81d6d4e

1 file changed

Lines changed: 12 additions & 7 deletions

File tree

problems/nvidia/eval_better_bench_grouped_gemm.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import base64
22
import dataclasses
33
import multiprocessing
4+
import random
45
import re
56
import time
67
import os
@@ -274,10 +275,14 @@ def _run_single_benchmark(
274275

275276
bm_start_time = time.perf_counter_ns()
276277
for i in range(max_repeats):
277-
torch.cuda.synchronize()
278-
279-
# Clone data before timing to prevent object-identity caching exploits
278+
# Clone and shuffle data before timing to prevent both
279+
# object-identity caching and call-order caching exploits
280280
iteration_data = _clone_data(data_list)
281+
shuffle_order = list(range(len(iteration_data)))
282+
random.shuffle(shuffle_order)
283+
iteration_data = [iteration_data[j] for j in shuffle_order]
284+
285+
torch.cuda.synchronize()
281286

282287
outputs = []
283288
clear_l2_cache()
@@ -294,10 +299,10 @@ def _run_single_benchmark(
294299
) * 1e6 # Convert ms to ns
295300

296301
if recheck:
297-
for reference_output, custom_output in zip(check_copy, outputs):
298-
good, message = check_implementation(reference_output, custom_output)
299-
if not good:
300-
return message
302+
for j, custom_output in zip(shuffle_order, outputs):
303+
good, message = check_implementation(check_copy[j], custom_output)
304+
if not good:
305+
return message
301306

302307
durations.append(duration)
303308

0 commit comments

Comments
 (0)