Skip to content

Commit 3d389a2

Browse files
G-structureNicolas Bourbaki
authored andcommitted
Improve grouped GEMM eval anti-cheat checks
1 parent 208fd03 commit 3d389a2

1 file changed

Lines changed: 83 additions & 28 deletions

File tree

problems/nvidia/eval_better_bench_grouped_gemm.py

Lines changed: 83 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import base64
22
import dataclasses
33
import multiprocessing
4+
import random
45
import re
56
import time
67
import os
78
import sys
89
import math
10+
import random
911

1012
# Disable CuTe DSL file caching for more stable benchmarking
1113
os.environ["CUTE_DSL_DISABLE_FILE_CACHING"] = "1"
@@ -240,7 +242,7 @@ def _run_single_benchmark(
240242

241243
durations = []
242244
data_list = []
243-
# generate input data once (local seed avoids mutating test.args)
245+
# generate input data once
244246

245247
local_seed = test.args.get("seed", None)
246248
for i in range(NUM_ITERATIONS_PER_BENCHMARK):
@@ -253,8 +255,14 @@ def _run_single_benchmark(
253255
data_list.append(data)
254256

255257
check_copy = _clone_data(data_list)
256-
257-
# first, one obligatory correctness check
258+
# Deterministic but hidden probe stream.
259+
# In benchmark mode we use randomized call windows and sparse probes.
260+
# In leaderboard mode we do one full sweep up front, then lightweight probes.
261+
probe_seed = _combine(int(test.args.get("seed", 0) or 0), 0x4D455452)
262+
probe_rng = random.Random(probe_seed)
263+
full_calls = len(data_list)
264+
265+
# First, one obligatory correctness check on fresh clones.
258266
outputs = []
259267
try:
260268
for data in data_list:
@@ -267,45 +275,88 @@ def _run_single_benchmark(
267275
if not good:
268276
return message
269277

270-
# Timing: individual per-call measurement with GPU sync between calls.
271-
# This prevents "batch-and-skip" exploits where a submission defers all
272-
# work to one call and returns cached/uncomputed results for the rest.
278+
# Timing: per-call intervals captured with CUDA events and one sync.
279+
# We randomize window length/order in benchmark mode to break fixed-N exploits.
273280
# Data is cloned each iteration to prevent object-identity caching.
274281

275282
bm_start_time = time.perf_counter_ns()
276283
for i in range(max_repeats):
284+
# Clone and shuffle data before timing to prevent both
285+
# object-identity caching and call-order caching exploits
277286
iteration_data = _clone_data(data_list)
287+
shuffle_order = list(range(len(iteration_data)))
288+
random.shuffle(shuffle_order)
289+
iteration_data = [iteration_data[j] for j in shuffle_order]
290+
278291
torch.cuda.synchronize()
279-
clear_l2_cache()
280292

281-
per_call_durations = []
293+
if recheck:
294+
call_indices = list(range(full_calls))
295+
else:
296+
call_indices = list(range(full_calls))
297+
probe_rng.shuffle(call_indices)
298+
min_calls = max(4, full_calls - 6)
299+
n_calls = probe_rng.randint(min_calls, full_calls)
300+
call_indices = call_indices[:n_calls]
301+
282302
outputs = []
283-
for j, data in enumerate(iteration_data):
284-
start_event = torch.cuda.Event(enable_timing=True)
285-
end_event = torch.cuda.Event(enable_timing=True)
286-
start_event.record()
287-
output = custom_kernel(data)
288-
end_event.record()
289-
torch.cuda.synchronize()
290-
per_call_durations.append(
291-
start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns
292-
)
293-
outputs.append(output)
303+
events = [torch.cuda.Event(enable_timing=True) for _ in range(len(call_indices) + 1)]
304+
if recheck:
305+
integrity_repeat = (i == 0) or (i % 20 == 0)
306+
else:
307+
integrity_repeat = (i < 3) or (i % 25 == 0)
294308

295-
# Per-call correctness check catches deferred-computation exploits:
296-
# if a submission skips the kernel and returns uncomputed tensors,
297-
# the check fails immediately.
298-
if recheck:
299-
good, message = check_implementation(check_copy[j], output)
309+
if integrity_repeat and len(call_indices) <= 1:
310+
in_loop_probe_pos = 0 if call_indices else None
311+
elif integrity_repeat:
312+
# Probe before last call to expose deferred-until-last behavior.
313+
in_loop_probe_pos = probe_rng.randrange(0, len(call_indices) - 1)
314+
else:
315+
in_loop_probe_pos = None
316+
317+
events[0].record()
318+
for k, idx in enumerate(call_indices):
319+
output = custom_kernel(iteration_data[idx])
320+
outputs.append((idx, output))
321+
events[k + 1].record()
322+
323+
# In-loop probe check catches deferred-until-last exploits that would
324+
# otherwise pass if outputs are only validated after the final call.
325+
if in_loop_probe_pos is not None and k == in_loop_probe_pos:
326+
torch.cuda.synchronize()
327+
good, message = check_implementation(check_copy[idx], output)
300328
if not good:
301329
return message
330+
torch.cuda.synchronize()
331+
332+
per_call_durations = [
333+
events[k].elapsed_time(events[k + 1]) * 1e6 for k in range(len(call_indices))
334+
]
302335

303-
duration = sum(per_call_durations) / NUM_ITERATIONS_PER_BENCHMARK
304-
durations.append(duration)
336+
# Correctness policy:
337+
# - benchmark: sparse hidden integrity repeats + randomized windows/order.
338+
# - leaderboard: sparse integrity repeats; first repeat gets full sweep.
339+
if recheck:
340+
if i == 0:
341+
check_positions = list(range(len(outputs)))
342+
else:
343+
check_positions = []
344+
else:
345+
check_positions = []
346+
347+
for pos in check_positions:
348+
idx, output = outputs[pos]
349+
good, message = check_implementation(check_copy[idx], output)
350+
if not good:
351+
return message
352+
353+
duration = sum(per_call_durations) / len(call_indices)
354+
if not integrity_repeat:
355+
durations.append(duration)
305356

306357
total_bm_duration = time.perf_counter_ns() - bm_start_time
307358
if (
308-
i > 1 and total_bm_duration > 1e8
359+
len(durations) > 1 and total_bm_duration > 1e8
309360
): # at least 2 runs, and at least 100 ms total time
310361
stats = calculate_stats(durations)
311362
# stop if either
@@ -319,6 +370,9 @@ def _run_single_benchmark(
319370
):
320371
break
321372

373+
if not durations:
374+
return "benchmark produced no timing samples"
375+
322376
return calculate_stats(durations)
323377

324378

@@ -527,8 +581,9 @@ def main():
527581
break
528582

529583
logger.log("check", "pass" if passed else "fail")
584+
return 0 if passed else 112
530585
elif mode == "profile":
531-
run_profiling(logger, pool, tests)
586+
return run_profiling(logger, pool, tests)
532587
else:
533588
# TODO: Implement script mode
534589
return 2

0 commit comments

Comments
 (0)