Skip to content

[codex] add repro for blockwise _scaled_mm investigation#4209

Draft
iamzainhuda wants to merge 2 commits intomainfrom
blockwise_linear_failure
Draft

[codex] add repro for blockwise _scaled_mm investigation#4209
iamzainhuda wants to merge 2 commits intomainfrom
blockwise_linear_failure

Conversation

@iamzainhuda
Copy link
Copy Markdown
Contributor

Summary

  • add a standalone CUDA repro script for the blockwise _scaled_mm backward-style path
  • capture the current investigation findings in one place so the issue can be reproduced and discussed upstream

What Changed

  • adds test/prototype/blockwise_fp8_training/repro_scaled_mm_blockwise_issue.py
  • the script compares four results on the same quantized inputs and reciprocal scales:
    • FP32 reference matmul
    • explicit dequantized matmul
    • Triton blockwise GEMM
    • torch._scaled_mm
  • it uses a deterministic nn.Linear weight initialization and probes k in {128000, 128128, 128256}

Findings So Far

  • for k=128000 (k_blocks=1000), _scaled_mm, Triton, and the explicit dequantized matmul agree:
    • _scaled_mm SQNR vs FP32: 31.14 dB
    • Triton SQNR vs FP32: 31.14 dB
  • for k=128128 (k_blocks=1001) and k=128256 (k_blocks=1002), Triton remains healthy while _scaled_mm collapses:
    • k=128128
      • _scaled_mm SQNR vs FP32: -71.57 dB
      • Triton SQNR vs FP32: 31.21 dB
      • _scaled_mm norm: 12,406,713
      • FP32 ref norm: 3,273.87
    • k=128256
      • _scaled_mm SQNR vs FP32: -73.20 dB
      • Triton SQNR vs FP32: 31.20 dB
      • _scaled_mm norm: 14,976,587
      • FP32 ref norm: 3,275.04
  • because Triton and the explicit dequantized matmul agree on the same FP8 tensors and reciprocal scales, the evidence points away from the quantization kernels and toward the _scaled_mm CUDA path or an undocumented shape/recipe restriction for this (BlockWise1x128, BlockWise128x128) regime

Impact

  • this explains the original blockwise linear backward failure when _scaled_mm was used on the grad path
  • the current workaround of using Triton in backward remains the safe path until _scaled_mm is better understood or fixed upstream

Validation

  • python test/prototype/blockwise_fp8_training/repro_scaled_mm_blockwise_issue.py

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 31, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4209

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a787043 with merge base f11eff8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 31, 2026
@danielvegamyhre
Copy link
Copy Markdown
Contributor

this is an interesting finding @iamzainhuda, cc @slayton58 @drisspg @ngimel who may be interested as well

@slayton58
Copy link
Copy Markdown

What GPU is this run on? And does it still repro with 1x128 x 1x128, or only with 1x128 x 128x128?

There's also torch.nn.functional.scaled_mm which is a newer, (now) better tested path for scaled gemms -- it would be worth switching to that and seeing if the problem persists.

@iamzainhuda
Copy link
Copy Markdown
Contributor Author

@slayton58 this is on H100s. it doesn't repro with 1x128 x 1x128. ran some quick tests, it seems like i can get it to work on the failing shapes with torch.nn.functional.scaled_mm will formalize this a bit further.

@slayton58
Copy link
Copy Markdown

128x128 scales have some... specific rules about formats, see cublas docs and padding of scales

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants