Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Support extra inputs for subgraph ops #18779

Merged
merged 27 commits into from
Aug 14, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions example/extensions/lib_pass/pass_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ MXReturnValue jsonPass(const std::string& in_graph, const std::string** out_grap
MXTensor* aux_ = res.alloc_aux("test_aux",{1},MXContext::CPU(0),kFloat32);

// convert json string to json object
JsonParser parser;
JsonVal json_val = parser.parse_to_json(in_graph);
JsonVal json_val = JsonVal::parse(in_graph);

// get nodes list
JsonVal nodes = json_val.map[JsonVal("nodes")];
Expand All @@ -86,7 +85,7 @@ MXReturnValue jsonPass(const std::string& in_graph, const std::string** out_grap
}
}

*out_graph = new std::string(parser.dump(json_val));
*out_graph = new std::string(json_val.dump());
return MX_SUCCESS;
}

Expand Down
53 changes: 46 additions & 7 deletions example/extensions/lib_subgraph/subgraph_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ MXReturnValue myExecutor(std::vector<MXTensor>* inputs,
std::cout << subgraph_sym << std::endl;

// convert json string to json object
JsonParser parser;
JsonVal json_val = parser.parse_to_json(subgraph_sym);
JsonVal json_val = JsonVal::parse(subgraph_sym);
// get nodes list
JsonVal nodes = json_val.map[JsonVal("nodes")];
//counter for inputs
Expand Down Expand Up @@ -148,6 +147,9 @@ class MyStatefulOp : public CustomStatefulOp {
MXReturnValue Forward(std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
const OpResource& op_res) {
if(attrs_.count(MX_STR_EXTRA_INPUTS) > 0 && std::stoi(attrs_.at(MX_STR_EXTRA_INPUTS)) > 0)
std::cout << "forward::extra_inputs(" << attrs_.at(MX_STR_EXTRA_INPUTS) << ")::inputs ["
<< inputs->size() << "]" << std::endl;
return myExecutor(inputs, outputs, subgraph_sym);
}

Expand Down Expand Up @@ -183,8 +185,7 @@ MXReturnValue mySupportedOps(const std::string& json,
std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
}
//convert json string to json object
JsonParser parser;
JsonVal json_val = parser.parse_to_json(json);
JsonVal json_val = JsonVal::parse(json);
//get nodes list
JsonVal nodes = json_val.map[JsonVal("nodes")];

Expand Down Expand Up @@ -249,7 +250,6 @@ MXReturnValue myReviewSubgraph(const std::string& json, int subgraph_id, bool* a
} else {
*accept = true;
std::cout << "accepting subgraph" << std::endl;
attrs->insert(std::pair<std::string,std::string>("myKey","myVal"));
}
return MX_SUCCESS;
}
Expand All @@ -269,8 +269,7 @@ class MySelector : public CustomOpSelector {
<< " ==> " << kv.second << std::endl;
}
//convert json string to json object
JsonParser parser;
JsonVal json_val = parser.parse_to_json(json);
JsonVal json_val = JsonVal::parse(json);
//get nodes list
nodes = json_val.map[JsonVal("nodes")];
}
Expand Down Expand Up @@ -331,6 +330,46 @@ REGISTER_PARTITIONER(mySelect)
.setCreateSelector("strategy1", createSelector)
.setReviewSubgraph("strategy1", myReviewSubgraph);

/* \brief a basic pass that adds a new input for subgraph ops */
MXReturnValue addInputPass(const std::string& in_graph, const std::string** out_graph,
samskalicky marked this conversation as resolved.
Show resolved Hide resolved
const std::unordered_map<std::string, std::string>& options,
const std::unordered_map<std::string, MXTensor>& args,
const std::unordered_map<std::string, MXTensor>& aux,
const PassResource& res) {
// convert graph from JSON string to Graph/Node data structure
Graph *g = Graph::fromString(in_graph);
//find node with '_custom_subgraph_op' op type
for(Node* n : g->nodes) {
if(n->op.compare("_custom_subgraph_op") == 0) {
//set extra input
n->attrs[MX_STR_EXTRA_INPUTS] = std::to_string(1);

//create a new input Node
Node* input = new Node();
std::string input_name = n->name + "_input";
input->name = input_name;
input->op = "null";
//add a new node in graph
g->nodes.push_back(input);
g->inputs.push_back(input);
//connect new input to node
input->outputs.push_back({n,(int)(n->inputs.size())});
//connect node to new input
n->inputs.push_back({input,0});
// add a corresponding tensor for this input
MXTensor* arg_ = res.alloc_arg(input_name,{1},MXContext::CPU(0),kFloat32);
}
}

//convert back to JSON string from Graph/Node
*out_graph = new std::string(g->toString());
return MX_SUCCESS;
}

REGISTER_PASS(addInputPass)
samskalicky marked this conversation as resolved.
Show resolved Hide resolved
.setBody(addInputPass);


MXReturnValue initialize(int version) {
if (version >= 10700) {
std::cout << "MXNet version " << version << " supported" << std::endl;
Expand Down
68 changes: 19 additions & 49 deletions example/extensions/lib_subgraph/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,45 +55,22 @@ def test(backend):
###############################################
#execute in MXNet
print('-------------------------------')
print('Testing regular MXNet execution')
exe = sym.bind(ctx=mx.cpu(), args=args)
out = exe.forward()
print('Testing regular Gluon execution')
inputs = [a,b]
sym_block = nn.SymbolBlock(sym, inputs)
sym_block.initialize()
out = sym_block(mx.nd.ones((3,2)),mx.nd.ones((3,2)))
print(out)

# with propogating shapes/types
print('-------------------------------')
print('Testing %s partitioning with shapes/types' % backend)
mysym2 = sym.optimize_for(backend,args)
print(mysym2.tojson())
exe2 = mysym2.bind(ctx=mx.cpu(), args=args)
out2 = exe2.forward()
print(out2)

# with propogating shapes/types, rejecting subgraph
print('-------------------------------')
print('Testing %s partitioning with shapes/types - rejecting subgraph' % backend)
mysym2 = sym.optimize_for(backend, args, reject=True)
exe2 = mysym2.bind(ctx=mx.cpu(), args=args)
out2 = exe2.forward()
print(out2)

# without propogating shapes/types
print('-------------------------------')
print('Testing %s partitioning without shapes/types' % backend)
mysym3 = sym.optimize_for(backend, myOpt='yello')
exe3 = mysym3.bind(ctx=mx.cpu(), args=args)
out3 = exe3.forward()
print(out3)

# Gluon Hybridize partitioning with shapes/types
print('-------------------------------')
print('Testing %s Gluon Hybridize partitioning with shapes/types' % backend)
inputs = [a,b]
sym_block = nn.SymbolBlock(sym, inputs)
sym_block.initialize()
sym_block.hybridize(backend=backend)
out4 = sym_block(mx.nd.ones((3,2)),mx.nd.ones((3,2)))
print(out4)
out2 = sym_block(mx.nd.ones((3,2)),mx.nd.ones((3,2)))
print(out2)

# Gluon Hybridize partitioning with shapes/types without inference
print('-------------------------------')
Expand All @@ -104,34 +81,27 @@ def test(backend):
sym_block2.optimize_for(mx.nd.ones((3,2)), mx.nd.ones((3,2)), backend=backend)
sym_block2.export('partitioned')

# Test with additional input to subgraph op
print('-------------------------------')
print('Testing %s Gluon Hybridize partitioning with extra input' % backend)
sym_block2.optimize_for(mx.nd.ones((3,2)), mx.nd.ones((3,2)), backend="addInputPass", clear=False)
out3 = sym_block2(mx.nd.ones((3,2)),mx.nd.ones((3,2)))
print(out3)


###############################################
# Test with subgraph directly consuming params
###############################################
args = {'a':mx.nd.ones((3,2))}
#execute in MXNet
print('-------------------------------')
print('Testing regular MXNet execution')
exe5 = sym2.bind(ctx=mx.cpu(), args=args)
out5 = exe5.forward()
inputs = [a]
sym2_block = nn.SymbolBlock(sym2, inputs)
sym2_block.initialize()
out5 = sym2_block(mx.nd.ones((3,2)))
print(out5)

# with propogating shapes/types
print('-------------------------------')
print('Testing %s partitioning with shapes/types' % backend)
mysym6 = sym2.optimize_for(backend, args, reqArgs=True)
print(mysym6.tojson())
exe6 = mysym6.bind(ctx=mx.cpu(), args=args)
out6 = exe6.forward()
print(out6)

# without propogating shapes/types
print('-------------------------------')
print('Testing %s partitioning without shapes/types' % backend)
mysym7 = sym2.optimize_for(backend, reqArgs=True)
exe7 = mysym7.bind(ctx=mx.cpu(), args=args)
out7 = exe7.forward()
print(out7)

# Gluon Hybridize partitioning with shapes/types
print('-------------------------------')
print('Testing %s Gluon Hybridize partitioning with shapes/types' % backend)
Expand Down
Loading