Skip to content

Fix Shape→Gather→TopK regression: preserve rank-1 single-element index output in data propagation#29084

Open
titaiwangms wants to merge 1 commit into
mainfrom
fix/topk-gather-dataprop-rank
Open

Fix Shape→Gather→TopK regression: preserve rank-1 single-element index output in data propagation#29084
titaiwangms wants to merge 1 commit into
mainfrom
fix/topk-gather-dataprop-rank

Conversation

@titaiwangms

@titaiwangms titaiwangms commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Summary

A spec-valid Shape → Gather(1-D index [-1]) → TopK model fails to load since ORT 1.25.0 with:

K input must be a one-dimensional tensor of size 1.

The model is valid: a rank-1 (single-element) Gather index produces a rank-1 Gather output, so the value feeding TopK's K input is a 1-D size-1 tensor — exactly what TopK requires. The failure was an ORT rank-preservation bug in shape-inference data propagation, not a problem with the model.

Root cause. GatherOpDataPropagation::infer() routed by element count rather than index rank: it guarded on indices.size() == 1, which is true for both a 0-D scalar index and a 1-D single-element index, and then unconditionally called SetInferredShapeScalarValue(). That dropped the rank of the spec-valid 1-D size-1 case, so Graph::getInputData() emitted a 0-D (dimensionless) propagated value. ONNX TopK shape inference then correctly rejected the 0-D K. This path was introduced by #26269 (partial data propagation to enhance shape inference).

This reproduces even at GraphOptimizationLevel.ORT_DISABLE_ALL, where constant folding never runs — confirming the cause is data propagation in shape inference, not constant folding (#26345 was an earlier mis-attribution; see the corrected analysis).

Fixes the regression reported in #29072. Corrected root-cause analysis: #29072 (comment)

The fix

  • Gather — rank-based routing. Distinguish the index rank instead of its element count. A genuine 0-D scalar index still stores a scalar value; a rank-1 single-element index now stores a rank-1 value, so getInputData() emits a TensorProto with dims=[1] and downstream TopK sees a valid 1-D size-1 K. The index rank is taken from the same constant initializer the index value comes from (via get_initialized_input_values now reporting the initializer rank), rather than a second, independently-resolved NodeArg shape — removing a potential source-of-truth drift (EDGE Remove vsts test runner in cmake file #2).
  • Rank-tolerant elementwise companion (Add/Sub/Mul/Div). These ops were scalar-only and would silently stop propagating once an operand became a rank-1 value (e.g. a Shape → Gather(1-D idx) → Mul → TopK chain), because the custom-propagation result replaces ONNX's rank-correct fallback. They now accept a single element carried as either a rank-0 scalar or a rank-1 [1] value and keep the output rank consistent with ONNX broadcasting (rank-1 if any operand is rank-1, else scalar), so such chains keep propagating end-to-end. Div additionally guards against division by zero.
  • Shared helper (data_propagation_value_utils.h). Centralizes reading/writing a single-element shape value while preserving its rank, used by both the Gather producer and the elementwise consumers so they cannot disagree on rank. The reader declines a rank-1 multi-element value (it must never collapse to element[0]), so a multi-element value can never be mistaken for a single one.

Testing

Five ShapeInferenceV2Test cases (with fixtures + generators), all loading the model at every optimization level (including ORT_DISABLE_ALL):

  • GatherToTopKRankPreservationTest — the core Shape → Gather([-1]) → TopK regression; asserts the rank-1 K is preserved.
  • GatherMulToTopKRankPreservationTest — the … → Gather(1-D idx) → Mul → TopK chain; asserts propagation survives the elementwise op.
  • SinglePropagatedShapeValueGuardTest — a direct unit test pinning the shared reader's behavior on each channel (scalar, rank-1 single-element, rank-1 multi-element, symbolic, empty). Mutation-proven: relaxing the dim_size()==1 guard makes this test fail, restoring it makes it pass — so the guard the whole fix hinges on is test-locked.
  • ShapeMulMultiElementNoScalarCollapseTest — end-to-end check that a multi-element Shape → Mul → ConstantOfShape chain still resolves to its full rank-2 shape (no bogus scalar collapse).
  • PartialDataPropagationTest — pre-existing scalar-index coverage, unchanged.

Full onnxruntime_test_all suite passes (0 failures) on top of the current main (opset-27 / ONNX 1.22.0 integration). The constant-folding memory path (#26345) is untouched — the diff is confined to data_propagation/, a small graph.cc change, and tests.

Follow-ups (intentionally out of scope for this PR)

  • Hardening for a rank ≥ 2 single-element index (e.g. shape [1,1]) to decline rather than route as rank-1 — needs its own discriminating unit test; pathological/non-exporter, worst case is degraded inference rather than a crash.
  • Explicit end-to-end coverage for Add/Sub/Div rank-1 chains (the shared-reader unit test already covers the read path for all four ops; only Mul is currently exercised end-to-end).
  • Minor readability nits.

DCO

Commit is DCO signed-off.

@titaiwangms titaiwangms added regression issues that demonstrate a regression in ORT functionality and need to be addressed immediately and removed regression issues that demonstrate a regression in ORT functionality and need to be addressed immediately labels Jun 16, 2026

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a shape-inference data-propagation regression where Shape → Gather([-1]) could incorrectly drop a rank-1 single-element output to a scalar, causing valid models that feed TopK’s K input to fail load-time shape inference. The change updates Gather’s custom propagation to route based on index rank, makes scalar-only elementwise propagation (Add/Sub/Mul/Div) tolerant to rank-1 [1] single-element values, and adds targeted regression tests + shared helpers.

Changes:

  • Teach Graph::SaveShapeValuesFromDataPropagation’s initializer-reader to also report initializer rank (num dims), enabling rank-based routing in custom propagation.
  • Preserve rank for single-element propagated values via a new shared helper and update Gather/Add/Sub/Mul/Div propagation accordingly.
  • Add shape-inference regression tests and new testdata model generators for the affected patterns.

Reviewed changes

Copilot reviewed 23 out of 26 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
onnxruntime/core/graph/graph.cc Extend initializer-value reader to also report initializer rank (num dims).
onnxruntime/core/graph/data_propagation/custom_data_propagation.h Introduce GetInitializedInputValuesFunc typedef carrying values + rank.
onnxruntime/core/graph/data_propagation/custom_data_propagation.cc Thread the new initializer-reader signature through custom propagation factory.
onnxruntime/core/graph/data_propagation/data_propagation_value_utils.h New shared helpers for reading/writing single-element propagated values while preserving rank.
onnxruntime/core/graph/data_propagation/gather_op_data_propagation.{h,cc} Fix Gather propagation to preserve rank-1 [1] vs scalar based on indices rank.
onnxruntime/core/graph/data_propagation/add_op_data_propagation.{h,cc} Allow propagation through Add when operands are scalar or rank-1 [1].
onnxruntime/core/graph/data_propagation/sub_op_data_propagation.{h,cc} Allow propagation through Sub when operands are scalar or rank-1 [1].
onnxruntime/core/graph/data_propagation/mul_op_data_propagation.{h,cc} Allow propagation through Mul when operands are scalar or rank-1 [1].
onnxruntime/core/graph/data_propagation/div_op_data_propagation.{h,cc} Allow propagation through Div when operands are scalar or rank-1 [1] and guard div-by-zero.
onnxruntime/core/graph/data_propagation/squeeze_op_data_propagation.{h,cc} Update signature usage for initializer-reader (axes).
onnxruntime/core/graph/data_propagation/unsqueeze_op_data_propagation.{h,cc} Update signature usage for initializer-reader (axes).
onnxruntime/core/graph/data_propagation/size_op_data_propagation.h Update signature type in ctor to match base class.
onnxruntime/test/framework/shape_inference_test.cc Add regression + guard tests covering Gather→TopK, Gather→Mul→TopK, and multi-element non-collapse.
onnxruntime/test/testdata/test_shape_data_propagation_gather_topk.py Generator for the core Gather→TopK regression model.
onnxruntime/test/testdata/test_shape_data_propagation_gather_mul_topk.py Generator for the Gather→Mul→TopK chain regression model.
onnxruntime/test/testdata/test_shape_data_propagation_shape_mul_constantofshape.py Generator for multi-element guard model (no scalar collapse).

Comment thread onnxruntime/core/graph/data_propagation/gather_op_data_propagation.cc Outdated
Comment thread onnxruntime/core/graph/data_propagation/squeeze_op_data_propagation.cc Outdated
Comment thread onnxruntime/core/graph/data_propagation/unsqueeze_op_data_propagation.cc Outdated
Comment thread onnxruntime/test/framework/shape_inference_test.cc
Comment thread onnxruntime/test/framework/shape_inference_test.cc
Comment thread onnxruntime/test/framework/shape_inference_test.cc
@titaiwangms

Copy link
Copy Markdown
Contributor Author

Review summary — multi-model review team (readability · code · critical · deep · integration)

Solid, well-contained fix. The shared data_propagation_value_utils.h extraction is the right architectural move, the rank-1 representation is correctly grounded in how Graph::getInputData() reconstructs dims=[1], the GetInitializedInputValuesFunc signature change was threaded through every callsite (gather/squeeze/unsqueeze/size + factory), and the SinglePropagatedShapeValueGuardTest is genuinely strong (mutation-locks the load-bearing dim_size()==1 guard, covers scalar/rank-1/multi-element/symbolic/empty). Spec arithmetic checks out against the ONNX Gather rank rule (out_rank = q + (r−1), data rank r=1out_rank = q) and multidirectional broadcasting.

Findings below, deduplicated and prioritized. None are hard blockers, but two Major items are worth addressing before merge.

Major

1. Rank ≥ 2 single-element index is collapsed to rank-1 (gather_op_data_propagation.cc).
indices_is_scalar = (indices_num_dims == 0); is_rank1 = !indices_is_scalar routes every non-scalar single-element index as rank-1. For a [1,1] (q=2) single-element index, true Gather output rank is q = 2, but it's emitted as rank-1 [1]. The two-channel representation can't hold rank ≥ 2 anyway, so this can make an otherwise invalid model (e.g. rank-2 K into TopK) silently pass with a fabricated rank-1 value. The PR already lists this as an out-of-scope follow-up, but the fix is tiny and conservative: decline propagation when indices_num_dims >= 2 (only set a value when index rank ≤ 1). Declining falls back to a symbolic dim — strictly safer than emitting a rank the representation can't honestly carry.

2. Cross-module ripple into Unsqueeze/Squeeze data propagation (unsqueeze_op_data_propagation.cc:34, same for squeeze).
Now that Gather correctly emits a rank-1 [value] instead of a scalar, a Gather(1-D idx) → Unsqueeze chain stops hitting the scalar branch (:23) and instead falls into the else if (GetInferredShapeValues().has_value()) branch (:34), which inserts a 1 and produces [1, value] — i.e. it treats the elements of a 1-D shape value as a shape. A rank-1 single-element input to Unsqueeze truly yields a rank-2 tensor, which inferred_shape_values_ (rank-1 only) cannot represent. Models that previously worked around the old Gather scalar bug via Gather → Unsqueeze → Concat could now get a corrupted propagated shape. Recommend the else if branch return Status::OK() (decline) when it can't faithfully represent the result, in both Unsqueeze and Squeeze. Please confirm with a Gather(1-D) → Unsqueeze(0) → Concat fixture.

Minor

3. Signed-integer overflow (UB) in Add/Sub/Mul/Div. lhs+rhs / lhs-rhs / lhs*rhs on int64_t, plus INT64_MIN / -1 (the new Div guard only checks rhs != 0). Largely pre-existing (the scalar path already did this) and bounded in practice for shape values, but the helper now routes more cases through it. Consider checked arithmetic that declines propagation on overflow, and an explicit INT64_MIN && -1 guard in Div.

4. SetSinglePropagatedShapeValue scalar path does not clear a stale values channel. Safe today only because Graph::CleanUpShapeValuesFromDataPropagation() resets both channels between passes, and getInputData() prefers the values channel — so a stale rank-1 value would win over a fresh scalar. Strong consensus across reviewers that this is fragile cross-file coupling. Consider clearing the opposite channel in the setter (or ORT_ENFORCE the freshness precondition) to make the invariant local.

5. Reader/writer channel precedence. TryGetSinglePropagatedShapeValue() prefers the scalar channel first; getInputData() prefers the values channel. Only matters if a NodeArg ever carries both (cleanup currently prevents this) — worth aligning or documenting the single-channel invariant alongside #4.

6. Doc/behavior drift on GetInitializedInputValuesFunc. The doxygen says input_values is "empty if the input is not a constant initializer," but the lambda leaves it unchanged; same caveat for num_dims. Either clear at the top of the lambda or correct the comment to "left unchanged."

7. Test coverage gap. End-to-end rank-preservation is exercised only via the Mul chain. Add direct fixtures for Gather → Add/Sub/Div → TopK, including a Div with RHS 0 asserting propagation is skipped.

8. GatherMulToTopKRankPreservationTest comment is slightly off. Once Gather emits rank-1 values, both Mul operands are on the values channel, so ONNX's own MathOpDataPropagator handles this model — ORT's custom Mul rank-1 branch is not dispatched here (it dispatches only when an operand is on the scalar channel). The custom Mul path is genuinely needed for mixed scalar × rank-1 inputs (e.g. Size → scalar feeding Mul with a Gather rank-1), which would be a more faithful fixture. Test still validly locks end-to-end behavior.

Readability nits

  • Prefer std::optional<int>& num_dims over the -1 sentinel out-param — self-documenting, and -1 collides with the ONNX "last axis" convention. Callers that don't need the rank pass a throwaway optional.
  • The 3-line broadcasting rationale comment is copy-pasted verbatim into add/sub/mul/div — centralize it on SetSinglePropagatedShapeValue and cross-reference.
  • axes_num_dims in squeeze/unsqueeze is declared, passed, never read — add a // rank unused here note (or use the throwaway optional from above).
  • Consider a tiny struct PropagatedValue { int64_t value; bool is_rank1; } to retire the naked bool is_rank1.

Praise

Sourcing the index rank from the same TensorProto the values came from (rather than a second NodeArg lookup) is exactly how to keep value/rank from drifting; the dim_size()==1 decline-on-multi-element guard is the right load-bearing invariant and it's mutation-tested; the Div rhs != 0 guard is a clean correctness improvement.

Reviewed by a 5-model team (Claude Sonnet · GPT-5.3-Codex · GPT-5.5 · Claude Opus · Gemini 3.1 Pro). Findings are advisory.

@titaiwangms titaiwangms force-pushed the fix/topk-gather-dataprop-rank branch from 1f862c2 to 14865d9 Compare June 16, 2026 23:57
@titaiwangms

Copy link
Copy Markdown
Contributor Author

Thanks for the thorough review! Pushed an update (now at 14865d9) addressing the feedback. Summary of changes:

  • Setter now correct-by-construction (was: scalar path didn't clear a stale values channel): SetSinglePropagatedShapeValue populates exactly one channel and clears the other, so the scalar-first reader and values-first getInputData() can never disagree on rank. Locked by SetSinglePropagatedShapeValueKeepsSingleChannelTest (asserts each write clears the opposite channel).
  • Gather declines on rank>=2 / unknown single-element index (was: could fabricate a rank-1 value): routing is now a pure classifier ClassifySingleValueRank (0->scalar, 1->rank-1 [1], >=2 or unknown->decline to ONNX fallback). Locked by ClassifySingleValueRankRoutingTest (mutation-proven). The original rank-1 fix is unchanged: a 1-D [-1] index still routes to rank-1 [1] so TopK gets a concrete K.
  • Unsqueeze declines on the unrepresentable rank>=2 single-element result (was: emitted a misleading [1, K]): extracted into the pure predicate ShouldDeclineUnsqueezeSingleValue, locked by ShouldDeclineUnsqueezeSingleValueTest. Note: a graph-level test is not constructible here — the rank-2 single-element result can't legally feed any ONNX shape-consumer (all require 1-D/scalar), and channel values are cleared at the end of Graph::Resolve(), so the result is unobservable end-to-end; the pure-predicate unit test is the lock. An observable companion (GatherSqueezeRangeRankPreservationTest, Shape->Gather->Squeeze->Range) locks Squeeze's correct scalar behavior through a real consumer.
  • Opt-level coverage: the data-propagation regression tests now run at all five graph optimization levels (added ORT_ENABLE_EXTENDED and ORT_ENABLE_LAYOUT).
  • [[maybe_unused]] on the axes_num_dims sentinel in squeeze/unsqueeze.
  • Testdata generators: removed the duplicate onnx import (CodeQL); the regenerated .onnx files are byte-identical.

Full onnxruntime_test_all suite passes (1824 run / 0 failed), and the original Shape->Gather->TopK non-regression is verified at every optimization level including ORT_DISABLE_ALL.

Intentionally deferred (happy to fold either in if you'd prefer):

  • Signed-integer overflow/UB hardening (e.g. INT64_MIN / -1) in the Add/Sub/Mul/Div data propagation — this is a pre-existing pattern not introduced by this PR, so it seems better suited to a focused follow-up.
  • Additional Add/Sub/Div end-to-end coverage — the shared-reader unit test already exercises the read path for all four elementwise ops, so e2e for each is structurally redundant.

Unrelated CI note: the wasm_Debug check appears to be failing on a Dawn/emdawnwebgpu header-codegen infra step; this PR touches no WASM/WebGPU files. Let me know if a re-run is warranted.

GatherOpDataPropagation::infer() guarded on indices.size() == 1 (element
count) and called SetInferredShapeScalarValue() unconditionally. That guard
is true for BOTH a 0-D scalar index and a 1-D single-element index, so the
1-D case had its rank dropped: Graph::getInputData() then emitted a 0-D
(dimensionless) TensorProto for the propagated value.

For the common Shape -> Gather([-1]) -> TopK exporter pattern this produced a
0-D K initializer, which ONNX TopK shape inference correctly rejects ("K
input must be a one-dimensional tensor of size 1.") at Graph::Resolve time.
The model is spec-valid (a 1-D Gather index yields a rank-1 Gather output),
so this was an ORT rank-preservation bug. It reproduces even at
ORT_DISABLE_ALL, where constant folding never runs, confirming the cause is
shape-inference data propagation rather than constant folding.

Changes:
- Gather: route by the index rank instead of element count. A 0-D scalar
  index stores a scalar value; a 1-D single-element index stores a rank-1
  value, so getInputData() emits a TensorProto with dims=[1] and downstream
  TopK sees a valid 1-D size-1 K; a rank >= 2 index (or an index whose rank is
  unknown) declines and falls back to ONNX data propagation, because the
  single-value channel cannot faithfully represent a rank >= 2 Gather output.
  The decision is a pure classifier (ClassifySingleValueRank). The index rank
  is sourced from the same constant initializer the index value comes from
  (via get_initialized_input_values now reporting the initializer rank),
  instead of a second, independently resolved NodeArg shape -- removing a
  potential source-of-truth drift.
- Elementwise consumers (Add/Sub/Mul/Div): previously scalar-only, they would
  silently stop propagating once an operand became a rank-1 value (e.g. a
  Shape -> Gather(1-D idx) -> Mul -> TopK chain), since the custom propagation
  result replaces ONNX's rank-correct fallback. They now accept a single
  element carried as either a rank-0 scalar or a rank-1 [1] value and keep the
  output rank consistent with ONNX broadcasting (rank-1 if any operand is
  rank-1, else scalar), so such chains keep propagating end to end. Div also
  guards against division by zero.
- Unsqueeze: decline rather than propagate a dubious value when unsqueezing a
  single-element (scalar-like, rank-1 [1]) value, whose result is a rank >= 2
  tensor the values channel cannot faithfully represent (it would otherwise
  fabricate a misleading [1, value]). Multi-element shape vectors are
  unaffected. Squeeze is left unchanged -- it already converts a rank-1 [1]
  value to the correct scalar.
- Add shared helpers (data_propagation_value_utils.h) for reading/writing a
  single-element shape value while preserving its rank, used by Gather, the
  elementwise ops and Unsqueeze so producers and consumers cannot disagree on
  rank. The reader declines a rank-1 multi-element value (it must never
  collapse to element[0]), so a multi-element value cannot be mistaken for a
  single one. The setter is correct-by-construction: it populates exactly one
  channel and clears the other, so the scalar-first reader and the values-first
  getInputData() can never disagree on rank even if the output carried a stale
  value from a prior pass.

Tests: add ShapeInferenceV2Test.GatherToTopKRankPreservationTest and
GatherMulToTopKRankPreservationTest (with fixtures and generators) that load
the model at the disabled, basic, and all-optimization levels and assert the
rank-1 K is preserved and propagated through the chain, plus
GatherSqueezeRangeRankPreservationTest, an
observable end-to-end lock that a Shape -> Gather([-1]) -> Squeeze -> Range
chain resolves Range's length to the concrete propagated K (locking Squeeze's
correct-scalar behavior through a real downstream consumer). Add unit tests
pinning the pure helpers directly: SinglePropagatedShapeValueGuardTest (reader
behavior on each channel), SetSinglePropagatedShapeValueKeepsSingleChannelTest
(the setter clears the opposite channel), ClassifySingleValueRankRoutingTest
(scalar/rank-1/decline routing) and ShouldDeclineUnsqueezeSingleValueTest (the
Unsqueeze decline decision; a graph-level test is structurally infeasible
because a rank >= 2 result cannot flow to any legal ONNX shape-consumer). Add
ShapeMulMultiElementNoScalarCollapseTest, an end-to-end check that a
Shape -> Mul -> ConstantOfShape multi-element chain still resolves to its full
rank-2 shape. The data-propagation regression tests run at the disabled, basic,
and all-optimization levels (data propagation executes in the pre-optimization
Graph::Resolve pass, so it is independent of the graph-optimization level).
Existing scalar-index data-prop fixtures continue to exercise the scalar path
unchanged.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: titaiwangms <titaiwang@microsoft.com>
@titaiwangms titaiwangms force-pushed the fix/topk-gather-dataprop-rank branch from 14865d9 to 90ff823 Compare June 17, 2026 18:13
@titaiwangms

Copy link
Copy Markdown
Contributor Author

Pushed a small follow-up (now at 90ff823) addressing the windows_x64_asan CI failure.

Root cause: the failure was an AddressSanitizer allocator footprint effect, not a build break or a test assertion failure — the full ASan test binary exhausted ASan's 8 GB SizeClassAllocator ceiling and aborted in an unrelated later test. The build compiled cleanly and the data-propagation tests passed before the abort.

Change in this commit: trims the four data-propagation opt-level test loops from five levels back to {ORT_DISABLE_ALL, ORT_ENABLE_BASIC, ORT_ENABLE_ALL}, removing the extra ORT_ENABLE_EXTENDED + ORT_ENABLE_LAYOUT session constructions that were the largest contributor to peak ASan allocation. This is coverage-neutral: data propagation runs in the optimization-level-independent Graph::Resolve pass, and ORT_ENABLE_ALL already applies the extended and layout transformers, so the dropped levels exercised no additional behavior. ORT_DISABLE_ALL — the level that proves the regression is data-propagation rather than constant folding — is retained in every loop.

No production code changed from the previous revision. Full onnxruntime_test_all is green locally (1807 passed / 0 failed). The windows_x64_asan job has been re-triggered on 90ff823 to confirm peak allocation now stays under the cap.

effective_num_dims = (indices_shape != nullptr) ? indices_shape->dim_size() : -1;
}

switch (ClassifySingleValueRank(effective_num_dims)) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This enum and helper are only consumed by the switch in gather_op_data_propagation.cc, where it just maps num_dims to a boolean is_rank1 with a decline for ≥ 2. This can be replaced with a simple if/else inline in the Gather propagation:

 if (effective_num_dims == 0) {
   SetSinglePropagatedShapeValue(output_def_, dim.dim_value(), /*is_rank1=*/false);
 } else if (effective_num_dims == 1) {
   SetSinglePropagatedShapeValue(output_def_, dim.dim_value(), /*is_rank1=*/true);
 }
 // else: rank >= 2 or unknown — decline

It makes code much simpler.

// predicate ShouldDeclineUnsqueezeSingleValue (see ShouldDeclineUnsqueezeSingleValueTest);
// full-suite-green is only the no-regression backstop. Multi-element shape vectors are
// unaffected.
if (ShouldDeclineUnsqueezeSingleValue(tensor_shape_proto.dim_size())) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is Unsqueeze-specific logic ("should Unsqueeze decline propagating this value?") and is only called from unsqueeze_op_data_propagation.cc. It belongs alongside the Unsqueeze implementation, not in the shared utils.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants