|
| 1 | +import base64 |
| 2 | +import dataclasses |
| 3 | +import multiprocessing |
| 4 | +import re |
| 5 | +import time |
| 6 | +import os |
| 7 | +import sys |
| 8 | +import math |
| 9 | +from pathlib import Path |
| 10 | +from typing import Any, Optional |
| 11 | + |
| 12 | +import torch.cuda |
| 13 | + |
| 14 | +from utils import set_seed, clear_l2_cache_large as clear_l2_cache |
| 15 | +try: |
| 16 | + from task import TestSpec |
| 17 | +except ImportError: |
| 18 | + TestSpec = dict |
| 19 | + |
| 20 | +from reference import check_implementation, generate_input |
| 21 | + |
| 22 | + |
| 23 | +class PopcornOutput: |
| 24 | + def __init__(self, fd: int): |
| 25 | + self.file = os.fdopen(fd, 'w') |
| 26 | + os.set_inheritable(fd, False) |
| 27 | + |
| 28 | + def __enter__(self): |
| 29 | + return self |
| 30 | + |
| 31 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 32 | + self.file.close() |
| 33 | + |
| 34 | + def print(self, *args, **kwargs): |
| 35 | + print(*args, **kwargs, file=self.file, flush=True) |
| 36 | + |
| 37 | + def log(self, key, value): |
| 38 | + self.print(f"{key}: {value}") |
| 39 | + |
| 40 | + |
| 41 | +@dataclasses.dataclass |
| 42 | +class TestCase: |
| 43 | + args: dict |
| 44 | + spec: str |
| 45 | + |
| 46 | + |
| 47 | +def _combine(a: int, b: int) -> int: |
| 48 | + # combine two integers into one: |
| 49 | + # we need this to generate a secret seed based on the test-level seed and |
| 50 | + # the global secret seed. |
| 51 | + # the test-level seeds are public knowledge, and typically relatively small numbers, |
| 52 | + # so we need to make sure they don't provide any useful info for the full seed. |
| 53 | + # This Cantor construction ensures that if the secret seed is a large number, |
| 54 | + # then so is the overall seed. |
| 55 | + return int(a + (a+b)*(a+b+1)//2) |
| 56 | + |
| 57 | + |
| 58 | +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: |
| 59 | + try: |
| 60 | + content = Path(file_name).read_text() |
| 61 | + except Exception as E: |
| 62 | + print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) |
| 63 | + exit(113) |
| 64 | + |
| 65 | + tests = [] |
| 66 | + lines = content.splitlines() |
| 67 | + match = r"\s*([a-zA-Z_]\w*):\s*([a-zA-Z_]\w*|[+-]?[0-9]+)\s*" |
| 68 | + for line in lines: |
| 69 | + parts = line.split(";") |
| 70 | + case = {} |
| 71 | + for part in parts: |
| 72 | + matched = re.match(match, part) |
| 73 | + if not re.fullmatch(match, part): |
| 74 | + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) |
| 75 | + exit(113) |
| 76 | + key = matched[1] |
| 77 | + val = matched[2] |
| 78 | + try: |
| 79 | + val = int(val) |
| 80 | + except ValueError: |
| 81 | + if val == "true": |
| 82 | + val = True |
| 83 | + elif val == "false": |
| 84 | + val = False |
| 85 | + |
| 86 | + case[key] = val |
| 87 | + tests.append(TestCase(spec=line, args=case)) |
| 88 | + |
| 89 | + if seed is not None: |
| 90 | + for test in tests: |
| 91 | + if "seed" in test.args: |
| 92 | + test.args["seed"] = _combine(test.args["seed"], seed) |
| 93 | + |
| 94 | + return tests |
| 95 | + |
| 96 | + |
| 97 | +@dataclasses.dataclass |
| 98 | +class Stats: |
| 99 | + runs: int |
| 100 | + mean: float |
| 101 | + std: float |
| 102 | + err: float |
| 103 | + best: float |
| 104 | + worst: float |
| 105 | + |
| 106 | + |
| 107 | +def calculate_stats(durations: list[int]): |
| 108 | + """ |
| 109 | + Calculate statistical data from a list of durations. |
| 110 | +
|
| 111 | + @param durations: A list of durations in nanoseconds. |
| 112 | + @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. |
| 113 | + """ |
| 114 | + runs = len(durations) |
| 115 | + total = sum(durations) |
| 116 | + best = min(durations) |
| 117 | + worst = max(durations) |
| 118 | + |
| 119 | + avg = total / runs |
| 120 | + variance = sum(map(lambda x: (x - avg)**2, durations)) |
| 121 | + std = math.sqrt(variance / (runs - 1)) |
| 122 | + err = std / math.sqrt(runs) |
| 123 | + |
| 124 | + return Stats(runs=runs, mean=avg, std=std, err=err, best=float(best), |
| 125 | + worst=float(worst)) |
| 126 | + |
| 127 | + |
| 128 | +def _clone_data(data): |
| 129 | + """ |
| 130 | + Return data as-is (no cloning). |
| 131 | +
|
| 132 | + aiter's fused_moe produces incorrect results when weight tensors are |
| 133 | + cloned to different memory addresses (same values, different output). |
| 134 | + Since fused_moe does not mutate its inputs, skipping the clone is safe. |
| 135 | + """ |
| 136 | + return data |
| 137 | + |
| 138 | + |
| 139 | +def wrap_check_implementation(data, submission_output): |
| 140 | + # Old version returned just a single string, new version |
| 141 | + # returns (bool, str); this function ensures compatibility with old |
| 142 | + # problem definitions. |
| 143 | + result = check_implementation(data, submission_output) |
| 144 | + if isinstance(result, tuple): |
| 145 | + return result |
| 146 | + else: |
| 147 | + return not bool(result), result |
| 148 | + |
| 149 | + |
| 150 | +def _run_single_test(test: TestCase): |
| 151 | + """ |
| 152 | + Runs a single test case. Do not call directly |
| 153 | + """ |
| 154 | + from submission import custom_kernel |
| 155 | + data = generate_input(**test.args) |
| 156 | + torch.cuda.synchronize() |
| 157 | + submission_output = custom_kernel(_clone_data(data)) |
| 158 | + torch.cuda.synchronize() |
| 159 | + return wrap_check_implementation(data, submission_output) |
| 160 | + |
| 161 | + |
| 162 | +def run_single_test(pool: multiprocessing.Pool, test: TestCase): |
| 163 | + """ |
| 164 | + Runs a single test in another process. |
| 165 | + """ |
| 166 | + return pool.apply(_run_single_test, (test,)) |
| 167 | + |
| 168 | + |
| 169 | +def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): |
| 170 | + """ |
| 171 | + Executes the actual test case code and checks for correctness. |
| 172 | +
|
| 173 | + @param logger: A PopcornOutput object used for logging test results. |
| 174 | + @param tests: A list of TestCase objects representing the test cases to be executed. |
| 175 | + @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. |
| 176 | + """ |
| 177 | + passed = True |
| 178 | + logger.log("test-count", len(tests)) |
| 179 | + for idx, test in enumerate(tests): |
| 180 | + logger.log(f"test.{idx}.spec", test.spec) |
| 181 | + good, message = run_single_test(pool, test) |
| 182 | + if not good: |
| 183 | + logger.log(f"test.{idx}.status", "fail") |
| 184 | + logger.log(f"test.{idx}.error", message) |
| 185 | + passed = False |
| 186 | + else: |
| 187 | + logger.log(f"test.{idx}.status", "pass") |
| 188 | + if message: |
| 189 | + logger.log(f"test.{idx}.message", message) |
| 190 | + |
| 191 | + if passed: |
| 192 | + logger.log("check", "pass") |
| 193 | + return 0 |
| 194 | + else: |
| 195 | + logger.log("check", "fail") |
| 196 | + return 112 |
| 197 | + |
| 198 | + |
| 199 | +def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float) -> Stats | Any: |
| 200 | + """ |
| 201 | + Runs one benchmark. Do not call directly. |
| 202 | + """ |
| 203 | + from submission import custom_kernel |
| 204 | + |
| 205 | + durations = [] |
| 206 | + # generate input data once |
| 207 | + data = generate_input(**test.args) |
| 208 | + check_copy = _clone_data(data) |
| 209 | + # first, one obligatory correctness check |
| 210 | + output = custom_kernel(data) |
| 211 | + good, message = wrap_check_implementation(check_copy, output) |
| 212 | + if not good: |
| 213 | + return message |
| 214 | + |
| 215 | + # now, do multiple timing runs without further correctness testing |
| 216 | + # there is an upper bound of 100 runs, and a lower bound of 3 runs; |
| 217 | + # otherwise, we repeat until we either measure at least 10 full seconds, |
| 218 | + # or the relative error of the mean is below 1%. |
| 219 | + |
| 220 | + bm_start_time = time.perf_counter_ns() |
| 221 | + for i in range(max_repeats): |
| 222 | + if recheck: |
| 223 | + # ensure we use a different seed for every benchmark |
| 224 | + if "seed" in test.args: |
| 225 | + test.args["seed"] += 13 |
| 226 | + |
| 227 | + data = generate_input(**test.args) |
| 228 | + check_copy = _clone_data(data) |
| 229 | + torch.cuda.synchronize() |
| 230 | + clear_l2_cache() |
| 231 | + start_event = torch.cuda.Event(enable_timing=True) |
| 232 | + end_event = torch.cuda.Event(enable_timing=True) |
| 233 | + start_event.record() |
| 234 | + output = custom_kernel(data) |
| 235 | + end_event.record() |
| 236 | + torch.cuda.synchronize() |
| 237 | + |
| 238 | + if recheck: |
| 239 | + good, message = check_implementation(check_copy, output) |
| 240 | + if not good: |
| 241 | + return message |
| 242 | + |
| 243 | + del output |
| 244 | + durations.append(start_event.elapsed_time(end_event) * 1e6) |
| 245 | + |
| 246 | + if i > 1: |
| 247 | + total_bm_duration = time.perf_counter_ns() - bm_start_time |
| 248 | + stats = calculate_stats(durations) |
| 249 | + # stop if either |
| 250 | + # a) relative error dips below 0.1% |
| 251 | + # b) we exceed the total time limit for benchmarking the kernel |
| 252 | + # c) we exceed 2 minutes of total wallclock time. |
| 253 | + if stats.err / stats.mean < 0.001 or stats.mean * stats.runs > max_time_ns or total_bm_duration > 120e9: |
| 254 | + break |
| 255 | + |
| 256 | + return calculate_stats(durations) |
| 257 | + |
| 258 | + |
| 259 | +def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int, |
| 260 | + max_time_ns: float): |
| 261 | + """ |
| 262 | + For a particular test case, check correctness (if applicable) and grab runtime results. |
| 263 | +
|
| 264 | + @param pool: Process on which the benchmark will be launched. |
| 265 | + @param test: TestCase object. |
| 266 | + @param recheck: Flag for whether to explicitly check functional correctness. |
| 267 | + @param max_repeats: Number of trials to repeat. |
| 268 | + @param max_time_ns: Timeout time in nanoseconds. |
| 269 | + @return: A Stats object for this particular benchmark case or an error if the test fails. |
| 270 | + """ |
| 271 | + return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) |
| 272 | + |
| 273 | + |
| 274 | +def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): |
| 275 | + """ |
| 276 | + Executes benchmarking code for a CUDA Kernel and logs runtimes. |
| 277 | +
|
| 278 | + @param logger: A PopcornOutput object used for logging benchmark results. |
| 279 | + @param pool: Process on which the benchmarks will be launched. |
| 280 | + @param tests: A list of TestCase objects representing the test cases to be benchmarked. |
| 281 | + @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. |
| 282 | + """ |
| 283 | + # warm up |
| 284 | + run_single_benchmark(pool, tests[0], False, 100, 10e7) |
| 285 | + |
| 286 | + passed = True |
| 287 | + logger.log("benchmark-count", len(tests)) |
| 288 | + for idx, test in enumerate(tests): |
| 289 | + logger.log(f"benchmark.{idx}.spec", test.spec) |
| 290 | + result = run_single_benchmark(pool, test, False, 1000, 50e9) |
| 291 | + if isinstance(result, Stats): |
| 292 | + for field in dataclasses.fields(Stats): |
| 293 | + logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) |
| 294 | + else: |
| 295 | + passed = False |
| 296 | + logger.log(f"benchmark.{idx}.status", "fail") |
| 297 | + logger.log(f"benchmark.{idx}.error", result) |
| 298 | + |
| 299 | + if passed: |
| 300 | + logger.log("check", "pass") |
| 301 | + return 0 |
| 302 | + else: |
| 303 | + logger.log("check", "fail") |
| 304 | + return 112 |
| 305 | + |
| 306 | + |
| 307 | +def run_single_profile(test: TestCase) -> str: |
| 308 | + """ |
| 309 | + Runs a single test case. Do not call directly |
| 310 | + """ |
| 311 | + from submission import custom_kernel |
| 312 | + from torch.profiler import profile, record_function, ProfilerActivity |
| 313 | + data = generate_input(**test.args) |
| 314 | + torch.cuda.synchronize() |
| 315 | + |
| 316 | + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: |
| 317 | + submission_output = custom_kernel(_clone_data(data)) |
| 318 | + torch.cuda.synchronize() |
| 319 | + return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) |
| 320 | + |
| 321 | + |
| 322 | +def run_profiling(logger: PopcornOutput, tests: list[TestCase]): |
| 323 | + logger.log("benchmark-count", len(tests)) |
| 324 | + for idx, test in enumerate(tests): |
| 325 | + logger.log(f"benchmark.{idx}.spec", test.spec) |
| 326 | + report = run_single_profile(test) |
| 327 | + logger.log(f"benchmark.{idx}.report", base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8")) |
| 328 | + logger.log("check", "pass") |
| 329 | + return 0 |
| 330 | + |
| 331 | + |
| 332 | +def main(): |
| 333 | + fd = os.getenv("POPCORN_FD") |
| 334 | + if not fd: |
| 335 | + return 111 |
| 336 | + |
| 337 | + if len(sys.argv) < 3: |
| 338 | + return 2 |
| 339 | + |
| 340 | + mode = sys.argv[1] |
| 341 | + seed = os.getenv("POPCORN_SEED") |
| 342 | + os.unsetenv("POPCORN_SEED") |
| 343 | + seed = int(seed) if seed else None |
| 344 | + set_seed(seed or 42) |
| 345 | + tests = get_test_cases(sys.argv[2], seed) |
| 346 | + |
| 347 | + with PopcornOutput(int(fd)) as logger: |
| 348 | + import multiprocessing |
| 349 | + mp_context = multiprocessing.get_context('spawn') |
| 350 | + with mp_context.Pool(1) as pool: |
| 351 | + if mode == "test": |
| 352 | + return run_testing(logger, pool, tests) |
| 353 | + if mode == "benchmark": |
| 354 | + return run_benchmarking(logger, pool, tests) |
| 355 | + |
| 356 | + if mode == "leaderboard": |
| 357 | + # warmup |
| 358 | + run_single_benchmark(pool, tests[0], False, 100, 1e7) |
| 359 | + logger.log("benchmark-count", len(tests)) |
| 360 | + passed = True |
| 361 | + for i in range(len(tests)): |
| 362 | + result = run_single_benchmark(pool, tests[i], True, 100, 30e9) |
| 363 | + logger.log(f"benchmark.{i}.spec", tests[i].spec) |
| 364 | + if isinstance(result, Stats): |
| 365 | + for field in dataclasses.fields(Stats): |
| 366 | + logger.log(f"benchmark.{i}.{field.name}", getattr(result, field.name)) |
| 367 | + else: |
| 368 | + passed = False |
| 369 | + logger.log(f"benchmark.{i}.status", "fail") |
| 370 | + logger.log(f"benchmark.{i}.error", str(result)) # TODO: Make sure result implements __str__? |
| 371 | + break |
| 372 | + |
| 373 | + logger.log("check", "pass" if passed else "fail") |
| 374 | + elif mode == "profile": |
| 375 | + run_profiling(logger, tests) |
| 376 | + else: |
| 377 | + # TODO: Implement script mode |
| 378 | + return 2 |
| 379 | + |
| 380 | + |
| 381 | +if __name__ == "__main__": |
| 382 | + sys.exit(main()) |
0 commit comments