fix blockwise FP8 scaled_mm scale layout in Float8BlockwiseLinear#4229
Open
iamzainhuda wants to merge 1 commit intomainfrom
Open
fix blockwise FP8 scaled_mm scale layout in Float8BlockwiseLinear#4229iamzainhuda wants to merge 1 commit intomainfrom
iamzainhuda wants to merge 1 commit intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4229
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 551f3c7 with merge base b1ddd15 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
What was going wrong?
The issue was not just that we were calling torch._scaled_mm, but that this path in our blockwise FP8 linear code was not matching the cuBLASLt scale-layout contract for blockwise scaling.
In particular:
There was a second layout bug in grad_weight = grad_output^T @ x:
Also to note:
Testing
cc @slayton58 @drisspg as part of #4209