From 0f0f667b396081764c8689d924ecb12607701214 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Tue, 21 Apr 2020 09:50:52 -0700 Subject: [PATCH] [WIP] MXNet Extensions enhancements (#17885) * add debug prints to debug error in CI * add debug prints to debug error in CI * remove prints * initial commit * enabled calling create for selector * connected selector to call external class * added code to remove temp graph attrs * fixed build issues * changed shape inference to use different attr names * fixed selector class * cleaned up APIs * fixed sanity * updated build for extensions * sanity fix * refactored MXLoadLib into separate functions * undo rebase * finished merge * enabled verbose in library loading * fixed example * added passing args/aux down to graph pass * added creating new args/aux for graph passes * fixed return args/aux * fixed sanity * whitespace * fixed lint * updated perl API, README, added pass_lib to cmake build flow * fixed mistake with relu example lib * fixed perl syntax * addressed comments * addressed more comments * fixed compile issues Co-authored-by: Ubuntu Co-authored-by: Ubuntu --- CMakeLists.txt | 18 +- Makefile | 17 +- example/extensions/lib_api/init_lib.cc | 2 +- example/extensions/lib_api/test_loading.py | 10 + example/extensions/lib_custom_op/README.md | 57 +- example/extensions/lib_custom_op/gemm_lib.cc | 118 +-- example/extensions/lib_custom_op/relu_lib.cu | 278 +++--- .../lib_custom_op/transposecsr_lib.cc | 113 +-- .../lib_custom_op/transposerowsp_lib.cc | 113 +-- example/extensions/lib_pass/Makefile | 24 + example/extensions/lib_pass/README.md | 190 ++++ example/extensions/lib_pass/pass_lib.cc | 104 +++ example/extensions/lib_pass/test_pass.py | 98 ++ example/extensions/lib_subgraph/README.md | 4 +- .../extensions/lib_subgraph/subgraph_lib.cc | 133 ++- .../extensions/lib_subgraph/test_subgraph.py | 191 ++-- include/mxnet/c_api.h | 11 +- include/mxnet/lib_api.h | 855 ++++++++++++------ perl-package/AI-MXNetCAPI/mxnet.i | 26 +- python/mxnet/library.py | 13 +- python/mxnet/symbol/symbol.py | 40 +- src/c_api/c_api.cc | 386 +++++++- src/c_api/c_api_executor.cc | 1 - src/c_api/c_api_symbolic.cc | 76 +- src/executor/infer_graph_attr_pass.cc | 1 + .../partitioner/custom_subgraph_property.h | 196 +++- 26 files changed, 2246 insertions(+), 829 deletions(-) create mode 100644 example/extensions/lib_pass/Makefile create mode 100644 example/extensions/lib_pass/README.md create mode 100644 example/extensions/lib_pass/pass_lib.cc create mode 100644 example/extensions/lib_pass/test_pass.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 103faf086483..437d01668246 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -733,18 +733,34 @@ endif() # extension libraries (custom operators, custom subgraphs) are built by default add_library(customop_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/gemm_lib.cc) +add_library(transposecsr_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/transposecsr_lib.cc) +add_library(transposerowsp_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/transposerowsp_lib.cc) add_library(subgraph_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_subgraph/subgraph_lib.cc) +add_library(pass_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_pass/pass_lib.cc) target_include_directories(customop_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) +target_include_directories(transposecsr_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) +target_include_directories(transposerowsp_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) target_include_directories(subgraph_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) +target_include_directories(pass_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) if(USE_CUDA) add_library(customop_gpu_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/relu_lib.cu) target_include_directories(customop_gpu_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) endif() -if(MSVC) +if(UNIX) + if (USE_CUDA) + target_compile_options(customop_gpu_lib PUBLIC -shared) + endif() +elseif(MSVC) target_compile_options(customop_lib PUBLIC /LD) + target_compile_options(transposecsr_lib PUBLIC /LD) + target_compile_options(transposerowsp_lib PUBLIC /LD) target_compile_options(subgraph_lib PUBLIC /LD) + target_compile_options(pass_lib PUBLIC /LD) set_target_properties(customop_lib PROPERTIES PREFIX "lib") + set_target_properties(transposecsr_lib PROPERTIES PREFIX "lib") + set_target_properties(transposerowsp_lib PROPERTIES PREFIX "lib") set_target_properties(subgraph_lib PROPERTIES PREFIX "lib") + set_target_properties(pass_lib PROPERTIES PREFIX "lib") if(USE_CUDA) target_compile_options(customop_gpu_lib PUBLIC "$<$:-Xcompiler=-fPIC>") set_target_properties(customop_gpu_lib PROPERTIES PREFIX "lib") diff --git a/Makefile b/Makefile index 8c478d61d11a..c050dae5e45a 100644 --- a/Makefile +++ b/Makefile @@ -667,7 +667,7 @@ pylint: python3 -m pylint --rcfile=$(ROOTDIR)/ci/other/pylintrc --ignore-patterns=".*\.so$$,.*\.dll$$,.*\.dylib$$" python/mxnet # MXNet extension dynamically loading libraries -EXT_LIBS = build/libcustomop_lib.so build/libsubgraph_lib.so +EXT_LIBS = build/libcustomop_lib.so build/libtransposecsr_lib.so build/libtransposerowsp_lib.so build/libsubgraph_lib.so build/libpass_lib.so ifeq ($(USE_CUDA), 1) EXT_LIBS += build/libcustomop_gpu_lib.so endif @@ -682,6 +682,21 @@ build/libcustomop_gpu_lib.so: build/libsubgraph_lib.so: @mkdir -p $(@D) $(CXX) -shared -fPIC -std=c++11 example/extensions/lib_subgraph/subgraph_lib.cc -o $@ -I include/mxnet +build/libtransposecsr_lib.so: + @mkdir -p $(@D) + $(CXX) -shared -fPIC -std=c++11 example/extensions/lib_custom_op/transposecsr_lib.cc -o $@ -I include/mxnet +build/libtransposerowsp_lib.so: + @mkdir -p $(@D) + $(CXX) -shared -fPIC -std=c++11 example/extensions/lib_custom_op/transposerowsp_lib.cc -o $@ -I include/mxnet +build/libcustomop_gpu_lib.so: + @mkdir -p $(@D) + $(NVCC) -shared -std=c++11 -Xcompiler -fPIC example/extensions/lib_custom_op/relu_lib.cu -o $@ -I include/mxnet +build/libsubgraph_lib.so: + @mkdir -p $(@D) + $(CXX) -shared -fPIC -std=c++11 example/extensions/lib_subgraph/subgraph_lib.cc -o $@ -I include/mxnet +build/libpass_lib.so: + @mkdir -p $(@D) + $(CXX) -shared -fPIC -std=c++11 example/extensions/lib_pass/pass_lib.cc -o $@ -I include/mxnet # Cython build cython: diff --git a/example/extensions/lib_api/init_lib.cc b/example/extensions/lib_api/init_lib.cc index 6a040ffa2ecb..fb3a10457cf5 100644 --- a/example/extensions/lib_api/init_lib.cc +++ b/example/extensions/lib_api/init_lib.cc @@ -27,7 +27,7 @@ #include "lib_api.h" MXReturnValue initialize(int version) { - if (version >= 10400) { + if (version >= 10700) { std::cout << "MXNet version " << version << " supported" << std::endl; return MX_SUCCESS; } else { diff --git a/example/extensions/lib_api/test_loading.py b/example/extensions/lib_api/test_loading.py index d2fc2185716c..7026ec6db6a9 100644 --- a/example/extensions/lib_api/test_loading.py +++ b/example/extensions/lib_api/test_loading.py @@ -25,9 +25,19 @@ import mxnet as mx import os +# test loading library if (os.name=='posix'): path = os.path.abspath('libinit_lib.so') mx.library.load(path) elif (os.name=='nt'): path = os.path.abspath('libinit_lib.dll') mx.library.load(path) + +# test loading library with verbose=False +if (os.name=='posix'): + path = os.path.abspath('libinit_lib.so') + mx.library.load(path, False) +elif (os.name=='nt'): + path = os.path.abspath('libinit_lib.dll') + mx.library.load(path, False) + diff --git a/example/extensions/lib_custom_op/README.md b/example/extensions/lib_custom_op/README.md index 26fdc4f0e70d..856c303a6025 100644 --- a/example/extensions/lib_custom_op/README.md +++ b/example/extensions/lib_custom_op/README.md @@ -22,15 +22,13 @@ C++ Custom Operator Example and Tutorial Adding new operators in MXNet requires understanding of MXNet backend operator registration and recompiling of MXNet with all its dependencies. Users can use the old Python custom operator to add new operators, but it is slow, complicated and has poor adoption rate. So our approach for adding custom operators is to enable dynamic loading of C++ custom operators compiled in external libraries at runtime. -Custom operators (CustomOp) enable users to write new operators without compiling against all of MXNet header files and dependencies. When a library containing custom operators is loaded dynamically, the operators found in the library will be re-registered in MXNet so that users can call those operators natively just like other built-in operators. +Custom operators (CustomOp) enable users to write new operators without compiling against all of MXNet header files and dependencies. When a library containing custom operators is loaded dynamically, the operators found in the library will be registered in MXNet so that users can call those operators natively just like other built-in operators. ## Getting Started ### Have MXNet Ready -Custom Operator support was merged (#15921, #17270) and is not available in versions of MXNet prior to v1.7.0. -To access the feature now, please install MXNet by compiling from source using master or using the previously mentioned commits, downloading one of the nightly builds, or from a release of MXNet 1.7.0+. -For running the following example, it doesn’t matter if it is a CUDA, MKLDNN or plain MXNet build; the custom operator doesn’t interact with the execution of other native MXNet operators. +To run the following example, the build type of MXNet doesn’t matter since the custom operator doesn’t interact with the execution of other native MXNet operators. Note that if you want to run GPU examples and write your custom operators running on GPU, you still need an MXNet CUDA build. ### Run An Example @@ -117,8 +115,7 @@ There are several required building blocks for making a custom operator: ```c++ MXReturnValue parseAttrs( - std::map attrs, + const std::unordered_map& attrs, int* num_in, int* num_out) ``` @@ -129,9 +126,9 @@ There are several required building blocks for making a custom operator: ```c++ MXReturnValue inferType( - std::map attrs, - std::vector &intypes, - std::vector &outtypes) + const std::unordered_map& attrs, + std::vector* intypes, + std::vector* outtypes) ``` * [inferShape](./gemm_lib.cc#L143): @@ -139,9 +136,9 @@ There are several required building blocks for making a custom operator: ```c++ MXReturnValue inferShape( - std::map attrs, - std::vector> &inshapes, - std::vector> &outshapes) + const std::unordered_map& attrs, + std::vector>* inshapes, + std::vector>* outshapes) ``` * [forward](./gemm_lib.cc#L56): @@ -149,10 +146,10 @@ There are several required building blocks for making a custom operator: ```c++ MXReturnValue forward( - std::map attrs, - std::vector inputs, - std::vector outputs, - OpResource res) + const std::unordered_map& attrs, + std::vector* inputs, + std::vector* outputs, + const OpResource& res) ``` Also there are some optional functions you can specify: @@ -162,10 +159,21 @@ Also there are some optional functions you can specify: ```c++ MXReturnValue backward( - std::map attrs, - std::vector inputs, - std::vector outputs, - OpResource res) + const std::unordered_map& attrs, + std::vector* inputs, + std::vector* outputs, + const OpResource& res) +``` + +* [inferSType](./transposecsr_lib.cc#168) - Storage Type Inference: + * This function specifies how the custom operator infers storage types for inputs and outputs. + +```c++ + MXReturnValue inferSType( + const std::unordered_map& attrs, + std::vector* inputs, + std::vector* outputs, + const OpResource& res) ``` * [mutateInputs](./gemm_lib.cc#L214) - Specify mutable input: @@ -173,8 +181,8 @@ Also there are some optional functions you can specify: ```c++ MXReturnValue mutateInputs( - std::map attrs, - std::vector &input_indices) + const std::unordered_map& attrs, + std::vector* input_indices) ``` After specifying those functions, register the custom opeartor with MXNet: @@ -200,6 +208,9 @@ If the number of input and output tensors are fixed, you can use hard-coded numb * **inferType**: This function takes three arguments. The 1st argument is the attributes (same as above). The 2nd argument is the a list of input data types corresponding to the input tensors. The 3rd argument is the placeholder for output tensor data types you need to assign. For example, if this operator has one input and one output, and data type doesn’t change, then you can do `outtypes[0] = intypes[0]` to populate the data type. +* **inferSType**: This function takes three arguments. The 1st argument is the attributes (same as above). The 2nd argument is the a list of input storage types corresponding to the input tensors. The 3rd argument is the placeholder for output storage types you need to assign. +For example, if this operator has one input and one output, and data type doesn’t change, then you can do `outtypes[0] = intypes[0]` to populate the data type. + * **inferShape**: This function is similar to the `inferType` function, except it is used for populating the output data shapes. You need to figure out the shapes of each output tensors for this computation. For example, if the inputs are images with shape (224,224,3) and you write a padding operator to make 10px borders for the images, then your output shape will be (234,234,3). @@ -285,7 +296,7 @@ As a result, you don’t need to call `cudaMemcpy` to move the tensor data to th } ``` -Note that the `cuda_stream` object used for launching kernels is passed from MXNet backend via `OpResource` object. See below for details of `Operator Resource`. +Note that the `cuda_stream` object used for launching kernels is passed from MXNet backend via `OpResource` object. See below for details of `Operator Resource`. You need to compile the `lib_api.h` header file with `nvcc` if you plan to create a custom GPU operator to enable the GPU support in the APIs. Also, `in_data` and `out_data` are pointers to the tensor data allocated on the GPU, so you can pass them directly to your CUDA kernel. At this point all the attribute functions for each operator (`parseAttrs`, `inferShape`, etc.) run on the CPU, including the `forwardGPU` function. The only part that will actually run on the GPU is the launched CUDA kernel function. diff --git a/example/extensions/lib_custom_op/gemm_lib.cc b/example/extensions/lib_custom_op/gemm_lib.cc index daeac337f4d6..4f8dabadc6a1 100644 --- a/example/extensions/lib_custom_op/gemm_lib.cc +++ b/example/extensions/lib_custom_op/gemm_lib.cc @@ -53,23 +53,23 @@ void transpose(const float* A, float* At, const unsigned n, const unsigned m) { * Executes C = A * B * inputs[0] = A; inputs[1] = B; outputs[0] = C */ -MXReturnValue forward(std::map attrs, - std::vector inputs, - std::vector outputs, - OpResource res) { +MXReturnValue forward(const std::unordered_map& attrs, + std::vector* inputs, + std::vector* outputs, + const OpResource& res) { // simple example of using runtime data type - if (inputs[0].dtype == kFloat32) { + if (inputs->at(0).dtype == kFloat32) { typedef float DType; // extract data pointers from tensors // if using dltensor repr, below lines can be changed to something like // DType* A = reinterpret_cast(inputs[0].dltensor.data); - DType* A = inputs[0].data(); - DType* B = inputs[1].data(); - DType* C = outputs[0].data(); + DType* A = inputs->at(0).data(); + DType* B = inputs->at(1).data(); + DType* C = outputs->at(0).data(); // set tensor shapes - unsigned n = inputs[0].shape[0]; - unsigned k = inputs[0].shape[1]; - unsigned m = inputs[1].shape[1]; + unsigned n = inputs->at(0).shape[0]; + unsigned k = inputs->at(0).shape[1]; + unsigned m = inputs->at(1).shape[1]; gemm(A, B, C, n, k, m); } @@ -87,20 +87,20 @@ MXReturnValue forward(std::map attrs, ***** gradient outputs * outputs[0] = dA; outputs[1] = dB */ -MXReturnValue backward(std::map attrs, - std::vector inputs, - std::vector outputs, - OpResource res) { +MXReturnValue backward(const std::unordered_map& attrs, + std::vector* inputs, + std::vector* outputs, + const OpResource& res) { // extract data pointers from tensors - float* dC = inputs[0].data(); - float* A = inputs[1].data(); - float* B = inputs[2].data(); - float* dA = outputs[0].data(); - float* dB = outputs[1].data(); + float* dC = inputs->at(0).data(); + float* A = inputs->at(1).data(); + float* B = inputs->at(2).data(); + float* dA = outputs->at(0).data(); + float* dB = outputs->at(1).data(); // set tensor shapes - unsigned n = inputs[1].shape[0]; - unsigned k = inputs[1].shape[1]; - unsigned m = inputs[2].shape[1]; + unsigned n = inputs->at(1).shape[0]; + unsigned k = inputs->at(1).shape[1]; + unsigned m = inputs->at(2).shape[1]; // allocate temporary workspace memory through resource manager // for multiple arrays better to request a big memory pool void *workspace = res.alloc_cpu((k*n + m*k) * sizeof(float)); @@ -115,54 +115,55 @@ MXReturnValue backward(std::map attrs, return MX_SUCCESS; } -MXReturnValue parseAttrs(std::map attrs, int* num_in, int* num_out) { +MXReturnValue parseAttrs(const std::unordered_map& attrs, + int* num_in, int* num_out) { *num_in = 2; *num_out = 1; return MX_SUCCESS; } -MXReturnValue inferType(std::map attrs, - std::vector &intypes, - std::vector &outtypes) { +MXReturnValue inferType(const std::unordered_map& attrs, + std::vector *intypes, + std::vector *outtypes) { // validate inputs - if (intypes.size() != 2) { + if (intypes->size() != 2) { std::cout << "Expected 2 inputs to inferType" << std::endl; return MX_FAIL; } - for (unsigned i = 0; i < intypes.size(); i++) { - if (intypes[i] != kFloat32) { + for (unsigned i = 0; i < intypes->size(); i++) { + if (intypes->at(i) != kFloat32) { std::cout << "Expected input " << i << " to have float32 type" << std::endl; return MX_FAIL; } } - outtypes[0] = intypes[0]; + outtypes->at(0) = intypes->at(0); return MX_SUCCESS; } -MXReturnValue inferShape(std::map attrs, - std::vector> &inshapes, - std::vector> &outshapes) { +MXReturnValue inferShape(const std::unordered_map& attrs, + std::vector>* inshapes, + std::vector>* outshapes) { // validate inputs - if (inshapes.size() != 2) { + if (inshapes->size() != 2) { std::cout << "Expected 2 inputs to inferShape" << std::endl; return MX_FAIL; } - if (inshapes[0].size() != 2 || inshapes[1].size() != 2) { + if (inshapes->at(0).size() != 2 || inshapes->at(1).size() != 2) { std::cout << "Expected 2D matrices for both inputs to inferShape" << std::endl; return MX_FAIL; } - unsigned n = inshapes[0][0]; - unsigned k = inshapes[0][1]; - unsigned kk = inshapes[1][0]; - unsigned m = inshapes[1][1]; + unsigned n = inshapes->at(0)[0]; + unsigned k = inshapes->at(0)[1]; + unsigned kk = inshapes->at(1)[0]; + unsigned m = inshapes->at(1)[1]; if (k != kk) { std::cout << "Exected first input axis 1 equals to second input axis 0" << std::endl; return MX_FAIL; } - outshapes[0] = {n, m}; + outshapes->at(0) = {n, m}; return MX_SUCCESS; } @@ -177,41 +178,42 @@ REGISTER_OP(my_gemm) class MyStatefulGemm : public CustomStatefulOp { public: - explicit MyStatefulGemm(int count) : count(count) {} + explicit MyStatefulGemm(int count, + const std::unordered_map& attrs) + : count(count), attrs_(attrs) {} - MXReturnValue Forward(std::vector inputs, - std::vector outputs, - OpResource op_res) { + MXReturnValue Forward(std::vector* inputs, + std::vector* outputs, + const OpResource& op_res) { std::cout << "Info: keyword + number of forward: " << ++count << std::endl; - std::map attrs; - return forward(attrs, inputs, outputs, op_res); + return forward(attrs_, inputs, outputs, op_res); } - MXReturnValue Backward(std::vector inputs, - std::vector outputs, - OpResource op_res) { - std::map attrs; - return backward(attrs, inputs, outputs, op_res); + MXReturnValue Backward(std::vector* inputs, + std::vector* outputs, + const OpResource& op_res) { + return backward(attrs_, inputs, outputs, op_res); } ~MyStatefulGemm() {} private: int count; + const std::unordered_map attrs_; }; -MXReturnValue createOpState(std::map attrs, +MXReturnValue createOpState(const std::unordered_map& attrs, CustomStatefulOp** op_inst) { // testing passing of keyword arguments - int count = attrs.count("test_kw") > 0 ? std::stoi(attrs["test_kw"]) : 0; + int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0; // creating stateful operator instance - *op_inst = new MyStatefulGemm(count); + *op_inst = new MyStatefulGemm(count, attrs); std::cout << "Info: stateful operator created" << std::endl; return MX_SUCCESS; } -MXReturnValue mutateInputs(std::map attrs, - std::vector &input_indices) { +MXReturnValue mutateInputs(const std::unordered_map& attrs, + std::vector* input_indices) { // input_indices.push_back(1); // mark mutate input return MX_SUCCESS; } @@ -224,7 +226,7 @@ REGISTER_OP(state_gemm) .setCreateOpState(createOpState, "cpu"); MXReturnValue initialize(int version) { - if (version >= 10400) { + if (version >= 10700) { std::cout << "MXNet version " << version << " supported" << std::endl; return MX_SUCCESS; } else { diff --git a/example/extensions/lib_custom_op/relu_lib.cu b/example/extensions/lib_custom_op/relu_lib.cu index 60112ee4b1e5..a4711cbeab67 100644 --- a/example/extensions/lib_custom_op/relu_lib.cu +++ b/example/extensions/lib_custom_op/relu_lib.cu @@ -29,93 +29,93 @@ #define NumThreadPerBlock 256 // mxnet recommended cuda thread number per block __global__ void relu_gpu_forward(float *out, float *in, int64_t N) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < N) - out[tid] = in[tid] > 0 ? in[tid] : 0; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < N) + out[tid] = in[tid] > 0 ? in[tid] : 0; } __global__ void relu_gpu_backward(float *ingrad, float *outgrad, float *indata, int64_t N) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < N) - ingrad[tid] = indata[tid] > 0 ? 1 * outgrad[tid] : 0; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < N) + ingrad[tid] = indata[tid] > 0 ? 1 * outgrad[tid] : 0; } -MXReturnValue forwardCPU(std::map attrs, - std::vector inputs, - std::vector outputs, - OpResource res) { - float* in_data = inputs[0].data(); - float* out_data = outputs[0].data(); - for (int i=0; i 0 ? in_data[i] : 0; - } - return MX_SUCCESS; +MXReturnValue forwardCPU(const std::unordered_map& attrs, + std::vector* inputs, + std::vector* outputs, + const OpResource& res) { + float* in_data = inputs->at(0).data(); + float* out_data = outputs->at(0).data(); + for (int i=0; iat(0).size(); i++) { + out_data[i] = in_data[i] > 0 ? in_data[i] : 0; + } + return MX_SUCCESS; } -MXReturnValue backwardCPU(std::map attrs, - std::vector inputs, - std::vector outputs, - OpResource res) { - float* out_grad = inputs[0].data(); - float* in_data = inputs[1].data(); - float* in_grad = outputs[0].data(); - for (int i=0; i 0 ? 1 * out_grad[i] : 0; - } - return MX_SUCCESS; +MXReturnValue backwardCPU(const std::unordered_map& attrs, + std::vector* inputs, + std::vector* outputs, + const OpResource& res) { + float* out_grad = inputs->at(0).data(); + float* in_data = inputs->at(1).data(); + float* in_grad = outputs->at(0).data(); + for (int i=0; iat(1).size(); i++) { + in_grad[i] = in_data[i] > 0 ? 1 * out_grad[i] : 0; + } + return MX_SUCCESS; } -MXReturnValue forwardGPU(std::map attrs, - std::vector inputs, - std::vector outputs, - OpResource res) { - float* in_data = inputs[0].data(); - float* out_data = outputs[0].data(); +MXReturnValue forwardGPU(const std::unordered_map& attrs, + std::vector* inputs, + std::vector* outputs, + const OpResource& res) { + float* in_data = inputs->at(0).data(); + float* out_data = outputs->at(0).data(); - mx_stream_t cuda_stream = res.get_cuda_stream(); - int64_t N = inputs[0].size(); - int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock; + mx_stream_t cuda_stream = res.get_cuda_stream(); + int64_t N = inputs->at(0).size(); + int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock; - relu_gpu_forward<<>>(out_data, in_data, N); + relu_gpu_forward<<>>(out_data, in_data, N); - return MX_SUCCESS; + return MX_SUCCESS; } -MXReturnValue backwardGPU(std::map attrs, - std::vector inputs, - std::vector outputs, - OpResource res) { - float* out_grad = inputs[0].data(); - float* in_data = inputs[1].data(); - float* in_grad = outputs[0].data(); - - mx_stream_t cuda_stream = res.get_cuda_stream(); - int64_t N = inputs[0].size(); - int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock; +MXReturnValue backwardGPU(const std::unordered_map& attrs, + std::vector* inputs, + std::vector* outputs, + const OpResource& res) { + float* out_grad = inputs->at(0).data(); + float* in_data = inputs->at(1).data(); + float* in_grad = outputs->at(0).data(); - relu_gpu_backward<<>>(in_grad, out_grad, in_data, N); + mx_stream_t cuda_stream = res.get_cuda_stream(); + int64_t N = inputs->at(0).size(); + int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock; + relu_gpu_backward<<>>(in_grad, out_grad, in_data, N); - return MX_SUCCESS; + return MX_SUCCESS; } -MXReturnValue parseAttrs(std::map attrs, int* num_in, int* num_out) { - *num_in = 1; - *num_out = 1; - return MX_SUCCESS; +MXReturnValue parseAttrs(const std::unordered_map& attrs, + int* num_in, int* num_out) { + *num_in = 1; + *num_out = 1; + return MX_SUCCESS; } -MXReturnValue inferType(std::map attrs, - std::vector &intypes, - std::vector &outtypes) { - outtypes[0] = intypes[0]; - return MX_SUCCESS; +MXReturnValue inferType(const std::unordered_map& attrs, + std::vector* intypes, + std::vector* outtypes) { + outtypes->at(0) = intypes->at(0); + return MX_SUCCESS; } -MXReturnValue inferShape(std::map attrs, - std::vector> &inshapes, - std::vector> &outshapes) { - outshapes[0] = inshapes[0]; - return MX_SUCCESS; +MXReturnValue inferShape(const std::unordered_map& attrs, + std::vector>* inshapes, + std::vector>* outshapes) { + outshapes->at(0) = inshapes->at(0); + return MX_SUCCESS; } REGISTER_OP(my_relu) @@ -128,51 +128,53 @@ REGISTER_OP(my_relu) .setBackward(backwardGPU, "gpu"); class MyStatefulReluCPU : public CustomStatefulOp { -public: - explicit MyStatefulReluCPU() {} - MXReturnValue Forward(std::vector inputs, - std::vector outputs, - OpResource op_res) { - std::map attrs; - return forwardCPU(attrs, inputs, outputs, op_res); + public: + explicit MyStatefulReluCPU(const std::unordered_map& attrs) + : attrs_(attrs) {} + MXReturnValue Forward(std::vector* inputs, + std::vector* outputs, + const OpResource& op_res) { + return forwardCPU(attrs_, inputs, outputs, op_res); } - MXReturnValue Backward(std::vector inputs, - std::vector outputs, - OpResource op_res) { - std::map attrs; - return backwardCPU(attrs, inputs, outputs, op_res); + MXReturnValue Backward(std::vector* inputs, + std::vector* outputs, + const OpResource& op_res) { + return backwardCPU(attrs_, inputs, outputs, op_res); } ~MyStatefulReluCPU() {} + private: + const std::unordered_map attrs_; }; class MyStatefulReluGPU : public CustomStatefulOp { -public: - explicit MyStatefulReluGPU() {} - MXReturnValue Forward(std::vector inputs, - std::vector outputs, - OpResource op_res) { - std::map attrs; - return forwardGPU(attrs, inputs, outputs, op_res); + public: + explicit MyStatefulReluGPU(const std::unordered_map& attrs) + : attrs_(attrs) {} + MXReturnValue Forward(std::vector* inputs, + std::vector* outputs, + const OpResource& op_res) { + return forwardGPU(attrs_, inputs, outputs, op_res); } - MXReturnValue Backward(std::vector inputs, - std::vector outputs, - OpResource op_res) { - std::map attrs; - return backwardGPU(attrs, inputs, outputs, op_res); + MXReturnValue Backward(std::vector* inputs, + std::vector* outputs, + const OpResource& op_res) { + return backwardGPU(attrs_, inputs, outputs, op_res); } ~MyStatefulReluGPU() {} + private: + const std::unordered_map attrs_; }; -MXReturnValue createOpStateCPU(std::map attrs, +MXReturnValue createOpStateCPU(const std::unordered_map& attrs, CustomStatefulOp** op_inst) { - *op_inst = new MyStatefulReluCPU(); - return MX_SUCCESS; + *op_inst = new MyStatefulReluCPU(attrs); + return MX_SUCCESS; } -MXReturnValue createOpStateGPU(std::map attrs, +MXReturnValue createOpStateGPU(const std::unordered_map& attrs, CustomStatefulOp** op_inst) { - *op_inst = new MyStatefulReluGPU(); - return MX_SUCCESS; + *op_inst = new MyStatefulReluGPU(attrs); + return MX_SUCCESS; } REGISTER_OP(my_state_relu) @@ -205,46 +207,46 @@ __global__ void noisy_relu_gpu_forward(float *out, float *in, int64_t N, mx_gpu_ } } -MXReturnValue noisyForwardCPU(std::map attrs, - std::vector inputs, - std::vector outputs, - OpResource res) { - float* in_data = inputs[0].data(); - float* out_data = outputs[0].data(); - - mx_cpu_rand_t* states = res.get_cpu_rand_states(); - std::normal_distribution dist_normal; - - for (int i=0; i 0 ? in_data[i] + noise : 0; - } - return MX_SUCCESS; +MXReturnValue noisyForwardCPU(const std::unordered_map& attrs, + std::vector* inputs, + std::vector* outputs, + const OpResource& res) { + float* in_data = inputs->at(0).data(); + float* out_data = outputs->at(0).data(); + + mx_cpu_rand_t* states = res.get_cpu_rand_states(); + std::normal_distribution dist_normal; + + for (int i=0; iat(0).size(); ++i) { + float noise = dist_normal(*states); + out_data[i] = in_data[i] + noise > 0 ? in_data[i] + noise : 0; + } + return MX_SUCCESS; } -MXReturnValue noisyForwardGPU(std::map attrs, - std::vector inputs, - std::vector outputs, - OpResource res) { - float* in_data = inputs[0].data(); - float* out_data = outputs[0].data(); - - mx_stream_t cuda_stream = res.get_cuda_stream(); - int64_t N = inputs[0].size(); - - // below is mxnet recommended workflow to parallel random number generating - int nthread = (N + NumRandomPerThread - 1) / NumRandomPerThread; - // we should not launch more threads than mxnet supported random number GPU states - int num_thread_need = nthread < MX_NUM_GPU_RANDOM_STATES ? nthread : MX_NUM_GPU_RANDOM_STATES; - // each cuda thread processes [step * tid, step * id + step) snippet of input tensor - int step = (N + num_thread_need - 1) / num_thread_need; - // this can ensure number of parallel threads less than mxnet supported random number states - int num_block = (num_thread_need + NumThreadPerBlock - 1) / NumThreadPerBlock; - - noisy_relu_gpu_forward<<>>( +MXReturnValue noisyForwardGPU(const std::unordered_map& attrs, + std::vector* inputs, + std::vector* outputs, + const OpResource& res) { + float* in_data = inputs->at(0).data(); + float* out_data = outputs->at(0).data(); + + mx_stream_t cuda_stream = res.get_cuda_stream(); + int64_t N = inputs->at(0).size(); + + // below is mxnet recommended workflow to parallel random number generating + int nthread = (N + NumRandomPerThread - 1) / NumRandomPerThread; + // we should not launch more threads than mxnet supported random number GPU states + int num_thread_need = nthread < MX_NUM_GPU_RANDOM_STATES ? nthread : MX_NUM_GPU_RANDOM_STATES; + // each cuda thread processes [step * tid, step * id + step) snippet of input tensor + int step = (N + num_thread_need - 1) / num_thread_need; + // this can ensure number of parallel threads less than mxnet supported random number states + int num_block = (num_thread_need + NumThreadPerBlock - 1) / NumThreadPerBlock; + + noisy_relu_gpu_forward<<>>( out_data, in_data, N, res.get_gpu_rand_states(), step); - return MX_SUCCESS; + return MX_SUCCESS; } REGISTER_OP(my_noisy_relu) @@ -257,11 +259,11 @@ REGISTER_OP(my_noisy_relu) .setBackward(backwardGPU, "gpu"); MXReturnValue initialize(int version) { - if (version >= 10400) { - std::cout << "MXNet version " << version << " supported" << std::endl; - return MX_SUCCESS; - } else { - std::cout << "MXNet version " << version << " not supported" << std::endl; - return MX_FAIL; - } + if (version >= 10700) { + std::cout << "MXNet version " << version << " supported" << std::endl; + return MX_SUCCESS; + } else { + std::cout << "MXNet version " << version << " not supported" << std::endl; + return MX_FAIL; + } } diff --git a/example/extensions/lib_custom_op/transposecsr_lib.cc b/example/extensions/lib_custom_op/transposecsr_lib.cc index 0daeb3e9f83e..224cd6aa81b6 100644 --- a/example/extensions/lib_custom_op/transposecsr_lib.cc +++ b/example/extensions/lib_custom_op/transposecsr_lib.cc @@ -26,7 +26,7 @@ #include #include "lib_api.h" -void transpose(MXTensor src, MXTensor dst, OpResource res) { +void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) { MXSparse* A = src.data(); MXSparse* B = dst.data(); std::vector shape = src.shape; @@ -63,76 +63,78 @@ void transpose(MXTensor src, MXTensor dst, OpResource res) { } } -MXReturnValue forward(std::map attrs, - std::vector inputs, - std::vector outputs, - OpResource res) { +MXReturnValue forward(const std::unordered_map& attrs, + std::vector* inputs, + std::vector* outputs, + const OpResource& res) { // The data types and storage types of inputs and outputs should be the same. - if(inputs[0].dtype != outputs[0].dtype || inputs[0].stype != outputs[0].stype) { + if(inputs->at(0).dtype != outputs->at(0).dtype || + inputs->at(0).stype != outputs->at(0).stype) { std::cout << "Error! Expected all inputs and outputs to be the same type." - << "Found input storage type:" << inputs[0].stype - << " Found output storage type:" << outputs[0].stype - << " Found input data type:" << inputs[0].dtype - << " Found output data type:" << outputs[0].dtype << std::endl; + << "Found input storage type:" << inputs->at(0).stype + << " Found output storage type:" << outputs->at(0).stype + << " Found input data type:" << inputs->at(0).dtype + << " Found output data type:" << outputs->at(0).dtype << std::endl; return MX_FAIL; } - transpose(inputs[0], outputs[0], res); + transpose(inputs->at(0), outputs->at(0), res); return MX_SUCCESS; } -MXReturnValue backward(std::map attrs, - std::vector inputs, - std::vector outputs, - OpResource res) { +MXReturnValue backward(const std::unordered_map& attrs, + std::vector* inputs, + std::vector* outputs, + const OpResource& res) { return MX_SUCCESS; } -MXReturnValue parseAttrs(std::map attrs, int* num_in, int* num_out) { +MXReturnValue parseAttrs(const std::unordered_map& attrs, + int* num_in, int* num_out) { *num_in = 1; *num_out = 1; return MX_SUCCESS; } -MXReturnValue inferType(std::map attrs, - std::vector &intypes, - std::vector &outtypes) { +MXReturnValue inferType(const std::unordered_map& attrs, + std::vector* intypes, + std::vector* outtypes) { // validate inputs - if (intypes.size() != 1) { + if (intypes->size() != 1) { std::cout << "Expected 1 inputs to inferType" << std::endl; return MX_FAIL; } - if (intypes[0] != kFloat32) { + if (intypes->at(0) != kFloat32) { std::cout << "Expected input to have float32 type" << std::endl; return MX_FAIL; } - outtypes[0] = intypes[0]; + outtypes->at(0) = intypes->at(0); return MX_SUCCESS; } -MXReturnValue inferSType(std::map attrs, - std::vector &instypes, - std::vector &outstypes) { - if (instypes[0] != kCSRStorage) { +MXReturnValue inferSType(const std::unordered_map& attrs, + std::vector* instypes, + std::vector* outstypes) { + if (instypes->at(0) != kCSRStorage) { std::cout << "Expected storage type is kCSRStorage" << std::endl; return MX_FAIL; } - outstypes[0] = instypes[0]; + outstypes->at(0) = instypes->at(0); return MX_SUCCESS; } -MXReturnValue inferShape(std::map attrs, - std::vector> &inshapes, - std::vector> &outshapes) { +MXReturnValue inferShape(const std::unordered_map& attrs, + std::vector>* inshapes, + std::vector>* outshapes) { // validate inputs - if (inshapes.size() != 1) { + if (inshapes->size() != 1) { std::cout << "Expected 1 inputs to inferShape" << std::endl; return MX_FAIL; } - outshapes[0].push_back(inshapes[0][1]); - outshapes[0].push_back(inshapes[0][0]); + outshapes->at(0).push_back(inshapes->at(0)[1]); + outshapes->at(0).push_back(inshapes->at(0)[0]); return MX_SUCCESS; } @@ -147,34 +149,35 @@ REGISTER_OP(my_transposecsr) /* ------------------------------------------------------------------------- */ class MyStatefulTransposeCSR : public CustomStatefulOp { - public: - explicit MyStatefulTransposeCSR(int count) : count(count) {} - - MXReturnValue Forward(std::vector inputs, - std::vector outputs, - OpResource op_res) { - std::cout << "Info: keyword + number of forward: " << ++count << std::endl; - std::map attrs; - return forward(attrs, inputs, outputs, op_res); - } + public: + explicit MyStatefulTransposeCSR(int count, + const std::unordered_map& attrs) + : count(count), attrs_(attrs) {} + + MXReturnValue Forward(std::vector* inputs, + std::vector* outputs, + const OpResource& op_res) { + std::cout << "Info: keyword + number of forward: " << ++count << std::endl; + return forward(attrs_, inputs, outputs, op_res); + } - MXReturnValue Backward(std::vector inputs, - std::vector outputs, - OpResource op_res) { - std::map attrs; - return backward(attrs, inputs, outputs, op_res); - } + MXReturnValue Backward(std::vector* inputs, + std::vector* outputs, + const OpResource& op_res) { + return backward(attrs_, inputs, outputs, op_res); + } - private: - int count; + private: + int count; + const std::unordered_map attrs_; }; -MXReturnValue createOpState(std::map attrs, +MXReturnValue createOpState(const std::unordered_map& attrs, CustomStatefulOp** op_inst) { // testing passing of keyword arguments - int count = attrs.count("test_kw") > 0 ? std::stoi(attrs["test_kw"]) : 0; + int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0; // creating stateful operator instance - *op_inst = new MyStatefulTransposeCSR(count); + *op_inst = new MyStatefulTransposeCSR(count, attrs); std::cout << "Info: stateful operator created" << std::endl; return MX_SUCCESS; } @@ -187,7 +190,7 @@ REGISTER_OP(my_state_transposecsr) .setCreateOpState(createOpState, "cpu"); MXReturnValue initialize(int version) { - if (version >= 10400) { + if (version >= 10700) { std::cout << "MXNet version " << version << " supported" << std::endl; return MX_SUCCESS; } else { diff --git a/example/extensions/lib_custom_op/transposerowsp_lib.cc b/example/extensions/lib_custom_op/transposerowsp_lib.cc index 883d816cfa81..46d3c4d41a4c 100644 --- a/example/extensions/lib_custom_op/transposerowsp_lib.cc +++ b/example/extensions/lib_custom_op/transposerowsp_lib.cc @@ -26,7 +26,7 @@ #include #include "lib_api.h" -void transpose(MXTensor src, MXTensor dst, OpResource res) { +void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) { MXSparse* A = src.data(); MXSparse* B = dst.data(); @@ -66,75 +66,77 @@ void transpose(MXTensor src, MXTensor dst, OpResource res) { } } -MXReturnValue forward(std::map attrs, - std::vector inputs, - std::vector outputs, - OpResource res) { +MXReturnValue forward(const std::unordered_map& attrs, + std::vector* inputs, + std::vector* outputs, + const OpResource& res) { // The data types and storage types of inputs and outputs should be the same. - if(inputs[0].dtype != outputs[0].dtype || inputs[0].stype != outputs[0].stype) { + if(inputs->at(0).dtype != outputs->at(0).dtype || + inputs->at(0).stype != outputs->at(0).stype) { std::cout << "Error! Expected all inputs and outputs to be the same type." - << "Found input storage type:" << inputs[0].stype - << " Found output storage type:" << outputs[0].stype - << " Found input data type:" << inputs[0].dtype - << " Found output data type:" << outputs[0].dtype << std::endl; + << "Found input storage type:" << inputs->at(0).stype + << " Found output storage type:" << outputs->at(0).stype + << " Found input data type:" << inputs->at(0).dtype + << " Found output data type:" << outputs->at(0).dtype << std::endl; return MX_FAIL; } - transpose(inputs[0], outputs[0], res); + transpose(inputs->at(0), outputs->at(0), res); return MX_SUCCESS; } -MXReturnValue backward(std::map attrs, - std::vector inputs, - std::vector outputs, - OpResource res) { +MXReturnValue backward(const std::unordered_map& attrs, + std::vector* inputs, + std::vector* outputs, + const OpResource& res) { return MX_SUCCESS; } -MXReturnValue parseAttrs(std::map attrs, int* num_in, int* num_out) { +MXReturnValue parseAttrs(const std::unordered_map& attrs, + int* num_in, int* num_out) { *num_in = 1; *num_out = 1; return MX_SUCCESS; } -MXReturnValue inferType(std::map attrs, - std::vector &intypes, - std::vector &outtypes) { +MXReturnValue inferType(const std::unordered_map& attrs, + std::vector* intypes, + std::vector* outtypes) { // validate inputs - if (intypes.size() != 1) { + if (intypes->size() != 1) { std::cout << "Expected 1 inputs to inferType" << std::endl; return MX_FAIL; } - if (intypes[0] != kFloat32) { + if (intypes->at(0) != kFloat32) { std::cout << "Expected input to have float32 type" << std::endl; return MX_FAIL; } - outtypes[0] = intypes[0]; + outtypes->at(0) = intypes->at(0); return MX_SUCCESS; } -MXReturnValue inferSType(std::map attrs, - std::vector &instypes, - std::vector &outstypes) { - if (instypes[0] != kRowSparseStorage) { +MXReturnValue inferSType(const std::unordered_map& attrs, + std::vector* instypes, + std::vector* outstypes) { + if (instypes->at(0) != kRowSparseStorage) { std::cout << "Expected storage type is kRowSparseStorage" << std::endl; return MX_FAIL; } - outstypes[0] = instypes[0]; + outstypes->at(0) = instypes->at(0); return MX_SUCCESS; } -MXReturnValue inferShape(std::map attrs, - std::vector> &inshapes, - std::vector> &outshapes) { +MXReturnValue inferShape(const std::unordered_map& attrs, + std::vector>* inshapes, + std::vector>* outshapes) { // validate inputs - if (inshapes.size() != 1) { + if (inshapes->size() != 1) { std::cout << "Expected 1 inputs to inferShape" << std::endl; return MX_FAIL; } - outshapes[0].push_back(inshapes[0][1]); - outshapes[0].push_back(inshapes[0][0]); + outshapes->at(0).push_back(inshapes->at(0)[1]); + outshapes->at(0).push_back(inshapes->at(0)[0]); return MX_SUCCESS; } @@ -149,34 +151,35 @@ REGISTER_OP(my_transposerowsp) /* ------------------------------------------------------------------------- */ class MyStatefulTransposeRowSP : public CustomStatefulOp { - public: - explicit MyStatefulTransposeRowSP(int count) : count(count) {} - - MXReturnValue Forward(std::vector inputs, - std::vector outputs, - OpResource op_res) { - std::cout << "Info: keyword + number of forward: " << ++count << std::endl; - std::map attrs; - return forward(attrs, inputs, outputs, op_res); - } + public: + explicit MyStatefulTransposeRowSP(int count, + const std::unordered_map& attrs) + : count(count), attrs_(attrs) {} + + MXReturnValue Forward(std::vector* inputs, + std::vector* outputs, + const OpResource& op_res) { + std::cout << "Info: keyword + number of forward: " << ++count << std::endl; + return forward(attrs_, inputs, outputs, op_res); + } - MXReturnValue Backward(std::vector inputs, - std::vector outputs, - OpResource op_res) { - std::map attrs; - return backward(attrs, inputs, outputs, op_res); - } + MXReturnValue Backward(std::vector* inputs, + std::vector* outputs, + const OpResource& op_res) { + return backward(attrs_, inputs, outputs, op_res); + } - private: - int count; + private: + int count; + const std::unordered_map attrs_; }; -MXReturnValue createOpState(std::map attrs, +MXReturnValue createOpState(const std::unordered_map& attrs, CustomStatefulOp** op_inst) { // testing passing of keyword arguments - int count = attrs.count("test_kw") > 0 ? std::stoi(attrs["test_kw"]) : 0; + int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0; // creating stateful operator instance - *op_inst = new MyStatefulTransposeRowSP(count); + *op_inst = new MyStatefulTransposeRowSP(count, attrs); std::cout << "Info: stateful operator created" << std::endl; return MX_SUCCESS; } @@ -189,7 +192,7 @@ REGISTER_OP(my_state_transposerowsp) .setCreateOpState(createOpState, "cpu"); MXReturnValue initialize(int version) { - if (version >= 10400) { + if (version >= 10700) { std::cout << "MXNet version " << version << " supported" << std::endl; return MX_SUCCESS; } else { diff --git a/example/extensions/lib_pass/Makefile b/example/extensions/lib_pass/Makefile new file mode 100644 index 000000000000..759a08c48c89 --- /dev/null +++ b/example/extensions/lib_pass/Makefile @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +all: pass_lib + +pass_lib: + g++ -shared -fPIC -std=c++11 pass_lib.cc -o libpass_lib.so -I ../../../include/mxnet + +clean: + rm -rf libpass_lib.so diff --git a/example/extensions/lib_pass/README.md b/example/extensions/lib_pass/README.md new file mode 100644 index 000000000000..c2771242440f --- /dev/null +++ b/example/extensions/lib_pass/README.md @@ -0,0 +1,190 @@ + + + + + + + + + + + + + + + + + +Custom Graph Pass Example and Tutorial +======================================= + +## Introduction + +Adding custom graph passes in MXNet used to require deep understanding of the MXNet backend, including nnvm pass registration and other internal classes, followed by recompiling MXNet from source. This feature allows adding custom graph passes by dynamically loading external libraries at runtime. + +This custom graph pass feature enables users to write custom model modification strategies without compiling against all of MXNet header files and dependencies. When a library containing custom passes is loaded dynamically, the components found in the library will be registered in MXNet so that users can use those natively just like other built-in components. + +## Getting Started + +### Have MXNet Ready + +To run the following example, the build type of MXNet doesn’t matter since the custom pass doesn’t interact with the execution of other native MXNet features. Note that if you want to use your custom pass with models running on GPU, you still need an MXNet CUDA build. + +### Run An Example + +You can start getting familiar with custom passes by running an example provided in the **example/extensions/lib_pass** directory. The `myPass` example just copies the input graph to the output. Go to the **lib_pass** directory and follow these steps: + +1. Run `make`. The Makefile will generate the dynamic library **libpass_lib.so** which is compiled from the `pass_lib.cc` file. This is the library you are going to load that contains everything for the custom pass. +2. Run `python test_pass.py`. It’ll first load the above library, find the components, register them in the MXNet backend, then execute the pass on the model and execute the operators like a regular MXNet operator and output the result. Below is the output when running the `python test_pass.py` command. Notice that it loads 2 passes: myPass and jsonPass. + +``` +[10:38:03] src/c_api/c_api.cc:286: Found 0 operators in library +[10:38:03] src/c_api/c_api.cc:785: Found 0 partitioners in library +[07:14:00] src/c_api/c_api.cc:887: Found 2 graph passes in library +[07:14:00] src/c_api/c_api.cc:902: Graph Pass [0] myPass +[07:14:00] src/c_api/c_api.cc:902: Graph Pass [1] jsonPass +``` + +### Basic Files For Custom Pass Library +* **lib_pass/pass_lib.cc**: This file has a source code implementation of all required components to make a custom pass, it also shows registration of them so that they can be loaded by MXNet. +* **lib_pass/Makefile**: This file compiles the source code to a dynamic shared library, with a header file `include/mxnet/lib_api.h` from MXNet source code. Currently the custom pass is compatible with C++11 onwards. +* **lib_pass/test_pass.py**: This file calls `mx.library.load(‘libpass_lib.so’)` to load the library containing the custom components, executes the pass on the model using the `optimize_for` API, and prints outputs of the forward passes. The outputs should be the same as the regular MXNet forward pass without running the pass. +* **include/mxnet/lib_api.h**: This file from MXNet source code is the single header file needed to include all necessary data types and function prototypes for writing a custom library. You can either specify the include path in the `Makefile`, or copy the header file over to `example/extensions/lib_pass` folder. Note that apart from this header, the custom library is independent of MXNet source. +## Writing Custom Pass Library +To build your own library containing a custom pass, compose a C++ source file like `mypass_lib.cc`, include `lib_api.h` header file, and write your custom pass with these essential functions: +- `initialize` - Library Initialization Function +- `REGISTER_PASS` - Pass Registration Macro +- `graphPass` - Pass Implementation +Then compile it to the `mypass_lib.so` dynamic library using the following command: +```bash +g++ -shared -fPIC -std=c++11 mypass_lib.cc -o libmypass_lib.so -I ../../../include/mxnet +``` + +Finally, you can write a Python script to load the library and execute your pass on a model: + +```python +import mxnet as mx +mx.library.load(‘libmypass_lib.so’) +sym, _, _ = mx.model.load_checkpoint('mymodel', 0) +# Symbol/Module flow +sym2 = sym.optimize_for("myPass") +# Gluon flow 1 +sym_block = nn.SymbolBlock(sym, inputs) +sym_block.hybridize(backend='myPass') +# Gluon flow 2 +sym_block = nn.SymbolBlock(sym, inputs) +sym_block.optimize_for(x, backend='myPass') +``` + +### Using a Custom Pass Library + +APIs in MXNet are available in both Symbol and Gluon APIs. For the Symbol API, the `optimize_for` API can be called on Symbol objects to return a new Symbol post graph pass. + +``` +optimize_for(backend, args=None, aux=None, ctx=None, **kwargs) +``` + +The `optimize_for` API takes at least 1 argument, `backend` which is a string that identifies which backend to use to optimize the model. The `args` and `aux` arguments are optional and take a list of NDArray or dict of str to NDArray. They are used to infer shapes and types and before executing the graph pass. The `ctx` argument is optional and takes a device context to infer storage types. It also takes any other user-specified options that will be passed to the backend APIs. + +For the Gluon API, the `hybridize` API can be called on HybridBlocks to execute a graph pass on the internal CachedOp Symbol. + +``` +hybridize(backend=None, backend_opts=None, **kwargs) +``` + +The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which pass that will be executed on the model. The `backend_opts` takes other user-specified options that will be passed to the backend APIs. The actual pass runs once just before the first the forward pass. + +If you just want to run a graph pass on the HybridBlock but not run a complete forward pass, you can use the `optimize_for` API that combines the work done in the `hybridize` API with part of the work done in the forward pass. + +``` +optimize_for(x, backend=None, backend_opts=None, **kwargs) +``` + +When the `optimize_for` API is called on a HybridBlock it runs the graph pass immediately. This lets users export the modified model without running a complete forward pass. + +``` +block.optimize_for(x, backend='myPass') +block.export('optimized') +``` + +But you can also use `optimize_for` in place of `hybridize` and run inference immediately after too. + +``` +block.optimize_for(x, backend='myPass') +block(x) +``` + +### Writing A Custom Graph Pass + +There are several essential building blocks for making a custom pass: + +* [initialize](./pass_lib.cc#44): + * This function is the library initialization function necessary for any dynamic libraries. It lets you check if the user is using a compatible version of MXNet. Note that this `version` parameter is passed from MXNet when library is loaded. + + MXReturnValue initialize(int version) + +* [graphPass](./pass_lib.cc#31): + * This function provides a copy of the model graph as a JSON string, and provides an interface for returning a modified model JSON string. Also this is where a custom pass can validate the options specified by the user. + + MXReturnValue graphPass( + const std::string& in_graph, + const std::string** out_graph, + const std::unordered_map& options, + const std::unordered_map& args, + const std::unordered_map& aux, + const PassResource& res) + +* [REGISTER_PASS(my_pass_name)](./pass_lib.cc#L41): + * This macro registers the custom pass and its properties to MXNet by its name. The argument to `setBody` is the `graphPass` function. + + REGISTER_PASS(my_pass_name) + .setBody(graphPass); + +Let’s take a closer look at those registry functions: + +* **graphPass**: This function takes six arguments. The 1st argument is a JSON string of the model architecture graph, where nodes are inputs/params/weights and edges are data dependencies. The graph is pre-sorted in topological order. The 2nd argument is a pointer to a pointer of a JSON model string. It is expected users will dereference and assign the address of their output string allocated with `new` and `delete` will be called on it automatically. The third argument is the map of options specified by the user. Users can pass custom options to the pass and they are passed to this function in the `options` map. The fourth and fifth arguments are the named tensor mapping for the args and aux params for the model. They will contain the model params if the user provides them to the `optimize_for` API. The last argument is the `PassResource` object for memory allocation and other utilities. The details of `PassResource` are covered in the section below + +### Pass Resource + +Some graph passes require allocating new NDArrays to add/replace model params. The `alloc_arg` and `alloc_aux` APIs enabling allocating new NDArrays and integrating them with the user-provide args and aux params. Both APIs have the following signature: + +``` + MXTensor* alloc_xxx(const std::string& name, + const std::vector& shapes, + const MXContext &ctx, + MXDType dtype) +``` + +If the `name` provided matches the name of an existing param it replaces the previous one. Otherwise it adds a new param to the appropriate arg/aux set. + +### Parsing a JSON string + +To simplify custom libraries, basic JSON parsing utility functions have been implemented in the `lib_api.h` header file. You create a `JsonParser` object and parse the string by calling the `parse_to_json` API like: + +```c++ +JsonParser parser; +JsonVal json_val = parser.parse_to_json(json_string); +``` + +A `JsonVal` is a class that represents the nodes in a JSON structure. You can check the type of a node (num, str, list, or map) by comparing the `JsonVal.type` to `STR`, `NUM`, `LIST`, or `MAP`. Then you can get that value from the node like: + +```c++ +switch(json_val.type) { + case STR: + std::string str = json_val.str; + break; + case NUM: + int num = json_val.num; + break; + case LIST: + std::vector list = json_val.list; + break; + case MAP: + std::map map = json_val.map; + break; + default: + // error +} +``` + +There are also convenience constructors for creating `JsonVal` objects for strings and numbers like `JsonVal("myKey")` or `JsonVal(42)`. This makes it easy to get specific keys from a map like `json_val.map[JsonVal("nodes")]`. diff --git a/example/extensions/lib_pass/pass_lib.cc b/example/extensions/lib_pass/pass_lib.cc new file mode 100644 index 000000000000..bbdcd73a7a0b --- /dev/null +++ b/example/extensions/lib_pass/pass_lib.cc @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file subgraph_lib.cc + * \brief subgraph operator implementation library file + */ + +#include +#include +#include +#include "lib_api.h" + +/* \brief a basic pass that copies the input to the output */ +MXReturnValue myPass(const std::string& in_graph, const std::string** out_graph, + const std::unordered_map& options, + const std::unordered_map& args, + const std::unordered_map& aux, + const PassResource& res) { + for (auto kv : options) { + std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl; + } + + *out_graph = new std::string(in_graph); + return MX_SUCCESS; +} + +REGISTER_PASS(myPass) +.setBody(myPass); + +/* \brief a basic pass that parses the input string to JSON and then dumps it back */ +MXReturnValue jsonPass(const std::string& in_graph, const std::string** out_graph, + const std::unordered_map& options, + const std::unordered_map& args, + const std::unordered_map& aux, + const PassResource& res) { + for (auto kv : options) + std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl; + + // add test arg/aux + + MXTensor* arg_ = res.alloc_arg("test_arg",{1},MXContext::CPU(0),kFloat32); + MXTensor* aux_ = res.alloc_aux("test_aux",{1},MXContext::CPU(0),kFloat32); + + // convert json string to json object + JsonParser parser; + JsonVal json_val = parser.parse_to_json(in_graph); + + // get nodes list + JsonVal nodes = json_val.map[JsonVal("nodes")]; + + // loop over nodes + for(int i=0; i= 10700) { + std::cout << "MXNet version " << version << " supported" << std::endl; + return MX_SUCCESS; + } else { + std::cout << "MXNet version " << version << " not supported" << std::endl; + return MX_FAIL; + } +} diff --git a/example/extensions/lib_pass/test_pass.py b/example/extensions/lib_pass/test_pass.py new file mode 100644 index 000000000000..8930c9478152 --- /dev/null +++ b/example/extensions/lib_pass/test_pass.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=arguments-differ + +# This test checks if dynamic loading of library into MXNet is successful +# and checks the end of end computation of custom operator + +import os, ctypes +import mxnet as mx +from mxnet.gluon import nn +from mxnet import nd +from mxnet.base import _LIB, check_call, mx_uint, c_str, c_str_array, SymbolHandle + +# load library +if (os.name=='posix'): + path = os.path.abspath('libpass_lib.so') + mx.library.load(path) +elif (os.name=='nt'): + path = os.path.abspath('libpass_lib.dll') + mx.library.load(path) + +############################################### +# Test with not consuming params +############################################### +# example model, ops do not have args (use outputs from other ops as inputs) +a = mx.sym.var('a') +b = mx.sym.var('b') +c = a + b +d = mx.sym.exp(c) +sym = mx.sym.log(d) + +def test_model(pass_name): + # execute in MXNet + print('-------------------------------') + print('Testing regular MXNet execution') + exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + out = exe.forward() + print(out) + + # Symbol optimize_for + # with propogating shapes/types + print('-------------------------------') + print('Testing pass "%s" with shapes/types' % pass_name) + arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')] + aux = [] + mysym2 = sym.optimize_for(pass_name,arg_array,aux) + print(mysym2.tojson()) + exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + out2 = exe2.forward() + print(out2) + + # without propogating shapes/types + print('-------------------------------') + print('Testing pass "%s" without shapes/types' % pass_name) + mysym3 = sym.optimize_for(pass_name, myOpt='yello') + exe3 = mysym3.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + out3 = exe3.forward() + print(out3) + + # Gluon Hybridize + print('-------------------------------') + print('Testing pass "%s" Gluon Hybridize with shapes/types' % pass_name) + inputs = [a,b] + sym_block = nn.SymbolBlock(sym, inputs) + sym_block.initialize() + sym_block.hybridize(backend=pass_name) + out4 = sym_block(mx.nd.ones((3,2)),mx.nd.ones((3,2))) + print(out4) + + # Gluon optimize_for + print('-------------------------------') + print('Testing pass "%s" Gluon Hybridize with shapes/types without inference' % pass_name) + inputs = [a,b] + sym_block2 = nn.SymbolBlock(sym, inputs) + sym_block2.initialize() + sym_block2.optimize_for(mx.nd.ones((3,2)), mx.nd.ones((3,2)), backend=pass_name) + sym_block2.export('modified') + +test_model('myPass') +test_model('jsonPass') diff --git a/example/extensions/lib_subgraph/README.md b/example/extensions/lib_subgraph/README.md index 83c823676f18..6644a1fdc8ff 100644 --- a/example/extensions/lib_subgraph/README.md +++ b/example/extensions/lib_subgraph/README.md @@ -22,13 +22,13 @@ Custom Partitioner Example and Tutorial Adding custom model partitioners in MXNet used to require deep understanding of the MXNet backend, including operator registration and other internal classes, followed by recompiling MXNet from source. This feature allows adding custom partitioners by dynamically loading external libraries at runtime. -This custom partitioner feature enables users to write custom model partitioning strategies without compiling against all of MXNet header files and dependencies. When a library containing custom partitioners is loaded dynamically, the components found in the library will be re-registered in MXNet so that users can use those natively just like other built-in components. +This custom partitioner feature enables users to write custom model partitioning strategies without compiling against all of MXNet header files and dependencies. When a library containing custom partitioners is loaded dynamically, the components found in the library will be registered in MXNet so that users can use those natively just like other built-in components. ## Getting Started ### Have MXNet Ready -The custom partitioner feature was merged recently (#15969) and is not available in versions of MXNet prior to v1.7.0. To use the feature now, please install MXNet either by installing the nightly pip wheel or compiling from source. For running the following example, it doesn’t matter if it is a CUDA, MKLDNN or plain MXNet build; the custom partitioner doesn’t interact with the execution of other native MXNet features. Note that if you want to write your custom partitioners running on GPU, you still need an MXNet CUDA build. +To run the following example, the build type of MXNet doesn’t matter since the custom partitioner doesn’t interact with the execution of other native MXNet features. Note that if you want to use your custom partitioners with models running on GPU, you still need an MXNet CUDA build. ### Run An Example diff --git a/example/extensions/lib_subgraph/subgraph_lib.cc b/example/extensions/lib_subgraph/subgraph_lib.cc index d821bdb0d1c2..28442078ebe6 100644 --- a/example/extensions/lib_subgraph/subgraph_lib.cc +++ b/example/extensions/lib_subgraph/subgraph_lib.cc @@ -50,9 +50,9 @@ void myExp(MXTensor &in, MXTensor &out) { * so all we need to do is go through the ops in order * and execute each op. */ -MXReturnValue myExecutor(std::vector inputs, - std::vector outputs, - std::string subgraph_sym) { +MXReturnValue myExecutor(std::vector* inputs, + std::vector* outputs, + const std::string& subgraph_sym) { std::cout << "Info: subgraph symbol is: " << std::endl; std::cout << subgraph_sym << std::endl; @@ -79,12 +79,12 @@ MXReturnValue myExecutor(std::vector inputs, // handle each op type if (op.compare("null") == 0) { // null is an input data to the subgraph, add to data storage - data.push_back(inputs[input_cnt++]); + data.push_back(inputs->at(input_cnt++)); } else if (op.compare("log") == 0) { // get input tensor based on node ID inputs from data storage MXTensor &input = data[node_inputs.list[0].list[0].num]; // create temporary storage - MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, {"cpu", 0}, kDefaultStorage); + MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, MXContext::CPU(0), kDefaultStorage); // save allocated ptr to free later to_free.push_back(tmp.data_ptr); // execute log operator @@ -95,7 +95,7 @@ MXReturnValue myExecutor(std::vector inputs, // get input tensor based on node ID inputs from data storage MXTensor &input = data[node_inputs.list[0].list[0].num]; // create temporary storage - MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, {"cpu", 0}, kDefaultStorage); + MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, MXContext::CPU(0), kDefaultStorage); // save allocated ptr to free later to_free.push_back(tmp.data_ptr); // execute exp operator @@ -118,7 +118,7 @@ MXReturnValue myExecutor(std::vector inputs, // get computed result MXTensor &result = data[heads.list[0].list[0].num]; // get output tensor to pass to MX - MXTensor &out = outputs[j]; + MXTensor &out = outputs->at(j); float *out_data = out.data(); float *res_data = result.data(); // loop and copy data @@ -137,34 +137,34 @@ MXReturnValue myExecutor(std::vector inputs, class MyStatefulOp : public CustomStatefulOp { public: - explicit MyStatefulOp(std::string sym, std::map attrs) + explicit MyStatefulOp(const std::string& sym, + const std::unordered_map& attrs) : subgraph_sym(sym), attrs_(attrs) { for (auto kv : attrs) { std::cout << "subgraphOp attributes: " << kv.first << " ==> " << kv.second << std::endl; } } - MXReturnValue Forward(std::vector inputs, - std::vector outputs, - OpResource op_res) { + MXReturnValue Forward(std::vector* inputs, + std::vector* outputs, + const OpResource& op_res) { return myExecutor(inputs, outputs, subgraph_sym); } private: - std::string subgraph_sym; - std::map attrs_; + const std::string subgraph_sym; + const std::unordered_map attrs_; }; -MXReturnValue createOpState(std::map attrs, +MXReturnValue createOpState(const std::unordered_map& attrs, CustomStatefulOp** op_inst) { std::string serialized_subgraph = "[empty]"; // MXNet subgraph is stored as Symbol in operator node attrs subgraphs field // custom subgraph is stored as json string in custom operator attrs map entry if (attrs.count(MX_STR_SUBGRAPH_SYM_JSON)) { // user can now parse json and run other custom ops inside subgraph - serialized_subgraph = attrs[MX_STR_SUBGRAPH_SYM_JSON]; + serialized_subgraph = attrs.at(MX_STR_SUBGRAPH_SYM_JSON); } - attrs.erase(MX_STR_SUBGRAPH_SYM_JSON); *op_inst = new MyStatefulOp(serialized_subgraph, attrs); std::cout << "Info: stateful operator created" << std::endl; return MX_SUCCESS; @@ -176,9 +176,9 @@ REGISTER_OP(_custom_subgraph_op) const std::vector op_names({"exp","log"}); -MXReturnValue mySupportedOps(std::string json, - std::vector& ids, - std::unordered_map& options) { +MXReturnValue mySupportedOps(const std::string& json, + std::vector* ids, + const std::unordered_map& options) { for (auto kv : options) { std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl; } @@ -208,19 +208,19 @@ MXReturnValue mySupportedOps(std::string json, if((dtype == kFloat32 && options.count("reqFloat") > 0) || options.count("reqFloat") == 0) { //check if op is in whitelist if(std::find(op_names.begin(),op_names.end(),op.str.c_str()) != op_names.end()) { - // found op in whitelist, set value to 1 to include op in subgraph - ids[i]=true; + // found op in whitelist, set value to -1 to include op in any subgraph + ids->at(i) = -1; } } } return MX_SUCCESS; } -MXReturnValue myReviewSubgraph(std::string json, int subgraph_id, bool* accept, - std::unordered_map& options, - std::unordered_map& attrs, - std::map& args, - std::map& aux) { +MXReturnValue myReviewSubgraph(const std::string& json, int subgraph_id, bool* accept, + const std::unordered_map& options, + std::unordered_map* attrs, + const std::unordered_map& args, + const std::unordered_map& aux) { for (auto kv : options) { std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl; } @@ -242,24 +242,97 @@ MXReturnValue myReviewSubgraph(std::string json, int subgraph_id, bool* accept, } // check if option `reject` was specified, and if so check if value is 'True' - if(options.count("reject") > 0 && options["reject"].compare("True") == 0) { + if(options.count("reject") > 0 && options.at("reject").compare("True") == 0) { // if specified, reject the subgraph. this is only used for testing *accept = false; std::cout << "rejecting subgraph" << std::endl; } else { *accept = true; std::cout << "accepting subgraph" << std::endl; - attrs["myKey"] = "myVal"; + attrs->insert(std::pair("myKey","myVal")); } return MX_SUCCESS; } REGISTER_PARTITIONER(myProp) -.addStrategy("strategy1", mySupportedOps, "_custom_subgraph_op") +.addStrategy("strategy1", "_custom_subgraph_op") +.setSupportedOps("strategy1", mySupportedOps) +.setReviewSubgraph("strategy1", myReviewSubgraph); + +class MySelector : public CustomOpSelector { + public: + MySelector(const std::string& json, + const std::unordered_map& options) : + graph_json(json), options_(options) { + for (auto kv : options) { + std::cout << "selector options: " << kv.first + << " ==> " << kv.second << std::endl; + } + //convert json string to json object + JsonParser parser; + JsonVal json_val = parser.parse_to_json(json); + //get nodes list + nodes = json_val.map[JsonVal("nodes")]; + } + bool chooseNode(int nodeID) { + JsonVal node = nodes.list[nodeID]; + JsonVal op = node.map[JsonVal("op")]; + + //get shape/type if available + std::string shape; + int dtype = -1; + if(node.map.find(JsonVal("attrs")) != node.map.end()) { + JsonVal attrs = node.map[JsonVal("attrs")]; + if(attrs.map.find(JsonVal("shape")) != attrs.map.end()) + shape = attrs.map[JsonVal("shape")].str; + if(attrs.map.find(JsonVal("dtype")) != attrs.map.end()) + dtype = std::stoi(attrs.map[JsonVal("dtype")].str); + } + + //check if op dtype is float, and if option was specified to require float types + if((dtype == kFloat32 && options_.count("reqFloat") > 0) || options_.count("reqFloat") == 0) { + //check if op is in whitelist + if(std::find(op_names.begin(),op_names.end(),op.str.c_str()) != op_names.end()) { + // found op in whitelist, return true to include op subgraph + return true; + } + } + return false; + } + virtual bool Select(int nodeID) { + return chooseNode(nodeID); + } + virtual bool SelectInput(int nodeID, int input_nodeID) { + return chooseNode(input_nodeID); + } + virtual bool SelectOutput(int nodeID, int output_nodeID) { + return chooseNode(output_nodeID); + } + virtual void Filter(std::vector& candidates, + std::vector& keep) { + keep.insert(keep.end(), candidates.begin(), candidates.end()); + } + virtual void Reset() {} + private: + std::string graph_json; + JsonVal nodes; + const std::unordered_map options_; +}; + +MXReturnValue createSelector(const std::string& json, CustomOpSelector** sel_inst, + const std::unordered_map& options) { + *sel_inst = new MySelector(json, options); + std::cout << "Info: selector created" << std::endl; + return MX_SUCCESS; +} + +REGISTER_PARTITIONER(mySelect) +.addStrategy("strategy1", "_custom_subgraph_op") +.setCreateSelector("strategy1", createSelector) .setReviewSubgraph("strategy1", myReviewSubgraph); MXReturnValue initialize(int version) { - if (version >= 10400) { + if (version >= 10700) { std::cout << "MXNet version " << version << " supported" << std::endl; return MX_SUCCESS; } else { diff --git a/example/extensions/lib_subgraph/test_subgraph.py b/example/extensions/lib_subgraph/test_subgraph.py index 55a40514f105..eb7102a1511c 100644 --- a/example/extensions/lib_subgraph/test_subgraph.py +++ b/example/extensions/lib_subgraph/test_subgraph.py @@ -37,9 +37,6 @@ path = os.path.abspath('libsubgraph_lib.dll') mx.library.load(path) -############################################### -# Test with subgraph not consuming params -############################################### # example model, ops to be partitioned do not have args (use outputs from other ops as inputs) a = mx.sym.var('a') b = mx.sym.var('b') @@ -47,98 +44,104 @@ d = mx.sym.exp(c) sym = mx.sym.log(d) -#execute in MXNet -print('-------------------------------') -print('Testing regular MXNet execution') -exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) -out = exe.forward() -print(out) - -# with propogating shapes/types -print('-------------------------------') -print('Testing partitioning with shapes/types') -arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')] -mysym2 = sym.optimize_for("myProp",arg_array) -print(mysym2.tojson()) -exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) -out2 = exe2.forward() -print(out2) - -# with propogating shapes/types, rejecting subgraph -print('-------------------------------') -print('Testing partitioning with shapes/types - rejecting subgraph') -arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')] -mysym2 = sym.optimize_for("myProp", arg_array, reject=True) -exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) -out2 = exe2.forward() -print(out2) - -# without propogating shapes/types -print('-------------------------------') -print('Testing partitioning without shapes/types') -mysym3 = sym.optimize_for("myProp", myOpt='yello') -exe3 = mysym3.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) -out3 = exe3.forward() -print(out3) - -# Gluon Hybridize partitioning with shapes/types -print('-------------------------------') -print('Testing Gluon Hybridize partitioning with shapes/types') -inputs = [a,b] -sym_block = nn.SymbolBlock(sym, inputs) -sym_block.initialize() -sym_block.hybridize(backend='myProp') -out4 = sym_block(mx.nd.ones((3,2)),mx.nd.ones((3,2))) -print(out4) - -# Gluon Hybridize partitioning with shapes/types without inference -print('-------------------------------') -print('Testing Gluon Hybridize partitioning with shapes/types without inference') -inputs = [a,b] -sym_block2 = nn.SymbolBlock(sym, inputs) -sym_block2.initialize() -sym_block2.optimize_for(mx.nd.ones((3,2)), mx.nd.ones((3,2)), backend='myProp') -sym_block2.export('partitioned') - - -############################################### -# Test with subgraph directly consuming params -############################################### # example model, ops to be partitioned have args d2 = mx.sym.exp(a) sym2 = mx.sym.log(d2) -#execute in MXNet -print('-------------------------------') -print('Testing regular MXNet execution') -exe5 = sym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))}) -out5 = exe5.forward() -print(out5) - -# with propogating shapes/types -print('-------------------------------') -print('Testing partitioning with shapes/types') -arg_array = [mx.nd.ones((3,2),dtype='float32')] -mysym6 = sym2.optimize_for("myProp", arg_array, reqArgs=True) -print(mysym6.tojson()) -exe6 = mysym6.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))}) -out6 = exe6.forward() -print(out6) - -# without propogating shapes/types -print('-------------------------------') -print('Testing partitioning without shapes/types') -mysym7 = sym2.optimize_for("myProp", reqArgs=True) -exe7 = mysym7.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))}) -out7 = exe7.forward() -print(out7) - -# Gluon Hybridize partitioning with shapes/types -print('-------------------------------') -print('Testing Gluon Hybridize partitioning with shapes/types') -inputs = [a] -sym2_block = nn.SymbolBlock(sym2, inputs) -sym2_block.initialize() -sym2_block.hybridize(backend='myProp') -out8 = sym2_block(mx.nd.ones((3,2))) -print(out8) +def test(backend): + ############################################### + # Test with subgraph not consuming params + ############################################### + #execute in MXNet + print('-------------------------------') + print('Testing regular MXNet execution') + exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + out = exe.forward() + print(out) + + # with propogating shapes/types + print('-------------------------------') + print('Testing %s partitioning with shapes/types' % backend) + arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')] + mysym2 = sym.optimize_for(backend,arg_array) + print(mysym2.tojson()) + exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + out2 = exe2.forward() + print(out2) + + # with propogating shapes/types, rejecting subgraph + print('-------------------------------') + print('Testing %s partitioning with shapes/types - rejecting subgraph' % backend) + arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')] + mysym2 = sym.optimize_for(backend, arg_array, reject=True) + exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + out2 = exe2.forward() + print(out2) + + # without propogating shapes/types + print('-------------------------------') + print('Testing %s partitioning without shapes/types' % backend) + mysym3 = sym.optimize_for(backend, myOpt='yello') + exe3 = mysym3.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + out3 = exe3.forward() + print(out3) + + # Gluon Hybridize partitioning with shapes/types + print('-------------------------------') + print('Testing %s Gluon Hybridize partitioning with shapes/types' % backend) + inputs = [a,b] + sym_block = nn.SymbolBlock(sym, inputs) + sym_block.initialize() + sym_block.hybridize(backend=backend) + out4 = sym_block(mx.nd.ones((3,2)),mx.nd.ones((3,2))) + print(out4) + + # Gluon Hybridize partitioning with shapes/types without inference + print('-------------------------------') + print('Testing %s Gluon Hybridize partitioning with shapes/types without inference' % backend) + inputs = [a,b] + sym_block2 = nn.SymbolBlock(sym, inputs) + sym_block2.initialize() + sym_block2.optimize_for(mx.nd.ones((3,2)), mx.nd.ones((3,2)), backend=backend) + sym_block2.export('partitioned') + + ############################################### + # Test with subgraph directly consuming params + ############################################### + #execute in MXNet + print('-------------------------------') + print('Testing regular MXNet execution') + exe5 = sym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))}) + out5 = exe5.forward() + print(out5) + + # with propogating shapes/types + print('-------------------------------') + print('Testing %s partitioning with shapes/types' % backend) + arg_array = [mx.nd.ones((3,2),dtype='float32')] + mysym6 = sym2.optimize_for(backend, arg_array, reqArgs=True) + print(mysym6.tojson()) + exe6 = mysym6.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))}) + out6 = exe6.forward() + print(out6) + + # without propogating shapes/types + print('-------------------------------') + print('Testing %s partitioning without shapes/types' % backend) + mysym7 = sym2.optimize_for(backend, reqArgs=True) + exe7 = mysym7.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))}) + out7 = exe7.forward() + print(out7) + + # Gluon Hybridize partitioning with shapes/types + print('-------------------------------') + print('Testing %s Gluon Hybridize partitioning with shapes/types' % backend) + inputs = [a] + sym2_block = nn.SymbolBlock(sym2, inputs) + sym2_block.initialize() + sym2_block.hybridize(backend=backend) + out8 = sym2_block(mx.nd.ones((3,2))) + print(out8) + +test("myProp") +test("mySelect") diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index e3d9062cec79..cfb2400c9290 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -234,9 +234,10 @@ MXNET_DLL const char *MXGetLastError(); /*! * \brief Load library dynamically * \param path to the library .so file + * \param 0 for quiet, 1 for verbose * \return 0 when success, -1 when failure happens. */ -MXNET_DLL int MXLoadLib(const char *path); +MXNET_DLL int MXLoadLib(const char *path, unsigned verbose); /*! * \brief Get list of features supported on the runtime @@ -2176,7 +2177,13 @@ MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle, NDArrayHandle* in_aux_handle, const mx_uint num_options, const char** keys, - const char** vals); + const char** vals, + int* new_args_cnt, + NDArrayHandle** new_args_handle, + char*** new_arg_names_handle, + int* new_aux_cnt, + NDArrayHandle** new_aux_handle, + char*** new_aux_names_handle); //-------------------------------------------- diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h index c793a30c96d9..c8ba712a9ec4 100644 --- a/include/mxnet/lib_api.h +++ b/include/mxnet/lib_api.h @@ -22,7 +22,11 @@ * \file lib_api.h * \brief APIs to interact with libraries * This API specifies function prototypes to - * register custom ops for library authors + * register custom ops, partitioner, and passes + * for library authors + * See example/extension/lib_custom_op/README.md + * See example/extension/lib_subgraph/README.md + * See example/extension/lib_pass/README.md */ #ifndef MXNET_LIB_API_H_ @@ -45,7 +49,7 @@ #endif /* Make sure to update the version number everytime you make changes */ -#define MX_LIBRARY_VERSION 6 +#define MX_LIBRARY_VERSION 7 /*! * \brief For loading multiple custom op libraries in Linux, exporting same symbol multiple @@ -237,10 +241,20 @@ enum MXStorageType { * dev_type is string repr of supported context, currently only "cpu" and "gpu" * dev_id is the device index where the tensor locates */ -typedef struct { +struct MXContext { + MXContext() : dev_type("error"), dev_id(-1) {} + explicit MXContext(std::string dev_type_, int dev_id_) + : dev_type(dev_type_), dev_id(dev_id_) {} + explicit MXContext(const char* dev_type_, int dev_id_) + : dev_type(dev_type_), dev_id(dev_id_) {} + static MXContext CPU() { return MXContext("cpu", 0); } + static MXContext GPU() { return MXContext("gpu", 0); } + static MXContext CPU(int dev_id) { return MXContext("cpu", dev_id); } + static MXContext GPU(int dev_id) { return MXContext("gpu", dev_id); } + std::string dev_type; int dev_id; -} MXContext; +}; enum MXReturnValue { MX_FAIL = 0, @@ -382,7 +396,7 @@ struct MXTensor { } /*! \brief helper function to get data size */ - inline int64_t size() { + inline int64_t size() const { int64_t size = 1; for (unsigned int i = 0; i < shape.size(); i++) { size *= shape[i]; @@ -427,9 +441,13 @@ struct MXTensor { /*! \brief resource malloc function to allocate memory inside Forward/Backward functions */ typedef void* (*xpu_malloc_t)(void*, int); - +/*! \brief sparse alloc function to allocate memory inside Forward/Backward functions */ typedef void (*sparse_malloc_t)(void*, int, int, int, void**, int64_t**, int64_t**); - +/*! \brief resource malloc function to allocate ndarrays for graph passes */ +typedef void (*nd_malloc_t)(const void* _ndarray_alloc, const int64_t* shapes, int num_shapes, + const char* dev_str, int dev_id, int dtype, const char* name, + int isArg, void** data); +/*! \brief GPU stream pointer, is void* when not compiled with CUDA */ #if defined(__NVCC__) typedef cudaStream_t mx_stream_t; typedef curandStatePhilox4_32_10_t mx_gpu_rand_t; @@ -444,6 +462,38 @@ typedef std::mt19937 mx_cpu_rand_t; #define MX_NUM_CPU_RANDOM_STATES 1024 #define MX_NUM_GPU_RANDOM_STATES 32768 +class PassResource { + public: + PassResource(std::unordered_map* new_args, + std::unordered_map* new_aux, + nd_malloc_t nd_malloc, const void* nd_alloc) + : new_args_(new_args), new_aux_(new_aux), nd_malloc_(nd_malloc), nd_alloc_(nd_alloc) {} + MXTensor* alloc_arg(const std::string& name, const std::vector& shapes, + const MXContext &ctx, MXDType dtype) const { + void* data; + nd_malloc_(nd_alloc_, shapes.data(), shapes.size(), ctx.dev_type.c_str(), ctx.dev_id, + dtype, name.c_str(), 1, &data); + MXTensor tensor(data, shapes, dtype, 0, ctx, kDefaultStorage); + (*new_args_)[name] = tensor; + return &(new_args_->at(name)); + } + MXTensor* alloc_aux(const std::string& name, const std::vector& shapes, + const MXContext &ctx, MXDType dtype) const { + void* data; + nd_malloc_(nd_alloc_, shapes.data(), shapes.size(), ctx.dev_type.c_str(), ctx.dev_id, + dtype, name.c_str(), 0, &data); + MXTensor tensor(data, shapes, dtype, 0, ctx, kDefaultStorage); + (*new_aux_)[name] = tensor; + return &(new_aux_->at(name)); + } + + private: + std::unordered_map* new_args_; + std::unordered_map* new_aux_; + nd_malloc_t nd_malloc_; + const void* nd_alloc_; +}; + /*! * \brief provide resource APIs memory allocation mechanism to Forward/Backward functions */ @@ -459,36 +509,36 @@ class OpResource { rand_cpu_states(rng_cpu_states), rand_gpu_states(rng_gpu_states) {} /*! \brief allocate cpu memory controlled by MXNet */ - void* alloc_cpu(int size) { + void* alloc_cpu(int size) const { return cpu_malloc(cpu_alloc, size); } /*! \brief allocate gpu memory controlled by MXNet */ - void* alloc_gpu(int size) { + void* alloc_gpu(int size) const { return gpu_malloc(gpu_alloc, size); } /*! \brief return the cuda stream object with correct type */ - mx_stream_t get_cuda_stream() { + mx_stream_t get_cuda_stream() const { return static_cast(cuda_stream); } /*! \brief allocate sparse memory controlled by MXNet */ - void alloc_sparse(MXSparse* sparse, int index, int indices_len, int indptr_len = 0) { + void alloc_sparse(MXSparse* sparse, int index, int indices_len, int indptr_len = 0) const { sparse_malloc(sparse_alloc, index, indices_len, indptr_len, &(sparse->data), &(sparse->indices), &(sparse->indptr)); } /*! \brief get pointer to initialized and seeded random number states located on CPU */ /* Access each state by states[id], but this id should be <= MX_NUM_CPU_RANDOM_STATES */ - mx_cpu_rand_t* get_cpu_rand_states() { + mx_cpu_rand_t* get_cpu_rand_states() const { return static_cast(rand_cpu_states); } /*! \brief get pointer to initialized and seeded random number states located on GPU */ /* Access each state by states[id], but this id should be <= MX_NUM_GPU_RANDOM_STATES */ /* Note that if you are using cpu build, it will return a nullptr */ - mx_gpu_rand_t* get_gpu_rand_states() { + mx_gpu_rand_t* get_gpu_rand_states() const { return static_cast(rand_gpu_states); } @@ -507,16 +557,17 @@ class OpResource { void *rand_cpu_states, *rand_gpu_states; }; -/*! - * \brief Json utility to parse serialized subgraph symbol - */ /*! \brief Macro to help passing serialized subgraph through attribute dict */ #define MX_STR_SUBGRAPH_SYM_JSON "subgraph_sym_json" -#define MX_STR_DTYPE "__dtype__" -#define MX_STR_SHAPE "__shape__" +#define MX_STR_DTYPE "__ext_dtype__" +#define MX_STR_SHAPE "__ext_shape__" /* \brief get shape value from list of shapes string - * format: [[1]] or [[1],[2]] + * + * Examples: + * + * getShapeAt("[[1]]", 0) returns "[1]" + * getShapeAt("[[1],[2,3]]", 1) returns "[2,3]" */ std::string getShapeAt(const std::string& shape, unsigned index) { int idx = 1; // start at 1 to skip the first square bracket [ @@ -529,7 +580,11 @@ std::string getShapeAt(const std::string& shape, unsigned index) { } /* \brief get dtype value from list of dtypes string - * format: [1] or [1,2] + * + * Examples: + * + * getDtypeAt("[1]", 0) returns "1" + * getDtypeAt("[1,2]", 1) returns "2" */ std::string getDtypeAt(const std::string& dtype, unsigned index) { // find the beginning of the output dtype for the particular output index @@ -541,6 +596,9 @@ std::string getDtypeAt(const std::string& dtype, unsigned index) { return dtype.substr(idx+1, stop-idx-1); } +/*! + * \brief Json utility to parse serialized subgraph symbol + */ /*! \brief Types of JSON objects */ enum JsonType {ERR, STR, NUM, LIST, MAP}; @@ -589,14 +647,14 @@ struct JsonVal { /*! \brief functions used for parsing JSON */ struct JsonParser { - JsonVal parse_to_json(std::string json) { + JsonVal parse_to_json(const std::string& json) { unsigned int idx = 0; return parse(json, &idx); } - void print_json_val(JsonVal val) { + void print_json_val(const JsonVal& val) { std::cout << json_val_string(val) << std::endl; } - // debug function to convert a JSON object to a string + // debug function to dump data structure to string std::string json_val_string(const JsonVal &val) { std::string ret; switch (val.type) { @@ -625,7 +683,7 @@ struct JsonParser { return ret; } // parse a string JSON object - JsonVal parse_string(std::string json, unsigned int* idx) { + JsonVal parse_string(const std::string& json, unsigned int* idx) { JsonVal ret(STR); while (*idx < json.size()) { if (json[*idx] == '"') { @@ -640,7 +698,7 @@ struct JsonParser { return JsonVal(); } // parse a number JSON object - JsonVal parse_num(std::string json, unsigned int* idx) { + JsonVal parse_num(const std::string& json, unsigned int* idx) { JsonVal ret(NUM); while (*idx < json.size()) { if (json[*idx] >= '0' && json[*idx] <= '9') { @@ -654,7 +712,7 @@ struct JsonParser { return ret; } // parse a list of JSON objects - JsonVal parse_list(std::string json, unsigned int* idx) { + JsonVal parse_list(const std::string& json, unsigned int* idx) { JsonVal ret(LIST); while (*idx < json.size()) { if (json[*idx] == ']') { @@ -670,7 +728,7 @@ struct JsonParser { return JsonVal(); } // parse a map of JSON objects - JsonVal parse_map(std::string json, unsigned int* idx) { + JsonVal parse_map(const std::string& json, unsigned int* idx) { JsonVal ret(MAP), key; while (*idx < json.size()) { if (json[*idx] == '}') { @@ -690,7 +748,7 @@ struct JsonParser { return JsonVal(); } // generic parse function - JsonVal parse(std::string json, unsigned int *idx) { + JsonVal parse(const std::string& json, unsigned int *idx) { JsonVal ret; while (*idx < json.size()) { if (json[*idx] == '"') { @@ -710,21 +768,93 @@ struct JsonParser { } return ret; } + // convert JSON object back to JSON-compatible string + std::string dump(const JsonVal &val) { + std::string ret; + switch (val.type) { + case ERR: + ret = "json(Error)"; + break; + case STR: + ret = "\"" + val.str + "\""; + break; + case NUM: + ret = val.str; + break; + case LIST: + ret = "["; + for (unsigned i=0; i < val.list.size(); i++) { + auto &item = val.list[i]; + ret += dump(item); + if (i < val.list.size()-1) + ret += ","; + } + ret += "]"; + break; + case MAP: + ret = "{"; + unsigned cnt = 0; + for (auto &item : val.map) { + ret += dump(item.first) + " : " + dump(item.second); + if (cnt++ < val.map.size()-1) + ret += ","; + } + ret += "}"; + break; + } + return ret; + } +}; + +/* \brief An abstract class for library authors creating custom + * partitioners. Optional, can just implement supportedOps instead + */ +class CustomOpSelector { + public: + /* \brief Select a node to include in subgraph, return true to include node + * nodeID - index of node in graph + */ + virtual bool Select(int nodeID) = 0; + /* \brief Select an input node from current node to include in subgraph + * return true to include node + * nodeID - index of node in graph + * input_nodeID - index of input node in graph + */ + virtual bool SelectInput(int nodeID, int input_nodeID) = 0; + /* \brief Select an output node from current node to include in subgraph + * return true to include node + * nodeID - index of node in graph + * output_nodeID - index of output node in graph + */ + virtual bool SelectOutput(int nodeID, int output_nodeID) = 0; + /* \brief Review nodes to include in subgraph + * return set of candidate nodes to keep in subgraph + * candidates - indices of nodes to include in subgraph + * keep - indices of nodes to keep in subgraph + */ + virtual void Filter(const std::vector& candidates, + std::vector* keep) { + keep->insert(keep->end(), candidates.begin(), candidates.end()); + } + /* \brief Reset any selector state, called after growing subgraph, before filter + * Called after finished calling SelectInput/SelectOutput and growing subgraph + */ + virtual void Reset() {} }; /*! - * \brief An abstract class for library author creating stateful op + * \brief An abstract class for library authors creating stateful op * custom library should override Forward and destructor, and has an * option to implement Backward */ class CustomStatefulOp { public: - virtual MXReturnValue Forward(std::vector inputs, - std::vector outputs, - OpResource op_res) = 0; - virtual MXReturnValue Backward(std::vector inputs, - std::vector outputs, - OpResource op_res) { + virtual MXReturnValue Forward(std::vector* inputs, + std::vector* outputs, + const OpResource& op_res) = 0; + virtual MXReturnValue Backward(std::vector* inputs, + std::vector* outputs, + const OpResource& op_res) { std::cout << "Error! Operator does not support backward" << std::endl; return MX_FAIL; } @@ -740,21 +870,31 @@ class CustomStatefulOpWrapper { }; /*! \brief Custom Operator function templates */ -typedef MXReturnValue (*fcomp_t)(std::map, - std::vector, std::vector, - OpResource res); -typedef MXReturnValue (*parseAttrs_t)(std::map, - int*, int*); -typedef MXReturnValue (*inferType_t)(std::map, - std::vector&, std::vector&); -typedef MXReturnValue (*inferSType_t)(std::map, - std::vector&, std::vector&); -typedef MXReturnValue (*inferShape_t)(std::map, - std::vector >&, - std::vector >&); -typedef MXReturnValue (*mutateInputs_t)(std::map, - std::vector&); -typedef MXReturnValue (*createOpState_t)(std::map, +typedef MXReturnValue (*fcomp_t)(const std::unordered_map& attributes, + std::vector* inputs, + std::vector* outputs, + const OpResource& res); +typedef MXReturnValue (*parseAttrs_t)(const std::unordered_map& attributes, + int* num_inputs, int* num_outputs); +typedef MXReturnValue (*inferType_t)(const std::unordered_map& attributes, + std::vector* in_types, + std::vector* out_types); +typedef MXReturnValue (*inferSType_t)(const std::unordered_map& attributes, + std::vector* in_storage_types, + std::vector* out_storage_types); +typedef MXReturnValue (*inferShape_t)(const std::unordered_map& attributes, + std::vector >* in_shapes, + std::vector >* out_shapes); +typedef MXReturnValue (*mutateInputs_t)(const std::unordered_map& attributes, + std::vector* input_indices); +typedef MXReturnValue (*createOpState_t)(const std::unordered_map& attributes, CustomStatefulOp**); /*! @@ -852,14 +992,45 @@ class CustomOp { std::unordered_map create_op_ctx_map; }; +/*! \brief Custom Pass Create function template */ +typedef MXReturnValue (*graphPass_t)(const std::string& in_graph, const std::string** out_graph, + const std::unordered_map& options, + const std::unordered_map& args, + const std::unordered_map& aux, + const PassResource& res); + +/*! + * \brief An abstract class for graph passes + */ +class CustomPass { + public: + CustomPass() : name("ERROR") {} + explicit CustomPass(const char* pass_name) + : name(pass_name) {} + CustomPass& setBody(graphPass_t fn) { + pass = fn; + return *this; + } + + /*! \brief pass name */ + const char* name; + /*! \brief pass function */ + graphPass_t pass; +}; + /*! \brief Custom Subgraph Create function template */ -typedef MXReturnValue (*supportedOps_t)(std::string, std::vector&, - std::unordered_map&); -typedef MXReturnValue (*reviewSubgraph_t)(std::string, int, bool*, - std::unordered_map&, - std::unordered_map&, - std::map&, - std::map&); +typedef MXReturnValue (*supportedOps_t)(const std::string& json, std::vector* ids, + const std::unordered_map& options); +typedef MXReturnValue (*createSelector_t)(const std::string& json, CustomOpSelector** sel_inst, + const std::unordered_map& options); +typedef MXReturnValue (*reviewSubgraph_t)(const std::string& json, int subgraph_id, bool* accept, + const std::unordered_map& options, + std::unordered_map* attrs, + const std::unordered_map& args, + const std::unordered_map& aux); /*! * \brief An abstract class for subgraph property @@ -870,32 +1041,52 @@ class CustomPartitioner { explicit CustomPartitioner(const char* backend_name) : name(backend_name) {} CustomPartitioner& addStrategy(const char* prop_name, - supportedOps_t fn, const char* sg_name) { strategies.push_back(prop_name); - supportedOps.push_back(fn); op_names.push_back(sg_name); return *this; } + CustomPartitioner& setSupportedOps(const char* prop_name, supportedOps_t fn) { + supported_map[std::string(prop_name)] = fn; + return *this; + } + CustomPartitioner& setCreateSelector(const char* prop_name, createSelector_t fn) { + selector_map[std::string(prop_name)] = fn; + return *this; + } CustomPartitioner& setReviewSubgraph(const char* prop_name, reviewSubgraph_t fn) { review_map[std::string(prop_name)] = fn; return *this; } + supportedOps_t getSupportedOps(int stg_id) { + std::string prop(strategies[stg_id]); + if (supported_map.count(prop) > 0) + return supported_map[prop]; + else + return nullptr; + } + createSelector_t getCreateSelector(int stg_id) { + std::string prop(strategies[stg_id]); + if (selector_map.count(prop) > 0) + return selector_map[prop]; + else + return nullptr; + } reviewSubgraph_t getReviewSubgraph(int stg_id) { std::string prop(strategies[stg_id]); - if (review_map.find(prop) != review_map.end()) + if (review_map.count(prop) > 0) return review_map[prop]; else return nullptr; } - /*! \brief partitioner name */ + /*! \brief partitioner name */ const char* name; + std::map supported_map; + std::map selector_map; std::map review_map; /*! \brief strategy names */ std::vector strategies; - /*! \brief supported ops function */ - std::vector supportedOps; /*! \brief subgraph operator name */ std::vector op_names; }; @@ -959,6 +1150,9 @@ class Registry { #define MX_REGISTER_PROP_NAME_(Name) MXNet ## _CustomSubProp ## _ #define MX_REGISTER_PROP_DEF_(Name) CustomPartitioner MX_REGISTER_PROP_NAME_(Name) +#define MX_REGISTER_PASS_NAME_(Name) MXNet ## _CustomPass ## _ +#define MX_REGISTER_PASS_DEF_(Name) CustomPass MX_REGISTER_PASS_NAME_(Name) + /*! \brief assign a var to a value */ #define REGISTER_OP(Name) MX_STR_CONCAT(MX_REGISTER_DEF_(Name), __COUNTER__) = \ Registry::get()->add(MX_TOSTRING(Name)) @@ -967,6 +1161,10 @@ class Registry { MX_STR_CONCAT(MX_REGISTER_PROP_DEF_(Name), __COUNTER__) = \ Registry::get()->add(MX_TOSTRING(Name)) +#define REGISTER_PASS(Name) \ + MX_STR_CONCAT(MX_REGISTER_PASS_DEF_(Name), __COUNTER__) = \ + Registry::get()->add(MX_TOSTRING(Name)) + /* -------------- BELOW ARE CTYPE FUNCTIONS PROTOTYPES --------------- */ /*! @@ -998,6 +1196,7 @@ typedef int (*opCallParseAttrs_t)(parseAttrs_t parseAttrs, const char* const* ke typedef int (*opCallInferShape_t)(inferShape_t inferShape, const char* const* keys, const char* const* vals, int num, unsigned int** inshapes, int* indims, int num_in, + unsigned int*** mod_inshapes, int** mod_indims, unsigned int*** outshapes, int** outdims, int num_out); #define MXLIB_OPCALLINFERTYPE_STR "_opCallInferType" @@ -1069,14 +1268,37 @@ typedef int (*partRegGetCount_t)(int idx, const char** name); #define MXLIB_PARTREGGET_STR "_partRegGet" typedef void (*partRegGet_t)(int part_idx, int stg_idx, const char** strategy, - supportedOps_t* supportedOps, reviewSubgraph_t* reviewSubgraph, - const char** op_name); + supportedOps_t* supportedOps, createSelector_t* createSelector, + reviewSubgraph_t* reviewSubgraph, const char** op_name); #define MXLIB_PARTCALLSUPPORTEDOPS_STR "_partCallSupportedOps" typedef int (*partCallSupportedOps_t)(supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char* const* opt_keys, const char* const* opt_vals, int num_opts); +#define MXLIB_PARTCALLCREATESELECTOR_STR "_partCallCreateSelector" +typedef int (*partCallCreateSelector_t)(createSelector_t createSelector, const char *json, + void** selector, const char* const* opt_keys, + const char* const* opt_vals, int num_opts); + +#define MXLIB_PARTCALLSELECT_STR "_partCallSelect" +typedef void (*partCallSelect_t)(void* sel_inst, int nodeID, int* selected); + +#define MXLIB_PARTCALLSELECTINPUT_STR "_partCallSelectInput" +typedef void (*partCallSelectInput_t)(void* sel_inst, int nodeID, int input_nodeID, + int* selected); + +#define MXLIB_PARTCALLSELECTOUTPUT_STR "_partCallSelectOutput" +typedef void (*partCallSelectOutput_t)(void* sel_inst, int nodeID, int output_nodeID, + int* selected); + +#define MXLIB_PARTCALLFILTER_STR "_partCallFilter" +typedef void (*partCallFilter_t)(void* sel_inst, int* candidates, int num_candidates, + int** keep, int* num_keep); + +#define MXLIB_PARTCALLRESET_STR "_partCallReset" +typedef void (*partCallReset_t)(void* sel_inst); + #define MXLIB_PARTCALLREVIEWSUBGRAPH_STR "_partCallReviewSubgraph" typedef int (*partCallReviewSubgraph_t)(reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char* const* opt_keys, @@ -1093,45 +1315,61 @@ typedef int (*partCallReviewSubgraph_t)(reviewSubgraph_t reviewSubgraph, const c const size_t* aux_IDs, const char* const* aux_dev_type, const int* aux_dev_id); +#define MXLIB_PASSREGSIZE_STR "_passRegSize" +typedef int (*passRegSize_t)(void); + +#define MXLIB_PASSREGGET_STR "_passRegGet" +typedef void (*passRegGet_t)(int pass_idx, graphPass_t* graphPass, const char** pass_name); + +#define MXLIB_PASSCALLGRAPHPASS_STR "_passCallGraphPass" +typedef int (*passCallGraphPass_t)(graphPass_t graphPass, const char *in_graph, + char** out_graph, const char* const* opt_keys, + const char* const* opt_vals, int num_opts, + const char* pass_name, const char* const* arg_names, + int num_args, void* const* arg_data, + const int64_t* const* arg_shapes, const int* arg_dims, + const int* arg_types, const size_t* arg_IDs, + const char* const* arg_dev_type, const int* arg_dev_id, + const char* const* aux_names, int num_aux, + void* const* aux_data, const int64_t* const* aux_shapes, + const int* aux_dims, const int* aux_types, + const size_t* aux_IDs, const char* const* aux_dev_type, + const int* aux_dev_id, nd_malloc_t nd_malloc, + const void* nd_alloc); + #define MXLIB_INITIALIZE_STR "initialize" typedef int (*initialize_t)(int version); #define MXLIB_OPVERSION_STR "_opVersion" typedef int (*opVersion_t)(); -extern "C" { - /*! \brief returns MXNet library version */ #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) int __cdecl +#define MX_INT_RET __declspec(dllexport) int __cdecl +#define MX_VOID_RET __declspec(dllexport) void __cdecl #else - int +#define MX_INT_RET int +#define MX_VOID_RET void #endif - _opVersion() { + +extern "C" { + /*! \brief returns MXNet library version */ + MX_INT_RET _opVersion() { return MX_LIBRARY_VERSION; } /*! \brief returns number of ops registered in this library */ -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) int __cdecl -#else - int -#endif - _opRegSize() { + MX_INT_RET _opRegSize() { return Registry::get()->size(); } /*! \brief returns operator registration at specified index */ -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) void __cdecl -#else - void -#endif - _opRegGet(int idx, const char** name, int *isSGop, - const char*** forward_ctx, fcomp_t** forward_fp, int* forward_count, - const char*** backward_ctx, fcomp_t** backward_fp, int* backward_count, - const char*** create_op_ctx, createOpState_t** create_op_fp, int* create_op_count, - parseAttrs_t* parse, inferType_t* type, inferSType_t* stype, - inferShape_t* shape, mutateInputs_t* mutate) { + MX_VOID_RET _opRegGet(int idx, const char** name, int *isSGop, + const char*** forward_ctx, fcomp_t** forward_fp, + int* forward_count, const char*** backward_ctx, + fcomp_t** backward_fp, int* backward_count, + const char*** create_op_ctx, createOpState_t** create_op_fp, + int* create_op_count, parseAttrs_t* parse, inferType_t* type, + inferSType_t* stype, inferShape_t* shape, mutateInputs_t* mutate) { CustomOp &op = Registry::get()->get(idx); *name = op.name; *parse = op.parse_attrs; @@ -1153,26 +1391,16 @@ extern "C" { } /*! \brief calls free from the external library for library allocated arrays */ -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) void __cdecl -#else - void -#endif - _opCallFree(void* ptr) { + MX_VOID_RET _opCallFree(void* ptr) { free(ptr); } /*! \brief returns status of calling parse attributes function for operator from library */ -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) int __cdecl -#else - int -#endif - _opCallParseAttrs(parseAttrs_t parseAttrs, const char* const* keys, - const char* const* vals, int num, - int* num_in, int* num_out) { + MX_INT_RET _opCallParseAttrs(parseAttrs_t parseAttrs, const char* const* keys, + const char* const* vals, int num, + int* num_in, int* num_out) { // create map of attributes from list - std::map attrs; + std::unordered_map attrs; for (int i = 0; i < num; i++) { attrs[std::string(keys[i])] = std::string(vals[i]); } @@ -1181,17 +1409,13 @@ extern "C" { } /*! \brief returns status of calling inferShape function for operator from library */ -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) int __cdecl -#else - int -#endif - _opCallInferShape(inferShape_t inferShape, const char* const* keys, - const char* const* vals, int num, - unsigned int** inshapes, int* indims, int num_in, - unsigned int*** outshapes, int** outdims, int num_out) { + MX_INT_RET _opCallInferShape(inferShape_t inferShape, const char* const* keys, + const char* const* vals, int num, + unsigned int** inshapes, int* indims, int num_in, + unsigned int*** mod_inshapes, int** mod_indims, + unsigned int*** outshapes, int** outdims, int num_out) { // create map of attributes from list - std::map attrs; + std::unordered_map attrs; for (int i = 0; i < num; i++) { attrs[std::string(keys[i])] = std::string(vals[i]); } @@ -1207,9 +1431,21 @@ extern "C" { // create a vector of shapes for outputs std::vector > out_shapes(num_out); - int retval = inferShape(attrs, in_shapes, out_shapes); - if (!retval) - return retval; + int retval = inferShape(attrs, &in_shapes, &out_shapes); + if (!retval) return retval; + + // allocate space for modified input dims, shape + *mod_indims = static_cast(malloc (num_in * sizeof(int))); + *mod_inshapes = static_cast(malloc (num_in * sizeof(unsigned*))); + + // copy modified input shapes + for (int i = 0; i < num_in; i++) { + (*mod_indims)[i] = in_shapes[i].size(); + (*mod_inshapes)[i] = static_cast(malloc ((*mod_indims)[i] * sizeof(unsigned))); + for (int j = 0; j < (*mod_indims)[i]; j++) { + (*mod_inshapes)[i][j] = in_shapes[i][j]; + } + } // allocate space for output dims, shape *outdims = static_cast(malloc (num_out * sizeof(int))); @@ -1219,7 +1455,7 @@ extern "C" { for (int i = 0; i < num_out; i++) { (*outdims)[i] = out_shapes[i].size(); (*outshapes)[i] = static_cast(malloc ((*outdims)[i] * sizeof(unsigned))); - for (int j = 0; j < indims[i]; j++) { + for (int j = 0; j < (*outdims)[i]; j++) { (*outshapes)[i][j] = out_shapes[i][j]; } } @@ -1228,16 +1464,11 @@ extern "C" { } /*! \brief returns status of calling inferType function for operator from library */ -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) int __cdecl -#else - int -#endif - _opCallInferType(inferType_t inferType, const char* const* keys, - const char* const* vals, int num, - int* intypes, int num_in, int* outtypes, int num_out) { + MX_INT_RET _opCallInferType(inferType_t inferType, const char* const* keys, + const char* const* vals, int num, + int* intypes, int num_in, int* outtypes, int num_out) { // create map of attributes from list - std::map attrs; + std::unordered_map attrs; for (int i = 0; i < num; i++) { attrs[std::string(keys[i])] = std::string(vals[i]); } @@ -1251,10 +1482,14 @@ extern "C" { // create a vector of types for outputs std::vector out_types(num_out, -1); - int retval = inferType(attrs, in_types, out_types); + int retval = inferType(attrs, &in_types, &out_types); if (!retval) return retval; + // copy modified input types + for (int i = 0; i < num_in; i++) { + intypes[i] = in_types[i]; + } // copy output types for (int i = 0; i < num_out; i++) { outtypes[i] = out_types[i]; @@ -1264,16 +1499,11 @@ extern "C" { } /*! \brief returns status of calling inferSType function for operator from library */ -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) int __cdecl -#else - int -#endif - _opCallInferSType(inferSType_t inferSType, const char* const* keys, - const char* const* vals, int num, - int* instypes, int num_in, int* outstypes, int num_out) { + MX_INT_RET _opCallInferSType(inferSType_t inferSType, const char* const* keys, + const char* const* vals, int num, + int* instypes, int num_in, int* outstypes, int num_out) { // create map of attributes from list - std::map attrs; + std::unordered_map attrs; for (int i = 0; i < num; i++) { attrs[std::string(keys[i])] = std::string(vals[i]); } @@ -1287,11 +1517,15 @@ extern "C" { // create a vector of types for outputs std::vector out_stypes(num_out, -1); - int retval = inferSType(attrs, in_stypes, out_stypes); + int retval = inferSType(attrs, &in_stypes, &out_stypes); if (!retval) return retval; + // copy modified input storage types + for (int i = 0; i < num_in; i++) { + instypes[i] = in_stypes[i]; + } // copy output storage types for (int i = 0; i < num_out; i++) { outstypes[i] = out_stypes[i]; @@ -1301,26 +1535,21 @@ extern "C" { } /*! \brief returns status of calling Forward/Backward function for operator from library */ -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) int __cdecl -#else - int -#endif - _opCallFCompute(fcomp_t fcomp, const char* const* keys, const char* const* vals, int num, - const int64_t** inshapes, int* indims, void** indata, int* intypes, - size_t* inIDs, const char** indev_type, int* indev_id, int num_in, - const int64_t** outshapes, int* outdims, void** outdata, int* outtypes, - size_t* outIDs, const char** outdev_type, int* outdev_id, int num_out, - xpu_malloc_t cpu_malloc, void* cpu_alloc, - xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream, - sparse_malloc_t sparse_malloc, void* sparse_alloc, - int* instypes, int* outstypes, void** in_indices, void** out_indices, - void** in_indptr, void** out_indptr, - int64_t* in_indices_shapes, int64_t* out_indices_shapes, - int64_t* in_indptr_shapes, int64_t* out_indptr_shapes, - void* rng_cpu_states, void* rng_gpu_states) { + MX_INT_RET _opCallFCompute(fcomp_t fcomp, const char* const* keys, const char* const* vals, + int num, const int64_t** inshapes, int* indims, void** indata, + int* intypes, size_t* inIDs, const char** indev_type, int* indev_id, + int num_in, const int64_t** outshapes, int* outdims, void** outdata, + int* outtypes, size_t* outIDs, const char** outdev_type, + int* outdev_id, int num_out, xpu_malloc_t cpu_malloc, void* cpu_alloc, + xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream, + sparse_malloc_t sparse_malloc, void* sparse_alloc, + int* instypes, int* outstypes, void** in_indices, void** out_indices, + void** in_indptr, void** out_indptr, + int64_t* in_indices_shapes, int64_t* out_indices_shapes, + int64_t* in_indptr_shapes, int64_t* out_indptr_shapes, + void* rng_cpu_states, void* rng_gpu_states) { // create map of attributes from list - std::map attrs; + std::unordered_map attrs; for (int i = 0; i < num; i++) { attrs[std::string(keys[i])] = std::string(vals[i]); } @@ -1334,7 +1563,7 @@ extern "C" { // Dense representation. if (instypes[i] == 0) { inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i], - inIDs[i], {indev_type[i], indev_id[i]}, kDefaultStorage); + inIDs[i], MXContext(indev_type[i], indev_id[i]), kDefaultStorage); } else { // Sparse representation. MXStorageType type; @@ -1347,7 +1576,8 @@ extern "C" { in_indices_shapes[i], in_indptr[i], in_indptr_shapes[i]); } inputs[i].setTensor(reinterpret_cast(&in_sparse[i]), (MXDType)intypes[i], - inshapes[i], indims[i], inIDs[i], {indev_type[i], indev_id[i]}, type); + inshapes[i], indims[i], inIDs[i], + MXContext(indev_type[i], indev_id[i]), type); } } @@ -1359,7 +1589,7 @@ extern "C" { // Dense representation. if (outstypes[i] == 0) { outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i], - outIDs[i], {outdev_type[i], outdev_id[i]}, kDefaultStorage); + outIDs[i], MXContext(outdev_type[i], outdev_id[i]), kDefaultStorage); } else { // Sparse representation. MXStorageType type; @@ -1373,27 +1603,22 @@ extern "C" { out_indices_shapes[i], out_indptr[i], out_indptr_shapes[i]); } outputs[i].setTensor(reinterpret_cast(&out_sparse[i]), (MXDType)outtypes[i], - outshapes[i], outdims[i], outIDs[i], {outdev_type[i], - outdev_id[i]}, type); + outshapes[i], outdims[i], outIDs[i], + MXContext(outdev_type[i], outdev_id[i]), type); } } OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, cuda_stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states); - return fcomp(attrs, inputs, outputs, res); + return fcomp(attrs, &inputs, &outputs, res); } /*! \brief returns status of calling mutateInputs function for operator from library */ -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) int __cdecl -#else - int -#endif - _opCallMutateInputs(mutateInputs_t mutate, const char* const* keys, - const char* const* vals, int num, - int** mutate_indices, int* indices_size) { + MX_INT_RET _opCallMutateInputs(mutateInputs_t mutate, const char* const* keys, + const char* const* vals, int num, + int** mutate_indices, int* indices_size) { // create map of attributes from list - std::map attrs; + std::unordered_map attrs; for (int i = 0; i < num; i++) { attrs[std::string(keys[i])] = std::string(vals[i]); } @@ -1401,7 +1626,7 @@ extern "C" { // create a vector of mutate input indices std::vector mut_ind; - int retval = mutate(attrs, mut_ind); + int retval = mutate(attrs, &mut_ind); if (!retval) return retval; @@ -1416,16 +1641,11 @@ extern "C" { } /*! \brief returns status of calling createStatefulOp function for operator from library */ -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) int __cdecl -#else - int -#endif - _opCallCreateOpState(createOpState_t create_op, const char* const* keys, - const char* const* vals, int num, - void** state_op) { + MX_INT_RET _opCallCreateOpState(createOpState_t create_op, const char* const* keys, + const char* const* vals, int num, + void** state_op) { // create map of attributes from list - std::map attrs; + std::unordered_map attrs; for (int i = 0; i < num; i++) { attrs[std::string(keys[i])] = std::string(vals[i]); } @@ -1437,24 +1657,20 @@ extern "C" { } /*! \brief returns status of calling Stateful Forward/Backward for operator from library */ -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) int __cdecl -#else - int -#endif - _opCallFStatefulCompute(int is_forward, void* state_op, - const int64_t** inshapes, int* indims, void** indata, int* intypes, - size_t* inIDs, const char** indev_type, int* indev_id, int num_in, - const int64_t** outshapes, int* outdims, void** outdata, int* outtypes, - size_t* outIDs, const char** outdev_type, int* outdev_id, int num_out, - xpu_malloc_t cpu_malloc, void* cpu_alloc, - xpu_malloc_t gpu_malloc, void* gpu_alloc, void* stream, - sparse_malloc_t sparse_malloc, void* sparse_alloc, - int* instypes, int* outstypes, void** in_indices, void** out_indices, - void** in_indptr, void** out_indptr, - int64_t* in_indices_shapes, int64_t* out_indices_shapes, - int64_t* in_indptr_shapes, int64_t* out_indptr_shapes, - void* rng_cpu_states, void* rng_gpu_states) { + MX_INT_RET _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t** inshapes, + int* indims, void** indata, int* intypes, size_t* inIDs, + const char** indev_type, int* indev_id, int num_in, + const int64_t** outshapes, int* outdims, void** outdata, + int* outtypes, size_t* outIDs, const char** outdev_type, + int* outdev_id, int num_out, xpu_malloc_t cpu_malloc, + void* cpu_alloc, xpu_malloc_t gpu_malloc, void* gpu_alloc, + void* stream, sparse_malloc_t sparse_malloc, + void* sparse_alloc, int* instypes, int* outstypes, + void** in_indices, void** out_indices, void** in_indptr, + void** out_indptr, int64_t* in_indices_shapes, + int64_t* out_indices_shapes, int64_t* in_indptr_shapes, + int64_t* out_indptr_shapes, + void* rng_cpu_states, void* rng_gpu_states) { // create a vector of tensors for inputs std::vector inputs(num_in); // create a vector for sparse inputs @@ -1464,7 +1680,7 @@ extern "C" { if (instypes[i] == 0) { // Dense representation. inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i], - inIDs[i], {indev_type[i], indev_id[i]}, kDefaultStorage); + inIDs[i], MXContext(indev_type[i], indev_id[i]), kDefaultStorage); } else { // Sparse representation. MXStorageType type; @@ -1477,8 +1693,8 @@ extern "C" { in_indices_shapes[i], in_indptr[i], in_indptr_shapes[i]); } inputs[i].setTensor(reinterpret_cast(&in_sparse[i]), (MXDType)intypes[i], - inshapes[i], indims[i], inIDs[i], {indev_type[i], - indev_id[i]}, type); + inshapes[i], indims[i], inIDs[i], + MXContext(indev_type[i], indev_id[i]), type); } } @@ -1491,7 +1707,7 @@ extern "C" { if (outstypes[i] == 0) { // Dense representation. outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i], - outIDs[i], {outdev_type[i], outdev_id[i]}, kDefaultStorage); + outIDs[i], MXContext(outdev_type[i], outdev_id[i]), kDefaultStorage); } else { // Sparse representation. MXStorageType type; @@ -1505,8 +1721,8 @@ extern "C" { out_indices_shapes[i], out_indptr[i], out_indptr_shapes[i]); } outputs[i].setTensor(reinterpret_cast(&out_sparse[i]), (MXDType)outtypes[i], - outshapes[i], outdims[i], outIDs[i], {outdev_type[i], - outdev_id[i]}, type); + outshapes[i], outdims[i], outIDs[i], + MXContext(outdev_type[i], outdev_id[i]), type); } } @@ -1515,68 +1731,50 @@ extern "C" { CustomStatefulOp* op_ptr = reinterpret_cast(state_op); if (is_forward) { - return op_ptr->Forward(inputs, outputs, res); + return op_ptr->Forward(&inputs, &outputs, res); } - return op_ptr->Backward(inputs, outputs, res); + return op_ptr->Backward(&inputs, &outputs, res); } /*! \brief returns number of partitioners registered in this library */ -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) int __cdecl -#else - int -#endif - _partRegSize() { + MX_INT_RET _partRegSize() { return Registry::get()->size(); } /* returns number of strategies registered for partitioner * at specified index */ -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) int __cdecl -#else - int -#endif - _partRegGetCount(int idx, const char** name) { + MX_INT_RET _partRegGetCount(int idx, const char** name) { CustomPartitioner part = Registry::get()->get(idx); *name = part.name; return part.strategies.size(); } /*! \brief returns partitioner registration at specified index */ -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) void __cdecl -#else - void -#endif - _partRegGet(int part_idx, int stg_idx, const char** strategy, supportedOps_t* supportedOps, - reviewSubgraph_t* reviewSubgraph, const char** op_name) { + MX_VOID_RET _partRegGet(int part_idx, int stg_idx, const char** strategy, + supportedOps_t* supportedOps, createSelector_t* createSelector, + reviewSubgraph_t* reviewSubgraph, const char** op_name) { CustomPartitioner part = Registry::get()->get(part_idx); *strategy = part.strategies[stg_idx]; - *supportedOps = part.supportedOps[stg_idx]; *op_name = part.op_names[stg_idx]; + *supportedOps = part.getSupportedOps(stg_idx); + *createSelector = part.getCreateSelector(stg_idx); *reviewSubgraph = part.getReviewSubgraph(stg_idx); } - /*! \brief returns status of calling parse attributes function for operator from library */ -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) int __cdecl -#else - int -#endif - _partCallSupportedOps(supportedOps_t supportedOps, const char *json, - int num_ids, int *ids, const char* const* opt_keys, - const char* const* opt_vals, int num_opts) { + /*! \brief returns status of calling supported ops function from library */ + MX_INT_RET _partCallSupportedOps(supportedOps_t supportedOps, const char *json, + int num_ids, int *ids, const char* const* opt_keys, + const char* const* opt_vals, int num_opts) { std::string subgraph_json(json); // create map of options from list std::unordered_map opts; for (int i = 0; i < num_opts; i++) opts[std::string(opt_keys[i])] = std::string(opt_vals[i]); - // create array of bools for operator support - std::vector _ids(num_ids, false); + // create array of subgraph IDs for operator support + std::vector _ids(num_ids, -2); // call user's supportedOps function - MXReturnValue retval = supportedOps(subgraph_json, _ids, opts); + MXReturnValue retval = supportedOps(subgraph_json, &_ids, opts); if (!retval) return retval; // copy bools in ids to ints @@ -1586,26 +1784,83 @@ extern "C" { return retval; } - /*! \brief returns status of calling parse attributes function for operator from library */ -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - __declspec(dllexport) int __cdecl -#else - int -#endif - _partCallReviewSubgraph(reviewSubgraph_t reviewSubgraph, const char *json, - int subgraph_id, int *accept, const char* const* opt_keys, - const char* const* opt_vals, int num_opts, - char*** attr_keys, char*** attr_vals, int *num_attrs, - const char* const* arg_names, int num_args, - void* const* arg_data, const int64_t* const* arg_shapes, - const int* arg_dims, const int* arg_types, - const size_t* arg_IDs, const char* const* arg_dev_type, - const int* arg_dev_id, - const char* const* aux_names, int num_aux, - void* const* aux_data, const int64_t* const* aux_shapes, - const int* aux_dims, const int* aux_types, - const size_t* aux_IDs, const char* const* aux_dev_type, - const int* aux_dev_id) { + /*! \brief returns status of calling create selector function from library */ + MX_INT_RET _partCallCreateSelector(createSelector_t createSelector, const char *json, + void** selector, const char* const* opt_keys, + const char* const* opt_vals, int num_opts) { + std::string symbol_json(json); + // create map of options from list + std::unordered_map opts; + for (int i = 0; i < num_opts; i++) + opts[std::string(opt_keys[i])] = std::string(opt_vals[i]); + + // void pointer to hold selector instance created in custom library + // eventually pointer is populated by instance from custom library + CustomOpSelector** sel_ptr = reinterpret_cast(selector); + + // call user's createSelector function + return createSelector(symbol_json, sel_ptr, opts); + } + + /*! \brief returns status of calling select function from library */ + MX_VOID_RET _partCallSelect(void* sel_inst, int nodeID, int* selected) { + CustomOpSelector* sel_ptr = reinterpret_cast(sel_inst); + *selected = sel_ptr->Select(nodeID); + } + + /*! \brief returns status of calling select input function from library */ + MX_VOID_RET _partCallSelectInput(void* sel_inst, int nodeID, + int input_nodeID, int* selected) { + CustomOpSelector* sel_ptr = reinterpret_cast(sel_inst); + *selected = sel_ptr->SelectInput(nodeID, input_nodeID); + } + + /*! \brief returns status of calling select output function from library */ + MX_VOID_RET _partCallSelectOutput(void* sel_inst, int nodeID, + int output_nodeID, int* selected) { + CustomOpSelector* sel_ptr = reinterpret_cast(sel_inst); + *selected = sel_ptr->SelectOutput(nodeID, output_nodeID); + } + + /*! \brief returns status of calling filter function from library */ + MX_VOID_RET _partCallFilter(void* sel_inst, int* candidates, int num_candidates, + int** keep, int* num_keep) { + CustomOpSelector* sel_ptr = reinterpret_cast(sel_inst); + std::vector candidates_(num_candidates); + for (int i=0; i < num_candidates; i++) { + candidates_[i] = candidates[i]; + } + std::vector keep_; + + sel_ptr->Filter(candidates_, &keep_); + + *num_keep = keep_.size(); + *keep = static_cast(malloc(keep_.size() * sizeof(int))); + for (unsigned i=0; i < keep_.size(); i++) + (*keep)[i] = keep_[i]; + } + + /*! \brief returns status of calling reset selector function from library */ + MX_VOID_RET _partCallReset(void* sel_inst) { + CustomOpSelector* sel_ptr = reinterpret_cast(sel_inst); + sel_ptr->Reset(); + } + + /*! \brief returns status of calling review subgraph function from library */ + MX_INT_RET _partCallReviewSubgraph(reviewSubgraph_t reviewSubgraph, const char *json, + int subgraph_id, int *accept, const char* const* opt_keys, + const char* const* opt_vals, int num_opts, + char*** attr_keys, char*** attr_vals, int *num_attrs, + const char* const* arg_names, int num_args, + void* const* arg_data, const int64_t* const* arg_shapes, + const int* arg_dims, const int* arg_types, + const size_t* arg_IDs, const char* const* arg_dev_type, + const int* arg_dev_id, + const char* const* aux_names, int num_aux, + void* const* aux_data, const int64_t* const* aux_shapes, + const int* aux_dims, const int* aux_types, + const size_t* aux_IDs, const char* const* aux_dev_type, + const int* aux_dev_id) { std::string subgraph_json(json); bool accept_bool = false; // create map of attributes from list @@ -1614,34 +1869,33 @@ extern "C" { opts[std::string(opt_keys[i])] = std::string(opt_vals[i]); // create a map of named tensors for args - std::map args; + std::unordered_map args; for (int i = 0; i < num_args; i++) { std::vector shapes; for (int j = 0; j < arg_dims[i]; j++) shapes.push_back(arg_shapes[i][j]); MXTensor tensor(arg_data[i], shapes, (MXDType)arg_types[i], - arg_IDs[i], {arg_dev_type[i], arg_dev_id[i]}); + arg_IDs[i], MXContext(arg_dev_type[i], arg_dev_id[i])); args[arg_names[i]] = tensor; } // create a map of named tensors for aux - std::map aux; + std::unordered_map aux; for (int i = 0; i < num_aux; i++) { std::vector shapes; for (int j = 0; j < aux_dims[i]; j++) shapes.push_back(aux_shapes[i][j]); MXTensor tensor(aux_data[i], shapes, (MXDType)aux_types[i], - aux_IDs[i], {aux_dev_type[i], aux_dev_id[i]}); + aux_IDs[i], MXContext(aux_dev_type[i], aux_dev_id[i])); aux[aux_names[i]] = tensor; } - // attributes to set on subgraph node std::unordered_map attrs; MXReturnValue retval = reviewSubgraph(subgraph_json, subgraph_id, &accept_bool, - opts, attrs, args, aux); + opts, &attrs, args, aux); if (!retval) return retval; *accept = accept_bool; @@ -1666,6 +1920,79 @@ extern "C" { return retval; } + /*! \brief returns number of graph passes registered in this library */ + MX_INT_RET _passRegSize() { + return Registry::get()->size(); + } + + /*! \brief returns pass registration at specified index */ + MX_VOID_RET _passRegGet(int pass_idx, graphPass_t* graphPass, + const char** pass_name) { + CustomPass pass = Registry::get()->get(pass_idx); + *graphPass = pass.pass; + *pass_name = pass.name; + } + + /*! \brief returns status of calling graph pass function from library */ + MX_INT_RET _passCallGraphPass(graphPass_t graphPass, const char *json, + char** graph, const char* const* opt_keys, + const char* const* opt_vals, int num_opts, + const char* pass_name, const char* const* arg_names, int num_args, + void* const* arg_data, const int64_t* const* arg_shapes, + const int* arg_dims, const int* arg_types, + const size_t* arg_IDs, const char* const* arg_dev_type, + const int* arg_dev_id, const char* const* aux_names, int num_aux, + void* const* aux_data, const int64_t* const* aux_shapes, + const int* aux_dims, const int* aux_types, + const size_t* aux_IDs, const char* const* aux_dev_type, + const int* aux_dev_id, nd_malloc_t nd_malloc, + const void* nd_alloc) { + std::string graph_json(json); + const std::string* out_graph = nullptr; + // create map of attributes from list + std::unordered_map opts; + for (int i = 0; i < num_opts; i++) + opts[std::string(opt_keys[i])] = std::string(opt_vals[i]); + + // create a map of named tensors for args + std::unordered_map args; + for (int i = 0; i < num_args; i++) { + std::vector shapes; + for (int j = 0; j < arg_dims[i]; j++) + shapes.push_back(arg_shapes[i][j]); + + MXTensor tensor(arg_data[i], shapes, (MXDType)arg_types[i], + arg_IDs[i], MXContext(arg_dev_type[i], arg_dev_id[i])); + args[arg_names[i]] = tensor; + } + // create a map of named tensors for aux + std::unordered_map aux; + for (int i = 0; i < num_aux; i++) { + std::vector shapes; + for (int j = 0; j < aux_dims[i]; j++) + shapes.push_back(aux_shapes[i][j]); + + MXTensor tensor(aux_data[i], shapes, (MXDType)aux_types[i], + aux_IDs[i], MXContext(aux_dev_type[i], aux_dev_id[i])); + aux[aux_names[i]] = tensor; + } + + std::unordered_map new_args, new_aux; + PassResource res(&new_args, &new_aux, nd_malloc, nd_alloc); + MXReturnValue retval = graphPass(graph_json, &out_graph, opts, args, aux, res); + if (!retval) return retval; + + if (out_graph == nullptr) { + std::cout << "Error calling graph pass '" << pass_name + << "' returned out_graph string is null" << std::endl; + return MX_FAIL; + } + *graph = static_cast(malloc((out_graph->length()+1) * sizeof(char))); + out_graph->copy(*graph, out_graph->size()+1); + delete out_graph; + return retval; + } + /*! * \brief Checks if the MXNet version is supported by the library. * If supported, initializes the library. diff --git a/perl-package/AI-MXNetCAPI/mxnet.i b/perl-package/AI-MXNetCAPI/mxnet.i index 846b28ff0e34..9602b08d675e 100644 --- a/perl-package/AI-MXNetCAPI/mxnet.i +++ b/perl-package/AI-MXNetCAPI/mxnet.i @@ -1627,16 +1627,22 @@ int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle *out); * \param vals values corresponding to keys */ int MXOptimizeForBackend(SymbolHandle sym_handle, - const char* in, - const int dev_type, - SymbolHandle* in, - const mx_uint in, - NDArrayHandle* in, - const mx_uint in, - NDArrayHandle* in, - const mx_uint in, - const char** keys, - const char** vals); + const char* in, + const int dev_type, + SymbolHandle* in, + const mx_uint in, + NDArrayHandle* in, + const mx_uint in, + NDArrayHandle* in, + const mx_uint in, + const char** keys, + const char** vals, + int* new_args_cnt, + NDArrayHandle** new_args_handle, + char*** new_arg_names_handle, + int* new_aux_cnt, + NDArrayHandle** new_aux_handle, + char*** new_aux_names_handle); //-------------------------------------------- // Part 4: Executor interface diff --git a/python/mxnet/library.py b/python/mxnet/library.py index 13df2ec71298..e0c60d4588f9 100644 --- a/python/mxnet/library.py +++ b/python/mxnet/library.py @@ -20,16 +20,20 @@ import ctypes import sys import os -from .base import _LIB, check_call, MXNetError, _init_op_module +from .base import _LIB, check_call, MXNetError, _init_op_module, mx_uint from .ndarray.register import _make_ndarray_function from .symbol.register import _make_symbol_function -def load(path): +def load(path, verbose=True): """Loads library dynamically. Parameters --------- - path : Path to library .so/.dll file + path : string + Path to library .so/.dll file + + verbose : boolean + defaults to True, set to False to avoid printing library info Returns --------- @@ -46,9 +50,10 @@ def load(path): if not file_ext in ['.so', '.dll']: raise MXNetError("load path %s is NOT a library file" % path) + verbose_val = 1 if verbose else 0 byt_obj = path.encode('utf-8') chararr = ctypes.c_char_p(byt_obj) - check_call(_LIB.MXLoadLib(chararr)) + check_call(_LIB.MXLoadLib(chararr, mx_uint(verbose_val))) #regenerate operators _init_op_module('mxnet', 'ndarray', _make_ndarray_function) diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 0a19018b6e62..37b91867f84c 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1445,6 +1445,7 @@ def _gen_atomic_symbol(self): return Symbol(handle) + # pylint: disable=too-many-locals def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): """Partitions current symbol and optimizes it for a given backend, returns new partitioned symbol. @@ -1499,6 +1500,13 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): ctx = current_context() assert isinstance(ctx, Context) + new_args_size = ctypes.c_uint() + new_arg_names = ctypes.POINTER(ctypes.c_char_p)() + new_args_handle = ctypes.POINTER(NDArrayHandle)() + new_aux_size = ctypes.c_uint() + new_aux_names = ctypes.POINTER(ctypes.c_char_p)() + new_aux_handle = ctypes.POINTER(NDArrayHandle)() + key_list = [] val_list = [] for key, val in kwargs.items(): @@ -1514,7 +1522,37 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): aux_handle, mx_uint(len(key_list)), c_str_array(key_list), - c_str_array(val_list))) + c_str_array(val_list), + ctypes.byref(new_args_size), + ctypes.byref(new_args_handle), + ctypes.byref(new_arg_names), + ctypes.byref(new_aux_size), + ctypes.byref(new_aux_handle), + ctypes.byref(new_aux_names))) + arg_names = self.list_arguments() + if isinstance(args, dict): + for i in range(new_args_size.value): + args[py_str(new_arg_names[i])] = NDArray(NDArrayHandle(new_args_handle[i])) + elif isinstance(args, list): + for i in range(new_args_size.value): + name = py_str(new_arg_names[i]) + if name in arg_names: + idx = arg_names.index(name) + args[idx] = NDArray(NDArrayHandle(new_args_handle[i])) + else: + args.append(NDArray(NDArrayHandle(new_args_handle[i]))) + aux_names = self.list_auxiliary_states() + if isinstance(aux, dict): + for i in range(new_aux_size.value): + aux[py_str(new_aux_names[i])] = NDArray(NDArrayHandle(new_aux_handle[i])) + elif isinstance(aux, list): + for i in range(new_aux_size.value): + name = py_str(new_aux_names[i]) + if name in aux_names: + idx = aux_names.index(name) + aux[idx] = NDArray(NDArrayHandle(new_aux_handle[i])) + else: + aux.append(NDArray(NDArrayHandle(new_aux_handle[i]))) return Symbol(out) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 09fede6bd056..b9511aef15f7 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -209,7 +209,7 @@ void CustomFComputeDispatcher(const std::string op_name, // create lambda that allocates memory for sparse and // returns allocated arrays for data, indices and indptr. auto sparse_alloc = [&](int index, int indices_len, int idxptr_len, - void** data, int64_t** indices, int64_t** indptr) { + void** data, int64_t** indices, int64_t** indptr) { if (idxptr_len == 0) { // Row Sparse outputs[index].CheckAndAlloc({mshadow::Shape1(indices_len)}); @@ -320,28 +320,7 @@ void CustomFComputeDispatcher(const std::string op_name, } } -/*! - * \brief Loads dynamic custom library and initializes it - * \param path library path - */ -int MXLoadLib(const char *path) { - API_BEGIN(); - void *lib = LibraryInitializer::Get()->lib_load(path); - if (!lib) - LOG(FATAL) << "Unable to load library"; - - // check that library and MXNet use same version of library API - opVersion_t opVersion = get_func(lib, const_cast(MXLIB_OPVERSION_STR)); - int libVersion = opVersion(); - if (MX_LIBRARY_VERSION != libVersion) - LOG(FATAL) << "Library version (" << libVersion << ") does not match MXNet version (" - << MX_LIBRARY_VERSION << ")"; - - // initialize library by passing MXNet version - initialize_t initialize = get_func(lib, const_cast(MXLIB_INITIALIZE_STR)); - if (!initialize(static_cast(MXNET_VERSION))) - LOG(FATAL) << "Library failed to initialize"; - +void registerOperators(void *lib, int verbose) { // get C type interface functions opCallFree_t callFree = get_func(lib, const_cast(MXLIB_OPCALLFREE_STR)); @@ -369,16 +348,10 @@ int MXLoadLib(const char *path) { opCallFStatefulComp_t callFStatefulComp = get_func(lib, const_cast(MXLIB_OPCALLFSTATEFULCOMP_STR)); - partCallSupportedOps_t callSupportedOps = - get_func(lib, const_cast(MXLIB_PARTCALLSUPPORTEDOPS_STR)); - - partCallReviewSubgraph_t callReviewSubgraph = - get_func(lib, const_cast(MXLIB_PARTCALLREVIEWSUBGRAPH_STR)); - // get number of operators registered in the library opRegSize_t opRegSize = get_func(lib, const_cast(MXLIB_OPREGSIZE_STR)); int numOps = opRegSize(); - LOG(INFO) << "Found " << numOps << " operators in library"; + if (verbose) LOG(INFO) << "Found " << numOps << " operators in library"; /* * Get all custom operators implementation from custom library @@ -443,8 +416,8 @@ int MXLoadLib(const char *path) { CHECK(createop_map.size() != 0) << "Error loading '" << name << "' custom subgraph op, CreateOpState function was not set."; } - LOG(INFO) << "\tOp[" << i << "] " << name; - if (isSubgraphOp) LOG(INFO) << "\t\tisSubgraphOp"; + if (verbose) LOG(INFO) << "\tOp[" << i << "] " << name; + if (verbose && isSubgraphOp) LOG(INFO) << "\t\tisSubgraphOp"; std::string name_str(name); /* @@ -564,15 +537,42 @@ int MXLoadLib(const char *path) { } } + // modified input shapes will be allocated by infer shape function + uint32_t** mod_inshapes = nullptr; + int* mod_indims = nullptr; // output shapes will be allocated by infer shape function uint32_t** outshapes = nullptr; int* outdims = nullptr; CHECK(callInferShape(shape_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), inshapes.data(), indims.data(), in_shape->size(), + &mod_inshapes, &mod_indims, &outshapes, &outdims, out_shape->size())) << "Error calling InferShape for custom operator '" << name_str << "'"; + std::vector in_shapes(in_shape->size()); + // determine amount of memory needed to store all the modified input shapes + buff_size = 0; + for (unsigned i = 0; i < in_shape->size(); i++) { + buff_size += mod_indims[i]; + } + + // copy modified input shapes from custom op memory to MXNet memory + std::vector mod_inbuff(buff_size); + ptr = mod_inbuff.data(); + for (unsigned i = 0; i < in_shape->size(); ++i) { + in_shapes[i] = ptr; + for (int j = 0; j < mod_indims[i]; ++j, ++ptr) { + *ptr = static_cast(mod_inshapes[i][j]); + } + } + + // assign modified input shapes to ShapeVector + for (unsigned i = 0; i < in_shape->size(); ++i) { + SHAPE_ASSIGN_CHECK(*in_shape, i, + mxnet::TShape(in_shapes[i], in_shapes[i]+mod_indims[i])); + } + std::vector out_shapes(out_shape->size()); // determine amount of memory needed to store all the output shapes buff_size = 0; @@ -597,6 +597,12 @@ int MXLoadLib(const char *path) { } // free memory used by custom op to allocate shapes/dims + callFree(mod_indims); + for (unsigned i = 0; i < in_shape->size(); i++) { + callFree(mod_inshapes[i]); + } + callFree(mod_inshapes); + callFree(outdims); for (unsigned i = 0; i < out_shape->size(); i++) { callFree(outshapes[i]); @@ -628,6 +634,10 @@ int MXLoadLib(const char *path) { outtypes.data(), out_type->size())) << "Error calling InferType for custom operator '" << name_str << "'"; + // copy and assign modified input types from custom op to MXNet memory + for (size_t i = 0; i < in_type->size(); i++) { + TYPE_ASSIGN_CHECK(*in_type, i, intypes[i]); + } // copy and assign output types from custom op to MXNet memory for (size_t i = 0; i < out_type->size(); i++) { TYPE_ASSIGN_CHECK(*out_type, i, outtypes[i]); @@ -693,6 +703,10 @@ int MXLoadLib(const char *path) { outstypes.data(), out_stypes->size())) << "Error calling InferSType for custom operator '" << name_str << "'"; + // copy and assign modified input storage types from custom op to MXNet memory. + for (size_t i = 0; i < in_stypes->size(); i++) { + STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, i, instypes[i]); + } // copy and assign output storage types from custom op to MXNet memory. for (size_t i = 0; i < out_stypes->size(); i++) { STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, outstypes[i]); @@ -937,12 +951,41 @@ int MXLoadLib(const char *path) { } regOp.add_argument("data", "NDArray[]", "Source inputs"); } +} + +void registerPartitioners(void *lib, int verbose) { + // get C type interface functions + opCallFree_t callFree = get_func(lib, const_cast(MXLIB_OPCALLFREE_STR)); + + partCallSupportedOps_t callSupportedOps = + get_func(lib, const_cast(MXLIB_PARTCALLSUPPORTEDOPS_STR)); + + partCallCreateSelector_t callCreateSelector = + get_func(lib, const_cast(MXLIB_PARTCALLCREATESELECTOR_STR)); + + partCallSelect_t callSelect = + get_func(lib, const_cast(MXLIB_PARTCALLSELECT_STR)); + + partCallSelectInput_t callSelectInput = + get_func(lib, const_cast(MXLIB_PARTCALLSELECTINPUT_STR)); + + partCallSelectOutput_t callSelectOutput = + get_func(lib, const_cast(MXLIB_PARTCALLSELECTOUTPUT_STR)); + + partCallFilter_t callFilter = + get_func(lib, const_cast(MXLIB_PARTCALLFILTER_STR)); + + partCallReset_t callReset = + get_func(lib, const_cast(MXLIB_PARTCALLRESET_STR)); + + partCallReviewSubgraph_t callReviewSubgraph = + get_func(lib, const_cast(MXLIB_PARTCALLREVIEWSUBGRAPH_STR)); // get number of partitioners registered in the library partRegSize_t partRegSize = get_func(lib, const_cast(MXLIB_PARTREGSIZE_STR)); int numParts = partRegSize(); - LOG(INFO) << "Found " << numParts << " partitioners in library"; + if (verbose) LOG(INFO) << "Found " << numParts << " partitioners in library"; /* * Get all custom partitioners implementation from custom library @@ -958,7 +1001,7 @@ int MXLoadLib(const char *path) { CHECK(count > 0) << "Error loading '" << name << "' custom partitioner, no strategies defined"; std::string name_str(name); - LOG(INFO) << "\tPartitioner[" << i << "] " << name; + if (verbose) LOG(INFO) << "\tPartitioner[" << i << "] " << name; mxnet::op::SubgraphBackendRegistry::Get()->__REGISTER_BACKEND__(name); @@ -966,26 +1009,283 @@ int MXLoadLib(const char *path) { const char* strategy; // function pointers holding implementation from custom library supportedOps_t supportedOps_fp = nullptr; + createSelector_t createSelector_fp = nullptr; reviewSubgraph_t reviewSubgraph_fp = nullptr; // name of subgraph op const char* op_name = nullptr; - // get custom partitioner strategy from the dynamic library - partRegGet(i, j, &strategy, &supportedOps_fp, &reviewSubgraph_fp, &op_name); + // get custom partitioner strategy from the dynamic library + partRegGet(i, j, &strategy, &supportedOps_fp, &createSelector_fp, + &reviewSubgraph_fp, &op_name); // validate custom partitioner functions from the dynamic library - CHECK(supportedOps_fp != nullptr) << "Error loading '" << name - << "' custom partitioner strategy '" << strategy - << "', supportedOps function was not set."; + if (supportedOps_fp == nullptr && createSelector_fp == nullptr) + LOG(ERROR) << "Error loading '" << name << "' custom partitioner strategy '" + << strategy << "', must implement supportedOps or createSelector"; std::string strategy_str(strategy); std::string op_name_str(op_name); - LOG(INFO) << "\t\tStrategy[" << j << "] " << strategy_str - << " subgraphOp: '" << op_name_str << "'"; + if (verbose) LOG(INFO) << "\t\tStrategy[" << j << "] " << strategy_str + << " subgraphOp: '" << op_name_str << "'"; mxnet::op::SubgraphBackendRegistry::Get()->__REGISTER_CUSTOM_PROPERTY__ (name_str, std::make_shared - (strategy_str, callSupportedOps, supportedOps_fp, - callReviewSubgraph, reviewSubgraph_fp, callFree, op_name_str)); + (strategy_str, callSupportedOps, supportedOps_fp, callCreateSelector, + createSelector_fp, callSelect, callSelectInput, callSelectOutput, + callFilter, callReset, callReviewSubgraph, reviewSubgraph_fp, callFree, + op_name_str)); } } +} + +void registerPasses(void *lib, int verbose) { + // get C type interface functions + opCallFree_t callFree = get_func(lib, const_cast(MXLIB_OPCALLFREE_STR)); + + passCallGraphPass_t callGraphPass = + get_func(lib, const_cast(MXLIB_PASSCALLGRAPHPASS_STR)); + + // get number of passes registered in the library + partRegSize_t passRegSize = get_func(lib, + const_cast(MXLIB_PASSREGSIZE_STR)); + int numPasses = passRegSize(); + if (verbose) LOG(INFO) << "Found " << numPasses << " graph passes in library"; + + /* + * Get all custom pass implementation from custom library + * loop and register each pass in the library to NNVM + */ + passRegGet_t passRegGet = get_func(lib, const_cast(MXLIB_PASSREGGET_STR)); + for (int i = 0; i < numPasses; i++) { + const char* name; + // function pointers holding implementation from custom library + graphPass_t pass_fp = nullptr; + + // main function to get custom pass implemenation from the custom library + passRegGet(i, &pass_fp, &name); + + if (verbose) LOG(INFO) << "\tGraph Pass [" << i << "] " << name; + + auto pass_lambda = [=] (nnvm::Graph&& g) { + // get pass name + const char* pass_name = g.GetAttr("pass_name"); + // get options + const std::vector>& options_map = + g.GetAttr>>("options_map"); + // convert options_map_ to char* to pass to backend library + std::vector opt_keys, opt_vals; + for (auto& kv : options_map) { + opt_keys.push_back(kv.first.c_str()); + opt_vals.push_back(kv.second.c_str()); + } + + // get input args and arg names + std::vector in_arg_names = g.GetAttr>("in_arg_names"); + std::vector in_aux_names = g.GetAttr>("in_aux_names"); + NDArray **in_args_ptr = g.GetAttr("in_args"); + NDArray **in_aux_ptr = g.GetAttr("in_aux"); + + // get shapes/types + mxnet::ShapeVector shapes; + if (g.HasAttr("shape")) + shapes = g.GetAttr("shape"); + std::vector dtypes; + if (g.HasAttr("dtype")) + dtypes = g.GetAttr >("dtype"); + g.attrs.clear(); + const nnvm::IndexedGraph& indexed_graph = g.indexed_graph(); + + // set shape attrs for each node in the graph + if (shapes.size() > 0) { + for (unsigned nid = 0; nid < indexed_graph.num_nodes(); nid++) { + nnvm::Node* node = const_cast(indexed_graph[nid].source); + std::stringstream ss; + ss << "["; + // set the output shapes for this node + for (unsigned oid = 0; oid < node->num_outputs(); oid++) { + const uint32_t out_entry_id = indexed_graph.entry_id(nid, oid); + mxnet::TShape& shape = shapes[out_entry_id]; + ss << shape; + if (oid < node->num_outputs()-1) ss << ","; + } + ss << "]"; + node->attrs.dict[MX_STR_SHAPE] = ss.str(); + } + } + // set dtype attrs for each node in the graph + if (dtypes.size() > 0) { + for (unsigned nid = 0; nid < indexed_graph.num_nodes(); nid++) { + nnvm::Node* node = const_cast(indexed_graph[nid].source); + std::stringstream ss; + ss << "["; + // set the output dtypes for this node + for (unsigned oid = 0; oid < node->num_outputs(); oid++) { + const uint32_t out_entry_id = indexed_graph.entry_id(nid, oid); + int dtype = dtypes[out_entry_id]; + ss << dtype; + if (oid < node->num_outputs()-1) ss << ","; + } + ss << "]"; + node->attrs.dict[MX_STR_DTYPE] = ss.str(); + } + } + + std::vector arg_names, aux_names; + std::vector arg_data, aux_data; + std::vector arg_shapes, aux_shapes; + std::vector arg_dims, aux_dims; + std::vector arg_types, aux_types; + std::vector arg_verIDs, aux_verIDs; + std::vector arg_dev_type, aux_dev_type; + std::vector arg_dev_id, aux_dev_id; + + // convert input args + for (size_t i=0; i < in_arg_names.size(); i++) { + arg_names.push_back(in_arg_names[i].c_str()); + const NDArray &in_arg = *(in_args_ptr[i]); + +#if MXNET_USE_MKLDNN == 1 + // reorder data if in MKLDNN format + if (in_arg.IsMKLDNNData()) { + in_arg.Reorder2DefaultAsync(); + in_arg.WaitToRead(); + } +#endif + + // pull out parts of NDArray to send to backend + arg_data.push_back(in_arg.data().dptr_); + arg_shapes.push_back(in_arg.shape().data()); + arg_dims.push_back(in_arg.shape().ndim()); + arg_types.push_back(in_arg.dtype()); + arg_verIDs.push_back(in_arg.version()); + const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu"; + arg_dev_type.push_back(arg_ctx_str); + arg_dev_id.push_back(in_arg.ctx().real_dev_id()); + } + + // convert input aux + for (size_t i=0; i < in_aux_names.size(); i++) { + aux_names.push_back(in_aux_names[i].c_str()); + const auto &in_aux = *(in_aux_ptr[i]); + +#if MXNET_USE_MKLDNN == 1 + // reorder data if in MKLDNN format + if (in_aux.IsMKLDNNData()) { + in_aux.Reorder2DefaultAsync(); + in_aux.WaitToRead(); + } +#endif + + // pull out parts of NDArray to send to backend + aux_data.push_back(in_aux.data().dptr_); + aux_shapes.push_back(in_aux.shape().data()); + aux_dims.push_back(in_aux.shape().ndim()); + aux_types.push_back(in_aux.dtype()); + aux_verIDs.push_back(in_aux.version()); + const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu"; + aux_dev_type.push_back(aux_ctx_str); + aux_dev_id.push_back(in_aux.ctx().real_dev_id()); + } + + // convert graph to string + std::string in_json = nnvm::pass::SaveJSON(g); + + std::vector new_arg_names, new_aux_names; + std::vector new_args, new_aux; + + // create lambda that captures stream & resource objects + // this temp workspace holds memory allocated by custom library via OpResource + auto ndarray_alloc = [&](const mxnet::TShape &shape, Context ctx, int dtype, + std::string name, bool isArg) { + NDArray* arr = new NDArray(shape, ctx, dtype); + if (isArg) { + new_args.push_back(arr); + new_arg_names.push_back(name); + } else { + new_aux.push_back(arr); + new_aux_names.push_back(name); + } + return arr; + }; + + // create no-capture lambda so that we can cast it to function pointer + // lambda with captures cannot be cast to function pointer and pass to lib_api.h + // this needs to be a lambda function so that we can do the decltype cast + typedef decltype(ndarray_alloc) alloc_type_ndarray; + auto ndarray_malloc = [](const void* _ndarray_alloc, const int64_t* shapes, int num_shapes, + const char* dev_str, int dev_id, int dtype, const char* name, + int isArg, void** data) { + mxnet::TShape shape(num_shapes, 0); + for (int i = 0; i < num_shapes; i++) + shape[i] = shapes[i]; + int dev_type = -1; + if (strcmp(dev_str, "cpu") == 0) + dev_type = kCPU; + else + dev_type = kGPU; + Context ctx = Context::Create(static_cast(dev_type), dev_id); + + // cast the void* argument to the type for the cpu_alloc lambda function + const alloc_type_ndarray* ndalloc = static_cast(_ndarray_alloc); + // call cpu_alloc to actually allocate memory and return the pointer + NDArray* arr = (*ndalloc)(shape, ctx, dtype, name, isArg); + *data = arr->data().dptr_; + }; + + char* out_json; + CHECK(callGraphPass(pass_fp, in_json.c_str(), &out_json, opt_keys.data(), + opt_vals.data(), opt_keys.size(), pass_name, + arg_names.data(), arg_names.size(), arg_data.data(), + arg_shapes.data(), arg_dims.data(), arg_types.data(), + arg_verIDs.data(), arg_dev_type.data(), + arg_dev_id.data(), aux_names.data(), aux_names.size(), + aux_data.data(), aux_shapes.data(), aux_dims.data(), + aux_types.data(), aux_verIDs.data(), + aux_dev_type.data(), aux_dev_id.data(), + ndarray_malloc, &ndarray_alloc)) + << "Error calling graph pass for '" << pass_name << "'"; + + std::string out_string(out_json); + nnvm::Graph out_graph = nnvm::pass::LoadJSON(out_string); + + out_graph.attrs["new_args"] = std::make_shared(new_args); + out_graph.attrs["new_arg_names"] = std::make_shared(new_arg_names); + out_graph.attrs["new_aux"] = std::make_shared(new_aux); + out_graph.attrs["new_aux_names"] = std::make_shared(new_aux_names); + + callFree(out_json); + return out_graph; + }; + + nnvm::PassFunctionReg& pass = dmlc::Registry::Get()->__REGISTER__(name); + pass.set_body(pass_lambda); + pass.set_change_graph(true); + } +} + +/*! + * \brief Loads dynamic custom library and initializes it + * \param path library path + */ +int MXLoadLib(const char *path, unsigned verbose) { + API_BEGIN(); + void *lib = LibraryInitializer::Get()->lib_load(path); + if (!lib) + LOG(FATAL) << "Unable to load library"; + + // check that library and MXNet use same version of library API + opVersion_t opVersion = get_func(lib, const_cast(MXLIB_OPVERSION_STR)); + int libVersion = opVersion(); + if (MX_LIBRARY_VERSION != libVersion) + LOG(FATAL) << "Library version (" << libVersion << ") does not match MXNet version (" + << MX_LIBRARY_VERSION << ")"; + + // initialize library by passing MXNet version + initialize_t initialize = get_func(lib, const_cast(MXLIB_INITIALIZE_STR)); + if (!initialize(static_cast(MXNET_VERSION))) + LOG(FATAL) << "Library failed to initialize"; + + // find ops, partitioners, and passes in library + registerOperators(lib, verbose); + registerPartitioners(lib, verbose); + registerPasses(lib, verbose); API_END(); } diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index afc64f73de7c..9585654b412b 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -752,7 +752,6 @@ int _SimpleBindImpl(SymbolHandle symbol_handle, &arg_grad_vec, &aux_state_vec, use_shared_buffer ? &shared_buffer_map : nullptr, reinterpret_cast(shared_exec_handle)); - // copy ndarray ptrs to ret->handles so that front end // can access them ret->ret_handles.clear(); diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index d2b17a920c9c..3b3d83cbd6c2 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -1349,7 +1349,13 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, NDArrayHandle* in_aux_handle, const mx_uint num_options, const char** keys, - const char** vals) { + const char** vals, + int* new_args_cnt, + NDArrayHandle** new_args_handle, + char*** new_arg_names_handle, + int* new_aux_cnt, + NDArrayHandle** new_aux_handle, + char*** new_aux_names_handle) { // create copy of input symbol nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); @@ -1360,6 +1366,10 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, const auto& mutable_nodes = indexed_graph.mutable_input_nodes(); std::vector input_names = sym->ListInputNames(nnvm::Symbol::kAll); size_t num_forward_inputs = input_names.size(); + + NDArray ***new_args_ptr = reinterpret_cast(new_args_handle); + NDArray ***new_aux_ptr = reinterpret_cast(new_aux_handle); + if (args_len || aux_len) { NDArray **in_args_ptr = reinterpret_cast(in_args_handle); NDArray **in_aux_ptr = reinterpret_cast(in_aux_handle); @@ -1430,14 +1440,62 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, for (mx_uint i = 0; i < num_options; ++i) options_map.emplace_back(keys[i], vals[i]); - const auto backend = mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(backend_name); - const auto& subgraph_prop_list = backend->GetSubgraphProperties(); - for (auto property : subgraph_prop_list) { - property->PrePartition(g, options_map); - g.attrs["subgraph_property"] = std::make_shared(property); - g = ApplyPass(std::move(g), "BuildSubgraph"); - g.attrs.erase("subgraph_property"); - property->PostPartition(g); + if (mxnet::op::SubgraphBackendRegistry::Get()->backend_map_.count(backend_name) > 0) { + // use subgraph backend + const auto backend = mxnet::op::SubgraphBackendRegistry + ::Get()->GetSubgraphBackend(backend_name); + const auto& subgraph_prop_list = backend->GetSubgraphProperties(); + for (auto property : subgraph_prop_list) { + property->PrePartition(g, options_map); + g.attrs["subgraph_property"] = std::make_shared(property); + g = ApplyPass(std::move(g), "BuildSubgraph"); + g.attrs.erase("subgraph_property"); + property->PostPartition(g); + } + } else if (dmlc::Registry::Find(backend_name) != nullptr) { + // use graph pass + g.attrs["options_map"] = std::make_shared(options_map); + g.attrs["pass_name"] = std::make_shared(backend_name); + g = ApplyPass(std::move(g), backend_name); + + std::vector new_args = g.GetAttr>("new_args"); + std::vector new_aux = g.GetAttr>("new_aux"); + std::vector new_arg_names = g.GetAttr>("new_arg_names"); + std::vector new_aux_names = g.GetAttr>("new_aux_names"); + g.attrs.erase("new_args"); + g.attrs.erase("new_aux"); + g.attrs.erase("new_arg_names"); + g.attrs.erase("new_aux_names"); + + NDArray** new_arg_arr = new NDArray*[new_arg_names.size()]; + NDArray** new_aux_arr = new NDArray*[new_aux_names.size()]; + char** new_arg_cstr = new char*[new_arg_names.size()]; + char** new_aux_cstr = new char*[new_aux_names.size()]; + for (unsigned i = 0; i < new_arg_names.size(); i++) { + new_arg_arr[i] = new_args[i]; + std::string& s = new_arg_names[i]; + char* tmp = new char[s.length()+1]; + s.copy(tmp, s.length()); + tmp[s.length()] = '\0'; + new_arg_cstr[i] = tmp; + } + for (unsigned i = 0; i < new_aux_names.size(); i++) { + new_aux_arr[i] = new_aux[i]; + std::string& s = new_aux_names[i]; + char* tmp = new char[s.length()+1]; + s.copy(tmp, s.length()); + tmp[s.length()] = '\0'; + new_aux_cstr[i] = tmp; + } + *new_args_cnt = new_arg_names.size(); + *new_aux_cnt = new_aux_names.size(); + *new_arg_names_handle = new_arg_cstr; + *new_aux_names_handle = new_aux_cstr; + *new_args_ptr = new_arg_arr; + *new_aux_ptr = new_aux_arr; + } else { + // cannot find graph pass or subgraph backend registered in this name + LOG(ERROR) << "Error optimizing for backend '" << backend_name << "' cannot be found"; } s->outputs = g.outputs; *ret_sym_handle = s; diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc index 6819fbd33075..9875e1e7566a 100644 --- a/src/executor/infer_graph_attr_pass.cc +++ b/src/executor/infer_graph_attr_pass.cc @@ -662,6 +662,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, const std::string name = inode.source->attrs.name; const uint32_t num_inputs = inode.inputs.size(); const uint32_t num_outputs = inode.source->num_outputs(); + if (inode.source->is_variable()) { // Variable node. No operator. Only one output entry. CHECK(inode.source->op() == nullptr); diff --git a/src/operator/subgraph/partitioner/custom_subgraph_property.h b/src/operator/subgraph/partitioner/custom_subgraph_property.h index b7f2cc2d0fef..ea721c5aa71a 100644 --- a/src/operator/subgraph/partitioner/custom_subgraph_property.h +++ b/src/operator/subgraph/partitioner/custom_subgraph_property.h @@ -37,6 +37,7 @@ #include "../common.h" #include "../subgraph_property.h" #include "../../include/mxnet/lib_api.h" + namespace mxnet { namespace op { @@ -47,18 +48,88 @@ namespace op { */ class CustomContainOpSelector: public SubgraphSelector { public: - explicit CustomContainOpSelector(std::unordered_set supported_nodes) : - supported_nodes_(supported_nodes) {} + explicit CustomContainOpSelector(std::unordered_map supported_nodes, + void* sel_inst, partCallSelect_t callSelect, + partCallSelectInput_t callSelectInput, + partCallSelectOutput_t callSelectOutput, + partCallFilter_t callFilter, + partCallReset_t callReset, + opCallFree_t callFree, + std::unordered_map node2id) : + supported_nodes_(supported_nodes), sel_inst_(sel_inst), callSelect_(callSelect), + callSelectInput_(callSelectInput), callSelectOutput_(callSelectOutput), + callFilter_(callFilter), callReset_(callReset), callFree_(callFree), + node2id_(node2id) {} virtual bool Select(const nnvm::Node &n) { - return supported_nodes_.count(n.attrs.name) > 0; + if (!sel_inst_) { + return supported_nodes_.count(n.attrs.name) > 0; + } else { + int selected = 0; + callSelect_(sel_inst_, node2id_[&n], &selected); + return selected; + } } virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) { - return supported_nodes_.count(new_node.attrs.name) > 0; + if (!sel_inst_) { + // check that op type is supported and that both nodes have the same ID + // or the new node 's subgraph ID is any (-1) + return supported_nodes_.count(new_node.attrs.name) > 0 && + (supported_nodes_[n.attrs.name] == supported_nodes_[new_node.attrs.name] || + supported_nodes_[new_node.attrs.name] == -1); + } else { + int selected = 0; + callSelectInput_(sel_inst_, node2id_[&n], node2id_[&new_node], &selected); + return selected; + } } virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) { - return supported_nodes_.count(new_node.attrs.name) > 0; + if (!sel_inst_) { + // check that op type is supported and that both nodes have the same ID + // or the new node 's subgraph ID is any (-1) + return supported_nodes_.count(new_node.attrs.name) > 0 && + (supported_nodes_[n.attrs.name] == supported_nodes_[new_node.attrs.name] || + supported_nodes_[new_node.attrs.name] == -1); + } else { + int selected = 0; + callSelectOutput_(sel_inst_, node2id_[&n], node2id_[&new_node], &selected); + return selected; + } } - std::unordered_set supported_nodes_; + virtual std::vector Filter(const std::vector& candidates) { + if (!sel_inst_) { + return candidates; + } else { + std::unordered_map rev_map; + std::vector cand; + for (nnvm::Node* node : candidates) { + cand.push_back(node2id_[node]); + rev_map[node2id_[node]] = node; + } + int* keep_ = nullptr; + int num_keep = 0; + callFilter_(sel_inst_, cand.data(), cand.size(), &keep_, &num_keep); + std::vector keep; + for (int i=0; i < num_keep; i++) { + keep.push_back(rev_map[keep_[i]]); + } + callFree_(keep_); + return keep; + } + } + virtual void Reset() { + if (sel_inst_) + return callReset_(sel_inst_); + } + + std::unordered_map supported_nodes_; + void* sel_inst_; + partCallSelect_t callSelect_; + partCallSelectInput_t callSelectInput_; + partCallSelectOutput_t callSelectOutput_; + partCallFilter_t callFilter_; + partCallReset_t callReset_; + opCallFree_t callFree_; + std::unordered_map node2id_; }; /* @@ -73,12 +144,26 @@ class CustomSubgraphProperty: public SubgraphProperty { subgraph_prop("error"), call_supported_ops_(nullptr), supported_ops_(nullptr), + call_create_selector_(nullptr), + create_selector_(nullptr), + callSelect_(nullptr), + callSelectInput_(nullptr), + callSelectOutput_(nullptr), + callFilter_(nullptr), + callReset_(nullptr), call_review_subgraph_(nullptr), review_subgraph_(nullptr), subgraph_op_name("error") {} CustomSubgraphProperty(std::string subgraph_prop_name, partCallSupportedOps_t call_supported_ops, supportedOps_t supported_ops, + partCallCreateSelector_t call_create_selector, + createSelector_t create_selector, + partCallSelect_t callSelect, + partCallSelectInput_t callSelectInput, + partCallSelectOutput_t callSelectOutput, + partCallFilter_t callFilter, + partCallReset_t callReset, partCallReviewSubgraph_t call_review_subgraph, reviewSubgraph_t review_subgraph, opCallFree_t call_free, @@ -86,6 +171,13 @@ class CustomSubgraphProperty: public SubgraphProperty { subgraph_prop(subgraph_prop_name), call_supported_ops_(call_supported_ops), supported_ops_(supported_ops), + call_create_selector_(call_create_selector), + create_selector_(create_selector), + callSelect_(callSelect), + callSelectInput_(callSelectInput), + callSelectOutput_(callSelectOutput), + callFilter_(callFilter), + callReset_(callReset), call_review_subgraph_(call_review_subgraph), review_subgraph_(review_subgraph), call_free_(call_free), @@ -175,6 +267,13 @@ class CustomSubgraphProperty: public SubgraphProperty { graph.attrs.clear(); const nnvm::IndexedGraph& indexed_graph = graph.indexed_graph(); + // create map from nnvm::Node to nid + node2id.clear(); + for (unsigned nid = 0; nid < indexed_graph.num_nodes(); nid++) { + nnvm::Node* node = const_cast(indexed_graph[nid].source); + node2id[node] = nid; + } + // set shape attrs for each node in the graph if (g.HasAttr("shape")) { mxnet::ShapeVector shapes = g.GetAttr("shape"); @@ -212,15 +311,8 @@ class CustomSubgraphProperty: public SubgraphProperty { } } - CHECK(supported_ops_ != nullptr) - << "supported_ops_ is null for " << subgraph_prop << std::endl; - CHECK(call_supported_ops_ != nullptr) - << "call_supported_ops_ is null for " << subgraph_prop << std::endl; - - std::string subgraph_json = nnvm::pass::SaveJSON(graph); - std::vector supported_node_IDs(indexed_graph.num_nodes(), 0); - const char* json = subgraph_json.c_str(); - int *ids = supported_node_IDs.data(); + std::string graph_json = nnvm::pass::SaveJSON(graph); + const char* json = graph_json.c_str(); // clear options from previous call opt_keys_.clear(); @@ -236,16 +328,35 @@ class CustomSubgraphProperty: public SubgraphProperty { opt_vals_.push_back(kv.second.c_str()); } - CHECK(call_supported_ops_(supported_ops_, json, supported_node_IDs.size(), ids, - opt_keys_.data(), opt_vals_.data(), opt_keys_.size())) - << "Error calling supported_ops for '" << subgraph_prop << "'"; + // check if supportedOps was registered + if (supported_ops_ && call_supported_ops_) { + // setup array of subgraph IDs for each node + std::vector supported_node_IDs(indexed_graph.num_nodes(), -2); + int *ids = supported_node_IDs.data(); + // call supportedOps + CHECK(call_supported_ops_(supported_ops_, json, supported_node_IDs.size(), ids, + opt_keys_.data(), opt_vals_.data(), opt_keys_.size())) + << "Error calling supported_ops for '" << subgraph_prop << "'"; - const auto& idx = g.indexed_graph(); - // loop and add node names for each supported node ID - for (unsigned i = 0; i < supported_node_IDs.size(); i++) { - if (supported_node_IDs[i]) { - supported_nodes.insert(idx[i].source->attrs.name); + const auto& idx = g.indexed_graph(); + // loop and add node names for each supported node ID + for (unsigned i = 0; i < supported_node_IDs.size(); i++) { + if (supported_node_IDs[i] != -2) { + supported_nodes[idx[i].source->attrs.name] = supported_node_IDs[i]; + } } + } else if (call_create_selector_ && callSelect_ && callSelectInput_ && + callSelectOutput_ && callFilter_ && callReset_ && + create_selector_) { + sel_inst = nullptr; + CHECK(call_create_selector_(create_selector_, json, &sel_inst, + opt_keys_.data(), opt_vals_.data(), opt_keys_.size())) + << "Error calling supported_ops for '" << subgraph_prop << "'"; + } else { + CHECK(supported_ops_ != nullptr) + << "supported_ops_ is null for " << subgraph_prop << std::endl; + CHECK(call_supported_ops_ != nullptr) + << "call_supported_ops_ is null for " << subgraph_prop << std::endl; } } // override CreateSubgraphNode @@ -315,8 +426,8 @@ class CustomSubgraphProperty: public SubgraphProperty { ss << "["; for (unsigned i=0; i < sym.outputs.size(); i++) { const nnvm::NodeEntry& e = sym.outputs[i]; - if (e.node->attrs.dict.count("__shape__") > 0) { - std::string& shape = e.node->attrs.dict["__shape__"]; + if (e.node->attrs.dict.count(MX_STR_SHAPE) > 0) { + std::string& shape = e.node->attrs.dict[MX_STR_SHAPE]; // add this shape to the list ss << getShapeAt(shape, e.index); } @@ -324,7 +435,7 @@ class CustomSubgraphProperty: public SubgraphProperty { ss << ","; } ss << "]"; - n->attrs.dict["__shape__"] = ss.str(); + n->attrs.dict[MX_STR_SHAPE] = ss.str(); } // set dtypes { @@ -332,8 +443,8 @@ class CustomSubgraphProperty: public SubgraphProperty { ss << "["; for (unsigned i=0; i < sym.outputs.size(); i++) { const nnvm::NodeEntry& e = sym.outputs[i]; - if (e.node->attrs.dict.count("__dtype__") > 0) { - std::string& dtype = e.node->attrs.dict["__dtype__"]; + if (e.node->attrs.dict.count(MX_STR_DTYPE) > 0) { + std::string& dtype = e.node->attrs.dict[MX_STR_DTYPE]; // add this dtype to the list ss << getDtypeAt(dtype, e.index); } @@ -341,7 +452,7 @@ class CustomSubgraphProperty: public SubgraphProperty { ss << ","; } ss << "]"; - n->attrs.dict["__dtype__"] = ss.str(); + n->attrs.dict[MX_STR_DTYPE] = ss.str(); } // set user specified attributes for (auto attr : user_attrs) @@ -374,37 +485,46 @@ class CustomSubgraphProperty: public SubgraphProperty { } // pass down other attributes if available - if (orig.node->attrs.dict.count("__dtype__") > 0) { + if (orig.node->attrs.dict.count(MX_STR_DTYPE) > 0) { // get dtype string from other node - std::string& dtype = orig.node->attrs.dict["__dtype__"]; + std::string& dtype = orig.node->attrs.dict[MX_STR_DTYPE]; std::stringstream ss; ss << "[" << getDtypeAt(dtype, orig.index) << "]"; - e->node->attrs.dict["__dtype__"] = ss.str(); + e->node->attrs.dict[MX_STR_DTYPE] = ss.str(); } - if (orig.node->attrs.dict.count("__shape__") > 0) { + if (orig.node->attrs.dict.count(MX_STR_SHAPE) > 0) { // get shape string from other node - std::string& shape = orig.node->attrs.dict["__shape__"]; + std::string& shape = orig.node->attrs.dict[MX_STR_SHAPE]; // create new shape string for this node std::stringstream ss; ss << "[" << getShapeAt(shape, orig.index) << "]"; - e->node->attrs.dict["__shape__"] = ss.str(); + e->node->attrs.dict[MX_STR_SHAPE] = ss.str(); } } } // override CreateSubgraphSelector virtual SubgraphSelectorPtr CreateSubgraphSelector() const { - return std::make_shared(supported_nodes); + return std::make_shared(supported_nodes, + sel_inst, callSelect_, callSelectInput_, callSelectOutput_, + callFilter_, callReset_, call_free_, node2id); } std::string subgraph_prop; partCallSupportedOps_t call_supported_ops_; supportedOps_t supported_ops_; + partCallCreateSelector_t call_create_selector_; + createSelector_t create_selector_; + partCallSelect_t callSelect_; + partCallSelectInput_t callSelectInput_; + partCallSelectOutput_t callSelectOutput_; + partCallFilter_t callFilter_; + partCallReset_t callReset_; partCallReviewSubgraph_t call_review_subgraph_; reviewSubgraph_t review_subgraph_; opCallFree_t call_free_; - std::unordered_set supported_nodes; + std::unordered_map supported_nodes; std::string subgraph_op_name; std::vector> options_map_; std::vector opt_keys_, opt_vals_; @@ -419,6 +539,8 @@ class CustomSubgraphProperty: public SubgraphProperty { std::vector arg_verIDs, aux_verIDs; std::vector arg_dev_type, aux_dev_type; std::vector arg_dev_id, aux_dev_id; + void* sel_inst = nullptr; + std::unordered_map node2id; }; } // namespace op } // namespace mxnet