Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick]Update quantization round and clip calculation methods #43829

Merged
merged 3 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() {
.End()
.AddAttr("bit_length")
.IsIntIn({8, 16})
.End()
.AddAttr("round_type")
.IsOptional()
.IsIntIn({0, 1})
.End();
AddOpCompat(OpCompat("fake_channel_wise_quantize_dequantize_abs_max"))
.AddInput("X")
Expand All @@ -61,6 +65,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() {
.End()
.AddAttr("quant_axis")
.IsIntIn({0, 1})
.End()
.AddAttr("round_type")
.IsOptional()
.IsIntIn({0, 1})
.End();
}
// Delete quant_dequant_op, then quantize and dequantize weight
Expand Down Expand Up @@ -96,15 +104,18 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
auto var_map = any_op2_desc->Inputs();
std::string arg_name = "";
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(), name_m.second.end(),
if (std::find(name_m.second.begin(),
name_m.second.end(),
quant_dequant_op_out_name) != name_m.second.end()) {
arg_name = name_m.first;
break;
}
}
PADDLE_ENFORCE_GT(arg_name.size(), 0, platform::errors::InvalidArgument(
"can not find the input %s.",
quant_dequant_op_out_name));
PADDLE_ENFORCE_GT(
arg_name.size(),
0,
platform::errors::InvalidArgument("can not find the input %s.",
quant_dequant_op_out_name));
// any_op2_desc->SetAttr("enable_int8", true);
any_op2_desc->SetAttr("bit_length", bit_length);

Expand All @@ -123,7 +134,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
int quant_axis =
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis"));
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but "
"the received is %d",
Expand Down Expand Up @@ -176,7 +188,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
}
}
for (int i = 0; i < channel; i++) {
PADDLE_ENFORCE_NE(weight_scale[i], 0,
PADDLE_ENFORCE_NE(weight_scale[i],
0,
platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero."));
weight_scale[i] = weight_scale[i] / range;
Expand All @@ -188,7 +201,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
abs_max_weight =
std::max(abs_max_weight, std::abs(quantized_weight_data[j]));
}
PADDLE_ENFORCE_NE(abs_max_weight, 0,
PADDLE_ENFORCE_NE(abs_max_weight,
0,
platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero"));
weight_scale.push_back(abs_max_weight / range);
Expand Down
11 changes: 10 additions & 1 deletion paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() {
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X")
Expand All @@ -74,6 +78,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() {
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
}
// Delete quantize_linear_op dequantize_linear_op, then add input_scales
Expand Down Expand Up @@ -112,7 +120,8 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
const LoDTensor& input_scale_tensor =
scope->GetVar(quantize_linear_op_scale->Name())->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(input_scale_tensor.place()), true,
paddle::platform::is_cpu_place(input_scale_tensor.place()),
true,
platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>();
Expand Down
26 changes: 20 additions & 6 deletions paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() {
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X")
Expand All @@ -72,6 +76,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() {
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
Expand Down Expand Up @@ -322,7 +330,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
int quant_axis = BOOST_GET_CONST(
int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis"));
if (quant_axis == -1) { // per_layer quant_dequant: all OP
PADDLE_ENFORCE_EQ(weight_scale_nums, 1,
PADDLE_ENFORCE_EQ(weight_scale_nums,
1,
platform::errors::InvalidArgument(
"When quant_axis == -1 means use per_layer "
"quant_dequant, weight_scale'number should be 1."));
Expand All @@ -335,11 +344,13 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
} else if (quant_axis == 0) { // per_channel quant_dequant: conv2d,
// depthwise_conv2d, conv2d_fusion
PADDLE_ENFORCE_EQ(
weight_scale_nums, w_dims[quant_axis],
weight_scale_nums,
w_dims[quant_axis],
platform::errors::InvalidArgument(
"When quant_axis == 0 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."));
PADDLE_ENFORCE_EQ(w_dims.size(), 4,
PADDLE_ENFORCE_EQ(w_dims.size(),
4,
platform::errors::InvalidArgument(
"When quant_axis == 0 means use per_channel "
"quant_dequant, (conv2d, depthwise_conv2d, "
Expand All @@ -352,15 +363,17 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
}
} else if (quant_axis == 1) {
PADDLE_ENFORCE_EQ(
weight_scale_nums, w_dims[quant_axis],
weight_scale_nums,
w_dims[quant_axis],
platform::errors::InvalidArgument(
"When quant_axis == 1 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."));

if (w_dims.size() == 4) { // conv2d_transpose
std::string quantized_op_type = any_op2->Op()->Type();
PADDLE_ENFORCE_EQ(
quantized_op_type, "conv2d_transpose",
quantized_op_type,
"conv2d_transpose",
platform::errors::InvalidArgument(
"When quant_axis == 1 means use per_channel quant_dequant, "
"only conv2d_transpose weight dims equal 4."));
Expand Down Expand Up @@ -388,7 +401,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
weight_tensor->Resize(phi::make_ddim(phi::vectorize(w_dims)));
float* new_quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_quantized_weight_data, weight_data_tmp.data(),
memcpy(new_quantized_weight_data,
weight_data_tmp.data(),
weight_tensor->numel() * sizeof(float));

nodes2rm.insert(weight_dequantize_linear_op_scale);
Expand Down
Loading