Skip to content

Fix QMoE INT GEMV NaN and harden against partial K tiles#29067

Open
tianleiwu wants to merge 3 commits into
mainfrom
tlwu/20260615/address_qmoe_feedback
Open

Fix QMoE INT GEMV NaN and harden against partial K tiles#29067
tianleiwu wants to merge 3 commits into
mainfrom
tlwu/20260615/address_qmoe_feedback

Conversation

@tianleiwu

Copy link
Copy Markdown
Contributor

PR: Fix QMoE INT GEMV NaN and harden against partial K tiles

Description

The symmetric INT4/INT8 MoE decode GEMV could emit NaN/garbage when a GEMM
reduction dimension was not a whole multiple of the 64-element interleaved-weight
K tile (e.g. intermediate_size such as 544). The interleaved weight layout's
CUTLASS K iterator reads K in whole tiles of 64; a partial final tile makes
threads read past the valid activation range. This PR fixes the decode GEMV
selection gate to reject such shapes, adds an explicit up-front validation in the
QMoE op so the grouped GEMM path fails with a clear error instead of silently
producing wrong results, and folds in several QMoE review-feedback cleanups
(checked size arithmetic, env-var parsing, and documentation).

Summary of Changes

NaN fix and hardening (INT weight-only path)

File Change
onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu is_moe_gemv_supported now rejects k % kTileSizeK != 0 (64), so the decode GEMV is not selected for a partial final K tile and the path falls back to the grouped GEMM.
onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc Added an up-front guard for quant_type == "int": hidden_size (fc1.K) and inter_size (fc2.K) must be multiples of 64 (the interleaved-weight K tile); otherwise return INVALID_ARGUMENT with a clear message instead of computing garbage.

Review-feedback cleanups

File Change
onnxruntime/contrib_ops/cuda/moe/moe.cc Use SafeInt<size_t> for scratch byte-count arithmetic (expanded rows × element sizes) feeding a single allocation.
onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc Same SafeInt<size_t> scratch-size hardening.
onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu Parse ORT_MOE_GEMV_FP16_ACCUM via ParseEnvironmentVariableWithDefault<int>; include env_var_utils.h after dispatcher.h (SHARED_PROVIDER guard ordering, documented inline).
onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu Parse ORT_DISABLE_MOE_GEMV via the same helper; clarify fast-path comments (symmetric INT4/INT8, per-column or block-wise).
onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h Comment explaining the worst-case sizing of kFinegrainedScaleRowsPerStage for the smallest fine-grained group size.

Documentation

File Change
docs/contrib_ops/cuda/moe_qmoe.md Document the swiglu_fusion=0 + SwiGLU backward-compatibility remap (gpt-oss-20b interleaved layout) and the one-time warning.
docs/contrib_ops/cuda/qmoe_gemv_experiments.md Note that recorded numbers are point-in-time baselines tied to the listed GPU/driver/CUDA/ORT build.

Testing

  • python -m pytest onnxruntime/test/python/transformers/test_qmoe_cuda.py -k "gemv or swiglu or block or PrePack or prepack" — 84 passed, 6 skipped on H200 (sm90).
  • New TestSwigluQMoE::test_swiglu_qmoe_int_partial_ktile_rejected builds an inter_size=544 (= 17×32, partial 64 tile) INT8 SwiGLU QMoE and asserts the run raises "inter_size to be a multiple of 64".
  • New TestSwigluQMoE::test_swiglu_qmoe_fusion0_remap_parity exercises the swiglu_fusion=0 → interleaved remap parity.
  • TestQMoEIntPrePackSmoke::test_int4_swiglu_interleaved_small bumped from inter_size=32 (a now-rejected partial K tile) to 64.
  • ORT_ENABLE_FP4_GEMV=1 python -m pytest onnxruntime/test/python/transformers/test_qmoe_fp4_cuda.py — no failures (the new guard is scoped to quant_type == "int", so FP4/FP8 are unaffected).
  • lintrunner clean on the changed C++ and Python files.

Motivation and Context

The interleaved column-major weight layout (ColumnMajorTileInterleave<64, …>)
requires the GEMM reduction dim K to be a whole multiple of ThreadblockK (64
for fp16/bf16 activations). The single-matrix fpA_intB GEMM already throws on
this, but the grouped MoE GEMM and the decode GEMV had no equivalent guard and
silently produced NaN/garbage. This PR closes that gap at the QMoE boundary
(clear error) and in the GEMV dispatch gate (safe fallback). No supported,
64-aligned shape changes behavior.

Checklist

  • Tests added/updated
  • Documentation updated (if applicable)
  • No breaking changes (rejected shapes were already producing incorrect output)
  • CI passes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant