Skip to content

[CUDA] Fix QMoE int4/int8 weight prepack to always use SM80 layout#28978

Merged
tianleiwu merged 6 commits into
mainfrom
tlwu/refactor_qmoe_prepack_sm
Jun 11, 2026
Merged

[CUDA] Fix QMoE int4/int8 weight prepack to always use SM80 layout#28978
tianleiwu merged 6 commits into
mainfrom
tlwu/refactor_qmoe_prepack_sm

Conversation

@tianleiwu

@tianleiwu tianleiwu commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Summary

The CUDA QMoE INT4/INT8 grouped GEMM always dispatches to the Ampere (SM80) CUTLASS kernel — even on Hopper (SM90) — because mixed int-weight + fp16/bf16 activation is not a valid Hopper TMA warp-specialized specialisation. This PR makes weight prepacking always emit the SM80 (column-interleaved) fpA_intB layout regardless of the runtime device SM, fixing silently-wrong output on Hopper, and centralizes the arch-clamping logic in a single shared helper. It also cleans up the related tests and tightens MoE parity tolerances that were too loose to catch the layout bug.

Motivation

#28749 uses 90 for sm90 weight prepacking.

On SM90, isValidHopperMOESpecialisation<half_t, uint4b_t/uint8_t>() is false, so the grouped MoE GEMM falls back to the SM80 kernel. The weight preprocessor, however, skips column interleaving for arch == 90, so an auto-detected (force_arch=-1) pack on an H200 produced the non-interleaved SM90 layout that the SM80 kernel cannot consume — yielding wrong results. The previous PrePackIntExpertWeights logic clamped to sm_ (passing SM90 through), and the test that exercised the offline packer used auto-detect, so both could emit the wrong layout.

Key Changes

Area Change
fpA_intB_gemm_preprocessors{.h,_impl.cu} Extracted get_arch_for_mixed_gemm_weight_preprocess(int arch) as a shared, declared helper (clamps SM to the layout group: <80→75, 90→90, else 80).
fpA_intB_gemm_preprocessors_impl.h getLayoutDetailsForTransform now routes through the shared helper instead of duplicating the arch-range logic.
moe_quantization.cc (PrePackIntExpertWeights) Always packs INT4/INT8 expert weights for the SM80 layout (get_arch_for_mixed_gemm_weight_preprocess(80)) instead of clamping to the runtime sm_, since the SM80 kernel runs on every GPU.
onnxruntime_pybind_quant.cc (PackWeightsForMixedGemm) Replaced the ad-hoc {75,80,90} allowlist with the shared helper, so force_arch is clamped consistently with the runtime dispatch (removes the now-unused <set> include).
contrib_defs.cc / moe_quantization.h Updated weights_prepacked schema/field docs: layouts for -1/1 are EP-determined; for the CUDA EP -1 and 1 are equivalent today (both SM80), 1 reserved for a future Hopper-specific layout.
test_qmoe_cuda.py Removed the dead, never-called preprocess_weights_for_mixed_gemm helper; the real path (quant_dequant_blockwise) already pins sm=80.
test_moe_cuda.py Pinned the offline packer to arch=80, and tightened FP16 QMoE parity tolerance from atol 3.0 (4-bit) / 2.0 (8-bit) to 0.5 now that the layout is correct.
docs/ Regenerated ContribOperators.md and updated moe_qmoe.md to match the new schema docs and SM80-always packing rationale.

Testing Notes

On an H200 (SM90), with the CUDA 12.x/13.x Python wheel:

python -m pytest onnxruntime/test/python/transformers/test_qmoe_cuda.py
python -m pytest onnxruntime/test/python/transformers/test_moe_cuda.py -k "PhiQMoE or qmoe"
  • test_qmoe_cuda.py SwiGLU parity: SM80 layout → max diff ~0.001 (pass, tol 0.1); the prior SM90 layout produced max diff ~1.2 (fail), confirming the fix.
  • test_moe_cuda.py TestPhiQMoE (4-bit and 8-bit, all batch/seq combinations): worst observed max_diff ≈ 0.375 with the fixed layout, comfortably under the new atol=0.5.
  • ruff check passes on both edited test files.

@tianleiwu tianleiwu requested a review from justinchuby June 10, 2026 19:27
@tianleiwu tianleiwu changed the title Fix QMoE int4/int8 weight prepack to always use SM80 layout [CUDA] Fix QMoE int4/int8 weight prepack to always use SM80 layout Jun 10, 2026
justinchuby
justinchuby previously approved these changes Jun 10, 2026

@justinchuby justinchuby left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this!

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes incorrect INT4/INT8 QMoE results on Hopper by ensuring weight prepacking always produces the SM80 (column-interleaved) fpA_intB layout that the grouped MoE GEMM actually consumes (since the runtime dispatch always routes to the SM80 kernel even on SM90). It also centralizes the “arch → layout group” clamping logic and updates tests/docs to reflect the SM80-always behavior.

Changes:

  • Add a shared get_arch_for_mixed_gemm_weight_preprocess() helper and route layout selection through it.
  • Force CUDA QMoE INT4/INT8 expert-weight PrePack to emit SM80 layout irrespective of runtime SM.
  • Update Python tests and documentation to pin/describe SM80 layout usage and tighten parity tolerances.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
onnxruntime/test/python/transformers/test_qmoe_cuda.py Removes dead mixed-GEMM prepack helper; minor comment text update.
onnxruntime/test/python/transformers/test_moe_cuda.py Pins offline packing to arch=80 and tightens FP16 INT4/INT8 tolerances.
onnxruntime/python/onnxruntime_pybind_quant.cc Replaces ad-hoc SM allowlist with centralized arch-clamping helper.
onnxruntime/core/graph/contrib_ops/contrib_defs.cc Updates QMoE weights_prepacked attribute docs to be EP-determined.
onnxruntime/contrib_ops/cuda/moe/moe_quantization.h Expands/clarifies CUDA EP semantics around weights_prepacked.
onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc Forces SM80 layout packing for int weights; updates comments and gating text.
onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h Declares shared get_arch_for_mixed_gemm_weight_preprocess() helper.
onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h Uses the shared helper to select layout details for transforms.
onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu Implements the shared helper that clamps SM to layout groups.
docs/ContribOperators.md Regenerates schema docs for updated weights_prepacked description.
docs/contrib_ops/cuda/moe_qmoe.md Updates CUDA MoE/QMoE documentation for weights_prepacked and SM80-always rationale.

Comment thread onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc
Comment thread onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc
Comment thread docs/contrib_ops/cuda/moe_qmoe.md Outdated
Comment thread docs/contrib_ops/cuda/moe_qmoe.md Outdated
@tianleiwu tianleiwu enabled auto-merge (squash) June 10, 2026 22:05
@tianleiwu tianleiwu merged commit 2cf6c6c into main Jun 11, 2026
87 checks passed
@tianleiwu tianleiwu deleted the tlwu/refactor_qmoe_prepack_sm branch June 11, 2026 00:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants