Skip to content

Commit

Permalink
[cherry-pick]Update quantization round and clip calculation methods (#…
Browse files Browse the repository at this point in the history
…43829)

* update quantization clip and round

* fix quantization clip and round Attribute

* fix typo
  • Loading branch information
yghstill authored Jun 27, 2022
1 parent 9e776f6 commit ff70a26
Show file tree
Hide file tree
Showing 20 changed files with 2,406 additions and 1,538 deletions.
28 changes: 21 additions & 7 deletions paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() {
.End()
.AddAttr("bit_length")
.IsIntIn({8, 16})
.End()
.AddAttr("round_type")
.IsOptional()
.IsIntIn({0, 1})
.End();
AddOpCompat(OpCompat("fake_channel_wise_quantize_dequantize_abs_max"))
.AddInput("X")
Expand All @@ -61,6 +65,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() {
.End()
.AddAttr("quant_axis")
.IsIntIn({0, 1})
.End()
.AddAttr("round_type")
.IsOptional()
.IsIntIn({0, 1})
.End();
}
// Delete quant_dequant_op, then quantize and dequantize weight
Expand Down Expand Up @@ -96,15 +104,18 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
auto var_map = any_op2_desc->Inputs();
std::string arg_name = "";
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(), name_m.second.end(),
if (std::find(name_m.second.begin(),
name_m.second.end(),
quant_dequant_op_out_name) != name_m.second.end()) {
arg_name = name_m.first;
break;
}
}
PADDLE_ENFORCE_GT(arg_name.size(), 0, platform::errors::InvalidArgument(
"can not find the input %s.",
quant_dequant_op_out_name));
PADDLE_ENFORCE_GT(
arg_name.size(),
0,
platform::errors::InvalidArgument("can not find the input %s.",
quant_dequant_op_out_name));
// any_op2_desc->SetAttr("enable_int8", true);
any_op2_desc->SetAttr("bit_length", bit_length);

Expand All @@ -123,7 +134,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
int quant_axis =
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis"));
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but "
"the received is %d",
Expand Down Expand Up @@ -176,7 +188,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
}
}
for (int i = 0; i < channel; i++) {
PADDLE_ENFORCE_NE(weight_scale[i], 0,
PADDLE_ENFORCE_NE(weight_scale[i],
0,
platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero."));
weight_scale[i] = weight_scale[i] / range;
Expand All @@ -188,7 +201,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
abs_max_weight =
std::max(abs_max_weight, std::abs(quantized_weight_data[j]));
}
PADDLE_ENFORCE_NE(abs_max_weight, 0,
PADDLE_ENFORCE_NE(abs_max_weight,
0,
platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero"));
weight_scale.push_back(abs_max_weight / range);
Expand Down
11 changes: 10 additions & 1 deletion paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() {
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X")
Expand All @@ -74,6 +78,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() {
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
}
// Delete quantize_linear_op dequantize_linear_op, then add input_scales
Expand Down Expand Up @@ -112,7 +120,8 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
const LoDTensor& input_scale_tensor =
scope->GetVar(quantize_linear_op_scale->Name())->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(input_scale_tensor.place()), true,
paddle::platform::is_cpu_place(input_scale_tensor.place()),
true,
platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>();
Expand Down
26 changes: 20 additions & 6 deletions paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() {
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X")
Expand All @@ -72,6 +76,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() {
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
Expand Down Expand Up @@ -322,7 +330,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
int quant_axis = BOOST_GET_CONST(
int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis"));
if (quant_axis == -1) { // per_layer quant_dequant: all OP
PADDLE_ENFORCE_EQ(weight_scale_nums, 1,
PADDLE_ENFORCE_EQ(weight_scale_nums,
1,
platform::errors::InvalidArgument(
"When quant_axis == -1 means use per_layer "
"quant_dequant, weight_scale'number should be 1."));
Expand All @@ -335,11 +344,13 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
} else if (quant_axis == 0) { // per_channel quant_dequant: conv2d,
// depthwise_conv2d, conv2d_fusion
PADDLE_ENFORCE_EQ(
weight_scale_nums, w_dims[quant_axis],
weight_scale_nums,
w_dims[quant_axis],
platform::errors::InvalidArgument(
"When quant_axis == 0 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."));
PADDLE_ENFORCE_EQ(w_dims.size(), 4,
PADDLE_ENFORCE_EQ(w_dims.size(),
4,
platform::errors::InvalidArgument(
"When quant_axis == 0 means use per_channel "
"quant_dequant, (conv2d, depthwise_conv2d, "
Expand All @@ -352,15 +363,17 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
}
} else if (quant_axis == 1) {
PADDLE_ENFORCE_EQ(
weight_scale_nums, w_dims[quant_axis],
weight_scale_nums,
w_dims[quant_axis],
platform::errors::InvalidArgument(
"When quant_axis == 1 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."));

if (w_dims.size() == 4) { // conv2d_transpose
std::string quantized_op_type = any_op2->Op()->Type();
PADDLE_ENFORCE_EQ(
quantized_op_type, "conv2d_transpose",
quantized_op_type,
"conv2d_transpose",
platform::errors::InvalidArgument(
"When quant_axis == 1 means use per_channel quant_dequant, "
"only conv2d_transpose weight dims equal 4."));
Expand Down Expand Up @@ -388,7 +401,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
weight_tensor->Resize(phi::make_ddim(phi::vectorize(w_dims)));
float* new_quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_quantized_weight_data, weight_data_tmp.data(),
memcpy(new_quantized_weight_data,
weight_data_tmp.data(),
weight_tensor->numel() * sizeof(float));

nodes2rm.insert(weight_dequantize_linear_op_scale);
Expand Down
Loading

0 comments on commit ff70a26

Please sign in to comment.