diff --git a/src/executor/simple_partition_pass.h b/src/executor/simple_partition_pass.h index 5b26a4523c13..ea1dcf39b8ba 100644 --- a/src/executor/simple_partition_pass.h +++ b/src/executor/simple_partition_pass.h @@ -102,8 +102,7 @@ class BidirectionalGraph { std::vector> get_subsets(FCompatible is_compatible) { std::vector> subgraphs; std::unordered_set incomp_set; - std::unordered_set all_set(nodes.size()); - std::vector separation_sets; + std::vector> separation_sets; // Check each node for compatibility // and, if it is incompatible, mark nodes // on each side of it as not possible to be @@ -111,48 +110,79 @@ class BidirectionalGraph { for (Node& node : nodes) { if (!is_compatible(node.nnvmptr)) { incomp_set.insert(&node); - std::unordered_set in_graph; - std::unordered_set out_graph; - std::vector dummy_head; - dummy_head.emplace_back(&node); - DFS(dummy_head, false, [&out_graph, &is_compatible](Node* node) { - if (is_compatible(node->nnvmptr)) - out_graph.insert(node); - }); - DFS(dummy_head, true, [&in_graph, is_compatible](Node* node) { - if (is_compatible(node->nnvmptr)) - in_graph.insert(node); - }); - if (!(in_graph.empty() || out_graph.empty())) - separation_sets.push_back(std::make_pair(in_graph, out_graph)); } - all_set.emplace(&node); } - IncompMap incomp_map; - std::unordered_set comp_set; - comp_set.insert(all_set.begin(), all_set.end()); - for (Node* n : incomp_set) { - comp_set.erase(n); + for (Node& node : nodes) { + if (incomp_set.count(&node) != 0) { + // Check if all your inputs are incompatible too. + // If so, then your separation set does not matter, + // because it will covered by the sets of your inputs + bool inside_node = true; + for (Node* input : node.inputs) { + if (incomp_set.count(input) == 0) { + inside_node = false; + } + } + if (!inside_node) { + std::unordered_set in_graph; + std::unordered_set out_graph; + std::vector dummy_head; + dummy_head.emplace_back(&node); + DFS(dummy_head, false, [&out_graph](Node* node) { + out_graph.insert(node); + }); + DFS(dummy_head, true, [&in_graph](Node* node) { + in_graph.insert(node); + }); + separation_sets.push_back(std::make_pair(true, + std::make_pair(in_graph, out_graph))); + } else { + separation_sets.push_back(std::make_pair(false, PairSet())); + } + } else { + separation_sets.push_back(std::make_pair(false, PairSet())); + } } + IncompMap incomp_map; // For each node construct the map of nodes that cannot be in // the same subset - for (Node* n : comp_set) { - for (PairSet p : separation_sets) { - if (p.first.count(n)) { - incomp_map[n].insert(p.second.begin(), p.second.end()); - } else if (p.second.count(n)) { - incomp_map[n].insert(p.first.begin(), p.first.end()); + index_t num_nodes = nodes.size(); + for (index_t i = 0; i < num_nodes; ++i) { + const auto n = &(nodes[i]); + if (incomp_set.count(n) == 0) { + for (index_t j = i + 1; j < num_nodes; ++j) { + const auto& sep_set_pair = separation_sets[j]; + if (sep_set_pair.first && incomp_map[n].count(&nodes[j]) == 0) { + const auto& p = sep_set_pair.second; + if (p.first.count(n)) { + incomp_map[n].insert(p.second.begin(), p.second.end()); + } else if (p.second.count(n)) { + incomp_map[n].insert(p.first.begin(), p.first.end()); + } + } + } + for (index_t j = i - 1; j >= 0; --j) { + const auto& sep_set_pair = separation_sets[j]; + if (sep_set_pair.first && incomp_map[n].count(&nodes[j]) == 0) { + const auto& p = sep_set_pair.second; + if (p.first.count(n)) { + incomp_map[n].insert(p.second.begin(), p.second.end()); + } else if (p.second.count(n)) { + incomp_map[n].insert(p.first.begin(), p.first.end()); + } + } + } + for (Node* incomp_n : incomp_set) { + incomp_map[n].erase(incomp_n); } - } - for (Node* incomp_n : incomp_set) { - incomp_map[n].erase(incomp_n); } } std::unordered_set unused_set; - unused_set.reserve(comp_set.size()); - for (auto& n : comp_set) { - unused_set.insert(n); + for (auto& n : nodes) { + if (incomp_set.count(&n) == 0) { + unused_set.insert(&n); + } } std::unordered_set visited; std::deque stack(outputs.begin(), outputs.end()); diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index ec5a79a2e675..1edd9897ec82 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -1032,17 +1032,19 @@ OpStatePtr CachedOp::Forward( CHECK_EQ(inputs.size(), num_inputs()); Context default_ctx = inputs[0]->ctx(); - auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); + { + auto state_ptr = GetCachedOpState(default_ctx); + auto& state = state_ptr.get_state(); - const auto& idx = state.info.fwd_graph.indexed_graph(); - for (size_t i = 0; i < inputs.size(); ++i) { - CHECK_EQ(inputs[i]->ctx(), default_ctx) - << "CachedOp requires all inputs to live on the same context. But " - << idx[idx.input_nodes()[0]].source->attrs.name - << " is on " << default_ctx << " while " - << idx[idx.input_nodes()[i]].source->attrs.name - << " is on " << inputs[i]->ctx(); + const auto& idx = state.info.fwd_graph.indexed_graph(); + for (size_t i = 0; i < inputs.size(); ++i) { + CHECK_EQ(inputs[i]->ctx(), default_ctx) + << "CachedOp requires all inputs to live on the same context. But " + << idx[idx.input_nodes()[0]].source->attrs.name + << " is on " << default_ctx << " while " + << idx[idx.input_nodes()[i]].source->attrs.name + << " is on " << inputs[i]->ctx(); + } } int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size);