Skip to content

Commit c0684d1

Browse files
author
Mark Saroufim
committed
Fix moe-mxfp4: use local eval.py that skips tensor cloning
aiter's fused_moe produces incorrect results when weight tensors are cloned to different memory addresses. The eval harness clones all data before passing to the submission, which breaks fused_moe. Since fused_moe does not mutate its inputs, skipping the clone is safe and fixes the correctness check failures.
1 parent df582dd commit c0684d1

3 files changed

Lines changed: 385 additions & 14 deletions

File tree

Lines changed: 382 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,382 @@
1+
import base64
2+
import dataclasses
3+
import multiprocessing
4+
import re
5+
import time
6+
import os
7+
import sys
8+
import math
9+
from pathlib import Path
10+
from typing import Any, Optional
11+
12+
import torch.cuda
13+
14+
from utils import set_seed, clear_l2_cache_large as clear_l2_cache
15+
try:
16+
from task import TestSpec
17+
except ImportError:
18+
TestSpec = dict
19+
20+
from reference import check_implementation, generate_input
21+
22+
23+
class PopcornOutput:
24+
def __init__(self, fd: int):
25+
self.file = os.fdopen(fd, 'w')
26+
os.set_inheritable(fd, False)
27+
28+
def __enter__(self):
29+
return self
30+
31+
def __exit__(self, exc_type, exc_val, exc_tb):
32+
self.file.close()
33+
34+
def print(self, *args, **kwargs):
35+
print(*args, **kwargs, file=self.file, flush=True)
36+
37+
def log(self, key, value):
38+
self.print(f"{key}: {value}")
39+
40+
41+
@dataclasses.dataclass
42+
class TestCase:
43+
args: dict
44+
spec: str
45+
46+
47+
def _combine(a: int, b: int) -> int:
48+
# combine two integers into one:
49+
# we need this to generate a secret seed based on the test-level seed and
50+
# the global secret seed.
51+
# the test-level seeds are public knowledge, and typically relatively small numbers,
52+
# so we need to make sure they don't provide any useful info for the full seed.
53+
# This Cantor construction ensures that if the secret seed is a large number,
54+
# then so is the overall seed.
55+
return int(a + (a+b)*(a+b+1)//2)
56+
57+
58+
def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]:
59+
try:
60+
content = Path(file_name).read_text()
61+
except Exception as E:
62+
print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr)
63+
exit(113)
64+
65+
tests = []
66+
lines = content.splitlines()
67+
match = r"\s*([a-zA-Z_]\w*):\s*([a-zA-Z_]\w*|[+-]?[0-9]+)\s*"
68+
for line in lines:
69+
parts = line.split(";")
70+
case = {}
71+
for part in parts:
72+
matched = re.match(match, part)
73+
if not re.fullmatch(match, part):
74+
print(f"invalid test case: '{line}': '{part}'", file=sys.stderr)
75+
exit(113)
76+
key = matched[1]
77+
val = matched[2]
78+
try:
79+
val = int(val)
80+
except ValueError:
81+
if val == "true":
82+
val = True
83+
elif val == "false":
84+
val = False
85+
86+
case[key] = val
87+
tests.append(TestCase(spec=line, args=case))
88+
89+
if seed is not None:
90+
for test in tests:
91+
if "seed" in test.args:
92+
test.args["seed"] = _combine(test.args["seed"], seed)
93+
94+
return tests
95+
96+
97+
@dataclasses.dataclass
98+
class Stats:
99+
runs: int
100+
mean: float
101+
std: float
102+
err: float
103+
best: float
104+
worst: float
105+
106+
107+
def calculate_stats(durations: list[int]):
108+
"""
109+
Calculate statistical data from a list of durations.
110+
111+
@param durations: A list of durations in nanoseconds.
112+
@return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations.
113+
"""
114+
runs = len(durations)
115+
total = sum(durations)
116+
best = min(durations)
117+
worst = max(durations)
118+
119+
avg = total / runs
120+
variance = sum(map(lambda x: (x - avg)**2, durations))
121+
std = math.sqrt(variance / (runs - 1))
122+
err = std / math.sqrt(runs)
123+
124+
return Stats(runs=runs, mean=avg, std=std, err=err, best=float(best),
125+
worst=float(worst))
126+
127+
128+
def _clone_data(data):
129+
"""
130+
Return data as-is (no cloning).
131+
132+
aiter's fused_moe produces incorrect results when weight tensors are
133+
cloned to different memory addresses (same values, different output).
134+
Since fused_moe does not mutate its inputs, skipping the clone is safe.
135+
"""
136+
return data
137+
138+
139+
def wrap_check_implementation(data, submission_output):
140+
# Old version returned just a single string, new version
141+
# returns (bool, str); this function ensures compatibility with old
142+
# problem definitions.
143+
result = check_implementation(data, submission_output)
144+
if isinstance(result, tuple):
145+
return result
146+
else:
147+
return not bool(result), result
148+
149+
150+
def _run_single_test(test: TestCase):
151+
"""
152+
Runs a single test case. Do not call directly
153+
"""
154+
from submission import custom_kernel
155+
data = generate_input(**test.args)
156+
torch.cuda.synchronize()
157+
submission_output = custom_kernel(_clone_data(data))
158+
torch.cuda.synchronize()
159+
return wrap_check_implementation(data, submission_output)
160+
161+
162+
def run_single_test(pool: multiprocessing.Pool, test: TestCase):
163+
"""
164+
Runs a single test in another process.
165+
"""
166+
return pool.apply(_run_single_test, (test,))
167+
168+
169+
def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]):
170+
"""
171+
Executes the actual test case code and checks for correctness.
172+
173+
@param logger: A PopcornOutput object used for logging test results.
174+
@param tests: A list of TestCase objects representing the test cases to be executed.
175+
@return: An integer representing the exit status: 0 if all tests pass, otherwise 112.
176+
"""
177+
passed = True
178+
logger.log("test-count", len(tests))
179+
for idx, test in enumerate(tests):
180+
logger.log(f"test.{idx}.spec", test.spec)
181+
good, message = run_single_test(pool, test)
182+
if not good:
183+
logger.log(f"test.{idx}.status", "fail")
184+
logger.log(f"test.{idx}.error", message)
185+
passed = False
186+
else:
187+
logger.log(f"test.{idx}.status", "pass")
188+
if message:
189+
logger.log(f"test.{idx}.message", message)
190+
191+
if passed:
192+
logger.log("check", "pass")
193+
return 0
194+
else:
195+
logger.log("check", "fail")
196+
return 112
197+
198+
199+
def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float) -> Stats | Any:
200+
"""
201+
Runs one benchmark. Do not call directly.
202+
"""
203+
from submission import custom_kernel
204+
205+
durations = []
206+
# generate input data once
207+
data = generate_input(**test.args)
208+
check_copy = _clone_data(data)
209+
# first, one obligatory correctness check
210+
output = custom_kernel(data)
211+
good, message = wrap_check_implementation(check_copy, output)
212+
if not good:
213+
return message
214+
215+
# now, do multiple timing runs without further correctness testing
216+
# there is an upper bound of 100 runs, and a lower bound of 3 runs;
217+
# otherwise, we repeat until we either measure at least 10 full seconds,
218+
# or the relative error of the mean is below 1%.
219+
220+
bm_start_time = time.perf_counter_ns()
221+
for i in range(max_repeats):
222+
if recheck:
223+
# ensure we use a different seed for every benchmark
224+
if "seed" in test.args:
225+
test.args["seed"] += 13
226+
227+
data = generate_input(**test.args)
228+
check_copy = _clone_data(data)
229+
torch.cuda.synchronize()
230+
clear_l2_cache()
231+
start_event = torch.cuda.Event(enable_timing=True)
232+
end_event = torch.cuda.Event(enable_timing=True)
233+
start_event.record()
234+
output = custom_kernel(data)
235+
end_event.record()
236+
torch.cuda.synchronize()
237+
238+
if recheck:
239+
good, message = check_implementation(check_copy, output)
240+
if not good:
241+
return message
242+
243+
del output
244+
durations.append(start_event.elapsed_time(end_event) * 1e6)
245+
246+
if i > 1:
247+
total_bm_duration = time.perf_counter_ns() - bm_start_time
248+
stats = calculate_stats(durations)
249+
# stop if either
250+
# a) relative error dips below 0.1%
251+
# b) we exceed the total time limit for benchmarking the kernel
252+
# c) we exceed 2 minutes of total wallclock time.
253+
if stats.err / stats.mean < 0.001 or stats.mean * stats.runs > max_time_ns or total_bm_duration > 120e9:
254+
break
255+
256+
return calculate_stats(durations)
257+
258+
259+
def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int,
260+
max_time_ns: float):
261+
"""
262+
For a particular test case, check correctness (if applicable) and grab runtime results.
263+
264+
@param pool: Process on which the benchmark will be launched.
265+
@param test: TestCase object.
266+
@param recheck: Flag for whether to explicitly check functional correctness.
267+
@param max_repeats: Number of trials to repeat.
268+
@param max_time_ns: Timeout time in nanoseconds.
269+
@return: A Stats object for this particular benchmark case or an error if the test fails.
270+
"""
271+
return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns))
272+
273+
274+
def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]):
275+
"""
276+
Executes benchmarking code for a CUDA Kernel and logs runtimes.
277+
278+
@param logger: A PopcornOutput object used for logging benchmark results.
279+
@param pool: Process on which the benchmarks will be launched.
280+
@param tests: A list of TestCase objects representing the test cases to be benchmarked.
281+
@return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112.
282+
"""
283+
# warm up
284+
run_single_benchmark(pool, tests[0], False, 100, 10e7)
285+
286+
passed = True
287+
logger.log("benchmark-count", len(tests))
288+
for idx, test in enumerate(tests):
289+
logger.log(f"benchmark.{idx}.spec", test.spec)
290+
result = run_single_benchmark(pool, test, False, 1000, 50e9)
291+
if isinstance(result, Stats):
292+
for field in dataclasses.fields(Stats):
293+
logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name))
294+
else:
295+
passed = False
296+
logger.log(f"benchmark.{idx}.status", "fail")
297+
logger.log(f"benchmark.{idx}.error", result)
298+
299+
if passed:
300+
logger.log("check", "pass")
301+
return 0
302+
else:
303+
logger.log("check", "fail")
304+
return 112
305+
306+
307+
def run_single_profile(test: TestCase) -> str:
308+
"""
309+
Runs a single test case. Do not call directly
310+
"""
311+
from submission import custom_kernel
312+
from torch.profiler import profile, record_function, ProfilerActivity
313+
data = generate_input(**test.args)
314+
torch.cuda.synchronize()
315+
316+
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
317+
submission_output = custom_kernel(_clone_data(data))
318+
torch.cuda.synchronize()
319+
return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20)
320+
321+
322+
def run_profiling(logger: PopcornOutput, tests: list[TestCase]):
323+
logger.log("benchmark-count", len(tests))
324+
for idx, test in enumerate(tests):
325+
logger.log(f"benchmark.{idx}.spec", test.spec)
326+
report = run_single_profile(test)
327+
logger.log(f"benchmark.{idx}.report", base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"))
328+
logger.log("check", "pass")
329+
return 0
330+
331+
332+
def main():
333+
fd = os.getenv("POPCORN_FD")
334+
if not fd:
335+
return 111
336+
337+
if len(sys.argv) < 3:
338+
return 2
339+
340+
mode = sys.argv[1]
341+
seed = os.getenv("POPCORN_SEED")
342+
os.unsetenv("POPCORN_SEED")
343+
seed = int(seed) if seed else None
344+
set_seed(seed or 42)
345+
tests = get_test_cases(sys.argv[2], seed)
346+
347+
with PopcornOutput(int(fd)) as logger:
348+
import multiprocessing
349+
mp_context = multiprocessing.get_context('spawn')
350+
with mp_context.Pool(1) as pool:
351+
if mode == "test":
352+
return run_testing(logger, pool, tests)
353+
if mode == "benchmark":
354+
return run_benchmarking(logger, pool, tests)
355+
356+
if mode == "leaderboard":
357+
# warmup
358+
run_single_benchmark(pool, tests[0], False, 100, 1e7)
359+
logger.log("benchmark-count", len(tests))
360+
passed = True
361+
for i in range(len(tests)):
362+
result = run_single_benchmark(pool, tests[i], True, 100, 30e9)
363+
logger.log(f"benchmark.{i}.spec", tests[i].spec)
364+
if isinstance(result, Stats):
365+
for field in dataclasses.fields(Stats):
366+
logger.log(f"benchmark.{i}.{field.name}", getattr(result, field.name))
367+
else:
368+
passed = False
369+
logger.log(f"benchmark.{i}.status", "fail")
370+
logger.log(f"benchmark.{i}.error", str(result)) # TODO: Make sure result implements __str__?
371+
break
372+
373+
logger.log("check", "pass" if passed else "fail")
374+
elif mode == "profile":
375+
run_profiling(logger, tests)
376+
else:
377+
# TODO: Implement script mode
378+
return 2
379+
380+
381+
if __name__ == "__main__":
382+
sys.exit(main())

0 commit comments

Comments
 (0)