Skip to content

Commit cb5e290

Browse files
authored
[Relax][ONNX] Add Optional and MatMulInteger16 frontend support (#18950)
## Summary This PR adds Relax ONNX frontend support for: - `Optional` - `OptionalHasElement` - `OptionalGetElement` - `MatMulInteger16` from the `com.microsoft` domain The implementation follows existing TVM ONNX frontend patterns and keeps Optional handling explicit through an empty-Optional sentinel during import. ## Changes - add ONNX frontend converters for `Optional`, `OptionalHasElement`, and `OptionalGetElement` - add ONNX frontend converter for `MatMulInteger16` - extend ONNX attribute parsing to handle `TYPE_PROTO` - preserve empty Optional values during import and unwrap them consistently - register Optional-related ops and `MatMulInteger16` in the ONNX converter map - handle Optional outputs correctly in importer output counting and normalization - tighten converter docstrings and input validation for better consistency with nearby TVM code ## Tests Added or updated tests in `tests/python/relax/test_frontend_onnx.py` to cover: - numerical correctness for `MatMulInteger16` - structural IR checks for `MatMulInteger16` - invalid dtype rejection for `MatMulInteger16` - tensor and sequence Optional round-trips - empty Optional behavior for `OptionalHasElement` - structural IR checks ensuring Optional ops are erased as expected - missing `type` attribute rejection for empty `Optional` - empty `OptionalGetElement` rejection ## Validation Validated with: - `python -m ruff check python/tvm/relax/frontend/onnx/onnx_frontend.py tests/python/relax/test_frontend_onnx.py` - `python -m pytest -n 1 tests/python/relax/test_frontend_onnx.py -k "optional or matmulinteger16" -v` Result: - `13 passed` This PR completes the ONNX `MatMulInteger16` and `Optional` work tracked in #18945.
1 parent e229bda commit cb5e290

2 files changed

Lines changed: 455 additions & 10 deletions

File tree

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 112 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,37 @@ def _impl_v13(cls, bb, inputs, attr, params):
318318
return relax.op.matmul(inputs[0], inputs[1])
319319

320320

321+
class MatMulInteger16(OnnxOpConverter):
322+
"""Converts an ONNX MatMulInteger16 node into an equivalent Relax expression."""
323+
324+
@classmethod
325+
def _impl_v1(cls, bb, inputs, attr, params):
326+
if len(inputs) != 2:
327+
raise ValueError(f"MatMulInteger16 expects two inputs, but got {len(inputs)}")
328+
a, b = inputs
329+
valid_types = ["int16", "uint16"]
330+
if a.struct_info.dtype not in valid_types:
331+
raise ValueError(
332+
"MatMulInteger16 expects input A to have int16 or uint16 dtype, "
333+
f"but got {a.struct_info.dtype}"
334+
)
335+
if b.struct_info.dtype not in valid_types:
336+
raise ValueError(
337+
"MatMulInteger16 expects input B to have int16 or uint16 dtype, "
338+
f"but got {b.struct_info.dtype}"
339+
)
340+
341+
out_dtype = (
342+
"uint32"
343+
if a.struct_info.dtype == "uint16" and b.struct_info.dtype == "uint16"
344+
else "int32"
345+
)
346+
return relax.op.matmul(
347+
relax.op.astype(a, out_dtype),
348+
relax.op.astype(b, out_dtype),
349+
)
350+
351+
321352
def _to_numpy(x):
322353
if isinstance(x, relax.PrimValue):
323354
x = x.value
@@ -328,6 +359,19 @@ def _to_numpy(x):
328359
return x.data.numpy()
329360

330361

362+
class _EmptyOptional:
363+
"""Sentinel object that preserves an empty ONNX Optional during import."""
364+
365+
def __init__(self, type_proto: onnx.onnx_ml_pb2.TypeProto):
366+
self.type_proto = type_proto
367+
368+
369+
def _is_empty_optional(value: Any) -> bool:
370+
"""Returns whether the given value represents an empty ONNX Optional."""
371+
372+
return isinstance(value, _EmptyOptional)
373+
374+
331375
class BinaryBase(OnnxOpConverter):
332376
"""Converts an onnx BinaryBase node into an equivalent Relax expression."""
333377

@@ -3686,6 +3730,50 @@ def _impl_v1(cls, bb, inputs, attr, params):
36863730
)
36873731

36883732

3733+
class Optional_(OnnxOpConverter):
3734+
"""Converts an ONNX Optional node into an erased or empty Optional representation."""
3735+
3736+
@classmethod
3737+
def _impl_v15(cls, bb, inputs, attr, params):
3738+
if len(inputs) > 1:
3739+
raise ValueError(f"Optional accepts at most one input, but got {len(inputs)}")
3740+
if len(inputs) == 0 or inputs[0] is None:
3741+
if "type" not in attr:
3742+
raise ValueError("Optional without an input must specify the type attribute.")
3743+
return _EmptyOptional(attr["type"])
3744+
return inputs[0]
3745+
3746+
_impl_v18 = _impl_v15
3747+
3748+
3749+
class OptionalHasElement(OnnxOpConverter):
3750+
"""Converts an ONNX OptionalHasElement node into a boolean constant."""
3751+
3752+
@classmethod
3753+
def _impl_v15(cls, bb, inputs, attr, params):
3754+
if len(inputs) != 1:
3755+
raise ValueError(f"OptionalHasElement expects one input, but got {len(inputs)}")
3756+
if inputs[0] is None or _is_empty_optional(inputs[0]):
3757+
return relax.const(False, dtype="bool")
3758+
return relax.const(True, dtype="bool")
3759+
3760+
_impl_v18 = _impl_v15
3761+
3762+
3763+
class OptionalGetElement(OnnxOpConverter):
3764+
"""Converts an ONNX OptionalGetElement node by unwrapping a non-empty Optional."""
3765+
3766+
@classmethod
3767+
def _impl_v15(cls, bb, inputs, attr, params):
3768+
if len(inputs) != 1:
3769+
raise ValueError(f"OptionalGetElement expects one input, but got {len(inputs)}")
3770+
if inputs[0] is None or _is_empty_optional(inputs[0]):
3771+
raise ValueError("OptionalGetElement cannot access an empty optional.")
3772+
return inputs[0]
3773+
3774+
_impl_v18 = _impl_v15
3775+
3776+
36893777
class SequenceConstruct(OnnxOpConverter):
36903778
"""Operator converter for sequence construction op."""
36913779

@@ -4111,9 +4199,9 @@ def _impl_v10(cls, bb, inputs, attr, params):
41114199
def _get_convert_map():
41124200
return {
41134201
# defs/experimental
4114-
# "Optional": Optional_,
4115-
# "OptionalHasElement": OptionalHasElement,
4116-
# "OptionalGetElement": OptionalGetElement,
4202+
"Optional": Optional_,
4203+
"OptionalHasElement": OptionalHasElement,
4204+
"OptionalGetElement": OptionalGetElement,
41174205
# Binary operators
41184206
"Add": Add,
41194207
"Sub": Sub,
@@ -4184,7 +4272,7 @@ def _get_convert_map():
41844272
"Gemm": Gemm,
41854273
"MatMul": MatMul,
41864274
"MatMulInteger": MatMulInteger,
4187-
# "MatMulInteger16": MatMulInteger16,
4275+
"MatMulInteger16": MatMulInteger16,
41884276
"Reshape": Reshape,
41894277
"Sigmoid": Sigmoid,
41904278
"Softmax": Softmax,
@@ -4343,7 +4431,18 @@ def from_onnx(self, graph: onnx.onnx_ml_pb2.ModelProto, opset: int) -> IRModule:
43434431
self._check_for_unsupported_ops(graph)
43444432
self._construct_nodes(graph)
43454433

4346-
outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
4434+
# now return the outputs
4435+
output_names = [self._parse_value_proto(output) for output in graph.output]
4436+
outputs = []
4437+
for output_name in output_names:
4438+
output_value = self._nodes[output_name]
4439+
if _is_empty_optional(output_value):
4440+
raise ValueError(
4441+
"ONNX graph output "
4442+
f"{output_name} is an empty optional. Empty optional graph outputs "
4443+
"are not supported by the Relax ONNX frontend."
4444+
)
4445+
outputs.append(output_value)
43474446
outputs = outputs[0] if len(outputs) == 1 else relax.Tuple(outputs)
43484447

43494448
if has_if:
@@ -4515,6 +4614,8 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto):
45154614
"Squeeze",
45164615
]
45174616
return_tuple_ops = [
4617+
"Optional",
4618+
"OptionalGetElement",
45184619
"SequenceConstruct",
45194620
"SequenceEmpty",
45204621
"SequenceErase",
@@ -4533,7 +4634,8 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto):
45334634
try:
45344635
op = self._convert_operator(op_name, inputs, attr, self.opset)
45354636
# Create struct information for the new operator.
4536-
op = self.bb.normalize(op)
4637+
if isinstance(op, relax.Expr):
4638+
op = self.bb.normalize(op)
45374639
except TVMError as err:
45384640
print(f"Error converting operator {op_name}, with inputs: {inputs}")
45394641
raise err
@@ -4585,11 +4687,11 @@ def _parse_attr(self, attr_proto: onnx.onnx_ml_pb2.AttributeProto) -> dict[str,
45854687
if list(getattr(a, f)):
45864688
assert a.name not in attrs, "Only one type of attr is allowed"
45874689
attrs[a.name] = tuple(getattr(a, f))
4588-
for f in ["t"]:
4589-
if a.HasField(f):
4690+
for f in ["t", "tp"]:
4691+
if hasattr(a, f) and a.HasField(f):
45904692
attrs[a.name] = getattr(a, f)
4591-
for f in ["tensors"]:
4592-
if list(getattr(a, f)):
4693+
for f in ["tensors", "type_protos"]:
4694+
if hasattr(a, f) and list(getattr(a, f)):
45934695
assert a.name not in attrs, "Only one type of attr is allowed"
45944696
attrs[a.name] = tuple(getattr(a, f))
45954697
for f in ["graphs"]:

0 commit comments

Comments
 (0)