Skip to content

Commit c79caf0

Browse files
authored
[Relax][ONNX] Complete ShapeExpr reshape handling in ONNX frontend (#18956)
## Summary Complete `Reshape` handling for shape values in the Relax ONNX frontend. ## Changes - keep `ShapeExpr -> Reshape([-1])` on the shape-specialized path - materialize `ShapeExpr` to an `int64` tensor for other reshape targets and apply regular tensor reshape semantics - add frontend coverage for `Shape -> Reshape([-1])` - add frontend coverage for reshaping shape outputs to non-`[-1]` targets such as `[1, 3]` and `[3, 1]` - extend symbolic shape deduction coverage to include the common `Shape -> Reshape([-1]) -> Gather -> Unsqueeze` shape-construction pattern ## Validation - `pytest -k 'test_symbolic_shape_deduction or test_reshape_shape_output or test_reshape'` This PR completes the `Reshape` limitation in the Relax ONNX frontend operator work tracked in #18945.
1 parent cb5e290 commit c79caf0

2 files changed

Lines changed: 63 additions & 9 deletions

File tree

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,11 +1016,16 @@ def _impl_v13(cls, bb, inputs, attr, params):
10161016
data = inputs[0]
10171017
new_shape = get_constant(inputs[1], params)
10181018

1019-
if isinstance(data, relax.ShapeExpr) and isinstance(new_shape, relax.Constant):
1020-
new_shape = new_shape.data.numpy().tolist()
1021-
if new_shape != [-1]:
1022-
raise NotImplementedError("Need to fix this case")
1023-
return data
1019+
if isinstance(data, relax.ShapeExpr):
1020+
# Preserve identity flatten for shape values to keep shape-specialized
1021+
# handling in downstream shape-construction patterns.
1022+
if isinstance(new_shape, relax.Constant):
1023+
new_shape_values = new_shape.data.numpy().tolist()
1024+
if new_shape_values == [-1]:
1025+
return data
1026+
1027+
# Other reshape targets follow regular int64 tensor reshape semantics.
1028+
data = bb.normalize(relax.op.shape_to_tensor(data))
10241029

10251030
if isinstance(data, relax.Constant) and isinstance(new_shape, relax.Constant):
10261031
out = _np.reshape(data.data.numpy(), new_shape.data.numpy().tolist())

tests/python/relax/test_frontend_onnx.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,38 @@ def test_reshape(in_shape, shape, out_shape):
971971
check_correctness(model, inputs=input_values)
972972

973973

974+
@pytest.mark.parametrize(
975+
"target_shape, output_shape",
976+
[
977+
([-1], [3]),
978+
([1, 3], [1, 3]),
979+
([3, 1], [3, 1]),
980+
],
981+
)
982+
def test_reshape_shape_output(target_shape, output_shape):
983+
shape_node = helper.make_node("Shape", ["data"], ["shape_out"])
984+
reshape_node = helper.make_node("Reshape", ["shape_out", "target_shape"], ["reshaped"])
985+
986+
data_shape = [2, 3, 4]
987+
988+
graph = helper.make_graph(
989+
[shape_node, reshape_node],
990+
"reshape_shape_output",
991+
inputs=[
992+
helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape),
993+
],
994+
initializer=[
995+
helper.make_tensor("target_shape", TensorProto.INT64, [len(target_shape)], target_shape)
996+
],
997+
outputs=[helper.make_tensor_value_info("reshaped", TensorProto.INT64, output_shape)],
998+
)
999+
input_values = {
1000+
"data": np.random.randn(*data_shape).astype("float32"),
1001+
}
1002+
model = helper.make_model(graph, producer_name="reshape_shape_output")
1003+
check_correctness(model, inputs=input_values)
1004+
1005+
9741006
def test_transpose():
9751007
verify_unary("Transpose", [32, 32, 32], attrs={"perm": [1, 2, 0]})
9761008

@@ -3630,29 +3662,46 @@ def test_optional_get_element_empty_raises():
36303662
from_onnx(model, opset=18, keep_params_in_input=True)
36313663

36323664

3633-
def test_symbolic_shape_deduction():
3665+
@pytest.mark.parametrize("with_reshape_flatten", [False, True])
3666+
def test_symbolic_shape_deduction(with_reshape_flatten):
36343667
index_node = helper.make_node(
36353668
"Constant",
36363669
inputs=[],
36373670
outputs=["indices"],
36383671
value=helper.make_tensor("indices", TensorProto.INT64, [], [0]),
36393672
)
36403673
shape_node = helper.make_node("Shape", ["data"], ["shape_output"])
3641-
gather_node = helper.make_node("Gather", ["shape_output", "indices"], ["gather_output"])
3674+
nodes = [index_node, shape_node]
3675+
gather_input = "shape_output"
3676+
3677+
if with_reshape_flatten:
3678+
reshape_node = helper.make_node(
3679+
"Reshape", ["shape_output", "target_shape"], ["reshaped_shape"]
3680+
)
3681+
nodes.append(reshape_node)
3682+
gather_input = "reshaped_shape"
3683+
3684+
gather_node = helper.make_node("Gather", [gather_input, "indices"], ["gather_output"])
36423685
unsqueeze_node = helper.make_node("Unsqueeze", ["gather_output", "axes"], ["unsqueeze_output"])
36433686
constant_of_shape_node = helper.make_node(
36443687
"ConstantOfShape",
36453688
["unsqueeze_output"],
36463689
["output"],
36473690
value=helper.make_tensor("value", TensorProto.FLOAT, [], [1]),
36483691
)
3692+
nodes.extend([gather_node, unsqueeze_node, constant_of_shape_node])
3693+
3694+
initializers = [helper.make_tensor("axes", TensorProto.INT64, [1], vals=[0])]
3695+
if with_reshape_flatten:
3696+
initializers.append(helper.make_tensor("target_shape", TensorProto.INT64, [1], vals=[-1]))
3697+
36493698
graph = helper.make_graph(
3650-
[index_node, shape_node, gather_node, unsqueeze_node, constant_of_shape_node],
3699+
nodes,
36513700
"test_shape_deduction",
36523701
inputs=[
36533702
helper.make_tensor_value_info("data", TensorProto.FLOAT, ["batch", "seq"]),
36543703
],
3655-
initializer=[helper.make_tensor("axes", TensorProto.INT64, [1], vals=[0])],
3704+
initializer=initializers,
36563705
outputs=[helper.make_tensor_value_info("output", TensorProto.INT64, [1])],
36573706
)
36583707
model = helper.make_model(graph, producer_name="test_shape_deduction")

0 commit comments

Comments
 (0)