11import base64
22import dataclasses
33import multiprocessing
4+ import random
45import re
56import time
67import os
78import sys
89import math
10+ import random
911
1012# Disable CuTe DSL file caching for more stable benchmarking
1113os .environ ["CUTE_DSL_DISABLE_FILE_CACHING" ] = "1"
@@ -240,7 +242,7 @@ def _run_single_benchmark(
240242
241243 durations = []
242244 data_list = []
243- # generate input data once (local seed avoids mutating test.args)
245+ # generate input data once
244246
245247 local_seed = test .args .get ("seed" , None )
246248 for i in range (NUM_ITERATIONS_PER_BENCHMARK ):
@@ -253,8 +255,14 @@ def _run_single_benchmark(
253255 data_list .append (data )
254256
255257 check_copy = _clone_data (data_list )
256-
257- # first, one obligatory correctness check
258+ # Deterministic but hidden probe stream.
259+ # In benchmark mode we use randomized call windows and sparse probes.
260+ # In leaderboard mode we do one full sweep up front, then lightweight probes.
261+ probe_seed = _combine (int (test .args .get ("seed" , 0 ) or 0 ), 0x4D455452 )
262+ probe_rng = random .Random (probe_seed )
263+ full_calls = len (data_list )
264+
265+ # First, one obligatory correctness check on fresh clones.
258266 outputs = []
259267 try :
260268 for data in data_list :
@@ -267,45 +275,88 @@ def _run_single_benchmark(
267275 if not good :
268276 return message
269277
270- # Timing: individual per-call measurement with GPU sync between calls.
271- # This prevents "batch-and-skip" exploits where a submission defers all
272- # work to one call and returns cached/uncomputed results for the rest.
278+ # Timing: per-call intervals captured with CUDA events and one sync.
279+ # We randomize window length/order in benchmark mode to break fixed-N exploits.
273280 # Data is cloned each iteration to prevent object-identity caching.
274281
275282 bm_start_time = time .perf_counter_ns ()
276283 for i in range (max_repeats ):
284+ # Clone and shuffle data before timing to prevent both
285+ # object-identity caching and call-order caching exploits
277286 iteration_data = _clone_data (data_list )
287+ shuffle_order = list (range (len (iteration_data )))
288+ random .shuffle (shuffle_order )
289+ iteration_data = [iteration_data [j ] for j in shuffle_order ]
290+
278291 torch .cuda .synchronize ()
279- clear_l2_cache ()
280292
281- per_call_durations = []
293+ if recheck :
294+ call_indices = list (range (full_calls ))
295+ else :
296+ call_indices = list (range (full_calls ))
297+ probe_rng .shuffle (call_indices )
298+ min_calls = max (4 , full_calls - 6 )
299+ n_calls = probe_rng .randint (min_calls , full_calls )
300+ call_indices = call_indices [:n_calls ]
301+
282302 outputs = []
283- for j , data in enumerate (iteration_data ):
284- start_event = torch .cuda .Event (enable_timing = True )
285- end_event = torch .cuda .Event (enable_timing = True )
286- start_event .record ()
287- output = custom_kernel (data )
288- end_event .record ()
289- torch .cuda .synchronize ()
290- per_call_durations .append (
291- start_event .elapsed_time (end_event ) * 1e6 # Convert ms to ns
292- )
293- outputs .append (output )
303+ events = [torch .cuda .Event (enable_timing = True ) for _ in range (len (call_indices ) + 1 )]
304+ if recheck :
305+ integrity_repeat = (i == 0 ) or (i % 20 == 0 )
306+ else :
307+ integrity_repeat = (i < 3 ) or (i % 25 == 0 )
294308
295- # Per-call correctness check catches deferred-computation exploits:
296- # if a submission skips the kernel and returns uncomputed tensors,
297- # the check fails immediately.
298- if recheck :
299- good , message = check_implementation (check_copy [j ], output )
309+ if integrity_repeat and len (call_indices ) <= 1 :
310+ in_loop_probe_pos = 0 if call_indices else None
311+ elif integrity_repeat :
312+ # Probe before last call to expose deferred-until-last behavior.
313+ in_loop_probe_pos = probe_rng .randrange (0 , len (call_indices ) - 1 )
314+ else :
315+ in_loop_probe_pos = None
316+
317+ events [0 ].record ()
318+ for k , idx in enumerate (call_indices ):
319+ output = custom_kernel (iteration_data [idx ])
320+ outputs .append ((idx , output ))
321+ events [k + 1 ].record ()
322+
323+ # In-loop probe check catches deferred-until-last exploits that would
324+ # otherwise pass if outputs are only validated after the final call.
325+ if in_loop_probe_pos is not None and k == in_loop_probe_pos :
326+ torch .cuda .synchronize ()
327+ good , message = check_implementation (check_copy [idx ], output )
300328 if not good :
301329 return message
330+ torch .cuda .synchronize ()
331+
332+ per_call_durations = [
333+ events [k ].elapsed_time (events [k + 1 ]) * 1e6 for k in range (len (call_indices ))
334+ ]
302335
303- duration = sum (per_call_durations ) / NUM_ITERATIONS_PER_BENCHMARK
304- durations .append (duration )
336+ # Correctness policy:
337+ # - benchmark: sparse hidden integrity repeats + randomized windows/order.
338+ # - leaderboard: sparse integrity repeats; first repeat gets full sweep.
339+ if recheck :
340+ if i == 0 :
341+ check_positions = list (range (len (outputs )))
342+ else :
343+ check_positions = []
344+ else :
345+ check_positions = []
346+
347+ for pos in check_positions :
348+ idx , output = outputs [pos ]
349+ good , message = check_implementation (check_copy [idx ], output )
350+ if not good :
351+ return message
352+
353+ duration = sum (per_call_durations ) / len (call_indices )
354+ if not integrity_repeat :
355+ durations .append (duration )
305356
306357 total_bm_duration = time .perf_counter_ns () - bm_start_time
307358 if (
308- i > 1 and total_bm_duration > 1e8
359+ len ( durations ) > 1 and total_bm_duration > 1e8
309360 ): # at least 2 runs, and at least 100 ms total time
310361 stats = calculate_stats (durations )
311362 # stop if either
@@ -319,6 +370,9 @@ def _run_single_benchmark(
319370 ):
320371 break
321372
373+ if not durations :
374+ return "benchmark produced no timing samples"
375+
322376 return calculate_stats (durations )
323377
324378
@@ -527,8 +581,9 @@ def main():
527581 break
528582
529583 logger .log ("check" , "pass" if passed else "fail" )
584+ return 0 if passed else 112
530585 elif mode == "profile" :
531- run_profiling (logger , pool , tests )
586+ return run_profiling (logger , pool , tests )
532587 else :
533588 # TODO: Implement script mode
534589 return 2
0 commit comments