Skip to content

Commit

Permalink
generalize graph partitioning. (zheng-da#11)
Browse files Browse the repository at this point in the history
* add functions for cutting edges.

* construct subgraphs.

* generalize graph partition.

* restructure the code.

* create SubgraphOpState.

* register subgraph property.

* rename.

* address comments.

* update select API.

* rename.

* set subgraph property.

* fix bugs.

* fix bugs.
  • Loading branch information
zheng-da authored and reminisce committed Jun 13, 2018
1 parent 47bbf9f commit 43d0a75
Show file tree
Hide file tree
Showing 4 changed files with 380 additions and 118 deletions.
7 changes: 6 additions & 1 deletion src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "./c_api_common.h"
#include "../operator/operator_common.h"
#include "../executor/exec_pass.h"
#include "../operator/subgraph/subgraph_op.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -636,7 +637,11 @@ int MXPartitionGraph(SymbolHandle sym_handle,
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
*s = sym->Copy();
nnvm::Graph g = Symbol2Graph(*s);
g.attrs["subgraph_op_names"] = std::make_shared<nnvm::any>(std::move(op_name_set));
if (!op_name_set.empty()) {
mxnet::op::SubgraphPropertyPtr property
= std::make_shared<mxnet::op::SimpleSubgraphProperty>(op_name_set);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
}
g = ApplyPass(std::move(g), "PartitionGraph");
s->outputs = g.outputs;
*ret_sym_handle = s;
Expand Down
140 changes: 89 additions & 51 deletions src/operator/subgraph/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@
#include <mxnet/op_attr_types.h>
#include <unordered_set>

#include "./subgraph_op.h"

namespace nnvm {
NodePtr CreateVariableNode(const std::string& name);
}

namespace mxnet {

namespace op {

using nnvm::Symbol;
Expand All @@ -50,24 +57,6 @@ NodePtr CloneVariableNode(const nnvm::Node& src) {

namespace sg { // sg stands for subgraph

struct SimpleNode;
using SimpleNodePtr = std::shared_ptr<SimpleNode>;

struct SimpleNode {
static SimpleNodePtr Create() {
return std::make_shared<SimpleNode>();
}
SimpleNode() : label(-1), node(nullptr) {}
int label;
nnvm::Node* node;
// key is node ptr
// value is the index array standing for the entry indices
// in key->inputs that use this->node as input node
std::unordered_map<Node*, std::vector<int>> outputs;
//std::unordered_map<SimpleNodePtr, int> inputs;
//std::unordered_map<SimpleNodePtr, int> outputs;
};

void CreateSimpleGraph(const Graph& g,
std::vector<SimpleNodePtr>* simple_nodes) {
const auto& indexed_graph = g.indexed_graph();
Expand Down Expand Up @@ -111,8 +100,13 @@ void PrintSubgraph(const std::vector<SimpleNode*>& simple_nodes) {
LOG(INFO) << "Subgraph node names: " << op_names;
}

/*
* This function traverses the nodes in a computation graph from a starting
* node following the input links and output links, and marks all nodes that
* can be accessed from the starting node.
*/
void LabelSubgraph(const Graph&g,
const std::unordered_set<std::string>& op_names,
SubgraphSelectorPtr select_func,
const int label,
const size_t snid, // simple node id
const std::vector<SimpleNodePtr>& simple_nodes,
Expand All @@ -126,46 +120,52 @@ void LabelSubgraph(const Graph&g,
cur_node->label = label;
subgraph_nodes->push_back(cur_node);
// get qualified adjacent input nodes
for (auto& e : cur_node->node->inputs) {
if (!e.node->is_variable() && op_names.count(e.node->op()->name)) {
const auto nid = indexed_graph.node_id(e.node.get());
CHECK_LT(nid, simple_nodes.size());
if (simple_nodes[nid]->label == -1) { // this node has not been visited yet
node_queue.push(simple_nodes[nid].get());
} else {
CHECK_EQ(simple_nodes[nid]->label, label);
if (select_func->UseIncomingEdges()) {
for (auto& e : cur_node->node->inputs) {
if (select_func->Select(*e.node)) {
const auto nid = indexed_graph.node_id(e.node.get());
CHECK_LT(nid, simple_nodes.size());
// this node has not been visited yet
if (simple_nodes[nid]->label == -1)
node_queue.push(simple_nodes[nid].get());
}
}
}
// get qualified output nodes
for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); ++it) {
CHECK(!it->first->is_variable());
if (op_names.count(it->first->op()->name)) {
const auto nid = indexed_graph.node_id(it->first);
CHECK_LT(nid, simple_nodes.size());
if (simple_nodes[nid]->label == -1) { // this node has not been visited yet
node_queue.push(simple_nodes[nid].get());
} else {
CHECK_EQ(simple_nodes[nid]->label, label);
if (select_func->UseOutgoingEdges()) {
for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); ++it) {
if (select_func->Select(*it->first)) {
const auto nid = indexed_graph.node_id(it->first);
CHECK_LT(nid, simple_nodes.size());
// this node has not been visited yet
if (simple_nodes[nid]->label == -1)
node_queue.push(simple_nodes[nid].get());
}
}
}
}
}

// number of subgraphs found
/*
* This function finds subgraphs with all nodes that meet certain criteria.
* All nodes in a subgraph are marked with the same label.
* All nodes in a subgraph have to be connected with each other. If a node
* doesn't meet the given criteria, it will be marked with a separate label.
*/
void FindSubgraphs(const Graph& g,
const std::unordered_set<std::string>& op_names,
const SubgraphProperty &subg_prop,
const std::vector<SimpleNodePtr>& simple_nodes,
std::vector<std::vector<SimpleNode*>>* subgraph_nodes) {
//CHECK(simple_nodes != nullptr);
const auto& indexed_graph = g.indexed_graph();
CHECK_EQ(indexed_graph.num_nodes(), simple_nodes.size());
for (size_t i = 0; i < simple_nodes.size(); ++i) {
nnvm::Node* node = simple_nodes[i]->node;
if (!node->is_variable() && simple_nodes[i]->label == -1 && op_names.count(node->op()->name)) {
auto select_func = subg_prop.CreateSubgraphSelector();
if (select_func->Select(*node) && simple_nodes[i]->label == -1) {
subgraph_nodes->emplace_back();
LabelSubgraph(g, op_names, subgraph_nodes->size() - 1, i, simple_nodes, &subgraph_nodes->back());
LabelSubgraph(g, select_func, subgraph_nodes->size() - 1, i, simple_nodes,
&subgraph_nodes->back());
}
}
}
Expand All @@ -185,11 +185,9 @@ void FindInputEntries(const Graph& g,
}
for (auto& e : subgraph_nodes[i]->node->inputs) {
const auto nid = indexed_graph.node_id(e.node.get());
if (simple_nodes[nid]->label == -1) { // this is a node not belonging to the subgraph
// this is a node not belonging to the subgraph
if (simple_nodes[nid]->label != label)
input_entries->push_back(&e);
} else {
CHECK_EQ(simple_nodes[nid]->label, label);
}
}
}
}
Expand All @@ -208,14 +206,15 @@ void FindOutputEntries(Graph* g,
} else {
CHECK_EQ(subgraph_nodes[i]->label, label);
}
for (auto it = subgraph_nodes[i]->outputs.begin(); it != subgraph_nodes[i]->outputs.end(); ++it) {
for (auto it = subgraph_nodes[i]->outputs.begin();
it != subgraph_nodes[i]->outputs.end(); ++it) {
const auto nid = indexed_graph.node_id(it->first);
if (simple_nodes[nid]->label == -1) { // this is a node not belonging to the subgraph
// this is a node not belonging to the subgraph
if (simple_nodes[nid]->label != label) {
// TODO(zhengda) I need to test this.
for (int idx : it->second) {
output_entries->push_back(&simple_nodes[nid]->node->inputs[idx]);
}
} else {
CHECK_EQ(simple_nodes[nid]->label, label);
}
}
}
Expand All @@ -241,6 +240,29 @@ void PrintNodeEntries(const std::vector<nnvm::NodeEntry*>& entries) {
}
}

/*
* Given a computation graph and a set of input node entries, this function cuts
* the node entries and creates new variable nodes as the input nodes of the
* subgraph. It returns the nodes that connect to the subgraph directly and
* the names of the new variable nodes.
*/
void CutGraphInputs(const std::vector<nnvm::NodeEntry *> &input_entries,
bool skip_var, std::vector<nnvm::NodeEntry> *orig_entries) {
orig_entries->reserve(input_entries.size());
for (size_t i = 0; i < input_entries.size(); i++) {
nnvm::NodeEntry *e = input_entries[i];
// If the node is a variable itself, we may want to skip the node.
if (e->node->is_variable() && skip_var)
continue;

orig_entries->push_back(*e);
nnvm::Symbol sym;
sym.outputs.push_back(*e);
nnvm::NodePtr n = nnvm::CreateVariableNode(sym.ListOutputNames()[0]);
*e = nnvm::NodeEntry{n, 0, 0};
}
}

} // namespace sg

Graph PartitionGraph(Graph&& g) {
Expand All @@ -252,8 +274,7 @@ Graph PartitionGraph(Graph&& g) {
}
});
#endif
const std::unordered_set<std::string>& op_names = g.GetAttr<std::unordered_set<std::string>>("subgraph_op_names");
if (op_names.empty()) { // treat the whole graph as a subgraph
if (!g.HasAttr("subgraph_property")) { // treat the whole graph as a subgraph
Symbol whole_graph_sym;
whole_graph_sym.outputs = g.outputs;
// DO NOT define node name for subgraph op because it would serve
Expand Down Expand Up @@ -283,23 +304,40 @@ Graph PartitionGraph(Graph&& g) {
return ret;
} else {
using namespace sg;
SubgraphPropertyPtr subg_prop = g.GetAttr<SubgraphPropertyPtr>("subgraph_property");
std::vector<SimpleNodePtr> simple_nodes;
CreateSimpleGraph(g, &simple_nodes);
std::vector<std::vector<SimpleNode*>> subgraph_nodes;
FindSubgraphs(g, op_names, simple_nodes, &subgraph_nodes);
FindSubgraphs(g, *subg_prop, simple_nodes, &subgraph_nodes);
std::vector<nnvm::NodeEntry*> entries;
// TODO(junwu): take care of the situation when the op is the last op
for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
PrintSubgraph(subgraph_nodes[i]);

// Break the input links.
LOG(INFO) << "Searching for input entries...";
entries.clear();
FindInputEntries(g, simple_nodes, subgraph_nodes[i], &entries);
std::vector<nnvm::NodeEntry> orig_input_entries;
sg::CutGraphInputs(entries, false, &orig_input_entries);
PrintNodeEntries(entries);

LOG(INFO) << "Searching for output entries...";
entries.clear();
FindOutputEntries(&g, simple_nodes, subgraph_nodes[i], &entries);

// Create a subgraph.
nnvm::Symbol sym;
sym.outputs.resize(entries.size());
for (size_t i = 0; i < entries.size(); i++)
sym.outputs[i] = *entries[i];
nnvm::NodePtr n = subg_prop->CreateSubgraphNode(sym);

// Connect the external nodes to the subgraph node.
for (uint32_t i = 0; i < entries.size(); i++)
*entries[i] = nnvm::NodeEntry{n, i, 0};
// TODO(zhengda) this may not be the right order for input entries of a subgraph?
n->inputs = orig_input_entries;
PrintNodeEntries(entries);
}
return g;
Expand Down
Loading

0 comments on commit 43d0a75

Please sign in to comment.