Skip to content

fix blockwise FP8 scaled_mm scale layout in Float8BlockwiseLinear#4229

Open
iamzainhuda wants to merge 1 commit intomainfrom
func-scaled-mm
Open

fix blockwise FP8 scaled_mm scale layout in Float8BlockwiseLinear#4229
iamzainhuda wants to merge 1 commit intomainfrom
func-scaled-mm

Conversation

@iamzainhuda
Copy link
Copy Markdown
Contributor

@iamzainhuda iamzainhuda commented Apr 2, 2026

Summary

  • replace the non-Triton blockwise FP8 matmul path with functional scaled_mm semantics for forward, grad_x, and grad_weight
  • fix the BlockWise128x128 RHS scale layout by padding the K-block dimension to a multiple of 4, as required by cuBLASLt
  • fix the BlockWise1x128 RHS scale orientation in grad_weight by transposing the activation scales before the matmul
  • use aten._scaled_mm_v2 only under torch.compile(fullgraph=True) so the compiler can trace the op without graph-breaking on the Python F.scaled_mm wrapper

What was going wrong?

The issue was not just that we were calling torch._scaled_mm, but that this path in our blockwise FP8 linear code was not matching the cuBLASLt scale-layout contract for blockwise scaling.
In particular:

  • grad_x = grad_output @ weight uses RHS BlockWise128x128 scales
  • cuBLASLt requires those scales to be K-major, with the K-block dimension padded to a multiple of 4
  • our code was passing the unpadded layout, which caused incorrect behavior in the scaled-mm backend

There was a second layout bug in grad_weight = grad_output^T @ x:

  • the RHS uses BlockWise1x128 scaling
  • the quantized activation scales had the right values but the wrong orientation for the RHS scaled-mm call
  • transposing those scales fixes the contract mismatch

Also to note:

  • eager mode uses torch.nn.functional.scaled_mm
  • compile mode calls aten._scaled_mm_v2 directly because, in this torch build, Dynamo fullgraph cannot trace through the Python F.scaled_mm wrapper even though it lowers to the same underlying op

Testing

 pytest -q test/prototype/blockwise_fp8_training/test_blockwise_linear.py

cc @slayton58 @drisspg as part of #4209

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 2, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4229

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 551f3c7 with merge base b1ddd15 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 2, 2026
@iamzainhuda iamzainhuda marked this pull request as draft April 2, 2026 22:45
@iamzainhuda iamzainhuda added the module: training quantize_ api training flow label Apr 8, 2026
@iamzainhuda iamzainhuda changed the title [draft] use F.scaled_mm instead of torch._scaled_mm for blockwise linear GEMM fix blockwise FP8 scaled_mm scale layout in Float8BlockwiseLinear Apr 8, 2026
@iamzainhuda iamzainhuda marked this pull request as ready for review April 8, 2026 13:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant