Skip to content

Commit

Permalink
remove bf16 (#38133)
Browse files Browse the repository at this point in the history
* remove bf16

* remove comments

* remove wrong return

* fix UT
  • Loading branch information
b3602sss authored Dec 15, 2021
1 parent b28c374 commit 49108ef
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 0 deletions.
17 changes: 17 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2412,6 +2412,23 @@ PDNode *patterns::OrphanedBfloat16::operator()() {
return next_op;
}

PDNode *patterns::UnsupportedBfloat16::operator()() {
auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
prev_op->assert_more([&](Node *node) {
return node->Op()->HasAttr("mkldnn_data_type") == false;
});
auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput();

auto *op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
prev_op->LinksTo({prev_out});
op->LinksFrom({prev_out});
return op;
}

PDNode *patterns::LastBfloat16Ops::operator()() {
auto *op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) {
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,16 @@ struct OrphanedBfloat16 : public PatternBase {
PATTERN_DECL_NODE(next_op);
};

struct UnsupportedBfloat16 : public PatternBase {
UnsupportedBfloat16(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "unsupported_bfloat16") {}
PDNode* operator()();

PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(prev_out);
PATTERN_DECL_NODE(op);
};

struct LastBfloat16Ops : public PatternBase {
LastBfloat16Ops(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "last_bfloat16_ops") {}
Expand Down
21 changes: 21 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,31 @@ void CPUBfloat16PlacementPass::RemoveOrphanedOperators(
gpd(graph, handler);
}

void CPUBfloat16PlacementPass::RemoveUnsupportedOperators(
ir::Graph* graph, int* bfloat16_operators) const {
// now quantize is supported FP32 only, so try to find
// bfloat16 operator that input type is not FP32
GraphPatternDetector gpd;
patterns::UnsupportedBfloat16 unsupported_bfloat16_pattern{
gpd.mutable_pattern(), "unsupported_bfloat16"};
unsupported_bfloat16_pattern();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(prev_out, prev_out, unsupported_bfloat16_pattern);
GET_IR_NODE_FROM_SUBGRAPH(op, op, unsupported_bfloat16_pattern);
if ((prev_out->Var()->GetDataType() != proto::VarType::FP32)) {
op->Op()->SetAttr("mkldnn_data_type", std::string("float32"));
bfloat16_operators--;
}
};
gpd(graph, handler);
}

void CPUBfloat16PlacementPass::ApplyImpl(ir::Graph* graph) const {
int bfloat16_operators = 0;
SetMkldnnDataType(graph, &bfloat16_operators);
RemoveOrphanedOperators(graph, &bfloat16_operators);
RemoveUnsupportedOperators(graph, &bfloat16_operators);
PrettyLogDetail("--- marked %d operators to bfloat16 ",
bfloat16_operators);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class CPUBfloat16PlacementPass : public Pass {

void RemoveOrphanedOperators(ir::Graph* graph, int* bfloat16_operators) const;

void RemoveUnsupportedOperators(ir::Graph* graph,
int* bfloat16_operators) const;

void ApplyImpl(ir::Graph* graph) const override;
};

Expand Down

0 comments on commit 49108ef

Please sign in to comment.