[codex] add repro for blockwise _scaled_mm investigation#4209
[codex] add repro for blockwise _scaled_mm investigation#4209iamzainhuda wants to merge 2 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4209
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a787043 with merge base f11eff8 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
this is an interesting finding @iamzainhuda, cc @slayton58 @drisspg @ngimel who may be interested as well |
|
What GPU is this run on? And does it still repro with 1x128 x 1x128, or only with 1x128 x 128x128? There's also |
|
@slayton58 this is on H100s. it doesn't repro with 1x128 x 1x128. ran some quick tests, it seems like i can get it to work on the failing shapes with |
|
128x128 scales have some... specific rules about formats, see cublas docs and padding of scales |
Summary
_scaled_mmbackward-style pathWhat Changed
test/prototype/blockwise_fp8_training/repro_scaled_mm_blockwise_issue.pytorch._scaled_mmnn.Linearweight initialization and probesk in {128000, 128128, 128256}Findings So Far
k=128000(k_blocks=1000),_scaled_mm, Triton, and the explicit dequantized matmul agree:_scaled_mmSQNR vs FP32:31.14 dB31.14 dBk=128128(k_blocks=1001) andk=128256(k_blocks=1002), Triton remains healthy while_scaled_mmcollapses:k=128128_scaled_mmSQNR vs FP32:-71.57 dB31.21 dB_scaled_mmnorm:12,406,7133,273.87k=128256_scaled_mmSQNR vs FP32:-73.20 dB31.20 dB_scaled_mmnorm:14,976,5873,275.04_scaled_mmCUDA path or an undocumented shape/recipe restriction for this(BlockWise1x128, BlockWise128x128)regimeImpact
_scaled_mmwas used on the grad path_scaled_mmis better understood or fixed upstreamValidation
python test/prototype/blockwise_fp8_training/repro_scaled_mm_blockwise_issue.py