Skip to content

Commit 89910c6

Browse files
committed
Add feedback - round 2
1 parent 62a5cc1 commit 89910c6

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

python/tvm/relay/qnn/op/qnn.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -633,10 +633,12 @@ def add(
633633
The zero point of output quantized expr.
634634
635635
lhs_axis: int
636-
The channel axis for lhs quantization. Default value is -1.
636+
The channel axis for lhs quantization. Default value is -1 which corresponds
637+
to the last axis.
637638
638639
rhs_axis: int
639-
The channel axis for rhs quantization. Default value is -1.
640+
The channel axis for rhs quantization. Default value is -1 which corresponds
641+
to the last axis.
640642
641643
Returns
642644
-------
@@ -759,10 +761,12 @@ def mul(
759761
The zero point of output quantized expr.
760762
761763
lhs_axis: int
762-
The channel axis for lhs quantization. Default value is -1.
764+
The channel axis for lhs quantization. Default value is -1 which corresponds
765+
to the last axis.
763766
764767
rhs_axis: int
765-
The channel axis for rhs quantization. Default value is -1.
768+
The channel axis for rhs quantization. Default value is -1 which corresponds
769+
to the last axis.
766770
767771
Returns
768772
-------
@@ -1035,10 +1039,12 @@ def subtract(
10351039
The zero point of output quantized expr.
10361040
10371041
lhs_axis: int
1038-
The channel axis for lhs quantization. Default value is -1.
1042+
The channel axis for lhs quantization. Default value is -1 which corresponds
1043+
to the last axis.
10391044
10401045
rhs_axis: int
1041-
The channel axis for rhs quantization. Default value is -1.
1046+
The channel axis for rhs quantization. Default value is -1 which corresponds
1047+
to the last axis.
10421048
10431049
Returns
10441050
-------

src/relay/qnn/op/mul.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
5959
auto lhs_axis = broadcast_attrs->lhs_axis;
6060
auto rhs_axis = broadcast_attrs->rhs_axis;
6161

62-
if (lhs_axis == -1 && rhs_axis == -1) {
62+
if (IsConstScalar(args.lhs_scale) && IsConstScalar(args.rhs_scale)) {
6363
/*
6464
This is per-tensor quantized multiply.
6565

tests/python/relay/test_pass_fake_quantization_to_integer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,15 @@ def verify_binary_per_channel(lhs_scale, rhs_scale, lhs_zp, rhs_zp, out_zp, lhs_
654654
lhs_axis=1,
655655
rhs_axis=1,
656656
)
657+
verify_binary_per_channel(
658+
lhs_scale=np.random.uniform(1.0, 5.0, 224),
659+
rhs_scale=np.random.uniform(1.0, 5.0, 224),
660+
lhs_zp=np.random.randint(1, 3),
661+
rhs_zp=np.random.randint(1, 3),
662+
out_zp=np.random.randint(1, 3),
663+
lhs_axis=-1,
664+
rhs_axis=-1,
665+
)
657666

658667
# Different axes
659668
verify_binary_per_channel(

0 commit comments

Comments
 (0)