Skip to content

Commit 7dfb0c1

Browse files
sfvarogluLucien0
authored andcommitted
[QNN] Add per-channel quantization to add/subtract/multiply (apache#10718)
* Add per-channel quantization to QNN add/subtract/multiply * Add feedback * Add feedback - round 2 * Fix for arm test * Add params to the test * Try again * Try int * Move lhs_axis and rhs_axis * Add as an attribute * Add quotes
1 parent 32e155b commit 7dfb0c1

File tree

11 files changed

+374
-52
lines changed

11 files changed

+374
-52
lines changed

include/tvm/relay/qnn/attrs.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,25 @@ struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
106106
}
107107
};
108108

109+
/*! \brief Attribute for broadcast operator */
110+
struct BroadcastAttrs : public tvm::AttrsNode<BroadcastAttrs> {
111+
int lhs_axis;
112+
int rhs_axis;
113+
114+
TVM_DECLARE_ATTRS(BroadcastAttrs, "relay.attrs.BroadcastAttrs") {
115+
TVM_ATTR_FIELD(lhs_axis)
116+
.describe(
117+
"The channel axis for channel wise broadcast. Default value is -1,"
118+
"which corresponds to the last axis.")
119+
.set_default(-1);
120+
TVM_ATTR_FIELD(rhs_axis)
121+
.describe(
122+
"The channel axis for channel wise broadcast. Default value is -1,"
123+
"which corresponds to the last axis.")
124+
.set_default(-1);
125+
}
126+
};
127+
109128
} // namespace qnn
110129
} // namespace relay
111130
} // namespace tvm

python/tvm/relay/op/op_attrs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,11 @@ class OneHotAttrs(Attrs):
494494
"""Attributes used in one_hot operators"""
495495

496496

497+
@tvm._ffi.register_object("relay.attrs.BroadcastAttrs")
498+
class BroadcastAttrs(Attrs):
499+
"""Attributes used in broadcast operators"""
500+
501+
497502
@tvm._ffi.register_object("relay.attrs.QuantizeAttrs")
498503
class QuantizeAttrs(Attrs):
499504
"""Attributes used in quantize operators"""

python/tvm/relay/qnn/op/qnn.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,16 @@ def conv2d_transpose(
593593

594594

595595
def add(
596-
lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
596+
lhs,
597+
rhs,
598+
lhs_scale,
599+
lhs_zero_point,
600+
rhs_scale,
601+
rhs_zero_point,
602+
output_scale,
603+
output_zero_point,
604+
lhs_axis=-1,
605+
rhs_axis=-1,
597606
):
598607
"""Quantized addition with numpy-style broadcasting.
599608
@@ -623,6 +632,14 @@ def add(
623632
output_zero_point: relay.Expr
624633
The zero point of output quantized expr.
625634
635+
lhs_axis: int
636+
The channel axis for lhs quantization. Default value is -1 which corresponds
637+
to the last axis.
638+
639+
rhs_axis: int
640+
The channel axis for rhs quantization. Default value is -1 which corresponds
641+
to the last axis.
642+
626643
Returns
627644
-------
628645
result : relay.Expr
@@ -638,6 +655,8 @@ def add(
638655
rhs_zero_point,
639656
output_scale,
640657
output_zero_point,
658+
lhs_axis,
659+
rhs_axis,
641660
)
642661

643662

@@ -702,7 +721,16 @@ def dense(
702721

703722

704723
def mul(
705-
lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
724+
lhs,
725+
rhs,
726+
lhs_scale,
727+
lhs_zero_point,
728+
rhs_scale,
729+
rhs_zero_point,
730+
output_scale,
731+
output_zero_point,
732+
lhs_axis=-1,
733+
rhs_axis=-1,
706734
):
707735
"""Quantized multiplication with numpy-style broadcasting.
708736
@@ -732,6 +760,14 @@ def mul(
732760
output_zero_point: relay.Expr
733761
The zero point of output quantized expr.
734762
763+
lhs_axis: int
764+
The channel axis for lhs quantization. Default value is -1 which corresponds
765+
to the last axis.
766+
767+
rhs_axis: int
768+
The channel axis for rhs quantization. Default value is -1 which corresponds
769+
to the last axis.
770+
735771
Returns
736772
-------
737773
result : relay.Expr
@@ -747,6 +783,8 @@ def mul(
747783
rhs_zero_point,
748784
output_scale,
749785
output_zero_point,
786+
lhs_axis,
787+
rhs_axis,
750788
)
751789

752790

@@ -961,7 +999,16 @@ def sigmoid(x, scale, zero_point, output_scale, output_zero_point):
961999

9621000

9631001
def subtract(
964-
lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
1002+
lhs,
1003+
rhs,
1004+
lhs_scale,
1005+
lhs_zero_point,
1006+
rhs_scale,
1007+
rhs_zero_point,
1008+
output_scale,
1009+
output_zero_point,
1010+
lhs_axis=-1,
1011+
rhs_axis=-1,
9651012
):
9661013
"""Quantized subtraction with numpy-style broadcasting.
9671014
@@ -991,6 +1038,14 @@ def subtract(
9911038
output_zero_point: relay.Expr
9921039
The zero point of output quantized expr.
9931040
1041+
lhs_axis: int
1042+
The channel axis for lhs quantization. Default value is -1 which corresponds
1043+
to the last axis.
1044+
1045+
rhs_axis: int
1046+
The channel axis for rhs quantization. Default value is -1 which corresponds
1047+
to the last axis.
1048+
9941049
Returns
9951050
-------
9961051
result : relay.Expr
@@ -1006,6 +1061,8 @@ def subtract(
10061061
rhs_zero_point,
10071062
output_scale,
10081063
output_zero_point,
1064+
lhs_axis,
1065+
rhs_axis,
10091066
)
10101067

10111068

python/tvm/relay/transform/fake_quantization_to_integer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,8 @@ def binary(expr, type_map):
451451
right_t.zero_point,
452452
out_t.scale,
453453
out_t.zero_point,
454+
left_t.axis,
455+
right_t.axis,
454456
)
455457

456458
return [out, out_t]

src/relay/qnn/op/add.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
4545
// Get the input dtype and shape.
4646
QnnBinaryOpTensorType input_type(arg_types, 0);
4747

48+
const auto* broadcast_attrs = attrs.as<BroadcastAttrs>();
49+
ICHECK(broadcast_attrs != nullptr);
50+
51+
auto lhs_axis = broadcast_attrs->lhs_axis;
52+
auto rhs_axis = broadcast_attrs->rhs_axis;
53+
4854
// FIXME (anijain2305) - The lowering can be further optimized. Instead of inserting requantize in
4955
// the start, we can insert requantize at the end if both input tensors have same qnn params. In
5056
// that case, we can first add the tensors, subtract the zero point, and requantize at the end.
@@ -68,11 +74,11 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
6874
// Requantize LHS if necessary. Computes Q_a'
6975
auto requantized_lhs =
7076
RequantizeOrUpcast(args.lhs, args.lhs_scale, args.lhs_zero_point, args.output_scale,
71-
args.output_zero_point, input_type.shape);
77+
args.output_zero_point, input_type.shape, lhs_axis);
7278
// Requantize RHS if necessary. Computes Q_b'
7379
auto requantized_rhs =
7480
RequantizeOrUpcast(args.rhs, args.rhs_scale, args.rhs_zero_point, args.output_scale,
75-
args.output_zero_point, input_type.shape);
81+
args.output_zero_point, input_type.shape, rhs_axis);
7682
// Computes Q_a' + Q_b'
7783
auto output = Add(requantized_lhs, requantized_rhs);
7884

src/relay/qnn/op/mul.cc

Lines changed: 102 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ namespace qnn {
4242
*/
4343
Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
4444
const Array<tvm::relay::Type>& arg_types) {
45+
Expr output;
46+
4547
// Get the attrs.
4648
QnnBinaryOpArguments args(new_args);
4749

@@ -51,44 +53,108 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
5153
const auto int32_dtype = DataType::Int(32);
5254
const auto float32_dtype = DataType::Float(32);
5355

54-
/*
55-
A tensor multiplication c = a * b can be written in terms of respective
56-
quantized tensors, scales and zero points as
57-
S_c * (Q_c - zp_c) = S_a * (Q_a - zp_a) * S_b * (Q_b - zp_b).
58-
59-
We can consider the product (Q_a - zp_a) * (Q_b - zp_b) as a different
60-
quantized tensor of c, Q', with corresponding scale S' = S_a * S_b and zp' =
61-
0. The quantized multiplication then becomes
62-
Q_c = S'/S_c Q' + z_c,
63-
which is essentially a requantization of tensor Q' into tensor Q_c.
64-
*/
65-
66-
auto lhs_shifted = Cast(args.lhs, int32_dtype);
67-
auto rhs_shifted = Cast(args.rhs, int32_dtype);
68-
69-
auto zero_scalar = MakeConstantScalar(int32_dtype, 0);
70-
if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
71-
lhs_shifted = Subtract(lhs_shifted, args.lhs_zero_point);
56+
const auto* broadcast_attrs = attrs.as<BroadcastAttrs>();
57+
ICHECK(broadcast_attrs != nullptr);
58+
59+
auto lhs_axis = broadcast_attrs->lhs_axis;
60+
auto rhs_axis = broadcast_attrs->rhs_axis;
61+
62+
if (IsConstScalar(args.lhs_scale) && IsConstScalar(args.rhs_scale)) {
63+
/*
64+
This is per-tensor quantized multiply.
65+
66+
A tensor multiplication c = a * b can be written in terms of respective
67+
quantized tensors, scales and zero points as
68+
S_c * (Q_c - zp_c) = S_a * (Q_a - zp_a) * S_b * (Q_b - zp_b).
69+
70+
We can consider the product (Q_a - zp_a) * (Q_b - zp_b) as a different
71+
quantized tensor of c, Q', with corresponding scale S' = S_a * S_b and zp' =
72+
0. The quantized multiplication then becomes
73+
Q_c = S'/S_c Q' + z_c,
74+
which is essentially a requantization of tensor Q' into tensor Q_c.
75+
*/
76+
77+
auto lhs_shifted = Cast(args.lhs, int32_dtype);
78+
auto rhs_shifted = Cast(args.rhs, int32_dtype);
79+
80+
auto zero_scalar = MakeConstantScalar(int32_dtype, 0);
81+
if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
82+
lhs_shifted = Subtract(lhs_shifted, args.lhs_zero_point);
83+
}
84+
85+
if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) {
86+
rhs_shifted = Subtract(rhs_shifted, args.rhs_zero_point);
87+
}
88+
89+
// Create a new tensor Q'
90+
output = Multiply(lhs_shifted, rhs_shifted);
91+
92+
// Get the adjusted new scale and zero points.
93+
float lhs_scale_float = GetScalarFromConstant<float>(args.lhs_scale);
94+
float rhs_scale_float = GetScalarFromConstant<float>(args.rhs_scale);
95+
float new_scale_float = lhs_scale_float * rhs_scale_float;
96+
auto new_input_scale = MakeConstantScalar(float32_dtype, new_scale_float);
97+
auto new_input_zero_point = zero_scalar;
98+
99+
// Requantize to get Q_c
100+
output = Requantize(output, input_type.shape, new_input_scale, new_input_zero_point,
101+
args.output_scale, args.output_zero_point, input_type.dtype);
102+
} else if (lhs_axis == rhs_axis) {
103+
/*
104+
This is per-channel quantized multiply, assumming lhs_axis and rhs_axis are the same.
105+
The subtract is done on the specified axis via broadcast. Then, we multiply lhs and rhs.
106+
The output is requantized using new scale and axis. TODO: support different axes.
107+
*/
108+
109+
auto lhs_data = Cast(args.lhs, int32_dtype);
110+
auto rhs_data = Cast(args.rhs, int32_dtype);
111+
112+
auto zero_scalar = MakeConstantScalar(int32_dtype, 0);
113+
if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
114+
// Broadcast lhs zero point if needed
115+
int rank = static_cast<int>(input_type.shape.size());
116+
int axis = (lhs_axis < 0) ? ((rank > 0) ? rank + lhs_axis : 0) : lhs_axis;
117+
Expr lhs_zero_broadcast = ExpandBiasToMatchAxis(Reshape(args.lhs_zero_point,
118+
{
119+
-1,
120+
}),
121+
rank, {axis});
122+
lhs_data = Subtract(lhs_data, Cast(lhs_zero_broadcast, DataType::Int(32)));
123+
}
124+
125+
if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) {
126+
// Broadcast rhs zero point if needed
127+
int rank = static_cast<int>(input_type.shape.size());
128+
int axis = (rhs_axis < 0) ? ((rank > 0) ? rank + rhs_axis : 0) : rhs_axis;
129+
Expr rhs_zero_broadcast = ExpandBiasToMatchAxis(Reshape(args.rhs_zero_point,
130+
{
131+
-1,
132+
}),
133+
rank, {axis});
134+
rhs_data = Subtract(rhs_data, Cast(rhs_zero_broadcast, DataType::Int(32)));
135+
}
136+
137+
// Create a new tensor Q'
138+
output = Multiply(lhs_data, rhs_data);
139+
140+
// Requantize to get Q_c
141+
auto lhs_scales = GetFloatVectorFromConstant(args.lhs_scale);
142+
auto rhs_scales = GetFloatVectorFromConstant(args.rhs_scale);
143+
std::vector<double> output_multipliers;
144+
for (size_t i = 0; i < lhs_scales.size(); i++) {
145+
double multiplier = static_cast<double>(lhs_scales[i]) * static_cast<double>(rhs_scales[i]);
146+
output_multipliers.push_back(multiplier);
147+
}
148+
auto new_input_scale = MakeConstantTensor(
149+
DataType::Float(32), {(int64_t)output_multipliers.size()}, output_multipliers);
150+
151+
output = Requantize(output, input_type.shape, new_input_scale, zero_scalar, args.output_scale,
152+
args.output_zero_point, input_type.dtype, lhs_axis);
153+
154+
} else {
155+
LOG(FATAL) << "Not supported: lhs_axis and rhs_axis are not the same.";
72156
}
73157

74-
if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) {
75-
rhs_shifted = Subtract(rhs_shifted, args.rhs_zero_point);
76-
}
77-
78-
// Create a new tensor Q'
79-
auto output = Multiply(lhs_shifted, rhs_shifted);
80-
81-
// Get the adjusted new scale and zero points.
82-
float lhs_scale_float = GetScalarFromConstant<float>(args.lhs_scale);
83-
float rhs_scale_float = GetScalarFromConstant<float>(args.rhs_scale);
84-
float new_scale_float = lhs_scale_float * rhs_scale_float;
85-
auto new_input_scale = MakeConstantScalar(float32_dtype, new_scale_float);
86-
auto new_input_zero_point = zero_scalar;
87-
88-
// Requantize to get Q_c
89-
output = Requantize(output, input_type.shape, new_input_scale, new_input_zero_point,
90-
args.output_scale, args.output_zero_point, input_type.dtype);
91-
92158
return output;
93159
}
94160

0 commit comments

Comments
 (0)