Skip to content

Commit

Permalink
Add FLAGS_allow_cinn_ops & FLAGS_deny_cinn_ops for controlling op typ…
Browse files Browse the repository at this point in the history
…es used in training with CINN. (#36842)

* Update UT test_parallel_executor_run_cinn.py.

* Add FLAGS_allow_cinn_ops & FLAGS_deny_cinn_ops & FLAGS_cinn_ops_delim.

* Use the custom StringSplit function and remove the FLAGS_cinn_ops_delim flag.

* Add FlagController test.

* Apply lock to the cache_ only in CinnCompiler.

* Add VizGraph & ReadableKey method for CinnCompiler.

* Update the dot style of VizGraph in CinnCompiler.
  • Loading branch information
wzzju authored Nov 3, 2021
1 parent fb39469 commit 2479664
Show file tree
Hide file tree
Showing 9 changed files with 346 additions and 63 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/paddle2cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper l
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector graph_pattern_detector cinn_compiler errors enforce)
cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn)
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph transform_desc cinn)
cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn)
cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS framework_proto graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn)

cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key)
cc_test(build_cinn_pass_test SRCS build_cinn_pass_test.cc DEPS build_cinn_pass cinn_compiler)
Expand Down
46 changes: 42 additions & 4 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm>
#include <iterator>
#include <memory>
#include <regex>
#include <string>
#include <unordered_map>
#include <unordered_set>
Expand All @@ -25,6 +26,8 @@ limitations under the License. */

#include "cinn/frontend/op_mapper_registry.h"
#include "cinn/frontend/op_mappers/use_op_mappers.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
Expand All @@ -34,6 +37,9 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"

DECLARE_string(allow_cinn_ops);
DECLARE_string(deny_cinn_ops);

namespace paddle {
namespace framework {
namespace paddle2cinn {
Expand All @@ -46,6 +52,20 @@ using GraphNodeSet = std::unordered_set<Node*>;
using GraphNodeMap = std::unordered_map<Node*, Node*>;

namespace {
// The delim(`;`) that is used to split the FLAGS_allow_cinn_ops
// & FLAGS_deny_cinn_ops.
constexpr char kDelim[] = ";";

std::unordered_set<std::string> StringSplit(const std::string& str,
const std::string& delim) {
std::regex reg(delim);
std::unordered_set<std::string> elems{
std::sregex_token_iterator(str.begin(), str.end(), reg, -1),
std::sregex_token_iterator()};
elems.erase("");
return elems;
}

int ExtractOpRole(const GraphNodeSet& cluster) {
std::unordered_set<int> op_roles;
std::string attr_name = OpProtoAndCheckerMaker::OpRoleAttrName();
Expand Down Expand Up @@ -339,10 +359,27 @@ void ReplaceSubGraphWithCinnOpNode(const GraphNodeSet& cluster,
// all of op node supported by CINN. We using OpMapperRegistry
// to check whether the op node supported by CINN.
void SearchAllSubgraphs(Graph* graph) {
auto teller = [](const Node* node) {
return ::cinn::frontend::OpMapperRegistry::Global()->Find(node->Name()) !=
nullptr;
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
auto teller = [&allow_ops, &deny_ops](const Node* node) {
bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find(
node->Name()) != nullptr;
// if the op type is registered in CINN and allow_ops is not empty, return
// true only when it is in allow_ops
if (allow_ops.size()) {
return registered && allow_ops.count(node->Name());
}
// if the op type is registered in CINN and deny_ops is not empty, return
// true only when it is not in deny_ops
if (deny_ops.size()) {
return registered && !deny_ops.count(node->Name());
}
// if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops,
// return true only when it is registered in CINN
return registered;
};
VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops;
VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops;
std::vector<GraphNodeVec> clusters =
framework::ir::SubgraphDetector(graph, teller)();

Expand Down Expand Up @@ -375,7 +412,8 @@ void SearchAllSubgraphs(Graph* graph) {
// save it in CinnCompiler
std::string compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph(
cluster_set, cluster_internals, cluster_inputs, cluster_outputs));
VLOG(4) << "Compilation Key: " << compilation_key;
VLOG(4) << "Compilation Key:\n"
<< cinn_compiler->ReadableKey(compilation_key);

// Replace the found cluster to a new cinn op node
ReplaceSubGraphWithCinnOpNode(cluster_set, cluster_inputs, cluster_outputs,
Expand Down
134 changes: 108 additions & 26 deletions paddle/fluid/framework/paddle2cinn/cinn_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,42 @@

#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"

#include <iterator>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>

#include "cinn/common/target.h"
#include "cinn/common/type.h"
#include "cinn/frontend/decomposer/use_decomposer.h"
#include "cinn/frontend/net_builder.h" // need to remove after
#include "cinn/frontend/pass/use_program_pass.h"
#include "cinn/frontend/program_pass.h"
#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/pass.h"
#include "cinn/hlir/pass/use_pass.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h"

namespace paddle {
namespace framework {
namespace paddle2cinn {

using ir::Graph;
using ir::Node;
using inference::analysis::Dot;
using ::cinn::common::Target;
using ::cinn::common::Float;
using ::cinn::hlir::framework::GraphCompiler;
Expand All @@ -54,47 +62,121 @@ CinnCompiler* CinnCompiler::GetInstance() {
return &instance;
}

const CinnCompiledObject& CinnCompiler::Compile(
const Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) {
CinnCacheKey cur_key(graph, input_tensors, target.arch_str());
bool exist = false;
{
AutoRDLock r_guard{&rwlock_};
exist = cache_.count(cur_key) != 0;
}
if (!exist) {
real_compiled_num_++;
auto compiled_res = CompileGraph(graph, input_tensors, target);
AutoWRLock w_guard{&rwlock_};
if (!cache_.count(cur_key)) {
cache_[cur_key] = std::move(compiled_res);
}
}
AutoRDLock guard{&rwlock_};
const auto& cached_boj = *cache_[cur_key];
return cached_boj;
}

const CinnCompiledObject& CinnCompiler::Compile(
const std::string& compilation_key,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) {
VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(compilation_key);
const auto& graph = FindGraph(compilation_key);
return Compile(graph, input_tensors, target);
}

std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
std::string graph_key;
ProgramDesc program;
GraphToProgram(*graph, &program);
program.Proto()->SerializeToString(&graph_key);
if (!graphs_.count(graph_key)) {
graphs_[graph_key] = std::move(graph);
} else {
LOG(WARNING)
<< "The graph being added is already in CinnCompiler. Its key is:\n"
<< graph_key;
}

PADDLE_ENFORCE_EQ(
graphs_.count(graph_key), 0,
platform::errors::PreconditionNotMet(
"The graph to be added is already in CinnCompiler, which is:\n",
VizGraph(graph_key).c_str()));
graphs_[graph_key] = std::move(graph);
VLOG(4) << "-- Add a graph into CinnCompiler, which is:\n"
<< VizGraph(graph_key);
return graph_key;
}

const Graph& CinnCompiler::FindGraph(const std::string& graph_key) const {
PADDLE_ENFORCE_NE(
graphs_.count(graph_key), 0,
platform::errors::InvalidArgument("Can not find the target graph: %s",
graph_key.c_str()));
platform::errors::PreconditionNotMet(
"Can not find the target graph, of which the key is:\n%s",
ReadableKey(graph_key).c_str()));
return *graphs_.at(graph_key);
}

const CinnCompiledObject& CinnCompiler::Compile(
const Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) {
CinnCacheKey cur_key(graph, input_tensors, target.arch_str());
if (!cache_.count(cur_key)) {
real_compiled_num_++;
cache_[cur_key] = CompileGraph(graph, input_tensors, target);
std::string CinnCompiler::VizGraph(const std::string& key) const {
Dot dot;
std::unordered_map<const Node*, std::string> node2dot;
const Graph& graph = FindGraph(key);
int id = 0;
// Create nodes
for (const Node* n : graph.Nodes()) {
std::string node_id = "Node" + std::to_string(id++);
if (n->IsOp()) {
dot.AddNode(
node_id,
{Dot::Attr("shape", "box"), Dot::Attr("style", "rounded,filled,bold"),
Dot::Attr("color", "#303A3A"), Dot::Attr("fontcolor", "#ffffff")},
n->Name());
} else if (n->IsVar()) {
auto label = n->Name();
if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) {
auto shape = n->Var()->GetShape();
std::vector<std::string> shape_str(shape.size());
std::transform(shape.begin(), shape.end(), shape_str.begin(),
[](const auto& val) { return std::to_string(val); });
label += "\n" + string::join_strings(shape_str, ',');
}
dot.AddNode(
node_id,
{Dot::Attr("shape", "box"), Dot::Attr("style", "rounded,filled,bold"),
Dot::Attr("color", n->Var()->IsParameter() ? "#148b97" : "#dddddd"),
Dot::Attr("fontcolor",
n->Var()->IsParameter() ? "#ffffff" : "#000000")},
label);
}
node2dot[n] = node_id;
}
// Create edges
for (const Node* n : graph.Nodes()) {
const auto& src_id = node2dot.at(n);
for (auto* out : n->outputs) {
const auto& dest_id = node2dot.at(out);
dot.AddEdge(src_id, dest_id, {});
}
}
return *cache_[cur_key];
return dot.Build();
}

const CinnCompiledObject& CinnCompiler::Compile(
const std::string& compilation_key,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) {
const auto& graph = FindGraph(compilation_key);
return Compile(graph, input_tensors, target);
std::string CinnCompiler::ReadableKey(const std::string& key) const {
proto::ProgramDesc desc;
desc.ParseFromString(key);
return desc.DebugString();
}

void CinnCompiler::Clear() {
{
AutoWRLock guard{&rwlock_};
graphs_.clear();
cache_.clear();
}
real_compiled_num_ = 0;
}

std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
Expand All @@ -107,7 +189,7 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
ProgramPass::Apply(&frontend_program, target, {"Decomposer"});
auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>(
frontend_program, target);
VLOG(4) << "The " << real_compiled_num_ << "-th compilation ("
VLOG(4) << "-- The " << real_compiled_num_ << "-th compilation ("
<< target.arch_str() << "), and its related graph:\n"
<< cinn_graph->Visualize();
ApplyPass(cinn_graph.get(), "OpFusion");
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/framework/paddle2cinn/cinn_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/macros.h"

Expand Down Expand Up @@ -64,6 +65,12 @@ class CinnCompiler {

const ir::Graph& FindGraph(const std::string& key) const;

std::string VizGraph(const std::string& key) const;

std::string ReadableKey(const std::string& key) const;

void Clear();

std::int64_t real_compiled_num() const { return real_compiled_num_; }

~CinnCompiler() = default;
Expand All @@ -80,6 +87,7 @@ class CinnCompiler {
CinnCacheKey::Hash>
cache_;
std::atomic_int64_t real_compiled_num_{0};
mutable RWLock rwlock_;

DISABLE_COPY_AND_ASSIGN(CinnCompiler);
};
Expand Down
Loading

0 comments on commit 2479664

Please sign in to comment.