Skip to content

Commit

Permalink
[QNN] Requantize - Optimize lowering for some corner cases. (apache#3864
Browse files Browse the repository at this point in the history
)
  • Loading branch information
anijain2305 authored and wweic committed Sep 16, 2019
1 parent 65f51cf commit 48dd80a
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 36 deletions.
79 changes: 43 additions & 36 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,48 +129,55 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
tensor = Subtract(tensor, input_zp);
}

// 3) Multiply the integer multiplier
if (left_shift != 0) {
tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift));
}
// Perform the multiplication in higher precision.
// The scalar is a fixed point value of int32 where the decimal point is
// between bits 31 and 30. After multiplying with input_tensor, the result is
// in int64 where the decimal point is sitting between bits 31 and 30 (from
// the right, rightmost bit is bit 0). The computation is performed in higher
// precision to avoid overflow in multiplying two int32 values.
Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
auto multiplied_t = Multiply(tensor, scalar);
// If the input and output scales are same, we can skip the fixed point multiplication.
auto scaled_int64_t = tensor;
if (param->input_scale != param->output_scale) {
// 3) Multiply the integer multiplier
if (left_shift != 0) {
tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift));
}
// Perform the multiplication in higher precision.
// The scalar is a fixed point value of int32 where the decimal point is
// between bits 31 and 30. After multiplying with input_tensor, the result is
// in int64 where the decimal point is sitting between bits 31 and 30 (from
// the right, rightmost bit is bit 0). The computation is performed in higher
// precision to avoid overflow in multiplying two int32 values.
Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
auto multiplied_t = Multiply(tensor, scalar);

// 4) Find the rounding scalar. This depends on where the final decimal point
// sits. As we will be right shifting the multiplied_t, we need to first
// calculate the total_right_shift.
int total_right_shift = right_shift + 31;
int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
// 4) Find the rounding scalar. This depends on where the final decimal point
// sits. As we will be right shifting the multiplied_t, we need to first
// calculate the total_right_shift.
int total_right_shift = right_shift + 31;
int64_t pos_rounding_value = (1ll << (total_right_shift - 1));

tensor = multiplied_t;
Expr round_scalar;
if (param->rounding == "UPWARD") {
round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
} else if (param->rounding == "TONEAREST") {
auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
tensor = multiplied_t;
Expr round_scalar;
if (param->rounding == "UPWARD") {
round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
} else if (param->rounding == "TONEAREST") {
auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);

auto zero = MakeConstantScalar(hp_dtype, 0);
auto zero_t = Full(zero, input_shape, hp_dtype);
round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
}
// Add the rounding scalar.
tensor = Add(tensor, round_scalar);
auto zero = MakeConstantScalar(hp_dtype, 0);
auto zero_t = Full(zero, input_shape, hp_dtype);
round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
}
// Add the rounding scalar.
tensor = Add(tensor, round_scalar);

// 5) Simply right shift the result to get the final output.
auto scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
// 5) Simply right shift the result to get the final output.
scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
}

// 6) Add the output zero point.
auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point);
auto shifted_int64_t = Add(output_zp, scaled_int64_t);
auto shifted_int64_t = scaled_int64_t;
if (param->output_zero_point != 0) {
auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point);
shifted_int64_t = Add(output_zp, scaled_int64_t);
}

// 7) Clip to the out_dtype min/max.
auto q_min = GetQmin(out_dtype);
Expand Down
1 change: 1 addition & 0 deletions tests/python/relay/test_qnn_requantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def same_scale_test():
input_scale=0.5,
output_scale=0.5,
rounding=rounding)
assert 'right_shift' not in mod.astext()
verify(mod, (golden_data, golden_output))

def downscale_test():
Expand Down

0 comments on commit 48dd80a

Please sign in to comment.