Skip to content

Commit 5cca18b

Browse files
author
Krzysztof Parzyszek
authored
[Frontend] Add ONNX importer for QLinearSoftmax (#14425)
Add a missing importer for quantized softmax.
1 parent 9f6ce7c commit 5cca18b

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5715,6 +5715,26 @@ def _impl_v10(cls, inputs, attr, params):
57155715
return _qnn.op.quantize(out, y_scale, y_zero_point, out_dtype=dtype)
57165716

57175717

5718+
class QLinearSoftmax(OnnxOpConverter):
5719+
"""Operator converter for QLinearSoftmax from Microsoft onnxruntime contrib opset."""
5720+
5721+
@classmethod
5722+
def _impl_v1(cls, inputs, attr, params):
5723+
axis = attr["axis"]
5724+
5725+
x = inputs[0]
5726+
x_scale = get_scalar(inputs[1], params)
5727+
x_zero_point = get_scalar(inputs[2], params, "int32")
5728+
y_scale = fold_constant(get_scalar(inputs[3], params))
5729+
y_zero_point = get_scalar(inputs[4], params, "int32")
5730+
5731+
dtype = infer_type(x).checked_type.dtype
5732+
5733+
x = _qnn.op.dequantize(x, x_scale, x_zero_point)
5734+
out = _op.nn.softmax(x, axis)
5735+
return _qnn.op.quantize(out, y_scale, y_zero_point, out_dtype=dtype)
5736+
5737+
57185738
class QLinearConcat(OnnxOpConverter):
57195739
"""Operator converter for QLinearConcat from Microsoft onnxruntime contrib opset."""
57205740

@@ -6812,6 +6832,7 @@ def _get_convert_map(opset):
68126832
"QLinearMatMul": QLinearMatMul.get_converter(opset),
68136833
"QLinearMul": QLinearMul.get_converter(opset),
68146834
"QLinearSigmoid": QLinearSigmoid.get_converter(opset),
6835+
"QLinearSoftmax": QLinearSoftmax.get_converter(opset),
68156836
"ConvInteger": ConvInteger.get_converter(opset),
68166837
"QLinearAveragePool": QLinearAveragePool.get_converter(opset),
68176838
"QLinearGlobalAveragePool": QLinearGlobalAveragePool.get_converter(opset),

tests/python/frontend/onnx/test_forward.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6990,6 +6990,31 @@ def verify_qlinearsigmoid(a_shape):
69906990
verify_qlinearsigmoid([])
69916991

69926992

6993+
@tvm.testing.parametrize_targets
6994+
def test_qlinearsoftmax(target, dev):
6995+
"""test_qlinearsoftmax"""
6996+
6997+
def verify_qlinearsoftmax(a_shape):
6998+
6999+
_ = np.random.random(a_shape).astype("float32")
7000+
7001+
input_nodes = [helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape))]
7002+
7003+
node = helper.make_node("Softmax", ["a"], ["B"])
7004+
graph = helper.make_graph(
7005+
[node],
7006+
"qlinearsoftmax_test",
7007+
inputs=input_nodes,
7008+
outputs=[helper.make_tensor_value_info("B", TensorProto.FLOAT, list(a_shape))],
7009+
)
7010+
model = helper.make_model(graph, producer_name="qlinearsoftmax_test")
7011+
quantize_and_verify_with_ort(model, ["a"], [a_shape], target, dev)
7012+
7013+
verify_qlinearsoftmax([4, 2])
7014+
verify_qlinearsoftmax([5])
7015+
verify_qlinearsoftmax([3, 4, 5])
7016+
7017+
69937018
@tvm.testing.parametrize_targets("llvm")
69947019
def test_random_bernoulli(target, dev):
69957020
"""test_random_bernoulli"""

0 commit comments

Comments
 (0)