Fix QMoE INT GEMV NaN and harden against partial K tiles#29067
Open
tianleiwu wants to merge 3 commits into
Open
Fix QMoE INT GEMV NaN and harden against partial K tiles#29067tianleiwu wants to merge 3 commits into
tianleiwu wants to merge 3 commits into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR: Fix QMoE INT GEMV NaN and harden against partial K tiles
Description
The symmetric INT4/INT8 MoE decode GEMV could emit
NaN/garbage when a GEMMreduction dimension was not a whole multiple of the 64-element interleaved-weight
K tile (e.g.
intermediate_sizesuch as 544). The interleaved weight layout'sCUTLASS K iterator reads K in whole tiles of 64; a partial final tile makes
threads read past the valid activation range. This PR fixes the decode GEMV
selection gate to reject such shapes, adds an explicit up-front validation in the
QMoE op so the grouped GEMM path fails with a clear error instead of silently
producing wrong results, and folds in several QMoE review-feedback cleanups
(checked size arithmetic, env-var parsing, and documentation).
Summary of Changes
NaN fix and hardening (INT weight-only path)
onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cuis_moe_gemv_supportednow rejectsk % kTileSizeK != 0(64), so the decode GEMV is not selected for a partial final K tile and the path falls back to the grouped GEMM.onnxruntime/contrib_ops/cuda/moe/moe_quantization.ccquant_type == "int":hidden_size(fc1.K) andinter_size(fc2.K) must be multiples of 64 (the interleaved-weight K tile); otherwise returnINVALID_ARGUMENTwith a clear message instead of computing garbage.Review-feedback cleanups
onnxruntime/contrib_ops/cuda/moe/moe.ccSafeInt<size_t>for scratch byte-count arithmetic (expanded rows × element sizes) feeding a single allocation.onnxruntime/contrib_ops/cuda/moe/moe_quantization.ccSafeInt<size_t>scratch-size hardening.onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cuORT_MOE_GEMV_FP16_ACCUMviaParseEnvironmentVariableWithDefault<int>; includeenv_var_utils.hafterdispatcher.h(SHARED_PROVIDER guard ordering, documented inline).onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cuORT_DISABLE_MOE_GEMVvia the same helper; clarify fast-path comments (symmetric INT4/INT8, per-column or block-wise).onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.hkFinegrainedScaleRowsPerStagefor the smallest fine-grained group size.Documentation
docs/contrib_ops/cuda/moe_qmoe.mdswiglu_fusion=0+ SwiGLU backward-compatibility remap (gpt-oss-20b interleaved layout) and the one-time warning.docs/contrib_ops/cuda/qmoe_gemv_experiments.mdTesting
python -m pytest onnxruntime/test/python/transformers/test_qmoe_cuda.py -k "gemv or swiglu or block or PrePack or prepack"— 84 passed, 6 skipped on H200 (sm90).TestSwigluQMoE::test_swiglu_qmoe_int_partial_ktile_rejectedbuilds aninter_size=544(= 17×32, partial 64 tile) INT8 SwiGLU QMoE and asserts the run raises"inter_size to be a multiple of 64".TestSwigluQMoE::test_swiglu_qmoe_fusion0_remap_parityexercises theswiglu_fusion=0→ interleaved remap parity.TestQMoEIntPrePackSmoke::test_int4_swiglu_interleaved_smallbumped frominter_size=32(a now-rejected partial K tile) to64.ORT_ENABLE_FP4_GEMV=1 python -m pytest onnxruntime/test/python/transformers/test_qmoe_fp4_cuda.py— no failures (the new guard is scoped toquant_type == "int", so FP4/FP8 are unaffected).lintrunnerclean on the changed C++ and Python files.Motivation and Context
The interleaved column-major weight layout (
ColumnMajorTileInterleave<64, …>)requires the GEMM reduction dim K to be a whole multiple of
ThreadblockK(64for fp16/bf16 activations). The single-matrix
fpA_intBGEMM already throws onthis, but the grouped MoE GEMM and the decode GEMV had no equivalent guard and
silently produced
NaN/garbage. This PR closes that gap at the QMoE boundary(clear error) and in the GEMV dispatch gate (safe fallback). No supported,
64-aligned shape changes behavior.
Checklist