Skip to content

Commit 8418026

Browse files
margaretqianMargaret Qian
andauthored
[FQ2I] Add leaky relu to FQ21 (#10378)
* add leaky relu op + passing unit test * passing test * format * clean up * lekay relu qnn op * wip * qnn op * add comment * lint Co-authored-by: Margaret Qian <[email protected]>
1 parent 7d5ef84 commit 8418026

File tree

5 files changed

+244
-0
lines changed

5 files changed

+244
-0
lines changed

python/tvm/relay/qnn/op/qnn.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,3 +1050,30 @@ def batch_matmul(x, y, x_zero_point, y_zero_point, x_scale, y_scale, out_dtype="
10501050
# register fuse pattern for qnn ops
10511051
reg.register_pattern("qnn.quantize", OpPattern.OPAQUE)
10521052
reg.register_pattern("qnn.dequantize", OpPattern.OPAQUE)
1053+
1054+
1055+
def leaky_relu(x, alpha, scale, zero_point):
1056+
"""Quantized leaky relu.
1057+
1058+
Parameters
1059+
----------
1060+
x : relay.Expr
1061+
The quantized input tensor.
1062+
alpha: double
1063+
The alpha value.
1064+
scale: relay.Expr
1065+
The scale of the quantized expr.
1066+
zero_point: relay.Expr
1067+
The zero point of quantized expr.
1068+
1069+
Returns
1070+
-------
1071+
result : relay.Expr
1072+
The computed result.
1073+
"""
1074+
return _make.leaky_relu(
1075+
x,
1076+
alpha,
1077+
scale,
1078+
zero_point,
1079+
)

python/tvm/relay/transform/fake_quantization_to_integer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,16 @@ def relu(expr, type_map):
346346
return [relay.op.maximum(arg, fold_constant(zero)), t]
347347

348348

349+
@register_fake_quantization_to_integer("nn.leaky_relu")
350+
def leaky_relu(expr, type_map):
351+
"""Rewrite a leaky relu op"""
352+
arg = expr.args[0]
353+
t = type_map[arg]
354+
alpha = expr.attrs.alpha
355+
output = relay.qnn.op.leaky_relu(expr, alpha, t.scale, t.zero_point)
356+
return [output, t]
357+
358+
349359
@register_fake_quantization_to_integer("nn.pad")
350360
def pad(expr, type_map):
351361
"""Rewite an nn.pad op"""

src/relay/qnn/op/leaky_relu.cc

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file src/relay/qnn/op/leaky_relu.cc
22+
* \brief QNN leaky relu operator.
23+
*/
24+
#include <tvm/relay/analysis.h>
25+
#include <tvm/relay/op_attr_types.h>
26+
27+
#include "op_common.h"
28+
29+
namespace tvm {
30+
namespace relay {
31+
namespace qnn {
32+
33+
bool QnnLeakyReluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
34+
const TypeReporter& reporter) {
35+
// Expected Types: data, scale, zero_point
36+
ICHECK_EQ(types.size(), 4);
37+
const auto* x = types[0].as<TensorTypeNode>();
38+
if (x == nullptr) return false;
39+
ICHECK(x->dtype == DataType::Int(8) || x->dtype == DataType::UInt(8))
40+
<< "Expected quantized leaky_relu type(int8, uint8) for input but was " << x->dtype;
41+
const auto* param = attrs.as<LeakyReluAttrs>();
42+
ICHECK(param != nullptr) << "LeakyReluAttrs cannot be nullptr.";
43+
44+
// Check the types of scale and zero points.
45+
for (size_t i = 1; i < 3; ++i) {
46+
if (types[i].as<IncompleteTypeNode>()) {
47+
return false;
48+
}
49+
}
50+
51+
ICHECK(IsScalarType(types[1], DataType::Float(32))); // scale
52+
ICHECK(IsScalarType(types[2], DataType::Int(32))); // zero_point
53+
54+
// Assign types for scale and zero points.
55+
reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // scale
56+
reporter->Assign(types[2], TensorType({}, DataType::Int(32))); // zero_point
57+
58+
// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
59+
// IdentityRel infer type function.
60+
Array<Type> tensor_types = {types[0], types[3]};
61+
return IdentityRel(tensor_types, 2, attrs, reporter);
62+
}
63+
64+
// Positional relay function to create quantized leaky relu operator used by frontend FFI.
65+
Expr MakeQuantizedLeakyRelu(Expr x, double alpha, Expr scale, Expr zero_point) {
66+
auto attrs = make_object<LeakyReluAttrs>();
67+
attrs->alpha = alpha;
68+
static const Op& op = Op::Get("qnn.leaky_relu");
69+
return Call(op, {x, scale, zero_point}, Attrs(attrs), {});
70+
}
71+
72+
/*
73+
* \brief Canonicalizes the QNN leaky relu op.
74+
* \param attrs The empty attribute.
75+
* \param new_args The new mutated args to the call node.
76+
* \param arg_types The types of input and output.
77+
* \return The sequence of Relay ops for leaky relu op.
78+
*/
79+
Expr QnnLeakyReluCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
80+
const Array<tvm::relay::Type>& arg_types) {
81+
// We rely on fixed point arithmetic to preserve the precision of multiplication
82+
// by a small alpha value < 1.
83+
//
84+
// We assume the same scale and zero point for alpha and the input tensor.
85+
// Let T = s(q_t - z) where q_t is the input arg[0]
86+
// Then, the quantized value of alpha * T is:
87+
// q(a * T, s, z) = [(a * T) / s] + z = a * s(q_t - z) / s + z = a * (q_t - z) + z
88+
// = a * q_t + (1 - a) * z
89+
//
90+
// We return the quantized value of alpha * T for all values q_t < input_zero_point.
91+
92+
ICHECK_EQ(new_args.size(), 3);
93+
Expr quantized_data = Cast(new_args[0], DataType::Int(32));
94+
Expr input_zero_point = Cast(new_args[2], DataType::Int(32));
95+
96+
const auto* q_attrs = attrs.as<LeakyReluAttrs>();
97+
auto alpha = q_attrs->alpha;
98+
99+
int32_t fixed_point_multiplier, shift;
100+
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(alpha);
101+
auto prod = FixedPointMultiply(quantized_data, fixed_point_multiplier, shift);
102+
103+
int32_t fixed_point_multiplier_z, shift_z;
104+
std::tie(fixed_point_multiplier_z, shift_z) = GetFixedPointMultiplierShift(1 - alpha);
105+
auto scaled_z = FixedPointMultiply(input_zero_point, fixed_point_multiplier_z, shift_z);
106+
107+
auto add = Add(prod, scaled_z);
108+
auto output = Where(Less(quantized_data, input_zero_point), add, quantized_data);
109+
110+
const auto* input_type = arg_types[0].as<TensorTypeNode>();
111+
return ConvertDtype(output, input_type->dtype);
112+
}
113+
114+
RELAY_REGISTER_OP("qnn.leaky_relu")
115+
.describe("Leaky relu for quantized tensors.")
116+
.set_attrs_type<LeakyReluAttrs>()
117+
.set_num_inputs(3)
118+
.add_argument("data", "Quantized Tensor", "The input data.")
119+
.add_argument("scale", "Tensor", "The quantization scale of the input tensor.")
120+
.add_argument("zero_point", "Tensor", "The quantization zero_point of the input tensor.")
121+
.set_support_level(11)
122+
.add_type_rel("QLeakyRelu", QnnLeakyReluRel)
123+
.set_attr<TNonComputational>("TNonComputational", true)
124+
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnLeakyReluCanonicalize);
125+
126+
TVM_REGISTER_GLOBAL("relay.qnn.op._make.leaky_relu").set_body_typed(MakeQuantizedLeakyRelu);
127+
128+
} // namespace qnn
129+
} // namespace relay
130+
} // namespace tvm
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import tvm
19+
import numpy as np
20+
from tvm import relay
21+
22+
23+
def dequantize(data, scale, zp):
24+
return scale * (np.asarray(data) - zp)
25+
26+
27+
def generate_golden_output(x_data, dequantized_x, alpha, scale, zero_point):
28+
prod = np.multiply(dequantized_x, alpha)
29+
prod = np.around(prod / scale + zero_point)
30+
31+
output = np.where(x_data < zero_point, prod, x_data)
32+
return output
33+
34+
35+
def test_qnn_leaky_relu():
36+
data_dtype = "uint8"
37+
scale = 0.125
38+
zero_point = 60
39+
alpha = 0.9
40+
41+
x = relay.var("x", shape=(1, 4), dtype=data_dtype)
42+
y = relay.qnn.op.leaky_relu(
43+
x=x,
44+
alpha=alpha,
45+
scale=relay.const(scale, "float32"),
46+
zero_point=relay.const(zero_point, "int32"),
47+
)
48+
49+
func = relay.Function([x], y)
50+
mod = tvm.IRModule.from_expr(func)
51+
mod = relay.transform.InferType()(mod)
52+
mod = relay.qnn.transform.CanonicalizeOps()(mod)
53+
func = mod["main"]
54+
55+
x_data = np.array((255, 133, 0, 9)).reshape((1, 4))
56+
x_dequantized = dequantize(x_data, scale, zero_point)
57+
golden_output = generate_golden_output(x_data, x_dequantized, alpha, scale, zero_point)
58+
59+
op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)(x_data)
60+
61+
np.testing.assert_equal(op_res.numpy(), golden_output)
62+
63+
64+
if __name__ == "__main__":
65+
test_qnn_leaky_relu()

tests/python/relay/test_pass_fake_quantization_to_integer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,18 @@ def test_fake_quantize_relu_per_channel():
551551
compare_fq_to_int(op, [x_np])
552552

553553

554+
def test_fake_quantize_leaky_relu():
555+
x = relay.var("x", shape=[1, 3, 224, 224], dtype="uint8")
556+
557+
x = relay.qnn.op.dequantize(x, relay.const(2.0), relay.const(114))
558+
op = relay.op.nn.leaky_relu(x, 0.1)
559+
op = relay.qnn.op.quantize(op, relay.const(2.0), relay.const(114), out_dtype="uint8")
560+
561+
x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8")
562+
563+
compare_fq_to_int(op, [x_np], True)
564+
565+
554566
@pytest.mark.parametrize(
555567
"operator",
556568
[relay.op.add, relay.op.multiply, relay.op.subtract, relay.op.minimum, relay.op.maximum],

0 commit comments

Comments
 (0)