Skip to content

Fix incorrect rank handling in aten_repeat_interleave_self_int ONNX export #2932#2936

Open
PratikWayase wants to merge 3 commits into
microsoft:mainfrom
PratikWayase:fix-repeat-interleave-shape-inference
Open

Fix incorrect rank handling in aten_repeat_interleave_self_int ONNX export #2932#2936
PratikWayase wants to merge 3 commits into
microsoft:mainfrom
PratikWayase:fix-repeat-interleave-shape-inference

Conversation

@PratikWayase

Copy link
Copy Markdown

Summary

This PR fixes a ShapeInferenceError that occurs during the ONNX export of torch.repeat_interleave on 1D tensors. Additionally, it implements full support for the dim=None case, which was previously raising a NotImplementedError.

Problem

When exporting a model using torch.repeat_interleave on a 1D tensor, the exporter incorrectly returned an Identity node with a rank of 2 instead of 1. This caused a strict shape inference failure:

[ShapeInferenceError] Inferred shape and existing shape differ in rank: (2) vs (1)

Solution

  • Handled dim=None: Modified the logic in core.py to flatten the input tensor (op.Reshape(self, [-1])) and set dim = 0 and self_rank = 1 when dim is None. This allows the existing expansion logic to process it correctly.
  • Removed faulty shortcut: Deleted the incorrect if self_rank == 1: return op.Identity(tiled) block. This block was returning a tensor with an inflated rank. The code now correctly falls through to the final_shape calculation and op.Reshape to ensure the output rank perfectly matches the input rank.
  • Updated Tests: Removed the .skip() markers in ops_test_data.py that were intentionally bypassing tests for the dim=None case. This enables full test coverage for this newly supported scenario.

Testing

Related Issue

Fixes #2932

Comment thread onnxscript/function_libs/torch_lib/ops/core.py Fixed
Comment thread onnxscript/function_libs/torch_lib/ops/core.py Fixed

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the Torch-to-ONNX export implementation of aten::repeat_interleave to fix incorrect rank handling for 1D inputs and to add support for dim=None, aligning exported graphs with PyTorch semantics and preventing strict ONNX shape inference failures.

Changes:

  • Add handling for dim=None by flattening the input and routing through the existing expand/reshape logic.
  • Remove the prior self_rank == 1 shortcut that produced an incorrect output rank.
  • Enable previously skipped dim=None test cases in ops_test_data.py.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
onnxscript/function_libs/torch_lib/ops/core.py Adjusts aten_repeat_interleave_self_int rank/dim handling and removes the problematic 1D Identity shortcut.
tests/function_libs/torch_lib/ops_test_data.py Unskips dim=None test cases so they execute as part of the existing test suite.
Comments suppressed due to low confidence (1)

onnxscript/function_libs/torch_lib/ops/core.py:8304

  • The dim is None branch introduces redundant/ineffective self_rank logic and leaves dim unnormalized. self_rank is assigned in the branch and then immediately overwritten by self_rank = len(self.shape), and negative dim values are converted to pos_dim but dim is still later used (via end=dim), which can slice the wrong prefix when dim < 0. Normalize dim to pos_dim once and compute self_rank only once (after the optional flatten).
    if dim is None:
        self = op.Reshape(self,[-1])
        dim = 0
        self_rank = 1
    else:
        self_rank = len(self.shape)

    self_rank = len(self.shape)
    pos_dim = (dim + self_rank) % self_rank
    unsqueezed = op.Unsqueeze(self, [pos_dim + 1])

@codecov

codecov Bot commented Jun 15, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 0% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.66%. Comparing base (5989b56) to head (8b48666).

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 0.00% 4 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2936   +/-   ##
=======================================
  Coverage   72.66%   72.66%           
=======================================
  Files         259      259           
  Lines       31748    31748           
  Branches     3005     3004    -1     
=======================================
  Hits        23069    23069           
  Misses       7660     7660           
  Partials     1019     1019           

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

Comment thread tests/function_libs/torch_lib/ops_test.py Outdated
Comment thread onnxscript/function_libs/torch_lib/ops/core.py Fixed
Comment thread onnxscript/function_libs/torch_lib/ops/core.py Fixed
Comment thread tests/function_libs/torch_lib/ops_test.py Fixed
Comment thread tests/function_libs/torch_lib/ops_test.py Fixed
Comment thread tests/function_libs/torch_lib/ops_test.py Fixed
Comment thread tests/function_libs/torch_lib/ops_test.py Fixed
Comment thread tests/function_libs/torch_lib/ops_test.py Fixed
Comment thread tests/function_libs/torch_lib/ops_test.py Fixed
Comment thread tests/function_libs/torch_lib/ops_test.py Fixed
Comment thread tests/function_libs/torch_lib/ops_test.py Fixed

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

Comment on lines 8279 to +8283
"""repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor

The trick is to repeat in one direction orthogonal to reshape.

.. code-block:: python

x = torch.tensor([[0, 1, 2], [3, 4, 5]])
x.repeat_interleave(2, dim=0)

x = torch.tensor([[0, 1, 2], [3, 4, 5]])
x.repeat_interleave(2, dim=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

bug in onnx export of aten_repeat_interleave_self_int results in incorrect code

4 participants