From c2d344dd05875c3aa275ca7d7d591f49fd30214c Mon Sep 17 00:00:00 2001 From: yeliang2258 <30516196+yeliang2258@users.noreply.github.com> Date: Thu, 20 Oct 2022 15:57:13 +0800 Subject: [PATCH] [cherry-pick] Fix quantize model deploy bug in MKLDNN (#47119) * Fix quantize model deploy bugs when using MKLDNN (#45920) * fix immutable op quantize bugs * fix * fix build bug * fix test * notest,test=inference * fix ppyoloe acc drop bugs * fix test * fix test * add test * fix * fix * fix test * fix refined name bug * fix test * bias fix * fix matmul weight dequant bug * re-ci * fix tester * fix test * fix tester * update weight dequantize func * update code * update test for converage * update test * update cmake * update cmakelist * update code * rerun ci * remove useless code * re-ci * update code * update code * fix header * update code for log --- .../framework/ir/graph_pattern_detector.h | 2 +- .../compute_propagate_scales_mkldnn_pass.cc | 2 +- .../framework/ir/mkldnn/cpu_quantize_pass.cc | 32 ++++-- .../framework/ir/mkldnn/cpu_quantize_pass.h | 4 +- .../ir/mkldnn/cpu_quantize_pass_tester.cc | 14 +-- .../mkldnn/params_quantization_mkldnn_pass.cc | 37 +++---- .../params_quantization_mkldnn_pass_tester.cc | 59 +++++++++-- .../ir/mkldnn/quant_dequant_mkldnn_pass.cc | 68 ++++++++----- .../ir/mkldnn/quant_dequant_mkldnn_pass.h | 3 +- .../fluid/inference/api/mkldnn_quantizer.cc | 2 +- ...st_onnx_format_quantization_mobilenetv1.py | 97 +++---------------- 11 files changed, 156 insertions(+), 164 deletions(-) mode change 100644 => 100755 paddle/fluid/framework/ir/graph_pattern_detector.h mode change 100644 => 100755 paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc mode change 100644 => 100755 paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc mode change 100644 => 100755 paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.cc mode change 100644 => 100755 paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass_tester.cc mode change 100644 => 100755 paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.h mode change 100644 => 100755 paddle/fluid/inference/api/mkldnn_quantizer.cc diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h old mode 100644 new mode 100755 index 27bb69b0506a7..4e658d535be9c --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1073,7 +1073,7 @@ struct ResidualElementwise : public PatternBase { }; // General struct for immutable ops: -// reshape, transpose, slice, shape, nearest-interp +// reshape, transpose, slice, nearest-interp // Forward pass for no weights-op. // immutable_out is a result of the operator. struct Immutable : public PatternBase { diff --git a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc old mode 100644 new mode 100755 index 954a579015847..9e9387dbeb6d6 --- a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc @@ -67,7 +67,7 @@ std::vector ComputePropagateScalesMkldnnPass::GetScales(Tensor* tensor, for (int i = 0; i < columns; i++) { float max_value = FLT_MIN; for (int j = 0; j < rows; j++) { - max_value = std::max(max_value, std::abs(data[i + j * columns])); + max_value = std::max(max_value, std::abs(data[j + i * rows])); } max_value = 1.0 / max_value; if (std::isinf(max_value) || std::isnan(max_value)) { diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index 4b91686ffab10..b8eddad1ce026 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -337,8 +337,10 @@ void CPUQuantizePass::GetQuantInfo(Graph* graph) const { graph, "has_quant_info", "var_quant_scales", var_quant_scales_); } -void CPUQuantizePass::QuantizeConv(Graph* graph, - bool with_residual_data) const { +void CPUQuantizePass::QuantizeConv( + Graph* graph, + bool with_residual_data, + std::vector* changed_weight) const { GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); patterns::ConvResidual conv_pattern{pattern, name_scope_}; @@ -411,7 +413,16 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, auto filter_scale_tensor = GetScaleTensorForNode(conv_filter); EigenVectorArrayMap eigen_tensor{filter_scale_tensor.data(), filter_scale_tensor.numel()}; - eigen_tensor *= static_cast(S8_MAX); + + // If the scale value of a weight is already multiplied by S8_MAX, it does + // not need to be multiplied again + if (std::find(changed_weight->begin(), + changed_weight->end(), + conv_filter->Name()) == changed_weight->end()) { + eigen_tensor *= static_cast(S8_MAX); + changed_weight->push_back(conv_filter->Name()); + } + std::vector filter_scale{ filter_scale_tensor.data(), filter_scale_tensor.data() + filter_scale_tensor.numel()}; @@ -697,6 +708,13 @@ void CPUQuantizePass::QuantizeImmutable(Graph* graph, return; } + // skip if the dtype of immutable_in is not float32 + auto dtype = immutable_in->Var()->GetDataType(); + if (dtype != proto::VarType::FP32) { + MarkAndLogCannotQuantizeOp(immutable_op, "The input dtype is not float."); + return; + } + if (!AreScalesPresentForNodes({immutable_out})) { MarkAndLogCannotQuantizeOp(immutable_op, "No scale available for the operator"); @@ -1156,9 +1174,12 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { param_scope(), platform::errors::InvalidArgument("Scope cannot be nullptr.")); + // Save the scale values of which weights have been processed to avoid + // secondary processing + std::vector changed_weight = {}; GetQuantInfo(graph); - QuantizeConv(graph, false /* with_residual_data */); - QuantizeConv(graph, true /* with_residual_data */); + QuantizeConv(graph, false /* with_residual_data */, &changed_weight); + QuantizeConv(graph, true /* with_residual_data */, &changed_weight); QuantizePool(graph); QuantizeConcat(graph); QuantizePriorBox(graph); @@ -1168,7 +1189,6 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { QuantizeImmutable(graph, "reshape2", "X"); QuantizeImmutable(graph, "transpose2", "X"); QuantizeImmutable(graph, "slice", "Input"); - QuantizeImmutable(graph, "shape", "Input"); QuantizeImmutable(graph, "nearest_interp", "X"); QuantizeImmutable(graph, "nearest_interp_v2", "X"); QuantizeElementwise(graph, "elementwise_add"); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h index f26d8bfc84c15..a7470520af197 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h @@ -49,7 +49,9 @@ class CPUQuantizePass : public FusePassBase { protected: void ApplyImpl(ir::Graph* graph) const override; - void QuantizeConv(Graph* graph, bool with_residual_data = false) const; + void QuantizeConv(Graph* graph, + bool with_residual_data = false, + std::vector* changed_weight = nullptr) const; void QuantizeFc(Graph* graph) const; void QuantizePool(Graph* graph) const; void QuantizeConcat(Graph* graph) const; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc old mode 100644 new mode 100755 index 4dabdd6bed0bd..70623214503d8 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc @@ -66,7 +66,7 @@ void SetOp(ProgramDesc* prog, type == "nearest_interp" || type == "nearest_interp_v2") { op->SetInput("X", {inputs[0]}); op->SetOutput("Out", {outputs[0]}); - } else if (type == "slice" || type == "shape") { + } else if (type == "slice") { op->SetInput("Input", {inputs[0]}); op->SetOutput("Out", {outputs[0]}); } else if (type == "dropout") { @@ -467,7 +467,7 @@ static const std::initializer_list variable_names_immutable_ops = { void TestImmutableOp(const std::string tested_op) { ProgramDesc prog; for (auto& v : variable_names_immutable_ops) { - prog.MutableBlock(0)->Var(v); + prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32); } SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true); SetOp(&prog, tested_op, tested_op, {"b"}, {"c"}, true, "int8"); @@ -520,7 +520,7 @@ void TestImmutableOpBetweenNonQuantizedOp(const std::string tested_op) { void TestImmutableOpWithManyOutputs(const std::string tested_op) { ProgramDesc prog; for (auto& v : variable_names_immutable_ops) { - prog.MutableBlock(0)->Var(v); + prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32); } SetOp(&prog, "dropout", "Dropout1", {"a"}, {"b"}, true, "float32"); @@ -556,12 +556,8 @@ void TestImmutableOpWithManyOutputs(const std::string tested_op) { SCALE * S8_MAX); } -const std::vector immutables = {"reshape2", - "transpose2", - "slice", - "shape", - "nearest_interp", - "nearest_interp_v2"}; +const std::vector immutables = { + "reshape2", "transpose2", "slice", "nearest_interp", "nearest_interp_v2"}; class TestImmutables : public testing::TestWithParam {}; diff --git a/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.cc old mode 100644 new mode 100755 index 177309376e825..e2065648ac5d0 --- a/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.cc @@ -52,35 +52,23 @@ bool HasBias(ir::Node* conv_op) { conv_op->Op()->Input("Bias").size() > 0; } -bool ShouldSkipConv(ir::Node* conv_op, Scope* scope, ir::Node* conv_filter) { - if (!platform::HasOpINT8DataType(conv_op->Op())) { - VLOG(4) << "Skipping non-int8 convolution (id: " << conv_op->id() << ")."; - return true; - } - - auto filter_var = scope->GetVar(conv_filter->Name()); - if (filter_var->Get().dtype() != phi::DataType::FLOAT32) { - VLOG(4) << "Skipping convolution (id: " << conv_op->id() - << ") because it's a bug that it is detected again."; - return true; - } - - VLOG(4) << "Not skipping convolution (id: " << conv_op->id() << ")"; - return false; -} - template void QuantizeConvInput(Scope* scope, ir::Graph* g, ir::Node* conv_op, const std::string& input_name, const std::string& scales_attr_name) { - const auto scales = - conv_op->Op()->GetAttrIfExists>(scales_attr_name); - - auto* tensor = scope->GetVar(input_name)->GetMutable(); - QuantizeParams(tensor, scales); - + auto var = scope->GetVar(input_name); + if (var->Get().dtype() != phi::DataType::FLOAT32) { + VLOG(1) << "Skipping quantize the input: " << input_name + << " of convolution because it is detected again."; + } else { + const auto scales = + conv_op->Op()->GetAttrIfExists>(scales_attr_name); + + auto* tensor = scope->GetVar(input_name)->GetMutable(); + QuantizeParams(tensor, scales); + } conv_op->Op()->SetAttr(scales_attr_name, std::vector(1, 1)); } @@ -151,7 +139,8 @@ void ParamsQuantizationMkldnnPass::QuantizeConv(ir::Graph* graph, PADDLE_ENFORCE_NOT_NULL( scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); - if (ShouldSkipConv(conv_op, scope, conv_filter)) { + // If not a quantized OP + if (!platform::HasOpINT8DataType(conv_op->Op())) { return; } diff --git a/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass_tester.cc old mode 100644 new mode 100755 index 507f25d92d8bc..e04cf388ac0d7 --- a/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass_tester.cc @@ -89,8 +89,14 @@ struct ProgramStrategy { virtual void CheckOp(const OpDesc& op) const = 0; - VarDesc* AddInput(OpDesc* op, std::string input_name, const Data& data) { - const std::string var_name = input_name + "_var"; + VarDesc* AddInput(OpDesc* op, + std::string input_name, + const Data& data, + const std::string user_var_name = "") { + std::string var_name = user_var_name; + if (var_name.empty()) { + var_name = input_name + "_var"; + } op->SetInput(input_name, {var_name}); auto var = program.MutableBlock(0)->Var(var_name); var->SetShape(data.getShape()); @@ -98,8 +104,14 @@ struct ProgramStrategy { return var; } - void AddOutput(OpDesc* op, std::string output_name, const Data& data) { - const std::string var_name = output_name + "_var"; + void AddOutput(OpDesc* op, + std::string output_name, + const Data& data, + const std::string user_var_name = "") { + std::string var_name = user_var_name; + if (var_name.empty()) { + var_name = output_name + "_var"; + } op->SetOutput(output_name, {var_name}); program.MutableBlock(0)->Var(var_name); test_scope.CreateTensor(var_name, data); @@ -117,21 +129,23 @@ struct ConvProgramStrategy : public ProgramStrategy { std::vector&& scale_weights, int groups = 1, Data&& bias = Data(), - std::vector&& scale_bias = {}) + std::vector&& scale_bias = {}, + bool share_weight = false) : input(std::move(input)), filter(std::move(filter)), output(std::move(output)), scale_weights(std::move(scale_weights)), groups(std::move(groups)), bias(std::move(bias)), - scale_bias(std::move(scale_bias)) {} + scale_bias(std::move(scale_bias)), + share_weight(std::move(share_weight)) {} protected: - OpDesc* CreateBasicConvOp() { + OpDesc* CreateBasicConvOp(const std::string conv_name = "Conv1") { auto op = program.MutableBlock(0)->AppendOp(); op->SetType("conv2d"); op->SetAttr("use_mkldnn", true); - op->SetAttr("name", std::string{"Conv1"}); + op->SetAttr("name", conv_name); op->SetAttr("mkldnn_data_type", std::string{"int8"}); op->SetAttr("data_format", std::string{"NCHW"}); op->SetAttr("dilations", std::vector({1, 1})); @@ -155,6 +169,20 @@ struct ConvProgramStrategy : public ProgramStrategy { AddInput(op, "Bias", bias); op->SetAttr("Bias_scales", scale_bias); } + + if (share_weight) { + OpDesc* op2 = CreateBasicConvOp("Conv2"); + AddInput(op2, "Input", input); + AddInput(op2, "Filter", filter)->SetPersistable(true); + AddOutput(op2, "Output", output, "output2"); + op2->SetAttr("Scale_weights", scale_weights); + op2->SetAttr("Scale_in", 1.0f); + op2->SetAttr("groups", groups); + if (HasBias()) { + AddInput(op2, "Bias", bias, "Bias2"); + op2->SetAttr("Bias_scales", scale_bias); + } + } } void CheckOp(const OpDesc& op) const override { @@ -210,9 +238,9 @@ struct ConvProgramStrategy : public ProgramStrategy { const Data output; const std::vector scale_weights; const int groups; - const Data bias; const std::vector scale_bias; + const bool share_weight; }; struct ParamsQuantizationMkldnnPassTestFixture : public ::testing::Test { @@ -340,6 +368,19 @@ TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_with_bias_2g2o2i1h1w) { RunPassTest(std::move(program)); } +TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_with_bias_2g2o2i1h1ws) { + auto program = std::make_unique( + GenericInput(), + Data({2, 2, 2, 1, 1}, {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f}), + GenericOutput(), + std::vector{2.f, 2.f, 4.f, 4.f}, + 2, + Data({2, 2, 1, 1, 1}, {1.5f, 1.5f, 1.5f, 1.5f}), + std::vector{2.f, 2.f, 4.f, 4.f}, + true); + RunPassTest(std::move(program)); +} + } // namespace } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc index b674ef52183c0..f0577bab7fb7a 100644 --- a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc @@ -109,27 +109,34 @@ void QuantDequantMkldnnPass::CollectWeightScalesInfoFromONNXFormatDequantize( if (op_node->Name() == "dequantize_linear") { auto* op_desc = op_node->Op(); + + auto scale_name = op_desc->Input("Scale")[0]; + auto* var = scope->FindVar(scale_name); + PADDLE_ENFORCE_NOT_NULL( + var, + platform::errors::NotFound( + "The Scales variable [%s] of dequantize op is not found.", var)); + + auto* scale_tensor = var->GetMutable(); + auto* scale_data = scale_tensor->data(); + auto x_var_name = op_desc->Input("X")[0]; auto* weight_var = scope->FindVar(x_var_name); if (!weight_var) { auto out_var_name = op_desc->Output("Y")[0]; - if (var_quant_scales->count(x_var_name) && - !var_quant_scales->count(out_var_name)) { - std::vector scale_v = var_quant_scales->at(x_var_name); + float scale = 1.0 / scale_data[0]; + if (std::isinf(scale) || std::isnan(scale)) { + scale = 0.0; + } + std::vector scale_v = {scale}; + if (!var_quant_scales->count(out_var_name)) { var_quant_scales->insert(std::make_pair(out_var_name, scale_v)); } + if (!var_quant_scales->count(x_var_name)) { + var_quant_scales->insert(std::make_pair(x_var_name, scale_v)); + } } else { *onnx_format_quantize_model = true; - auto scale_name = op_desc->Input("Scale")[0]; - auto* var = scope->FindVar(scale_name); - PADDLE_ENFORCE_NOT_NULL( - var, - platform::errors::NotFound( - "The Scales variable [%s] of dequantize op is not found.", - var)); - - auto* scale_tensor = var->GetMutable(); - auto* scale_data = scale_tensor->data(); std::vector thresholds(scale_data, scale_data + scale_tensor->numel()); weight_thresholds->insert(std::make_pair(x_var_name, thresholds)); @@ -182,7 +189,7 @@ void QuantDequantMkldnnPass::CollectInputScalesFromQuantize( auto* scale_data = scale_tensor->data(); float scale = 1.0 / scale_data[0]; if (std::isinf(scale) || std::isnan(scale)) { - scale = 0.0; + continue; } if (!var_quant_scales->count(x_var_name)) { @@ -520,12 +527,10 @@ void QuantDequantMkldnnPass::ConvertFromINT8ToFP32( int step_c = step_n / size; for (int i = 0; i < weight_dims[0]; i++) { int begin_n = i * step_n; - for (int j = begin_n; j < begin_n + step_n; j++) { - for (int k = 0; k < size; k++) { - int begin_c = k * step_c; - for (int m = begin_c; m < begin_c + step_c; m++) { - weight_data[m] *= scales[k]; - } + for (int j = 0; j < size; j++) { + int begin_c = begin_n + j * step_c; + for (int k = 0; k < step_c; k++) { + weight_data[begin_c + k] *= scales[j]; } } } @@ -588,7 +593,8 @@ void QuantDequantMkldnnPass::DequantizeOpWeightsFromONNXFormat( Scope* scope, const std::string& weight_name, const std::unordered_map>& - weight_thresholds) const { + weight_thresholds, + std::vector* dequantized_weights_names) const { auto* op_desc = op_node->Op(); std::string weight_var_name = op_desc->Input(weight_name)[0]; @@ -596,6 +602,13 @@ void QuantDequantMkldnnPass::DequantizeOpWeightsFromONNXFormat( auto iter = weight_thresholds.find(weight_var_name); if (iter != weight_thresholds.end()) { scales = iter->second; + auto name_iter = std::find(dequantized_weights_names->begin(), + dequantized_weights_names->end(), + weight_var_name); + // Has been dequantized + if (name_iter != dequantized_weights_names->end()) { + return; + } } else { if (!IsInt8Weight(op_node, scope, weight_name)) { return; @@ -605,7 +618,7 @@ void QuantDequantMkldnnPass::DequantizeOpWeightsFromONNXFormat( "the model is correct.", weight_var_name)); } - + dequantized_weights_names->push_back(weight_var_name); auto* var = scope->FindVar(weight_var_name); PADDLE_ENFORCE_NOT_NULL( var, @@ -634,14 +647,17 @@ void QuantDequantMkldnnPass::DequantizeWeights( << "No need to dequantize weights because weight_thresholds is empty."; return; } - + std::vector dequantized_weights_names; for (auto* op_node : ir::TopologyVarientSort(*graph, static_cast(0))) { if (!op_node->IsOp()) continue; if (op_node->Name() == "conv2d" || op_node->Name() == "depthwise_conv2d") { if (onnx_format_quantize_model) { - DequantizeOpWeightsFromONNXFormat( - op_node, scope, "Filter", weight_thresholds); + DequantizeOpWeightsFromONNXFormat(op_node, + scope, + "Filter", + weight_thresholds, + &dequantized_weights_names); } else if (IsInt8Weight(op_node, scope, "Filter")) { DequantizeOpWeights( op_node, scope, "Filter", "Output", weight_thresholds); @@ -650,7 +666,7 @@ void QuantDequantMkldnnPass::DequantizeWeights( op_node->Name() == "matmul_v2") { if (onnx_format_quantize_model) { DequantizeOpWeightsFromONNXFormat( - op_node, scope, "Y", weight_thresholds); + op_node, scope, "Y", weight_thresholds, &dequantized_weights_names); } else if (IsInt8Weight(op_node, scope, "Y")) { DequantizeOpWeights(op_node, scope, "Y", "Out", weight_thresholds); } diff --git a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.h b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.h old mode 100644 new mode 100755 index eee7fc96ed1d4..b89b393cc329c --- a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.h @@ -125,7 +125,8 @@ class QuantDequantMkldnnPass : public FusePassBase { Scope* scope, const std::string& weight_name, const std::unordered_map>& - weight_thresholds) const; + weight_thresholds, + std::vector* dequantized_weights_names) const; void DequantizeWeights( ir::Graph* graph, diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc old mode 100644 new mode 100755 index cef7402e6c061..bca2cde0fc2c6 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -142,7 +142,7 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs( scales_[var_name] = scales_[input_var_name]; } compute_scale = false; - } else if (op->Type() == "slice" || op->Type() == "shape") { + } else if (op->Type() == "slice") { auto input_var_name = op->Input("Input")[0]; PADDLE_ENFORCE_NE(scales_.find(input_var_name), scales_.end(), diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_onnx_format_quantization_mobilenetv1.py b/python/paddle/fluid/tests/unittests/mkldnn/test_onnx_format_quantization_mobilenetv1.py index aa1d35f50c56b..5667253e5cb73 100755 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_onnx_format_quantization_mobilenetv1.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_onnx_format_quantization_mobilenetv1.py @@ -18,8 +18,6 @@ import random import math import functools -import contextlib -import tempfile import numpy as np from PIL import Image, ImageEnhance import paddle @@ -151,13 +149,11 @@ def setUp(self): self.infer_iterations = 50000 if os.environ.get( 'DATASET') == 'full' else 2 - self.root_path = tempfile.TemporaryDirectory() - self.int8_model = os.path.join(self.root_path.name, - "post_training_quantization") - print("self.int8_model: ", self.int8_model) + self.int8_model = "post_training_quantization" def tearDown(self): - self.root_path.cleanup() + cmd = 'rm -rf post_training_quantization' + os.system(cmd) pass def cache_unzipping(self, target_folder, zip_path): @@ -268,16 +264,8 @@ def generate_quantized_model(self, is_use_cache_file=False, is_optimize_model=False, onnx_format=False): - try: - os.system("mkdir " + self.int8_model) - except Exception as e: - print("Failed to create {} due to {}".format( - self.int8_model, str(e))) - sys.exit(-1) - place = fluid.CPUPlace() exe = fluid.Executor(place) - scope = fluid.global_scope() val_reader = val() ptq = PostTrainingQuantization(executor=exe, @@ -311,12 +299,6 @@ def run_test(self, model_cache_folder = self.download_data(data_urls, data_md5s, model) - print("Start FP32 inference for {0} on {1} images ...".format( - model, infer_iterations * batch_size)) - (fp32_throughput, fp32_latency, fp32_acc1) = self.run_program( - os.path.join(model_cache_folder, "model"), batch_size, - infer_iterations) - print("Start INT8 post training quantization for {0} on {1} images ...". format(model, sample_iterations * batch_size)) self.generate_quantized_model(os.path.join(model_cache_folder, "model"), @@ -324,6 +306,12 @@ def run_test(self, is_full_quantize, is_use_cache_file, is_optimize_model, onnx_format) + print("Start FP32 inference for {0} on {1} images ...".format( + model, infer_iterations * batch_size)) + (fp32_throughput, fp32_latency, fp32_acc1) = self.run_program( + os.path.join(model_cache_folder, "model"), batch_size, + infer_iterations) + print("Start INT8 inference for {0} on {1} images ...".format( model, infer_iterations * batch_size)) (int8_throughput, int8_latency, @@ -347,10 +335,10 @@ def run_test(self, self.assertLess(delta_value, diff_threshold) -class TestMKLDNNInt8ForMobilenetv1AvgONNXFormat(TestPostTrainingQuantization): +class TestMKLDNNInt8ForResnet50AvgONNXFormat(TestPostTrainingQuantization): - def test_onnx_format_avg_mobilenetv1(self): - model = "MobileNet-V1" + def test_onnx_format_avg_resnet50(self): + model = "resnet50" algo = "avg" round_type = "round" data_urls = [ @@ -379,66 +367,5 @@ def test_onnx_format_avg_mobilenetv1(self): onnx_format=True) -class TestMKLDNNInt8ForMobilenetv1Avg(TestPostTrainingQuantization): - - def test_avg_mobilenetv1(self): - model = "MobileNet-V1" - algo = "avg" - round_type = "round" - data_urls = [ - 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' - ] - data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] - quantizable_op_type = [ - "conv2d", - "depthwise_conv2d", - "mul", - ] - is_full_quantize = False - is_use_cache_file = False - is_optimize_model = False - diff_threshold = 0 - self.run_test(model, - algo, - round_type, - data_urls, - data_md5s, - quantizable_op_type, - is_full_quantize, - is_use_cache_file, - is_optimize_model, - diff_threshold, - onnx_format=False) - - -class TestMKLDNNInt8ForMobilenetv1AbsMax(TestPostTrainingQuantization): - - def test_abs_max_mobilenetv1(self): - model = "MobileNet-V1" - algo = "abs_max" - round_type = "round" - data_urls = [ - 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' - ] - data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] - quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] - is_full_quantize = False - is_use_cache_file = False - is_optimize_model = False - # The accuracy diff of post-training quantization (abs_max) maybe bigger - diff_threshold = 0 - self.run_test(model, - algo, - round_type, - data_urls, - data_md5s, - quantizable_op_type, - is_full_quantize, - is_use_cache_file, - is_optimize_model, - diff_threshold, - onnx_format=False) - - if __name__ == '__main__': unittest.main()