Skip to content

Commit

Permalink
Enable partitioning algo to eliminate cycles
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Jun 20, 2018
1 parent 31cc8fc commit 1293828
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 7 deletions.
142 changes: 137 additions & 5 deletions src/operator/subgraph/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,23 @@ void PrintSubgraph(const std::vector<SimpleNode*>& simple_nodes) {
LOG(INFO) << "Subgraph node names: " << op_names;
}

void ResetSubgraphNodes(std::vector<SimpleNode*>* subgraph_nodes) {
for (auto sn : *subgraph_nodes) {
sn->label = -1;
}
subgraph_nodes->clear();
}

/*
* This function traverses the nodes in a computation graph from a starting
* node following the input links and output links, and marks all nodes that
* can be accessed from the starting node.
*/
#if 0
void LabelSubgraph(const Graph&g,
SubgraphSelectorPtr select_func,
const int label,
const size_t snid, // simple node id
const size_t snid, // simple node id, this is a seed
const std::vector<SimpleNodePtr>& simple_nodes,
std::vector<SimpleNode*>* subgraph_nodes) {
const auto& indexed_graph = g.indexed_graph();
Expand Down Expand Up @@ -142,6 +150,130 @@ void LabelSubgraph(const Graph&g,
}
}
}
#endif

bool LabelSubgraph(const Graph& g,
SubgraphSelectorPtr select_func,
const int label,
const size_t snid, // simple node id, this is a seed
const std::vector<SimpleNodePtr>& simple_nodes,
std::vector<SimpleNode*>* subgraph_nodes,
std::unordered_set<const nnvm::Node*>* excluded_nodes = nullptr) {
const auto& indexed_graph = g.indexed_graph();
std::queue<SimpleNode*> node_queue;
if (!excluded_nodes || !excluded_nodes->count(simple_nodes[snid]->node)) {
node_queue.push(simple_nodes[snid].get());
}
// key: nodes that serve as input/output nodes to the subgraph
// value: pair of vectors of nodes in the subgraph. The first vector contains the output nodes of the key in the subgraph,
// and the second vector contains the input ndoes of the key in the subgraph.
// If both vectors are non-empty, it means there is a loop between the subgraph and the key node.
// When breaking the loop, we want to start removing the node with the largest node id.
std::unordered_map<const nnvm::Node*,
std::pair<std::vector<const nnvm::Node*>, std::vector<const nnvm::Node*>>> non_subgraph_node_map;
while (!node_queue.empty()) {
SimpleNode* cur_node = node_queue.front();
node_queue.pop();
cur_node->label = label;
subgraph_nodes->push_back(cur_node);
// get qualified adjacent input nodes
for (auto& e : cur_node->node->inputs) {
const bool select_input = (!excluded_nodes || !excluded_nodes->count(e.node.get()))
&& select_func->SelectInput(*cur_node->node, *e.node);
if (select_input) {
// e.node is a subgraph node
const auto nid = indexed_graph.node_id(e.node.get());
CHECK_LT(nid, simple_nodes.size());
// this node has not been visited yet
if (simple_nodes[nid]->label == -1) {
node_queue.push(simple_nodes[nid].get());
}
} else {
// e.node is an input node of the subgraph
non_subgraph_node_map[e.node.get()].first.push_back(cur_node->node);
}
}
// get qualified output nodes
for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); ++it) {
const bool select_output = (!excluded_nodes || !excluded_nodes->count(it->first))
&& select_func->SelectOutput(*cur_node->node, *it->first);
if (select_output) {
// it->first is a subgraph node
const auto nid = indexed_graph.node_id(it->first);
CHECK_LT(nid, simple_nodes.size());
// this node has not been visited yet
if (simple_nodes[nid]->label == -1) {
node_queue.push(simple_nodes[nid].get());
}
} else {
// it->first is an output node of the subgraph
non_subgraph_node_map[it->first].second.push_back(cur_node->node);
}
}
}
auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
};
// check whether there is a loop between the subgraph and its input/output nodes
int excluded_node_id = -1;
for (auto& kv : non_subgraph_node_map) {
auto& output_nodes = kv.second.first;
auto& input_nodes = kv.second.second;
if (!output_nodes.empty() && !input_nodes.empty()) {
// there is a loop between kv->first and the subgraph
std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
const auto node_id = std::max(indexed_graph.node_id(output_nodes.back()),
indexed_graph.node_id(input_nodes.back()));
excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
}
}
if (excluded_node_id != -1) {
CHECK_LT(excluded_node_id, static_cast<int>(simple_nodes.size()));
CHECK_NE(excluded_node_id, static_cast<int>(snid))
<< "A cycle is found in the computational graph between nodes "
<< simple_nodes[excluded_node_id]->node->attrs.name << " and "
<< simple_nodes[snid]->node->attrs.name;
excluded_nodes->insert(simple_nodes[excluded_node_id]->node);
ResetSubgraphNodes(subgraph_nodes);
return false;
}
return true;
}

void FindSubgraph(const Graph& g,
SubgraphSelectorPtr select_func,
const int label,
const size_t snid, // simple node id, this is a seed
const std::vector<SimpleNodePtr>& simple_nodes,
std::vector<SimpleNode*>* subgraph_nodes) {
std::unordered_set<const nnvm::Node*> excluded_nodes;
const size_t max_num_retry = simple_nodes.size() * simple_nodes.size();
size_t count = 0;
bool success = false;
while (!success && count < max_num_retry) {
success = LabelSubgraph(g, select_func, label, snid, simple_nodes, subgraph_nodes, &excluded_nodes);
if (!success) {
CHECK(!excluded_nodes.empty());
std::string excluded_node_names;
for (auto node : excluded_nodes) {
excluded_node_names += node->attrs.name + ", ";
}
LOG(INFO) << "Found a cycle when BFS from node " << simple_nodes[snid]->node->attrs.name
<< ". Excluding nodes " << excluded_node_names << "and retry";
}
++count;
}
if (!success) {
LOG(INFO) << "Tried " << count << " times of finding subgraphs starting from node "
<< simple_nodes[snid]->node->attrs.name << " without success because a loop "
"is always found between the subgraph and some other nodes. Will treat "
"seed node " << simple_nodes[snid]->node->attrs.name << "as a subgraph with one node";
CHECK(subgraph_nodes->empty());
simple_nodes[snid]->label = label;
subgraph_nodes->push_back(simple_nodes[snid].get());
}
}

/*
* This function finds subgraphs with all nodes that meet certain criteria.
Expand All @@ -161,8 +293,8 @@ void FindSubgraphs(const Graph& g,
auto select_func = subg_prop.CreateSubgraphSelector();
if (select_func->Select(*node) && simple_nodes[i]->label == -1) {
subgraph_nodes->emplace_back();
LabelSubgraph(g, select_func, subgraph_nodes->size() - 1, i, simple_nodes,
&subgraph_nodes->back());
FindSubgraph(g, select_func, subgraph_nodes->size() - 1, i, simple_nodes,
&subgraph_nodes->back());
}
}
}
Expand Down Expand Up @@ -357,8 +489,7 @@ Graph CreateSubgraphNode(Graph&& g,
// update entry_top_order_map with newly created orig_input_entries
auto it = entry_top_order_map->find(input_entries[i]);
CHECK(it != entry_top_order_map->end());
auto entry_it = entry_top_order_map->find(&e);
CHECK(entry_it == entry_top_order_map->end());
CHECK_EQ(entry_top_order_map->count(&e), 0U);
entry_top_order_map->emplace(&e, it->second);
// update input entries' source simple nodes' outputs map
nnvm::Node* node = e.node.get();
Expand All @@ -385,6 +516,7 @@ Graph CreateSubgraphNode(Graph&& g,
void TopSortEntries(const Graph& g, std::unordered_map<const nnvm::NodeEntry*, size_t>* entry_top_order_map) {
CHECK(entry_top_order_map != nullptr);
std::unordered_set<const nnvm::Node*> visited;
// meaning of tuple: (graph node, index of node's inputs, node entry as the output of the graph node)
std::stack<std::tuple<nnvm::Node*, size_t, const nnvm::NodeEntry*>> s;
auto in_degree = [] (const nnvm::Node* node)->size_t {
if (!node) {
Expand Down
27 changes: 25 additions & 2 deletions tests/python/unittest/test_subgraph_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ def check_input_order(sym, op_names):
c_str_array(op_names), ctypes.byref(out)))

new_sym = Symbol(out)
print(sym.list_inputs())
print(new_sym.list_inputs())
#print(sym.list_inputs())
#print(new_sym.list_inputs())
assert new_sym.list_inputs() == sym.list_inputs()
print('original outputs: %s' % sym.list_outputs())
print('new sym outputs: %s' % new_sym.list_outputs())

def test_network_structure_1():
data1 = mx.sym.var('data1')
Expand All @@ -65,7 +67,28 @@ def test_network_structure_1():
out = mx.sym.Group([conv1, conv2])
check_input_order(out, ['Convolution'])

def test_network_structure_2():
data1 = mx.sym.var('data1')
data2 = mx.sym.var('data2')
conv1 = mx.sym.Convolution(data=data1, weight=data2, no_bias=True, kernel=(2, 2), num_filter=1)
conv2 = mx.sym.Convolution(data=data2, weight=data1, no_bias=True, kernel=(2, 2), num_filter=1)
out = conv1 + conv2
check_input_order(out, ['Convolution'])
check_input_order(out, ['Convolution', '_Plus', 'elemwise_add', '_plus'])

def test_network_structure_3():
# this tests whether the partitioning algorithm can deal with cycles
data = mx.sym.var('data')
ret = mx.sym.exp(data)
ret1 = mx.sym.cos(ret)
ret2 = mx.sym.sin(ret)
ret = ret1 + ret2
check_input_order(ret, ['exp', 'sin', '_Plus', 'elemwise_add', '_plus'])
check_input_order(ret, ['exp', 'cos', '_Plus', 'elemwise_add', '_plus'])

test_network_structure_1()
test_network_structure_2()
test_network_structure_3()


if __name__ == '__main__':
Expand Down

0 comments on commit 1293828

Please sign in to comment.