diff --git a/example/extensions/lib_api/init_lib.cc b/example/extensions/lib_api/init_lib.cc index fb3a10457cf5..0ed43761fe53 100644 --- a/example/extensions/lib_api/init_lib.cc +++ b/example/extensions/lib_api/init_lib.cc @@ -26,12 +26,14 @@ #include #include "lib_api.h" +using namespace mxnet::ext; + MXReturnValue initialize(int version) { if (version >= 10700) { std::cout << "MXNet version " << version << " supported" << std::endl; return MX_SUCCESS; } else { - std::cout << "MXNet version " << version << " not supported" << std::endl; + MX_ERROR_MSG << "MXNet version " << version << " not supported"; return MX_FAIL; } } diff --git a/example/extensions/lib_custom_op/gemm_lib.cc b/example/extensions/lib_custom_op/gemm_lib.cc index 4f8dabadc6a1..59905c896bef 100644 --- a/example/extensions/lib_custom_op/gemm_lib.cc +++ b/example/extensions/lib_custom_op/gemm_lib.cc @@ -26,6 +26,8 @@ #include #include "lib_api.h" +using namespace mxnet::ext; + // main matrix multiplication routine void gemm(const float* A, const float* B, float* C, const unsigned n, const unsigned k, const unsigned m) { @@ -127,12 +129,12 @@ MXReturnValue inferType(const std::unordered_map& attr std::vector *outtypes) { // validate inputs if (intypes->size() != 2) { - std::cout << "Expected 2 inputs to inferType" << std::endl; + MX_ERROR_MSG << "Expected 2 inputs to inferType"; return MX_FAIL; } for (unsigned i = 0; i < intypes->size(); i++) { if (intypes->at(i) != kFloat32) { - std::cout << "Expected input " << i << " to have float32 type" << std::endl; + MX_ERROR_MSG << "Expected input " << i << " to have float32 type"; return MX_FAIL; } } @@ -146,11 +148,11 @@ MXReturnValue inferShape(const std::unordered_map& att std::vector>* outshapes) { // validate inputs if (inshapes->size() != 2) { - std::cout << "Expected 2 inputs to inferShape" << std::endl; + MX_ERROR_MSG << "Expected 2 inputs to inferShape"; return MX_FAIL; } if (inshapes->at(0).size() != 2 || inshapes->at(1).size() != 2) { - std::cout << "Expected 2D matrices for both inputs to inferShape" << std::endl; + MX_ERROR_MSG << "Expected 2D matrices for both inputs to inferShape"; return MX_FAIL; } @@ -159,7 +161,7 @@ MXReturnValue inferShape(const std::unordered_map& att unsigned kk = inshapes->at(1)[0]; unsigned m = inshapes->at(1)[1]; if (k != kk) { - std::cout << "Exected first input axis 1 equals to second input axis 0" << std::endl; + MX_ERROR_MSG << "Exected first input axis 1 equals to second input axis 0"; return MX_FAIL; } @@ -195,8 +197,6 @@ class MyStatefulGemm : public CustomStatefulOp { return backward(attrs_, inputs, outputs, op_res); } - ~MyStatefulGemm() {} - private: int count; const std::unordered_map attrs_; @@ -230,7 +230,7 @@ MXReturnValue initialize(int version) { std::cout << "MXNet version " << version << " supported" << std::endl; return MX_SUCCESS; } else { - std::cout << "MXNet version " << version << " not supported" << std::endl; + MX_ERROR_MSG << "MXNet version " << version << " not supported"; return MX_FAIL; } } diff --git a/example/extensions/lib_custom_op/relu_lib.cu b/example/extensions/lib_custom_op/relu_lib.cu index a4711cbeab67..7022c76e6999 100644 --- a/example/extensions/lib_custom_op/relu_lib.cu +++ b/example/extensions/lib_custom_op/relu_lib.cu @@ -26,6 +26,8 @@ #include #include "lib_api.h" +using namespace mxnet::ext; + #define NumThreadPerBlock 256 // mxnet recommended cuda thread number per block __global__ void relu_gpu_forward(float *out, float *in, int64_t N) { @@ -263,7 +265,7 @@ MXReturnValue initialize(int version) { std::cout << "MXNet version " << version << " supported" << std::endl; return MX_SUCCESS; } else { - std::cout << "MXNet version " << version << " not supported" << std::endl; + MX_ERROR_MSG << "MXNet version " << version << " not supported"; return MX_FAIL; } } diff --git a/example/extensions/lib_custom_op/transposecsr_lib.cc b/example/extensions/lib_custom_op/transposecsr_lib.cc index 224cd6aa81b6..d3941d74c969 100644 --- a/example/extensions/lib_custom_op/transposecsr_lib.cc +++ b/example/extensions/lib_custom_op/transposecsr_lib.cc @@ -26,6 +26,8 @@ #include #include "lib_api.h" +using namespace mxnet::ext; + void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) { MXSparse* A = src.data(); MXSparse* B = dst.data(); @@ -70,11 +72,11 @@ MXReturnValue forward(const std::unordered_map& attrs, // The data types and storage types of inputs and outputs should be the same. if(inputs->at(0).dtype != outputs->at(0).dtype || inputs->at(0).stype != outputs->at(0).stype) { - std::cout << "Error! Expected all inputs and outputs to be the same type." - << "Found input storage type:" << inputs->at(0).stype - << " Found output storage type:" << outputs->at(0).stype - << " Found input data type:" << inputs->at(0).dtype - << " Found output data type:" << outputs->at(0).dtype << std::endl; + MX_ERROR_MSG << "Error! Expected all inputs and outputs to be the same type." + << "Found input storage type:" << inputs->at(0).stype + << " Found output storage type:" << outputs->at(0).stype + << " Found input data type:" << inputs->at(0).dtype + << " Found output data type:" << outputs->at(0).dtype; return MX_FAIL; } @@ -101,11 +103,11 @@ MXReturnValue inferType(const std::unordered_map& attr std::vector* outtypes) { // validate inputs if (intypes->size() != 1) { - std::cout << "Expected 1 inputs to inferType" << std::endl; + MX_ERROR_MSG << "Expected 1 inputs to inferType"; return MX_FAIL; } if (intypes->at(0) != kFloat32) { - std::cout << "Expected input to have float32 type" << std::endl; + MX_ERROR_MSG << "Expected input to have float32 type"; return MX_FAIL; } @@ -117,7 +119,7 @@ MXReturnValue inferSType(const std::unordered_map& att std::vector* instypes, std::vector* outstypes) { if (instypes->at(0) != kCSRStorage) { - std::cout << "Expected storage type is kCSRStorage" << std::endl; + MX_ERROR_MSG << "Expected storage type is kCSRStorage"; return MX_FAIL; } outstypes->at(0) = instypes->at(0); @@ -129,7 +131,7 @@ MXReturnValue inferShape(const std::unordered_map& att std::vector>* outshapes) { // validate inputs if (inshapes->size() != 1) { - std::cout << "Expected 1 inputs to inferShape" << std::endl; + MX_ERROR_MSG << "Expected 1 inputs to inferShape"; return MX_FAIL; } @@ -194,7 +196,7 @@ MXReturnValue initialize(int version) { std::cout << "MXNet version " << version << " supported" << std::endl; return MX_SUCCESS; } else { - std::cout << "MXNet version " << version << " not supported" << std::endl; + MX_ERROR_MSG << "MXNet version " << version << " not supported"; return MX_FAIL; } } diff --git a/example/extensions/lib_custom_op/transposerowsp_lib.cc b/example/extensions/lib_custom_op/transposerowsp_lib.cc index 46d3c4d41a4c..90ad594d556b 100644 --- a/example/extensions/lib_custom_op/transposerowsp_lib.cc +++ b/example/extensions/lib_custom_op/transposerowsp_lib.cc @@ -26,6 +26,8 @@ #include #include "lib_api.h" +using namespace mxnet::ext; + void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) { MXSparse* A = src.data(); MXSparse* B = dst.data(); @@ -73,11 +75,11 @@ MXReturnValue forward(const std::unordered_map& attrs, // The data types and storage types of inputs and outputs should be the same. if(inputs->at(0).dtype != outputs->at(0).dtype || inputs->at(0).stype != outputs->at(0).stype) { - std::cout << "Error! Expected all inputs and outputs to be the same type." - << "Found input storage type:" << inputs->at(0).stype - << " Found output storage type:" << outputs->at(0).stype - << " Found input data type:" << inputs->at(0).dtype - << " Found output data type:" << outputs->at(0).dtype << std::endl; + MX_ERROR_MSG << "Error! Expected all inputs and outputs to be the same type." + << "Found input storage type:" << inputs->at(0).stype + << " Found output storage type:" << outputs->at(0).stype + << " Found input data type:" << inputs->at(0).dtype + << " Found output data type:" << outputs->at(0).dtype; return MX_FAIL; } transpose(inputs->at(0), outputs->at(0), res); @@ -103,11 +105,11 @@ MXReturnValue inferType(const std::unordered_map& attr std::vector* outtypes) { // validate inputs if (intypes->size() != 1) { - std::cout << "Expected 1 inputs to inferType" << std::endl; + MX_ERROR_MSG << "Expected 1 inputs to inferType"; return MX_FAIL; } if (intypes->at(0) != kFloat32) { - std::cout << "Expected input to have float32 type" << std::endl; + MX_ERROR_MSG << "Expected input to have float32 type"; return MX_FAIL; } @@ -119,7 +121,7 @@ MXReturnValue inferSType(const std::unordered_map& att std::vector* instypes, std::vector* outstypes) { if (instypes->at(0) != kRowSparseStorage) { - std::cout << "Expected storage type is kRowSparseStorage" << std::endl; + MX_ERROR_MSG << "Expected storage type is kRowSparseStorage"; return MX_FAIL; } outstypes->at(0) = instypes->at(0); @@ -131,7 +133,7 @@ MXReturnValue inferShape(const std::unordered_map& att std::vector>* outshapes) { // validate inputs if (inshapes->size() != 1) { - std::cout << "Expected 1 inputs to inferShape" << std::endl; + MX_ERROR_MSG << "Expected 1 inputs to inferShape"; return MX_FAIL; } @@ -196,7 +198,7 @@ MXReturnValue initialize(int version) { std::cout << "MXNet version " << version << " supported" << std::endl; return MX_SUCCESS; } else { - std::cout << "MXNet version " << version << " not supported" << std::endl; + MX_ERROR_MSG << "MXNet version " << version << " not supported"; return MX_FAIL; } } diff --git a/example/extensions/lib_pass/README.md b/example/extensions/lib_pass/README.md index c2771242440f..18272c0be436 100644 --- a/example/extensions/lib_pass/README.md +++ b/example/extensions/lib_pass/README.md @@ -32,22 +32,21 @@ To run the following example, the build type of MXNet doesn’t matter since the ### Run An Example -You can start getting familiar with custom passes by running an example provided in the **example/extensions/lib_pass** directory. The `myPass` example just copies the input graph to the output. Go to the **lib_pass** directory and follow these steps: +You can start getting familiar with custom passes by running an example provided in the **example/extensions/lib_pass** directory. The `myPass` example just prints out the graph. Go to the **lib_pass** directory and follow these steps: 1. Run `make`. The Makefile will generate the dynamic library **libpass_lib.so** which is compiled from the `pass_lib.cc` file. This is the library you are going to load that contains everything for the custom pass. -2. Run `python test_pass.py`. It’ll first load the above library, find the components, register them in the MXNet backend, then execute the pass on the model and execute the operators like a regular MXNet operator and output the result. Below is the output when running the `python test_pass.py` command. Notice that it loads 2 passes: myPass and jsonPass. +2. Run `python test_pass.py`. It’ll first load the above library, find the components, register them in the MXNet backend, then execute the pass on the model and execute the operators like a regular MXNet operator and output the result. Below is the output when running the `python test_pass.py` command. Notice that it loads 1 pass: `myPass`. ``` [10:38:03] src/c_api/c_api.cc:286: Found 0 operators in library [10:38:03] src/c_api/c_api.cc:785: Found 0 partitioners in library -[07:14:00] src/c_api/c_api.cc:887: Found 2 graph passes in library +[07:14:00] src/c_api/c_api.cc:887: Found 1 graph passes in library [07:14:00] src/c_api/c_api.cc:902: Graph Pass [0] myPass -[07:14:00] src/c_api/c_api.cc:902: Graph Pass [1] jsonPass ``` ### Basic Files For Custom Pass Library * **lib_pass/pass_lib.cc**: This file has a source code implementation of all required components to make a custom pass, it also shows registration of them so that they can be loaded by MXNet. -* **lib_pass/Makefile**: This file compiles the source code to a dynamic shared library, with a header file `include/mxnet/lib_api.h` from MXNet source code. Currently the custom pass is compatible with C++11 onwards. +* **lib_pass/Makefile**: This file compiles the source code to a dynamic shared library, with a header file `include/mxnet/lib_api.h` from MXNet source code. Currently the custom pass is compatible with C++11 and above. * **lib_pass/test_pass.py**: This file calls `mx.library.load(‘libpass_lib.so’)` to load the library containing the custom components, executes the pass on the model using the `optimize_for` API, and prints outputs of the forward passes. The outputs should be the same as the regular MXNet forward pass without running the pass. * **include/mxnet/lib_api.h**: This file from MXNet source code is the single header file needed to include all necessary data types and function prototypes for writing a custom library. You can either specify the include path in the `Makefile`, or copy the header file over to `example/extensions/lib_pass` folder. Note that apart from this header, the custom library is independent of MXNet source. ## Writing Custom Pass Library @@ -78,38 +77,38 @@ sym_block.optimize_for(x, backend='myPass') ### Using a Custom Pass Library -APIs in MXNet are available in both Symbol and Gluon APIs. For the Symbol API, the `optimize_for` API can be called on Symbol objects to return a new Symbol post graph pass. +APIs in MXNet are available in both Symbol and Gluon APIs. For the Symbol API, `optimize_for` can be called on Symbol objects to run the graph pass and return a new Symbol. -``` -optimize_for(backend, args=None, aux=None, ctx=None, **kwargs) +```python +sym.optimize_for(backend, args=None, aux=None, ctx=None, **kwargs) ``` The `optimize_for` API takes at least 1 argument, `backend` which is a string that identifies which backend to use to optimize the model. The `args` and `aux` arguments are optional and take a list of NDArray or dict of str to NDArray. They are used to infer shapes and types and before executing the graph pass. The `ctx` argument is optional and takes a device context to infer storage types. It also takes any other user-specified options that will be passed to the backend APIs. -For the Gluon API, the `hybridize` API can be called on HybridBlocks to execute a graph pass on the internal CachedOp Symbol. +For the Gluon API, `hybridize` can be called on HybridBlocks to execute a graph pass on the internal CachedOp Symbol. -``` -hybridize(backend=None, backend_opts=None, **kwargs) +```python +block.hybridize(backend=None, backend_opts=None, **kwargs) ``` The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which pass that will be executed on the model. The `backend_opts` takes other user-specified options that will be passed to the backend APIs. The actual pass runs once just before the first the forward pass. If you just want to run a graph pass on the HybridBlock but not run a complete forward pass, you can use the `optimize_for` API that combines the work done in the `hybridize` API with part of the work done in the forward pass. -``` -optimize_for(x, backend=None, backend_opts=None, **kwargs) +```python +block.optimize_for(x, backend=None, backend_opts=None, **kwargs) ``` When the `optimize_for` API is called on a HybridBlock it runs the graph pass immediately. This lets users export the modified model without running a complete forward pass. -``` +```python block.optimize_for(x, backend='myPass') block.export('optimized') ``` But you can also use `optimize_for` in place of `hybridize` and run inference immediately after too. -``` +```python block.optimize_for(x, backend='myPass') block(x) ``` @@ -120,50 +119,80 @@ There are several essential building blocks for making a custom pass: * [initialize](./pass_lib.cc#44): * This function is the library initialization function necessary for any dynamic libraries. It lets you check if the user is using a compatible version of MXNet. Note that this `version` parameter is passed from MXNet when library is loaded. - +```c++ MXReturnValue initialize(int version) - +``` * [graphPass](./pass_lib.cc#31): - * This function provides a copy of the model graph as a JSON string, and provides an interface for returning a modified model JSON string. Also this is where a custom pass can validate the options specified by the user. - + * This function provides a copy of the model graph, and any specific options from the user. +```c++ MXReturnValue graphPass( - const std::string& in_graph, - const std::string** out_graph, - const std::unordered_map& options, - const std::unordered_map& args, - const std::unordered_map& aux, - const PassResource& res) - + mxnet::ext::Graph *g, + const std::unordered_map& options) +``` * [REGISTER_PASS(my_pass_name)](./pass_lib.cc#L41): * This macro registers the custom pass and its properties to MXNet by its name. The argument to `setBody` is the `graphPass` function. - +```c++ REGISTER_PASS(my_pass_name) .setBody(graphPass); - +``` Let’s take a closer look at those registry functions: -* **graphPass**: This function takes six arguments. The 1st argument is a JSON string of the model architecture graph, where nodes are inputs/params/weights and edges are data dependencies. The graph is pre-sorted in topological order. The 2nd argument is a pointer to a pointer of a JSON model string. It is expected users will dereference and assign the address of their output string allocated with `new` and `delete` will be called on it automatically. The third argument is the map of options specified by the user. Users can pass custom options to the pass and they are passed to this function in the `options` map. The fourth and fifth arguments are the named tensor mapping for the args and aux params for the model. They will contain the model params if the user provides them to the `optimize_for` API. The last argument is the `PassResource` object for memory allocation and other utilities. The details of `PassResource` are covered in the section below +* **graphPass**: This function takes two arguments. The first argument is the Graph of the model architecture, where nodes are inputs/params/weights and edges are data dependencies. The second argument is the map of options specified by the user. Users can pass custom options to the pass and they are passed to this function in the `options` map. + +### Graph representation -### Pass Resource +The `Graph` class represents the model's architecture. Each `Node` in the graph represents an operator or weight (ie. args/aux param). Since an operator in MXNet can take multiple inputs and produce multiple outputs, each input/output is represented by a `NodeEntry`. A `Node` contains the following: +- `op` - [string] operator name +- `name` - [string] unique node name +- `inputs` - [vector of NodeEntry] set of inputs to the node +- `outputs` - [vector of NodeEntry] set of outputs from the node +- `subgraph` - [vector of Graph] set of subgraphs in the node +- `attrs` - [map of string to string] set of attributes for the node -Some graph passes require allocating new NDArrays to add/replace model params. The `alloc_arg` and `alloc_aux` APIs enabling allocating new NDArrays and integrating them with the user-provide args and aux params. Both APIs have the following signature: +The `inputs` are a set of `NodeEntry` where each contains a pointer to a `Node` that produces the data, and an `entry` that is the index of the output on the other `Node`. Conversely, the `outputs` are a set of `NodeEntry` where each contains a pointer to a`Node` that consumes the data, and and `entry` that is the index of the input on the other `Node`. This bidirectional dependency will enable you to easily traverse the graph. +A `Graph` contains the following: +- `nodes` - [vector of Node] set of nodes in the graph +- `inputs` - [vector of Node] set of inputs to the graph +- `outputs` - [vector of NodeEntry] set of outputs from the graph +- `attrs` - [map of string to JSON object] set of attributes for the graph + +The `nodes` are all the nodes in the graph (superset). The `inputs` are only those nodes that are model inputs (ie. input image) or weights (ie. arg/aux params). The `outputs` are the outputs from the operators in the model that are true outputs of the model (ie. prediction results). + +Heres an example creating a new node and adding it to the graph: +```c++ +g->addNode("myConv","Convolution"); ``` - MXTensor* alloc_xxx(const std::string& name, - const std::vector& shapes, +Heres an example creating an edge between two nodes: +```c++ +n1->outputs.push_back({n2,1}); +n2->inputs.push_back({n1,0}); +``` +Here node `n1` produces an output at index 0 that is consumed by node `n2` on the input at index 1. + +![example connection](example_connection.png) + +Some graph passes require allocating new NDArrays to add/replace model params. The `alloc_arg` and `alloc_aux` APIs enable allocating new NDArrays and integrate them with the model args and aux params. Both APIs have the following signature: + +```c++ + MXTensor* alloc_xxx(const std::vector& shapes, const MXContext &ctx, MXDType dtype) ``` -If the `name` provided matches the name of an existing param it replaces the previous one. Otherwise it adds a new param to the appropriate arg/aux set. +This function can be called on a node in the graph to allocate a tensor for that node like: + +```c++ +node->alloc_arg({1},MXContext::CPU(0),kFloat32); +``` +It adds a new param to the appropriate arg/aux set when the graph pass returns. If you wish to remove an existing param, just remove the node in the graph corresponding to that param. It will be deleted after the pass completes and removed from the dictionary of args or aux (whichever it is a member of). ### Parsing a JSON string To simplify custom libraries, basic JSON parsing utility functions have been implemented in the `lib_api.h` header file. You create a `JsonParser` object and parse the string by calling the `parse_to_json` API like: ```c++ -JsonParser parser; -JsonVal json_val = parser.parse_to_json(json_string); +JsonVal json_val = JsonVal::parse(json); ``` A `JsonVal` is a class that represents the nodes in a JSON structure. You can check the type of a node (num, str, list, or map) by comparing the `JsonVal.type` to `STR`, `NUM`, `LIST`, or `MAP`. Then you can get that value from the node like: @@ -187,4 +216,4 @@ switch(json_val.type) { } ``` -There are also convenience constructors for creating `JsonVal` objects for strings and numbers like `JsonVal("myKey")` or `JsonVal(42)`. This makes it easy to get specific keys from a map like `json_val.map[JsonVal("nodes")]`. +You call the `dump` function on a `JsonVal` object like `json_val.dump()` to get a JSON-compatible string. There are also convenience constructors for creating `JsonVal` objects for strings and numbers like `JsonVal("myKey")` or `JsonVal(42)`. This makes it easy to get specific keys from a map like `json_val.map[JsonVal("nodes")]`. diff --git a/example/extensions/lib_pass/example_connection.png b/example/extensions/lib_pass/example_connection.png new file mode 100644 index 000000000000..ef56c6228a6f Binary files /dev/null and b/example/extensions/lib_pass/example_connection.png differ diff --git a/example/extensions/lib_pass/pass_lib.cc b/example/extensions/lib_pass/pass_lib.cc index bbdcd73a7a0b..5f5137319999 100644 --- a/example/extensions/lib_pass/pass_lib.cc +++ b/example/extensions/lib_pass/pass_lib.cc @@ -28,77 +28,27 @@ #include #include "lib_api.h" -/* \brief a basic pass that copies the input to the output */ -MXReturnValue myPass(const std::string& in_graph, const std::string** out_graph, - const std::unordered_map& options, - const std::unordered_map& args, - const std::unordered_map& aux, - const PassResource& res) { +using namespace mxnet::ext; + +/* \brief a basic pass that prints out the options and the graph */ +MXReturnValue myPass(mxnet::ext::Graph *g, + const std::unordered_map& options) { for (auto kv : options) { std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl; } - - *out_graph = new std::string(in_graph); + g->print(); return MX_SUCCESS; } REGISTER_PASS(myPass) .setBody(myPass); -/* \brief a basic pass that parses the input string to JSON and then dumps it back */ -MXReturnValue jsonPass(const std::string& in_graph, const std::string** out_graph, - const std::unordered_map& options, - const std::unordered_map& args, - const std::unordered_map& aux, - const PassResource& res) { - for (auto kv : options) - std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl; - - // add test arg/aux - - MXTensor* arg_ = res.alloc_arg("test_arg",{1},MXContext::CPU(0),kFloat32); - MXTensor* aux_ = res.alloc_aux("test_aux",{1},MXContext::CPU(0),kFloat32); - - // convert json string to json object - JsonParser parser; - JsonVal json_val = parser.parse_to_json(in_graph); - - // get nodes list - JsonVal nodes = json_val.map[JsonVal("nodes")]; - - // loop over nodes - for(int i=0; i= 10700) { std::cout << "MXNet version " << version << " supported" << std::endl; return MX_SUCCESS; } else { - std::cout << "MXNet version " << version << " not supported" << std::endl; + MX_ERROR_MSG << "MXNet version " << version << " not supported" << std::endl; return MX_FAIL; } } diff --git a/example/extensions/lib_pass/test_pass.py b/example/extensions/lib_pass/test_pass.py index 8930c9478152..01d6eddac67b 100644 --- a/example/extensions/lib_pass/test_pass.py +++ b/example/extensions/lib_pass/test_pass.py @@ -51,6 +51,7 @@ def test_model(pass_name): # execute in MXNet print('-------------------------------') print('Testing regular MXNet execution') + exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) out = exe.forward() print(out) @@ -95,4 +96,3 @@ def test_model(pass_name): sym_block2.export('modified') test_model('myPass') -test_model('jsonPass') diff --git a/example/extensions/lib_subgraph/README.md b/example/extensions/lib_subgraph/README.md index 6644a1fdc8ff..2752d27a67f4 100644 --- a/example/extensions/lib_subgraph/README.md +++ b/example/extensions/lib_subgraph/README.md @@ -38,11 +38,16 @@ You can start getting familiar with custom partitioners by running an example pr 2. Run `python test_subgraph.py`. It’ll first load the above library, find the components, register them in the MXNet backend, then partition the model and execute the operators like a regular MXNet operator and output the result. Below is the output when running the `python test_subgraph.py` command. Notice that it loads 2 operators: my_gemm and state_gemm. ``` -[10:38:03] src/c_api/c_api.cc:286: Found 1 operators in library -[10:38:03] src/c_api/c_api.cc:350: Op[0] _custom_subgraph_op -[10:38:03] src/c_api/c_api.cc:785: Found 1 partitioners in library -[10:38:03] src/c_api/c_api.cc:801: Partitioner[0] myProp -[10:38:03] src/c_api/c_api.cc:821: Strategy[0] strategy1 subgraphOp: '_custom_subgraph_op' +[02:01:18] src/c_api/c_api.cc:515: Found 1 operators in library +[02:01:18] src/c_api/c_api.cc:580: Op[0] _custom_subgraph_op +[02:01:18] src/c_api/c_api.cc:581: isSubgraphOp +[02:01:18] src/c_api/c_api.cc:1121: Found 2 partitioners in library +[02:01:18] src/c_api/c_api.cc:1137: Partitioner[0] myProp +[02:01:18] src/c_api/c_api.cc:1159: Strategy[0] strategy1 subgraphOp: '_custom_subgraph_op' +[02:01:18] src/c_api/c_api.cc:1137: Partitioner[1] mySelect +[02:01:18] src/c_api/c_api.cc:1159: Strategy[0] strategy1 subgraphOp: '_custom_subgraph_op' +[02:01:18] src/c_api/c_api.cc:1182: Found 1 graph passes in library +[02:01:18] src/c_api/c_api.cc:1197: Graph Pass [0] addInputPass ``` ### Basic Files For Custom Partitioner Library @@ -91,38 +96,39 @@ In the Gluon hybridize flow, the model is actually hybridized during the first i ### Using a Custom Partitioner Library -Partitioning APIs in MXNet are available in both Symbol and Gluon APIs. For the Symbol API, the `optimize_for` API can be called on Symbol objects to return a partitioned Symbol. +Partitioning APIs in MXNet are available in both Symbol and Gluon APIs. For the Symbol API, `optimize_for` can be called on Symbol objects to return a partitioned Symbol. -``` -optimize_for(backend, args=None, aux=None, ctx=None, **kwargs) +```python +sym.optimize_for(backend, args=None, aux=None, ctx=None, **kwargs) ``` The `optimize_for` API takes at least 1 argument, `backend` which is a string that identifies which backend to partition the model for. The `args` and `aux` arguments are optional and take a list of NDArray or dict of str to NDArray. They are used to infer shapes and types and before partitioning, and passed to the backend to use during compilation. The `ctx` argument is optional and takes a device context to infer storage types. It also takes any other user-specified options that will be passed to the backend partitioning APIs. -For the Gluon API, the `hybridize` API can be called on HybridBlocks to partition the internal CachedOp Symbol. +For the Gluon API, `hybridize` can be called on HybridBlocks to partition the internal CachedOp Symbol. -``` -hybridize(backend=None, backend_opts=None, **kwargs) +```python +block.hybridize(backend=None, backend_opts=None, clear=True, **kwargs) ``` -The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which backend that will partition the model. The `backend_opts` takes other user-specified options that will be passed to the backend partitioning APIs. The actual partitioning takes place during the forward pass. +The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which backend that will partition the model. The `backend_opts` are other user-specified options (as a Python dictionary of strings mapped to strings) that will be passed to the backend partitioning APIs. The `clear` argument defaults to `True` and clears any previous optimizations done on the block. If you want to chain optimizations together, set `clear` to `False`. The actual partitioning takes place during the forward pass. If you want to use `hybridize` to chain multiple optimizations, be sure to execute a forward pass after each call to `hybridize`. If you just want to partition the HybridBlock but not run a complete forward pass, you can use the `optimize_for` API that combines the work done in the `hybridize` API with part of the work done in the forward pass. -``` -optimize_for(x, backend=None, backend_opts=None, **kwargs) +```python +block.optimize_for(x, backend=None, backend_opts=None, clear=True, **kwargs) ``` -When the `optimize_for` API is called on a HybridBlock it partitions immediately. This lets users export the partitioned model without running a complete forward pass. +When the `optimize_for` API is called on a HybridBlock it partitions immediately. This lets users export the partitioned model without running a complete forward pass. Chaining multiple optimizations is as simple as calling `optimize_for` multiple times, no need to execute a forward pass (as opposed to `hybridize`). -``` +```python block.optimize_for(x, backend='myPart') +block.optimize_for(x, backend='myOtherPart', clear=False) block.export('partitioned') ``` But you can also use `optimize_for` in place of `hybridize` and run inference immediately after too. -``` +```python block.optimize_for(x, backend='myPart') block(x) ``` @@ -133,44 +139,105 @@ There are several essential building blocks for making a custom partitioner: * [initialize](./subgraph_lib.cc#L261): * This function is the library initialization function necessary for any dynamic libraries. It lets you check if the user is using a compatible version of MXNet. Note that this `version` parameter is passed from MXNet when library is loaded. - +```c++ MXReturnValue initialize(int version) - +``` * [supportedOps](./subgraph_lib.cc#L179): - * This function provides a copy of the model graph as a JSON string, and provides an interface for identifying which operators should be partitioned into a subgraph. Also this is where a custom partitioner can validate the options specified by the user. - + * This function provides a copy of the model Graph, and an interface for identifying which operators should be partitioned into a subgraph. Also this is where a custom partitioner can validate the options specified by the user. +```c++ MXReturnValue supportedOps( - std::string json, - std::vector& ids, - std::unordered_map& options) - + const mxnet::ext::Graph* graph, + std::vector* ids, + const std::unordered_map& options) +``` * [REGISTER_PARTITIONER(my_part_name)](./subgraph_lib.cc#L257): - * This macro registers the custom partitioner and its properties to MXNet by its name. Notice that a partitioner can have multiple partitioning strategies. This enables multiple *passes* to be run in a single partitioning call from the user. The first argument to `addStrategy` is a user-specified name. The second argument is the `supportedOps` function. The third argument is the name of the subgraph operator to create for each subgraph created during partitioning (see below for more info about subgraph operators). The `setReviewSubgraph` API registers a callback function that is called for each subgraph created during partitioning (more on this below). Notice that the first argument to this function is the strategy to associate with and the second argument is the `reviewSubgraph` function. - + * This macro registers the custom partitioner and its properties to MXNet by its name. Notice that a partitioner can have multiple partitioning strategies. This enables multiple *passes* to be run in a single partitioning call from the user. The first argument to `addStrategy` is a user-specified name. The second argument is the name of the subgraph operator to create for each subgraph created during partitioning (see below for more info about subgraph operators). The `setSupportedOps` API registers the `supportedOps` function. The `setReviewSubgraph` API registers a callback function that is called for each subgraph created during partitioning (more on this below). Notice that the first argument to this function is the strategy to associate with and the second argument is the `reviewSubgraph` function. +```c++ REGISTER_PARTITIONER(my_part_name) - .addStrategy("strategy1", supportedOps, "_custom_subgraph_op") + .addStrategy("strategy1", "_custom_subgraph_op") + .setSupportedOps("strategy1", supportedOps) .setReviewSubgraph("strategy1", reviewSubgraph); - - +``` Also there are some optional functions you can specify: * [reviewSubgraph](./subgraph_lib.cc#L219): * This function provides an opportunity to accept/reject a subgraph after MXNet partitions it. It also allows specifying custom attributes on the subgraph (ie. user-generated IDs). If you do not register this function, subgraphs will be accepted by default. - +```c++ MXReturnValue reviewSubgraph( - std::string json, + const mxnet::ext::Graph* subgraph, int subgraph_id, bool* accept, - std::unordered_map& options, - std::unordered_map& attrs, - std::map& args, - std::map& aux) - + const std::unordered_map& options) +``` Let’s take a closer look at those registry functions: -* **supportedOps**: This function takes four arguments. The 1st argument is a JSON string of the model architecture graph, where nodes are inputs/params/weights and edges are data dependencies. The graph is pre-sorted in topological order. The 2nd argument is an array of booleans, one for each operator in the model. When traversing the graph, operators to be partitioned into subgraphs are identified and an entry is set to `true` for the index in the `ids` array corresponding to the node ID. The last argument is the map of options specified by the user. Users can pass custom options to the partitioner and they are passed to this function in the `options` map. +* **supportedOps**: This function takes 3 arguments. The 1st argument is the model architecture graph, where nodes are inputs/params/weights and edges are data dependencies. The graph is pre-sorted in topological order. The 2nd argument is an array of integers, one for each operator in the model. When traversing the graph, operators to be partitioned into subgraphs are identified and an entry is set to a value for the index in the `ids` array corresponding to the node ID. Setting a non-negative value (ie. [0, MAX_INT]) indicates the operator should be partitioned into that specific subgraph. Setting a value of -1 indicates that the operator can be partitioned into any subgraph. The last argument is the map of options specified by the user. Users can pass custom options to the partitioner and they are passed to this function in the `options` map. + +* **reviewSubgraph**: This function takes four arguments. The 1st argument is the newly partitioned subgraph. The 2nd argument is the subgraph ID, this is just a number MXNet uses to identify this particular subgraph (it starts at zero and increments, unique for each subgraph in the model). The 3rd argument is an output to be set in this function to tell MXNet whether to accept (value: `true`) or reject (value: `false`) the subgraph. You might want to reject a subgraph if it doesnt include all the operators you want, for example. The `options` map is the same one passed to the `supportedOps` API. The 4th argument is the map of options specified by the user. Any custom attributes set on the Graph object will be available later at runtime, and provides a mechanisn to pass info from partition-time to runtime. For inputs to the subgraph that come directly from the params/weights of the model, you can access the raw tensor data directly from that node in the graph. + +### Writing a Custom Selector +Instead of implementing the `supportedOps` API, you can choose to implement a custom selector class for more control over partitioning instead. -* **reviewSubgraph**: This function takes five arguments. The 1st argument is a JSON string of the newly partitioned subgraph. The 2nd argument is the subgraph ID, this is just a number MXNet uses to identify this particular subgraph (it starts at zero and increments, unique for each subgraph in the model). The 3rd argument is an output to be set in this function to tell MXNet whether to accept (value: `true`) or reject (value: `false`) the subgraph. You might want to reject a subgraph if it doesnt include all the operators you want, for example. The `options` map is the same one passed to the `supportedOps` API. The 4th argument is the map of options specified by the user. The 5th argument is a map of attributes that should be set on the created subgraph. These attributes will be available later at runtime, and provides a mechanisn to pass info from partition-time to runtime. The last argument is the map of params/weights/args to the model and the associated names. For inputs the the subgraph that come directly from the params/weights of the model, you can look up the name of the input in this map to get the actual tensor values. +* [createSelector](./subgraph_lib.cc#L321): + * This function provides a copy of the model graph as the first argument. The 2nd argument is a placeholder for CustomOpSelector object. You must define a class that inherits from the `CustomOpSelector` class and override the required functions. Then you need to create an instance of your class and assign it to the placeholder. The last argument is a map of user-specified options. +```c++ + MXReturnValue createSelector( + const mxnet::ext::Graph *graph, + CustomOpSelector** sel_inst, + const std::unordered_map& options) +``` +Instead of registering a `supportedOps` API, register the `setCreateSelector` API. +```c++ + REGISTER_PARTITIONER(my_part_name) + .addStrategy("strategy1", "_custom_subgraph_op") + .setCreateSelector("strategy1", createSelector) + .setReviewSubgraph("strategy1", reviewSubgraph); +``` +When implementing your own selector class, you must inherit from the `CustomOpSelector` class and define the following APIs: +* [Select](./subgraph_lib.cc#L301): + * This function selects a node to include in a subgraph by the index of the node (`nodeID`) in the graph. Return `true` to include this node or `false` to reject this node. +```c++ + bool Select( + int nodeID) +``` +* [SelectInput](./subgraph_lib.cc#L304): + * This function grows the subgraph from a node (`nodeID`) to a node that produces one of its inputs (`input_nodeID`). Return `true` to include this node (`input_nodeID`) or `false` to reject this node. +```c++ + bool SelectInput( + int nodeID, + int input_nodeID) +``` +* [SelectOutput](./subgraph_lib.cc#L304): + * This function grows the subgraph from a node (`nodeID`) to a node that consumes one of its outputs (`output_nodeID`). Return `true` to include this node (`output_nodeID`) or `false` to reject this node. +```c++ + bool SelectOutput( + int nodeID, + int output_nodeID) +``` +All of these APIs refer to the model's graph that is provided to the `createSelector` API. When you implement your custom `createSelector` function, you can pass the graph and options to the constructor of your class like this: +```c++ +MXReturnValue myCreateSelector(const mxnet::ext::Graph *graph, + CustomOpSelector** sel_inst, + const std::unordered_map& options) { + *sel_inst = new MySelector(graph, options); + return MX_SUCCESS; +} +``` +In addition to the 3 required APIs shown above, you can also implement the following optional APIs for your `CustomOpSelector` class: +* [Filter](./subgraph_lib.cc#L310): + * This function enables reviewing the candidate nodes to include in subgraph. The `candidates` are the indices of nodes in the graph to be included in the subgraph. The 2nd argument `keep` is an empty vector to be filled with the indices of nodes you wish to keep in the subgraph. Any remaining candidate nodes not added to `keep` will be excluded from the subgraph. The following function body shows the default behavior when not overloaded, to keep all candidates: +```c++ + void Filter( + std::vector& candidates, + std::vector& keep) { + keep.insert(keep.end(), candidates.begin(), candidates.end()); + } +``` +* [Reset](./subgraph_lib.cc#L314): + * This function provides an opportunity to reset any selector state between subgraphs. It is called after growing subgraph, and before `Filter`. There is no default behavior. +```c++ + virtual void Reset() {} +``` ### Writing A Custom Subgraph Operator @@ -178,19 +245,31 @@ A partitioning strategy specifies how to partition a model and isolate operators When registering a custom subgraph operator, all thats needed is to register a `createOpState` function and to set that the operator is a subgraph operator by calling the `setIsSubgraphOp` API like: -``` +```c++ REGISTER_OP(my_subgraph_op) .setIsSubgraphOp() .setCreateOpState(createOpState, "cpu"); ``` +### Converting a JSON string encoded graph + +A Graph object can be created from a JSON string containing a graph/subgraph like: + +```c++ +mxnet::ext::Graph* g = mxnet::ext::Graph::fromString(json); +``` + +It can be converted back to a JSON string just as easily: +```c++ +std::string json = g->toString(); +``` + ### Parsing a JSON string To simplify custom partitioner libraries, basic JSON parsing utility functions have been implemented in the `lib_api.h` header file. You create a `JsonParser` object and parse the string by calling the `parse_to_json` API like: ```c++ -JsonParser parser; -JsonVal json_val = parser.parse_to_json(json_string); +JsonVal json_val = JsonVal::parse(json); ``` A `JsonVal` is a class that represents the nodes in a JSON structure. You can check the type of a node (num, str, list, or map) by comparing the `JsonVal.type` to `STR`, `NUM`, `LIST`, or `MAP`. Then you can get that value from the node like: @@ -214,4 +293,4 @@ switch(json_val.type) { } ``` -There are also convenience constructors for creating `JsonVal` objects for strings and numbers like `JsonVal("myKey")` or `JsonVal(42)`. This makes it easy to get specific keys from a map like `json_val.map[JsonVal("nodes")]`. \ No newline at end of file +You call the `dump` function on a `JsonVal` object like `json_val.dump()` to get a JSON-compatible string. There are also convenience constructors for creating `JsonVal` objects for strings and numbers like `JsonVal("myKey")` or `JsonVal(42)`. This makes it easy to get specific keys from a map like `json_val.map[JsonVal("nodes")]`. diff --git a/example/extensions/lib_subgraph/subgraph_lib.cc b/example/extensions/lib_subgraph/subgraph_lib.cc index 28442078ebe6..2f954e092152 100644 --- a/example/extensions/lib_subgraph/subgraph_lib.cc +++ b/example/extensions/lib_subgraph/subgraph_lib.cc @@ -28,19 +28,21 @@ #include #include "lib_api.h" +using namespace mxnet::ext; + /* function to execute log operator on floats */ -void myLog(MXTensor &in, MXTensor &out) { - float* inp = in.data(); - float* outp = out.data(); - for (int64_t i = 0; i < in.size(); i++) { +void myLog(MXTensor *in, MXTensor *out) { + float* inp = in->data(); + float* outp = out->data(); + for (int64_t i = 0; i < in->size(); i++) { outp[i] = logf(inp[i]); } } /* function to execute exp operator on floats */ -void myExp(MXTensor &in, MXTensor &out) { - float* inp = in.data(); - float* outp =out.data(); - for (int64_t i = 0; i < in.size(); i++) { +void myExp(MXTensor *in, MXTensor *out) { + float* inp = in->data(); + float* outp =out->data(); + for (int64_t i = 0; i < in->size(); i++) { outp[i] = expf(inp[i]); } } @@ -52,15 +54,10 @@ void myExp(MXTensor &in, MXTensor &out) { */ MXReturnValue myExecutor(std::vector* inputs, std::vector* outputs, - const std::string& subgraph_sym) { - std::cout << "Info: subgraph symbol is: " << std::endl; - std::cout << subgraph_sym << std::endl; + mxnet::ext::Graph *subgraph) { + std::cout << "Info: subgraph is: " << std::endl; + subgraph->print(); - // convert json string to json object - JsonParser parser; - JsonVal json_val = parser.parse_to_json(subgraph_sym); - // get nodes list - JsonVal nodes = json_val.map[JsonVal("nodes")]; //counter for inputs int input_cnt = 0; // temporary tensor storage @@ -69,41 +66,40 @@ MXReturnValue myExecutor(std::vector* inputs, std::vector to_free; // loop over nodes - for(int i=0; isize(); i++) { + mxnet::ext::Node* node = subgraph->getNode(i); // handle each op type - if (op.compare("null") == 0) { - // null is an input data to the subgraph, add to data storage - data.push_back(inputs->at(input_cnt++)); - } else if (op.compare("log") == 0) { + if (node->op.compare("null") == 0) { + // set tensor for this input to the subgraph + node->tensor = &inputs->at(input_cnt++); + } else if (node->op.compare("log") == 0) { // get input tensor based on node ID inputs from data storage - MXTensor &input = data[node_inputs.list[0].list[0].num]; + MXTensor *input = node->inputs.at(0).node->tensor; // create temporary storage - MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, MXContext::CPU(0), kDefaultStorage); + MXTensor tmp(malloc(input->size()*4), input->shape, input->dtype, 0, MXContext::CPU(0), kDefaultStorage); // NOLINT // save allocated ptr to free later to_free.push_back(tmp.data_ptr); // execute log operator - myLog(input,tmp); + myLog(input,&tmp); // add output tensor to data storage data.push_back(tmp); - } else if (op.compare("exp") == 0) { + // set tensor for this node so we can read it later + node->tensor = &data.back(); + } else if (node->op.compare("exp") == 0) { // get input tensor based on node ID inputs from data storage - MXTensor &input = data[node_inputs.list[0].list[0].num]; + MXTensor *input = node->inputs.at(0).node->tensor; // create temporary storage - MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, MXContext::CPU(0), kDefaultStorage); + MXTensor tmp(malloc(input->size()*4), input->shape, input->dtype, 0, MXContext::CPU(0), kDefaultStorage); // NOLINT // save allocated ptr to free later to_free.push_back(tmp.data_ptr); // execute exp operator - myExp(input,tmp); + myExp(input,&tmp); // add output tensor to data storage data.push_back(tmp); + // set tensor for this node so we can read it later + node->tensor = &data.back(); } else { - std::cout << "Error! Unsupported op '" << op << "' found in myExecutor"; + MX_ERROR_MSG << "Error! Unsupported op '" << node->op << "' found in myExecutor"; // free allocated temporary storage for (void* ptr : to_free) free(ptr); @@ -111,18 +107,16 @@ MXReturnValue myExecutor(std::vector* inputs, } } - // get list of outputs from subgraph - JsonVal heads = json_val.map[JsonVal("heads")]; // copy all operator results to outputs of subgraph - for (int j = 0; j < heads.list.size(); j++) { + for (int j = 0; j < subgraph->outputs.size(); j++) { // get computed result - MXTensor &result = data[heads.list[0].list[0].num]; + MXTensor *result = subgraph->outputs[j].node->tensor; // get output tensor to pass to MX MXTensor &out = outputs->at(j); float *out_data = out.data(); - float *res_data = result.data(); + float *res_data = result->data(); // loop and copy data - for (int64_t i = 0; i < result.size(); i++) { + for (int64_t i = 0; i < result->size(); i++) { out_data[i] = res_data[i]; } } @@ -137,22 +131,26 @@ MXReturnValue myExecutor(std::vector* inputs, class MyStatefulOp : public CustomStatefulOp { public: - explicit MyStatefulOp(const std::string& sym, + explicit MyStatefulOp(std::string json, const std::unordered_map& attrs) - : subgraph_sym(sym), attrs_(attrs) { - for (auto kv : attrs) { + : attrs_(attrs) { + for (const auto &kv : attrs) { std::cout << "subgraphOp attributes: " << kv.first << " ==> " << kv.second << std::endl; } + subgraph_ = mxnet::ext::Graph::fromString(json); } MXReturnValue Forward(std::vector* inputs, std::vector* outputs, - const OpResource& op_res) { - return myExecutor(inputs, outputs, subgraph_sym); + const OpResource& op_res) override { + if(attrs_.count(MX_STR_EXTRA_INPUTS) > 0 && std::stoi(attrs_.at(MX_STR_EXTRA_INPUTS)) > 0) + std::cout << "forward::extra_inputs(" << attrs_.at(MX_STR_EXTRA_INPUTS) << ")::inputs [" + << inputs->size() << "]" << std::endl; + return myExecutor(inputs, outputs, subgraph_); } private: - const std::string subgraph_sym; + mxnet::ext::Graph *subgraph_; const std::unordered_map attrs_; }; @@ -176,39 +174,30 @@ REGISTER_OP(_custom_subgraph_op) const std::vector op_names({"exp","log"}); -MXReturnValue mySupportedOps(const std::string& json, +MXReturnValue mySupportedOps(const mxnet::ext::Graph* graph, std::vector* ids, const std::unordered_map& options) { for (auto kv : options) { std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl; } - //convert json string to json object - JsonParser parser; - JsonVal json_val = parser.parse_to_json(json); - //get nodes list - JsonVal nodes = json_val.map[JsonVal("nodes")]; //loop over nodes - for(int i=0; isize(); i++) { + const mxnet::ext::Node *node = graph->getNode(i); //get shape/type if available std::string shape; int dtype = -1; - if(node.map.find(JsonVal("attrs")) != node.map.end()) { - JsonVal attrs = node.map[JsonVal("attrs")]; - if(attrs.map.find(JsonVal("shape")) != attrs.map.end()) - shape = attrs.map[JsonVal("shape")].str; - if(attrs.map.find(JsonVal("dtype")) != attrs.map.end()) - dtype = std::stoi(attrs.map[JsonVal("dtype")].str); - } + if(node->attrs.count("shape") > 0) + shape = node->attrs.at("shape"); + if(node->attrs.count("dtype") > 0) + dtype = std::stoi(node->attrs.at("dtype")); //check if op dtype is float, and if option was specified to require float types if((dtype == kFloat32 && options.count("reqFloat") > 0) || options.count("reqFloat") == 0) { - //check if op is in whitelist - if(std::find(op_names.begin(),op_names.end(),op.str.c_str()) != op_names.end()) { - // found op in whitelist, set value to -1 to include op in any subgraph + //check if op is in allowlist + if(std::find(op_names.begin(),op_names.end(),node->op.c_str()) != op_names.end()) { + // found op in allowlist, set value to -1 to include op in any subgraph ids->at(i) = -1; } } @@ -216,30 +205,11 @@ MXReturnValue mySupportedOps(const std::string& json, return MX_SUCCESS; } -MXReturnValue myReviewSubgraph(const std::string& json, int subgraph_id, bool* accept, - const std::unordered_map& options, - std::unordered_map* attrs, - const std::unordered_map& args, - const std::unordered_map& aux) { +MXReturnValue myReviewSubgraph(const mxnet::ext::Graph *subgraph, int subgraph_id, bool* accept, + const std::unordered_map& options) { for (auto kv : options) { std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl; } - for (auto kv : args) { - std::cout << "arg: " << kv.first << " ==> ("; - for (auto s : kv.second.shape) - std::cout << s << ","; - std::cout << ") ["; - for (int i=0; i()[i] << ", "; - std::cout << "]" << std::endl; - } - - // check if option `reqArgs` was specified, and if so check if args were provided - if(options.count("reqArgs") > 0 && args.size() == 0) { - *accept = false; - std::cout << "rejecting subgraph since args were not provided" << std::endl; - return MX_SUCCESS; - } // check if option `reject` was specified, and if so check if value is 'True' if(options.count("reject") > 0 && options.at("reject").compare("True") == 0) { @@ -249,7 +219,6 @@ MXReturnValue myReviewSubgraph(const std::string& json, int subgraph_id, bool* a } else { *accept = true; std::cout << "accepting subgraph" << std::endl; - attrs->insert(std::pair("myKey","myVal")); } return MX_SUCCESS; } @@ -261,39 +230,30 @@ REGISTER_PARTITIONER(myProp) class MySelector : public CustomOpSelector { public: - MySelector(const std::string& json, + MySelector(const mxnet::ext::Graph *graph, const std::unordered_map& options) : - graph_json(json), options_(options) { + graph_(graph), options_(options) { for (auto kv : options) { std::cout << "selector options: " << kv.first << " ==> " << kv.second << std::endl; } - //convert json string to json object - JsonParser parser; - JsonVal json_val = parser.parse_to_json(json); - //get nodes list - nodes = json_val.map[JsonVal("nodes")]; } bool chooseNode(int nodeID) { - JsonVal node = nodes.list[nodeID]; - JsonVal op = node.map[JsonVal("op")]; + const mxnet::ext::Node *node = graph_->getNode(nodeID); //get shape/type if available std::string shape; int dtype = -1; - if(node.map.find(JsonVal("attrs")) != node.map.end()) { - JsonVal attrs = node.map[JsonVal("attrs")]; - if(attrs.map.find(JsonVal("shape")) != attrs.map.end()) - shape = attrs.map[JsonVal("shape")].str; - if(attrs.map.find(JsonVal("dtype")) != attrs.map.end()) - dtype = std::stoi(attrs.map[JsonVal("dtype")].str); - } + if(node->attrs.count("shape") > 0) + shape = node->attrs.at("shape"); + if(node->attrs.count("dtype") > 0) + dtype = std::stoi(node->attrs.at("dtype")); //check if op dtype is float, and if option was specified to require float types if((dtype == kFloat32 && options_.count("reqFloat") > 0) || options_.count("reqFloat") == 0) { - //check if op is in whitelist - if(std::find(op_names.begin(),op_names.end(),op.str.c_str()) != op_names.end()) { - // found op in whitelist, return true to include op subgraph + //check if op is in allowlist + if(std::find(op_names.begin(),op_names.end(),node->op.c_str()) != op_names.end()) { + // found op in allowlist, return true to include op subgraph return true; } } @@ -314,14 +274,13 @@ class MySelector : public CustomOpSelector { } virtual void Reset() {} private: - std::string graph_json; - JsonVal nodes; + const mxnet::ext::Graph *graph_; const std::unordered_map options_; }; -MXReturnValue createSelector(const std::string& json, CustomOpSelector** sel_inst, +MXReturnValue createSelector(const mxnet::ext::Graph *graph, CustomOpSelector** sel_inst, const std::unordered_map& options) { - *sel_inst = new MySelector(json, options); + *sel_inst = new MySelector(graph, options); std::cout << "Info: selector created" << std::endl; return MX_SUCCESS; } @@ -331,12 +290,41 @@ REGISTER_PARTITIONER(mySelect) .setCreateSelector("strategy1", createSelector) .setReviewSubgraph("strategy1", myReviewSubgraph); +/* \brief a basic pass that adds a new input for subgraph ops */ +MXReturnValue addInputPass(mxnet::ext::Graph *graph, + const std::unordered_map& options) { + //find node with '_custom_subgraph_op' op type + for(int i=0; isize(); i++) { + mxnet::ext::Node* n = graph->getNode(i); + if(n->op.compare("_custom_subgraph_op") == 0) { + //set extra input + n->attrs[MX_STR_EXTRA_INPUTS] = std::to_string(1); + + //create a new input Node + Node* input = graph->addNode(n->name + "_input", "null"); + //set this node as an input in the graph + graph->inputs.push_back(input); + //connect new input to node + input->outputs.push_back({n,(int)(n->inputs.size())}); + //connect node to new input + n->inputs.push_back({input,0}); + // add a corresponding tensor for this input + input->alloc_arg({1},MXContext::CPU(0),kFloat32); + } + } + + return MX_SUCCESS; +} + +REGISTER_PASS(addInputPass) +.setBody(addInputPass); + MXReturnValue initialize(int version) { if (version >= 10700) { std::cout << "MXNet version " << version << " supported" << std::endl; return MX_SUCCESS; } else { - std::cout << "MXNet version " << version << " not supported" << std::endl; + MX_ERROR_MSG << "MXNet version " << version << " not supported by custom library" << std::endl; return MX_FAIL; } } diff --git a/example/extensions/lib_subgraph/test_subgraph.py b/example/extensions/lib_subgraph/test_subgraph.py index eb7102a1511c..267a417d92f2 100644 --- a/example/extensions/lib_subgraph/test_subgraph.py +++ b/example/extensions/lib_subgraph/test_subgraph.py @@ -93,8 +93,8 @@ def test(backend): sym_block = nn.SymbolBlock(sym, inputs) sym_block.initialize() sym_block.hybridize(backend=backend) - out4 = sym_block(mx.nd.ones((3,2)),mx.nd.ones((3,2))) - print(out4) + out2 = sym_block(mx.nd.ones((3,2)),mx.nd.ones((3,2))) + print(out2) # Gluon Hybridize partitioning with shapes/types without inference print('-------------------------------') @@ -105,6 +105,13 @@ def test(backend): sym_block2.optimize_for(mx.nd.ones((3,2)), mx.nd.ones((3,2)), backend=backend) sym_block2.export('partitioned') + # Test with additional input to subgraph op + print('-------------------------------') + print('Testing %s Gluon Hybridize partitioning with extra input' % backend) + sym_block2.optimize_for(mx.nd.ones((3,2)), mx.nd.ones((3,2)), backend="addInputPass", clear=False) + out3 = sym_block2(mx.nd.ones((3,2)),mx.nd.ones((3,2))) + print(out3) + ############################################### # Test with subgraph directly consuming params ############################################### diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h index 5674733d250a..3367bc661c12 100644 --- a/include/mxnet/lib_api.h +++ b/include/mxnet/lib_api.h @@ -37,12 +37,15 @@ #include #include #include +#include #include #include #include #include #include +#include #include +#include #if defined(__NVCC__) #include @@ -210,6 +213,9 @@ extern "C" { #endif #endif +namespace mxnet { +namespace ext { + /*! * \brief Tensor data type, consistent with mshadow data type */ @@ -462,12 +468,14 @@ typedef std::mt19937 mx_cpu_rand_t; #define MX_NUM_CPU_RANDOM_STATES 1024 #define MX_NUM_GPU_RANDOM_STATES 32768 +/* \brief Class to help allocate new args/aux params in graph passes */ class PassResource { public: PassResource(std::unordered_map* new_args, std::unordered_map* new_aux, nd_malloc_t nd_malloc, const void* nd_alloc) : new_args_(new_args), new_aux_(new_aux), nd_malloc_(nd_malloc), nd_alloc_(nd_alloc) {} + // allocate new arg param, adds to args map, returns newly allocated tensor MXTensor* alloc_arg(const std::string& name, const std::vector& shapes, const MXContext &ctx, MXDType dtype) const { void* data; @@ -477,6 +485,7 @@ class PassResource { (*new_args_)[name] = tensor; return &(new_args_->at(name)); } + // allocate new aux param, adds to aux map, returns newly allocated tensor MXTensor* alloc_aux(const std::string& name, const std::vector& shapes, const MXContext &ctx, MXDType dtype) const { void* data; @@ -557,10 +566,14 @@ class OpResource { void *rand_cpu_states, *rand_gpu_states; }; -/*! \brief Macro to help passing serialized subgraph through attribute dict */ +/*! \brief attribute key to help passing serialized subgraph through subgraph op attribute */ #define MX_STR_SUBGRAPH_SYM_JSON "subgraph_sym_json" +/*! \brief dtype attribute key for ops after type propagation */ #define MX_STR_DTYPE "__ext_dtype__" +/*! \brief shape attribute key for ops after shape propagation */ #define MX_STR_SHAPE "__ext_shape__" +/*! \brief extra input attribute key for ops */ +#define MX_STR_EXTRA_INPUTS "__ext_extra_inputs__" /* \brief get shape value from list of shapes string * @@ -638,52 +651,50 @@ struct JsonVal { } return type < o.type; } - JsonType type; - int num; - std::string str; - std::vector list; - std::map map; -}; -/*! \brief functions used for parsing JSON */ -struct JsonParser { - JsonVal parse_to_json(const std::string& json) { - unsigned int idx = 0; - return parse(json, &idx); - } - void print_json_val(const JsonVal& val) { - std::cout << json_val_string(val) << std::endl; - } - // debug function to dump data structure to string - std::string json_val_string(const JsonVal &val) { + // convert JSON object back to JSON-compatible string + std::string dump() const { std::string ret; - switch (val.type) { + switch (type) { case ERR: ret = "json(Error)"; break; case STR: - ret = "json(STR:" + val.str + ")"; + ret = "\"" + str + "\""; break; case NUM: - ret = "json(INT:" + val.str + ")"; + ret = str; break; case LIST: - ret = "json(LIST:["; - for (auto &item : val.list) - ret += json_val_string(item) + ","; - ret += "])"; + ret = "["; + for (unsigned i=0; i < list.size(); i++) { + auto &item = list[i]; + ret += item.dump(); + if (i < list.size()-1) + ret += ","; + } + ret += "]"; break; case MAP: - ret = "json(MAP:{"; - for (auto &item : val.map) - ret += json_val_string(item.first) + " : " + json_val_string(item.second) + ","; - ret += "})"; + ret = "{"; + unsigned cnt = 0; + for (auto &item : map) { + ret += item.first.dump() + " : " + item.second.dump(); + if (cnt++ < map.size()-1) + ret += ","; + } + ret += "}"; break; } return ret; } + // convert JSON-compatible string to JSON object + static JsonVal parse(const std::string& json) { + unsigned int idx = 0; + return JsonVal::parse(json, &idx); + } // parse a string JSON object - JsonVal parse_string(const std::string& json, unsigned int* idx) { + static JsonVal parse_string(const std::string& json, unsigned int* idx) { JsonVal ret(STR); while (*idx < json.size()) { if (json[*idx] == '"') { @@ -698,7 +709,7 @@ struct JsonParser { return JsonVal(); } // parse a number JSON object - JsonVal parse_num(const std::string& json, unsigned int* idx) { + static JsonVal parse_num(const std::string& json, unsigned int* idx) { JsonVal ret(NUM); while (*idx < json.size()) { if (json[*idx] >= '0' && json[*idx] <= '9') { @@ -712,14 +723,14 @@ struct JsonParser { return ret; } // parse a list of JSON objects - JsonVal parse_list(const std::string& json, unsigned int* idx) { + static JsonVal parse_list(const std::string& json, unsigned int* idx) { JsonVal ret(LIST); while (*idx < json.size()) { if (json[*idx] == ']') { ++(*idx); return ret; } else { - JsonVal item = parse(json, idx); + JsonVal item = JsonVal::parse(json, idx); if (item.type != ERR) ret.list.push_back(item); } @@ -728,14 +739,14 @@ struct JsonParser { return JsonVal(); } // parse a map of JSON objects - JsonVal parse_map(const std::string& json, unsigned int* idx) { + static JsonVal parse_map(const std::string& json, unsigned int* idx) { JsonVal ret(MAP), key; while (*idx < json.size()) { if (json[*idx] == '}') { ++(*idx); return ret; } else { - JsonVal item = parse(json, idx); + JsonVal item = JsonVal::parse(json, idx); if (key.type == ERR) { key = item; } else { @@ -748,62 +759,410 @@ struct JsonParser { return JsonVal(); } // generic parse function - JsonVal parse(const std::string& json, unsigned int *idx) { + static JsonVal parse(const std::string& json, unsigned int *idx) { JsonVal ret; while (*idx < json.size()) { if (json[*idx] == '"') { ++(*idx); - ret = parse_string(json, idx); + ret = JsonVal::parse_string(json, idx); } else if (json[*idx] >= '0' && json[*idx] <= '9') { - ret = parse_num(json, idx); + ret = JsonVal::parse_num(json, idx); } else if (json[*idx] == '[') { ++(*idx); - ret = parse_list(json, idx); + ret = JsonVal::parse_list(json, idx); } else if (json[*idx] == '{') { ++(*idx); - ret = parse_map(json, idx); + ret = JsonVal::parse_map(json, idx); } else if (json[*idx] == ']' || json[*idx] == '}') {return ret;} if (ret.type != ERR) return ret; ++(*idx); } return ret; } - // convert JSON object back to JSON-compatible string - std::string dump(const JsonVal &val) { + // debug function to convert data structure to a debugstring + std::string toString() const { std::string ret; - switch (val.type) { + switch (type) { case ERR: ret = "json(Error)"; break; case STR: - ret = "\"" + val.str + "\""; + ret = "json(STR:" + str + ")"; break; case NUM: - ret = val.str; + ret = "json(INT:" + str + ")"; break; case LIST: - ret = "["; - for (unsigned i=0; i < val.list.size(); i++) { - auto &item = val.list[i]; - ret += dump(item); - if (i < val.list.size()-1) - ret += ","; - } - ret += "]"; + ret = "json(LIST:["; + for (auto &item : list) + ret += item.toString() + ","; + ret += "])"; break; case MAP: - ret = "{"; - unsigned cnt = 0; - for (auto &item : val.map) { - ret += dump(item.first) + " : " + dump(item.second); - if (cnt++ < val.map.size()-1) - ret += ","; - } - ret += "}"; + ret = "json(MAP:{"; + for (auto &item : map) + ret += item.first.toString() + " : " + item.second.toString() + ","; + ret += "})"; break; } return ret; } + JsonType type; + int num; + std::string str; + std::vector list; + std::map map; +}; + +/*! + * \brief Graph utility to parse serialized subgraph symbol + */ +class Node; +class Graph; + +// Representation of an input/output to a node +struct NodeEntry { + Node* node; // other node thats producing/consuming inputs/outputs + int entry; // entry index from other node (ie. output index from producing node) +}; + +// Representation of a node in the graph +class Node { + public: + Node() {tensor = nullptr;} + // internally set passResource to enable tensor allocation for graph passes + void _setPassResource(PassResource* res_) {res = res_;} + /* \brief allocate an arg tensor for this node */ + void alloc_arg(const std::vector& shapes, + const MXContext &ctx, MXDType dtype) { + if (!res) + throw std::runtime_error( + "Node not initialized. Cannot use alloc_arg outside of graph passes."); + tensor = res->alloc_arg(name, shapes, ctx, dtype); + } + /* \brief allocate an aux tensor for this node */ + void alloc_aux(const std::vector& shapes, + const MXContext &ctx, MXDType dtype) { + if (!res) + throw std::runtime_error( + "Node not initialized. Cannot use alloc_aux outside of graph passes."); + tensor = res->alloc_aux(name, shapes, ctx, dtype); + } + std::string op; // operator name (ie. Convolution) + std::string name; // unique node name (ie. conv_0 or conv_1) + MXTensor* tensor; // tensor data for input nodes + std::vector inputs; // set of inputs to the node + std::vector outputs; // set of outputs from the node + std::vector subgraphs; // set of subgraphs within this node + std::unordered_map attrs; // node attributes + + private: + PassResource* res; +}; + +// Representation of the graph +class Graph { + public: + Graph() : res(nullptr) {} + /* \brief deleted nodes when deleting the graph */ + ~Graph() { + for (int i = 0; i < nodes.size(); i++) + delete nodes[i]; + } + + /* \brief create a graph object from an unparsed string */ + static Graph* fromString(const std::string& json) { + JsonVal val = JsonVal::parse(json); + return fromJson(val); + } + + /* \brief create a graph object from a parsed JSON object */ + static Graph* fromJson(JsonVal val) { + // get nodes list + JsonVal nodes = val.map[JsonVal("nodes")]; + Graph *g = new Graph(); + + std::map nodeMap; + // loop over nodes + for (int i = 0; i < nodes.list.size(); i++) { + Node* n = new Node(); + g->nodes.push_back(n); + JsonVal node = nodes.list[i]; + + // set the op info + n->op = node.map[JsonVal("op")].str; + n->name = node.map[JsonVal("name")].str; + + // if op is null it is an input to the graph + if (n->op.compare("null") == 0) + g->inputs.push_back(n); + + // set attrs + JsonVal attributes = node.map[JsonVal("attrs")]; + for (auto& kv : attributes.map) { + n->attrs[kv.first.str] = kv.second.str; + } + + // set subgraphs, parsing each into a graph + if (node.map.count(JsonVal("subgraphs")) > 0) { + JsonVal subgraphs = node.map[JsonVal("subgraphs")]; + for (auto &subgraph : subgraphs.list) { + n->subgraphs.push_back(fromJson(subgraph)); + } + } + + // set node inputs + JsonVal node_inputs = node.map[JsonVal("inputs")]; + n->inputs.resize(node_inputs.list.size()); + for (int j = 0; j < node_inputs.list.size(); j++) { + JsonVal input = node_inputs.list[j]; + NodeEntry& entry = n->inputs[j]; + // get pointer to other node + entry.node = nodeMap[input.list[0].num]; + // get the other node's output index + entry.entry = input.list[1].num; + // set other nodes output as connected to this node + entry.node->outputs.push_back({n, j}); + } + nodeMap[i] = n; + } + + // set graph level outputs + JsonVal& heads = val.map[JsonVal("heads")]; + g->outputs.resize(heads.list.size()); + for (int i = 0; i < heads.list.size(); i++) { + JsonVal head = heads.list[i]; + g->outputs[i].node = nodeMap[head.list[0].num]; + g->outputs[i].entry = head.list[1].num; + } + + // add all attributes to the graph + for (auto& kv : val.map) { + if (kv.first.str.compare("nodes") != 0 && + kv.first.str.compare("heads") != 0 && + kv.first.str.compare("node_row_ptr") != 0 && + kv.first.str.compare("arg_nodes") != 0) { + g->attrs[kv.first.str] = kv.second; + } + } + return g; + } + + /* \brief convert graph object back to JSON object */ + JsonVal toJson() { + // top level object is a map + JsonVal val(MAP); + + // add attributes + for (auto& kv : attrs) { + val.map[JsonVal(kv.first)] = kv.second; + } + + // sort graph nodes in topological order, create mapping of node to index + std::map nodeMap; + std::vector sorted = topological_sort(); + // nodes are in reverse topological order in the vector (back is first) + // so loop from end to front over the vector 'sorted' + for (int i = sorted.size()-1; i >= 0; i--) { + nodeMap[sorted[i]] = sorted.size()-1-i; + } + + // create node_row_ptr entry + val.map[JsonVal("node_row_ptr")] = JsonVal(LIST); + JsonVal& node_row_ptr = val.map[JsonVal("node_row_ptr")]; + for (int i = 0; i < nodes.size(); i++) + node_row_ptr.list.push_back(JsonVal(i)); + + // add all input nodes + val.map[JsonVal("arg_nodes")] = JsonVal(LIST); + JsonVal& arg_nodes = val.map[JsonVal("arg_nodes")]; + for (int i = 0; i < inputs.size(); i++) + arg_nodes.list.push_back(JsonVal(nodeMap[inputs[i]])); + + // add all output nodes + val.map[JsonVal("heads")] = JsonVal(LIST); + JsonVal& heads = val.map[JsonVal("heads")]; + for (int i = 0; i < outputs.size(); i++) { + heads.list.push_back(JsonVal(LIST)); + JsonVal& out = heads.list[i]; + out.list.push_back(JsonVal(nodeMap[outputs[i].node])); + out.list.push_back(JsonVal(outputs[i].entry)); + out.list.push_back(JsonVal(0)); + } + + // add all graph nodes + val.map[JsonVal("nodes")] = JsonVal(LIST); + JsonVal& nodes_ = val.map[JsonVal("nodes")]; + for (int i = sorted.size()-1; i >= 0; i--) { + // each node is a map + nodes_.list.push_back(JsonVal(MAP)); + Node* n = sorted[i]; + JsonVal& n_ = nodes_.list[nodes_.list.size()-1]; + + n_.map[JsonVal("op")] = JsonVal(n->op); + n_.map[JsonVal("name")] = JsonVal(n->name); + n_.map[JsonVal("inputs")] = JsonVal(LIST); + + // add inputs for this node + JsonVal& inputs_ = n_.map[JsonVal("inputs")]; + for (int j = 0; j < n->inputs.size(); j++) { + inputs_.list.push_back(JsonVal(LIST)); + NodeEntry& entry = n->inputs[j]; + JsonVal& in = inputs_.list[j]; + in.list.push_back(JsonVal(nodeMap[entry.node])); + in.list.push_back(JsonVal(entry.entry)); + in.list.push_back(JsonVal(0)); + } + + // add subgraphs for this node, convert each back to JSON + if (n->subgraphs.size() > 0) { + n_.map[JsonVal("subgraphs")] = JsonVal(LIST); + JsonVal &subgraphs_ = n_.map[JsonVal("subgraphs")]; + for (Graph *subgraph : n->subgraphs) { + subgraphs_.list.push_back(subgraph->toJson()); + } + } + + // add attributes for this node + n_.map[JsonVal("attrs")] = JsonVal(MAP); + JsonVal& attrs_ = n_.map[JsonVal("attrs")]; + for (auto& kv : n->attrs) { + attrs_.map[JsonVal(kv.first)] = JsonVal(kv.second); + } + } + return val; + } + + /* \brief convert graph object to JSON string */ + std::string toString() { + return toJson().dump(); + } + + /* \brief visits a node "n" */ + void _dfs_util(Node* n, std::unordered_set* to_visit, + std::function handler) const { + to_visit->erase(n); // remove node now that we're visiting it + for (NodeEntry& e : n->outputs) { + Node* o = e.node; + if (to_visit->count(o) != 0) { + _dfs_util(o, to_visit, handler); // visit neighbor + } + } + handler(n); // post-order visit this node + } + + /* \brief post-order DFS graph traversal */ + void DFS(std::function handler) const { + std::unordered_set to_visit; + // put all nodes in set to visit + for (auto& n : nodes) + to_visit.insert(n); + // visit all inputs first + for (auto& i : inputs) + if (to_visit.count(i) != 0) + _dfs_util(i, &to_visit, handler); + // visit any nodes left + while (to_visit.size() > 0) + _dfs_util(*(to_visit.begin()), &to_visit, handler); + } + + /* \brief sort graph nodes in topological order */ + std::vector topological_sort() const { + std::vector sorted; + auto handler = [&](Node* n) { + sorted.push_back(n); // when visiting each node, add it in order to the vector + }; + DFS(handler); + return sorted; + } + + /* \brief print out graph details */ + void print(int indent = 0) const { + std::string space = ""; + for (int i = 0; i < indent; i++) space+=" "; + + std::cout << space << "########### Graph #############" << std::endl; + std::cout << space << "attributes: " << std::endl; + for (auto &kv : attrs) + std::cout << space << "\t" << kv.first << " : " << kv.second.str << std::endl; + std::cout << space << "inputs: " << inputs.size() << std::endl; + std::cout << space << "outputs: " << outputs.size() << std::endl; + std::cout << space << "nodes: " << nodes.size() << std::endl; + std::vector sorted = topological_sort(); + // loop over each node and print out its inputs/outputs + for (int i = sorted.size()-1; i >= 0; i--) { + std::cout << space << "Node: " << sorted[i]->name << std::endl; + for (int j = 0; j < sorted[i]->inputs.size(); j++) { + std::cout << space << "\tInput: " << sorted[i]->inputs[j].node->name << " " + << sorted[i]->inputs[j].entry << std::endl; + } + for (int j = 0; j < sorted[i]->outputs.size(); j++) { + std::cout << space << "\tOutput: " << sorted[i]->outputs[j].node->name << " " + << sorted[i]->outputs[j].entry << std::endl; + } + if (sorted[i]->subgraphs.size() > 0) { + for (auto &subgraph : sorted[i]->subgraphs) { + std::cout << space << "\tSubgraph:" << std::endl; + subgraph->print(indent+2); + } + } + } + std::cout << space << "###############################" << std::endl; + } + + /* \brief add a new node to this graph */ + Node* addNode(const std::string& name, const std::string& op) { + Node* n = new Node(); + n->name = name; + n->op = op; + if (res) + n->_setPassResource(res); + return n; + } + /* \brief get node at index in graph */ + Node* getNode(size_t idx) { + return nodes[idx]; + } + /* \brief get const node at index in const graph */ + const Node* getNode(size_t idx) const { + return nodes.at(idx); + } + /* \brief get attribute on graph */ + const JsonVal& getAttr(const std::string& key) const { + return attrs.at(key); + } + /* \brief get number of nodes in the graph */ + size_t size() const { + return nodes.size(); + } + // internally set passResource to enable tensor allocation for graph passes + void _setPassResource(PassResource* res_) {res = res_;} + // internally set arg/aux params when available + void _setParams(std::unordered_map* args, + std::unordered_map* aux) { + // set params for each input node + for (Node* node : inputs) { + if (args->count(node->name) > 0) + node->tensor = &args->at(node->name); + else if (aux->count(node->name) > 0) + node->tensor = &aux->at(node->name); + } + + if (res) { + // set passResource for each node + for (Node* node : nodes) { + node->_setPassResource(res); + } + } + } + + std::vector inputs; + std::vector outputs; + std::map attrs; + + private: + std::vector nodes; + PassResource* res; }; /* \brief An abstract class for library authors creating custom @@ -993,11 +1352,8 @@ class CustomOp { }; /*! \brief Custom Pass Create function template */ -typedef MXReturnValue (*graphPass_t)(const std::string& in_graph, const std::string** out_graph, - const std::unordered_map& options, - const std::unordered_map& args, - const std::unordered_map& aux, - const PassResource& res); +typedef MXReturnValue (*graphPass_t)(mxnet::ext::Graph* graph, + const std::unordered_map& options); /*! * \brief An abstract class for graph passes @@ -1019,18 +1375,17 @@ class CustomPass { }; /*! \brief Custom Subgraph Create function template */ -typedef MXReturnValue (*supportedOps_t)(const std::string& json, std::vector* ids, +typedef MXReturnValue (*supportedOps_t)(const mxnet::ext::Graph *graph, std::vector* ids, const std::unordered_map& options); -typedef MXReturnValue (*createSelector_t)(const std::string& json, CustomOpSelector** sel_inst, +typedef MXReturnValue (*createSelector_t)(const mxnet::ext::Graph *graph, + CustomOpSelector** sel_inst, const std::unordered_map& options); -typedef MXReturnValue (*reviewSubgraph_t)(const std::string& json, int subgraph_id, bool* accept, +typedef MXReturnValue (*reviewSubgraph_t)(const mxnet::ext::Graph *subgraph, int subgraph_id, + bool* accept, const std::unordered_map& options, - std::unordered_map* attrs, - const std::unordered_map& args, - const std::unordered_map& aux); + std::string>& options); /*! * \brief An abstract class for subgraph property @@ -1165,6 +1520,44 @@ class Registry { MX_STR_CONCAT(MX_REGISTER_PASS_DEF_(Name), __COUNTER__) = \ Registry::get()->add(MX_TOSTRING(Name)) +/* \brief Class to store error messages from extensions to pass to MXNet */ +class MXerrorMsgs { + public: + /*! + * \brief get singleton pointer to class + * \returns pointer to class + */ + static MXerrorMsgs* get() { + static MXerrorMsgs inst; + return &inst; + } + /*! + * \brief add a new error message + */ + std::stringstream& add(const char* file, int line) { + messages.push_back(std::stringstream()); + messages.back() << file << "[" << line << "]: "; + return messages.back(); + } + int size() { + return messages.size(); + } + const std::string* get(int idx) { + return new std::string(messages.at(idx).str()); + } + + private: + /*! \brief constructor */ + MXerrorMsgs() {} + /*! \brief destructor */ + ~MXerrorMsgs() {} + /*! \brief map of entries in registry */ + std::vector messages; +}; + +// Add a new error message, example: MX_ERROR_MSG << "my error msg"; +#define MX_ERROR_MSG MXerrorMsgs::get()->add(__FILE__, __LINE__) + /* -------------- BELOW ARE CTYPE FUNCTIONS PROTOTYPES --------------- */ /*! @@ -1177,12 +1570,13 @@ typedef int (*opRegSize_t)(void); #define MXLIB_OPREGGET_STR "_opRegGet" typedef int (*opRegGet_t)(int idx, const char** name, int *isSGop, - const char*** forward_ctx, fcomp_t** forward_fp, int* forward_count, - const char*** backward_ctx, fcomp_t** backward_fp, int* backward_count, - const char*** create_op_ctx, createOpState_t** create_op_fp, - int* create_op_count, - parseAttrs_t* parse, inferType_t* type, inferSType_t* stype, - inferShape_t* shape, mutateInputs_t* mutate); + const char*** forward_ctx, mxnet::ext::fcomp_t** forward_fp, + int* forward_count, const char*** backward_ctx, + mxnet::ext::fcomp_t** backward_fp, int* backward_count, + const char*** create_op_ctx, mxnet::ext::createOpState_t** create_op_fp, + int* create_op_count, mxnet::ext::parseAttrs_t* parse, + mxnet::ext::inferType_t* type, mxnet::ext::inferSType_t* stype, + mxnet::ext::inferShape_t* shape, mxnet::ext::mutateInputs_t* mutate); #define MXLIB_OPCALLFREE_STR "_opCallFree" typedef int (*opCallFree_t)(void* ptr); @@ -1343,6 +1737,12 @@ typedef int (*initialize_t)(int version); #define MXLIB_OPVERSION_STR "_opVersion" typedef int (*opVersion_t)(); +#define MXLIB_MSGSIZE_STR "_msgSize" +typedef int (*msgSize_t)(void); + +#define MXLIB_MSGGET_STR "_msgGet" +typedef int (*msgGet_t)(int idx, const char** msg); + #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) #define MX_INT_RET __declspec(dllexport) int __cdecl #define MX_VOID_RET __declspec(dllexport) void __cdecl @@ -1351,6 +1751,9 @@ typedef int (*opVersion_t)(); #define MX_VOID_RET void #endif +} // namespace ext +} // namespace mxnet + extern "C" { /*! \brief returns MXNet library version */ MX_INT_RET _opVersion() { @@ -1359,18 +1762,19 @@ extern "C" { /*! \brief returns number of ops registered in this library */ MX_INT_RET _opRegSize() { - return Registry::get()->size(); + return mxnet::ext::Registry::get()->size(); } /*! \brief returns operator registration at specified index */ MX_VOID_RET _opRegGet(int idx, const char** name, int *isSGop, - const char*** forward_ctx, fcomp_t** forward_fp, + const char*** forward_ctx, mxnet::ext::fcomp_t** forward_fp, int* forward_count, const char*** backward_ctx, - fcomp_t** backward_fp, int* backward_count, - const char*** create_op_ctx, createOpState_t** create_op_fp, - int* create_op_count, parseAttrs_t* parse, inferType_t* type, - inferSType_t* stype, inferShape_t* shape, mutateInputs_t* mutate) { - CustomOp &op = Registry::get()->get(idx); + mxnet::ext::fcomp_t** backward_fp, int* backward_count, + const char*** create_op_ctx, mxnet::ext::createOpState_t** create_op_fp, + int* create_op_count, mxnet::ext::parseAttrs_t* parse, + mxnet::ext::inferType_t* type, mxnet::ext::inferSType_t* stype, + mxnet::ext::inferShape_t* shape, mxnet::ext::mutateInputs_t* mutate) { + mxnet::ext::CustomOp &op = mxnet::ext::Registry::get()->get(idx); *name = op.name; *parse = op.parse_attrs; *type = op.infer_type; @@ -1396,7 +1800,7 @@ extern "C" { } /*! \brief returns status of calling parse attributes function for operator from library */ - MX_INT_RET _opCallParseAttrs(parseAttrs_t parseAttrs, const char* const* keys, + MX_INT_RET _opCallParseAttrs(mxnet::ext::parseAttrs_t parseAttrs, const char* const* keys, const char* const* vals, int num, int* num_in, int* num_out) { // create map of attributes from list @@ -1409,7 +1813,7 @@ extern "C" { } /*! \brief returns status of calling inferShape function for operator from library */ - MX_INT_RET _opCallInferShape(inferShape_t inferShape, const char* const* keys, + MX_INT_RET _opCallInferShape(mxnet::ext::inferShape_t inferShape, const char* const* keys, const char* const* vals, int num, unsigned int** inshapes, int* indims, int num_in, unsigned int*** mod_inshapes, int** mod_indims, @@ -1464,7 +1868,7 @@ extern "C" { } /*! \brief returns status of calling inferType function for operator from library */ - MX_INT_RET _opCallInferType(inferType_t inferType, const char* const* keys, + MX_INT_RET _opCallInferType(mxnet::ext::inferType_t inferType, const char* const* keys, const char* const* vals, int num, int* intypes, int num_in, int* outtypes, int num_out) { // create map of attributes from list @@ -1499,7 +1903,7 @@ extern "C" { } /*! \brief returns status of calling inferSType function for operator from library */ - MX_INT_RET _opCallInferSType(inferSType_t inferSType, const char* const* keys, + MX_INT_RET _opCallInferSType(mxnet::ext::inferSType_t inferSType, const char* const* keys, const char* const* vals, int num, int* instypes, int num_in, int* outstypes, int num_out) { // create map of attributes from list @@ -1535,14 +1939,17 @@ extern "C" { } /*! \brief returns status of calling Forward/Backward function for operator from library */ - MX_INT_RET _opCallFCompute(fcomp_t fcomp, const char* const* keys, const char* const* vals, + MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp, const char* const* keys, + const char* const* vals, int num, const int64_t** inshapes, int* indims, void** indata, int* intypes, size_t* inIDs, const char** indev_type, int* indev_id, int num_in, const int64_t** outshapes, int* outdims, void** outdata, int* outtypes, size_t* outIDs, const char** outdev_type, - int* outdev_id, int num_out, xpu_malloc_t cpu_malloc, void* cpu_alloc, - xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream, - sparse_malloc_t sparse_malloc, void* sparse_alloc, + int* outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, + void* cpu_alloc, + mxnet::ext::xpu_malloc_t gpu_malloc, void* gpu_alloc, + void* cuda_stream, + mxnet::ext::sparse_malloc_t sparse_malloc, void* sparse_alloc, int* instypes, int* outstypes, void** in_indices, void** out_indices, void** in_indptr, void** out_indptr, int64_t* in_indices_shapes, int64_t* out_indices_shapes, @@ -1555,66 +1962,70 @@ extern "C" { } // create a vector of tensors for inputs - std::vector inputs(num_in); + std::vector inputs(num_in); // create a vector for sparse inputs - std::vector in_sparse(num_in); + std::vector in_sparse(num_in); for (int i = 0; i < num_in; i++) { // Dense representation. if (instypes[i] == 0) { - inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i], - inIDs[i], MXContext(indev_type[i], indev_id[i]), kDefaultStorage); + inputs[i].setTensor(indata[i], (mxnet::ext::MXDType)intypes[i], inshapes[i], indims[i], + inIDs[i], mxnet::ext::MXContext(indev_type[i], indev_id[i]), + mxnet::ext::kDefaultStorage); } else { // Sparse representation. - MXStorageType type; + mxnet::ext::MXStorageType type; if (instypes[i] == 1) { - type = kRowSparseStorage; + type = mxnet::ext::kRowSparseStorage; in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]); } else { - type = kCSRStorage; + type = mxnet::ext::kCSRStorage; in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i], in_indptr[i], in_indptr_shapes[i]); } - inputs[i].setTensor(reinterpret_cast(&in_sparse[i]), (MXDType)intypes[i], + inputs[i].setTensor(reinterpret_cast(&in_sparse[i]), (mxnet::ext::MXDType)intypes[i], inshapes[i], indims[i], inIDs[i], - MXContext(indev_type[i], indev_id[i]), type); + mxnet::ext::MXContext(indev_type[i], indev_id[i]), type); } } // create a vector of tensors for outputs - std::vector outputs(num_out); - std::vector out_sparse(num_out); + std::vector outputs(num_out); + std::vector out_sparse(num_out); for (int i = 0; i < num_out; i++) { // Dense representation. if (outstypes[i] == 0) { - outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i], - outIDs[i], MXContext(outdev_type[i], outdev_id[i]), kDefaultStorage); + outputs[i].setTensor(outdata[i], (mxnet::ext::MXDType)outtypes[i], outshapes[i], outdims[i], + outIDs[i], mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), + mxnet::ext::kDefaultStorage); } else { // Sparse representation. - MXStorageType type; + mxnet::ext::MXStorageType type; if (outstypes[i] == 1) { - type = kRowSparseStorage; + type = mxnet::ext::kRowSparseStorage; out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i], out_indices_shapes[i]); } else { - type = kCSRStorage; + type = mxnet::ext::kCSRStorage; out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i], out_indices_shapes[i], out_indptr[i], out_indptr_shapes[i]); } - outputs[i].setTensor(reinterpret_cast(&out_sparse[i]), (MXDType)outtypes[i], + outputs[i].setTensor(reinterpret_cast(&out_sparse[i]), + (mxnet::ext::MXDType)outtypes[i], outshapes[i], outdims[i], outIDs[i], - MXContext(outdev_type[i], outdev_id[i]), type); + mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), type); } } - OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, - cuda_stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states); + mxnet::ext::OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, + cuda_stream, sparse_malloc, sparse_alloc, + rng_cpu_states, rng_gpu_states); return fcomp(attrs, &inputs, &outputs, res); } /*! \brief returns status of calling mutateInputs function for operator from library */ - MX_INT_RET _opCallMutateInputs(mutateInputs_t mutate, const char* const* keys, + MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char* const* keys, const char* const* vals, int num, int** mutate_indices, int* indices_size) { // create map of attributes from list @@ -1641,7 +2052,7 @@ extern "C" { } /*! \brief returns status of calling createStatefulOp function for operator from library */ - MX_INT_RET _opCallCreateOpState(createOpState_t create_op, const char* const* keys, + MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char* const* keys, const char* const* vals, int num, void** state_op) { // create map of attributes from list @@ -1652,7 +2063,8 @@ extern "C" { // void pointer to hold custom state op instance created in custom library // eventually state_op pointer is populated by instance from custom library - CustomStatefulOp** op_ptr = reinterpret_cast(state_op); + mxnet::ext::CustomStatefulOp** op_ptr = + reinterpret_cast(state_op); return create_op(attrs, op_ptr); } @@ -1662,9 +2074,11 @@ extern "C" { const char** indev_type, int* indev_id, int num_in, const int64_t** outshapes, int* outdims, void** outdata, int* outtypes, size_t* outIDs, const char** outdev_type, - int* outdev_id, int num_out, xpu_malloc_t cpu_malloc, - void* cpu_alloc, xpu_malloc_t gpu_malloc, void* gpu_alloc, - void* stream, sparse_malloc_t sparse_malloc, + int* outdev_id, int num_out, + mxnet::ext::xpu_malloc_t cpu_malloc, + void* cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc, + void* gpu_alloc, + void* stream, mxnet::ext::sparse_malloc_t sparse_malloc, void* sparse_alloc, int* instypes, int* outstypes, void** in_indices, void** out_indices, void** in_indptr, void** out_indptr, int64_t* in_indices_shapes, @@ -1672,64 +2086,68 @@ extern "C" { int64_t* out_indptr_shapes, void* rng_cpu_states, void* rng_gpu_states) { // create a vector of tensors for inputs - std::vector inputs(num_in); + std::vector inputs(num_in); // create a vector for sparse inputs - std::vector in_sparse(num_in); + std::vector in_sparse(num_in); for (int i = 0; i < num_in; i++) { if (instypes[i] == 0) { // Dense representation. - inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i], - inIDs[i], MXContext(indev_type[i], indev_id[i]), kDefaultStorage); + inputs[i].setTensor(indata[i], (mxnet::ext::MXDType)intypes[i], inshapes[i], indims[i], + inIDs[i], mxnet::ext::MXContext(indev_type[i], indev_id[i]), + mxnet::ext::kDefaultStorage); } else { // Sparse representation. - MXStorageType type; + mxnet::ext::MXStorageType type; if (instypes[i] == 1) { - type = kRowSparseStorage; + type = mxnet::ext::kRowSparseStorage; in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]); } else { - type = kCSRStorage; + type = mxnet::ext::kCSRStorage; in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i], in_indptr[i], in_indptr_shapes[i]); } - inputs[i].setTensor(reinterpret_cast(&in_sparse[i]), (MXDType)intypes[i], + inputs[i].setTensor(reinterpret_cast(&in_sparse[i]), (mxnet::ext::MXDType)intypes[i], inshapes[i], indims[i], inIDs[i], - MXContext(indev_type[i], indev_id[i]), type); + mxnet::ext::MXContext(indev_type[i], indev_id[i]), type); } } // create a vector of tensors for outputs - std::vector outputs(num_out); + std::vector outputs(num_out); // create a vector for sparse outputs - std::vector out_sparse(num_out); + std::vector out_sparse(num_out); for (int i = 0; i < num_out; i++) { if (outstypes[i] == 0) { // Dense representation. - outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i], - outIDs[i], MXContext(outdev_type[i], outdev_id[i]), kDefaultStorage); + outputs[i].setTensor(outdata[i], (mxnet::ext::MXDType)outtypes[i], outshapes[i], outdims[i], + outIDs[i], mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), + mxnet::ext::kDefaultStorage); } else { // Sparse representation. - MXStorageType type; + mxnet::ext::MXStorageType type; if (outstypes[i] == 1) { - type = kRowSparseStorage; + type = mxnet::ext::kRowSparseStorage; out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i], out_indices_shapes[i]); } else { - type = kCSRStorage; + type = mxnet::ext::kCSRStorage; out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i], out_indices_shapes[i], out_indptr[i], out_indptr_shapes[i]); } - outputs[i].setTensor(reinterpret_cast(&out_sparse[i]), (MXDType)outtypes[i], + outputs[i].setTensor(reinterpret_cast(&out_sparse[i]), + (mxnet::ext::MXDType)outtypes[i], outshapes[i], outdims[i], outIDs[i], - MXContext(outdev_type[i], outdev_id[i]), type); + mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), type); } } - OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, - stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states); + mxnet::ext::OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, + stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states); - CustomStatefulOp* op_ptr = reinterpret_cast(state_op); + mxnet::ext::CustomStatefulOp* op_ptr = + reinterpret_cast(state_op); if (is_forward) { return op_ptr->Forward(&inputs, &outputs, res); } @@ -1738,22 +2156,25 @@ extern "C" { /*! \brief returns number of partitioners registered in this library */ MX_INT_RET _partRegSize() { - return Registry::get()->size(); + return mxnet::ext::Registry::get()->size(); } /* returns number of strategies registered for partitioner * at specified index */ MX_INT_RET _partRegGetCount(int idx, const char** name) { - CustomPartitioner part = Registry::get()->get(idx); + mxnet::ext::CustomPartitioner part = + mxnet::ext::Registry::get()->get(idx); *name = part.name; return part.strategies.size(); } /*! \brief returns partitioner registration at specified index */ MX_VOID_RET _partRegGet(int part_idx, int stg_idx, const char** strategy, - supportedOps_t* supportedOps, createSelector_t* createSelector, - reviewSubgraph_t* reviewSubgraph, const char** op_name) { - CustomPartitioner part = Registry::get()->get(part_idx); + mxnet::ext::supportedOps_t* supportedOps, + mxnet::ext::createSelector_t* createSelector, + mxnet::ext::reviewSubgraph_t* reviewSubgraph, const char** op_name) { + mxnet::ext::CustomPartitioner part = + mxnet::ext::Registry::get()->get(part_idx); *strategy = part.strategies[stg_idx]; *op_name = part.op_names[stg_idx]; *supportedOps = part.getSupportedOps(stg_idx); @@ -1762,10 +2183,10 @@ extern "C" { } /*! \brief returns status of calling supported ops function from library */ - MX_INT_RET _partCallSupportedOps(supportedOps_t supportedOps, const char *json, + MX_INT_RET _partCallSupportedOps(mxnet::ext::supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char* const* opt_keys, const char* const* opt_vals, int num_opts) { - std::string subgraph_json(json); + mxnet::ext::Graph *graph = mxnet::ext::Graph::fromString(json); // create map of options from list std::unordered_map opts; for (int i = 0; i < num_opts; i++) @@ -1774,7 +2195,7 @@ extern "C" { // create array of subgraph IDs for operator support std::vector _ids(num_ids, -2); // call user's supportedOps function - MXReturnValue retval = supportedOps(subgraph_json, &_ids, opts); + mxnet::ext::MXReturnValue retval = supportedOps(graph, &_ids, opts); if (!retval) return retval; // copy bools in ids to ints @@ -1785,10 +2206,10 @@ extern "C" { } /*! \brief returns status of calling create selector function from library */ - MX_INT_RET _partCallCreateSelector(createSelector_t createSelector, const char *json, + MX_INT_RET _partCallCreateSelector(mxnet::ext::createSelector_t createSelector, const char *json, void** selector, const char* const* opt_keys, const char* const* opt_vals, int num_opts) { - std::string symbol_json(json); + mxnet::ext::Graph *graph = mxnet::ext::Graph::fromString(json); // create map of options from list std::unordered_map opts; for (int i = 0; i < num_opts; i++) @@ -1796,36 +2217,41 @@ extern "C" { // void pointer to hold selector instance created in custom library // eventually pointer is populated by instance from custom library - CustomOpSelector** sel_ptr = reinterpret_cast(selector); + mxnet::ext::CustomOpSelector** sel_ptr = + reinterpret_cast(selector); // call user's createSelector function - return createSelector(symbol_json, sel_ptr, opts); + return createSelector(graph, sel_ptr, opts); } /*! \brief returns status of calling select function from library */ MX_VOID_RET _partCallSelect(void* sel_inst, int nodeID, int* selected) { - CustomOpSelector* sel_ptr = reinterpret_cast(sel_inst); + mxnet::ext::CustomOpSelector* sel_ptr = + reinterpret_cast(sel_inst); *selected = sel_ptr->Select(nodeID); } /*! \brief returns status of calling select input function from library */ MX_VOID_RET _partCallSelectInput(void* sel_inst, int nodeID, int input_nodeID, int* selected) { - CustomOpSelector* sel_ptr = reinterpret_cast(sel_inst); + mxnet::ext::CustomOpSelector* sel_ptr = + reinterpret_cast(sel_inst); *selected = sel_ptr->SelectInput(nodeID, input_nodeID); } /*! \brief returns status of calling select output function from library */ MX_VOID_RET _partCallSelectOutput(void* sel_inst, int nodeID, int output_nodeID, int* selected) { - CustomOpSelector* sel_ptr = reinterpret_cast(sel_inst); + mxnet::ext::CustomOpSelector* sel_ptr = + reinterpret_cast(sel_inst); *selected = sel_ptr->SelectOutput(nodeID, output_nodeID); } /*! \brief returns status of calling filter function from library */ MX_VOID_RET _partCallFilter(void* sel_inst, int* candidates, int num_candidates, int** keep, int* num_keep) { - CustomOpSelector* sel_ptr = reinterpret_cast(sel_inst); + mxnet::ext::CustomOpSelector* sel_ptr = + reinterpret_cast(sel_inst); std::vector candidates_(num_candidates); for (int i=0; i < num_candidates; i++) { candidates_[i] = candidates[i]; @@ -1842,12 +2268,13 @@ extern "C" { /*! \brief returns status of calling reset selector function from library */ MX_VOID_RET _partCallReset(void* sel_inst) { - CustomOpSelector* sel_ptr = reinterpret_cast(sel_inst); + mxnet::ext::CustomOpSelector* sel_ptr = + reinterpret_cast(sel_inst); sel_ptr->Reset(); } /*! \brief returns status of calling review subgraph function from library */ - MX_INT_RET _partCallReviewSubgraph(reviewSubgraph_t reviewSubgraph, const char *json, + MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char* const* opt_keys, const char* const* opt_vals, int num_opts, char*** attr_keys, char*** attr_vals, int *num_attrs, @@ -1861,7 +2288,7 @@ extern "C" { const int* aux_dims, const int* aux_types, const size_t* aux_IDs, const char* const* aux_dev_type, const int* aux_dev_id) { - std::string subgraph_json(json); + mxnet::ext::Graph *subgraph = mxnet::ext::Graph::fromString(json); bool accept_bool = false; // create map of attributes from list std::unordered_map opts; @@ -1869,50 +2296,50 @@ extern "C" { opts[std::string(opt_keys[i])] = std::string(opt_vals[i]); // create a map of named tensors for args - std::unordered_map args; + std::unordered_map args; for (int i = 0; i < num_args; i++) { std::vector shapes; for (int j = 0; j < arg_dims[i]; j++) shapes.push_back(arg_shapes[i][j]); - MXTensor tensor(arg_data[i], shapes, (MXDType)arg_types[i], - arg_IDs[i], MXContext(arg_dev_type[i], arg_dev_id[i])); + mxnet::ext::MXTensor tensor(arg_data[i], shapes, (mxnet::ext::MXDType)arg_types[i], + arg_IDs[i], mxnet::ext::MXContext(arg_dev_type[i], arg_dev_id[i])); args[arg_names[i]] = tensor; } // create a map of named tensors for aux - std::unordered_map aux; + std::unordered_map aux; for (int i = 0; i < num_aux; i++) { std::vector shapes; for (int j = 0; j < aux_dims[i]; j++) shapes.push_back(aux_shapes[i][j]); - MXTensor tensor(aux_data[i], shapes, (MXDType)aux_types[i], - aux_IDs[i], MXContext(aux_dev_type[i], aux_dev_id[i])); + mxnet::ext::MXTensor tensor(aux_data[i], shapes, (mxnet::ext::MXDType)aux_types[i], + aux_IDs[i], mxnet::ext::MXContext(aux_dev_type[i], + aux_dev_id[i])); aux[aux_names[i]] = tensor; } - // attributes to set on subgraph node - std::unordered_map attrs; - - MXReturnValue retval = reviewSubgraph(subgraph_json, subgraph_id, &accept_bool, - opts, &attrs, args, aux); + subgraph->_setParams(&args, &aux); + mxnet::ext::MXReturnValue retval = reviewSubgraph(subgraph, subgraph_id, &accept_bool, + opts); if (!retval) return retval; *accept = accept_bool; - if (attrs.size() > 0) { - *num_attrs = attrs.size(); + if (subgraph->attrs.size() > 0) { + *num_attrs = subgraph->attrs.size(); // allocate space for attributes - *attr_keys = static_cast(malloc (attrs.size() * sizeof(char*))); - *attr_vals = static_cast(malloc (attrs.size() * sizeof(char*))); + *attr_keys = static_cast(malloc (*num_attrs * sizeof(char*))); + *attr_vals = static_cast(malloc (*num_attrs * sizeof(char*))); // copy attributes int i = 0; - for (auto kv : attrs) { + for (auto kv : subgraph->attrs) { (*attr_keys)[i] = static_cast(malloc ((kv.first.size()+1) * sizeof(char))); - (*attr_vals)[i] = static_cast(malloc ((kv.second.size()+1) * sizeof(char))); + std::string val = kv.second.dump(); // convert JsonVal back to string + (*attr_vals)[i] = static_cast(malloc ((val.size()+1) * sizeof(char))); snprintf((*attr_keys)[i], kv.first.size()+1, "%s", kv.first.c_str()); - snprintf((*attr_vals)[i], kv.second.size()+1, "%s", kv.second.c_str()); + snprintf((*attr_vals)[i], val.size()+1, "%s", val.c_str()); i++; } } @@ -1922,20 +2349,21 @@ extern "C" { /*! \brief returns number of graph passes registered in this library */ MX_INT_RET _passRegSize() { - return Registry::get()->size(); + return mxnet::ext::Registry::get()->size(); } /*! \brief returns pass registration at specified index */ - MX_VOID_RET _passRegGet(int pass_idx, graphPass_t* graphPass, + MX_VOID_RET _passRegGet(int pass_idx, mxnet::ext::graphPass_t* graphPass, const char** pass_name) { - CustomPass pass = Registry::get()->get(pass_idx); + mxnet::ext::CustomPass pass = + mxnet::ext::Registry::get()->get(pass_idx); *graphPass = pass.pass; *pass_name = pass.name; } /*! \brief returns status of calling graph pass function from library */ - MX_INT_RET _passCallGraphPass(graphPass_t graphPass, const char *json, - char** graph, const char* const* opt_keys, + MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *json, + char** out_graph, const char* const* opt_keys, const char* const* opt_vals, int num_opts, const char* pass_name, const char* const* arg_names, int num_args, void* const* arg_data, const int64_t* const* arg_shapes, @@ -1945,51 +2373,48 @@ extern "C" { void* const* aux_data, const int64_t* const* aux_shapes, const int* aux_dims, const int* aux_types, const size_t* aux_IDs, const char* const* aux_dev_type, - const int* aux_dev_id, nd_malloc_t nd_malloc, + const int* aux_dev_id, mxnet::ext::nd_malloc_t nd_malloc, const void* nd_alloc) { - std::string graph_json(json); - const std::string* out_graph = nullptr; + mxnet::ext::Graph *graph = mxnet::ext::Graph::fromString(json); // create map of attributes from list std::unordered_map opts; for (int i = 0; i < num_opts; i++) opts[std::string(opt_keys[i])] = std::string(opt_vals[i]); // create a map of named tensors for args - std::unordered_map args; + std::unordered_map args; for (int i = 0; i < num_args; i++) { std::vector shapes; for (int j = 0; j < arg_dims[i]; j++) shapes.push_back(arg_shapes[i][j]); - MXTensor tensor(arg_data[i], shapes, (MXDType)arg_types[i], - arg_IDs[i], MXContext(arg_dev_type[i], arg_dev_id[i])); + mxnet::ext::MXTensor tensor(arg_data[i], shapes, (mxnet::ext::MXDType)arg_types[i], + arg_IDs[i], mxnet::ext::MXContext(arg_dev_type[i], + arg_dev_id[i])); args[arg_names[i]] = tensor; } // create a map of named tensors for aux - std::unordered_map aux; + std::unordered_map aux; for (int i = 0; i < num_aux; i++) { std::vector shapes; for (int j = 0; j < aux_dims[i]; j++) shapes.push_back(aux_shapes[i][j]); - MXTensor tensor(aux_data[i], shapes, (MXDType)aux_types[i], - aux_IDs[i], MXContext(aux_dev_type[i], aux_dev_id[i])); + mxnet::ext::MXTensor tensor(aux_data[i], shapes, (mxnet::ext::MXDType)aux_types[i], + aux_IDs[i], mxnet::ext::MXContext(aux_dev_type[i], + aux_dev_id[i])); aux[aux_names[i]] = tensor; } - std::unordered_map new_args, new_aux; - PassResource res(&new_args, &new_aux, nd_malloc, nd_alloc); - MXReturnValue retval = graphPass(graph_json, &out_graph, opts, args, aux, res); + std::unordered_map new_args, new_aux; + mxnet::ext::PassResource res(&new_args, &new_aux, nd_malloc, nd_alloc); + graph->_setParams(&args, &aux); + graph->_setPassResource(&res); + mxnet::ext::MXReturnValue retval = graphPass(graph, opts); if (!retval) return retval; - if (out_graph == nullptr) { - std::cout << "Error calling graph pass '" << pass_name - << "' returned out_graph string is null" << std::endl; - return MX_FAIL; - } - *graph = static_cast(malloc((out_graph->length()+1) * sizeof(char))); - out_graph->copy(*graph, out_graph->size()+1); - delete out_graph; + std::string *tmp = new std::string(graph->toString()); + *out_graph = const_cast(tmp->c_str()); return retval; } @@ -2001,10 +2426,19 @@ extern "C" { * \return Non-zero value on error i.e. library incompatible with passed MXNet version */ #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) MXReturnValue __cdecl + __declspec(dllexport) mxnet::ext::MXReturnValue __cdecl #else - MXReturnValue + mxnet::ext::MXReturnValue #endif initialize(int version); -} + + MX_INT_RET _msgSize() { + return mxnet::ext::MXerrorMsgs::get()->size(); + } + + /*! \brief returns operator registration at specified index */ + MX_VOID_RET _msgGet(int idx, const char** msg) { + *msg = mxnet::ext::MXerrorMsgs::get()->get(idx)->c_str(); + } +} // extern "C" #endif // MXNET_LIB_API_H_ diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 968c78760af9..d7afd8a787b4 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1035,7 +1035,7 @@ def _call_cached_op(self, *args): out = [out] return _regroup(out, self._out_format) - def optimize_for(self, x, *args, backend=None, backend_opts=None, **kwargs): + def optimize_for(self, x, *args, backend=None, backend_opts=None, clear=True, **kwargs): """Partitions the current HybridBlock and optimizes it for a given backend without executing a forward pass. Modifies the HybridBlock in-place. @@ -1065,6 +1065,7 @@ def optimize_for(self, x, *args, backend=None, backend_opts=None, **kwargs): The name of backend, as registered in `SubgraphBackendRegistry`, default None backend_opts : dict of user-specified options to pass to the backend for partitioning, optional Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty` + clear : clears any previous optimizations static_alloc : bool, default False Statically allocate memory to improve speed. Memory usage may increase. static_shape : bool, default False @@ -1074,7 +1075,7 @@ def optimize_for(self, x, *args, backend=None, backend_opts=None, **kwargs): """ # do hybrize API call - self.hybridize(True, backend, backend_opts, **kwargs) + self.hybridize(True, backend, backend_opts, clear, **kwargs) # do part of forward API call has_symbol, has_ndarray, ctx_set, _ = _gather_type_ctx_info([x] + list(args)) @@ -1112,7 +1113,7 @@ def register_child(self, block, name=None): super(HybridBlock, self).register_child(block, name) self._clear_cached_op() - def hybridize(self, active=True, backend=None, backend_opts=None, **kwargs): + def hybridize(self, active=True, backend=None, backend_opts=None, clear=True, **kwargs): """Activates or deactivates :py:class:`HybridBlock` s recursively. Has no effect on non-hybrid children. @@ -1124,6 +1125,7 @@ def hybridize(self, active=True, backend=None, backend_opts=None, **kwargs): The name of backend, as registered in `SubgraphBackendRegistry`, default None backend_opts : dict of user-specified options to pass to the backend for partitioning, optional Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty` + clear : clears any previous optimizations static_alloc : bool, default False Statically allocate memory to improve speed. Memory usage may increase. static_shape : bool, default False @@ -1140,7 +1142,8 @@ def hybridize(self, active=True, backend=None, backend_opts=None, **kwargs): self._active = active self._flags = list(kwargs.items()) - self._clear_cached_op() + if clear: + self._clear_cached_op() if active and self._forward_hooks or self._forward_pre_hooks: warnings.warn('"{block}" is being hybridized while still having forward hook/pre-hook. ' 'If "{block}" is a child of HybridBlock, the hooks will not take effect.' diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index fdc794231465..06eea017d6c5 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -97,21 +97,41 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e, // NOTE: return value is added in API_END +std::string getExtensionMsgs(mxnet::ext::msgSize_t msgSize, + mxnet::ext::msgGet_t msgGet) { + std::string str; + if (msgSize() > 0) { + str = "\nExtension Traceback:\n"; + for (int i = 0; i < msgSize(); i++) { + const char* tmp; + msgGet(i, &tmp); + // format: [i] message + str += std::string("\t[") + std::to_string(i) + std::string("] ") + + std::string(tmp) + std::string("\n"); + } + } + return str; +} + /*! * \brief Common compute function dispatcher for forward/backward and stateful forward/backward * state_ptr will be nullptr for regular ops; fcomp_fp is nullptr for stateful ops */ void CustomFComputeDispatcher(const std::string op_name, - const opCallFComp_t callFComp, - const fcomp_t fcomp_fp, + const mxnet::ext::opCallFComp_t callFComp, + const mxnet::ext::fcomp_t fcomp_fp, const nnvm::NodeAttrs* attrs, - const opCallFStatefulComp_t callFStatefulComp, + const mxnet::ext::opCallFStatefulComp_t callFStatefulComp, int stateful_forward_flag, const OpStatePtr* state_ptr, const OpContext& ctx, const std::vector& inputs, const std::vector& req, - const std::vector& outputs) { + const std::vector& outputs, + mxnet::ext::msgSize_t msgSize, + mxnet::ext::msgGet_t msgGet) { + using namespace mxnet::ext; + std::vector in_data, out_data; std::vector in_shapes, out_shapes; std::vector in_dims, out_dims; @@ -280,47 +300,225 @@ void CustomFComputeDispatcher(const std::string op_name, } // call fcompute function - CHECK(callFComp(fcomp_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - in_shapes.data(), in_dims.data(), in_data.data(), in_types.data(), - in_verIDs.data(), in_dev_type.data(), in_dev_id.data(), in_data.size(), - out_shapes.data(), out_dims.data(), out_data.data(), out_types.data(), - out_verIDs.data(), out_dev_type.data(), out_dev_id.data(), out_data.size(), - cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream, - sparse_malloc, &sparse_alloc, in_stypes.data(), out_stypes.data(), - in_indices.data(), out_indices.data(), in_indptr.data(), out_indptr.data(), - in_indices_shapes.data(), out_indices_shapes.data(), - in_indptr_shapes.data(), out_indptr_shapes.data(), - rng_cpu_states, rng_gpu_states)) - << "Error calling FCompute for custom operator '" << op_name << "'"; + int retval = callFComp(fcomp_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + in_shapes.data(), in_dims.data(), in_data.data(), in_types.data(), + in_verIDs.data(), in_dev_type.data(), in_dev_id.data(), in_data.size(), + out_shapes.data(), out_dims.data(), out_data.data(), out_types.data(), + out_verIDs.data(), out_dev_type.data(), out_dev_id.data(), + out_data.size(), + cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream, + sparse_malloc, &sparse_alloc, in_stypes.data(), out_stypes.data(), + in_indices.data(), out_indices.data(), in_indptr.data(), + out_indptr.data(), + in_indices_shapes.data(), out_indices_shapes.data(), + in_indptr_shapes.data(), out_indptr_shapes.data(), + rng_cpu_states, rng_gpu_states); + std::string msgs = getExtensionMsgs(msgSize, msgGet); + CHECK(retval) << "Error calling FCompute for custom operator '" << op_name << "'" << msgs; } if (state_ptr != nullptr) { // retrieve op state object created from CreateOpState CustomStatefulOpWrapper& op = state_ptr->get_state(); CustomStatefulOp* state_op_inst = op.get_instance(); + std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(state_op_inst != nullptr) - << "Error custom stateful operator is null for operator '" << op_name << "'"; + << "Error custom stateful operator is null for operator '" << op_name << "'" << msgs; // call fcompute function - CHECK(callFStatefulComp(stateful_forward_flag, state_op_inst, - in_shapes.data(), in_dims.data(), in_data.data(), in_types.data(), - in_verIDs.data(), in_dev_type.data(), in_dev_id.data(), - in_data.size(), - out_shapes.data(), out_dims.data(), out_data.data(), out_types.data(), - out_verIDs.data(), out_dev_type.data(), out_dev_id.data(), - out_data.size(), - cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream, - sparse_malloc, &sparse_alloc, in_stypes.data(), out_stypes.data(), - in_indices.data(), out_indices.data(), - in_indptr.data(), out_indptr.data(), - in_indices_shapes.data(), out_indices_shapes.data(), - in_indptr_shapes.data(), out_indptr_shapes.data(), - rng_cpu_states, rng_gpu_states)) - << "Error calling FStatefulCompute for custom operator '" << op_name << "'"; - } -} - -void registerOperators(void *lib, int verbose) { + int retval = callFStatefulComp(stateful_forward_flag, state_op_inst, + in_shapes.data(), in_dims.data(), in_data.data(), + in_types.data(), + in_verIDs.data(), in_dev_type.data(), in_dev_id.data(), + in_data.size(), + out_shapes.data(), out_dims.data(), out_data.data(), + out_types.data(), + out_verIDs.data(), out_dev_type.data(), out_dev_id.data(), + out_data.size(), + cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream, + sparse_malloc, &sparse_alloc, in_stypes.data(), + out_stypes.data(), in_indices.data(), out_indices.data(), + in_indptr.data(), out_indptr.data(), + in_indices_shapes.data(), out_indices_shapes.data(), + in_indptr_shapes.data(), out_indptr_shapes.data(), + rng_cpu_states, rng_gpu_states); + msgs = getExtensionMsgs(msgSize, msgGet); + CHECK(retval) << "Error calling FStatefulCompute for custom operator '" << op_name << "'" + << msgs; + } +} + +template +void registerOp(const char* name, const std::string& name_str, bool isSubgraphOp, + RescReq resc_req, AttrParser attr_parser, NumInputs num_inputs, + NumOutputs num_outputs, NumInOuts num_inouts, InferType infer_type, + InferShape infer_shape, InferSType infer_storage_type, + MutateInputs mutate_inputs, SubgraphNumInputs num_subgraph_inputs, + SubgraphInferType infer_subgraph_type, SubgraphInferShape infer_subgraph_shape, + SubgraphInferSType infer_subgraph_storage_type, CreateOpState create_opstate, + GradReg grad_reg, mxnet::ext::mutateInputs_t mutate_fp, + const std::unordered_map &createop_map, + const std::unordered_map &forward_ctx_map, + const std::unordered_map &backward_ctx_map, + mxnet::ext::opCallFComp_t callFComp, + mxnet::ext::opCallFStatefulComp_t callFStatefulComp, + mxnet::ext::msgSize_t msgSize, + mxnet::ext::msgGet_t msgGet) { + using namespace mxnet::ext; + + // check if operator is already registered + const nnvm::Op *regOpPtr = dmlc::Registry::Get()->Find(name); + nnvm::Op ®Op = dmlc::Registry::Get()->__REGISTER_OR_GET__(name); + int plevel = 10; + if (regOpPtr != nullptr) { + // overwrite registration of existing op with custom op + regOp.arguments.clear(); + // set attribute with higher plevel (11) to allow re-registering once + // TODO(samskalicky): enable constant overwriting of registertion multiple times + plevel++; + } + // define supported resources for both subgraph ops and regular ops + regOp.set_attr("FResourceRequest", resc_req, plevel); + if (!isSubgraphOp) { + regOp.set_attr_parser(attr_parser); + regOp.set_num_inputs(num_inputs); + regOp.set_num_outputs(num_outputs); + regOp.set_attr("FInferType", infer_type, plevel); + regOp.set_attr("FInferStorageType", infer_storage_type, plevel); + regOp.set_attr("FInferShape", infer_shape, plevel); + // optionally add fmutate inputs if user specified a function + if (mutate_fp != nullptr) + regOp.set_attr("FMutateInputs", mutate_inputs, plevel); + } else { + using namespace mxnet::op; + regOp.set_num_inputs(num_subgraph_inputs); + regOp.set_num_outputs(DefaultSubgraphOpNumOutputs); + regOp.set_attr("FInferType", infer_subgraph_type, plevel); + regOp.set_attr("FInferShape", infer_subgraph_shape, plevel); + regOp.set_attr("FInferStorageType", + infer_subgraph_storage_type, plevel); + regOp.set_attr("FMutateInputs", + DefaultSubgraphOpMutableInputs, plevel); + } + // optionally add stateful forward + if (createop_map.size() != 0) { + regOp.set_attr("FCreateOpState", create_opstate, plevel); + auto fstate_forward = [=](const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CustomFComputeDispatcher(name_str, nullptr, nullptr, nullptr, + callFStatefulComp, 1, &state_ptr, ctx, inputs, req, outputs, + msgSize, msgGet); + }; + if (createop_map.count("cpu") > 0) + regOp.set_attr("FStatefulComputeEx", fstate_forward, plevel); + if (createop_map.count("gpu") > 0) + regOp.set_attr("FStatefulComputeEx", fstate_forward, plevel); + } else { + auto forward_lambda = [=](const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (ctx.run_ctx.ctx.dev_mask() == Context::kCPU) { + CHECK_GT(forward_ctx_map.count("cpu"), 0); + fcomp_t fcomp = forward_ctx_map.at("cpu"); + CustomFComputeDispatcher(name_str, callFComp, fcomp, &attrs, + nullptr, 0, nullptr, ctx, inputs, req, outputs, msgSize, msgGet); + } else if (ctx.run_ctx.ctx.dev_mask() == Context::kGPU) { + CHECK_GT(forward_ctx_map.count("gpu"), 0); + fcomp_t fcomp = forward_ctx_map.at("gpu"); + CustomFComputeDispatcher(name_str, callFComp, fcomp, &attrs, + nullptr, 0, nullptr, ctx, inputs, req, outputs, msgSize, msgGet); + } + }; + if (forward_ctx_map.count("cpu") > 0) + regOp.set_attr("FComputeEx", forward_lambda, plevel); + if (forward_ctx_map.count("gpu") > 0) + regOp.set_attr("FComputeEx", forward_lambda, plevel); + } + // optionally add fgradient if user specified a function, or for stateful ops + if (backward_ctx_map.size() != 0 || createop_map.size() != 0) { + std::string grad_name = "_backward_" + name_str; + nnvm::Op &gradOp = dmlc::Registry::Get()->__REGISTER_OR_GET__(grad_name); + regOp.set_attr("FGradient", grad_reg, plevel); + gradOp.set_attr("TIsBackward", true, plevel); + gradOp.set_attr("FInferStorageType", infer_storage_type, plevel); + gradOp.set_attr("FResourceRequest", resc_req, plevel); + + if (!isSubgraphOp) { + // register attr parser and standard functions for non-subgraph ops + gradOp.set_attr_parser(attr_parser); + gradOp.set_num_inputs(num_inouts); + gradOp.set_num_outputs(num_inputs); + } else { + // for subgraph ops use special functions that do not invoke attr_parser + using namespace mxnet::op; + auto grad_inouts = [=](const nnvm::NodeAttrs& attrs) { + // for backward passes, inputs + outputs + input gradients (one for each output) + uint32_t cnt = num_subgraph_inputs(attrs); + cnt += 2 * DefaultSubgraphOpNumOutputs(attrs); + return cnt; + }; + gradOp.set_num_inputs(grad_inouts); + gradOp.set_num_outputs(num_subgraph_inputs); + } + + if (createop_map.size() != 0) { + // for stateful operators + gradOp.set_attr("TIsLayerOpBackward", true, plevel); + auto fstate_backward = [=](const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CustomFComputeDispatcher(name_str, nullptr, nullptr, nullptr, + callFStatefulComp, 0, &state_ptr, ctx, inputs, req, outputs, + msgSize, msgGet); + }; + gradOp.set_attr("FStatefulComputeEx", fstate_backward, plevel); + gradOp.set_attr("FStatefulComputeEx", fstate_backward, plevel); + } else { + // for stateless operators + if (backward_ctx_map.count("cpu") > 0) { + fcomp_t fcomp_back_cpu = backward_ctx_map.at("cpu"); + auto backward_cpu_lambda = [=](const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CustomFComputeDispatcher(name_str, callFComp, fcomp_back_cpu, &attrs, + nullptr, 0, nullptr, ctx, inputs, req, outputs, msgSize, msgGet); + }; + gradOp.set_attr("FComputeEx", backward_cpu_lambda, plevel); + } + if (backward_ctx_map.count("gpu") > 0) { + fcomp_t fcomp_back_gpu = backward_ctx_map.at("gpu"); + auto backward_gpu_lambda = [=](const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CustomFComputeDispatcher(name_str, callFComp, fcomp_back_gpu, &attrs, + nullptr, 0, nullptr, ctx, inputs, req, outputs, msgSize, msgGet); + }; + gradOp.set_attr("FComputeEx", backward_gpu_lambda, plevel); + } + } + } + regOp.add_argument("data", "NDArray[]", "Source inputs"); +} + +void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, + mxnet::ext::msgGet_t msgGet) { + using namespace mxnet::ext; + // get C type interface functions opCallFree_t callFree = get_func(lib, const_cast(MXLIB_OPCALLFREE_STR)); @@ -446,9 +644,10 @@ void registerOperators(void *lib, int verbose) { int num_in = -1; int num_out = -1; - CHECK(callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - &num_in, &num_out)) - << "Error calling ParseAttrs for custom operator '" << name_str << "'"; + int retval = callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + &num_in, &num_out); + std::string msgs = getExtensionMsgs(msgSize, msgGet); + CHECK(retval) << "Error calling ParseAttrs for custom operator '" << name_str << "'" << msgs; // return type void }; @@ -464,11 +663,31 @@ void registerOperators(void *lib, int verbose) { int num_in = -1; int num_out = -1; - CHECK(callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - &num_in, &num_out)) - << "Error calling ParseAttrs::num_inputs for custom operator '" << name_str << "'"; + int retval = callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + &num_in, &num_out); + std::string msgs = getExtensionMsgs(msgSize, msgGet); + CHECK(retval) << "Error calling ParseAttrs::num_inputs for custom operator '" << name_str + << "'" << msgs; + + // get extra inputs, if exists + int extra_inputs = 0; + if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0) + extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS)); + + return num_in + extra_inputs; + }; + + // lambda function to call parse attributes and return the number of inputs for subgraph ops + auto num_subgraph_inputs = [=](const NodeAttrs& attrs) { + // get number of inputs for subgraph + int num_in = mxnet::op::DefaultSubgraphOpNumInputs(attrs); - return num_in; + // get extra inputs, if exists + int extra_inputs = 0; + if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0) + extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS)); + + return num_in + extra_inputs; }; // lambda function to call parse attributes and return the number of outputs @@ -482,9 +701,11 @@ void registerOperators(void *lib, int verbose) { int num_in = -1; int num_out = -1; - CHECK(callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - &num_in, &num_out)) - << "Error calling ParseAttrs::num_outputs for custom operator '" << name_str << "'"; + int retval = callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + &num_in, &num_out); + std::string msgs = getExtensionMsgs(msgSize, msgGet); + CHECK(retval) << "Error calling ParseAttrs::num_outputs for custom operator '" << name_str + << "'" << msgs; return num_out; }; @@ -501,11 +722,19 @@ void registerOperators(void *lib, int verbose) { int num_in = -1; int num_out = -1; - CHECK(callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - &num_in, &num_out)) - << "Error calling ParseAttrs::num_outputs for custom operator '" << name_str << "'"; + int retval = callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + &num_in, &num_out); + std::string msgs = getExtensionMsgs(msgSize, msgGet); + CHECK(retval) << "Error calling ParseAttrs::num_outputs for custom operator '" << name_str + << "'" << msgs; // for backward passes, inputs + outputs + input gradients (one for each output) - return num_in + 2 * num_out; + + // get extra inputs, if exists + int extra_inputs = 0; + if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0) + extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS)); + + return num_in + extra_inputs + 2 * num_out; }; // lambda function to call infer shape @@ -519,17 +748,24 @@ void registerOperators(void *lib, int verbose) { attr_vals.push_back(kv.second.c_str()); } - std::vector inshapes(in_shape->size()); - std::vector indims(in_shape->size()); + // get extra inputs, if exists + int extra_inputs = 0; + if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0) + extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS)); + int num_inputs = in_shape->size() - extra_inputs; + + std::vector inshapes(num_inputs); + std::vector indims(num_inputs); // determine amount of memory needed to store all the input shapes size_t buff_size = 0; - for (const auto& i : *in_shape) buff_size += i.ndim(); + for (size_t i = 0; i < num_inputs; ++i) + buff_size += (*in_shape)[i].ndim(); // copy input shapes from ShapeVector to raw memory layout std::vector inbuff(buff_size); uint32_t *ptr = inbuff.data(); - for (size_t i = 0; i < in_shape->size(); ++i) { + for (size_t i = 0; i < num_inputs; ++i) { inshapes[i] = ptr; indims[i] = (*in_shape)[i].ndim(); for (int j = 0; j < (*in_shape)[i].ndim(); ++j, ++ptr) { @@ -544,23 +780,24 @@ void registerOperators(void *lib, int verbose) { uint32_t** outshapes = nullptr; int* outdims = nullptr; - CHECK(callInferShape(shape_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - inshapes.data(), indims.data(), in_shape->size(), - &mod_inshapes, &mod_indims, - &outshapes, &outdims, out_shape->size())) - << "Error calling InferShape for custom operator '" << name_str << "'"; + int retval = callInferShape(shape_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + inshapes.data(), indims.data(), num_inputs, + &mod_inshapes, &mod_indims, + &outshapes, &outdims, out_shape->size()); + std::string msgs = getExtensionMsgs(msgSize, msgGet); + CHECK(retval) << "Error calling InferShape for custom operator '" << name_str << "'" << msgs; - std::vector in_shapes(in_shape->size()); + std::vector in_shapes(num_inputs); // determine amount of memory needed to store all the modified input shapes buff_size = 0; - for (unsigned i = 0; i < in_shape->size(); i++) { + for (unsigned i = 0; i < num_inputs; i++) { buff_size += mod_indims[i]; } // copy modified input shapes from custom op memory to MXNet memory std::vector mod_inbuff(buff_size); ptr = mod_inbuff.data(); - for (unsigned i = 0; i < in_shape->size(); ++i) { + for (unsigned i = 0; i < num_inputs; ++i) { in_shapes[i] = ptr; for (int j = 0; j < mod_indims[i]; ++j, ++ptr) { *ptr = static_cast(mod_inshapes[i][j]); @@ -568,7 +805,7 @@ void registerOperators(void *lib, int verbose) { } // assign modified input shapes to ShapeVector - for (unsigned i = 0; i < in_shape->size(); ++i) { + for (unsigned i = 0; i < num_inputs; ++i) { SHAPE_ASSIGN_CHECK(*in_shape, i, mxnet::TShape(in_shapes[i], in_shapes[i]+mod_indims[i])); } @@ -598,7 +835,7 @@ void registerOperators(void *lib, int verbose) { // free memory used by custom op to allocate shapes/dims callFree(mod_indims); - for (unsigned i = 0; i < in_shape->size(); i++) { + for (unsigned i = 0; i < num_inputs; i++) { callFree(mod_inshapes[i]); } callFree(mod_inshapes); @@ -612,6 +849,28 @@ void registerOperators(void *lib, int verbose) { return true; }; + // lambda function to call infer shape for subgraph ops + auto infer_subgraph_shape = [=] (const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_shape, + mxnet::ShapeVector *out_shape) { + // convert attributes to vector of char* + std::vector attr_keys, attr_vals; + for (auto &kv : attrs.dict) { + attr_keys.push_back(kv.first.c_str()); + attr_vals.push_back(kv.second.c_str()); + } + + // get extra inputs, if exists + int extra_inputs = 0; + if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0) + extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS)); + + auto in_first = in_shape->begin(); + auto in_last = in_first + in_shape->size() - extra_inputs; + mxnet::ShapeVector *sg_in_shapes = new mxnet::ShapeVector(in_first, in_last); + return mxnet::op::DefaultSubgraphOpShape(attrs, sg_in_shapes, out_shape); + }; + // lambda function to call infer type auto infer_type = [=] (const nnvm::NodeAttrs& attrs, std::vector *in_type, @@ -623,19 +882,26 @@ void registerOperators(void *lib, int verbose) { attr_vals.push_back(kv.second.c_str()); } + // get extra inputs, if exists + int extra_inputs = 0; + if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0) + extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS)); + int num_inputs = in_type->size() - extra_inputs; + // copy input types from in_type std::vector intypes(*in_type); // output types will be populated by inferType function std::vector outtypes(out_type->size()); - CHECK(callInferType(type_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - intypes.data(), in_type->size(), - outtypes.data(), out_type->size())) - << "Error calling InferType for custom operator '" << name_str << "'"; + int retval = callInferType(type_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + intypes.data(), num_inputs, + outtypes.data(), out_type->size()); + std::string msgs = getExtensionMsgs(msgSize, msgGet); + CHECK(retval) << "Error calling InferType for custom operator '" << name_str << "'" << msgs; // copy and assign modified input types from custom op to MXNet memory - for (size_t i = 0; i < in_type->size(); i++) { + for (size_t i = 0; i < num_inputs; i++) { TYPE_ASSIGN_CHECK(*in_type, i, intypes[i]); } // copy and assign output types from custom op to MXNet memory @@ -646,6 +912,29 @@ void registerOperators(void *lib, int verbose) { return true; }; + // lambda function to call infer type for subgraph ops + auto infer_subgraph_type = [=] (const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + // convert attributes to vector of char* + std::vector attr_keys, attr_vals; + for (auto &kv : attrs.dict) { + attr_keys.push_back(kv.first.c_str()); + attr_vals.push_back(kv.second.c_str()); + } + + // get extra inputs, if exists + int extra_inputs = 0; + if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0) + extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS)); + + auto in_first = in_type->begin(); + auto in_last = in_first + in_type->size() - extra_inputs; + std::vector *sg_in_types = new std::vector(in_first, in_last); + + return mxnet::op::DefaultSubgraphOpType(attrs, sg_in_types, out_type); + }; + // lambda function to convert from external mutate_inputs to internal MXNet types auto mutate_inputs = [=](const nnvm::NodeAttrs& attrs) { // convert attributes to vector of char* @@ -660,9 +949,11 @@ void registerOperators(void *lib, int verbose) { int indices_size = 0; // call mutate inputs function - CHECK(callMutateInputs(mutate_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - &mutate_indices, &indices_size)) - << "Error calling MutateInputs for custom operator '" << name_str << "'"; + int retval = callMutateInputs(mutate_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + &mutate_indices, &indices_size); + std::string msgs = getExtensionMsgs(msgSize, msgGet); + CHECK(retval) << "Error calling MutateInputs for custom operator '" << name_str << "'" + << msgs; std::vector mutate_indices_list(indices_size); for (int i=0; i < indices_size; i++) { @@ -679,7 +970,7 @@ void registerOperators(void *lib, int verbose) { std::vector* in_stypes, std::vector* out_stypes) { if (stype_fp == nullptr) { - // InferSType is not defineid in customized lib. + // InferSType is not defined in customized lib. CHECK(mxnet::common::ContainsOnlyStorage(*in_stypes, mxnet::kDefaultStorage)) << "Error input tensors are not dense for custom operator '" << name_str << "'"; // set outputs as dense @@ -693,18 +984,27 @@ void registerOperators(void *lib, int verbose) { attr_keys.push_back(kv.first.c_str()); attr_vals.push_back(kv.second.c_str()); } + + // get extra inputs, if exists + int extra_inputs = 0; + if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0) + extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS)); + int num_inputs = in_stypes->size() - extra_inputs; + // copy input types from in_stype std::vector instypes(*in_stypes); // output types will be populated by inferType function std::vector outstypes(out_stypes->size()); - CHECK(callInferSType(stype_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - instypes.data(), in_stypes->size(), - outstypes.data(), out_stypes->size())) - << "Error calling InferSType for custom operator '" << name_str << "'"; + int retval = callInferSType(stype_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + instypes.data(), num_inputs, + outstypes.data(), out_stypes->size()); + std::string msgs = getExtensionMsgs(msgSize, msgGet); + CHECK(retval) << "Error calling InferSType for custom operator '" << name_str << "'" + << msgs; // copy and assign modified input storage types from custom op to MXNet memory. - for (size_t i = 0; i < in_stypes->size(); i++) { + for (size_t i = 0; i < num_inputs; i++) { STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, i, instypes[i]); } // copy and assign output storage types from custom op to MXNet memory. @@ -717,6 +1017,25 @@ void registerOperators(void *lib, int verbose) { } }; + // lambda function to set storage types for subgraph ops + auto infer_subgraph_storage_type = [=](const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_stypes, + std::vector* out_stypes) { + // get extra inputs, if exists + int extra_inputs = 0; + if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0) + extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS)); + + auto in_first = in_stypes->begin(); + auto in_last = in_first + in_stypes->size() - extra_inputs; + std::vector *sg_in_stypes = new std::vector(in_first, in_last); + + return mxnet::op::DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, + sg_in_stypes, out_stypes); + }; + // FGradient register lambda auto grad_reg = [=](const nnvm::ObjectPtr& n, const std::vector& ograds) { // create node for gradient @@ -789,19 +1108,24 @@ void registerOperators(void *lib, int verbose) { if (ctx.dev_mask() == Context::kCPU) { CHECK(createop_map.count("cpu") > 0) << "CPU CreateOpState not implemented for '" << name_str << "'"; - CHECK(callCreateOpState(createop_map.at("cpu"), attr_keys.data(), attr_vals.data(), - attr_keys.size(), &state_op_inst)) - << "Error calling CreateOpState CPU for custom operator '" << name_str << "'"; + int retval = callCreateOpState(createop_map.at("cpu"), attr_keys.data(), attr_vals.data(), + attr_keys.size(), &state_op_inst); + std::string msgs = getExtensionMsgs(msgSize, msgGet); + CHECK(retval) << "Error calling CreateOpState CPU for custom operator '" << name_str << "'" + << msgs; } else if (ctx.dev_mask() == Context::kGPU) { CHECK(createop_map.count("gpu") > 0) << "GPU CreateOpState not implemented for '" << name_str << "'"; - CHECK(callCreateOpState(createop_map.at("gpu"), attr_keys.data(), attr_vals.data(), - attr_keys.size(), &state_op_inst)) - << "Error calling CreateOpState GPU for custom operator '" << name_str << "'"; + int retval = callCreateOpState(createop_map.at("gpu"), attr_keys.data(), attr_vals.data(), + attr_keys.size(), &state_op_inst); + std::string msgs = getExtensionMsgs(msgSize, msgGet); + CHECK(retval) << "Error calling CreateOpState GPU for custom operator '" << name_str << "'" + << msgs; } + std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(state_op_inst != nullptr) - << "Error custom library failed to create stateful operator '" << name_str << "'"; + << "Error custom library failed to create stateful operator '" << name_str << "'" << msgs; CustomStatefulOp* state_op = reinterpret_cast(state_op_inst); return OpStatePtr::Create(state_op); @@ -809,151 +1133,19 @@ void registerOperators(void *lib, int verbose) { /* -------------- BELOW IS THE REGISTRATION FOR CUSTOM OPERATORS --------------- */ - // check if operator is already registered - const nnvm::Op *regOpPtr = dmlc::Registry::Get()->Find(name); - nnvm::Op ®Op = dmlc::Registry::Get()->__REGISTER_OR_GET__(name); - int plevel = 10; - if (regOpPtr != nullptr) { - // overwrite registration of existing op with custom op - regOp.arguments.clear(); - // set attribute with higher plevel (11) to allow re-registering once - // TODO(samskalicky): enable constant overwriting of registertion multiple times - plevel++; - } - // define supported resources for both subgraph ops and regular ops - regOp.set_attr("FResourceRequest", resc_req, plevel); - if (!isSubgraphOp) { - regOp.set_attr_parser(attr_parser); - regOp.set_num_inputs(num_inputs); - regOp.set_num_outputs(num_outputs); - regOp.set_attr("FInferType", infer_type, plevel); - regOp.set_attr("FInferStorageType", infer_storage_type, plevel); - regOp.set_attr("FInferShape", infer_shape, plevel); - // optionally add fmutate inputs if user specified a function - if (mutate_fp != nullptr) - regOp.set_attr("FMutateInputs", mutate_inputs, plevel); - } else { - using namespace mxnet::op; - regOp.set_num_inputs(DefaultSubgraphOpNumInputs); - regOp.set_num_outputs(DefaultSubgraphOpNumOutputs); - regOp.set_attr("FInferType", DefaultSubgraphOpType, plevel); - regOp.set_attr("FInferShape", DefaultSubgraphOpShape, plevel); - regOp.set_attr("FInferStorageType", - DefaultSubgraphOpStorageType, plevel); - regOp.set_attr("FMutateInputs", - DefaultSubgraphOpMutableInputs, plevel); - } - // optionally add stateful forward - if (createop_map.size() != 0) { - regOp.set_attr("FCreateOpState", create_opstate, plevel); - auto fstate_forward = [=](const OpStatePtr& state_ptr, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - CustomFComputeDispatcher(name_str, nullptr, nullptr, nullptr, - callFStatefulComp, 1, &state_ptr, ctx, inputs, req, outputs); - }; - if (createop_map.count("cpu") > 0) - regOp.set_attr("FStatefulComputeEx", fstate_forward, plevel); - if (createop_map.count("gpu") > 0) - regOp.set_attr("FStatefulComputeEx", fstate_forward, plevel); - } else { - auto forward_lambda = [=](const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - if (ctx.run_ctx.ctx.dev_mask() == Context::kCPU) { - CHECK_GT(forward_ctx_map.count("cpu"), 0); - fcomp_t fcomp = forward_ctx_map.at("cpu"); - CustomFComputeDispatcher(name_str, callFComp, fcomp, &attrs, - nullptr, 0, nullptr, ctx, inputs, req, outputs); - } else if (ctx.run_ctx.ctx.dev_mask() == Context::kGPU) { - CHECK_GT(forward_ctx_map.count("gpu"), 0); - fcomp_t fcomp = forward_ctx_map.at("gpu"); - CustomFComputeDispatcher(name_str, callFComp, fcomp, &attrs, - nullptr, 0, nullptr, ctx, inputs, req, outputs); - } - }; - if (forward_ctx_map.count("cpu") > 0) - regOp.set_attr("FComputeEx", forward_lambda, plevel); - if (forward_ctx_map.count("gpu") > 0) - regOp.set_attr("FComputeEx", forward_lambda, plevel); - } - // optionally add fgradient if user specified a function, or for stateful ops - if (backward_ctx_map.size() != 0 || createop_map.size() != 0) { - std::string grad_name = "_backward_" + name_str; - nnvm::Op &gradOp = dmlc::Registry::Get()->__REGISTER_OR_GET__(grad_name); - regOp.set_attr("FGradient", grad_reg, plevel); - gradOp.set_attr("TIsBackward", true, plevel); - gradOp.set_attr("FInferStorageType", infer_storage_type, plevel); - gradOp.set_attr("FResourceRequest", resc_req, plevel); - - if (!isSubgraphOp) { - // register attr parser and standard functions for non-subgraph ops - gradOp.set_attr_parser(attr_parser); - gradOp.set_num_inputs(num_inouts); - gradOp.set_num_outputs(num_inputs); - } else { - // for subgraph ops use special functions that do not invoke attr_parser - using namespace mxnet::op; - auto grad_inouts = [=](const nnvm::NodeAttrs& attrs) { - // for backward passes, inputs + outputs + input gradients (one for each output) - uint32_t cnt = DefaultSubgraphOpNumInputs(attrs); - cnt += 2 * DefaultSubgraphOpNumOutputs(attrs); - return cnt; - }; - gradOp.set_num_inputs(grad_inouts); - gradOp.set_num_outputs(DefaultSubgraphOpNumInputs); - } - - if (createop_map.size() != 0) { - // for stateful operators - gradOp.set_attr("TIsLayerOpBackward", true, plevel); - auto fstate_backward = [=](const OpStatePtr& state_ptr, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - CustomFComputeDispatcher(name_str, nullptr, nullptr, nullptr, - callFStatefulComp, 0, &state_ptr, ctx, inputs, req, outputs); - }; - gradOp.set_attr("FStatefulComputeEx", fstate_backward, plevel); - gradOp.set_attr("FStatefulComputeEx", fstate_backward, plevel); - } else { - // for stateless operators - if (backward_ctx_map.count("cpu") > 0) { - fcomp_t fcomp_back_cpu = backward_ctx_map.at("cpu"); - auto backward_cpu_lambda = [=](const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - CustomFComputeDispatcher(name_str, callFComp, fcomp_back_cpu, &attrs, - nullptr, 0, nullptr, ctx, inputs, req, outputs); - }; - gradOp.set_attr("FComputeEx", backward_cpu_lambda, plevel); - } - if (backward_ctx_map.count("gpu") > 0) { - fcomp_t fcomp_back_gpu = backward_ctx_map.at("gpu"); - auto backward_gpu_lambda = [=](const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - CustomFComputeDispatcher(name_str, callFComp, fcomp_back_gpu, &attrs, - nullptr, 0, nullptr, ctx, inputs, req, outputs); - }; - gradOp.set_attr("FComputeEx", backward_gpu_lambda, plevel); - } - } - } - regOp.add_argument("data", "NDArray[]", "Source inputs"); + registerOp(name, name_str, isSubgraphOp, resc_req, attr_parser, num_inputs, num_outputs, + num_inouts, infer_type, infer_shape, infer_storage_type, mutate_inputs, + num_subgraph_inputs, infer_subgraph_type, infer_subgraph_shape, + infer_subgraph_storage_type, create_opstate, grad_reg, mutate_fp, + createop_map, forward_ctx_map, backward_ctx_map, callFComp, callFStatefulComp, + msgSize, msgGet); } } -void registerPartitioners(void *lib, int verbose) { +void registerPartitioners(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, + mxnet::ext::msgGet_t msgGet) { + using namespace mxnet::ext; + // get C type interface functions opCallFree_t callFree = get_func(lib, const_cast(MXLIB_OPCALLFREE_STR)); @@ -1035,7 +1227,10 @@ void registerPartitioners(void *lib, int verbose) { } } -void registerPasses(void *lib, int verbose) { +void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, + mxnet::ext::msgGet_t msgGet) { + using namespace mxnet::ext; + // get C type interface functions opCallFree_t callFree = get_func(lib, const_cast(MXLIB_OPCALLFREE_STR)); @@ -1230,17 +1425,18 @@ void registerPasses(void *lib, int verbose) { }; char* out_json; - CHECK(callGraphPass(pass_fp, in_json.c_str(), &out_json, opt_keys.data(), - opt_vals.data(), opt_keys.size(), pass_name, - arg_names.data(), arg_names.size(), arg_data.data(), - arg_shapes.data(), arg_dims.data(), arg_types.data(), - arg_verIDs.data(), arg_dev_type.data(), - arg_dev_id.data(), aux_names.data(), aux_names.size(), - aux_data.data(), aux_shapes.data(), aux_dims.data(), - aux_types.data(), aux_verIDs.data(), - aux_dev_type.data(), aux_dev_id.data(), - ndarray_malloc, &ndarray_alloc)) - << "Error calling graph pass for '" << pass_name << "'"; + int retval = callGraphPass(pass_fp, in_json.c_str(), &out_json, opt_keys.data(), + opt_vals.data(), opt_keys.size(), pass_name, + arg_names.data(), arg_names.size(), arg_data.data(), + arg_shapes.data(), arg_dims.data(), arg_types.data(), + arg_verIDs.data(), arg_dev_type.data(), + arg_dev_id.data(), aux_names.data(), aux_names.size(), + aux_data.data(), aux_shapes.data(), aux_dims.data(), + aux_types.data(), aux_verIDs.data(), + aux_dev_type.data(), aux_dev_id.data(), + ndarray_malloc, &ndarray_alloc); + std::string msgs = getExtensionMsgs(msgSize, msgGet); + CHECK(retval) << "Error calling graph pass for '" << pass_name << "'" << msgs; std::string out_string(out_json); nnvm::Graph out_graph = nnvm::pass::LoadJSON(out_string); @@ -1271,21 +1467,31 @@ int MXLoadLib(const char *path, unsigned verbose) { LOG(FATAL) << "Unable to load library"; // check that library and MXNet use same version of library API - opVersion_t opVersion = get_func(lib, const_cast(MXLIB_OPVERSION_STR)); + mxnet::ext::opVersion_t opVersion = + get_func(lib, const_cast(MXLIB_OPVERSION_STR)); int libVersion = opVersion(); if (MX_LIBRARY_VERSION != libVersion) LOG(FATAL) << "Library version (" << libVersion << ") does not match MXNet version (" << MX_LIBRARY_VERSION << ")"; + // get error messaging APIs + mxnet::ext::msgSize_t msgSize = + get_func(lib, const_cast(MXLIB_MSGSIZE_STR)); + mxnet::ext::msgGet_t msgGet = + get_func(lib, const_cast(MXLIB_MSGGET_STR)); + // initialize library by passing MXNet version - initialize_t initialize = get_func(lib, const_cast(MXLIB_INITIALIZE_STR)); - if (!initialize(static_cast(MXNET_VERSION))) - LOG(FATAL) << "Library failed to initialize"; + mxnet::ext::initialize_t initialize = + get_func(lib, const_cast(MXLIB_INITIALIZE_STR)); + if (!initialize(static_cast(MXNET_VERSION))) { + std::string msgs = getExtensionMsgs(msgSize, msgGet); + LOG(FATAL) << "Library failed to initialize" << msgs; + } // find ops, partitioners, and passes in library - registerOperators(lib, verbose); - registerPartitioners(lib, verbose); - registerPasses(lib, verbose); + registerOperators(lib, verbose, msgSize, msgGet); + registerPartitioners(lib, verbose, msgSize, msgGet); + registerPasses(lib, verbose, msgSize, msgGet); API_END(); } diff --git a/src/operator/subgraph/partitioner/custom_subgraph_property.h b/src/operator/subgraph/partitioner/custom_subgraph_property.h index ea721c5aa71a..b936b050cdba 100644 --- a/src/operator/subgraph/partitioner/custom_subgraph_property.h +++ b/src/operator/subgraph/partitioner/custom_subgraph_property.h @@ -49,12 +49,12 @@ namespace op { class CustomContainOpSelector: public SubgraphSelector { public: explicit CustomContainOpSelector(std::unordered_map supported_nodes, - void* sel_inst, partCallSelect_t callSelect, - partCallSelectInput_t callSelectInput, - partCallSelectOutput_t callSelectOutput, - partCallFilter_t callFilter, - partCallReset_t callReset, - opCallFree_t callFree, + void* sel_inst, mxnet::ext::partCallSelect_t callSelect, + mxnet::ext::partCallSelectInput_t callSelectInput, + mxnet::ext::partCallSelectOutput_t callSelectOutput, + mxnet::ext::partCallFilter_t callFilter, + mxnet::ext::partCallReset_t callReset, + mxnet::ext::opCallFree_t callFree, std::unordered_map node2id) : supported_nodes_(supported_nodes), sel_inst_(sel_inst), callSelect_(callSelect), callSelectInput_(callSelectInput), callSelectOutput_(callSelectOutput), @@ -123,12 +123,12 @@ class CustomContainOpSelector: public SubgraphSelector { std::unordered_map supported_nodes_; void* sel_inst_; - partCallSelect_t callSelect_; - partCallSelectInput_t callSelectInput_; - partCallSelectOutput_t callSelectOutput_; - partCallFilter_t callFilter_; - partCallReset_t callReset_; - opCallFree_t callFree_; + mxnet::ext::partCallSelect_t callSelect_; + mxnet::ext::partCallSelectInput_t callSelectInput_; + mxnet::ext::partCallSelectOutput_t callSelectOutput_; + mxnet::ext::partCallFilter_t callFilter_; + mxnet::ext::partCallReset_t callReset_; + mxnet::ext::opCallFree_t callFree_; std::unordered_map node2id_; }; @@ -155,18 +155,18 @@ class CustomSubgraphProperty: public SubgraphProperty { review_subgraph_(nullptr), subgraph_op_name("error") {} CustomSubgraphProperty(std::string subgraph_prop_name, - partCallSupportedOps_t call_supported_ops, - supportedOps_t supported_ops, - partCallCreateSelector_t call_create_selector, - createSelector_t create_selector, - partCallSelect_t callSelect, - partCallSelectInput_t callSelectInput, - partCallSelectOutput_t callSelectOutput, - partCallFilter_t callFilter, - partCallReset_t callReset, - partCallReviewSubgraph_t call_review_subgraph, - reviewSubgraph_t review_subgraph, - opCallFree_t call_free, + mxnet::ext::partCallSupportedOps_t call_supported_ops, + mxnet::ext::supportedOps_t supported_ops, + mxnet::ext::partCallCreateSelector_t call_create_selector, + mxnet::ext::createSelector_t create_selector, + mxnet::ext::partCallSelect_t callSelect, + mxnet::ext::partCallSelectInput_t callSelectInput, + mxnet::ext::partCallSelectOutput_t callSelectOutput, + mxnet::ext::partCallFilter_t callFilter, + mxnet::ext::partCallReset_t callReset, + mxnet::ext::partCallReviewSubgraph_t call_review_subgraph, + mxnet::ext::reviewSubgraph_t review_subgraph, + mxnet::ext::opCallFree_t call_free, std::string op_name) : subgraph_prop(subgraph_prop_name), call_supported_ops_(call_supported_ops), @@ -429,7 +429,7 @@ class CustomSubgraphProperty: public SubgraphProperty { if (e.node->attrs.dict.count(MX_STR_SHAPE) > 0) { std::string& shape = e.node->attrs.dict[MX_STR_SHAPE]; // add this shape to the list - ss << getShapeAt(shape, e.index); + ss << mxnet::ext::getShapeAt(shape, e.index); } if (i < sym.outputs.size()-1) ss << ","; @@ -446,7 +446,7 @@ class CustomSubgraphProperty: public SubgraphProperty { if (e.node->attrs.dict.count(MX_STR_DTYPE) > 0) { std::string& dtype = e.node->attrs.dict[MX_STR_DTYPE]; // add this dtype to the list - ss << getDtypeAt(dtype, e.index); + ss << mxnet::ext::getDtypeAt(dtype, e.index); } if (i < sym.outputs.size()-1) ss << ","; @@ -489,7 +489,7 @@ class CustomSubgraphProperty: public SubgraphProperty { // get dtype string from other node std::string& dtype = orig.node->attrs.dict[MX_STR_DTYPE]; std::stringstream ss; - ss << "[" << getDtypeAt(dtype, orig.index) << "]"; + ss << "[" << mxnet::ext::getDtypeAt(dtype, orig.index) << "]"; e->node->attrs.dict[MX_STR_DTYPE] = ss.str(); } @@ -498,7 +498,7 @@ class CustomSubgraphProperty: public SubgraphProperty { std::string& shape = orig.node->attrs.dict[MX_STR_SHAPE]; // create new shape string for this node std::stringstream ss; - ss << "[" << getShapeAt(shape, orig.index) << "]"; + ss << "[" << mxnet::ext::getShapeAt(shape, orig.index) << "]"; e->node->attrs.dict[MX_STR_SHAPE] = ss.str(); } } @@ -512,18 +512,18 @@ class CustomSubgraphProperty: public SubgraphProperty { } std::string subgraph_prop; - partCallSupportedOps_t call_supported_ops_; - supportedOps_t supported_ops_; - partCallCreateSelector_t call_create_selector_; - createSelector_t create_selector_; - partCallSelect_t callSelect_; - partCallSelectInput_t callSelectInput_; - partCallSelectOutput_t callSelectOutput_; - partCallFilter_t callFilter_; - partCallReset_t callReset_; - partCallReviewSubgraph_t call_review_subgraph_; - reviewSubgraph_t review_subgraph_; - opCallFree_t call_free_; + mxnet::ext::partCallSupportedOps_t call_supported_ops_; + mxnet::ext::supportedOps_t supported_ops_; + mxnet::ext::partCallCreateSelector_t call_create_selector_; + mxnet::ext::createSelector_t create_selector_; + mxnet::ext::partCallSelect_t callSelect_; + mxnet::ext::partCallSelectInput_t callSelectInput_; + mxnet::ext::partCallSelectOutput_t callSelectOutput_; + mxnet::ext::partCallFilter_t callFilter_; + mxnet::ext::partCallReset_t callReset_; + mxnet::ext::partCallReviewSubgraph_t call_review_subgraph_; + mxnet::ext::reviewSubgraph_t review_subgraph_; + mxnet::ext::opCallFree_t call_free_; std::unordered_map supported_nodes; std::string subgraph_op_name; std::vector> options_map_;