Skip to content

Commit

Permalink
Add input data type checking in BF16 placement pass (#38702)
Browse files Browse the repository at this point in the history
  • Loading branch information
wozna authored Jan 5, 2022
1 parent bbe83ed commit 60c51de
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 9 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2441,11 +2441,13 @@ PDNode *patterns::Bfloat16Placement::operator()(
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
auto *op_in = pattern->NewNode(op_in_repr())->AsInput();
auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types);
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<bool>("use_mkldnn") ||
node->Op()->Type() == "reshape2";
});
op->LinksFrom({op_in});
return op;
}

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,7 @@ struct Bfloat16Placement : public PatternBase {
PDNode* operator()(
const std::unordered_set<std::string>& bfloat16_enabled_op_types);

PATTERN_DECL_NODE(op_in);
PATTERN_DECL_NODE(op);
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ void CPUBfloat16PlacementPass::SetMkldnnDataType(

auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_placement_pattern);
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_placement_pattern);

// Only float input can be converted to bfloat16
if (op_in->Var()->GetDataType() != proto::VarType::FP32) return;

if ((op->Op()->HasAttr("mkldnn_data_type") ||
op->Op()->HasProtoAttr("mkldnn_data_type")) &&
!platform::HasOpINT8DataType(op->Op())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ ProgramDesc BuildProgramDesc() {
for (auto& v :
std::vector<std::string>({"a", "b", "c", "f", "g", "h", "k", "l", "m",
"n", "o", "p", "r", "s"})) {
prog.MutableBlock(0)->Var(v);
prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32);
}

SetOp(&prog, "concat", "concat1", {"a", "b"}, {"c"});
Expand All @@ -86,9 +86,8 @@ ProgramDesc BuildProgramDesc() {
}

void MainTest(std::initializer_list<std::string> bfloat16_enabled_op_types,
unsigned expected_bfloat16_data_type_count) {
auto prog = BuildProgramDesc();

unsigned expected_bfloat16_data_type_count,
const ProgramDesc& prog) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));

auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass");
Expand All @@ -110,8 +109,8 @@ void MainTest(std::initializer_list<std::string> bfloat16_enabled_op_types,
EXPECT_EQ(bfloat16_data_type_count, expected_bfloat16_data_type_count);
}

void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) {
auto prog = BuildProgramDesc();
void DefaultAttrTest(unsigned expected_bfloat16_data_type_count,
const ProgramDesc& prog) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass");
graph.reset(pass->Apply(graph.release()));
Expand All @@ -128,15 +127,39 @@ void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) {
}

TEST(Bfloat16PlacementPass, enable_all) {
MainTest({"conv2d", "pool2d", "gelu", "concat", "sum"}, 8);
MainTest({"conv2d", "pool2d", "gelu", "concat", "sum"}, 8,
BuildProgramDesc());
}

TEST(Bfloat16PlacementPass, enabled_conv_and_pool) {
// 2 conv2d + 2 pool2 - 1 orphaned conv2d
MainTest({"conv2d", "pool2d"}, 3);
MainTest({"conv2d", "pool2d"}, 3, BuildProgramDesc());
}

TEST(Bfloat16PlacementPass, default_attr_value) {
DefaultAttrTest(10, BuildProgramDesc());
}

ProgramDesc BuildProgramDescWithDataType() {
ProgramDesc prog;

for (auto& v : std::vector<std::string>({"a", "b", "c", "d", "e"})) {
if (v == "a") {
prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::INT32);
} else {
prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32);
}
}

SetOp(&prog, "conv2d", "conv1", {"a"}, {"b"});
SetOp(&prog, "pool2d", "pool1", {"b"}, {"c"});
SetOp(&prog, "concat", "concat1", {"c", "d"}, {"e"});
return prog;
}

TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(10); }
TEST(Bfloat16PlacementPass, check_data_types) {
DefaultAttrTest(2, BuildProgramDescWithDataType());
}

} // namespace ir
} // namespace framework
Expand Down

0 comments on commit 60c51de

Please sign in to comment.