From 9a614c04815a5be2479b116e1b6220e9a623bccd Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 Oct 2020 01:20:35 +0000 Subject: [PATCH 1/9] Make BERT-GPU deploy compatible with MXNet 1.8 --- scripts/bert/bertpass_gpu.cc | 453 +++++++++++++++++++---------------- scripts/bert/deploy.py | 34 ++- scripts/bert/setup.py | 11 +- 3 files changed, 282 insertions(+), 216 deletions(-) diff --git a/scripts/bert/bertpass_gpu.cc b/scripts/bert/bertpass_gpu.cc index a773698454..0b79d53319 100644 --- a/scripts/bert/bertpass_gpu.cc +++ b/scripts/bert/bertpass_gpu.cc @@ -30,6 +30,7 @@ #include #include "mxnet/lib_api.h" +#if MXNET_1_7 class Node; struct NodeEntry { Node* node; @@ -47,7 +48,7 @@ class Node { class Graph { public: Graph() {} - static Graph fromString(const std::string& json) { + static Graph* fromString(const std::string& json) { JsonParser parser; JsonVal val = parser.parse_to_json(json); return fromJson(val); @@ -56,16 +57,16 @@ class Graph { for(int i=0; i nodeMap; // loop over nodes for(int i=0; inodes.push_back(n); JsonVal node = nodes.list[i]; // set the op info @@ -74,8 +75,8 @@ class Graph { // if op is null its an input to the graph if(n->op.compare("null") == 0) - g.inputs.push_back(n); - + g->inputs.push_back(n); + // set attrs JsonVal attributes = node.map[JsonVal("attrs")]; for(auto& kv : attributes.map) { @@ -99,20 +100,20 @@ class Graph { } JsonVal& heads = val.map[JsonVal("heads")]; - g.outputs.resize(heads.list.size()); + g->outputs.resize(heads.list.size()); for(int i=0; ioutputs[i].node = nodeMap[head.list[0].num]; + g->outputs[i].entry = head.list[1].num; } - + JsonParser parser; for(auto& kv : val.map) { if(kv.first.str.compare("nodes") != 0 && - kv.first.str.compare("heads") != 0 && - kv.first.str.compare("node_row_ptr") != 0 && - kv.first.str.compare("arg_nodes") != 0) { - g.attrs[kv.first.str] = kv.second; + kv.first.str.compare("heads") != 0 && + kv.first.str.compare("node_row_ptr") != 0 && + kv.first.str.compare("arg_nodes") != 0) { + g->attrs[kv.first.str] = kv.second; } } return g; @@ -138,7 +139,7 @@ class Graph { JsonVal& arg_nodes = val.map[JsonVal("arg_nodes")]; for(int i=0; i=0; i--) { nodes_.list.push_back(JsonVal(MAP)); Node* n = sorted[i]; JsonVal& n_ = nodes_.list[nodes_.list.size()-1]; - + n_.map[JsonVal("op")] = JsonVal(n->op); n_.map[JsonVal("name")] = JsonVal(n->name); n_.map[JsonVal("inputs")] = JsonVal(LIST); @@ -218,114 +219,97 @@ class Graph { return sorted; } - void print() { - std::cout << "########### Graph #############" << std::endl; - std::cout << "inputs: " << inputs.size() << std::endl; - std::cout << "outputs: " << outputs.size() << std::endl; - std::cout << "nodes: " << nodes.size() << std::endl; - std::vector sorted; - auto handler = [&](Node* n) { - sorted.push_back(n); - }; - DFS(handler); - - for(int i=sorted.size()-1; i>=0; i--) { - std::cout << "Node: " << sorted[i]->name << std::endl; - for(int j=0; jinputs.size(); j++) { - std::cout << "\tInput: " << sorted[i]->inputs[j].node->name << " " << sorted[i]->inputs[j].entry << std::endl; - } - for(int j=0; joutputs.size(); j++) { - std::cout << "\tOutput: " << sorted[i]->outputs[j].node->name << " " << sorted[i]->outputs[j].entry << std::endl; - } - } - std::cout << "###############################" << std::endl; - } - std::vector nodes; std::vector inputs; std::vector outputs; std::map attrs; }; +#else + using namespace mxnet::ext; +#endif -// example Sam: https://gist.github.com/samskalicky/5f44e159e9f1b04237eed8d20e5d9f28 +#if MXNET_1_7 MXReturnValue custom_pass(const std::string& in_graph, const std::string** out_graph, const std::unordered_map& options, const std::unordered_map& args, const std::unordered_map& aux, const PassResource& res) { - - for (auto kv : options) - std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl; - //convert graph from JSON string to Graph/Node data structure - Graph g = Graph::fromString(in_graph); - //g.print(); - + Graph *g = Graph::fromString(in_graph); + for(Node* n : g->nodes) { +#else +MXReturnValue custom_pass(mxnet::ext::Graph *g, + const std::unordered_map& options) { + for(int i=0; i < g->size(); i++) { + mxnet::ext::Node* n = g->getNode(i); +#endif + /////////////////////// AddBias + GELU ////////////////////////// - std::string str_ffn1 = "ffn_1_fwd"; - for(Node* n : g.nodes){ - if (n->name.find(str_ffn1) != std::string::npos) { - Node* node_ffn1_fwd = n; - Node* node_ffn1_bias = node_ffn1_fwd->inputs[2].node; - Node* node_gelu = node_ffn1_fwd->outputs[0].node; - - std::size_t pos = n->name.find("fwd"); - std::string base_name = n->name.substr(0,pos-1); - - // remove Bias terms in FC - node_ffn1_fwd->attrs["no_bias"]="True"; - node_ffn1_fwd->inputs.pop_back(); - - // create 2 expand_dims nodes to expand bias dimensions - Node* node_expand_1_bias = new Node(); - node_expand_1_bias->name = base_name + "_expand_1_bias"; - node_expand_1_bias->op = "expand_dims"; - node_expand_1_bias->attrs["axis"]="0"; - node_expand_1_bias->inputs.resize(1); - node_expand_1_bias->inputs[0].node = node_ffn1_bias; - node_expand_1_bias->inputs[0].entry = 0; - Node* node_expand_2_bias = new Node(); - node_expand_2_bias->name = base_name + "_expand_2_bias"; - node_expand_2_bias->op = "expand_dims"; - node_expand_2_bias->attrs["axis"]="0"; - node_expand_2_bias->inputs.resize(1); - node_expand_2_bias->inputs[0].node = node_expand_1_bias; - node_expand_2_bias->inputs[0].entry = 0; - g.nodes.push_back(node_expand_1_bias); - g.nodes.push_back(node_expand_2_bias); - - // create broadcast_like node - Node* node_bcst_like = new Node(); - node_bcst_like->name = base_name + "_broadcast_like"; - node_bcst_like->op = "broadcast_like";; - node_bcst_like->inputs.resize(2); - node_bcst_like->inputs[0].node = node_expand_2_bias; - node_bcst_like->inputs[0].entry = 0; - node_bcst_like->inputs[1].node = node_ffn1_fwd; - node_bcst_like->inputs[1].entry = 0; - g.nodes.push_back(node_bcst_like); - - // create BiasAdd Node - Node* node_add_bias = new Node(); - node_add_bias->name = base_name + "_add_bias"; - node_add_bias->op = "elemwise_add"; - node_add_bias->inputs.resize(2); - node_add_bias->inputs[0].node = node_ffn1_fwd; - node_add_bias->inputs[0].entry = 0; - node_add_bias->inputs[1].node = node_bcst_like; - node_add_bias->inputs[1].entry = 0; - g.nodes.push_back(node_add_bias); - - //set BiasAdd node as gelu input - node_gelu->inputs[0].node = node_add_bias; - node_gelu->inputs[0].entry = 0; - } + std::string str_ffn1 = "ffn_1_fwd"; + if (n->name.find(str_ffn1) != std::string::npos) { + Node* node_ffn1_fwd = n; + Node* node_ffn1_bias = node_ffn1_fwd->inputs[2].node; + Node* node_gelu = node_ffn1_fwd->outputs[0].node; + + std::size_t pos = n->name.find("fwd"); + std::string base_name = n->name.substr(0, pos - 1); + + // remove Bias terms in FC + node_ffn1_fwd->attrs["no_bias"]="True"; + node_ffn1_fwd->inputs.pop_back(); + + // create 4 new nodes: 2 expand_dims nodes to expand bias dimensions + // a broadcast_like node, and BiasAdd +#if MXNET_1_7 + Node* node_expand_1_bias = new Node(); + node_expand_1_bias->name = base_name + "_expand_1_bias"; + node_expand_1_bias->op = "expand_dims"; + Node* node_expand_2_bias = new Node(); + node_expand_2_bias->name = base_name + "_expand_2_bias"; + node_expand_2_bias->op = "expand_dims"; + Node* node_bcst_like = new Node(); + node_bcst_like->name = base_name + "_broadcast_like"; + node_bcst_like->op = "broadcast_like"; + Node* node_add_bias = new Node(); + node_add_bias->name = base_name + "_add_bias"; + node_add_bias->op = "elemwise_add"; + g->nodes.push_back(node_expand_1_bias); + g->nodes.push_back(node_expand_2_bias); + g->nodes.push_back(node_bcst_like); + g->nodes.push_back(node_add_bias); +#else + Node* node_expand_1_bias = g->addNode(base_name + "_expand_1_bias", "expand_dims"); + Node* node_expand_2_bias = g->addNode(base_name + "_expand_2_bias", "expand_dims"); + Node* node_bcst_like = g->addNode(base_name + "_broadcast_like", "broadcast_like"); + Node* node_add_bias = g->addNode(base_name + "_add_bias", "elemwise_add"); +#endif + node_expand_1_bias->attrs["axis"]="0"; + node_expand_1_bias->inputs.resize(1); + node_expand_1_bias->inputs[0].node = node_ffn1_bias; + node_expand_1_bias->inputs[0].entry = 0; + node_expand_2_bias->attrs["axis"]="0"; + node_expand_2_bias->inputs.resize(1); + node_expand_2_bias->inputs[0].node = node_expand_1_bias; + node_expand_2_bias->inputs[0].entry = 0; + node_bcst_like->inputs.resize(2); + node_bcst_like->inputs[0].node = node_expand_2_bias; + node_bcst_like->inputs[0].entry = 0; + node_bcst_like->inputs[1].node = node_ffn1_fwd; + node_bcst_like->inputs[1].entry = 0; + node_add_bias->inputs.resize(2); + node_add_bias->inputs[0].node = node_ffn1_fwd; + node_add_bias->inputs[0].entry = 0; + node_add_bias->inputs[1].node = node_bcst_like; + node_add_bias->inputs[1].entry = 0; + //set BiasAdd node as gelu input + node_gelu->inputs[0].node = node_add_bias; + node_gelu->inputs[0].entry = 0; + } } ///////////////////////////////////////////////////////////////// - - //////////////// MHA remove reshapes & concat /////////////////// - // find shape of weight / bias, number of heads, and count number of MHA layers + /////// MHA prepare weight & bias for interleaved strategy ////// + // find shapes, number of heads, and count number of MHA layers std::string query0_weight = "bertencoder0_transformer0_dotproductselfattentioncell0_query_weight"; std::string mult_qk0 = "bertencoder0_transformer0_dotproductselfattentioncell0_interleaved_matmul_selfatt_qk0"; std::string str_projection = "_dotproductselfattentioncell0_fullyconnected0"; @@ -333,114 +317,179 @@ MXReturnValue custom_pass(const std::string& in_graph, const std::string** out_g int num_heads = 0; int head_dimension = 0; int shape0, shape1; - for(Node* n : g.nodes){ - if (n->name.find(query0_weight) != std::string::npos) { - std::string shape = n->attrs["__shape__"]; - int pos_comma = shape.find(","); - shape0 = stoi(shape.substr(1, pos_comma-1)); - shape1 = stoi(shape.substr(pos_comma+2, shape.length()-pos_comma-3)); - } - if (n->name.find(mult_qk0) != std::string::npos) { - std::string h = n->attrs["heads"]; - num_heads = stoi(h); - } - if (n->name.find(str_projection) != std::string::npos) { - num_mha_layers++; - } +#if MXNET_1_7 + for(Node* n : g->nodes) { +#else + for(int i=0; i < g->size(); i++) { + mxnet::ext::Node* n = g->getNode(i); +#endif + if (n->name.find(query0_weight) != std::string::npos) { + std::string shape = n->attrs["__shape__"]; + int pos_comma = shape.find(","); + shape0 = stoi(shape.substr(1, pos_comma-1)); + shape1 = stoi(shape.substr(pos_comma + 2, shape.length() - pos_comma - 3)); + } + if (n->name.find(mult_qk0) != std::string::npos) { + std::string h = n->attrs["heads"]; + num_heads = stoi(h); + } + if (n->name.find(str_projection) != std::string::npos) { + num_mha_layers++; + } } head_dimension = shape0 / num_heads; // find projection nodes and set new interleaved intputs - for(Node* n : g.nodes){ - if (n->name.find("_dotproductselfattentioncell0_fullyconnected0") != std::string::npos) { - Node* node_projection = n; - std::size_t pos = node_projection->name.find("_fullyconnected0"); - std::string base_name = n->name.substr(0,pos); - - //////////////////// WEIGHTS //////////////////// - // create new arg with interleaved weights - std::string name_qkv_weights_interleaved = base_name + "_qkv_weights_interleaved"; - MXTensor* qkv_weights_interleaved = res.alloc_arg(name_qkv_weights_interleaved, {3*shape0,shape1}, MXContext::CPU(0), kFloat32); - float* qkv_w_data = qkv_weights_interleaved->data(); - // read from previous values and interleave them - MXTensor query_w = args.at(base_name+"_query_weight"); - MXTensor key_w = args.at(base_name+"_key_weight"); - MXTensor value_w = args.at(base_name+"_value_weight"); - float* query_w_data = query_w.data(); - float* key_w_data = key_w.data(); - float* value_w_data = value_w.data(); - for(int h=0; hname = name_qkv_weights_interleaved; - node_qkv_weights->op = "null"; - //add a new node in graph, also as input - g.nodes.push_back(node_qkv_weights); - g.inputs.push_back(node_qkv_weights); - // set connection with new input - node_projection->inputs[1].node = node_qkv_weights; - node_projection->inputs[1].entry = 0; - - //////////////////// BIAS //////////////////// - // create new arg with all bias - std::string name_qkv_bias = base_name + "_qkv_bias"; - MXTensor* qkv_bias = res.alloc_arg(name_qkv_bias, {3*shape0,}, MXContext::CPU(0), kFloat32); - float* qkv_bias_data = qkv_bias->data(); - // read from previous values and join them - MXTensor query_bias = args.at(base_name+"_query_bias"); - MXTensor key_bias = args.at(base_name+"_key_bias"); - MXTensor value_bias = args.at(base_name+"_value_bias"); - float* query_bias_data = query_bias.data(); - float* key_bias_data = key_bias.data(); - float* value_bias_data = value_bias.data(); - for(int e=0; ename = name_qkv_bias; - node_qkv_bias->op = "null"; - //add a new node in graph, also as input - g.nodes.push_back(node_qkv_bias); - g.inputs.push_back(node_qkv_bias); - // set connection with new input - node_projection->inputs[2].node = node_qkv_bias; - node_projection->inputs[2].entry = 0; +#if MXNET_1_7 + for(Node* n : g->nodes) { +#else + for(int i=0; i < g->size(); i++) { + mxnet::ext::Node* n = g->getNode(i); +#endif + if (n->name.find("_dotproductselfattentioncell0_fullyconnected0") != std::string::npos) { + Node* node_projection = n; + std::size_t pos = node_projection->name.find("_fullyconnected0"); + std::string base_name = n->name.substr(0, pos); + + //////////////////// WEIGHTS //////////////////// + // create new input node with interleaved weights + std::string name_qkv_weights_interleaved = base_name + "_qkv_weights_interleaved"; +#if MXNET_1_7 + // create a new input Node + Node* node_qkv_weights = new Node(); + node_qkv_weights->name = name_qkv_weights_interleaved; + node_qkv_weights->op = "null"; + MXTensor* qkv_weights_interleaved = res.alloc_arg(name_qkv_weights_interleaved, + {3 * shape0, shape1}, + MXContext::CPU(0), kFloat32); + float* qkv_w_data = qkv_weights_interleaved->data(); + // read from previous values and interleave them + MXTensor query_w = args.at(base_name + "_query_weight"); + MXTensor key_w = args.at(base_name + "_key_weight"); + MXTensor value_w = args.at(base_name + "_value_weight"); + float* query_w_data = query_w.data(); + float* key_w_data = key_w.data(); + float* value_w_data = value_w.data(); + //add a new node in graph, also as input + g->nodes.push_back(node_qkv_weights); + g->inputs.push_back(node_qkv_weights); +#else + Node* node_qkv_weights = g->addNode(name_qkv_weights_interleaved + "_input", "null"); + node_qkv_weights->alloc_arg({3 * shape0, shape1}, MXContext::CPU(0), kFloat32); + float* qkv_w_data = node_qkv_weights->tensor->data(); + // look back for query, key value weights original data + float *query_w_data, *key_w_data, *value_w_data; + int found = 0; + for (int j=i; j >= 0; j--) { + Node* n2 = g->getNode(j); + if (n2->name.find(base_name + "_query_weight") != std::string::npos) { + query_w_data = n2->tensor->data(); + found++; + } + if (n2->name.find(base_name + "_key_weight") != std::string::npos) { + key_w_data = n2->tensor->data(); + found++; + } + if (n2->name.find(base_name + "_value_weight") != std::string::npos) { + value_w_data = n2->tensor->data(); + found++; + } + if (found >= 3) + break; } +#endif + // interleave weights + for (int h=0; hinputs[1].node = node_qkv_weights; + node_projection->inputs[1].entry = 0; + + //////////////////// BIAS //////////////////// + // create new input node with all bias + std::string name_qkv_bias = base_name + "_qkv_bias"; +#if MXNET_1_7 + Node* node_qkv_bias = new Node(); + node_qkv_bias->name = name_qkv_bias; + node_qkv_bias->op = "null"; + MXTensor* qkv_bias = res.alloc_arg(name_qkv_bias, {3 * shape0, }, + MXContext::CPU(0), kFloat32); + float* qkv_bias_data = qkv_bias->data(); + MXTensor query_bias = args.at(base_name + "_query_bias"); + MXTensor key_bias = args.at(base_name + "_key_bias"); + MXTensor value_bias = args.at(base_name + "_value_bias"); + float* query_bias_data = query_bias.data(); + float* key_bias_data = key_bias.data(); + float* value_bias_data = value_bias.data(); + //add a new node in graph, also as input + g->nodes.push_back(node_qkv_bias); + g->inputs.push_back(node_qkv_bias); +#else + Node* node_qkv_bias = g->addNode(name_qkv_bias + "_input", "null"); + node_qkv_bias->alloc_arg({3 * shape0, }, MXContext::CPU(0), kFloat32); + float* qkv_bias_data = node_qkv_bias->tensor->data(); + // look back for query, key value weights original data + float *query_bias_data, *key_bias_data, *value_bias_data; + found = 0; + for (int j=i; j >= 0; j--) { + Node* n2 = g->getNode(j); + if (n2->name.find(base_name + "_query_bias") != std::string::npos) { + query_bias_data = n2->tensor->data(); + found++; + } + if (n2->name.find(base_name + "_key_bias") != std::string::npos) { + key_bias_data = n2->tensor->data(); + found++; + } + if (n2->name.find(base_name + "_value_weight") != std::string::npos) { + value_bias_data = n2->tensor->data(); + found++; + } + if (found >= 3) + break; + } +#endif + // concatenate bias terms + for (int e =0; e < shape0; ++e) { + qkv_bias_data[e] = query_bias_data[e]; + } + for (int e=0; e < shape0; ++e) { + qkv_bias_data[shape0 + e] = key_bias_data[e]; + } + for (int e=0; e < shape0; ++e) { + qkv_bias_data[2 * shape0 + e] = value_bias_data[e]; + } + // set connection with new input + node_projection->inputs[2].node = node_qkv_bias; + node_projection->inputs[2].entry = 0; + } } - ////////////////////////////////////////////////////////////////// +#if MXNET_1_7 //convert back to JSON string from Graph/Node - *out_graph = new std::string(g.toString()); + *out_graph = new std::string(g->toString()); +#endif return MX_SUCCESS; - } - REGISTER_PASS(custom_pass) .setBody(custom_pass); MXReturnValue initialize(int version) { - if (version >= 10400) { + printf("VERSION %i\n", version); + if (version >= 10700) { std::cout << "MXNet version " << version << " supported" << std::endl; return MX_SUCCESS; } else { diff --git a/scripts/bert/deploy.py b/scripts/bert/deploy.py index 8978e0c120..9cda76f885 100644 --- a/scripts/bert/deploy.py +++ b/scripts/bert/deploy.py @@ -321,17 +321,29 @@ def export(prefix): arg_array['data2'] = mx.nd.ones((test_batch_size, ), dtype='float32') custom_sym = sym.optimize_for('custom_pass', arg_array, aux_params) - nheads = 12 - if args.bert_model == 'bert_24_1024_16': - nheads = 24 - for i in range(nheads): - basename = 'bertencoder0_transformer' + str(i) + '_dotproductselfattentioncell0' - arg_array.pop(basename + '_query_weight') - arg_array.pop(basename + '_key_weight') - arg_array.pop(basename + '_value_weight') - arg_array.pop(basename + '_query_bias') - arg_array.pop(basename + '_key_bias') - arg_array.pop(basename + '_value_bias') + is_mxnet_1_7 = (mx.__version__ == '1.7.0') + '''if not is_mxnet_1_7: + shape_dict = {'data0': (test_batch_size, seq_length), + 'data1': (test_batch_size, seq_length), + 'data2': (test_batch_size, )} + type_dict = {'data0': 'float32', + 'data1': 'float32', + 'data2': 'float32'} + custom_sym = sym.optimize_for('custom_pass', arg_params, + aux_params, shape_dict=shape_dict, + type_dict=type_dict)''' + if is_mxnet_1_7: + nheads = 12 + if args.bert_model == 'bert_24_1024_16': + nheads = 24 + for i in range(nheads): + basename = 'bertencoder0_transformer' + str(i) + '_dotproductselfattentioncell0' + arg_array.pop(basename + '_query_weight') + arg_array.pop(basename + '_key_weight') + arg_array.pop(basename + '_value_weight') + arg_array.pop(basename + '_query_bias') + arg_array.pop(basename + '_key_bias') + arg_array.pop(basename + '_value_bias') arg_array.pop('data0') arg_array.pop('data1') arg_array.pop('data2') diff --git a/scripts/bert/setup.py b/scripts/bert/setup.py index 5872faa0f9..5f601790c2 100644 --- a/scripts/bert/setup.py +++ b/scripts/bert/setup.py @@ -26,9 +26,14 @@ def CompileBERTCustomPass(): pass_path = os.path.dirname(os.path.realpath(__file__)) source = os.path.join(pass_path, input_pass_file) target = os.path.join(pass_path, out_lib_file) - os.system('g++ -shared -fPIC -std=c++11 ' + str(source) + - ' -o ' + str(target) + ' -I ' + - str(mxnet_include_path)) + lib_api_cc = pathlib.Path.joinpath(mxnet_path, 'src/lib_api.cc') + is_mxnet_1_7 = (mxnet.__version__ == '1.7.0') + if is_mxnet_1_7: + os.system('g++ -shared -fPIC -std=c++11 ' + str(source) + + ' -o ' + str(target) + ' -I ' + str(mxnet_include_path) + ' -DMXNET_1_7') + else: + os.system('g++ -shared -fPIC -std=c++11 ' + str(source) + ' ' + str(lib_api_cc) + + ' -o ' + str(target) + ' -I ' + str(mxnet_include_path)) class CompileBERTPass(install): def run(self): From 6c012c3a7e23638afd2d06dae41716426d2411ca Mon Sep 17 00:00:00 2001 From: root Date: Fri, 16 Oct 2020 20:56:32 +0000 Subject: [PATCH 2/9] Activate CUDA Graphs and clean up --- scripts/bert/bertpass_gpu.cc | 31 ++++++++++++++----------------- scripts/bert/deploy.py | 17 ++++------------- scripts/bert/setup.py | 11 ++++------- 3 files changed, 22 insertions(+), 37 deletions(-) diff --git a/scripts/bert/bertpass_gpu.cc b/scripts/bert/bertpass_gpu.cc index 0b79d53319..64d64f64ca 100644 --- a/scripts/bert/bertpass_gpu.cc +++ b/scripts/bert/bertpass_gpu.cc @@ -30,7 +30,7 @@ #include #include "mxnet/lib_api.h" -#if MXNET_1_7 +#if MX_LIBRARY_VERSION <= 7 class Node; struct NodeEntry { Node* node; @@ -228,7 +228,7 @@ class Graph { using namespace mxnet::ext; #endif -#if MXNET_1_7 +#if MX_LIBRARY_VERSION <= 7 MXReturnValue custom_pass(const std::string& in_graph, const std::string** out_graph, const std::unordered_map& options, const std::unordered_map& args, @@ -255,12 +255,12 @@ MXReturnValue custom_pass(mxnet::ext::Graph *g, std::string base_name = n->name.substr(0, pos - 1); // remove Bias terms in FC - node_ffn1_fwd->attrs["no_bias"]="True"; + node_ffn1_fwd->attrs["no_bias"] = "True"; node_ffn1_fwd->inputs.pop_back(); // create 4 new nodes: 2 expand_dims nodes to expand bias dimensions // a broadcast_like node, and BiasAdd -#if MXNET_1_7 +#if MX_LIBRARY_VERSION <= 7 Node* node_expand_1_bias = new Node(); node_expand_1_bias->name = base_name + "_expand_1_bias"; node_expand_1_bias->op = "expand_dims"; @@ -283,11 +283,11 @@ MXReturnValue custom_pass(mxnet::ext::Graph *g, Node* node_bcst_like = g->addNode(base_name + "_broadcast_like", "broadcast_like"); Node* node_add_bias = g->addNode(base_name + "_add_bias", "elemwise_add"); #endif - node_expand_1_bias->attrs["axis"]="0"; + node_expand_1_bias->attrs["axis"] = "0"; node_expand_1_bias->inputs.resize(1); node_expand_1_bias->inputs[0].node = node_ffn1_bias; node_expand_1_bias->inputs[0].entry = 0; - node_expand_2_bias->attrs["axis"]="0"; + node_expand_2_bias->attrs["axis"] = "0"; node_expand_2_bias->inputs.resize(1); node_expand_2_bias->inputs[0].node = node_expand_1_bias; node_expand_2_bias->inputs[0].entry = 0; @@ -317,7 +317,7 @@ MXReturnValue custom_pass(mxnet::ext::Graph *g, int num_heads = 0; int head_dimension = 0; int shape0, shape1; -#if MXNET_1_7 +#if MX_LIBRARY_VERSION <= 7 for(Node* n : g->nodes) { #else for(int i=0; i < g->size(); i++) { @@ -340,7 +340,7 @@ MXReturnValue custom_pass(mxnet::ext::Graph *g, head_dimension = shape0 / num_heads; // find projection nodes and set new interleaved intputs -#if MXNET_1_7 +#if MX_LIBRARY_VERSION <= 7 for(Node* n : g->nodes) { #else for(int i=0; i < g->size(); i++) { @@ -354,7 +354,7 @@ MXReturnValue custom_pass(mxnet::ext::Graph *g, //////////////////// WEIGHTS //////////////////// // create new input node with interleaved weights std::string name_qkv_weights_interleaved = base_name + "_qkv_weights_interleaved"; -#if MXNET_1_7 +#if MX_LIBRARY_VERSION <= 7 // create a new input Node Node* node_qkv_weights = new Node(); node_qkv_weights->name = name_qkv_weights_interleaved; @@ -370,14 +370,13 @@ MXReturnValue custom_pass(mxnet::ext::Graph *g, float* query_w_data = query_w.data(); float* key_w_data = key_w.data(); float* value_w_data = value_w.data(); - //add a new node in graph, also as input g->nodes.push_back(node_qkv_weights); g->inputs.push_back(node_qkv_weights); #else Node* node_qkv_weights = g->addNode(name_qkv_weights_interleaved + "_input", "null"); node_qkv_weights->alloc_arg({3 * shape0, shape1}, MXContext::CPU(0), kFloat32); float* qkv_w_data = node_qkv_weights->tensor->data(); - // look back for query, key value weights original data + // look back for query, key, and value weights: original data float *query_w_data, *key_w_data, *value_w_data; int found = 0; for (int j=i; j >= 0; j--) { @@ -398,7 +397,7 @@ MXReturnValue custom_pass(mxnet::ext::Graph *g, break; } #endif - // interleave weights + // set interleave weights for (int h=0; hname = name_qkv_bias; node_qkv_bias->op = "null"; @@ -433,14 +432,13 @@ MXReturnValue custom_pass(mxnet::ext::Graph *g, float* query_bias_data = query_bias.data(); float* key_bias_data = key_bias.data(); float* value_bias_data = value_bias.data(); - //add a new node in graph, also as input g->nodes.push_back(node_qkv_bias); g->inputs.push_back(node_qkv_bias); #else Node* node_qkv_bias = g->addNode(name_qkv_bias + "_input", "null"); node_qkv_bias->alloc_arg({3 * shape0, }, MXContext::CPU(0), kFloat32); float* qkv_bias_data = node_qkv_bias->tensor->data(); - // look back for query, key value weights original data + // look back for query, key, and value bias: original data float *query_bias_data, *key_bias_data, *value_bias_data; found = 0; for (int j=i; j >= 0; j--) { @@ -477,7 +475,7 @@ MXReturnValue custom_pass(mxnet::ext::Graph *g, } } -#if MXNET_1_7 +#if MX_LIBRARY_VERSION <= 7 //convert back to JSON string from Graph/Node *out_graph = new std::string(g->toString()); #endif @@ -488,7 +486,6 @@ REGISTER_PASS(custom_pass) .setBody(custom_pass); MXReturnValue initialize(int version) { - printf("VERSION %i\n", version); if (version >= 10700) { std::cout << "MXNet version " << version << " supported" << std::endl; return MX_SUCCESS; diff --git a/scripts/bert/deploy.py b/scripts/bert/deploy.py index 9cda76f885..5cd608e304 100644 --- a/scripts/bert/deploy.py +++ b/scripts/bert/deploy.py @@ -228,6 +228,9 @@ ############################################################################### # Hybridize the model # ############################################################################### + if (mx.__version__ > '1.7.0'): + os.environ['MXNET_ENABLE_CUDA_GRAPHS'] = '1' + log.info('CUDA Graphs enabled ') export_ctx = mx.cpu() seq_length = args.seq_length do_lower_case = 'uncased' in args.dataset_name @@ -320,19 +323,7 @@ def export(prefix): arg_array['data1'] = mx.nd.ones((test_batch_size, seq_length), dtype='float32') arg_array['data2'] = mx.nd.ones((test_batch_size, ), dtype='float32') custom_sym = sym.optimize_for('custom_pass', arg_array, aux_params) - - is_mxnet_1_7 = (mx.__version__ == '1.7.0') - '''if not is_mxnet_1_7: - shape_dict = {'data0': (test_batch_size, seq_length), - 'data1': (test_batch_size, seq_length), - 'data2': (test_batch_size, )} - type_dict = {'data0': 'float32', - 'data1': 'float32', - 'data2': 'float32'} - custom_sym = sym.optimize_for('custom_pass', arg_params, - aux_params, shape_dict=shape_dict, - type_dict=type_dict)''' - if is_mxnet_1_7: + if (mx.__version__ <= '1.7.0'): nheads = 12 if args.bert_model == 'bert_24_1024_16': nheads = 24 diff --git a/scripts/bert/setup.py b/scripts/bert/setup.py index 5f601790c2..e688342109 100644 --- a/scripts/bert/setup.py +++ b/scripts/bert/setup.py @@ -27,13 +27,10 @@ def CompileBERTCustomPass(): source = os.path.join(pass_path, input_pass_file) target = os.path.join(pass_path, out_lib_file) lib_api_cc = pathlib.Path.joinpath(mxnet_path, 'src/lib_api.cc') - is_mxnet_1_7 = (mxnet.__version__ == '1.7.0') - if is_mxnet_1_7: - os.system('g++ -shared -fPIC -std=c++11 ' + str(source) + - ' -o ' + str(target) + ' -I ' + str(mxnet_include_path) + ' -DMXNET_1_7') - else: - os.system('g++ -shared -fPIC -std=c++11 ' + str(source) + ' ' + str(lib_api_cc) + - ' -o ' + str(target) + ' -I ' + str(mxnet_include_path)) + if (mxnet.__version__ > '1.7.0'): + source = source + ' ' + str(lib_api_cc) + os.system('g++ -shared -fPIC -std=c++11 ' + str(source) + + ' -o ' + str(target) + ' -I ' + str(mxnet_include_path)) class CompileBERTPass(install): def run(self): From 17446f42778a49a910d6a7323bf9c38adb50c261 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 4 Feb 2021 21:20:21 +0000 Subject: [PATCH 3/9] debugging lib_api.cc path --- scripts/bert/setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/bert/setup.py b/scripts/bert/setup.py index e688342109..3507afdada 100644 --- a/scripts/bert/setup.py +++ b/scripts/bert/setup.py @@ -29,6 +29,8 @@ def CompileBERTCustomPass(): lib_api_cc = pathlib.Path.joinpath(mxnet_path, 'src/lib_api.cc') if (mxnet.__version__ > '1.7.0'): source = source + ' ' + str(lib_api_cc) + print("MXNET ver: ", mxnet.__version__, 'source:', source, 'mxnet_include_path', mxnet_include_path) + print ("lib_api_cc Exist:"+str(os.path.exists(lib_api_cc))) os.system('g++ -shared -fPIC -std=c++11 ' + str(source) + ' -o ' + str(target) + ' -I ' + str(mxnet_include_path)) From 656145e922906692b7c18c82a9ddb0d561138faf Mon Sep 17 00:00:00 2001 From: root Date: Thu, 4 Feb 2021 21:29:29 +0000 Subject: [PATCH 4/9] fix lint --- scripts/bert/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/bert/setup.py b/scripts/bert/setup.py index 3507afdada..b176d0d644 100644 --- a/scripts/bert/setup.py +++ b/scripts/bert/setup.py @@ -30,7 +30,7 @@ def CompileBERTCustomPass(): if (mxnet.__version__ > '1.7.0'): source = source + ' ' + str(lib_api_cc) print("MXNET ver: ", mxnet.__version__, 'source:', source, 'mxnet_include_path', mxnet_include_path) - print ("lib_api_cc Exist:"+str(os.path.exists(lib_api_cc))) + print("lib_api_cc Exist:"+str(os.path.exists(lib_api_cc))) os.system('g++ -shared -fPIC -std=c++11 ' + str(source) + ' -o ' + str(target) + ' -I ' + str(mxnet_include_path)) From 2ebf53fba25909816f73474710252f73baca53d9 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 4 Feb 2021 21:53:32 +0000 Subject: [PATCH 5/9] fix lint --- scripts/bert/setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/bert/setup.py b/scripts/bert/setup.py index b176d0d644..b274bcb0e4 100644 --- a/scripts/bert/setup.py +++ b/scripts/bert/setup.py @@ -29,8 +29,8 @@ def CompileBERTCustomPass(): lib_api_cc = pathlib.Path.joinpath(mxnet_path, 'src/lib_api.cc') if (mxnet.__version__ > '1.7.0'): source = source + ' ' + str(lib_api_cc) - print("MXNET ver: ", mxnet.__version__, 'source:', source, 'mxnet_include_path', mxnet_include_path) - print("lib_api_cc Exist:"+str(os.path.exists(lib_api_cc))) + print('MXNET ver: ', mxnet.__version__, 'source:', source, 'mxnet_include_path', mxnet_include_path) + print('lib_api_cc Exist:' + str(os.path.exists(lib_api_cc))) os.system('g++ -shared -fPIC -std=c++11 ' + str(source) + ' -o ' + str(target) + ' -I ' + str(mxnet_include_path)) From 2086998fd8e2debe02a61e99f797dbffbd6613f5 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 4 Feb 2021 21:59:12 +0000 Subject: [PATCH 6/9] fix lint --- scripts/bert/setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/bert/setup.py b/scripts/bert/setup.py index b274bcb0e4..2a168269ee 100644 --- a/scripts/bert/setup.py +++ b/scripts/bert/setup.py @@ -29,7 +29,8 @@ def CompileBERTCustomPass(): lib_api_cc = pathlib.Path.joinpath(mxnet_path, 'src/lib_api.cc') if (mxnet.__version__ > '1.7.0'): source = source + ' ' + str(lib_api_cc) - print('MXNET ver: ', mxnet.__version__, 'source:', source, 'mxnet_include_path', mxnet_include_path) + print('MXNET ver: ', mxnet.__version__, 'source:', source, 'mxnet_include_path', + mxnet_include_path) print('lib_api_cc Exist:' + str(os.path.exists(lib_api_cc))) os.system('g++ -shared -fPIC -std=c++11 ' + str(source) + ' -o ' + str(target) + ' -I ' + str(mxnet_include_path)) From 7cea97c9f4c9c8e71f88cf29e5f548197fcc80a2 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 5 Feb 2021 00:12:11 +0000 Subject: [PATCH 7/9] debugging lib_api.cc --- scripts/bert/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/bert/setup.py b/scripts/bert/setup.py index 2a168269ee..23696e24d6 100644 --- a/scripts/bert/setup.py +++ b/scripts/bert/setup.py @@ -31,7 +31,7 @@ def CompileBERTCustomPass(): source = source + ' ' + str(lib_api_cc) print('MXNET ver: ', mxnet.__version__, 'source:', source, 'mxnet_include_path', mxnet_include_path) - print('lib_api_cc Exist:' + str(os.path.exists(lib_api_cc))) + print('lib_api_cc Exist:' + str(os.path.exists(str(lib_api_cc)))) os.system('g++ -shared -fPIC -std=c++11 ' + str(source) + ' -o ' + str(target) + ' -I ' + str(mxnet_include_path)) From 3826e0e7188676ab733307121f1ac81600d8f70c Mon Sep 17 00:00:00 2001 From: root Date: Fri, 5 Feb 2021 17:43:48 +0000 Subject: [PATCH 8/9] include lib_api.cc within GluonNLP --- scripts/bert/lib_api.cc | 1624 +++++++++++++++++++++++++++++++++++++++ scripts/bert/setup.py | 9 +- 2 files changed, 1629 insertions(+), 4 deletions(-) create mode 100644 scripts/bert/lib_api.cc diff --git a/scripts/bert/lib_api.cc b/scripts/bert/lib_api.cc new file mode 100644 index 0000000000..c273678dcd --- /dev/null +++ b/scripts/bert/lib_api.cc @@ -0,0 +1,1624 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file lib_api.cc + * \brief APIs to interact with libraries + * This API specifies function prototypes to + * register custom ops, partitioner, and passes + * for library authors + * See example/extension/lib_custom_op/README.md + * See example/extension/lib_subgraph/README.md + * See example/extension/lib_pass/README.md + */ + +#include + +mxnet::ext::MXerrorMsgs::~MXerrorMsgs() { + for (auto &ss : messages) + delete ss; +} + +mxnet::ext::MXerrorMsgs* mxnet::ext::MXerrorMsgs::get() { + static MXerrorMsgs inst; + return &inst; + } + +std::stringstream& mxnet::ext::MXerrorMsgs::add(const char* file, int line) { + messages.push_back(new std::stringstream()); + *messages.back() << file << "[" << line << "]: "; + return *messages.back(); +} + +int mxnet::ext::MXerrorMsgs::size() { + return messages.size(); +} + +const std::string* mxnet::ext::MXerrorMsgs::get(int idx) { + return new std::string(messages.at(idx)->str()); +} + +mxnet::ext::MXContext::MXContext() : dev_type("error"), dev_id(-1) {} + +mxnet::ext::MXContext::MXContext(std::string dev_type_, int dev_id_) + : dev_type(std::move(dev_type_)), dev_id(dev_id_) {} + +mxnet::ext::MXContext::MXContext(const char* dev_type_, int dev_id_) + : dev_type(dev_type_), dev_id(dev_id_) {} + +mxnet::ext::MXContext mxnet::ext::MXContext::CPU() { return MXContext("cpu", 0); } + +mxnet::ext::MXContext mxnet::ext::MXContext::GPU() { return MXContext("gpu", 0); } + +mxnet::ext::MXContext mxnet::ext::MXContext::CPU(int dev_id) { return MXContext("cpu", dev_id); } + +mxnet::ext::MXContext mxnet::ext::MXContext::GPU(int dev_id) { return MXContext("gpu", dev_id); } + +void mxnet::ext::MXSparse::set(void *data_ptr, const int64_t* dims, int ndims, void *idx, + int64_t num_idx, void *idx_ptr, int64_t num_idx_ptr) { + data = data_ptr; + // If CSR, num of non-zero elemets is num_idx, + // If row sparse, num of elements is num_idx * width. + data_len = num_idx; + if (!idx_ptr) { + for (int i = 1; i < ndims; ++i) + data_len *= dims[i]; + } + + indices = reinterpret_cast(idx); + indices_len = num_idx; + + if (idx_ptr) { + indptr = reinterpret_cast(idx_ptr); + indptr_len = num_idx_ptr; + } +} + +mxnet::ext::MXTensor::MXTensor() : data_ptr(nullptr), dtype(kUNSET), verID(0), + stype(kDefaultStorage) {} +mxnet::ext::MXTensor::MXTensor(const MXTensor& oth) : data_ptr(oth.data_ptr), shape(oth.shape), + dtype(oth.dtype), verID(oth.verID), + ctx(oth.ctx), stype(oth.stype) { + setDLTensor(); +} + +mxnet::ext::MXTensor::MXTensor(void *data_ptr, std::vector shape, MXDType dtype, + size_t vID, MXContext mx_ctx, MXStorageType stype) + : data_ptr(data_ptr), shape(std::move(shape)), dtype(dtype), verID(vID), ctx(std::move(mx_ctx)), + stype(stype) { + setDLTensor(); +} + +void mxnet::ext::MXTensor::setTensor(void *dptr, MXDType type, const int64_t* dims, int ndims, + size_t vID, MXContext mx_ctx, MXStorageType storage_type) { + data_ptr = dptr; dtype = type; verID = vID; ctx = mx_ctx; stype = storage_type; + shape.clear(); + for (int j = 0; j < ndims; j++) { + shape.push_back(dims[j]); + } + setDLTensor(); +} + +void mxnet::ext::MXTensor::setDLTensor() { + dltensor.data = data_ptr; + dltensor.ndim = shape.size(); + dltensor.shape = const_cast(shape.data()); + dltensor.strides = nullptr; + dltensor.byte_offset = 0; + dltensor.dtype.lanes = 1; + dltensor.ctx.device_id = ctx.dev_id; + if (ctx.dev_type == "cpu") + dltensor.ctx.device_type = kDLCPU; + else if (ctx.dev_type == "gpu") + dltensor.ctx.device_type = kDLGPU; + else if (ctx.dev_type == "opencl") + dltensor.ctx.device_type = kDLOpenCL; + else if (ctx.dev_type == "vulcan") + dltensor.ctx.device_type = kDLVulkan; + else if (ctx.dev_type == "metal") + dltensor.ctx.device_type = kDLMetal; + else if (ctx.dev_type == "vpi") + dltensor.ctx.device_type = kDLVPI; + else if (ctx.dev_type == "rocm") + dltensor.ctx.device_type = kDLROCM; + else + dltensor.ctx.device_type = kDLExtDev; + switch (dtype) { + case kFloat32: + dltensor.dtype.code = kDLFloat; + dltensor.dtype.bits = 32; + break; + case kFloat64: + dltensor.dtype.code = kDLFloat; + dltensor.dtype.bits = 64; + break; + case kFloat16: + dltensor.dtype.code = kDLFloat; + dltensor.dtype.bits = 16; + break; + case kUint8: + dltensor.dtype.code = kDLUInt; + dltensor.dtype.bits = 8; + break; + case kInt32: + dltensor.dtype.code = kDLInt; + dltensor.dtype.bits = 32; + break; + case kInt8: + dltensor.dtype.code = kDLInt; + dltensor.dtype.bits = 8; + break; + case kInt64: + dltensor.dtype.code = kDLInt; + dltensor.dtype.bits = 64; + break; + default: + dltensor.dtype.code = 0; + dltensor.dtype.bits = 0; + throw std::runtime_error("Error! Invalid dtype flag: " + + std::to_string(static_cast(dtype)) + + " when constructing MXTensor"); + } +} + +int64_t mxnet::ext::MXTensor::size() const { + int64_t size = 1; + for (auto &s : shape) + size *= s; + return size; +} + +bool mxnet::ext::MXTensor::isSame(const MXTensor &oth) const { + return data_ptr == oth.data_ptr && + dtype == oth.dtype && + verID == oth.verID && + ctx.dev_type == oth.ctx.dev_type && + ctx.dev_id == oth.ctx.dev_id && + shape == oth.shape && + stype == oth.stype; +} + +mxnet::ext::PassResource::PassResource(std::unordered_map* new_args, + std::unordered_map* new_aux, + nd_malloc_t nd_malloc, const void* nd_alloc) + : new_args_(new_args), new_aux_(new_aux), nd_malloc_(nd_malloc), nd_alloc_(nd_alloc) {} + +mxnet::ext::MXTensor* mxnet::ext::PassResource::alloc_arg(const std::string& name, + const std::vector& shapes, + const mxnet::ext::MXContext &ctx, + mxnet::ext::MXDType dtype) const { + void* data; + nd_malloc_(nd_alloc_, shapes.data(), shapes.size(), ctx.dev_type.c_str(), ctx.dev_id, + dtype, name.c_str(), 1, &data); + MXTensor tensor(data, shapes, dtype, 0, ctx, kDefaultStorage); + (*new_args_)[name] = tensor; + return &(new_args_->at(name)); +} + +mxnet::ext::MXTensor* mxnet::ext::PassResource::alloc_aux(const std::string& name, + const std::vector& shapes, + const mxnet::ext::MXContext &ctx, + mxnet::ext::MXDType dtype) const { + void* data; + nd_malloc_(nd_alloc_, shapes.data(), shapes.size(), ctx.dev_type.c_str(), ctx.dev_id, + dtype, name.c_str(), 0, &data); + MXTensor tensor(data, shapes, dtype, 0, ctx, kDefaultStorage); + (*new_aux_)[name] = tensor; + return &(new_aux_->at(name)); +} + +mxnet::ext::OpResource::OpResource(xpu_malloc_t cpu_malloc_fp, void* cpu_alloc_fp, + xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream, + sparse_malloc_t sparse_malloc_fp, void* sparse_alloc_fp, + void* rng_cpu_states, void* rng_gpu_states) + : cpu_malloc(cpu_malloc_fp), gpu_malloc(gpu_malloc_fp), + cpu_alloc(cpu_alloc_fp), gpu_alloc(gpu_alloc_fp), cuda_stream(stream), + sparse_malloc(sparse_malloc_fp), sparse_alloc(sparse_alloc_fp), + rand_cpu_states(rng_cpu_states), rand_gpu_states(rng_gpu_states) {} + +void* mxnet::ext::OpResource::alloc_cpu(int size) const { + return cpu_malloc(cpu_alloc, size); +} + +void* mxnet::ext::OpResource::alloc_gpu(int size) const { + return gpu_malloc(gpu_alloc, size); +} + +void mxnet::ext::OpResource::alloc_sparse(mxnet::ext::MXSparse* sparse, int index, + int indices_len, int indptr_len) const { + sparse_malloc(sparse_alloc, index, indices_len, indptr_len, + &(sparse->data), &(sparse->indices), &(sparse->indptr)); +} + +mxnet::ext::mx_cpu_rand_t* mxnet::ext::OpResource::get_cpu_rand_states() const { + return static_cast(rand_cpu_states); +} + +std::string mxnet::ext::getShapeAt(const std::string& shape, unsigned index) { + int idx = 1; // start at 1 to skip the first square bracket [ + // find the beginning of the output shape for the particular output index + for (unsigned x=0; x < index; x++) + idx = shape.find("[", idx+1); + int stop = shape.find("]", idx); // find stop index for this output shape + // add this shape to the list + return shape.substr(idx, stop-idx+1); +} + +std::string mxnet::ext::getDtypeAt(const std::string& dtype, unsigned index) { + // find the beginning of the output dtype for the particular output index + int idx = 0; + for (unsigned x=0; x < index; x++) + idx = dtype.find(",", idx+1); + int stop = dtype.find(",", idx+1); // find stop index for this output dtype + if (stop == -1) stop = dtype.find("]", idx+1); + return dtype.substr(idx+1, stop-idx-1); +} + +mxnet::ext::JsonVal::JsonVal() : type(ERR), num(-1), str("") {} +mxnet::ext::JsonVal::JsonVal(mxnet::ext::JsonType t) : type(t), num(-1), str("") {} +mxnet::ext::JsonVal::JsonVal(std::string s) : type(STR), num(-1), str(std::move(s)) {} +mxnet::ext::JsonVal::JsonVal(int n) : type(NUM), num(n), str(std::to_string(n)) {} +mxnet::ext::JsonVal::JsonVal(JsonType t, int n, std::string s) : type(t), num(n), + str(std::move(s)) {} + +bool mxnet::ext::JsonVal::operator<(const mxnet::ext::JsonVal &o) const { + // for string JSON objects compare the string + if (type == STR) return type == o.type && str < o.str; + // for number JSON objects compare the number + if (type == NUM) return type == o.type && num < o.num; + // for list JSON objects, compare the size of list, and then each object in the list + if (type == LIST) { + if (list.size() != o.list.size()) return false; + for (unsigned int i=0; i< list.size(); i++) + if (list[i] < o.list[i]) + return false; // if we find an object that doesnt match return + return true; // all objects in lists matched + } + // for map JSON objects, compare the size of map, and then each key/value in the maps + if (type == MAP) { + if (map.size() != o.map.size()) return false; + for (auto &item : map) { + // if one map is missing a key in another return + if (o.map.find(item.first) == o.map.end()) return false; + if (item.second < o.map.at(item.first)) return false; + } + return true; + } + return type < o.type; +} + +std::string mxnet::ext::JsonVal::dump() const { + std::string ret; + switch (type) { + case ERR: + ret = "json(Error)"; + break; + case STR: + ret = "\"" + str + "\""; + break; + case NUM: + ret = str; + break; + case LIST: + ret = "["; + for (unsigned i=0; i < list.size(); i++) { + auto &item = list[i]; + ret += item.dump(); + if (i < list.size()-1) + ret += ","; + } + ret += "]"; + break; + case MAP: + ret = "{"; + unsigned cnt = 0; + for (auto &item : map) { + ret += item.first.dump() + " : " + item.second.dump(); + if (cnt++ < map.size()-1) + ret += ","; + } + ret += "}"; + break; + } + return ret; +} + +mxnet::ext::JsonVal mxnet::ext::JsonVal::parse(const std::string& json) { + unsigned int idx = 0; + return JsonVal::parse(json, &idx); +} + +mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_string(const std::string& json, unsigned int* idx) { + JsonVal ret(STR); + while (*idx < json.size()) { + if (json[*idx] == '"' && (ret.str.size() == 0 || + (ret.str.size() > 0 && ret.str.back() != '\\'))) { + ++(*idx); + return ret; + } else { + ret.str += json[*idx]; + ++(*idx); + } + } + MX_ERROR_MSG << "Error! Unable to parse string: '" << json.substr(*idx) << "'" << std::endl; + return JsonVal(); +} + +mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_num(const std::string& json, unsigned int* idx) { + JsonVal ret(NUM); + while (*idx < json.size()) { + if (json[*idx] >= '0' && json[*idx] <= '9') { + ret.str += json[*idx]; + ++(*idx); + } else { + break; + } + } + ret.num = std::stoi(ret.str); + return ret; +} + +mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_list(const std::string& json, unsigned int* idx) { + JsonVal ret(LIST); + while (*idx < json.size()) { + if (json[*idx] == ']') { + ++(*idx); + return ret; + } else { + JsonVal item = JsonVal::parse(json, idx); + if (item.type != ERR) + ret.list.push_back(item); + } + } + MX_ERROR_MSG << "Error! Unable to parse list: '" << json.substr(*idx) << "'" << std::endl; + return JsonVal(); +} + +mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_map(const std::string& json, unsigned int* idx) { + JsonVal ret(MAP), key; + while (*idx < json.size()) { + if (json[*idx] == '}') { + ++(*idx); + return ret; + } else { + JsonVal item = JsonVal::parse(json, idx); + if (key.type == ERR) { + key = item; + } else { + ret.map[key] = item; + key.type = ERR; + } + } + } + MX_ERROR_MSG << "Error! Unable to parse map: '" << json.substr(*idx) << "'" << std::endl; + return mxnet::ext::JsonVal(); +} + +mxnet::ext::JsonVal mxnet::ext::JsonVal::parse(const std::string& json, unsigned int *idx) { + JsonVal ret; + while (*idx < json.size()) { + if (json[*idx] == '"') { + ++(*idx); + ret = JsonVal::parse_string(json, idx); + } else if (json[*idx] >= '0' && json[*idx] <= '9') { + ret = JsonVal::parse_num(json, idx); + } else if (json[*idx] == '[') { + ++(*idx); + ret = JsonVal::parse_list(json, idx); + } else if (json[*idx] == '{') { + ++(*idx); + ret = JsonVal::parse_map(json, idx); + } else if (json[*idx] == ']' || json[*idx] == '}') {return ret;} + if (ret.type != ERR) return ret; + ++(*idx); + } + return ret; +} + +std::string mxnet::ext::JsonVal::toString() const { + std::string ret; + switch (type) { + case ERR: + ret = "json(Error)"; + break; + case STR: + ret = "json(STR:" + str + ")"; + break; + case NUM: + ret = "json(INT:" + str + ")"; + break; + case LIST: + ret = "json(LIST:["; + for (auto &item : list) + ret += item.toString() + ","; + ret += "])"; + break; + case MAP: + ret = "json(MAP:{"; + for (auto &item : map) + ret += item.first.toString() + " : " + item.second.toString() + ","; + ret += "})"; + break; + } + return ret; +} + +mxnet::ext::Node::Node() {tensor = nullptr;} + +void mxnet::ext::Node::_setPassResource(mxnet::ext::PassResource* res_) {res = res_;} + +void mxnet::ext::Node::alloc_arg(const std::vector& shapes, + const mxnet::ext::MXContext &ctx, mxnet::ext::MXDType dtype) { + if (!res) + throw std::runtime_error("Node not initialized. Cannot use alloc_arg outside of graph passes."); + tensor = res->alloc_arg(name, shapes, ctx, dtype); +} + +void mxnet::ext::Node::alloc_aux(const std::vector& shapes, + const mxnet::ext::MXContext &ctx, mxnet::ext::MXDType dtype) { + if (!res) + throw std::runtime_error("Node not initialized. Cannot use alloc_aux outside of graph passes."); + tensor = res->alloc_aux(name, shapes, ctx, dtype); +} + +mxnet::ext::Graph::Graph() : res(nullptr) {} + +mxnet::ext::Graph::~Graph() { + for (auto &node : nodes) + delete node; +} + +mxnet::ext::Graph* mxnet::ext::Graph::fromString(const std::string& json) { + JsonVal val = JsonVal::parse(json); + return fromJson(val); +} + +mxnet::ext::Graph* mxnet::ext::Graph::fromJson(mxnet::ext::JsonVal val) { + // get nodes list + JsonVal nodes = val.map[JsonVal("nodes")]; + Graph *g = new Graph(); + + std::map nodeMap; + // loop over nodes + for (size_t i = 0; i < nodes.list.size(); i++) { + Node* n = new Node(); + g->nodes.push_back(n); + JsonVal node = nodes.list[i]; + + // set the op info + n->op = node.map[JsonVal("op")].str; + n->name = node.map[JsonVal("name")].str; + + // if op is null it is an input to the graph + if (n->op.compare("null") == 0) + g->inputs.push_back(n); + + // set attrs + JsonVal attributes = node.map[JsonVal("attrs")]; + for (auto& kv : attributes.map) { + n->attrs[kv.first.str] = kv.second.str; + } + + // set subgraphs, parsing each into a graph + if (node.map.count(JsonVal("subgraphs")) > 0) { + JsonVal subgraphs = node.map[JsonVal("subgraphs")]; + for (auto &subgraph : subgraphs.list) { + n->subgraphs.push_back(fromJson(subgraph)); + } + } + + // set node inputs + JsonVal node_inputs = node.map[JsonVal("inputs")]; + n->inputs.resize(node_inputs.list.size()); + for (size_t j = 0; j < node_inputs.list.size(); j++) { + JsonVal input = node_inputs.list[j]; + NodeEntry& entry = n->inputs[j]; + // get pointer to other node + entry.node = nodeMap[input.list[0].num]; + // get the other node's output index + entry.entry = input.list[1].num; + // set other nodes output as connected to this node + entry.node->outputs.push_back({n, static_cast(j)}); + } + nodeMap[i] = n; + } + + // set graph level outputs + JsonVal& heads = val.map[JsonVal("heads")]; + g->outputs.resize(heads.list.size()); + for (size_t i = 0; i < heads.list.size(); i++) { + JsonVal head = heads.list[i]; + g->outputs[i].node = nodeMap[head.list[0].num]; + g->outputs[i].entry = head.list[1].num; + } + + // add all attributes to the graph + for (auto& kv : val.map) { + if (kv.first.str.compare("nodes") != 0 && + kv.first.str.compare("heads") != 0 && + kv.first.str.compare("node_row_ptr") != 0 && + kv.first.str.compare("arg_nodes") != 0) { + g->attrs[kv.first.str] = kv.second; + } + } + return g; +} + +/* \brief convert graph object back to JSON object */ +mxnet::ext::JsonVal mxnet::ext::Graph::toJson() const { + // top level object is a map + JsonVal val(MAP); + + // add attributes + for (auto& kv : attrs) { + val.map[JsonVal(kv.first)] = kv.second; + } + + // sort graph nodes in topological order, create mapping of node to index + std::map nodeMap; + std::vector sorted = topological_sort(); + // nodes are in reverse topological order in the vector (back is first) + // so loop from end to front over the vector 'sorted' + for (int i = sorted.size()-1; i >= 0; i--) { + nodeMap[sorted[i]] = sorted.size()-1-i; + } + + // create node_row_ptr entry + val.map[JsonVal("node_row_ptr")] = JsonVal(LIST); + JsonVal& node_row_ptr = val.map[JsonVal("node_row_ptr")]; + for (size_t i = 0; i < nodes.size(); i++) + node_row_ptr.list.emplace_back(i); + + // add all input nodes + val.map[JsonVal("arg_nodes")] = JsonVal(LIST); + JsonVal& arg_nodes = val.map[JsonVal("arg_nodes")]; + for (auto &input : inputs) + arg_nodes.list.emplace_back(nodeMap[input]); + + // add all output nodes + val.map[JsonVal("heads")] = JsonVal(LIST); + JsonVal& heads = val.map[JsonVal("heads")]; + for (size_t i = 0; i < outputs.size(); i++) { + heads.list.emplace_back(LIST); + JsonVal& out = heads.list[i]; + out.list.emplace_back(nodeMap[outputs[i].node]); + out.list.emplace_back(outputs[i].entry); + out.list.emplace_back(0); + } + + // add all graph nodes + val.map[JsonVal("nodes")] = JsonVal(LIST); + JsonVal& nodes_ = val.map[JsonVal("nodes")]; + for (int i = sorted.size()-1; i >= 0; i--) { + // each node is a map + nodes_.list.emplace_back(MAP); + Node* n = sorted[i]; + JsonVal& n_ = nodes_.list[nodes_.list.size()-1]; + + n_.map[JsonVal("op")] = JsonVal(n->op); + n_.map[JsonVal("name")] = JsonVal(n->name); + n_.map[JsonVal("inputs")] = JsonVal(LIST); + + // add inputs for this node + JsonVal& inputs_ = n_.map[JsonVal("inputs")]; + for (size_t j = 0; j < n->inputs.size(); j++) { + inputs_.list.emplace_back(LIST); + NodeEntry& entry = n->inputs[j]; + JsonVal& in = inputs_.list[j]; + in.list.emplace_back(nodeMap[entry.node]); + in.list.emplace_back(entry.entry); + in.list.emplace_back(0); + } + + // add subgraphs for this node, convert each back to JSON + if (n->subgraphs.size() > 0) { + n_.map[JsonVal("subgraphs")] = JsonVal(LIST); + JsonVal &subgraphs_ = n_.map[JsonVal("subgraphs")]; + for (Graph *subgraph : n->subgraphs) { + subgraphs_.list.push_back(subgraph->toJson()); + } + } + + // add attributes for this node + n_.map[JsonVal("attrs")] = JsonVal(MAP); + JsonVal& attrs_ = n_.map[JsonVal("attrs")]; + for (auto& kv : n->attrs) { + attrs_.map[JsonVal(kv.first)] = JsonVal(kv.second); + } + } + return val; +} + +/* \brief convert graph object to JSON string */ +std::string mxnet::ext::Graph::toString() const { + return toJson().dump(); +} + + /* \brief visits a node "n" */ +void mxnet::ext::Graph::_dfs_util(Node* n, std::unordered_set* to_visit, + std::function handler) const { + to_visit->erase(n); // remove node now that we're visiting it + for (NodeEntry& e : n->outputs) { + Node* o = e.node; + if (to_visit->count(o) != 0) { + _dfs_util(o, to_visit, handler); // visit neighbor + } + } + handler(n); // post-order visit this node +} + +/* \brief post-order DFS graph traversal */ +void mxnet::ext::Graph::DFS(std::function handler) const { + std::unordered_set to_visit; + // put all nodes in set to visit + for (auto& n : nodes) + to_visit.insert(n); + // visit all inputs first + for (auto& i : inputs) + if (to_visit.count(i) != 0) + _dfs_util(i, &to_visit, handler); + // visit any nodes left + while (to_visit.size() > 0) + _dfs_util(*(to_visit.begin()), &to_visit, handler); +} + +/* \brief sort graph nodes in topological order */ +std::vector mxnet::ext::Graph::topological_sort() const { + std::vector sorted; + auto handler = [&](mxnet::ext::Node* n) { + sorted.push_back(n); // when visiting each node, add it in order to the vector + }; + DFS(handler); + return sorted; +} + +/* \brief print out graph details */ +void mxnet::ext::Graph::print(int indent) const { + std::string space = ""; + for (int i = 0; i < indent; i++) space+=" "; + + std::cout << space << "########### Graph #############" << std::endl; + std::cout << space << "attributes: " << std::endl; + for (auto &kv : attrs) + std::cout << space << "\t" << kv.first << " : " << kv.second.str << std::endl; + std::cout << space << "inputs: " << inputs.size() << std::endl; + std::cout << space << "outputs: " << outputs.size() << std::endl; + std::cout << space << "nodes: " << nodes.size() << std::endl; + std::vector sorted = topological_sort(); + // loop over each node and print out its inputs/outputs + for (int i = sorted.size()-1; i >= 0; i--) { + std::cout << space << "Node: " << sorted[i]->name << std::endl; + for (auto &input : sorted[i]->inputs) { + std::cout << space << "\tInput: " << input.node->name << " " + << input.entry << std::endl; + } + for (auto &output : sorted[i]->outputs) { + std::cout << space << "\tOutput: " << output.node->name << " " + << output.entry << std::endl; + } + if (sorted[i]->subgraphs.size() > 0) { + for (auto &subgraph : sorted[i]->subgraphs) { + std::cout << space << "\tSubgraph:" << std::endl; + subgraph->print(indent+2); + } + } + } + std::cout << space << "###############################" << std::endl; +} + +/* \brief add a new node to this graph */ +mxnet::ext::Node* mxnet::ext::Graph::addNode(const std::string& name, const std::string& op) { + Node* n = new Node(); + nodes.push_back(n); + n->name = name; + n->op = op; + if (res) + n->_setPassResource(res); + return n; +} + +/* \brief get node at index in graph */ +mxnet::ext::Node* mxnet::ext::Graph::getNode(size_t idx) { + return nodes[idx]; +} + +/* \brief get const node at index in const graph */ +const mxnet::ext::Node* mxnet::ext::Graph::getNode(size_t idx) const { + return nodes.at(idx); +} + +/* \brief get attribute on graph */ +const mxnet::ext::JsonVal& mxnet::ext::Graph::getAttr(const std::string& key) const { + return attrs.at(key); +} + +/* \brief get number of nodes in the graph */ +size_t mxnet::ext::Graph::size() const { + return nodes.size(); +} + +// internally set passResource to enable tensor allocation for graph passes +void mxnet::ext::Graph::_setPassResource(PassResource* res_) { + res = res_; + // set passResource for each node + for (Node* node : nodes) { + node->_setPassResource(res); + } +} + +// internally set arg/aux params when available +void mxnet::ext::Graph::_setParams(std::unordered_map* args, + std::unordered_map* aux) { + // set params for each input node + for (Node* node : inputs) { + std::string name = node->name; + if (node->attrs.count("isArg") > 0 && node->attrs["isArg"].compare("True") == 0) + // mapping name back to original node name from subgraph input name + name = node->attrs["argName"]; + if (args->count(name) > 0) + node->tensor = &args->at(name); + else if (aux->count(name) > 0) + node->tensor = &aux->at(name); + } +} + +mxnet::ext::CustomOp::CustomOp(const char* op_name) + : name(op_name), parse_attrs(nullptr), infer_type(nullptr), infer_storage_type(nullptr), + infer_shape(nullptr), mutate_inputs(nullptr), isSGop(false) {} + +mxnet::ext::CustomOp& mxnet::ext::CustomOp::setForward(mxnet::ext::fcomp_t fcomp, const char* ctx) { + if (forward_ctx_map.count(ctx) > 0) + raiseDuplicateContextError(); + forward_ctx_map[ctx] = fcomp; + return *this; +} + +mxnet::ext::CustomOp& mxnet::ext::CustomOp::setBackward(mxnet::ext::fcomp_t fgrad, + const char* ctx) { + if (backward_ctx_map.count(ctx) > 0) + raiseDuplicateContextError(); + backward_ctx_map[ctx] = fgrad; + return *this; +} + +mxnet::ext::CustomOp& mxnet::ext::CustomOp::setParseAttrs(mxnet::ext::parseAttrs_t func) { + parse_attrs = func; + return *this; +} + +mxnet::ext::CustomOp& mxnet::ext::CustomOp::setInferType(mxnet::ext::inferType_t func) { + infer_type = func; + return *this; +} + +mxnet::ext::CustomOp& mxnet::ext::CustomOp::setInferSType(mxnet::ext::inferSType_t func) { + infer_storage_type = func; + return *this; +} + +mxnet::ext::CustomOp& mxnet::ext::CustomOp::setInferShape(mxnet::ext::inferShape_t func) { + infer_shape = func; + return *this; +} + +mxnet::ext::CustomOp& mxnet::ext::CustomOp::setMutateInputs(mxnet::ext::mutateInputs_t func) { + mutate_inputs = func; + return *this; +} + +mxnet::ext::CustomOp& mxnet::ext::CustomOp::setCreateOpState(mxnet::ext::createOpState_t func, + const char* ctx) { + if (create_op_ctx_map.count(ctx) > 0) + raiseDuplicateContextError(); + create_op_ctx_map[ctx] = func; + return *this; +} + +mxnet::ext::CustomOp& mxnet::ext::CustomOp::setIsSubgraphOp() { + isSGop = true; + return *this; +} + +void mxnet::ext::CustomOp::mapToVector() { + for (auto kv : forward_ctx_map) { + forward_ctx_cstr.push_back(kv.first); + forward_fp.push_back(kv.second); + } + for (auto kv : backward_ctx_map) { + backward_ctx_cstr.push_back(kv.first); + backward_fp.push_back(kv.second); + } + for (auto kv : create_op_ctx_map) { + create_op_ctx_cstr.push_back(kv.first); + create_op_fp.push_back(kv.second); + } +} + +void mxnet::ext::CustomOp::raiseDuplicateContextError() { + std::string op_name_str(name); + throw std::runtime_error( + "Error! Error! Cannot register multiple functions under same context for operator '" + + op_name_str + "'"); +} + +mxnet::ext::CustomPass::CustomPass() : name("ERROR") {} +mxnet::ext::CustomPass::CustomPass(const char* pass_name) + : name(pass_name) {} +mxnet::ext::CustomPass& mxnet::ext::CustomPass::setBody(graphPass_t fn) { + pass = fn; + return *this; +} + +mxnet::ext::CustomPartitioner::CustomPartitioner() : name("ERROR") {} +mxnet::ext::CustomPartitioner::CustomPartitioner(const char* backend_name) : + name(backend_name) {} + +mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::addStrategy(const char* prop_name, + const char* sg_name) { + strategies.push_back(prop_name); + op_names.push_back(sg_name); + return *this; +} + +mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::setSupportedOps(const char* prop_name, + mxnet::ext::supportedOps_t fn) { + supported_map[std::string(prop_name)] = fn; + return *this; +} + +mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::setCreateSelector( + const char* prop_name, mxnet::ext::createSelector_t fn) { + selector_map[std::string(prop_name)] = fn; + return *this; +} + +mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::setReviewSubgraph( + const char* prop_name, mxnet::ext::reviewSubgraph_t fn) { + review_map[std::string(prop_name)] = fn; + return *this; +} + +mxnet::ext::supportedOps_t mxnet::ext::CustomPartitioner::getSupportedOps(int stg_id) { + std::string prop(strategies[stg_id]); + if (supported_map.count(prop) > 0) + return supported_map[prop]; + else + return nullptr; +} + +mxnet::ext::createSelector_t mxnet::ext::CustomPartitioner::getCreateSelector(int stg_id) { + std::string prop(strategies[stg_id]); + if (selector_map.count(prop) > 0) + return selector_map[prop]; + else + return nullptr; +} + +mxnet::ext::reviewSubgraph_t mxnet::ext::CustomPartitioner::getReviewSubgraph(int stg_id) { + std::string prop(strategies[stg_id]); + if (review_map.count(prop) > 0) + return review_map[prop]; + else + return nullptr; +} + +/*! \brief returns MXNet library version */ +MX_INT_RET _opVersion() { + return MX_LIBRARY_VERSION; +} + +/*! \brief returns number of ops registered in this library */ +MX_INT_RET _opRegSize() { + return mxnet::ext::Registry::get()->size(); +} + +/*! \brief returns operator registration at specified index */ +MX_VOID_RET _opRegGet(int idx, const char** name, int *isSGop, + const char*** forward_ctx, mxnet::ext::fcomp_t** forward_fp, + int* forward_count, const char*** backward_ctx, + mxnet::ext::fcomp_t** backward_fp, int* backward_count, + const char*** create_op_ctx, mxnet::ext::createOpState_t** create_op_fp, + int* create_op_count, mxnet::ext::parseAttrs_t* parse, + mxnet::ext::inferType_t* type, mxnet::ext::inferSType_t* stype, + mxnet::ext::inferShape_t* shape, mxnet::ext::mutateInputs_t* mutate) { + mxnet::ext::CustomOp &op = mxnet::ext::Registry::get()->get(idx); + *name = op.name; + *parse = op.parse_attrs; + *type = op.infer_type; + *stype = op.infer_storage_type; + *shape = op.infer_shape; + *mutate = op.mutate_inputs; + *isSGop = op.isSGop; + op.mapToVector(); + *forward_ctx = op.forward_ctx_cstr.data(); + *forward_fp = op.forward_fp.data(); + *forward_count = op.forward_fp.size(); + *backward_ctx = op.backward_ctx_cstr.data(); + *backward_fp = op.backward_fp.data(); + *backward_count = op.backward_fp.size(); + *create_op_ctx = op.create_op_ctx_cstr.data(); + *create_op_fp = op.create_op_fp.data(); + *create_op_count = op.create_op_fp.size(); +} + +/*! \brief calls free from the external library for library allocated arrays */ +MX_VOID_RET _opCallFree(void* ptr) { + free(ptr); // NOLINT +} + +/*! \brief returns status of calling parse attributes function for operator from library */ +MX_INT_RET _opCallParseAttrs(mxnet::ext::parseAttrs_t parseAttrs, const char* const* keys, + const char* const* vals, int num, + int* num_in, int* num_out) { + // create map of attributes from list + std::unordered_map attrs; + for (int i = 0; i < num; i++) { + attrs[std::string(keys[i])] = std::string(vals[i]); + } + return parseAttrs(attrs, num_in, num_out); +} + +/*! \brief returns status of calling inferShape function for operator from library */ +MX_INT_RET _opCallInferShape(mxnet::ext::inferShape_t inferShape, const char* const* keys, + const char* const* vals, int num, + unsigned int** inshapes, int* indims, int num_in, + unsigned int*** mod_inshapes, int** mod_indims, + unsigned int*** outshapes, int** outdims, int num_out) { + // create map of attributes from list + std::unordered_map attrs; + for (int i = 0; i < num; i++) { + attrs[std::string(keys[i])] = std::string(vals[i]); + } + + // create a vector of shapes for inputs + std::vector > in_shapes(num_in); + for (int i = 0; i < num_in; i++) { + for (int j = 0; j < indims[i]; j++) { + in_shapes[i].push_back(inshapes[i][j]); + } + } + + // create a vector of shapes for outputs + std::vector > out_shapes(num_out); + + int retval = inferShape(attrs, &in_shapes, &out_shapes); + if (!retval) return retval; + + // allocate space for modified input dims, shape + *mod_indims = static_cast(malloc (num_in * sizeof(int))); // NOLINT + *mod_inshapes = static_cast(malloc (num_in * sizeof(unsigned*))); // NOLINT + + // copy modified input shapes + for (int i = 0; i < num_in; i++) { + (*mod_indims)[i] = in_shapes[i].size(); + (*mod_inshapes)[i] = static_cast( + malloc ((*mod_indims)[i] * sizeof(unsigned))); // NOLINT + for (int j = 0; j < (*mod_indims)[i]; j++) { + (*mod_inshapes)[i][j] = in_shapes[i][j]; + } + } + + // allocate space for output dims, shape + *outdims = static_cast(malloc (num_out * sizeof(int))); // NOLINT + *outshapes = static_cast(malloc (num_out * sizeof(unsigned*))); // NOLINT + + // copy output shapes + for (int i = 0; i < num_out; i++) { + (*outdims)[i] = out_shapes[i].size(); + (*outshapes)[i] = static_cast(malloc ((*outdims)[i] * sizeof(unsigned))); // NOLINT + for (int j = 0; j < (*outdims)[i]; j++) { + (*outshapes)[i][j] = out_shapes[i][j]; + } + } + return retval; +} + +/*! \brief returns status of calling inferType function for operator from library */ +MX_INT_RET _opCallInferType(mxnet::ext::inferType_t inferType, const char* const* keys, + const char* const* vals, int num, + int* intypes, int num_in, int* outtypes, int num_out) { + // create map of attributes from list + std::unordered_map attrs; + for (int i = 0; i < num; i++) { + attrs[std::string(keys[i])] = std::string(vals[i]); + } + + // create a vector of types for inputs + std::vector in_types(num_in); + for (int i = 0; i < num_in; i++) { + in_types[i] = intypes[i]; + } + + // create a vector of types for outputs + std::vector out_types(num_out, -1); + + int retval = inferType(attrs, &in_types, &out_types); + if (!retval) + return retval; + + // copy modified input types + for (int i = 0; i < num_in; i++) { + intypes[i] = in_types[i]; + } + // copy output types + for (int i = 0; i < num_out; i++) { + outtypes[i] = out_types[i]; + } + + return retval; +} + +/*! \brief returns status of calling inferSType function for operator from library */ +MX_INT_RET _opCallInferSType(mxnet::ext::inferSType_t inferSType, const char* const* keys, + const char* const* vals, int num, + int* instypes, int num_in, int* outstypes, int num_out) { + // create map of attributes from list + std::unordered_map attrs; + for (int i = 0; i < num; i++) { + attrs[std::string(keys[i])] = std::string(vals[i]); + } + + // create a vector of types for inputs + std::vector in_stypes(num_in); + for (int i = 0; i < num_in; i++) { + in_stypes[i] = instypes[i]; + } + + // create a vector of types for outputs + std::vector out_stypes(num_out, -1); + + int retval = inferSType(attrs, &in_stypes, &out_stypes); + + if (!retval) + return retval; + + // copy modified input storage types + for (int i = 0; i < num_in; i++) { + instypes[i] = in_stypes[i]; + } + // copy output storage types + for (int i = 0; i < num_out; i++) { + outstypes[i] = out_stypes[i]; + } + + return retval; +} + +/*! \brief returns status of calling Forward/Backward function for operator from library */ +MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp, const char* const* keys, + const char* const* vals, + int num, const int64_t** inshapes, int* indims, void** indata, + int* intypes, size_t* inIDs, const char** indev_type, int* indev_id, + int num_in, const int64_t** outshapes, int* outdims, void** outdata, + int* outtypes, size_t* outIDs, const char** outdev_type, + int* outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, + void* cpu_alloc, + mxnet::ext::xpu_malloc_t gpu_malloc, void* gpu_alloc, + void* cuda_stream, + mxnet::ext::sparse_malloc_t sparse_malloc, void* sparse_alloc, + int* instypes, int* outstypes, void** in_indices, void** out_indices, + void** in_indptr, void** out_indptr, + int64_t* in_indices_shapes, int64_t* out_indices_shapes, + int64_t* in_indptr_shapes, int64_t* out_indptr_shapes, + void* rng_cpu_states, void* rng_gpu_states) { + // create map of attributes from list + std::unordered_map attrs; + for (int i = 0; i < num; i++) { + attrs[std::string(keys[i])] = std::string(vals[i]); + } + + // create a vector of tensors for inputs + std::vector inputs(num_in); + // create a vector for sparse inputs + std::vector in_sparse(num_in); + + for (int i = 0; i < num_in; i++) { + // Dense representation. + if (instypes[i] == 0) { + inputs[i].setTensor(indata[i], (mxnet::ext::MXDType)intypes[i], inshapes[i], indims[i], + inIDs[i], mxnet::ext::MXContext(indev_type[i], indev_id[i]), + mxnet::ext::kDefaultStorage); + } else { + // Sparse representation. + mxnet::ext::MXStorageType type; + if (instypes[i] == 1) { + type = mxnet::ext::kRowSparseStorage; + in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]); + } else { + type = mxnet::ext::kCSRStorage; + in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], + in_indices_shapes[i], in_indptr[i], in_indptr_shapes[i]); + } + inputs[i].setTensor(reinterpret_cast(&in_sparse[i]), (mxnet::ext::MXDType)intypes[i], + inshapes[i], indims[i], inIDs[i], + mxnet::ext::MXContext(indev_type[i], indev_id[i]), type); + } + } + + // create a vector of tensors for outputs + std::vector outputs(num_out); + std::vector out_sparse(num_out); + + for (int i = 0; i < num_out; i++) { + // Dense representation. + if (outstypes[i] == 0) { + outputs[i].setTensor(outdata[i], (mxnet::ext::MXDType)outtypes[i], outshapes[i], outdims[i], + outIDs[i], mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), + mxnet::ext::kDefaultStorage); + } else { + // Sparse representation. + mxnet::ext::MXStorageType type; + if (outstypes[i] == 1) { + type = mxnet::ext::kRowSparseStorage; + out_sparse[i].set(outdata[i], outshapes[i], outdims[i], + out_indices[i], out_indices_shapes[i]); + } else { + type = mxnet::ext::kCSRStorage; + out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i], + out_indices_shapes[i], out_indptr[i], out_indptr_shapes[i]); + } + outputs[i].setTensor(reinterpret_cast(&out_sparse[i]), + (mxnet::ext::MXDType)outtypes[i], + outshapes[i], outdims[i], outIDs[i], + mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), type); + } + } + + mxnet::ext::OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, + cuda_stream, sparse_malloc, sparse_alloc, + rng_cpu_states, rng_gpu_states); + return fcomp(attrs, &inputs, &outputs, res); +} + +/*! \brief returns status of calling mutateInputs function for operator from library */ +MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char* const* keys, + const char* const* vals, int num, + int** mutate_indices, int* indices_size) { + // create map of attributes from list + std::unordered_map attrs; + for (int i = 0; i < num; i++) { + attrs[std::string(keys[i])] = std::string(vals[i]); + } + + // create a vector of mutate input indices + std::vector mut_ind; + + int retval = mutate(attrs, &mut_ind); + if (!retval) + return retval; + + // output the input indices + *indices_size = mut_ind.size(); + *mutate_indices = static_cast(malloc (*indices_size * sizeof(int))); // NOLINT + for (int i = 0; i < *indices_size; i++) { + (*mutate_indices)[i] = mut_ind[i]; + } + + return retval; +} + +/*! \brief returns status of calling createStatefulOp function for operator from library */ +MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char* const* keys, + const char* const* vals, int num, const char* dev_type, + int dev_id, unsigned int** inshapes, int* indims, + int num_in, const int* intypes, void** state_op) { + // create map of attributes from list + std::unordered_map attrs; + for (int i = 0; i < num; i++) { + attrs[std::string(keys[i])] = std::string(vals[i]); + } + + mxnet::ext::MXContext ctx(dev_type, dev_id); + + // create a vector of shapes for inputs + std::vector > in_shapes(num_in); + for (int i = 0; i < num_in; i++) { + for (int j = 0; j < indims[i]; j++) { + in_shapes[i].push_back(inshapes[i][j]); + } + } + + // create a vector of types for inputs + std::vector in_types(num_in); + for (int i = 0; i < num_in; i++) { + in_types[i] = intypes[i]; + } + + // void pointer to hold custom state op instance created in custom library + // eventually state_op pointer is populated by instance from custom library + mxnet::ext::CustomStatefulOp** op_ptr = + reinterpret_cast(state_op); + return create_op(attrs, ctx, in_shapes, in_types, op_ptr); +} + +/*! \brief returns status of calling Stateful Forward/Backward for operator from library */ +MX_INT_RET _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t** inshapes, + int* indims, void** indata, int* intypes, size_t* inIDs, + const char** indev_type, int* indev_id, int num_in, + const int64_t** outshapes, int* outdims, void** outdata, + int* outtypes, size_t* outIDs, const char** outdev_type, + int* outdev_id, int num_out, + mxnet::ext::xpu_malloc_t cpu_malloc, + void* cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc, + void* gpu_alloc, + void* stream, mxnet::ext::sparse_malloc_t sparse_malloc, + void* sparse_alloc, int* instypes, int* outstypes, + void** in_indices, void** out_indices, void** in_indptr, + void** out_indptr, int64_t* in_indices_shapes, + int64_t* out_indices_shapes, int64_t* in_indptr_shapes, + int64_t* out_indptr_shapes, + void* rng_cpu_states, void* rng_gpu_states) { + // create a vector of tensors for inputs + std::vector inputs(num_in); + // create a vector for sparse inputs + std::vector in_sparse(num_in); + + for (int i = 0; i < num_in; i++) { + if (instypes[i] == 0) { + // Dense representation. + inputs[i].setTensor(indata[i], (mxnet::ext::MXDType)intypes[i], inshapes[i], indims[i], + inIDs[i], mxnet::ext::MXContext(indev_type[i], indev_id[i]), + mxnet::ext::kDefaultStorage); + } else { + // Sparse representation. + mxnet::ext::MXStorageType type; + if (instypes[i] == 1) { + type = mxnet::ext::kRowSparseStorage; + in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]); + } else { + type = mxnet::ext::kCSRStorage; + in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], + in_indices_shapes[i], in_indptr[i], in_indptr_shapes[i]); + } + inputs[i].setTensor(reinterpret_cast(&in_sparse[i]), (mxnet::ext::MXDType)intypes[i], + inshapes[i], indims[i], inIDs[i], + mxnet::ext::MXContext(indev_type[i], indev_id[i]), type); + } + } + + // create a vector of tensors for outputs + std::vector outputs(num_out); + // create a vector for sparse outputs + std::vector out_sparse(num_out); + + for (int i = 0; i < num_out; i++) { + if (outstypes[i] == 0) { + // Dense representation. + outputs[i].setTensor(outdata[i], (mxnet::ext::MXDType)outtypes[i], outshapes[i], outdims[i], + outIDs[i], mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), + mxnet::ext::kDefaultStorage); + } else { + // Sparse representation. + mxnet::ext::MXStorageType type; + if (outstypes[i] == 1) { + type = mxnet::ext::kRowSparseStorage; + out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i], + out_indices_shapes[i]); + } else { + type = mxnet::ext::kCSRStorage; + out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i], + out_indices_shapes[i], out_indptr[i], out_indptr_shapes[i]); + } + outputs[i].setTensor(reinterpret_cast(&out_sparse[i]), + (mxnet::ext::MXDType)outtypes[i], + outshapes[i], outdims[i], outIDs[i], + mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), type); + } + } + + mxnet::ext::OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, + stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states); + + mxnet::ext::CustomStatefulOp* op_ptr = + reinterpret_cast(state_op); + if (is_forward) { + return op_ptr->Forward(&inputs, &outputs, res); + } + return op_ptr->Backward(&inputs, &outputs, res); +} + +/*! \brief returns number of partitioners registered in this library */ +MX_INT_RET _partRegSize() { + return mxnet::ext::Registry::get()->size(); +} + +/* returns number of strategies registered for partitioner + * at specified index */ +MX_INT_RET _partRegGetCount(int idx, const char** name) { + mxnet::ext::CustomPartitioner part = + mxnet::ext::Registry::get()->get(idx); + *name = part.name; + return part.strategies.size(); +} + +/*! \brief returns partitioner registration at specified index */ +MX_VOID_RET _partRegGet(int part_idx, int stg_idx, const char** strategy, + mxnet::ext::supportedOps_t* supportedOps, + mxnet::ext::createSelector_t* createSelector, + mxnet::ext::reviewSubgraph_t* reviewSubgraph, const char** op_name) { + mxnet::ext::CustomPartitioner part = + mxnet::ext::Registry::get()->get(part_idx); + *strategy = part.strategies[stg_idx]; + *op_name = part.op_names[stg_idx]; + *supportedOps = part.getSupportedOps(stg_idx); + *createSelector = part.getCreateSelector(stg_idx); + *reviewSubgraph = part.getReviewSubgraph(stg_idx); +} + +/*! \brief returns status of calling supported ops function from library */ +MX_INT_RET _partCallSupportedOps(mxnet::ext::supportedOps_t supportedOps, const char *json, + int num_ids, int *ids, const char* const* opt_keys, + const char* const* opt_vals, int num_opts) { + mxnet::ext::Graph *graph = mxnet::ext::Graph::fromString(json); + // create map of options from list + std::unordered_map opts; + for (int i = 0; i < num_opts; i++) + opts[std::string(opt_keys[i])] = std::string(opt_vals[i]); + + // create array of subgraph IDs for operator support + std::vector _ids(num_ids, -2); + // call user's supportedOps function + mxnet::ext::MXReturnValue retval = supportedOps(graph, &_ids, opts); + if (!retval) return retval; + + // copy bools in ids to ints + for (int i = 0; i < num_ids; i++) + ids[i] = _ids[i]; + + return retval; +} + +/*! \brief returns status of calling create selector function from library */ +MX_INT_RET _partCallCreateSelector(mxnet::ext::createSelector_t createSelector, const char *json, + void** selector, const char* const* opt_keys, + const char* const* opt_vals, int num_opts) { + mxnet::ext::Graph *graph = mxnet::ext::Graph::fromString(json); + // create map of options from list + std::unordered_map opts; + for (int i = 0; i < num_opts; i++) + opts[std::string(opt_keys[i])] = std::string(opt_vals[i]); + + // void pointer to hold selector instance created in custom library + // eventually pointer is populated by instance from custom library + mxnet::ext::CustomOpSelector** sel_ptr = + reinterpret_cast(selector); + + // call user's createSelector function + return createSelector(graph, sel_ptr, opts); +} + +/*! \brief returns status of calling select function from library */ +MX_VOID_RET _partCallSelect(void* sel_inst, int nodeID, int* selected) { + mxnet::ext::CustomOpSelector* sel_ptr = + reinterpret_cast(sel_inst); + *selected = sel_ptr->Select(nodeID); +} + +/*! \brief returns status of calling select input function from library */ +MX_VOID_RET _partCallSelectInput(void* sel_inst, int nodeID, + int input_nodeID, int* selected) { + mxnet::ext::CustomOpSelector* sel_ptr = + reinterpret_cast(sel_inst); + *selected = sel_ptr->SelectInput(nodeID, input_nodeID); +} + +/*! \brief returns status of calling select output function from library */ +MX_VOID_RET _partCallSelectOutput(void* sel_inst, int nodeID, + int output_nodeID, int* selected) { + mxnet::ext::CustomOpSelector* sel_ptr = + reinterpret_cast(sel_inst); + *selected = sel_ptr->SelectOutput(nodeID, output_nodeID); +} + +/*! \brief returns status of calling filter function from library */ +MX_VOID_RET _partCallFilter(void* sel_inst, int* candidates, int num_candidates, + int** keep, int* num_keep) { + mxnet::ext::CustomOpSelector* sel_ptr = + reinterpret_cast(sel_inst); + std::vector candidates_(num_candidates); + for (int i=0; i < num_candidates; i++) { + candidates_[i] = candidates[i]; + } + std::vector keep_; + + sel_ptr->Filter(candidates_, &keep_); + + *num_keep = keep_.size(); + *keep = static_cast(malloc(keep_.size() * sizeof(int))); // NOLINT + for (unsigned i=0; i < keep_.size(); i++) + (*keep)[i] = keep_[i]; +} + +/*! \brief returns status of calling reset selector function from library */ +MX_VOID_RET _partCallReset(void* sel_inst) { + mxnet::ext::CustomOpSelector* sel_ptr = + reinterpret_cast(sel_inst); + sel_ptr->Reset(); +} + +/*! \brief returns status of calling review subgraph function from library */ +MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph, const char *json, + int subgraph_id, int *accept, const char* const* opt_keys, + const char* const* opt_vals, int num_opts, + char*** attr_keys, char*** attr_vals, int *num_attrs, + const char* const* arg_names, int num_args, + void* const* arg_data, const int64_t* const* arg_shapes, + const int* arg_dims, const int* arg_types, + const size_t* arg_IDs, const char* const* arg_dev_type, + const int* arg_dev_id, + const char* const* aux_names, int num_aux, + void* const* aux_data, const int64_t* const* aux_shapes, + const int* aux_dims, const int* aux_types, + const size_t* aux_IDs, const char* const* aux_dev_type, + const int* aux_dev_id) { + mxnet::ext::Graph *subgraph = mxnet::ext::Graph::fromString(json); + bool accept_bool = false; + // create map of attributes from list + std::unordered_map opts; + for (int i = 0; i < num_opts; i++) + opts[std::string(opt_keys[i])] = std::string(opt_vals[i]); + + // create a map of named tensors for args + std::unordered_map args; + for (int i = 0; i < num_args; i++) { + std::vector shapes; + for (int j = 0; j < arg_dims[i]; j++) + shapes.push_back(arg_shapes[i][j]); + + mxnet::ext::MXTensor tensor(arg_data[i], shapes, (mxnet::ext::MXDType)arg_types[i], + arg_IDs[i], mxnet::ext::MXContext(arg_dev_type[i], arg_dev_id[i])); + args[arg_names[i]] = tensor; + } + // create a map of named tensors for aux + std::unordered_map aux; + for (int i = 0; i < num_aux; i++) { + std::vector shapes; + for (int j = 0; j < aux_dims[i]; j++) + shapes.push_back(aux_shapes[i][j]); + + mxnet::ext::MXTensor tensor(aux_data[i], shapes, (mxnet::ext::MXDType)aux_types[i], + aux_IDs[i], mxnet::ext::MXContext(aux_dev_type[i], + aux_dev_id[i])); + aux[aux_names[i]] = tensor; + } + + subgraph->_setParams(&args, &aux); + + std::unordered_map attrs; + mxnet::ext::MXReturnValue retval = reviewSubgraph(subgraph, subgraph_id, &accept_bool, + opts, &attrs); + if (!retval) return retval; + + *accept = accept_bool; + + if (attrs.size() > 0) { + *num_attrs = attrs.size(); + // allocate space for attributes + *attr_keys = static_cast(malloc (*num_attrs * sizeof(char*))); // NOLINT + *attr_vals = static_cast(malloc (*num_attrs * sizeof(char*))); // NOLINT + + // copy attributes + int i = 0; + for (auto kv : attrs) { + (*attr_keys)[i] = static_cast(malloc ((kv.first.size()+1) * sizeof(char))); // NOLINT + (*attr_vals)[i] = static_cast(malloc ((kv.second.size()+1) * sizeof(char))); // NOLINT + snprintf((*attr_keys)[i], kv.first.size()+1, "%s", kv.first.c_str()); + snprintf((*attr_vals)[i], kv.second.size()+1, "%s", kv.second.c_str()); + i++; + } + } + + return retval; +} + +/*! \brief returns number of graph passes registered in this library */ +MX_INT_RET _passRegSize() { + return mxnet::ext::Registry::get()->size(); +} + +/*! \brief returns pass registration at specified index */ +MX_VOID_RET _passRegGet(int pass_idx, mxnet::ext::graphPass_t* graphPass, + const char** pass_name) { + mxnet::ext::CustomPass pass = + mxnet::ext::Registry::get()->get(pass_idx); + *graphPass = pass.pass; + *pass_name = pass.name; +} + +/*! \brief returns status of calling graph pass function from library */ +MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *json, + char** out_graph, const char* const* opt_keys, + const char* const* opt_vals, int num_opts, + const char* pass_name, const char* const* arg_names, int num_args, + void* const* arg_data, const int64_t* const* arg_shapes, + const int* arg_dims, const int* arg_types, + const size_t* arg_IDs, const char* const* arg_dev_type, + const int* arg_dev_id, const char* const* aux_names, int num_aux, + void* const* aux_data, const int64_t* const* aux_shapes, + const int* aux_dims, const int* aux_types, + const size_t* aux_IDs, const char* const* aux_dev_type, + const int* aux_dev_id, mxnet::ext::nd_malloc_t nd_malloc, + const void* nd_alloc) { + mxnet::ext::Graph *graph = mxnet::ext::Graph::fromString(json); + // create map of attributes from list + std::unordered_map opts; + for (int i = 0; i < num_opts; i++) + opts[std::string(opt_keys[i])] = std::string(opt_vals[i]); + + // create a map of named tensors for args + std::unordered_map args; + for (int i = 0; i < num_args; i++) { + std::vector shapes; + for (int j = 0; j < arg_dims[i]; j++) + shapes.push_back(arg_shapes[i][j]); + + mxnet::ext::MXTensor tensor(arg_data[i], shapes, (mxnet::ext::MXDType)arg_types[i], + arg_IDs[i], mxnet::ext::MXContext(arg_dev_type[i], + arg_dev_id[i])); + args[arg_names[i]] = tensor; + } + // create a map of named tensors for aux + std::unordered_map aux; + for (int i = 0; i < num_aux; i++) { + std::vector shapes; + for (int j = 0; j < aux_dims[i]; j++) + shapes.push_back(aux_shapes[i][j]); + + mxnet::ext::MXTensor tensor(aux_data[i], shapes, (mxnet::ext::MXDType)aux_types[i], + aux_IDs[i], mxnet::ext::MXContext(aux_dev_type[i], + aux_dev_id[i])); + aux[aux_names[i]] = tensor; + } + + std::unordered_map new_args, new_aux; + mxnet::ext::PassResource res(&new_args, &new_aux, nd_malloc, nd_alloc); + graph->_setParams(&args, &aux); + graph->_setPassResource(&res); + mxnet::ext::MXReturnValue retval = graphPass(graph, opts); + if (!retval) return retval; + + std::string tmp = graph->toString(); + *out_graph = static_cast(malloc ((tmp.size()+1) * sizeof(char))); // NOLINT + snprintf((*out_graph), tmp.size()+1, "%s", tmp.c_str()); + return retval; +} + +/*! + * \brief Checks if the MXNet version is supported by the library. + * If supported, initializes the library. + * \param version MXNet version number passed to library and defined as: + * MXNET_VERSION = (MXNET_MAJOR*10000 + MXNET_MINOR*100 + MXNET_PATCH) + * \return Non-zero value on error i.e. library incompatible with passed MXNet version + */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) +__declspec(dllexport) mxnet::ext::MXReturnValue __cdecl +#else +mxnet::ext::MXReturnValue +#endif +initialize(int version); + +MX_INT_RET _msgSize() { + return mxnet::ext::MXerrorMsgs::get()->size(); +} + +/*! \brief returns operator registration at specified index */ +MX_VOID_RET _msgGet(int idx, const char** msg) { + *msg = mxnet::ext::MXerrorMsgs::get()->get(idx)->c_str(); +} diff --git a/scripts/bert/setup.py b/scripts/bert/setup.py index 23696e24d6..0ba33a85e6 100644 --- a/scripts/bert/setup.py +++ b/scripts/bert/setup.py @@ -26,12 +26,13 @@ def CompileBERTCustomPass(): pass_path = os.path.dirname(os.path.realpath(__file__)) source = os.path.join(pass_path, input_pass_file) target = os.path.join(pass_path, out_lib_file) - lib_api_cc = pathlib.Path.joinpath(mxnet_path, 'src/lib_api.cc') + # lib_api_cc = pathlib.Path.joinpath(mxnet_path, 'src/lib_api.cc') + lib_api_cc = os.path.join(pass_path, 'lib_api.cc') if (mxnet.__version__ > '1.7.0'): source = source + ' ' + str(lib_api_cc) - print('MXNET ver: ', mxnet.__version__, 'source:', source, 'mxnet_include_path', - mxnet_include_path) - print('lib_api_cc Exist:' + str(os.path.exists(str(lib_api_cc)))) + # print('MXNET ver: ', mxnet.__version__, 'source:', source, 'mxnet_include_path', + # mxnet_include_path) + # print('lib_api_cc Exist:' + str(os.path.exists(str(lib_api_cc)))) os.system('g++ -shared -fPIC -std=c++11 ' + str(source) + ' -o ' + str(target) + ' -I ' + str(mxnet_include_path)) From 309351d3983077bac905d9df4f4575594f1633b5 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 10 Feb 2021 17:08:26 +0000 Subject: [PATCH 9/9] remove debugging prints --- scripts/bert/setup.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/scripts/bert/setup.py b/scripts/bert/setup.py index 0ba33a85e6..04b285bc8d 100644 --- a/scripts/bert/setup.py +++ b/scripts/bert/setup.py @@ -30,9 +30,6 @@ def CompileBERTCustomPass(): lib_api_cc = os.path.join(pass_path, 'lib_api.cc') if (mxnet.__version__ > '1.7.0'): source = source + ' ' + str(lib_api_cc) - # print('MXNET ver: ', mxnet.__version__, 'source:', source, 'mxnet_include_path', - # mxnet_include_path) - # print('lib_api_cc Exist:' + str(os.path.exists(str(lib_api_cc)))) os.system('g++ -shared -fPIC -std=c++11 ' + str(source) + ' -o ' + str(target) + ' -I ' + str(mxnet_include_path))