Skip to content

Commit cf6b3bd

Browse files
author
Mark Saroufim
committed
warmup all shapes and init env over all processes
1 parent 8227254 commit cf6b3bd

1 file changed

Lines changed: 11 additions & 3 deletions

File tree

problems/nvidia/eval_better_bench.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
import math
99

1010
# Disable CuTe DSL file caching for more stable benchmarking
11-
os.environ["CUTE_DSL_DISABLE_FILE_CACHING"] = "True"
11+
os.environ["CUTE_DSL_DISABLE_FILE_CACHING"] = "1"
12+
13+
14+
def _init_worker():
15+
"""Initialize worker process with correct env vars."""
16+
os.environ["CUTE_DSL_DISABLE_FILE_CACHING"] = "1"
17+
1218

1319
from pathlib import Path
1420
from typing import Any, Optional
@@ -463,14 +469,16 @@ def main():
463469
import multiprocessing
464470

465471
mp_context = multiprocessing.get_context("spawn")
466-
with mp_context.Pool(1) as pool:
472+
with mp_context.Pool(1, initializer=_init_worker) as pool:
467473
if mode == "test":
468474
return run_testing(logger, pool, tests)
469475
if mode == "benchmark":
470476
return run_benchmarking(logger, pool, tests)
471477

472478
if mode == "leaderboard":
473-
run_single_benchmark(pool, tests[0], False, 1000, 5e8)
479+
# Warmup all test shapes to ensure consistent benchmarking
480+
for test in tests:
481+
run_single_benchmark(pool, test, False, 1000, 5e8)
474482
logger.log("benchmark-count", len(tests))
475483
passed = True
476484
for i in range(len(tests)):

0 commit comments

Comments
 (0)