Skip to content

QMoE: fail loudly when weights_prepacked=0 but PrePack did not run#28965

Merged
justinchuby merged 4 commits into
mainfrom
copilot/fix-silent-wrong-output-with-prepacked-weights
Jun 11, 2026
Merged

QMoE: fail loudly when weights_prepacked=0 but PrePack did not run#28965
justinchuby merged 4 commits into
mainfrom
copilot/fix-silent-wrong-output-with-prepacked-weights

Conversation

Copilot AI commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Description

When a QMoE model sets weights_prepacked=0 (raw [E, N, K/pack] int weights) and the session has session.disable_prepacking, PrePack() never runs, so packed_fc{1,2}_weights_ stay null and int_weights_consumed_by_prepack is false. The code then falls through to the raw initializer pointers — but those bytes are not in CUTLASS layout, so the runner consumes them as-if-prepacked and produces silently wrong output with no diagnostic.

Changes in onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc (QMoE::ComputeInternal):

  • Int path: Added a defensive INVALID_ARGUMENT guard — when is_int && !weights_prepacked_ but either prepack buffer is null, return a clear error instead of feeding non-CUTLASS bytes to the runner.
  • wfp4afp8 native path: Same fall-through (packed_fp4_fc{1,2}_weights_ ? ... : raw) replaced with an explicit guard that errors when the repacked FP4 buffers were not produced.

Also added a focused regression test in onnxruntime/test/contrib_ops/moe_test.cc covering quant_type='int' with weights_prepacked=0 and session.disable_prepacking=1, asserting that QMoE fails with an actionable error instead of producing output.

Merged the branch with the latest main.

Motivation and Context

A prior fix removed the null-pointer crash on this path but left a misleading-success outcome that is newly user-reachable via the weights_prepacked=0 contract — the exact silent-failure mode the offline-path work set out to eliminate. These guards convert that into a loud, actionable error. The wfp4afp8 branch shares the same fall-through and is hardened for consistency.

The added regression test ensures this fail-loudly behavior remains covered going forward, especially when prepacking is disabled at the session level.

Copilot AI changed the title [WIP] Fix silent wrong output when weights_prepacked is set to 0 QMoE: fail loudly when weights_prepacked=0 but PrePack did not run Jun 9, 2026
Copilot AI requested a review from justinchuby June 9, 2026 20:45
@justinchuby justinchuby requested a review from Copilot June 9, 2026 21:51
@justinchuby justinchuby marked this pull request as ready for review June 9, 2026 21:51

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR hardens the CUDA QMoE execution path to avoid silent wrong outputs when models rely on PrePack()-produced CUTLASS layouts (int weights_prepacked=0, and native wfp4afp8) but prepacking never ran (e.g., session.disable_prepacking).

Changes:

  • Add an INVALID_ARGUMENT guard for quant_type='int' && weights_prepacked=0 when the required prepacked int-weight buffers are missing.
  • Add an INVALID_ARGUMENT guard for native wfp4afp8 when the repacked FP4 weight buffers are missing, instead of falling back to raw initializer bytes.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc

@tianleiwu tianleiwu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verdict: Approve.

This is a clean, well-scoped defensive fix that converts a silent wrong-output path into a loud, actionable error.

Correctness

  • The int guard is placed right after int_weights_consumed_by_prepack is computed and fires exactly when is_int && !weights_prepacked_ but one of the prepack buffers is null — i.e. the negation of int_weights_consumed_by_prepack within that subset. After the guard, the weights_prepacked=0 int path is guaranteed to have both buffers, which also closes the partial-prepack concern noted in the surrounding comment.
  • The wfp4afp8 native guard is correct: PrePackRepackFP4Weights is the only producer of packed_fp4_fc{1,2}_weights_, and the weights_prepacked attribute is int-only, so a null repacked buffer on the native path can only mean PrePack did not run. The previous fall-through to raw initializer bytes was genuinely unsafe.
  • The is_packed handling on those paths keeps the source initializer alive, so failing loudly here does not risk dangling pointers.

Tests

  • The new QMoETest_CUDA_Int4_DisablePrepackingFailsLoudly exercises the int branch via session.disable_prepacking=1 and asserts the actionable failure message. Because the guard returns before any kernel launch, it runs on any SM ≥ 700 and doesn't depend on CUTLASS shape constraints.

One optional maintainability suggestion left inline.

Comment thread onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc
@justinchuby justinchuby merged commit cf509d8 into main Jun 11, 2026
84 of 85 checks passed
@justinchuby justinchuby deleted the copilot/fix-silent-wrong-output-with-prepacked-weights branch June 11, 2026 18:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

QMoE: When a model sets weights_prepacked=0 runner can consume weights as-if-prepacked and produces silently wrong output with no diagnostic

4 participants