Skip to content

Commit 62a5cc1

Browse files
committed
Add feedback
1 parent 504d1a8 commit 62a5cc1

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

src/relay/qnn/op/op_common.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,8 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con
261261
auto lhs_rank = static_cast<int>(lhs_data->shape.size());
262262
auto rhs_rank = static_cast<int>(rhs_data->shape.size());
263263

264-
lhs_axis = (lhs_axis < 0) ? ((lhs_rank > 0) ? lhs_data->shape.size() + lhs_axis : 0) : lhs_axis;
265-
rhs_axis = (rhs_axis < 0) ? ((rhs_rank > 0) ? rhs_data->shape.size() + rhs_axis : 0) : rhs_axis;
264+
lhs_axis = (lhs_axis < 0) ? ((lhs_rank > 0) ? lhs_rank + lhs_axis : 0) : lhs_axis;
265+
rhs_axis = (rhs_axis < 0) ? ((rhs_rank > 0) ? rhs_rank + rhs_axis : 0) : rhs_axis;
266266

267267
// If zero point and scale are scalar then axis doesn't matter.
268268
bool lhs_scale_is_scalar = (types[2].as<TensorTypeNode>())->shape.size() == 0;
@@ -349,7 +349,6 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con
349349
.add_argument("lhs_axis", "Tensor", "The channel quantization of the lhs tensor.") \
350350
.add_argument("rhs_axis", "Tensor", "The channel quantization of the rhs tensor.") \
351351
.add_type_rel("QnnBroadcast", QnnBroadcastRel) \
352-
.set_attr<TOpPattern>("TOpPattern", kOpaque) \
353352
.set_attr<TNonComputational>("TNonComputational", true) \
354353
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnBinaryBroadcastLayout)
355354

tests/python/relay/test_pass_fake_quantization_to_integer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -611,17 +611,14 @@ def verify_binary_per_channel(lhs_scale, rhs_scale, lhs_zp, rhs_zp, out_zp, lhs_
611611
rhs_axis = lhs_axis # TODO: Support different axes for per-channel quantized multiply
612612
else:
613613
out_scale = relay.const(0.1)
614+
614615
x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8")
615616
x = relay.qnn.op.dequantize(x, relay.const(lhs_scale), relay.const(lhs_zp), axis=lhs_axis)
616617

617618
y = relay.var("y", shape=[1, 3, 224, 224], dtype="int8")
618619
y = relay.qnn.op.dequantize(y, relay.const(rhs_scale), relay.const(rhs_zp), axis=rhs_axis)
619620

620621
op = operator(x, y)
621-
if operator == relay.op.multiply:
622-
out_scale = relay.const(2.0)
623-
else:
624-
out_scale = relay.const(0.1)
625622

626623
op = relay.qnn.op.quantize(op, out_scale, relay.const(out_zp), out_dtype="int8")
627624
x_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8")

0 commit comments

Comments
 (0)