@@ -318,6 +318,37 @@ def _impl_v13(cls, bb, inputs, attr, params):
318318 return relax .op .matmul (inputs [0 ], inputs [1 ])
319319
320320
321+ class MatMulInteger16 (OnnxOpConverter ):
322+ """Converts an ONNX MatMulInteger16 node into an equivalent Relax expression."""
323+
324+ @classmethod
325+ def _impl_v1 (cls , bb , inputs , attr , params ):
326+ if len (inputs ) != 2 :
327+ raise ValueError (f"MatMulInteger16 expects two inputs, but got { len (inputs )} " )
328+ a , b = inputs
329+ valid_types = ["int16" , "uint16" ]
330+ if a .struct_info .dtype not in valid_types :
331+ raise ValueError (
332+ "MatMulInteger16 expects input A to have int16 or uint16 dtype, "
333+ f"but got { a .struct_info .dtype } "
334+ )
335+ if b .struct_info .dtype not in valid_types :
336+ raise ValueError (
337+ "MatMulInteger16 expects input B to have int16 or uint16 dtype, "
338+ f"but got { b .struct_info .dtype } "
339+ )
340+
341+ out_dtype = (
342+ "uint32"
343+ if a .struct_info .dtype == "uint16" and b .struct_info .dtype == "uint16"
344+ else "int32"
345+ )
346+ return relax .op .matmul (
347+ relax .op .astype (a , out_dtype ),
348+ relax .op .astype (b , out_dtype ),
349+ )
350+
351+
321352def _to_numpy (x ):
322353 if isinstance (x , relax .PrimValue ):
323354 x = x .value
@@ -328,6 +359,19 @@ def _to_numpy(x):
328359 return x .data .numpy ()
329360
330361
362+ class _EmptyOptional :
363+ """Sentinel object that preserves an empty ONNX Optional during import."""
364+
365+ def __init__ (self , type_proto : onnx .onnx_ml_pb2 .TypeProto ):
366+ self .type_proto = type_proto
367+
368+
369+ def _is_empty_optional (value : Any ) -> bool :
370+ """Returns whether the given value represents an empty ONNX Optional."""
371+
372+ return isinstance (value , _EmptyOptional )
373+
374+
331375class BinaryBase (OnnxOpConverter ):
332376 """Converts an onnx BinaryBase node into an equivalent Relax expression."""
333377
@@ -3686,6 +3730,50 @@ def _impl_v1(cls, bb, inputs, attr, params):
36863730 )
36873731
36883732
3733+ class Optional_ (OnnxOpConverter ):
3734+ """Converts an ONNX Optional node into an erased or empty Optional representation."""
3735+
3736+ @classmethod
3737+ def _impl_v15 (cls , bb , inputs , attr , params ):
3738+ if len (inputs ) > 1 :
3739+ raise ValueError (f"Optional accepts at most one input, but got { len (inputs )} " )
3740+ if len (inputs ) == 0 or inputs [0 ] is None :
3741+ if "type" not in attr :
3742+ raise ValueError ("Optional without an input must specify the type attribute." )
3743+ return _EmptyOptional (attr ["type" ])
3744+ return inputs [0 ]
3745+
3746+ _impl_v18 = _impl_v15
3747+
3748+
3749+ class OptionalHasElement (OnnxOpConverter ):
3750+ """Converts an ONNX OptionalHasElement node into a boolean constant."""
3751+
3752+ @classmethod
3753+ def _impl_v15 (cls , bb , inputs , attr , params ):
3754+ if len (inputs ) != 1 :
3755+ raise ValueError (f"OptionalHasElement expects one input, but got { len (inputs )} " )
3756+ if inputs [0 ] is None or _is_empty_optional (inputs [0 ]):
3757+ return relax .const (False , dtype = "bool" )
3758+ return relax .const (True , dtype = "bool" )
3759+
3760+ _impl_v18 = _impl_v15
3761+
3762+
3763+ class OptionalGetElement (OnnxOpConverter ):
3764+ """Converts an ONNX OptionalGetElement node by unwrapping a non-empty Optional."""
3765+
3766+ @classmethod
3767+ def _impl_v15 (cls , bb , inputs , attr , params ):
3768+ if len (inputs ) != 1 :
3769+ raise ValueError (f"OptionalGetElement expects one input, but got { len (inputs )} " )
3770+ if inputs [0 ] is None or _is_empty_optional (inputs [0 ]):
3771+ raise ValueError ("OptionalGetElement cannot access an empty optional." )
3772+ return inputs [0 ]
3773+
3774+ _impl_v18 = _impl_v15
3775+
3776+
36893777class SequenceConstruct (OnnxOpConverter ):
36903778 """Operator converter for sequence construction op."""
36913779
@@ -4111,9 +4199,9 @@ def _impl_v10(cls, bb, inputs, attr, params):
41114199def _get_convert_map ():
41124200 return {
41134201 # defs/experimental
4114- # "Optional": Optional_,
4115- # "OptionalHasElement": OptionalHasElement,
4116- # "OptionalGetElement": OptionalGetElement,
4202+ "Optional" : Optional_ ,
4203+ "OptionalHasElement" : OptionalHasElement ,
4204+ "OptionalGetElement" : OptionalGetElement ,
41174205 # Binary operators
41184206 "Add" : Add ,
41194207 "Sub" : Sub ,
@@ -4184,7 +4272,7 @@ def _get_convert_map():
41844272 "Gemm" : Gemm ,
41854273 "MatMul" : MatMul ,
41864274 "MatMulInteger" : MatMulInteger ,
4187- # "MatMulInteger16": MatMulInteger16,
4275+ "MatMulInteger16" : MatMulInteger16 ,
41884276 "Reshape" : Reshape ,
41894277 "Sigmoid" : Sigmoid ,
41904278 "Softmax" : Softmax ,
@@ -4343,7 +4431,18 @@ def from_onnx(self, graph: onnx.onnx_ml_pb2.ModelProto, opset: int) -> IRModule:
43434431 self ._check_for_unsupported_ops (graph )
43444432 self ._construct_nodes (graph )
43454433
4346- outputs = [self ._nodes [self ._parse_value_proto (i )] for i in graph .output ]
4434+ # now return the outputs
4435+ output_names = [self ._parse_value_proto (output ) for output in graph .output ]
4436+ outputs = []
4437+ for output_name in output_names :
4438+ output_value = self ._nodes [output_name ]
4439+ if _is_empty_optional (output_value ):
4440+ raise ValueError (
4441+ "ONNX graph output "
4442+ f"{ output_name } is an empty optional. Empty optional graph outputs "
4443+ "are not supported by the Relax ONNX frontend."
4444+ )
4445+ outputs .append (output_value )
43474446 outputs = outputs [0 ] if len (outputs ) == 1 else relax .Tuple (outputs )
43484447
43494448 if has_if :
@@ -4515,6 +4614,8 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto):
45154614 "Squeeze" ,
45164615 ]
45174616 return_tuple_ops = [
4617+ "Optional" ,
4618+ "OptionalGetElement" ,
45184619 "SequenceConstruct" ,
45194620 "SequenceEmpty" ,
45204621 "SequenceErase" ,
@@ -4533,7 +4634,8 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto):
45334634 try :
45344635 op = self ._convert_operator (op_name , inputs , attr , self .opset )
45354636 # Create struct information for the new operator.
4536- op = self .bb .normalize (op )
4637+ if isinstance (op , relax .Expr ):
4638+ op = self .bb .normalize (op )
45374639 except TVMError as err :
45384640 print (f"Error converting operator { op_name } , with inputs: { inputs } " )
45394641 raise err
@@ -4585,11 +4687,11 @@ def _parse_attr(self, attr_proto: onnx.onnx_ml_pb2.AttributeProto) -> dict[str,
45854687 if list (getattr (a , f )):
45864688 assert a .name not in attrs , "Only one type of attr is allowed"
45874689 attrs [a .name ] = tuple (getattr (a , f ))
4588- for f in ["t" ]:
4589- if a .HasField (f ):
4690+ for f in ["t" , "tp" ]:
4691+ if hasattr ( a , f ) and a .HasField (f ):
45904692 attrs [a .name ] = getattr (a , f )
4591- for f in ["tensors" ]:
4592- if list (getattr (a , f )):
4693+ for f in ["tensors" , "type_protos" ]:
4694+ if hasattr ( a , f ) and list (getattr (a , f )):
45934695 assert a .name not in attrs , "Only one type of attr is allowed"
45944696 attrs [a .name ] = tuple (getattr (a , f ))
45954697 for f in ["graphs" ]:
0 commit comments