File tree Expand file tree Collapse file tree
problems/princeton/cross_entropy_py Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1+ #!POPCORN leaderboard princeton_cross_entropy
2+
13"""
24Baseline submission for the cross-entropy problem.
35
46Replace these functions with a faster implementation.
57
8+ The evaluator uses:
9+ - B = 4096
10+ - V in {32000, 50264, 128256}
11+ - V % 8 == 0
12+ - finite real-valued logits (no masking with -inf)
13+
614Example local bandwidth calculation for the three ranked shapes:
715
816 def print_max_bw(batch_size, vocab_size, combined_ms):
Original file line number Diff line number Diff line change @@ -13,19 +13,27 @@ description: |
1313 - cross_entropy_backward(logits, targets, grad_output) -> grad_logits
1414
1515 Inputs:
16- - logits: torch.bfloat16 tensor of shape (B, V)
16+ - logits: torch.bfloat16 tensor of real-valued, finite logits with shape (B, V)
1717 - targets: torch.int64 tensor of shape (B,)
1818 - grad_output: torch.float32 tensor of shape (B,)
1919
2020 Outputs:
2121 - forward output: torch.float32 tensor of shape (B,)
2222 - backward output: torch.bfloat16 tensor of shape (B, V)
2323
24+ Assumptions used by the evaluator and benchmark:
25+ - batch size is fixed at B = 4096
26+ - vocab sizes are V in {32000, 50264, 128256}
27+ - vocab size is guaranteed to be divisible by 8
28+ - logits are ordinary real numbers; masked values such as -inf are not used
29+
2430config :
2531 main : " eval.py"
2632
2733tests :
2834 - {"vocab_size": 32000}
35+ - {"vocab_size": 50264}
36+ - {"vocab_size": 128256}
2937
3038benchmarks :
3139 - {"vocab_size": 32000}
You can’t perform that action at this time.
0 commit comments