Skip to content

Commit

Permalink
[cherry-pick] Fix quantize model deploy bug in MKLDNN (#47119)
Browse files Browse the repository at this point in the history
* Fix quantize model deploy bugs when using MKLDNN (#45920)

* fix immutable op quantize bugs

* fix

* fix build bug

* fix test

* notest,test=inference

* fix ppyoloe acc drop bugs

* fix test

* fix test

* add test

* fix

* fix

* fix test

* fix refined name bug

* fix test

* bias fix

* fix matmul weight dequant bug

* re-ci

* fix tester

* fix test

* fix tester

* update weight dequantize func

* update code

* update test for converage

* update test

* update cmake

* update cmakelist

* update code

* rerun ci

* remove useless code

* re-ci

* update code

* update code

* fix header

* update code for log
  • Loading branch information
yeliang2258 authored Oct 20, 2022
1 parent 68c4ac3 commit c2d344d
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 164 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/graph_pattern_detector.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ struct ResidualElementwise : public PatternBase {
};

// General struct for immutable ops:
// reshape, transpose, slice, shape, nearest-interp
// reshape, transpose, slice, nearest-interp
// Forward pass for no weights-op.
// immutable_out is a result of the operator.
struct Immutable : public PatternBase {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ std::vector<float> ComputePropagateScalesMkldnnPass::GetScales(Tensor* tensor,
for (int i = 0; i < columns; i++) {
float max_value = FLT_MIN;
for (int j = 0; j < rows; j++) {
max_value = std::max(max_value, std::abs(data[i + j * columns]));
max_value = std::max(max_value, std::abs(data[j + i * rows]));
}
max_value = 1.0 / max_value;
if (std::isinf(max_value) || std::isnan(max_value)) {
Expand Down
32 changes: 26 additions & 6 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,10 @@ void CPUQuantizePass::GetQuantInfo(Graph* graph) const {
graph, "has_quant_info", "var_quant_scales", var_quant_scales_);
}

void CPUQuantizePass::QuantizeConv(Graph* graph,
bool with_residual_data) const {
void CPUQuantizePass::QuantizeConv(
Graph* graph,
bool with_residual_data,
std::vector<std::string>* changed_weight) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::ConvResidual conv_pattern{pattern, name_scope_};
Expand Down Expand Up @@ -411,7 +413,16 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
auto filter_scale_tensor = GetScaleTensorForNode(conv_filter);
EigenVectorArrayMap eigen_tensor{filter_scale_tensor.data<double>(),
filter_scale_tensor.numel()};
eigen_tensor *= static_cast<double>(S8_MAX);

// If the scale value of a weight is already multiplied by S8_MAX, it does
// not need to be multiplied again
if (std::find(changed_weight->begin(),
changed_weight->end(),
conv_filter->Name()) == changed_weight->end()) {
eigen_tensor *= static_cast<double>(S8_MAX);
changed_weight->push_back(conv_filter->Name());
}

std::vector<float> filter_scale{
filter_scale_tensor.data<double>(),
filter_scale_tensor.data<double>() + filter_scale_tensor.numel()};
Expand Down Expand Up @@ -697,6 +708,13 @@ void CPUQuantizePass::QuantizeImmutable(Graph* graph,
return;
}

// skip if the dtype of immutable_in is not float32
auto dtype = immutable_in->Var()->GetDataType();
if (dtype != proto::VarType::FP32) {
MarkAndLogCannotQuantizeOp(immutable_op, "The input dtype is not float.");
return;
}

if (!AreScalesPresentForNodes({immutable_out})) {
MarkAndLogCannotQuantizeOp(immutable_op,
"No scale available for the operator");
Expand Down Expand Up @@ -1156,9 +1174,12 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
param_scope(),
platform::errors::InvalidArgument("Scope cannot be nullptr."));

// Save the scale values of which weights have been processed to avoid
// secondary processing
std::vector<std::string> changed_weight = {};
GetQuantInfo(graph);
QuantizeConv(graph, false /* with_residual_data */);
QuantizeConv(graph, true /* with_residual_data */);
QuantizeConv(graph, false /* with_residual_data */, &changed_weight);
QuantizeConv(graph, true /* with_residual_data */, &changed_weight);
QuantizePool(graph);
QuantizeConcat(graph);
QuantizePriorBox(graph);
Expand All @@ -1168,7 +1189,6 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeImmutable(graph, "reshape2", "X");
QuantizeImmutable(graph, "transpose2", "X");
QuantizeImmutable(graph, "slice", "Input");
QuantizeImmutable(graph, "shape", "Input");
QuantizeImmutable(graph, "nearest_interp", "X");
QuantizeImmutable(graph, "nearest_interp_v2", "X");
QuantizeElementwise(graph, "elementwise_add");
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ class CPUQuantizePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;

void QuantizeConv(Graph* graph, bool with_residual_data = false) const;
void QuantizeConv(Graph* graph,
bool with_residual_data = false,
std::vector<std::string>* changed_weight = nullptr) const;
void QuantizeFc(Graph* graph) const;
void QuantizePool(Graph* graph) const;
void QuantizeConcat(Graph* graph) const;
Expand Down
14 changes: 5 additions & 9 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void SetOp(ProgramDesc* prog,
type == "nearest_interp" || type == "nearest_interp_v2") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
} else if (type == "slice" || type == "shape") {
} else if (type == "slice") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
} else if (type == "dropout") {
Expand Down Expand Up @@ -467,7 +467,7 @@ static const std::initializer_list<std::string> variable_names_immutable_ops = {
void TestImmutableOp(const std::string tested_op) {
ProgramDesc prog;
for (auto& v : variable_names_immutable_ops) {
prog.MutableBlock(0)->Var(v);
prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32);
}
SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true);
SetOp(&prog, tested_op, tested_op, {"b"}, {"c"}, true, "int8");
Expand Down Expand Up @@ -520,7 +520,7 @@ void TestImmutableOpBetweenNonQuantizedOp(const std::string tested_op) {
void TestImmutableOpWithManyOutputs(const std::string tested_op) {
ProgramDesc prog;
for (auto& v : variable_names_immutable_ops) {
prog.MutableBlock(0)->Var(v);
prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32);
}

SetOp(&prog, "dropout", "Dropout1", {"a"}, {"b"}, true, "float32");
Expand Down Expand Up @@ -556,12 +556,8 @@ void TestImmutableOpWithManyOutputs(const std::string tested_op) {
SCALE * S8_MAX);
}

const std::vector<std::string> immutables = {"reshape2",
"transpose2",
"slice",
"shape",
"nearest_interp",
"nearest_interp_v2"};
const std::vector<std::string> immutables = {
"reshape2", "transpose2", "slice", "nearest_interp", "nearest_interp_v2"};

class TestImmutables : public testing::TestWithParam<std::string> {};

Expand Down
37 changes: 13 additions & 24 deletions paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -52,35 +52,23 @@ bool HasBias(ir::Node* conv_op) {
conv_op->Op()->Input("Bias").size() > 0;
}

bool ShouldSkipConv(ir::Node* conv_op, Scope* scope, ir::Node* conv_filter) {
if (!platform::HasOpINT8DataType(conv_op->Op())) {
VLOG(4) << "Skipping non-int8 convolution (id: " << conv_op->id() << ").";
return true;
}

auto filter_var = scope->GetVar(conv_filter->Name());
if (filter_var->Get<LoDTensor>().dtype() != phi::DataType::FLOAT32) {
VLOG(4) << "Skipping convolution (id: " << conv_op->id()
<< ") because it's a bug that it is detected again.";
return true;
}

VLOG(4) << "Not skipping convolution (id: " << conv_op->id() << ")";
return false;
}

template <typename T>
void QuantizeConvInput(Scope* scope,
ir::Graph* g,
ir::Node* conv_op,
const std::string& input_name,
const std::string& scales_attr_name) {
const auto scales =
conv_op->Op()->GetAttrIfExists<std::vector<float>>(scales_attr_name);

auto* tensor = scope->GetVar(input_name)->GetMutable<LoDTensor>();
QuantizeParams<T>(tensor, scales);

auto var = scope->GetVar(input_name);
if (var->Get<LoDTensor>().dtype() != phi::DataType::FLOAT32) {
VLOG(1) << "Skipping quantize the input: " << input_name
<< " of convolution because it is detected again.";
} else {
const auto scales =
conv_op->Op()->GetAttrIfExists<std::vector<float>>(scales_attr_name);

auto* tensor = scope->GetVar(input_name)->GetMutable<LoDTensor>();
QuantizeParams<T>(tensor, scales);
}
conv_op->Op()->SetAttr(scales_attr_name, std::vector<float>(1, 1));
}

Expand Down Expand Up @@ -151,7 +139,8 @@ void ParamsQuantizationMkldnnPass::QuantizeConv(ir::Graph* graph,
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));

if (ShouldSkipConv(conv_op, scope, conv_filter)) {
// If not a quantized OP
if (!platform::HasOpINT8DataType(conv_op->Op())) {
return;
}

Expand Down
59 changes: 50 additions & 9 deletions paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass_tester.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,29 @@ struct ProgramStrategy {

virtual void CheckOp(const OpDesc& op) const = 0;

VarDesc* AddInput(OpDesc* op, std::string input_name, const Data& data) {
const std::string var_name = input_name + "_var";
VarDesc* AddInput(OpDesc* op,
std::string input_name,
const Data& data,
const std::string user_var_name = "") {
std::string var_name = user_var_name;
if (var_name.empty()) {
var_name = input_name + "_var";
}
op->SetInput(input_name, {var_name});
auto var = program.MutableBlock(0)->Var(var_name);
var->SetShape(data.getShape());
test_scope.CreateTensor(var_name, data);
return var;
}

void AddOutput(OpDesc* op, std::string output_name, const Data& data) {
const std::string var_name = output_name + "_var";
void AddOutput(OpDesc* op,
std::string output_name,
const Data& data,
const std::string user_var_name = "") {
std::string var_name = user_var_name;
if (var_name.empty()) {
var_name = output_name + "_var";
}
op->SetOutput(output_name, {var_name});
program.MutableBlock(0)->Var(var_name);
test_scope.CreateTensor(var_name, data);
Expand All @@ -117,21 +129,23 @@ struct ConvProgramStrategy : public ProgramStrategy {
std::vector<float>&& scale_weights,
int groups = 1,
Data&& bias = Data(),
std::vector<float>&& scale_bias = {})
std::vector<float>&& scale_bias = {},
bool share_weight = false)
: input(std::move(input)),
filter(std::move(filter)),
output(std::move(output)),
scale_weights(std::move(scale_weights)),
groups(std::move(groups)),
bias(std::move(bias)),
scale_bias(std::move(scale_bias)) {}
scale_bias(std::move(scale_bias)),
share_weight(std::move(share_weight)) {}

protected:
OpDesc* CreateBasicConvOp() {
OpDesc* CreateBasicConvOp(const std::string conv_name = "Conv1") {
auto op = program.MutableBlock(0)->AppendOp();
op->SetType("conv2d");
op->SetAttr("use_mkldnn", true);
op->SetAttr("name", std::string{"Conv1"});
op->SetAttr("name", conv_name);
op->SetAttr("mkldnn_data_type", std::string{"int8"});
op->SetAttr("data_format", std::string{"NCHW"});
op->SetAttr("dilations", std::vector<int>({1, 1}));
Expand All @@ -155,6 +169,20 @@ struct ConvProgramStrategy : public ProgramStrategy {
AddInput(op, "Bias", bias);
op->SetAttr("Bias_scales", scale_bias);
}

if (share_weight) {
OpDesc* op2 = CreateBasicConvOp("Conv2");
AddInput(op2, "Input", input);
AddInput(op2, "Filter", filter)->SetPersistable(true);
AddOutput(op2, "Output", output, "output2");
op2->SetAttr("Scale_weights", scale_weights);
op2->SetAttr("Scale_in", 1.0f);
op2->SetAttr("groups", groups);
if (HasBias()) {
AddInput(op2, "Bias", bias, "Bias2");
op2->SetAttr("Bias_scales", scale_bias);
}
}
}

void CheckOp(const OpDesc& op) const override {
Expand Down Expand Up @@ -210,9 +238,9 @@ struct ConvProgramStrategy : public ProgramStrategy {
const Data output;
const std::vector<float> scale_weights;
const int groups;

const Data bias;
const std::vector<float> scale_bias;
const bool share_weight;
};

struct ParamsQuantizationMkldnnPassTestFixture : public ::testing::Test {
Expand Down Expand Up @@ -340,6 +368,19 @@ TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_with_bias_2g2o2i1h1w) {
RunPassTest(std::move(program));
}

TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_with_bias_2g2o2i1h1ws) {
auto program = std::make_unique<ConvProgramStrategy>(
GenericInput(),
Data({2, 2, 2, 1, 1}, {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f}),
GenericOutput(),
std::vector<float>{2.f, 2.f, 4.f, 4.f},
2,
Data({2, 2, 1, 1, 1}, {1.5f, 1.5f, 1.5f, 1.5f}),
std::vector<float>{2.f, 2.f, 4.f, 4.f},
true);
RunPassTest(std::move(program));
}

} // namespace
} // namespace ir
} // namespace framework
Expand Down
Loading

0 comments on commit c2d344d

Please sign in to comment.