From cc4b8ec68b6ec9dae73046e9c34ac97439efda83 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Tue, 6 Oct 2020 12:24:40 -0700 Subject: [PATCH] [1.8.x] Backporting: Fixed setting attributes in reviewSubgraph (#19278) * initial commit * fixed mapping from top level param names to subgraph input names * fixed sanity * support escape characters when parsing strings * changed string allocation from new to malloc to match free * add node to graph nodes array * fixed add nodes Co-authored-by: Ubuntu --- .../extensions/lib_subgraph/subgraph_lib.cc | 10 ++++- include/mxnet/lib_api.h | 8 ++-- src/lib_api.cc | 40 +++++++++++-------- 3 files changed, 38 insertions(+), 20 deletions(-) diff --git a/example/extensions/lib_subgraph/subgraph_lib.cc b/example/extensions/lib_subgraph/subgraph_lib.cc index f47109343007..98508c793de3 100644 --- a/example/extensions/lib_subgraph/subgraph_lib.cc +++ b/example/extensions/lib_subgraph/subgraph_lib.cc @@ -209,11 +209,16 @@ MXReturnValue mySupportedOps(const mxnet::ext::Graph* graph, } MXReturnValue myReviewSubgraph(const mxnet::ext::Graph *subgraph, int subgraph_id, bool* accept, - const std::unordered_map& options) { + const std::unordered_map& options, + std::unordered_map* attrs) { for (auto kv : options) { std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl; } + std::string sg = subgraph->toString(); + std::cout << "subgraph " << subgraph_id << ": " << std::endl; + std::cout << sg << std::endl; + // check if option `reject` was specified, and if so check if value is 'True' if(options.count("reject") > 0 && options.at("reject").compare("True") == 0) { // if specified, reject the subgraph. this is only used for testing @@ -223,6 +228,9 @@ MXReturnValue myReviewSubgraph(const mxnet::ext::Graph *subgraph, int subgraph_i *accept = true; std::cout << "accepting subgraph" << std::endl; } + + attrs->emplace("myKey","myVal"); + return MX_SUCCESS; } diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h index 0213557fdc92..db93dbe6ff41 100644 --- a/include/mxnet/lib_api.h +++ b/include/mxnet/lib_api.h @@ -594,10 +594,10 @@ class Graph { static Graph* fromJson(JsonVal val); /* \brief convert graph object back to JSON object */ - JsonVal toJson(); + JsonVal toJson() const; /* \brief convert graph object to JSON string */ - std::string toString(); + std::string toString() const; /* \brief visits a node "n" */ void _dfs_util(Node* n, std::unordered_set* to_visit, @@ -819,7 +819,9 @@ typedef MXReturnValue (*createSelector_t)(const mxnet::ext::Graph *graph, typedef MXReturnValue (*reviewSubgraph_t)(const mxnet::ext::Graph *subgraph, int subgraph_id, bool* accept, const std::unordered_map& options); + std::string>& options, + std::unordered_map* attrs); /*! * \brief An abstract class for subgraph property diff --git a/src/lib_api.cc b/src/lib_api.cc index 20ae280acf6c..c273678dcd1a 100644 --- a/src/lib_api.cc +++ b/src/lib_api.cc @@ -348,7 +348,8 @@ mxnet::ext::JsonVal mxnet::ext::JsonVal::parse(const std::string& json) { mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_string(const std::string& json, unsigned int* idx) { JsonVal ret(STR); while (*idx < json.size()) { - if (json[*idx] == '"') { + if (json[*idx] == '"' && (ret.str.size() == 0 || + (ret.str.size() > 0 && ret.str.back() != '\\'))) { ++(*idx); return ret; } else { @@ -561,7 +562,7 @@ mxnet::ext::Graph* mxnet::ext::Graph::fromJson(mxnet::ext::JsonVal val) { } /* \brief convert graph object back to JSON object */ -mxnet::ext::JsonVal mxnet::ext::Graph::toJson() { +mxnet::ext::JsonVal mxnet::ext::Graph::toJson() const { // top level object is a map JsonVal val(MAP); @@ -646,7 +647,7 @@ mxnet::ext::JsonVal mxnet::ext::Graph::toJson() { } /* \brief convert graph object to JSON string */ -std::string mxnet::ext::Graph::toString() { +std::string mxnet::ext::Graph::toString() const { return toJson().dump(); } @@ -725,6 +726,7 @@ void mxnet::ext::Graph::print(int indent) const { /* \brief add a new node to this graph */ mxnet::ext::Node* mxnet::ext::Graph::addNode(const std::string& name, const std::string& op) { Node* n = new Node(); + nodes.push_back(n); n->name = name; n->op = op; if (res) @@ -766,10 +768,14 @@ void mxnet::ext::Graph::_setParams(std::unordered_map* aux) { // set params for each input node for (Node* node : inputs) { - if (args->count(node->name) > 0) - node->tensor = &args->at(node->name); - else if (aux->count(node->name) > 0) - node->tensor = &aux->at(node->name); + std::string name = node->name; + if (node->attrs.count("isArg") > 0 && node->attrs["isArg"].compare("True") == 0) + // mapping name back to original node name from subgraph input name + name = node->attrs["argName"]; + if (args->count(name) > 0) + node->tensor = &args->at(name); + else if (aux->count(name) > 0) + node->tensor = &aux->at(name); } } @@ -1494,26 +1500,27 @@ MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph, } subgraph->_setParams(&args, &aux); + + std::unordered_map attrs; mxnet::ext::MXReturnValue retval = reviewSubgraph(subgraph, subgraph_id, &accept_bool, - opts); + opts, &attrs); if (!retval) return retval; *accept = accept_bool; - if (subgraph->attrs.size() > 0) { - *num_attrs = subgraph->attrs.size(); + if (attrs.size() > 0) { + *num_attrs = attrs.size(); // allocate space for attributes *attr_keys = static_cast(malloc (*num_attrs * sizeof(char*))); // NOLINT *attr_vals = static_cast(malloc (*num_attrs * sizeof(char*))); // NOLINT // copy attributes int i = 0; - for (auto kv : subgraph->attrs) { + for (auto kv : attrs) { (*attr_keys)[i] = static_cast(malloc ((kv.first.size()+1) * sizeof(char))); // NOLINT - std::string val = kv.second.dump(); // convert JsonVal back to string - (*attr_vals)[i] = static_cast(malloc ((val.size()+1) * sizeof(char))); // NOLINT + (*attr_vals)[i] = static_cast(malloc ((kv.second.size()+1) * sizeof(char))); // NOLINT snprintf((*attr_keys)[i], kv.first.size()+1, "%s", kv.first.c_str()); - snprintf((*attr_vals)[i], val.size()+1, "%s", val.c_str()); + snprintf((*attr_vals)[i], kv.second.size()+1, "%s", kv.second.c_str()); i++; } } @@ -1587,8 +1594,9 @@ MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *jso mxnet::ext::MXReturnValue retval = graphPass(graph, opts); if (!retval) return retval; - std::string *tmp = new std::string(graph->toString()); - *out_graph = const_cast(tmp->c_str()); + std::string tmp = graph->toString(); + *out_graph = static_cast(malloc ((tmp.size()+1) * sizeof(char))); // NOLINT + snprintf((*out_graph), tmp.size()+1, "%s", tmp.c_str()); return retval; }