Fix Shape→Gather→TopK regression: preserve rank-1 single-element index output in data propagation#29084
Fix Shape→Gather→TopK regression: preserve rank-1 single-element index output in data propagation#29084titaiwangms wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Pull request overview
This PR fixes a shape-inference data-propagation regression where Shape → Gather([-1]) could incorrectly drop a rank-1 single-element output to a scalar, causing valid models that feed TopK’s K input to fail load-time shape inference. The change updates Gather’s custom propagation to route based on index rank, makes scalar-only elementwise propagation (Add/Sub/Mul/Div) tolerant to rank-1 [1] single-element values, and adds targeted regression tests + shared helpers.
Changes:
- Teach
Graph::SaveShapeValuesFromDataPropagation’s initializer-reader to also report initializer rank (num dims), enabling rank-based routing in custom propagation. - Preserve rank for single-element propagated values via a new shared helper and update Gather/Add/Sub/Mul/Div propagation accordingly.
- Add shape-inference regression tests and new testdata model generators for the affected patterns.
Reviewed changes
Copilot reviewed 23 out of 26 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/graph/graph.cc | Extend initializer-value reader to also report initializer rank (num dims). |
| onnxruntime/core/graph/data_propagation/custom_data_propagation.h | Introduce GetInitializedInputValuesFunc typedef carrying values + rank. |
| onnxruntime/core/graph/data_propagation/custom_data_propagation.cc | Thread the new initializer-reader signature through custom propagation factory. |
| onnxruntime/core/graph/data_propagation/data_propagation_value_utils.h | New shared helpers for reading/writing single-element propagated values while preserving rank. |
| onnxruntime/core/graph/data_propagation/gather_op_data_propagation.{h,cc} | Fix Gather propagation to preserve rank-1 [1] vs scalar based on indices rank. |
| onnxruntime/core/graph/data_propagation/add_op_data_propagation.{h,cc} | Allow propagation through Add when operands are scalar or rank-1 [1]. |
| onnxruntime/core/graph/data_propagation/sub_op_data_propagation.{h,cc} | Allow propagation through Sub when operands are scalar or rank-1 [1]. |
| onnxruntime/core/graph/data_propagation/mul_op_data_propagation.{h,cc} | Allow propagation through Mul when operands are scalar or rank-1 [1]. |
| onnxruntime/core/graph/data_propagation/div_op_data_propagation.{h,cc} | Allow propagation through Div when operands are scalar or rank-1 [1] and guard div-by-zero. |
| onnxruntime/core/graph/data_propagation/squeeze_op_data_propagation.{h,cc} | Update signature usage for initializer-reader (axes). |
| onnxruntime/core/graph/data_propagation/unsqueeze_op_data_propagation.{h,cc} | Update signature usage for initializer-reader (axes). |
| onnxruntime/core/graph/data_propagation/size_op_data_propagation.h | Update signature type in ctor to match base class. |
| onnxruntime/test/framework/shape_inference_test.cc | Add regression + guard tests covering Gather→TopK, Gather→Mul→TopK, and multi-element non-collapse. |
| onnxruntime/test/testdata/test_shape_data_propagation_gather_topk.py | Generator for the core Gather→TopK regression model. |
| onnxruntime/test/testdata/test_shape_data_propagation_gather_mul_topk.py | Generator for the Gather→Mul→TopK chain regression model. |
| onnxruntime/test/testdata/test_shape_data_propagation_shape_mul_constantofshape.py | Generator for multi-element guard model (no scalar collapse). |
Review summary — multi-model review team (readability · code · critical · deep · integration)Solid, well-contained fix. The shared Findings below, deduplicated and prioritized. None are hard blockers, but two Major1. Rank ≥ 2 single-element index is collapsed to rank-1 ( 2. Cross-module ripple into Unsqueeze/Squeeze data propagation ( Minor3. Signed-integer overflow (UB) in Add/Sub/Mul/Div. 4. 5. Reader/writer channel precedence. 6. Doc/behavior drift on 7. Test coverage gap. End-to-end rank-preservation is exercised only via the 8. Readability nits
PraiseSourcing the index rank from the same Reviewed by a 5-model team (Claude Sonnet · GPT-5.3-Codex · GPT-5.5 · Claude Opus · Gemini 3.1 Pro). Findings are advisory. |
1f862c2 to
14865d9
Compare
|
Thanks for the thorough review! Pushed an update (now at 14865d9) addressing the feedback. Summary of changes:
Full Intentionally deferred (happy to fold either in if you'd prefer):
Unrelated CI note: the |
GatherOpDataPropagation::infer() guarded on indices.size() == 1 (element
count) and called SetInferredShapeScalarValue() unconditionally. That guard
is true for BOTH a 0-D scalar index and a 1-D single-element index, so the
1-D case had its rank dropped: Graph::getInputData() then emitted a 0-D
(dimensionless) TensorProto for the propagated value.
For the common Shape -> Gather([-1]) -> TopK exporter pattern this produced a
0-D K initializer, which ONNX TopK shape inference correctly rejects ("K
input must be a one-dimensional tensor of size 1.") at Graph::Resolve time.
The model is spec-valid (a 1-D Gather index yields a rank-1 Gather output),
so this was an ORT rank-preservation bug. It reproduces even at
ORT_DISABLE_ALL, where constant folding never runs, confirming the cause is
shape-inference data propagation rather than constant folding.
Changes:
- Gather: route by the index rank instead of element count. A 0-D scalar
index stores a scalar value; a 1-D single-element index stores a rank-1
value, so getInputData() emits a TensorProto with dims=[1] and downstream
TopK sees a valid 1-D size-1 K; a rank >= 2 index (or an index whose rank is
unknown) declines and falls back to ONNX data propagation, because the
single-value channel cannot faithfully represent a rank >= 2 Gather output.
The decision is a pure classifier (ClassifySingleValueRank). The index rank
is sourced from the same constant initializer the index value comes from
(via get_initialized_input_values now reporting the initializer rank),
instead of a second, independently resolved NodeArg shape -- removing a
potential source-of-truth drift.
- Elementwise consumers (Add/Sub/Mul/Div): previously scalar-only, they would
silently stop propagating once an operand became a rank-1 value (e.g. a
Shape -> Gather(1-D idx) -> Mul -> TopK chain), since the custom propagation
result replaces ONNX's rank-correct fallback. They now accept a single
element carried as either a rank-0 scalar or a rank-1 [1] value and keep the
output rank consistent with ONNX broadcasting (rank-1 if any operand is
rank-1, else scalar), so such chains keep propagating end to end. Div also
guards against division by zero.
- Unsqueeze: decline rather than propagate a dubious value when unsqueezing a
single-element (scalar-like, rank-1 [1]) value, whose result is a rank >= 2
tensor the values channel cannot faithfully represent (it would otherwise
fabricate a misleading [1, value]). Multi-element shape vectors are
unaffected. Squeeze is left unchanged -- it already converts a rank-1 [1]
value to the correct scalar.
- Add shared helpers (data_propagation_value_utils.h) for reading/writing a
single-element shape value while preserving its rank, used by Gather, the
elementwise ops and Unsqueeze so producers and consumers cannot disagree on
rank. The reader declines a rank-1 multi-element value (it must never
collapse to element[0]), so a multi-element value cannot be mistaken for a
single one. The setter is correct-by-construction: it populates exactly one
channel and clears the other, so the scalar-first reader and the values-first
getInputData() can never disagree on rank even if the output carried a stale
value from a prior pass.
Tests: add ShapeInferenceV2Test.GatherToTopKRankPreservationTest and
GatherMulToTopKRankPreservationTest (with fixtures and generators) that load
the model at the disabled, basic, and all-optimization levels and assert the
rank-1 K is preserved and propagated through the chain, plus
GatherSqueezeRangeRankPreservationTest, an
observable end-to-end lock that a Shape -> Gather([-1]) -> Squeeze -> Range
chain resolves Range's length to the concrete propagated K (locking Squeeze's
correct-scalar behavior through a real downstream consumer). Add unit tests
pinning the pure helpers directly: SinglePropagatedShapeValueGuardTest (reader
behavior on each channel), SetSinglePropagatedShapeValueKeepsSingleChannelTest
(the setter clears the opposite channel), ClassifySingleValueRankRoutingTest
(scalar/rank-1/decline routing) and ShouldDeclineUnsqueezeSingleValueTest (the
Unsqueeze decline decision; a graph-level test is structurally infeasible
because a rank >= 2 result cannot flow to any legal ONNX shape-consumer). Add
ShapeMulMultiElementNoScalarCollapseTest, an end-to-end check that a
Shape -> Mul -> ConstantOfShape multi-element chain still resolves to its full
rank-2 shape. The data-propagation regression tests run at the disabled, basic,
and all-optimization levels (data propagation executes in the pre-optimization
Graph::Resolve pass, so it is independent of the graph-optimization level).
Existing scalar-index data-prop fixtures continue to exercise the scalar path
unchanged.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: titaiwangms <titaiwang@microsoft.com>
14865d9 to
90ff823
Compare
|
Pushed a small follow-up (now at Root cause: the failure was an AddressSanitizer allocator footprint effect, not a build break or a test assertion failure — the full ASan test binary exhausted ASan's 8 GB Change in this commit: trims the four data-propagation opt-level test loops from five levels back to No production code changed from the previous revision. Full |
| effective_num_dims = (indices_shape != nullptr) ? indices_shape->dim_size() : -1; | ||
| } | ||
|
|
||
| switch (ClassifySingleValueRank(effective_num_dims)) { |
There was a problem hiding this comment.
nit: This enum and helper are only consumed by the switch in gather_op_data_propagation.cc, where it just maps num_dims to a boolean is_rank1 with a decline for ≥ 2. This can be replaced with a simple if/else inline in the Gather propagation:
if (effective_num_dims == 0) {
SetSinglePropagatedShapeValue(output_def_, dim.dim_value(), /*is_rank1=*/false);
} else if (effective_num_dims == 1) {
SetSinglePropagatedShapeValue(output_def_, dim.dim_value(), /*is_rank1=*/true);
}
// else: rank >= 2 or unknown — declineIt makes code much simpler.
| // predicate ShouldDeclineUnsqueezeSingleValue (see ShouldDeclineUnsqueezeSingleValueTest); | ||
| // full-suite-green is only the no-regression backstop. Multi-element shape vectors are | ||
| // unaffected. | ||
| if (ShouldDeclineUnsqueezeSingleValue(tensor_shape_proto.dim_size())) { |
There was a problem hiding this comment.
nit: this is Unsqueeze-specific logic ("should Unsqueeze decline propagating this value?") and is only called from unsqueeze_op_data_propagation.cc. It belongs alongside the Unsqueeze implementation, not in the shared utils.
Summary
A spec-valid
Shape → Gather(1-D index [-1]) → TopKmodel fails to load since ORT 1.25.0 with:The model is valid: a rank-1 (single-element) Gather index produces a rank-1 Gather output, so the value feeding TopK's
Kinput is a 1-D size-1 tensor — exactly what TopK requires. The failure was an ORT rank-preservation bug in shape-inference data propagation, not a problem with the model.Root cause.
GatherOpDataPropagation::infer()routed by element count rather than index rank: it guarded onindices.size() == 1, which is true for both a 0-D scalar index and a 1-D single-element index, and then unconditionally calledSetInferredShapeScalarValue(). That dropped the rank of the spec-valid 1-D size-1 case, soGraph::getInputData()emitted a 0-D (dimensionless) propagated value. ONNX TopK shape inference then correctly rejected the 0-DK. This path was introduced by #26269 (partial data propagation to enhance shape inference).This reproduces even at
GraphOptimizationLevel.ORT_DISABLE_ALL, where constant folding never runs — confirming the cause is data propagation in shape inference, not constant folding (#26345 was an earlier mis-attribution; see the corrected analysis).Fixes the regression reported in #29072. Corrected root-cause analysis: #29072 (comment)
The fix
getInputData()emits aTensorProtowithdims=[1]and downstream TopK sees a valid 1-D size-1K. The index rank is taken from the same constant initializer the index value comes from (viaget_initialized_input_valuesnow reporting the initializer rank), rather than a second, independently-resolvedNodeArgshape — removing a potential source-of-truth drift (EDGE Remove vsts test runner in cmake file #2).Shape → Gather(1-D idx) → Mul → TopKchain), because the custom-propagation result replaces ONNX's rank-correct fallback. They now accept a single element carried as either a rank-0 scalar or a rank-1[1]value and keep the output rank consistent with ONNX broadcasting (rank-1 if any operand is rank-1, else scalar), so such chains keep propagating end-to-end. Div additionally guards against division by zero.data_propagation_value_utils.h). Centralizes reading/writing a single-element shape value while preserving its rank, used by both the Gather producer and the elementwise consumers so they cannot disagree on rank. The reader declines a rank-1 multi-element value (it must never collapse toelement[0]), so a multi-element value can never be mistaken for a single one.Testing
Five
ShapeInferenceV2Testcases (with fixtures + generators), all loading the model at every optimization level (includingORT_DISABLE_ALL):GatherToTopKRankPreservationTest— the coreShape → Gather([-1]) → TopKregression; asserts the rank-1Kis preserved.GatherMulToTopKRankPreservationTest— the… → Gather(1-D idx) → Mul → TopKchain; asserts propagation survives the elementwise op.SinglePropagatedShapeValueGuardTest— a direct unit test pinning the shared reader's behavior on each channel (scalar, rank-1 single-element, rank-1 multi-element, symbolic, empty). Mutation-proven: relaxing thedim_size()==1guard makes this test fail, restoring it makes it pass — so the guard the whole fix hinges on is test-locked.ShapeMulMultiElementNoScalarCollapseTest— end-to-end check that a multi-elementShape → Mul → ConstantOfShapechain still resolves to its full rank-2 shape (no bogus scalar collapse).PartialDataPropagationTest— pre-existing scalar-index coverage, unchanged.Full
onnxruntime_test_allsuite passes (0 failures) on top of the currentmain(opset-27 / ONNX 1.22.0 integration). The constant-folding memory path (#26345) is untouched — the diff is confined todata_propagation/, a smallgraph.ccchange, and tests.Follow-ups (intentionally out of scope for this PR)
[1,1]) to decline rather than route as rank-1 — needs its own discriminating unit test; pathological/non-exporter, worst case is degraded inference rather than a crash.DCO
Commit is DCO signed-off.