diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index e86a475e59add..0664a63c2b72b 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -114,7 +114,8 @@ void AddOutputVar(const std::unordered_set& output_vars, // var node are from internal nodes std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, const GraphNodeSet& cluster_internals, - const GraphNodeSet& cluster_inputs) { + const GraphNodeSet& cluster_inputs, + const GraphNodeSet& cluster_outputs) { // Graph's constructor must has one parameter, and in our code, // the ProgramDesc is useless, so here we pass a temporary object. auto subgraph = std::make_unique(framework::ProgramDesc()); @@ -127,7 +128,12 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, std::unordered_map old_var2new_var; for (auto* var : cluster_internals) { - auto sub_node = subgraph->CreateVarNode(var->Var()); + Node* sub_node; + if (var->Var() == nullptr) { + sub_node = subgraph->CreateEmptyNode(var->Name(), var->NodeType()); + } else { + sub_node = subgraph->CreateVarNode(var->Var()); + } old_var2new_var[var] = sub_node; } @@ -140,7 +146,7 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, for (auto* var : op->inputs) { if (cluster_internals.count(var)) { old_op2new_op[op]->inputs.emplace_back(old_var2new_var[var]); - } else if (cluster_inputs.count(var)) { + } else if (cluster_inputs.count(var) && var->Var() != nullptr) { if (var->Var()->IsParameter()) { // Parameters have been preserved in scope, compared to feed var, // param just need add new var and don't need add feed op. @@ -157,7 +163,7 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, for (auto* var : op->outputs) { if (cluster_internals.count(var)) { old_op2new_op[op]->outputs.emplace_back(old_var2new_var[var]); - } else { + } else if (cluster_outputs.count(var) && var->Var() != nullptr) { // Create new output var node to guarantee the independency of // subgraph. In other words, the subgraph has no connection with // other graph, even the input graph. @@ -239,14 +245,20 @@ Node* AddSpecialOpToGraph(const GraphNodeSet& cluster_inputs, framework::OpDesc special_op_desc; special_op_desc.SetType(kCinnLaunchOp); std::vector input_names; - std::transform(cluster_inputs.begin(), cluster_inputs.end(), - std::back_inserter(input_names), - [](Node* n) { return n->Name(); }); + std::for_each(cluster_inputs.begin(), cluster_inputs.end(), + [&input_names](Node* n) { + if (n->Var() != nullptr) { + input_names.emplace_back(n->Name()); + } + }); special_op_desc.SetInput("X", input_names); std::vector output_names; - std::transform(cluster_outputs.begin(), cluster_outputs.end(), - std::back_inserter(output_names), - [](Node* n) { return n->Name(); }); + std::for_each(cluster_outputs.begin(), cluster_outputs.end(), + [&output_names](Node* n) { + if (n->Var() != nullptr) { + output_names.emplace_back(n->Name()); + } + }); special_op_desc.SetOutput("Out", output_names); special_op_desc.SetAttr(kCompilationKey, compilation_key); special_op_desc.Flush(); @@ -362,8 +374,8 @@ void SearchAllSubgraphs(Graph* graph) { &cluster_internals); // Create a new subgraph according to the found cluster and // save it in CinnCompiler - std::string compilation_key = cinn_compiler->AddGraph( - CreateNewSubGraph(cluster_set, cluster_internals, cluster_inputs)); + std::string compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph( + cluster_set, cluster_internals, cluster_inputs, cluster_outputs)); // Replace the found cluster to a new special op node ReplaceSubGraphWithSpecialOpNode(cluster_set, cluster_inputs, cluster_outputs, cluster_internals, diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc index ab5768e0b2be3..79a27dccb4b00 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include "gtest/gtest.h" @@ -50,9 +51,10 @@ inline int CountNode(const std::unordered_set& nodes, inline Node* GetNode(const std::unordered_set& nodes, const std::string& op_name) { - return *std::find_if( - nodes.begin(), nodes.end(), - [&op_name](const Node* node) { return node->Name() == op_name; }); + return *std::find_if(nodes.begin(), nodes.end(), + [&op_name](const Node* node) { + return node->Name().find(op_name) != std::string::npos; + }); } inline bool CheckGraphIndependence(const std::unordered_set& nodes) { @@ -185,22 +187,25 @@ std::unique_ptr BuildAllOpSupportCinnGraph() { ir::Node* mul = g->CreateOpNode(&mul_op); ir::Node* relu = g->CreateOpNode(&relu_op); + ir::Node* v0 = g->CreateEmptyNode("var0", Node::Type::kVariable); ir::Node* v1 = g->CreateVarNode(&var1); ir::Node* v2 = g->CreateVarNode(&var2); ir::Node* v3 = g->CreateVarNode(&var3); ir::Node* v4 = g->CreateVarNode(&var4); ir::Node* v5 = g->CreateVarNode(&var5); ir::Node* v6 = g->CreateVarNode(&var6); + ir::Node* v7 = g->CreateControlDepVar(); // fill op node - mul->inputs = {v1, v2}; + mul->inputs = {v0, v1, v2}; mul->outputs = {v3}; add->inputs = {v3, v4}; add->outputs = {v5}; relu->inputs = {v5}; - relu->outputs = {v6}; + relu->outputs = {v6, v7}; // fill variable node + v0->outputs = {mul}; v1->outputs = {mul}; v2->outputs = {mul}; @@ -213,6 +218,7 @@ std::unique_ptr BuildAllOpSupportCinnGraph() { v5->outputs = {relu}; v6->inputs = {relu}; + v7->inputs = {relu}; return g; } @@ -225,25 +231,28 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { pass->Apply(g.get()); // After search, the graph should as following - // v1 --| - // v2 --| --> kCinnLaunchOp --> v6 + // v0 --| + // v1 --| |--> v6 + // v2 --| --> kCinnLaunchOp |--> v7 // v4 --| const auto& nodes = g->Nodes(); - ASSERT_EQ(nodes.size(), static_cast(5)); + ASSERT_EQ(nodes.size(), static_cast(7)); ASSERT_TRUE(CheckGraphIndependence(nodes)); // A new op named kCinnLaunchOp should be added ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp)); auto* cinn_op = GetNode(nodes, kCinnLaunchOp); + auto* v0 = GetNode(nodes, "var0"); auto* v1 = GetNode(nodes, "var1"); auto* v2 = GetNode(nodes, "var2"); auto* v4 = GetNode(nodes, "var4"); auto* v6 = GetNode(nodes, "var6"); + auto* v7 = GetNode(nodes, Node::kControlDepVarName); ASSERT_EQ( std::unordered_set(cinn_op->inputs.begin(), cinn_op->inputs.end()), - std::unordered_set({v1, v2, v4})); - ASSERT_EQ(cinn_op->outputs, std::vector({v6})); + std::unordered_set({v0, v1, v2, v4})); + ASSERT_EQ(cinn_op->outputs, std::vector({v6, v7})); ASSERT_EQ(v1->outputs, std::vector({cinn_op})); ASSERT_EQ(v6->inputs, std::vector({cinn_op}));