From ed852942cd59c656779c899e77e5c30fd2ad9388 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 8 Nov 2019 13:27:49 -0800 Subject: [PATCH] Fix --- src/executor/pointwise_fusion_pass.cc | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/executor/pointwise_fusion_pass.cc b/src/executor/pointwise_fusion_pass.cc index c6e2405cb2a4..6fe21402cb3a 100644 --- a/src/executor/pointwise_fusion_pass.cc +++ b/src/executor/pointwise_fusion_pass.cc @@ -83,15 +83,7 @@ namespace { auto node = nnvm::Node::Create(); subgraph_sym.outputs = subgraph.outputs; node->attrs.subgraphs.emplace_back(std::make_shared(subgraph_sym)); - std::ostringstream name_oss; - // the name of the new node will be the concatenation of all the node names in the subgraph - DFSVisit(subgraph.outputs, [&name_oss](const nnvm::NodePtr n) { - if (n->op() != nullptr) - name_oss << n->op()->name << "_"; - }); - auto subgraph_name = name_oss.str(); - subgraph_name.pop_back(); - node->attrs.name = subgraph_name; + node->attrs.name = "FusedOp"; node->attrs.dict["num_inputs"] = std::to_string(inputs_size); node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size()); node->attrs.op = Op::Get("_FusedOp"); @@ -152,7 +144,8 @@ Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector& sub auto it = node->control_deps.begin(); static auto& is_fusion = Op::GetAttr("TIsFusionHelper"); std::vector new_control_deps; - while (it != node->control_deps.end()) { + // Use the first control dependency to get the inferattr helper + if (it != node->control_deps.end()) { if (subgraph_set.count(it->get())) { new_control_deps.push_back(*it); } else { @@ -160,8 +153,7 @@ Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector& sub uint32_t node_id = subgraph_node->control_deps.size(); subgraph_node->control_deps.push_back(*it); auto helper_node = op::MakeNode("_FusedOpOutHelper", - subgraph_node->attrs.name + "_" - + node->attrs.name + "_outhelper", + "FusedOp_" + node->attrs.name + "_outhelper", nullptr, nullptr, nullptr); @@ -180,6 +172,17 @@ Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector& sub } }); + std::ostringstream name_oss; + // the name of the new node will be the concatenation of all the node names in the subgraph + DFSVisit(subgraph.outputs, [&name_oss](const nnvm::NodePtr n) { + if (n->op() != nullptr) { + name_oss << n->op()->name << "_"; + } + }); + auto subgraph_name = name_oss.str(); + subgraph_name.pop_back(); + subgraph_node->attrs.name = subgraph_name; + const auto& index = subgraph.indexed_graph(); DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const nnvm::NodePtr& node) { for (auto &e : node->control_deps) {