Skip to content

Commit 645fcf9

Browse files
[Relax][ONNX] Add frontend support for QuantizeLinear, DequantizeLinear, and DynamicQuantizeLinear (#19391)
## Summary This PR adds Relax ONNX frontend support for: - `QuantizeLinear` - `DequantizeLinear` - `DynamicQuantizeLinear` The implementation follows existing TVM ONNX frontend patterns and keeps QDQ handling consistent for singleton quantization parameters and optional zero-point inputs. ## Changes - add ONNX frontend converters for `QuantizeLinear`,`DequantizeLinear`, and `DynamicQuantizeLinear` - register Q/DQ-related ops in the ONNX converter map - handle optional zero-point inputs consistently during import - preserve singleton quantization parameter semantics in the QDQ legalization path - improve QDQ legalization behavior for imported ONNX models - add and update frontend tests for Q/DQ and `DynamicQuantizeLinear` ## Tests Added or updated tests in `tests/python/relax/test_frontend_onnx.py` to cover: - singleton-qparam `QuantizeLinear` in opset 10 - singleton-qparam `DequantizeLinear` in opset 10 - optional-zero-point `QuantizeLinear` in opset 13 - `DynamicQuantizeLinear` in opset 11 ## Validation Validated with: - `python -m pytest -n 1 tests/python/relax/test_frontend_onnx.py -k "quantizelinear or dequantizelinear or dynamicquantizelinear" -v` Result: - `4 passed`
1 parent b9ced1a commit 645fcf9

4 files changed

Lines changed: 272 additions & 18 deletions

File tree

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,73 @@ def get_converter(cls, opset):
311311
return getattr(cls, f"_impl_v{version}")
312312
raise NotImplementedError(f"opset version {version} of {cls.__name__} not implemented")
313313

314+
class QuantizeLinear(OnnxOpConverter):
315+
@classmethod
316+
def _impl_v10(cls, bb, inputs, attr, params):
317+
x, scale = inputs[0], inputs[1]
318+
zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None
319+
axis = attr.get("axis", 1)
320+
if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis == 1:
321+
axis = 0
322+
out_dtype = "uint8" if zp is None else zp.struct_info.dtype
323+
if zp is None:
324+
zp = relax.const(0, out_dtype)
325+
return relax.op.quantize(x, scale, zp, axis=axis, out_dtype=out_dtype)
326+
327+
@classmethod
328+
def _impl_v13(cls, bb, inputs, attr, params):
329+
x, scale = inputs[0], inputs[1]
330+
zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None
331+
axis = attr.get("axis", 1)
332+
if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis == 1:
333+
axis = 0
334+
out_dtype = "uint8" if zp is None else zp.struct_info.dtype
335+
if zp is None:
336+
zp = relax.const(0, out_dtype)
337+
return relax.op.quantize(x, scale, zp, axis=axis, out_dtype=out_dtype)
338+
339+
340+
class DequantizeLinear(OnnxOpConverter):
341+
@classmethod
342+
def _impl_v10(cls, bb, inputs, attr, params):
343+
x, scale = inputs[0], inputs[1]
344+
zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None
345+
axis = attr.get("axis", 1)
346+
if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis == 1:
347+
axis = 0
348+
if zp is None:
349+
zp = relax.const(0, x.struct_info.dtype)
350+
return relax.op.dequantize(x, scale, zp, axis=axis, out_dtype="float32")
351+
352+
@classmethod
353+
def _impl_v13(cls, bb, inputs, attr, params):
354+
x, scale = inputs[0], inputs[1]
355+
zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None
356+
axis = attr.get("axis", 1)
357+
if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis == 1:
358+
axis = 0
359+
if zp is None:
360+
zp = relax.const(0, x.struct_info.dtype)
361+
return relax.op.dequantize(x, scale, zp, axis=axis, out_dtype="float32")
362+
363+
364+
class DynamicQuantizeLinear(OnnxOpConverter):
365+
@classmethod
366+
def _impl_v11(cls, bb, inputs, attr, params):
367+
x = inputs[0]
368+
x_dtype = x.struct_info.dtype
369+
qmin = relax.const(0, x_dtype)
370+
qmax = relax.const(255, x_dtype)
371+
372+
x_max = relax.op.maximum(qmin, relax.op.max(x))
373+
x_min = relax.op.minimum(qmin, relax.op.min(x))
374+
y_scale = relax.op.divide(relax.op.subtract(x_max, x_min), qmax)
375+
376+
zp_fp = relax.op.subtract(qmin, relax.op.divide(x_min, y_scale))
377+
y_zero_point = relax.op.astype(relax.op.round(relax.op.clip(zp_fp, 0, 255)), "uint8")
378+
379+
y = relax.op.quantize(x, y_scale, y_zero_point, axis=0, out_dtype="uint8")
380+
return relax.Tuple([y, y_scale, y_zero_point])
314381

315382
class MatMul(OnnxOpConverter):
316383
"""Converts an onnx MatMul node into an equivalent Relax expression."""
@@ -4812,6 +4879,10 @@ def _get_convert_map():
48124879
"ConcatFromSequence": ConcatFromSequence,
48134880
"SplitToSequence": SplitToSequence,
48144881
"SequenceAt": SequenceAt,
4882+
# Quantization
4883+
"QuantizeLinear": QuantizeLinear,
4884+
"DequantizeLinear": DequantizeLinear,
4885+
"DynamicQuantizeLinear": DynamicQuantizeLinear,
48154886
}
48164887

48174888

python/tvm/relax/transform/legalize_ops/qdq.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pylint: disable=invalid-name
1818
"""Default legalization function for quantize/dequantize operators."""
1919

20+
from typing import Union
2021
import tvm
2122
from tvm import te, tirx
2223

@@ -35,6 +36,18 @@ def is_const_scalar(x):
3536
return isinstance(x, tvm.tirx.IntImm | tvm.tirx.FloatImm)
3637

3738

39+
def _is_singleton_qparam(qparam: te.Tensor) -> bool:
40+
"""Return True if qparam is a tensor with all dimensions equal to 1."""
41+
if not isinstance(qparam, te.Tensor):
42+
return False
43+
if len(qparam.shape) == 0:
44+
return True
45+
for dim in qparam.shape:
46+
if not isinstance(dim, tirx.IntImm) or dim.value != 1:
47+
return False
48+
return True
49+
50+
3851
@register_legalize("relax.quantize")
3952
def _quantize(bb: BlockBuilder, call: Call) -> Expr:
4053
"""
@@ -46,12 +59,26 @@ def _quantize(bb: BlockBuilder, call: Call) -> Expr:
4659

4760
def te_quantize(
4861
data: te.Tensor,
49-
scale: te.Tensor | tirx.IntImm | tirx.FloatImm,
50-
zp: te.Tensor | tirx.IntImm | tirx.FloatImm,
62+
scale: Union[te.Tensor, tirx.IntImm, tirx.FloatImm],
63+
zp: Union[te.Tensor, tirx.IntImm, tirx.FloatImm],
5164
):
65+
scale_singleton = _is_singleton_qparam(scale) if isinstance(scale, te.Tensor) else False
66+
zp_singleton = _is_singleton_qparam(zp) if isinstance(zp, te.Tensor) else False
67+
5268
def quantize_compute(*indices):
53-
scale_value = scale if is_const_scalar(scale) else scale[indices[axis]]
54-
zp_value = zp if is_const_scalar(zp) else zp[indices[axis]]
69+
if is_const_scalar(scale):
70+
scale_value = scale
71+
elif scale_singleton:
72+
scale_value = scale[(0,) * len(scale.shape)]
73+
else:
74+
scale_value = scale[indices[axis]]
75+
76+
if is_const_scalar(zp):
77+
zp_value = zp
78+
elif zp_singleton:
79+
zp_value = zp[(0,) * len(zp.shape)]
80+
else:
81+
zp_value = zp[indices[axis]]
5582
scaled = data[indices] / scale_value
5683
round_val = (te.round(scaled) if "int" in out_dtype else scaled) + zp_value
5784
return clip_cast(round_val, out_dtype)
@@ -94,12 +121,26 @@ def _dequantize(bb: BlockBuilder, call: Call) -> Expr:
94121

95122
def te_dequantize(
96123
data: te.Tensor,
97-
scale: te.Tensor | tirx.IntImm | tirx.FloatImm,
98-
zp: te.Tensor | tirx.IntImm | tirx.FloatImm,
124+
scale: Union[te.Tensor, tirx.IntImm, tirx.FloatImm],
125+
zp: Union[te.Tensor, tirx.IntImm, tirx.FloatImm],
99126
):
127+
scale_singleton = _is_singleton_qparam(scale) if isinstance(scale, te.Tensor) else False
128+
zp_singleton = _is_singleton_qparam(zp) if isinstance(zp, te.Tensor) else False
129+
100130
def dequantize_compute(*indices):
101-
scale_value = scale if is_const_scalar(scale) else scale[indices[axis]]
102-
zp_value = zp if is_const_scalar(zp) else zp[indices[axis]]
131+
if is_const_scalar(scale):
132+
scale_value = scale
133+
elif scale_singleton:
134+
scale_value = scale[(0,) * len(scale.shape)]
135+
else:
136+
scale_value = scale[indices[axis]]
137+
138+
if is_const_scalar(zp):
139+
zp_value = zp
140+
elif zp_singleton:
141+
zp_value = zp[(0,) * len(zp.shape)]
142+
else:
143+
zp_value = zp[indices[axis]]
103144
dtype = "float32" if "float" in data.dtype else "int32"
104145
sub = te.subtract(data[indices].astype(dtype), zp_value)
105146
out = te.multiply(sub, scale_value.astype("float32"))

src/relax/op/tensor/qdq.cc

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,14 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) {
7979
}
8080

8181
// Check datatype of zero_point param:
82-
if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::Float(16)) {
82+
if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::UInt(8) &&
83+
zp_sinfo->dtype != DataType::Int(16) && zp_sinfo->dtype != DataType::UInt(16) &&
84+
zp_sinfo->dtype != DataType::Int(32) && zp_sinfo->dtype != DataType::UInt(32) &&
85+
zp_sinfo->dtype != DataType::Float(16)) {
8386
ctx->ReportFatal(Diagnostic::Error(call)
84-
<< "zero_point param datatype should be 'int8' or 'float16', but got "
85-
<< zp_sinfo->dtype);
87+
<< "zero_point param datatype should be one of "
88+
<< "['int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 'float16'], "
89+
<< "but got " << zp_sinfo->dtype);
8690
}
8791

8892
// Check that "axis" attribute is not out of range:
@@ -104,9 +108,22 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) {
104108
}
105109
};
106110

111+
auto is_scalar_or_singleton_vector = [&](const TensorStructInfo& param_sinfo) {
112+
if (IsScalarTensor(param_sinfo)) return true;
113+
if (param_sinfo->shape.defined() && param_sinfo->shape->IsInstance<ShapeExprNode>()) {
114+
const auto& values = param_sinfo->shape.as<ShapeExprNode>()->values;
115+
if (!values.empty()) {
116+
return std::all_of(values.begin(), values.end(), [&](const PrimExpr& dim) {
117+
return ctx->GetAnalyzer()->CanProveEqual(dim, 1);
118+
});
119+
}
120+
}
121+
return false;
122+
};
123+
107124
// Check size matching of scale/zp params with input shape at dim = attrs->axis.
108-
if (!IsScalarTensor(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, "scale");
109-
if (!IsScalarTensor(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, "zero_point");
125+
if (!is_scalar_or_singleton_vector(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, "scale");
126+
if (!is_scalar_or_singleton_vector(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, "zero_point");
110127

111128
auto output_sinfo = ffi::make_object<TensorStructInfoNode>(*input_sinfo.get());
112129
output_sinfo->dtype = attrs->out_dtype;
@@ -167,10 +184,14 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx)
167184
}
168185

169186
// Check datatype of zero_point param:
170-
if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::Float(16)) {
187+
if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::UInt(8) &&
188+
zp_sinfo->dtype != DataType::Int(16) && zp_sinfo->dtype != DataType::UInt(16) &&
189+
zp_sinfo->dtype != DataType::Int(32) && zp_sinfo->dtype != DataType::UInt(32) &&
190+
zp_sinfo->dtype != DataType::Float(16)) {
171191
ctx->ReportFatal(Diagnostic::Error(call)
172-
<< "zero_point param datatype should be 'int8' or 'float16', but got "
173-
<< zp_sinfo->dtype);
192+
<< "zero_point param datatype should be one of "
193+
<< "['int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 'float16'], "
194+
<< "but got " << zp_sinfo->dtype);
174195
}
175196

176197
// Check that "axis" attribute is not out of range:
@@ -192,9 +213,22 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx)
192213
}
193214
};
194215

216+
auto is_scalar_or_singleton_vector = [&](const TensorStructInfo& param_sinfo) {
217+
if (IsScalarTensor(param_sinfo)) return true;
218+
if (param_sinfo->shape.defined() && param_sinfo->shape->IsInstance<ShapeExprNode>()) {
219+
const auto& values = param_sinfo->shape.as<ShapeExprNode>()->values;
220+
if (!values.empty()) {
221+
return std::all_of(values.begin(), values.end(), [&](const PrimExpr& dim) {
222+
return ctx->GetAnalyzer()->CanProveEqual(dim, 1);
223+
});
224+
}
225+
}
226+
return false;
227+
};
228+
195229
// Check size matching of scale/zp params with input shape at dim = attrs->axis.
196-
if (!IsScalarTensor(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, "scale");
197-
if (!IsScalarTensor(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, "zero_point");
230+
if (!is_scalar_or_singleton_vector(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, "scale");
231+
if (!is_scalar_or_singleton_vector(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, "zero_point");
198232

199233
auto output_sinfo = ffi::make_object<TensorStructInfoNode>(*input_sinfo.get());
200234
output_sinfo->dtype = attrs->out_dtype;

tests/python/relax/test_frontend_onnx.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5599,6 +5599,114 @@ def test_split_to_sequence_uneven_last_chunk(axis: int):
55995599
model = helper.make_model(graph, producer_name="test_split_to_sequence_uneven")
56005600
check_correctness(model)
56015601

5602+
def test_quantizelinear_singleton_qparams_opset10():
5603+
"""QuantizeLinear must treat shape-[1] scale/zp as scalar in opset10."""
5604+
node = helper.make_node("QuantizeLinear", ["x", "scale", "zero_point"], ["y"])
5605+
graph = helper.make_graph(
5606+
[node],
5607+
"quantizelinear_singleton_qparams_opset10",
5608+
[helper.make_tensor_value_info("x", TensorProto.FLOAT, [4, 3, 2, 2])],
5609+
[helper.make_tensor_value_info("y", TensorProto.UINT8, [4, 3, 2, 2])],
5610+
initializer=[
5611+
helper.make_tensor("scale", TensorProto.FLOAT, [1], [0.03125]),
5612+
helper.make_tensor("zero_point", TensorProto.UINT8, [1], [127]),
5613+
],
5614+
)
5615+
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 10)])
5616+
5617+
x = rg.standard_normal((4, 3, 2, 2)).astype("float32")
5618+
check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True)
5619+
5620+
5621+
def test_dequantizelinear_singleton_qparams_opset10():
5622+
"""DequantizeLinear must treat shape-[1] scale/zp as scalar in opset10."""
5623+
node = helper.make_node("DequantizeLinear", ["x", "scale", "zero_point"], ["y"])
5624+
graph = helper.make_graph(
5625+
[node],
5626+
"dequantizelinear_singleton_qparams_opset10",
5627+
[helper.make_tensor_value_info("x", TensorProto.UINT8, [64])],
5628+
[helper.make_tensor_value_info("y", TensorProto.FLOAT, [64])],
5629+
initializer=[
5630+
helper.make_tensor("scale", TensorProto.FLOAT, [1], [0.125]),
5631+
helper.make_tensor("zero_point", TensorProto.UINT8, [1], [1]),
5632+
],
5633+
)
5634+
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 10)])
5635+
5636+
x = rg.integers(low=0, high=255, size=(64,), dtype=np.uint8)
5637+
check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True)
5638+
5639+
5640+
def test_quantizelinear_optional_zero_point_opset13():
5641+
"""ONNX allows missing zero_point input; importer should default it to 0 (uint8)."""
5642+
node = helper.make_node("QuantizeLinear", ["x", "scale"], ["y"])
5643+
graph = helper.make_graph(
5644+
[node],
5645+
"quantizelinear_optional_zero_point_opset13",
5646+
[helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 5])],
5647+
[helper.make_tensor_value_info("y", TensorProto.UINT8, [2, 5])],
5648+
initializer=[helper.make_tensor("scale", TensorProto.FLOAT, [], [0.2])],
5649+
)
5650+
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
5651+
5652+
x = rg.standard_normal((2, 5)).astype("float32")
5653+
check_correctness(model, inputs={"x": x}, opset=13, check_dtypes=True)
5654+
5655+
5656+
def test_dynamicquantizelinear_opset11():
5657+
"""DynamicQuantizeLinear returns (y, y_scale, y_zero_point) with ORT parity."""
5658+
node = helper.make_node("DynamicQuantizeLinear", ["x"], ["y", "y_scale", "y_zero_point"])
5659+
graph = helper.make_graph(
5660+
[node],
5661+
"dynamicquantizelinear_opset11",
5662+
[helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3, 4])],
5663+
[
5664+
helper.make_tensor_value_info("y", TensorProto.UINT8, [2, 3, 4]),
5665+
helper.make_tensor_value_info("y_scale", TensorProto.FLOAT, []),
5666+
helper.make_tensor_value_info("y_zero_point", TensorProto.UINT8, []),
5667+
],
5668+
)
5669+
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)])
5670+
5671+
x = rg.standard_normal((2, 3, 4)).astype("float32")
5672+
check_correctness(model, inputs={"x": x}, opset=11, atol=1e-5, rtol=1e-5, check_dtypes=True)
5673+
5674+
def test_quantizelinear_default_axis_opset10():
5675+
"""opset10 QuantizeLinear should honor default axis=1 (not hardcode axis=0)."""
5676+
node = helper.make_node("QuantizeLinear", ["x", "scale", "zero_point"], ["y"])
5677+
graph = helper.make_graph(
5678+
[node],
5679+
"quantizelinear_axis_opset10",
5680+
[helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3, 4])],
5681+
[helper.make_tensor_value_info("y", TensorProto.UINT8, [2, 3, 4])],
5682+
initializer=[
5683+
helper.make_tensor("scale", TensorProto.FLOAT, [3], [0.05, 0.1, 0.2]),
5684+
helper.make_tensor("zero_point", TensorProto.UINT8, [3], [1, 127, 250]),
5685+
],
5686+
)
5687+
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 10)])
5688+
5689+
x = rg.standard_normal((2, 3, 4)).astype("float32")
5690+
check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True)
5691+
5692+
5693+
def test_dequantizelinear_default_axis_opset10():
5694+
"""opset10 DequantizeLinear should honor default axis=1 (not hardcode axis=0)."""
5695+
node = helper.make_node("DequantizeLinear", ["x", "scale", "zero_point"], ["y"])
5696+
graph = helper.make_graph(
5697+
[node],
5698+
"dequantizelinear_axis_opset10",
5699+
[helper.make_tensor_value_info("x", TensorProto.UINT8, [2, 3, 4])],
5700+
[helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 3, 4])],
5701+
initializer=[
5702+
helper.make_tensor("scale", TensorProto.FLOAT, [3], [0.05, 0.1, 0.2]),
5703+
helper.make_tensor("zero_point", TensorProto.UINT8, [3], [1, 127, 250]),
5704+
],
5705+
)
5706+
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 10)])
5707+
5708+
x = rg.integers(low=0, high=255, size=(2, 3, 4), dtype=np.uint8)
5709+
check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True)
56025710

56035711
if __name__ == "__main__":
56045712
tvm.testing.main()

0 commit comments

Comments
 (0)