Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct quantization for multiple input and output ops #48872

Merged
merged 1 commit into from
Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 77 additions & 25 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@ void CPUQuantizePass::QuantizeInputs(Graph* g,
float shift,
std::string shift_attr_name) const {
auto inputs = op->inputs;
auto var_names = op->Op()->Inputs().at(input_name);
std::vector<std::string> unique_var_names;
for (unsigned i = 0; i < var_names.size(); i++)
if (std::find(unique_var_names.begin(),
unique_var_names.end(),
var_names[i]) == unique_var_names.end())
unique_var_names.push_back(var_names[i]);

auto output = op->outputs[0];
PADDLE_ENFORCE_GE(inputs.size(),
1,
Expand All @@ -163,33 +171,59 @@ void CPUQuantizePass::QuantizeInputs(Graph* g,
// create a quantize op desc prototype
OpDesc q_desc;
q_desc.SetType("quantize");

std::vector<Node*> quantize_out_nodes(inputs.size());
std::vector<std::string> quantize_out_node_names(inputs.size());

double scale_out = GetScaleValueForNode(output);
unsigned max = are_inputs_unsigned ? U8_MAX : S8_MAX;
float scale = scale_out * max;

for (size_t i = 0; i < inputs.size(); i++) {
// Create quantize output variable
for (size_t var_id = 0; var_id < unique_var_names.size(); var_id++) {
auto index = -1;
for (size_t it = 0; it < inputs.size(); it++) {
if (inputs[it]->Name() == unique_var_names[var_id]) index = it;
}

if (index == -1) {
PADDLE_ENFORCE_NE(index,
-1,
platform::errors::InvalidArgument(
"Var(%s) isn't the input of the %s operator.",
unique_var_names[var_id],
op->Op()->Type()));
}

auto* input = inputs.at(index);

VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
quantize_out_nodes[i] = g->CreateVarNode(&quantize_out_desc);
quantize_out_node_names[i] = quantize_out_nodes[i]->Name();
quantize_out_nodes[var_id] = g->CreateVarNode(&quantize_out_desc);
quantize_out_node_names[var_id] = quantize_out_nodes[var_id]->Name();

q_desc.SetAttr("Scale", scale);
q_desc.SetAttr("Shift", shift);
q_desc.SetInput("Input", std::vector<std::string>({inputs[i]->Name()}));
q_desc.SetOutput("Output",
std::vector<std::string>({quantize_out_node_names[i]}));
q_desc.SetInput("Input", std::vector<std::string>({input->Name()}));
q_desc.SetOutput(
"Output", std::vector<std::string>({quantize_out_node_names[var_id]}));
q_desc.SetAttr("is_negative_input", !are_inputs_unsigned);
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.

// link quantize op
UnlinkNodes(inputs[i], op);
IR_NODE_LINK_TO(inputs[i], quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]);
IR_NODE_LINK_TO(quantize_out_nodes[i], op);
UnlinkNodes(input, op);
IR_NODE_LINK_TO(input, quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[var_id]);
IR_NODE_LINK_TO(quantize_out_nodes[var_id], op);
}

// If any inputs were duplicated, now you have to enter them in the correct
// order.
for (size_t i = unique_var_names.size(); i < var_names.size(); i++) {
auto index = std::find(
unique_var_names.begin(), unique_var_names.end(), var_names[i]);
if (index != unique_var_names.end()) {
auto id = std::distance(unique_var_names.begin(), index);
quantize_out_node_names[i] = quantize_out_nodes[id]->Name();
IR_NODE_LINK_TO(quantize_out_nodes[id], op);
}
}

// update op's input
Expand Down Expand Up @@ -252,44 +286,62 @@ void CPUQuantizePass::DequantizeOutputs(Graph* g,
bool is_unsigned,
std::string scale_attr_name) const {
auto outputs = op->outputs;
auto var_names = op->Op()->Outputs().at(output_name);

PADDLE_ENFORCE_GE(outputs.size(),
1,
platform::errors::InvalidArgument(
"OP(%s)'s outputs(%d) must be equal or greater than 1.",
op->Name(),
outputs.size()));

std::vector<std::string> quantize_in_node_names(outputs.size());
std::vector<std::string> dequantize_in_node_names(outputs.size());
std::vector<Node*> dequantize_in_nodes(outputs.size());

unsigned max = is_unsigned ? U8_MAX : S8_MAX;
float scale = scale_to_one * max;

for (size_t i = 0; i < outputs.size(); i++) {
for (size_t var_id = 0; var_id < var_names.size(); var_id++) {
auto index = -1;
for (size_t it = 0; it < outputs.size(); it++) {
if (outputs[it]->Name() == var_names[var_id]) index = it;
}

if (index == -1) {
PADDLE_ENFORCE_NE(index,
-1,
platform::errors::InvalidArgument(
"Var(%s) isn't the input of the %s operator.",
var_names[var_id],
op->Op()->Type()));
}

auto* output = outputs.at(index);

// Create dequantize input variable
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
Node* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);
quantize_in_node_names[i] = dequantize_in_node->Name();
dequantize_in_nodes[var_id] = g->CreateVarNode(&dequantize_in_desc);
dequantize_in_node_names[var_id] = dequantize_in_nodes[var_id]->Name();

// create a dequantize op node for output.
OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({quantize_in_node_names[i]}));
deq_desc.SetOutput("Output",
std::vector<std::string>({outputs[i]->Name()}));
deq_desc.SetInput(
"Input", std::vector<std::string>({dequantize_in_node_names[var_id]}));
deq_desc.SetOutput("Output", std::vector<std::string>({output->Name()}));
deq_desc.SetAttr("Scale", scale);
deq_desc.SetAttr("is_negative_input", !is_unsigned);
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.

// link dequantize op
UnlinkNodes(op, outputs[i]);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, outputs[i]);
UnlinkNodes(op, output);
IR_NODE_LINK_TO(op, dequantize_in_nodes[var_id]);
IR_NODE_LINK_TO(dequantize_in_nodes[var_id], dequantize_op);
IR_NODE_LINK_TO(dequantize_op, output);
}

// update op's output
op->Op()->SetOutput(output_name, quantize_in_node_names);
op->Op()->SetOutput(output_name, dequantize_in_node_names);
if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
}

Expand Down
39 changes: 39 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,45 @@ TEST(CpuQuantizePass, multi_gru_3) {
MainTestMultiGru(layers);
}

static const std::initializer_list<std::string>
variable_names_multi_inputs_outputs = {"a", "b", "c1", "c2", "d", "e"};

// a->Pool->b
// b->Split->c1, c2
// (c1, c2, c1, c2)->Concat->d
// d->Pool->e
ProgramDesc BuildProgramDescMulti() {
ProgramDesc prog;
for (auto& v : variable_names_multi_inputs_outputs) {
prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32);
}

SetOp(&prog, "pool2d", "Pool", {"a"}, {"b"}, true, "float32");
SetOp(&prog, "split", "Split", {"b"}, {"c1", "c2"}, true, "int8");
SetOp(
&prog, "concat", "Concat", {"c1", "c2", "c1", "c2"}, {"d"}, true, "int8");
SetOp(&prog, "pool2d", "Pool2", {"d"}, {"e"}, true, "float32");

return prog;
}

TEST(CpuQuantizePass, multi_inputs_outputs_ops) {
// a->QUANT1->Split
// b1->DEQUANT->OUT->QUANT
// b2->DEQUANT->OUT->QUANT
// (b1, b2, b1, b2)->Concat->c->DEQUANT->Pool->d
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

b1,b2 are not declared in BuildProgramDescMulti and variable_names_multi_inputs_outputs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it should be c1 and c2

int added_nodes = 6 * 2;
std::unordered_map<std::string, int> expected_operators = {{"pool2d", 2},
{"concat", 1},
{"split", 1},
{"quantize", 3},
{"dequantize", 3}};
MainTest(BuildProgramDescMulti(),
variable_names_multi_inputs_outputs,
expected_operators,
added_nodes);
}

} // namespace ir
} // namespace framework
} // namespace paddle
Expand Down
23 changes: 13 additions & 10 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
PADDLE_GET_CONST(float, quant_op->Op()->GetAttr("Scale"));
float dequant_shift = dequant_op->Op()->GetAttrIfExists<float>("Shift");
float quant_shift = quant_op->Op()->GetAttrIfExists<float>("Shift");
if (quant_op->Op()->GetAttrIfExists<bool>("is_negative_input") !=
dequant_op->Op()->GetAttrIfExists<bool>("is_negative_input")) {
return;
}

PADDLE_ENFORCE_NE(
nodes_keep_counter->find(dequant_out),
nodes_keep_counter->end(),
Expand All @@ -169,14 +174,13 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
if (dequant_scale == quant_scale && dequant_shift == quant_shift) {
// squash dequantize-quantize to nothing
auto quant_out_var_name = quant_out->Name();
auto next_op_inputs = next_op_desc->InputNames();
for (const auto& name : next_op_inputs) {
auto input_names = next_op_desc->Input(name);
for (auto input_name : next_op_desc->InputNames()) {
auto& input_names = next_op_desc->MutableInputs()->at(input_name);
std::replace(input_names.begin(),
input_names.end(),
quant_out_var_name,
dequant_in->Name());
next_op_desc->SetInput(name, input_names);
next_op_desc->SetInput(input_name, input_names);
}

if (keep_dequant)
Expand Down Expand Up @@ -413,12 +417,11 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {

// update the next operator input,
// by replacing quant_out with first_quant_out
auto last_op_names = last_op->Op()->Input(last_op_input_name);
last_op_names.erase(
std::remove(
last_op_names.begin(), last_op_names.end(), quant_out->Name()),
last_op_names.end());
last_op_names.push_back(first_quant_out->Name());
auto last_op_names = last_op->Op()->Inputs().at(last_op_input_name);
std::replace(last_op_names.begin(),
last_op_names.end(),
quant_out->Name(),
first_quant_out->Name());
last_op_op->SetInput(last_op_input_name,
std::vector<std::string>(last_op_names));

Expand Down