|
| 1 | +import dataclasses |
| 2 | +import math |
| 3 | +import os |
| 4 | +import re |
| 5 | +import statistics |
| 6 | +import sys |
| 7 | +from pathlib import Path |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.nn.functional as F |
| 11 | + |
| 12 | +from reference import ( |
| 13 | + ATOL, |
| 14 | + DTYPE, |
| 15 | + RTOL, |
| 16 | + generate_inputs, |
| 17 | + reference_backward, |
| 18 | + reference_forward, |
| 19 | +) |
| 20 | + |
| 21 | + |
| 22 | +# Original eval parameters |
| 23 | +B = 4_096 |
| 24 | +WARMUP_ITERS = 20 |
| 25 | +BENCH_ITERS = 100 |
| 26 | + |
| 27 | + |
| 28 | +class PopcornOutput: |
| 29 | + def __init__(self, fd: int): |
| 30 | + self.file = os.fdopen(fd, "w") |
| 31 | + os.set_inheritable(fd, False) |
| 32 | + |
| 33 | + def __enter__(self): |
| 34 | + return self |
| 35 | + |
| 36 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 37 | + self.file.close() |
| 38 | + |
| 39 | + def print(self, *args, **kwargs): |
| 40 | + print(*args, **kwargs, file=self.file, flush=True) |
| 41 | + |
| 42 | + def log(self, key, value): |
| 43 | + self.print(f"{key}: {value}") |
| 44 | + |
| 45 | + |
| 46 | +@dataclasses.dataclass |
| 47 | +class TestCase: |
| 48 | + args: dict |
| 49 | + spec: str |
| 50 | + |
| 51 | + |
| 52 | +@dataclasses.dataclass |
| 53 | +class Stats: |
| 54 | + runs: int |
| 55 | + mean: float |
| 56 | + std: float |
| 57 | + err: float |
| 58 | + best: float |
| 59 | + worst: float |
| 60 | + fwd_bw: float |
| 61 | + bwd_bw: float |
| 62 | + combined_bw: float |
| 63 | + |
| 64 | + |
| 65 | +def get_test_cases(file_name: str) -> list[TestCase]: |
| 66 | + try: |
| 67 | + content = Path(file_name).read_text() |
| 68 | + except Exception as exc: |
| 69 | + print(f"Could not open test file `{file_name}`: {exc}", file=sys.stderr) |
| 70 | + sys.exit(113) |
| 71 | + |
| 72 | + tests = [] |
| 73 | + lines = content.splitlines() |
| 74 | + match = r"\s*([a-zA-Z_]+):\s*([a-zA-Z_]+|[+-]?[0-9]+)\s*" |
| 75 | + for line in lines: |
| 76 | + if not line.strip(): |
| 77 | + continue |
| 78 | + parts = line.split(";") |
| 79 | + case = {} |
| 80 | + for part in parts: |
| 81 | + matched = re.match(match, part) |
| 82 | + if not re.fullmatch(match, part): |
| 83 | + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) |
| 84 | + sys.exit(113) |
| 85 | + key = matched[1] |
| 86 | + value = matched[2] |
| 87 | + try: |
| 88 | + value = int(value) |
| 89 | + except ValueError: |
| 90 | + pass |
| 91 | + case[key] = value |
| 92 | + tests.append(TestCase(spec=line, args=case)) |
| 93 | + return tests |
| 94 | + |
| 95 | + |
| 96 | +def load_submission(): |
| 97 | + import submission |
| 98 | + |
| 99 | + for fn_name in ("cross_entropy_forward", "cross_entropy_backward"): |
| 100 | + if not hasattr(submission, fn_name): |
| 101 | + raise AttributeError(f"Submission is missing function '{fn_name}'.") |
| 102 | + return submission |
| 103 | + |
| 104 | + |
| 105 | +def check_correctness(mod, vocab_size): |
| 106 | + logits, targets, grad_output = generate_inputs(B, vocab_size) |
| 107 | + |
| 108 | + ref_loss = reference_forward(logits, targets) |
| 109 | + sub_loss = mod.cross_entropy_forward(logits, targets) |
| 110 | + |
| 111 | + assert sub_loss.shape == ref_loss.shape, ( |
| 112 | + f"Forward shape mismatch: expected {ref_loss.shape}, got {sub_loss.shape}" |
| 113 | + ) |
| 114 | + assert sub_loss.dtype == torch.float32, ( |
| 115 | + f"Forward dtype mismatch: expected float32, got {sub_loss.dtype}" |
| 116 | + ) |
| 117 | + |
| 118 | + fwd_close = torch.allclose(sub_loss, ref_loss, atol=ATOL, rtol=RTOL) |
| 119 | + max_fwd_err = (sub_loss - ref_loss).abs().max().item() |
| 120 | + |
| 121 | + ref_grad = reference_backward(logits, targets, grad_output) |
| 122 | + sub_grad = mod.cross_entropy_backward(logits, targets, grad_output) |
| 123 | + |
| 124 | + assert sub_grad.shape == ref_grad.shape, ( |
| 125 | + f"Backward shape mismatch: expected {ref_grad.shape}, got {sub_grad.shape}" |
| 126 | + ) |
| 127 | + assert sub_grad.dtype == DTYPE, ( |
| 128 | + f"Backward dtype mismatch: expected {DTYPE}, got {sub_grad.dtype}" |
| 129 | + ) |
| 130 | + |
| 131 | + bwd_close = torch.allclose(sub_grad, ref_grad, atol=ATOL, rtol=RTOL) |
| 132 | + max_bwd_err = (sub_grad.float() - ref_grad.float()).abs().max().item() |
| 133 | + |
| 134 | + return fwd_close, bwd_close, max_fwd_err, max_bwd_err |
| 135 | + |
| 136 | + |
| 137 | +def benchmark_one(mod, vocab_size): |
| 138 | + logits, targets, grad_output = generate_inputs(B, vocab_size, seed=123) |
| 139 | + |
| 140 | + for _ in range(WARMUP_ITERS): |
| 141 | + mod.cross_entropy_forward(logits, targets) |
| 142 | + mod.cross_entropy_backward(logits, targets, grad_output) |
| 143 | + torch.cuda.synchronize() |
| 144 | + |
| 145 | + fwd_times = [] |
| 146 | + for _ in range(BENCH_ITERS): |
| 147 | + start = torch.cuda.Event(enable_timing=True) |
| 148 | + end = torch.cuda.Event(enable_timing=True) |
| 149 | + start.record() |
| 150 | + mod.cross_entropy_forward(logits, targets) |
| 151 | + end.record() |
| 152 | + torch.cuda.synchronize() |
| 153 | + fwd_times.append(start.elapsed_time(end)) |
| 154 | + |
| 155 | + bwd_times = [] |
| 156 | + for _ in range(BENCH_ITERS): |
| 157 | + start = torch.cuda.Event(enable_timing=True) |
| 158 | + end = torch.cuda.Event(enable_timing=True) |
| 159 | + start.record() |
| 160 | + mod.cross_entropy_backward(logits, targets, grad_output) |
| 161 | + end.record() |
| 162 | + torch.cuda.synchronize() |
| 163 | + bwd_times.append(start.elapsed_time(end)) |
| 164 | + |
| 165 | + combined_times = [] |
| 166 | + for _ in range(BENCH_ITERS): |
| 167 | + start = torch.cuda.Event(enable_timing=True) |
| 168 | + end = torch.cuda.Event(enable_timing=True) |
| 169 | + start.record() |
| 170 | + mod.cross_entropy_forward(logits, targets) |
| 171 | + mod.cross_entropy_backward(logits, targets, grad_output) |
| 172 | + end.record() |
| 173 | + torch.cuda.synchronize() |
| 174 | + combined_times.append(start.elapsed_time(end)) |
| 175 | + |
| 176 | + fwd_ms = statistics.median(fwd_times) |
| 177 | + bwd_ms = statistics.median(bwd_times) |
| 178 | + combined_ms = statistics.median(combined_times) |
| 179 | + |
| 180 | + fwd_bytes = 2 * B * vocab_size + 12 * B |
| 181 | + bwd_bytes = 4 * B * vocab_size + 12 * B |
| 182 | + total_bytes = fwd_bytes + bwd_bytes |
| 183 | + |
| 184 | + fwd_bw = fwd_bytes / (fwd_ms * 1e-3) / 1e9 |
| 185 | + bwd_bw = bwd_bytes / (bwd_ms * 1e-3) / 1e9 |
| 186 | + combined_bw = total_bytes / (combined_ms * 1e-3) / 1e9 |
| 187 | + |
| 188 | + # Keep KernelBot scoring on the exact reported metric: median combined ms. |
| 189 | + return Stats( |
| 190 | + runs=BENCH_ITERS, |
| 191 | + mean=combined_ms * 1e6, |
| 192 | + std=statistics.pstdev(combined_times) * 1e6, |
| 193 | + err=(statistics.pstdev(combined_times) / math.sqrt(len(combined_times))) * 1e6, |
| 194 | + best=min(combined_times) * 1e6, |
| 195 | + worst=max(combined_times) * 1e6, |
| 196 | + fwd_bw=fwd_bw, |
| 197 | + bwd_bw=bwd_bw, |
| 198 | + combined_bw=combined_bw, |
| 199 | + ) |
| 200 | + |
| 201 | + |
| 202 | +def run_testing(logger: PopcornOutput, tests: list[TestCase]) -> int: |
| 203 | + try: |
| 204 | + mod = load_submission() |
| 205 | + except Exception as exc: |
| 206 | + logger.log("check", "fail") |
| 207 | + logger.log("error", str(exc)) |
| 208 | + return 112 |
| 209 | + |
| 210 | + passed = True |
| 211 | + logger.log("test-count", len(tests)) |
| 212 | + for idx, test in enumerate(tests): |
| 213 | + vocab_size = int(test.args["vocab_size"]) |
| 214 | + logger.log(f"test.{idx}.spec", test.spec) |
| 215 | + try: |
| 216 | + fwd_ok, bwd_ok, fwd_err, bwd_err = check_correctness(mod, vocab_size) |
| 217 | + if fwd_ok and bwd_ok: |
| 218 | + logger.log(f"test.{idx}.status", "pass") |
| 219 | + logger.log( |
| 220 | + f"test.{idx}.message", |
| 221 | + f"forward max err={fwd_err:.3e}, backward max err={bwd_err:.3e}", |
| 222 | + ) |
| 223 | + else: |
| 224 | + logger.log(f"test.{idx}.status", "fail") |
| 225 | + logger.log( |
| 226 | + f"test.{idx}.error", |
| 227 | + f"forward max err={fwd_err:.3e} {'OK' if fwd_ok else 'FAIL'}; " |
| 228 | + f"backward max err={bwd_err:.3e} {'OK' if bwd_ok else 'FAIL'}", |
| 229 | + ) |
| 230 | + passed = False |
| 231 | + except Exception as exc: |
| 232 | + logger.log(f"test.{idx}.status", "fail") |
| 233 | + logger.log(f"test.{idx}.error", str(exc)) |
| 234 | + passed = False |
| 235 | + |
| 236 | + logger.log("check", "pass" if passed else "fail") |
| 237 | + return 0 if passed else 112 |
| 238 | + |
| 239 | + |
| 240 | +def run_benchmarking(logger: PopcornOutput, tests: list[TestCase]) -> int: |
| 241 | + try: |
| 242 | + mod = load_submission() |
| 243 | + except Exception as exc: |
| 244 | + logger.log("check", "fail") |
| 245 | + logger.log("error", str(exc)) |
| 246 | + return 112 |
| 247 | + |
| 248 | + baseline_mod = type(sys)("baseline") |
| 249 | + baseline_mod.cross_entropy_forward = ( |
| 250 | + lambda logits, targets: F.cross_entropy(logits.float(), targets, reduction="none") |
| 251 | + ) |
| 252 | + |
| 253 | + def baseline_bwd(logits, targets, grad_output): |
| 254 | + probs = torch.softmax(logits.float(), dim=-1) |
| 255 | + probs[torch.arange(logits.shape[0], device=logits.device), targets] -= 1.0 |
| 256 | + return (probs * grad_output.unsqueeze(1)).to(logits.dtype) |
| 257 | + |
| 258 | + baseline_mod.cross_entropy_backward = baseline_bwd |
| 259 | + |
| 260 | + passed = True |
| 261 | + logger.log("benchmark-count", len(tests)) |
| 262 | + for idx, test in enumerate(tests): |
| 263 | + vocab_size = int(test.args["vocab_size"]) |
| 264 | + logger.log(f"benchmark.{idx}.spec", test.spec) |
| 265 | + try: |
| 266 | + baseline = benchmark_one(baseline_mod, vocab_size) |
| 267 | + result = benchmark_one(mod, vocab_size) |
| 268 | + speedup = baseline.mean / result.mean |
| 269 | + except Exception as exc: |
| 270 | + logger.log(f"benchmark.{idx}.status", "fail") |
| 271 | + logger.log(f"benchmark.{idx}.error", str(exc)) |
| 272 | + passed = False |
| 273 | + continue |
| 274 | + |
| 275 | + logger.log(f"benchmark.{idx}.runs", result.runs) |
| 276 | + logger.log(f"benchmark.{idx}.mean", result.mean) |
| 277 | + logger.log(f"benchmark.{idx}.std", result.std) |
| 278 | + logger.log(f"benchmark.{idx}.err", result.err) |
| 279 | + logger.log(f"benchmark.{idx}.best", result.best) |
| 280 | + logger.log(f"benchmark.{idx}.worst", result.worst) |
| 281 | + logger.log(f"benchmark.{idx}.fwd_bw", result.fwd_bw) |
| 282 | + logger.log(f"benchmark.{idx}.bwd_bw", result.bwd_bw) |
| 283 | + logger.log(f"benchmark.{idx}.combined_bw", result.combined_bw) |
| 284 | + logger.log(f"benchmark.{idx}.speedup", speedup) |
| 285 | + logger.log( |
| 286 | + f"benchmark.{idx}.message", |
| 287 | + ( |
| 288 | + f"fwd+bwd={result.mean / 1e6:.3f} ms, " |
| 289 | + f"fwd_bw={result.fwd_bw:.1f} GB/s, " |
| 290 | + f"bwd_bw={result.bwd_bw:.1f} GB/s, " |
| 291 | + f"combined_bw={result.combined_bw:.1f} GB/s, " |
| 292 | + f"speedup={speedup:.2f}x" |
| 293 | + ), |
| 294 | + ) |
| 295 | + |
| 296 | + logger.log("check", "pass" if passed else "fail") |
| 297 | + return 0 if passed else 112 |
| 298 | + |
| 299 | + |
| 300 | +def main(): |
| 301 | + fd = os.getenv("POPCORN_FD") |
| 302 | + if not fd: |
| 303 | + return 111 |
| 304 | + |
| 305 | + if len(sys.argv) < 3: |
| 306 | + return 2 |
| 307 | + |
| 308 | + if not torch.cuda.is_available(): |
| 309 | + with PopcornOutput(int(fd)) as logger: |
| 310 | + logger.log("check", "fail") |
| 311 | + logger.log("error", "No CUDA GPU available. This script requires a GPU.") |
| 312 | + return 112 |
| 313 | + |
| 314 | + mode = sys.argv[1] |
| 315 | + tests = get_test_cases(sys.argv[2]) |
| 316 | + |
| 317 | + with PopcornOutput(int(fd)) as logger: |
| 318 | + if mode == "test": |
| 319 | + return run_testing(logger, tests) |
| 320 | + if mode in {"benchmark", "leaderboard"}: |
| 321 | + return run_benchmarking(logger, tests) |
| 322 | + |
| 323 | + logger.log("check", "fail") |
| 324 | + logger.log("error", f"Unsupported mode: {mode}") |
| 325 | + return 2 |
| 326 | + |
| 327 | + |
| 328 | +if __name__ == "__main__": |
| 329 | + raise SystemExit(main()) |
0 commit comments