Skip to content

Commit 6c537c9

Browse files
authored
[codex] add princeton 2026 cross entropy problem (#141)
* add princeton 2026 cross entropy problem * rename princeton problem paths * remove tri naming
1 parent 8b0e3ea commit 6c537c9

5 files changed

Lines changed: 450 additions & 0 deletions

File tree

Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
import dataclasses
2+
import math
3+
import os
4+
import re
5+
import statistics
6+
import sys
7+
from pathlib import Path
8+
9+
import torch
10+
import torch.nn.functional as F
11+
12+
from reference import (
13+
ATOL,
14+
DTYPE,
15+
RTOL,
16+
generate_inputs,
17+
reference_backward,
18+
reference_forward,
19+
)
20+
21+
22+
# Original eval parameters
23+
B = 4_096
24+
WARMUP_ITERS = 20
25+
BENCH_ITERS = 100
26+
27+
28+
class PopcornOutput:
29+
def __init__(self, fd: int):
30+
self.file = os.fdopen(fd, "w")
31+
os.set_inheritable(fd, False)
32+
33+
def __enter__(self):
34+
return self
35+
36+
def __exit__(self, exc_type, exc_val, exc_tb):
37+
self.file.close()
38+
39+
def print(self, *args, **kwargs):
40+
print(*args, **kwargs, file=self.file, flush=True)
41+
42+
def log(self, key, value):
43+
self.print(f"{key}: {value}")
44+
45+
46+
@dataclasses.dataclass
47+
class TestCase:
48+
args: dict
49+
spec: str
50+
51+
52+
@dataclasses.dataclass
53+
class Stats:
54+
runs: int
55+
mean: float
56+
std: float
57+
err: float
58+
best: float
59+
worst: float
60+
fwd_bw: float
61+
bwd_bw: float
62+
combined_bw: float
63+
64+
65+
def get_test_cases(file_name: str) -> list[TestCase]:
66+
try:
67+
content = Path(file_name).read_text()
68+
except Exception as exc:
69+
print(f"Could not open test file `{file_name}`: {exc}", file=sys.stderr)
70+
sys.exit(113)
71+
72+
tests = []
73+
lines = content.splitlines()
74+
match = r"\s*([a-zA-Z_]+):\s*([a-zA-Z_]+|[+-]?[0-9]+)\s*"
75+
for line in lines:
76+
if not line.strip():
77+
continue
78+
parts = line.split(";")
79+
case = {}
80+
for part in parts:
81+
matched = re.match(match, part)
82+
if not re.fullmatch(match, part):
83+
print(f"invalid test case: '{line}': '{part}'", file=sys.stderr)
84+
sys.exit(113)
85+
key = matched[1]
86+
value = matched[2]
87+
try:
88+
value = int(value)
89+
except ValueError:
90+
pass
91+
case[key] = value
92+
tests.append(TestCase(spec=line, args=case))
93+
return tests
94+
95+
96+
def load_submission():
97+
import submission
98+
99+
for fn_name in ("cross_entropy_forward", "cross_entropy_backward"):
100+
if not hasattr(submission, fn_name):
101+
raise AttributeError(f"Submission is missing function '{fn_name}'.")
102+
return submission
103+
104+
105+
def check_correctness(mod, vocab_size):
106+
logits, targets, grad_output = generate_inputs(B, vocab_size)
107+
108+
ref_loss = reference_forward(logits, targets)
109+
sub_loss = mod.cross_entropy_forward(logits, targets)
110+
111+
assert sub_loss.shape == ref_loss.shape, (
112+
f"Forward shape mismatch: expected {ref_loss.shape}, got {sub_loss.shape}"
113+
)
114+
assert sub_loss.dtype == torch.float32, (
115+
f"Forward dtype mismatch: expected float32, got {sub_loss.dtype}"
116+
)
117+
118+
fwd_close = torch.allclose(sub_loss, ref_loss, atol=ATOL, rtol=RTOL)
119+
max_fwd_err = (sub_loss - ref_loss).abs().max().item()
120+
121+
ref_grad = reference_backward(logits, targets, grad_output)
122+
sub_grad = mod.cross_entropy_backward(logits, targets, grad_output)
123+
124+
assert sub_grad.shape == ref_grad.shape, (
125+
f"Backward shape mismatch: expected {ref_grad.shape}, got {sub_grad.shape}"
126+
)
127+
assert sub_grad.dtype == DTYPE, (
128+
f"Backward dtype mismatch: expected {DTYPE}, got {sub_grad.dtype}"
129+
)
130+
131+
bwd_close = torch.allclose(sub_grad, ref_grad, atol=ATOL, rtol=RTOL)
132+
max_bwd_err = (sub_grad.float() - ref_grad.float()).abs().max().item()
133+
134+
return fwd_close, bwd_close, max_fwd_err, max_bwd_err
135+
136+
137+
def benchmark_one(mod, vocab_size):
138+
logits, targets, grad_output = generate_inputs(B, vocab_size, seed=123)
139+
140+
for _ in range(WARMUP_ITERS):
141+
mod.cross_entropy_forward(logits, targets)
142+
mod.cross_entropy_backward(logits, targets, grad_output)
143+
torch.cuda.synchronize()
144+
145+
fwd_times = []
146+
for _ in range(BENCH_ITERS):
147+
start = torch.cuda.Event(enable_timing=True)
148+
end = torch.cuda.Event(enable_timing=True)
149+
start.record()
150+
mod.cross_entropy_forward(logits, targets)
151+
end.record()
152+
torch.cuda.synchronize()
153+
fwd_times.append(start.elapsed_time(end))
154+
155+
bwd_times = []
156+
for _ in range(BENCH_ITERS):
157+
start = torch.cuda.Event(enable_timing=True)
158+
end = torch.cuda.Event(enable_timing=True)
159+
start.record()
160+
mod.cross_entropy_backward(logits, targets, grad_output)
161+
end.record()
162+
torch.cuda.synchronize()
163+
bwd_times.append(start.elapsed_time(end))
164+
165+
combined_times = []
166+
for _ in range(BENCH_ITERS):
167+
start = torch.cuda.Event(enable_timing=True)
168+
end = torch.cuda.Event(enable_timing=True)
169+
start.record()
170+
mod.cross_entropy_forward(logits, targets)
171+
mod.cross_entropy_backward(logits, targets, grad_output)
172+
end.record()
173+
torch.cuda.synchronize()
174+
combined_times.append(start.elapsed_time(end))
175+
176+
fwd_ms = statistics.median(fwd_times)
177+
bwd_ms = statistics.median(bwd_times)
178+
combined_ms = statistics.median(combined_times)
179+
180+
fwd_bytes = 2 * B * vocab_size + 12 * B
181+
bwd_bytes = 4 * B * vocab_size + 12 * B
182+
total_bytes = fwd_bytes + bwd_bytes
183+
184+
fwd_bw = fwd_bytes / (fwd_ms * 1e-3) / 1e9
185+
bwd_bw = bwd_bytes / (bwd_ms * 1e-3) / 1e9
186+
combined_bw = total_bytes / (combined_ms * 1e-3) / 1e9
187+
188+
# Keep KernelBot scoring on the exact reported metric: median combined ms.
189+
return Stats(
190+
runs=BENCH_ITERS,
191+
mean=combined_ms * 1e6,
192+
std=statistics.pstdev(combined_times) * 1e6,
193+
err=(statistics.pstdev(combined_times) / math.sqrt(len(combined_times))) * 1e6,
194+
best=min(combined_times) * 1e6,
195+
worst=max(combined_times) * 1e6,
196+
fwd_bw=fwd_bw,
197+
bwd_bw=bwd_bw,
198+
combined_bw=combined_bw,
199+
)
200+
201+
202+
def run_testing(logger: PopcornOutput, tests: list[TestCase]) -> int:
203+
try:
204+
mod = load_submission()
205+
except Exception as exc:
206+
logger.log("check", "fail")
207+
logger.log("error", str(exc))
208+
return 112
209+
210+
passed = True
211+
logger.log("test-count", len(tests))
212+
for idx, test in enumerate(tests):
213+
vocab_size = int(test.args["vocab_size"])
214+
logger.log(f"test.{idx}.spec", test.spec)
215+
try:
216+
fwd_ok, bwd_ok, fwd_err, bwd_err = check_correctness(mod, vocab_size)
217+
if fwd_ok and bwd_ok:
218+
logger.log(f"test.{idx}.status", "pass")
219+
logger.log(
220+
f"test.{idx}.message",
221+
f"forward max err={fwd_err:.3e}, backward max err={bwd_err:.3e}",
222+
)
223+
else:
224+
logger.log(f"test.{idx}.status", "fail")
225+
logger.log(
226+
f"test.{idx}.error",
227+
f"forward max err={fwd_err:.3e} {'OK' if fwd_ok else 'FAIL'}; "
228+
f"backward max err={bwd_err:.3e} {'OK' if bwd_ok else 'FAIL'}",
229+
)
230+
passed = False
231+
except Exception as exc:
232+
logger.log(f"test.{idx}.status", "fail")
233+
logger.log(f"test.{idx}.error", str(exc))
234+
passed = False
235+
236+
logger.log("check", "pass" if passed else "fail")
237+
return 0 if passed else 112
238+
239+
240+
def run_benchmarking(logger: PopcornOutput, tests: list[TestCase]) -> int:
241+
try:
242+
mod = load_submission()
243+
except Exception as exc:
244+
logger.log("check", "fail")
245+
logger.log("error", str(exc))
246+
return 112
247+
248+
baseline_mod = type(sys)("baseline")
249+
baseline_mod.cross_entropy_forward = (
250+
lambda logits, targets: F.cross_entropy(logits.float(), targets, reduction="none")
251+
)
252+
253+
def baseline_bwd(logits, targets, grad_output):
254+
probs = torch.softmax(logits.float(), dim=-1)
255+
probs[torch.arange(logits.shape[0], device=logits.device), targets] -= 1.0
256+
return (probs * grad_output.unsqueeze(1)).to(logits.dtype)
257+
258+
baseline_mod.cross_entropy_backward = baseline_bwd
259+
260+
passed = True
261+
logger.log("benchmark-count", len(tests))
262+
for idx, test in enumerate(tests):
263+
vocab_size = int(test.args["vocab_size"])
264+
logger.log(f"benchmark.{idx}.spec", test.spec)
265+
try:
266+
baseline = benchmark_one(baseline_mod, vocab_size)
267+
result = benchmark_one(mod, vocab_size)
268+
speedup = baseline.mean / result.mean
269+
except Exception as exc:
270+
logger.log(f"benchmark.{idx}.status", "fail")
271+
logger.log(f"benchmark.{idx}.error", str(exc))
272+
passed = False
273+
continue
274+
275+
logger.log(f"benchmark.{idx}.runs", result.runs)
276+
logger.log(f"benchmark.{idx}.mean", result.mean)
277+
logger.log(f"benchmark.{idx}.std", result.std)
278+
logger.log(f"benchmark.{idx}.err", result.err)
279+
logger.log(f"benchmark.{idx}.best", result.best)
280+
logger.log(f"benchmark.{idx}.worst", result.worst)
281+
logger.log(f"benchmark.{idx}.fwd_bw", result.fwd_bw)
282+
logger.log(f"benchmark.{idx}.bwd_bw", result.bwd_bw)
283+
logger.log(f"benchmark.{idx}.combined_bw", result.combined_bw)
284+
logger.log(f"benchmark.{idx}.speedup", speedup)
285+
logger.log(
286+
f"benchmark.{idx}.message",
287+
(
288+
f"fwd+bwd={result.mean / 1e6:.3f} ms, "
289+
f"fwd_bw={result.fwd_bw:.1f} GB/s, "
290+
f"bwd_bw={result.bwd_bw:.1f} GB/s, "
291+
f"combined_bw={result.combined_bw:.1f} GB/s, "
292+
f"speedup={speedup:.2f}x"
293+
),
294+
)
295+
296+
logger.log("check", "pass" if passed else "fail")
297+
return 0 if passed else 112
298+
299+
300+
def main():
301+
fd = os.getenv("POPCORN_FD")
302+
if not fd:
303+
return 111
304+
305+
if len(sys.argv) < 3:
306+
return 2
307+
308+
if not torch.cuda.is_available():
309+
with PopcornOutput(int(fd)) as logger:
310+
logger.log("check", "fail")
311+
logger.log("error", "No CUDA GPU available. This script requires a GPU.")
312+
return 112
313+
314+
mode = sys.argv[1]
315+
tests = get_test_cases(sys.argv[2])
316+
317+
with PopcornOutput(int(fd)) as logger:
318+
if mode == "test":
319+
return run_testing(logger, tests)
320+
if mode in {"benchmark", "leaderboard"}:
321+
return run_benchmarking(logger, tests)
322+
323+
logger.log("check", "fail")
324+
logger.log("error", f"Unsupported mode: {mode}")
325+
return 2
326+
327+
328+
if __name__ == "__main__":
329+
raise SystemExit(main())
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
5+
DTYPE = torch.bfloat16
6+
DEVICE = "cuda"
7+
ATOL = 1e-3
8+
RTOL = 1e-2
9+
10+
11+
def reference_forward(logits, targets):
12+
return F.cross_entropy(logits.float(), targets, reduction="none")
13+
14+
15+
def reference_backward(logits, targets, grad_output):
16+
probs = torch.softmax(logits.float(), dim=-1)
17+
grad = probs
18+
grad[torch.arange(logits.shape[0], device=logits.device), targets] -= 1.0
19+
grad = grad * grad_output.unsqueeze(1)
20+
return grad.to(logits.dtype)
21+
22+
23+
def generate_inputs(batch_size, vocab_size, seed=42):
24+
torch.manual_seed(seed)
25+
logits = torch.randn(batch_size, vocab_size, dtype=DTYPE, device=DEVICE)
26+
targets = torch.randint(0, vocab_size, (batch_size,), device=DEVICE)
27+
grad_output = torch.randn(batch_size, dtype=torch.float32, device=DEVICE)
28+
return logits, targets, grad_output

0 commit comments

Comments
 (0)