Skip to content

Commit 80686fb

Browse files
authored
upgrade input data scope to avoid cheating
1 parent 0b74249 commit 80686fb

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

problems/amd_202602/mixed-mla/reference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,13 @@ def generate_input(batchsize: int, qseqlen: int, kvseqlen: int, seed: int) -> in
306306
q = torch.randn(
307307
(total_q, NUM_HEADS, QK_HEAD_DIM),
308308
dtype=torch.bfloat16, device="cuda", generator=gen,
309-
) * 0.02
309+
)
310310

311311
# Compressed KV buffer: (total_kv, 1, 576) bf16 — the source of truth
312312
kv_buffer_bf16 = torch.randn(
313313
(total_kv, NUM_KV_HEADS, QK_HEAD_DIM),
314314
dtype=torch.bfloat16, device="cuda", generator=gen,
315-
) * 0.02
315+
)
316316

317317
# Quantize KV to fp8
318318
kv_buffer_fp8, kv_scale_fp8 = quantize_fp8(kv_buffer_bf16)

0 commit comments

Comments
 (0)