Skip to content

Commit feaed53

Browse files
committed
Add qnn.rsqrt op
1 parent 2a91f0d commit feaed53

File tree

8 files changed

+364
-0
lines changed

8 files changed

+364
-0
lines changed

python/tvm/relay/qnn/op/qnn.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,41 @@ def mul(
656656
)
657657

658658

659+
def rsqrt(x, scale, zero_point, output_scale, output_zero_point):
660+
"""Quantized reciprocal square root.
661+
662+
Parameters
663+
----------
664+
x : relay.Expr
665+
The quantized input tensor.
666+
667+
scale: relay.Expr
668+
The scale of the quantized expr.
669+
670+
zero_point: relay.Expr
671+
The zero point of quantized expr.
672+
673+
output_scale: relay.Expr
674+
The scale of the output quantized expr.
675+
676+
output_zero_point: relay.Expr
677+
The zero point of output quantized expr.
678+
679+
Returns
680+
-------
681+
result : relay.Expr
682+
The computed result.
683+
684+
"""
685+
return _make.rsqrt(
686+
x,
687+
scale,
688+
zero_point,
689+
output_scale,
690+
output_zero_point,
691+
)
692+
693+
659694
def subtract(
660695
lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
661696
):

python/tvm/relay/transform/fake_quantization_to_integer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,22 @@ def global_avgpool2d(expr, type_map):
126126
return [out, t]
127127

128128

129+
@register_fake_quantization_to_integer("rsqrt")
130+
def rsqrt(expr, type_map):
131+
"""Rewrite a rsqrt op"""
132+
arg = expr.args[0]
133+
x_t = type_map[arg]
134+
out_t = type_map[expr]
135+
out = relay.qnn.op.rsqrt(
136+
arg,
137+
x_t.scale,
138+
x_t.zero_point,
139+
out_t.scale,
140+
out_t.zero_point,
141+
)
142+
return [out, x_t]
143+
144+
129145
@register_fake_quantization_to_integer("nn.bias_add")
130146
def bias_add(expr, type_map):
131147
"""Rewrite a bias_add op"""
@@ -394,6 +410,7 @@ def binary(expr, type_map):
394410
out_t.scale,
395411
out_t.zero_point,
396412
)
413+
397414
return [out, out_t]
398415

399416
return register_fake_quantization_to_integer(op_name, binary)

src/relay/qnn/op/op_common.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,59 @@ struct QnnBinaryOpArguments {
8282
}
8383
};
8484

85+
/*
86+
* Number of inputs for the Qnn unary operators.
87+
*/
88+
static constexpr int kNumQnnUnaryOpInputs = 5;
89+
90+
/*
91+
* Number of expected arg types.
92+
*/
93+
static constexpr int kNumQnnUnaryOpArgTypes = 6;
94+
95+
/*
96+
* \brief Simple struct to organize the inputs to the Qnn
97+
* unary operators. The main reason to have a struct
98+
* is to be able to perform the common checks needed at a
99+
* central location.
100+
*/
101+
struct QnnUnaryOpArguments {
102+
Expr x;
103+
Expr scale;
104+
Expr zero_point;
105+
Expr output_scale;
106+
Expr output_zero_point;
107+
108+
explicit QnnUnaryOpArguments(const Array<Expr>& new_args) {
109+
ICHECK_EQ(new_args.size(), kNumQnnUnaryOpInputs);
110+
int idx = 0;
111+
x = new_args[idx++];
112+
scale = new_args[idx++];
113+
zero_point = new_args[idx++];
114+
output_scale = new_args[idx++];
115+
output_zero_point = new_args[idx++];
116+
ICHECK_EQ(idx, kNumQnnUnaryOpInputs);
117+
}
118+
};
119+
120+
/*
121+
* \brief Simple structure to hold the input tensor's dtype
122+
* and shape. This structure allows a common point to do
123+
* all the validation checks for Qnn unary operators.
124+
*/
125+
struct QnnUnaryOpTensorType {
126+
DataType dtype;
127+
Array<PrimExpr> shape;
128+
129+
explicit QnnUnaryOpTensorType(const Array<tvm::relay::Type>& arg_types, const int32_t arg_idx) {
130+
ICHECK_EQ(arg_types.size(), kNumQnnUnaryOpArgTypes);
131+
auto tensor_type = arg_types[arg_idx].as<TensorTypeNode>();
132+
ICHECK(tensor_type != nullptr);
133+
dtype = tensor_type->dtype;
134+
shape = tensor_type->shape;
135+
}
136+
};
137+
85138
/*
86139
* \brief Simple structure to hold the input tensor's dtype
87140
* and shape. This structure allows a common point to do

src/relay/qnn/op/rsqrt.cc

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file src/relay/qnn/op/rsqrt.cc
22+
* \brief QNN rsqrt operator.
23+
*/
24+
#include <tvm/relay/analysis.h>
25+
#include <tvm/relay/op_attr_types.h>
26+
27+
#include "op_common.h"
28+
29+
namespace tvm {
30+
namespace relay {
31+
namespace qnn {
32+
33+
bool QnnRsqrtRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
34+
const TypeReporter& reporter) {
35+
// Expected Types: data, scale, zero_point, output_scale, output_zero_point
36+
ICHECK_EQ(types.size(), 6);
37+
const auto* x = types[0].as<TensorTypeNode>();
38+
if (x == nullptr) return false;
39+
ICHECK(x->dtype == DataType::Int(8) || x->dtype == DataType::UInt(8))
40+
<< "Expected quantized rsqrt type(int8, uint8) for input but was " << x->dtype;
41+
42+
// Check the types of scale and zero points.
43+
for (size_t i = 1; i < 5; ++i) {
44+
if (types[i].as<IncompleteTypeNode>()) {
45+
return false;
46+
}
47+
}
48+
ICHECK(IsScalarType(types[1], DataType::Float(32))); // scale
49+
ICHECK(IsScalarType(types[2], DataType::Int(32))); // zero_point
50+
ICHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale
51+
ICHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point
52+
53+
// Assign types for scale and zero points.
54+
reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // scale
55+
reporter->Assign(types[2], TensorType({}, DataType::Int(32))); // zero_point
56+
reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // output_scale
57+
reporter->Assign(types[4], TensorType({}, DataType::Int(32))); // output_zero_point
58+
59+
// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
60+
// IdentityRel infer type function.
61+
Array<Type> tensor_types = {types[0], types[5]};
62+
return IdentityRel(tensor_types, 2, attrs, reporter);
63+
}
64+
65+
// Positional relay function to create quantized rsqrt operator used by frontend FFI.
66+
Expr MakeQuantizedRsqrt(Expr x, Expr scale, Expr zero_point, Expr output_scale,
67+
Expr output_zero_point) {
68+
static const Op& op = Op::Get("qnn.rsqrt");
69+
return Call(op, {x, scale, zero_point, output_scale, output_zero_point}, Attrs(), {});
70+
}
71+
72+
/*
73+
* \brief Canonicalizes the QNN rsqrt op.
74+
* \param attrs The empty attribute.
75+
* \param new_args The new mutated args to the call node.
76+
* \param arg_types The types of input and output.
77+
* \return The sequence of Relay ops for add op.
78+
*/
79+
Expr QnnRsqrtCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
80+
const Array<tvm::relay::Type>& arg_types) {
81+
// Get the args.
82+
QnnUnaryOpArguments args(new_args);
83+
84+
// Get the input dtype and shape.
85+
QnnUnaryOpTensorType input_type(arg_types, 0);
86+
87+
// Get the types for dequantize/quantize.
88+
Array<tvm::relay::Type> types;
89+
for (size_t i = 1; i < 5; ++i) {
90+
types.push_back(arg_types[i]);
91+
}
92+
93+
// Dequantize input.
94+
auto dequantized_arg = Dequantize(args.x, args.scale, args.zero_point, types, -1);
95+
96+
// Compute Rsqrt(Q_x')
97+
auto output = Rsqrt(dequantized_arg);
98+
99+
// Quantize output.
100+
return Quantize(output, args.output_scale, args.output_zero_point, input_type.dtype, types, -1);
101+
}
102+
103+
RELAY_REGISTER_OP("qnn.rsqrt")
104+
.describe("Elementwise rsqrt for quantized tensors.")
105+
.set_num_inputs(5)
106+
.add_argument("data", "Quantized Tensor", "The input data.")
107+
.add_argument("scale", "Tensor", "The quantization scale of the input tensor.")
108+
.add_argument("zero_point", "Tensor", "The quantization zero_point of the input tensor.")
109+
.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
110+
.add_argument("output_zero_point", "Tensor",
111+
"The quantization zero_point of the output tensor.")
112+
.set_support_level(11)
113+
.add_type_rel("QRsqrt", QnnRsqrtRel)
114+
.set_attr<TNonComputational>("TNonComputational", true)
115+
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnRsqrtCanonicalize);
116+
117+
TVM_REGISTER_GLOBAL("relay.qnn.op._make.rsqrt").set_body_typed(MakeQuantizedRsqrt);
118+
119+
} // namespace qnn
120+
} // namespace relay
121+
} // namespace tvm

src/relay/qnn/utils.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,33 @@ static inline Expr Requantize(const Expr& data, const Array<IndexExpr>& input_sh
109109
attrs.operator->(), input_shape, attrs->out_dtype);
110110
}
111111

112+
Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
113+
const Expr& input_zero_point, const Array<tvm::relay::Type>& types,
114+
const DequantizeAttrs* attrs);
115+
116+
static inline Expr Dequantize(const Expr& data, const Expr& input_scale,
117+
const Expr& input_zero_point, const Array<tvm::relay::Type>& types,
118+
const int& axis = -1) {
119+
auto attrs = make_object<DequantizeAttrs>();
120+
attrs->axis = std::move(axis);
121+
122+
return DequantizeLower(data, input_scale, input_zero_point, types, attrs.operator->());
123+
}
124+
125+
Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale,
126+
const Expr& output_zero_point, const Array<tvm::relay::Type>& types,
127+
const QuantizeAttrs* attrs);
128+
129+
static inline Expr Quantize(const Expr& data, const Expr& output_scale,
130+
const Expr& output_zero_point, const DataType& out_dtype,
131+
const Array<tvm::relay::Type>& types, const int& axis = -1) {
132+
auto attrs = make_object<QuantizeAttrs>();
133+
attrs->axis = std::move(axis);
134+
attrs->out_dtype = std::move(out_dtype);
135+
136+
return QuantizeLower(data, output_scale, output_zero_point, types, attrs.operator->());
137+
}
138+
112139
static inline int64_t get_const_int(const tvm::PrimExpr& x) {
113140
auto* value_ptr = tir::as_const_int(x);
114141
ICHECK(value_ptr) << "Expr is not a constant int";

src/relay/transforms/pattern_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,11 @@ inline Expr Sqrt(Expr x) {
550550
return Call(op, {x}, Attrs(), {});
551551
}
552552

553+
inline Expr Rsqrt(Expr x) {
554+
static const Op& op = Op::Get("rsqrt");
555+
return Call(op, {x}, Attrs(), {});
556+
}
557+
553558
inline Expr Relu(Expr x) {
554559
static const Op& op = Op::Get("nn.relu");
555560
return Call(op, {x}, Attrs(), {});
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import tvm
19+
import numpy as np
20+
from tvm import relay
21+
22+
23+
def dequantize(data, scale, zp):
24+
return scale * (np.asarray(data) - zp)
25+
26+
27+
def generate_golden_output(dequantized_x, output_scale, output_zero_point):
28+
rsqrt = 1 / np.sqrt(dequantized_x)
29+
output = np.around(rsqrt / output_scale + output_zero_point)
30+
31+
q_min = np.iinfo(np.uint8).min
32+
q_max = np.iinfo(np.uint8).max
33+
return np.clip(output, q_min, q_max)
34+
35+
36+
def test_saturation():
37+
# Same params
38+
data_dtype = "uint8"
39+
scale = output_scale = 0.125
40+
zero_point = output_zero_point = 0
41+
42+
x = relay.var("x", shape=(1, 4), dtype=data_dtype)
43+
y = relay.qnn.op.rsqrt(
44+
x=x,
45+
scale=relay.const(scale, "float32"),
46+
zero_point=relay.const(zero_point, "int32"),
47+
output_scale=relay.const(output_scale, "float32"),
48+
output_zero_point=relay.const(output_zero_point, "int32"),
49+
)
50+
51+
func = relay.Function([x], y)
52+
mod = tvm.IRModule.from_expr(func)
53+
mod = relay.transform.InferType()(mod)
54+
mod = relay.qnn.transform.CanonicalizeOps()(mod)
55+
func = mod["main"]
56+
57+
x_data = np.array((255, 133, 0, 9)).reshape((1, 4))
58+
x_dequantized = dequantize(x_data, scale, zero_point)
59+
golden_output = generate_golden_output(x_dequantized, output_scale, output_zero_point)
60+
61+
op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)(x_data)
62+
63+
np.testing.assert_equal(op_res.numpy(), np.uint8(golden_output))
64+
65+
# Different scale
66+
scale = 0.125
67+
output_scale = 0.25
68+
69+
y = relay.qnn.op.rsqrt(
70+
x=x,
71+
scale=relay.const(scale, "float32"),
72+
zero_point=relay.const(zero_point, "int32"),
73+
output_scale=relay.const(output_scale, "float32"),
74+
output_zero_point=relay.const(output_zero_point, "int32"),
75+
)
76+
77+
func = relay.Function([x], y)
78+
mod = tvm.IRModule.from_expr(func)
79+
mod = relay.transform.InferType()(mod)
80+
mod = relay.qnn.transform.CanonicalizeOps()(mod)
81+
func = mod["main"]
82+
83+
x_data = np.array((255, 133, 0, 9)).reshape((1, 4))
84+
x_dequantized = dequantize(x_data, scale, zero_point)
85+
golden_output = generate_golden_output(x_dequantized, output_scale, output_zero_point)
86+
87+
op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)(x_data)
88+
89+
np.testing.assert_equal(op_res.numpy(), golden_output)
90+
91+
92+
if __name__ == "__main__":
93+
test_saturation()

0 commit comments

Comments
 (0)