Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Improve the speed of the pointwise fusion graph pass #17114

Merged
merged 5 commits into from
Dec 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 64 additions & 34 deletions src/executor/simple_partition_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,57 +102,87 @@ class BidirectionalGraph {
std::vector<std::unordered_set<Node*>> get_subsets(FCompatible is_compatible) {
std::vector<std::unordered_set<Node*>> subgraphs;
std::unordered_set<Node*> incomp_set;
std::unordered_set<Node*> all_set(nodes.size());
std::vector<PairSet> separation_sets;
std::vector<std::pair<bool, PairSet>> separation_sets;
// Check each node for compatibility
// and, if it is incompatible, mark nodes
// on each side of it as not possible to be
// in the same subset
for (Node& node : nodes) {
if (!is_compatible(node.nnvmptr)) {
incomp_set.insert(&node);
std::unordered_set<Node*> in_graph;
std::unordered_set<Node*> out_graph;
std::vector<Node*> dummy_head;
dummy_head.emplace_back(&node);
DFS(dummy_head, false, [&out_graph, &is_compatible](Node* node) {
if (is_compatible(node->nnvmptr))
out_graph.insert(node);
});
DFS(dummy_head, true, [&in_graph, is_compatible](Node* node) {
if (is_compatible(node->nnvmptr))
in_graph.insert(node);
});
if (!(in_graph.empty() || out_graph.empty()))
separation_sets.push_back(std::make_pair(in_graph, out_graph));
}
all_set.emplace(&node);
}
IncompMap incomp_map;
std::unordered_set<Node*> comp_set;
comp_set.insert(all_set.begin(), all_set.end());
for (Node* n : incomp_set) {
comp_set.erase(n);
for (Node& node : nodes) {
if (incomp_set.count(&node) != 0) {
// Check if all your inputs are incompatible too.
// If so, then your separation set does not matter,
// because it will covered by the sets of your inputs
bool inside_node = true;
for (Node* input : node.inputs) {
if (incomp_set.count(input) == 0) {
inside_node = false;
}
}
if (!inside_node) {
std::unordered_set<Node*> in_graph;
std::unordered_set<Node*> out_graph;
std::vector<Node*> dummy_head;
dummy_head.emplace_back(&node);
DFS(dummy_head, false, [&out_graph](Node* node) {
out_graph.insert(node);
});
DFS(dummy_head, true, [&in_graph](Node* node) {
in_graph.insert(node);
});
separation_sets.push_back(std::make_pair(true,
std::make_pair(in_graph, out_graph)));
} else {
separation_sets.push_back(std::make_pair(false, PairSet()));
}
} else {
separation_sets.push_back(std::make_pair(false, PairSet()));
}
}
IncompMap incomp_map;
// For each node construct the map of nodes that cannot be in
// the same subset
for (Node* n : comp_set) {
for (PairSet p : separation_sets) {
if (p.first.count(n)) {
incomp_map[n].insert(p.second.begin(), p.second.end());
} else if (p.second.count(n)) {
incomp_map[n].insert(p.first.begin(), p.first.end());
index_t num_nodes = nodes.size();
for (index_t i = 0; i < num_nodes; ++i) {
const auto n = &(nodes[i]);
if (incomp_set.count(n) == 0) {
for (index_t j = i + 1; j < num_nodes; ++j) {
const auto& sep_set_pair = separation_sets[j];
if (sep_set_pair.first && incomp_map[n].count(&nodes[j]) == 0) {
const auto& p = sep_set_pair.second;
if (p.first.count(n)) {
incomp_map[n].insert(p.second.begin(), p.second.end());
} else if (p.second.count(n)) {
incomp_map[n].insert(p.first.begin(), p.first.end());
}
}
}
for (index_t j = i - 1; j >= 0; --j) {
const auto& sep_set_pair = separation_sets[j];
if (sep_set_pair.first && incomp_map[n].count(&nodes[j]) == 0) {
const auto& p = sep_set_pair.second;
if (p.first.count(n)) {
incomp_map[n].insert(p.second.begin(), p.second.end());
} else if (p.second.count(n)) {
incomp_map[n].insert(p.first.begin(), p.first.end());
}
}
}
for (Node* incomp_n : incomp_set) {
incomp_map[n].erase(incomp_n);
}
}
for (Node* incomp_n : incomp_set) {
incomp_map[n].erase(incomp_n);
}
}
std::unordered_set<Node*> unused_set;
unused_set.reserve(comp_set.size());

for (auto& n : comp_set) {
unused_set.insert(n);
for (auto& n : nodes) {
if (incomp_set.count(&n) == 0) {
unused_set.insert(&n);
}
}
std::unordered_set<Node*> visited;
std::deque<Node*> stack(outputs.begin(), outputs.end());
Expand Down
22 changes: 12 additions & 10 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1032,17 +1032,19 @@ OpStatePtr CachedOp::Forward(
CHECK_EQ(inputs.size(), num_inputs());

Context default_ctx = inputs[0]->ctx();
auto state_ptr = GetCachedOpState(default_ctx);
auto& state = state_ptr.get_state<CachedOpState>();
{
auto state_ptr = GetCachedOpState(default_ctx);
auto& state = state_ptr.get_state<CachedOpState>();

const auto& idx = state.info.fwd_graph.indexed_graph();
for (size_t i = 0; i < inputs.size(); ++i) {
CHECK_EQ(inputs[i]->ctx(), default_ctx)
<< "CachedOp requires all inputs to live on the same context. But "
<< idx[idx.input_nodes()[0]].source->attrs.name
<< " is on " << default_ctx << " while "
<< idx[idx.input_nodes()[i]].source->attrs.name
<< " is on " << inputs[i]->ctx();
const auto& idx = state.info.fwd_graph.indexed_graph();
for (size_t i = 0; i < inputs.size(); ++i) {
CHECK_EQ(inputs[i]->ctx(), default_ctx)
<< "CachedOp requires all inputs to live on the same context. But "
<< idx[idx.input_nodes()[0]].source->attrs.name
<< " is on " << default_ctx << " while "
<< idx[idx.input_nodes()[i]].source->attrs.name
<< " is on " << inputs[i]->ctx();
}
}

int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size);
Expand Down