Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[1.x] Fixed setting attributes in reviewSubgraph (#19274)
Browse files Browse the repository at this point in the history
* initial commit

* fixed mapping from top level param names to subgraph input names

* fixed sanity

* support escape characters when parsing strings

* add node to graph nodes array

* changed string allocation from new to malloc to match free

* fixed add nodes

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
samskalicky and Ubuntu committed Oct 6, 2020
1 parent 9cadff7 commit 3b69c60
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 20 deletions.
10 changes: 9 additions & 1 deletion example/extensions/lib_subgraph/subgraph_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,16 @@ MXReturnValue mySupportedOps(const mxnet::ext::Graph* graph,
}

MXReturnValue myReviewSubgraph(const mxnet::ext::Graph *subgraph, int subgraph_id, bool* accept,
const std::unordered_map<std::string, std::string>& options) {
const std::unordered_map<std::string, std::string>& options,
std::unordered_map<std::string, std::string>* attrs) {
for (auto kv : options) {
std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
}

std::string sg = subgraph->toString();
std::cout << "subgraph " << subgraph_id << ": " << std::endl;
std::cout << sg << std::endl;

// check if option `reject` was specified, and if so check if value is 'True'
if(options.count("reject") > 0 && options.at("reject").compare("True") == 0) {
// if specified, reject the subgraph. this is only used for testing
Expand All @@ -223,6 +228,9 @@ MXReturnValue myReviewSubgraph(const mxnet::ext::Graph *subgraph, int subgraph_i
*accept = true;
std::cout << "accepting subgraph" << std::endl;
}

attrs->emplace("myKey","myVal");

return MX_SUCCESS;
}

Expand Down
8 changes: 5 additions & 3 deletions include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -594,10 +594,10 @@ class Graph {
static Graph* fromJson(JsonVal val);

/* \brief convert graph object back to JSON object */
JsonVal toJson();
JsonVal toJson() const;

/* \brief convert graph object to JSON string */
std::string toString();
std::string toString() const;

/* \brief visits a node "n" */
void _dfs_util(Node* n, std::unordered_set<Node*>* to_visit,
Expand Down Expand Up @@ -819,7 +819,9 @@ typedef MXReturnValue (*createSelector_t)(const mxnet::ext::Graph *graph,
typedef MXReturnValue (*reviewSubgraph_t)(const mxnet::ext::Graph *subgraph, int subgraph_id,
bool* accept,
const std::unordered_map<std::string,
std::string>& options);
std::string>& options,
std::unordered_map<std::string,
std::string>* attrs);

/*!
* \brief An abstract class for subgraph property
Expand Down
40 changes: 24 additions & 16 deletions src/lib_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ mxnet::ext::JsonVal mxnet::ext::JsonVal::parse(const std::string& json) {
mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_string(const std::string& json, unsigned int* idx) {
JsonVal ret(STR);
while (*idx < json.size()) {
if (json[*idx] == '"') {
if (json[*idx] == '"' && (ret.str.size() == 0 ||
(ret.str.size() > 0 && ret.str.back() != '\\'))) {
++(*idx);
return ret;
} else {
Expand Down Expand Up @@ -561,7 +562,7 @@ mxnet::ext::Graph* mxnet::ext::Graph::fromJson(mxnet::ext::JsonVal val) {
}

/* \brief convert graph object back to JSON object */
mxnet::ext::JsonVal mxnet::ext::Graph::toJson() {
mxnet::ext::JsonVal mxnet::ext::Graph::toJson() const {
// top level object is a map
JsonVal val(MAP);

Expand Down Expand Up @@ -646,7 +647,7 @@ mxnet::ext::JsonVal mxnet::ext::Graph::toJson() {
}

/* \brief convert graph object to JSON string */
std::string mxnet::ext::Graph::toString() {
std::string mxnet::ext::Graph::toString() const {
return toJson().dump();
}

Expand Down Expand Up @@ -725,6 +726,7 @@ void mxnet::ext::Graph::print(int indent) const {
/* \brief add a new node to this graph */
mxnet::ext::Node* mxnet::ext::Graph::addNode(const std::string& name, const std::string& op) {
Node* n = new Node();
nodes.push_back(n);
n->name = name;
n->op = op;
if (res)
Expand Down Expand Up @@ -766,10 +768,14 @@ void mxnet::ext::Graph::_setParams(std::unordered_map<std::string, mxnet::ext::M
std::unordered_map<std::string, mxnet::ext::MXTensor>* aux) {
// set params for each input node
for (Node* node : inputs) {
if (args->count(node->name) > 0)
node->tensor = &args->at(node->name);
else if (aux->count(node->name) > 0)
node->tensor = &aux->at(node->name);
std::string name = node->name;
if (node->attrs.count("isArg") > 0 && node->attrs["isArg"].compare("True") == 0)
// mapping name back to original node name from subgraph input name
name = node->attrs["argName"];
if (args->count(name) > 0)
node->tensor = &args->at(name);
else if (aux->count(name) > 0)
node->tensor = &aux->at(name);
}
}

Expand Down Expand Up @@ -1494,26 +1500,27 @@ MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph,
}

subgraph->_setParams(&args, &aux);

std::unordered_map<std::string, std::string> attrs;
mxnet::ext::MXReturnValue retval = reviewSubgraph(subgraph, subgraph_id, &accept_bool,
opts);
opts, &attrs);
if (!retval) return retval;

*accept = accept_bool;

if (subgraph->attrs.size() > 0) {
*num_attrs = subgraph->attrs.size();
if (attrs.size() > 0) {
*num_attrs = attrs.size();
// allocate space for attributes
*attr_keys = static_cast<char**>(malloc (*num_attrs * sizeof(char*))); // NOLINT
*attr_vals = static_cast<char**>(malloc (*num_attrs * sizeof(char*))); // NOLINT

// copy attributes
int i = 0;
for (auto kv : subgraph->attrs) {
for (auto kv : attrs) {
(*attr_keys)[i] = static_cast<char*>(malloc ((kv.first.size()+1) * sizeof(char))); // NOLINT
std::string val = kv.second.dump(); // convert JsonVal back to string
(*attr_vals)[i] = static_cast<char*>(malloc ((val.size()+1) * sizeof(char))); // NOLINT
(*attr_vals)[i] = static_cast<char*>(malloc ((kv.second.size()+1) * sizeof(char))); // NOLINT
snprintf((*attr_keys)[i], kv.first.size()+1, "%s", kv.first.c_str());
snprintf((*attr_vals)[i], val.size()+1, "%s", val.c_str());
snprintf((*attr_vals)[i], kv.second.size()+1, "%s", kv.second.c_str());
i++;
}
}
Expand Down Expand Up @@ -1587,8 +1594,9 @@ MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *jso
mxnet::ext::MXReturnValue retval = graphPass(graph, opts);
if (!retval) return retval;

std::string *tmp = new std::string(graph->toString());
*out_graph = const_cast<char*>(tmp->c_str());
std::string tmp = graph->toString();
*out_graph = static_cast<char*>(malloc ((tmp.size()+1) * sizeof(char))); // NOLINT
snprintf((*out_graph), tmp.size()+1, "%s", tmp.c_str());
return retval;
}

Expand Down

0 comments on commit 3b69c60

Please sign in to comment.