Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Dynamic custom operator GPU support (#17270)
Browse files Browse the repository at this point in the history
* poc gpu customop end to end

* add backward and device id

* clear up customop makefile

* new fcomp register

* new setforward to pass custom context to c_api

* resolve sam comment: add cond register and fix setforward char

* tmp stateful op

* passing ctx of stateful op

* add gpu alloc and refactor all fcomp

* resolve sam comments and refactor alloc

* add gpu check to pass cpu build

* add unittest and resolve ptrend comments

* add cmake and jenkins

* fix windows

* windows gpu cmake build fix

* remove verbose
  • Loading branch information
rondogency authored Jan 31, 2020
1 parent 6214c4d commit a726c40
Show file tree
Hide file tree
Showing 11 changed files with 885 additions and 350 deletions.
36 changes: 25 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -685,10 +685,6 @@ if(MSVC)

endif()

add_library(sample_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/gemm_lib.cc)
add_library(subgraph_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_subgraph/subgraph_lib.cc)
target_include_directories(sample_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
target_include_directories(subgraph_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
set(MXNET_INSTALL_TARGETS mxnet)
if(UNIX)
string(APPEND CMAKE_CUDA_FLAGS "${CUDA_ARCH_FLAGS_SPACES}")
Expand All @@ -701,15 +697,8 @@ if(UNIX)
target_link_libraries(mxnet PRIVATE ${BEGIN_WHOLE_ARCHIVE} $<TARGET_FILE:mxnet_static> ${END_WHOLE_ARCHIVE})
target_link_libraries(mxnet PRIVATE mxnet_static)
target_link_libraries(mxnet_static PUBLIC ${CMAKE_DL_LIBS})
target_compile_options(sample_lib PUBLIC -shared)
target_compile_options(subgraph_lib PUBLIC -shared)
set_target_properties(mxnet_static PROPERTIES OUTPUT_NAME mxnet)
elseif(MSVC)
target_compile_options(sample_lib PUBLIC /LD)
target_compile_options(subgraph_lib PUBLIC /LD)
set_target_properties(sample_lib PROPERTIES PREFIX "lib")
set_target_properties(subgraph_lib PROPERTIES PREFIX "lib")

if(USE_CUDA)
if(MSVC)
if(USE_SPLIT_ARCH_DLL)
Expand Down Expand Up @@ -762,6 +751,31 @@ elseif(MSVC)

endif()

add_library(customop_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/gemm_lib.cc)
add_library(subgraph_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_subgraph/subgraph_lib.cc)
target_include_directories(customop_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
target_include_directories(subgraph_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
if (USE_CUDA)
add_library(customop_gpu_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/relu_lib.cu)
target_include_directories(customop_gpu_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
endif()
if(UNIX)
target_compile_options(customop_lib PUBLIC -shared)
target_compile_options(subgraph_lib PUBLIC -shared)
if (USE_CUDA)
target_compile_options(customop_gpu_lib PUBLIC -shared)
endif()
elseif(MSVC)
target_compile_options(customop_lib PUBLIC /LD)
target_compile_options(subgraph_lib PUBLIC /LD)
set_target_properties(customop_lib PROPERTIES PREFIX "lib")
set_target_properties(subgraph_lib PROPERTIES PREFIX "lib")
if (USE_CUDA)
target_compile_options(customop_gpu_lib PUBLIC "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-fPIC>")
set_target_properties(customop_gpu_lib PROPERTIES PREFIX "lib")
endif()
endif()

if(USE_DIST_KVSTORE)
add_subdirectory("3rdparty/ps-lite")
add_definitions(-DMXNET_USE_DIST_KVSTORE)
Expand Down
20 changes: 13 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ endif
.PHONY: clean all extra-packages test lint clean_all rcpplint rcppexport roxygen\
cython2 cython3 cython cyclean

all: lib/libmxnet.a lib/libmxnet.so $(BIN) extra-packages sample_lib subgraph_lib
all: lib/libmxnet.a lib/libmxnet.so $(BIN) extra-packages extension_libs

SRC = $(wildcard src/*/*/*/*.cc src/*/*/*.cc src/*/*.cc src/*.cc)
OBJ = $(patsubst %.cc, build/%.o, $(SRC))
Expand Down Expand Up @@ -664,11 +664,19 @@ cpplint:
pylint:
python3 -m pylint --rcfile=$(ROOTDIR)/ci/other/pylintrc --ignore-patterns=".*\.so$$,.*\.dll$$,.*\.dylib$$" python/mxnet

# sample lib for MXNet extension dynamically loading custom operator
sample_lib:
$(CXX) -shared -fPIC -std=c++11 example/extensions/lib_custom_op/gemm_lib.cc -o libsample_lib.so -I include/mxnet
# MXNet extension dynamically loading libraries
EXT_LIBS = custom_op_lib subgraph_lib
ifeq ($(USE_CUDA), 1)
EXT_LIBS += custom_op_gpu_lib
endif
extension_libs: $(EXT_LIBS)

custom_op_lib:
$(CXX) -shared -fPIC -std=c++11 example/extensions/lib_custom_op/gemm_lib.cc -o build/libcustomop_lib.so -I include/mxnet
custom_op_gpu_lib:
$(NVCC) -shared -std=c++11 -Xcompiler -fPIC example/extensions/lib_custom_op/relu_lib.cu -o build/libcustomop_gpu_lib.so -I include/mxnet
subgraph_lib:
$(CXX) -shared -fPIC -std=c++11 example/extensions/lib_subgraph/subgraph_lib.cc -o libsubgraph_lib.so -I include/mxnet
$(CXX) -shared -fPIC -std=c++11 example/extensions/lib_subgraph/subgraph_lib.cc -o build/libsubgraph_lib.so -I include/mxnet

# Cython build
cython:
Expand Down Expand Up @@ -734,7 +742,6 @@ clean: rclean cyclean $(EXTRA_PACKAGES_CLEAN)
cd $(NNVM_PATH); $(MAKE) clean; cd -
cd $(TVM_PATH); $(MAKE) clean; cd -
cd $(AMALGAMATION_PATH); $(MAKE) clean; cd -
$(RM) libsample_lib.so libsubgraph_lib.so
$(RM) -r $(patsubst %, %/*.d, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.d, $(EXTRA_OPERATORS))
$(RM) -r $(patsubst %, %/*.o, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.o, $(EXTRA_OPERATORS))
else
Expand All @@ -746,7 +753,6 @@ clean: rclean mkldnn_clean cyclean testclean $(EXTRA_PACKAGES_CLEAN)
cd $(NNVM_PATH); $(MAKE) clean; cd -
cd $(TVM_PATH); $(MAKE) clean; cd -
cd $(AMALGAMATION_PATH); $(MAKE) clean; cd -
$(RM) libsample_lib.so libsubgraph_lib.so
endif

clean_all: clean
Expand Down
14 changes: 7 additions & 7 deletions ci/jenkins/Jenkins_steps.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,23 @@
utils = load('ci/Jenkinsfile_utils.groovy')

// mxnet libraries
mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'

// Python wheels
mx_pip = 'build/*.whl'

// mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default.
mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so'
mx_cmake_lib_no_tvm_op = 'build/libmxnet.so, build/libmxnet.a, build/libsample_lib.so, build/libsubgraph_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so'
mx_cmake_lib_no_tvm_op = 'build/libmxnet.so, build/libmxnet.a, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so'
mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
// mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default.
mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/libsample_lib.so, build/libsubgraph_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests'
mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests'
mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so'
mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, libsample_lib.so, libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/cpp-package/example/*'

// Python unittest for CPU
Expand Down
7 changes: 5 additions & 2 deletions example/extensions/lib_custom_op/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
# specific language governing permissions and limitations
# under the License.

all: gemm_lib
all: gemm_lib relu_lib

gemm_lib:
g++ -shared -fPIC -std=c++11 gemm_lib.cc -o libgemm_lib.so -I ../../../include/mxnet

relu_lib:
nvcc -shared -std=c++11 -Xcompiler -fPIC relu_lib.cu -o librelu_lib.so -I ../../../include/mxnet

clean:
rm -rf libgemm_lib.so
rm -rf libgemm_lib.so librelu_lib.so
17 changes: 8 additions & 9 deletions example/extensions/lib_custom_op/gemm_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ MXReturnValue backward(std::map<std::string, std::string> attrs,
unsigned m = inputs[2].shape[1];
// allocate temporary workspace memory through resource manager
// for multiple arrays better to request a big memory pool
void *workspace = res.alloc((k*n + m*k) * sizeof(float));
void *workspace = res.alloc_cpu((k*n + m*k) * sizeof(float));
float *At = static_cast<float*>(workspace);
float *Bt = static_cast<float*>(workspace) + (k*n);

Expand Down Expand Up @@ -167,8 +167,8 @@ MXReturnValue inferShape(std::map<std::string, std::string> attrs,
}

REGISTER_OP(my_gemm)
.setForward(forward)
.setBackward(backward)
.setForward(forward, "cpu")
.setBackward(backward, "cpu")
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape);
Expand All @@ -182,8 +182,7 @@ class MyStatefulGemm : public CustomStatefulOp {
MXReturnValue Forward(std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource op_res) {
++count;
std::cout << "Info: keyword + number of forward: " << count << std::endl;
std::cout << "Info: keyword + number of forward: " << ++count << std::endl;
std::map<std::string, std::string> attrs;
return forward(attrs, inputs, outputs, op_res);
}
Expand All @@ -203,9 +202,9 @@ class MyStatefulGemm : public CustomStatefulOp {

MXReturnValue createOpState(std::map<std::string, std::string> attrs,
CustomStatefulOp** op_inst) {
int count = 0;
if (attrs.count("test_kw") > 0)
count = std::stoi(attrs["test_kw"]);
// testing passing of keyword arguments
int count = attrs.count("test_kw") > 0 ? std::stoi(attrs["test_kw"]) : 0;
// creating stateful operator instance
*op_inst = new MyStatefulGemm(count);
std::cout << "Info: stateful operator created" << std::endl;
return MX_SUCCESS;
Expand All @@ -222,7 +221,7 @@ REGISTER_OP(state_gemm)
.setInferType(inferType)
.setInferShape(inferShape)
.setMutateInputs(mutateInputs)
.setCreateOpState(createOpState);
.setCreateOpState(createOpState, "cpu");

MXReturnValue initialize(int version) {
if (version >= 10400) {
Expand Down
191 changes: 191 additions & 0 deletions example/extensions/lib_custom_op/relu_lib.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2020 by Contributors
* \file relu_lib.cu
* \brief simple custom relu operator implemented using CUDA function
*/

#include <iostream>
#include "lib_api.h"

__global__ void relu_gpu_forward(float *out, float *in, int64_t N) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N)
out[tid] = in[tid] > 0 ? in[tid] : 0;
}

__global__ void relu_gpu_backward(float *ingrad, float *outgrad, float *indata, int64_t N) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N)
ingrad[tid] = indata[tid] > 0 ? 1 * outgrad[tid] : 0;
}

MXReturnValue forwardCPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* in_data = inputs[0].data<float>();
float* out_data = outputs[0].data<float>();
for (int i=0; i<inputs[0].size(); i++) {
out_data[i] = in_data[i] > 0 ? in_data[i] : 0;
}
return MX_SUCCESS;
}

MXReturnValue backwardCPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* out_grad = inputs[0].data<float>();
float* in_data = inputs[1].data<float>();
float* in_grad = outputs[0].data<float>();
for (int i=0; i<inputs[1].size(); i++) {
in_grad[i] = in_data[i] > 0 ? 1 * out_grad[i] : 0;
}
return MX_SUCCESS;
}

MXReturnValue forwardGPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* in_data = inputs[0].data<float>();
float* out_data = outputs[0].data<float>();

mx_stream_t cuda_stream = res.get_cuda_stream();
int64_t N = inputs[0].size();
int block = 256;
int grid = (N + (block - 1)) / block;
relu_gpu_forward<<<grid,block,0,cuda_stream>>>(out_data, in_data, N);

return MX_SUCCESS;
}

MXReturnValue backwardGPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* out_grad = inputs[0].data<float>();
float* in_data = inputs[1].data<float>();
float* in_grad = outputs[0].data<float>();

mx_stream_t cuda_stream = res.get_cuda_stream();
int64_t N = inputs[0].size();
int block = 256;
int grid = (N + (block - 1)) / block;
relu_gpu_backward<<<grid,block,0,cuda_stream>>>(in_grad, out_grad, in_data, N);

return MX_SUCCESS;
}

MXReturnValue parseAttrs(std::map<std::string, std::string> attrs, int* num_in, int* num_out) {
*num_in = 1;
*num_out = 1;
return MX_SUCCESS;
}

MXReturnValue inferType(std::map<std::string, std::string> attrs,
std::vector<int> &intypes,
std::vector<int> &outtypes) {
outtypes[0] = intypes[0];
return MX_SUCCESS;
}

MXReturnValue inferShape(std::map<std::string, std::string> attrs,
std::vector<std::vector<unsigned int>> &inshapes,
std::vector<std::vector<unsigned int>> &outshapes) {
outshapes[0] = inshapes[0];
return MX_SUCCESS;
}

REGISTER_OP(my_relu)
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape)
.setForward(forwardCPU, "cpu")
.setForward(forwardGPU, "gpu")
.setBackward(backwardCPU, "cpu")
.setBackward(backwardGPU, "gpu");

class MyStatefulReluCPU : public CustomStatefulOp {
public:
explicit MyStatefulReluCPU() {}
MXReturnValue Forward(std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource op_res) {
std::map<std::string, std::string> attrs;
return forwardCPU(attrs, inputs, outputs, op_res);
}
MXReturnValue Backward(std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource op_res) {
std::map<std::string, std::string> attrs;
return backwardCPU(attrs, inputs, outputs, op_res);
}
~MyStatefulReluCPU() {}
};

class MyStatefulReluGPU : public CustomStatefulOp {
public:
explicit MyStatefulReluGPU() {}
MXReturnValue Forward(std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource op_res) {
std::map<std::string, std::string> attrs;
return forwardGPU(attrs, inputs, outputs, op_res);
}
MXReturnValue Backward(std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource op_res) {
std::map<std::string, std::string> attrs;
return backwardGPU(attrs, inputs, outputs, op_res);
}
~MyStatefulReluGPU() {}
};

MXReturnValue createOpStateCPU(std::map<std::string, std::string> attrs,
CustomStatefulOp** op_inst) {
*op_inst = new MyStatefulReluCPU();
return MX_SUCCESS;
}

MXReturnValue createOpStateGPU(std::map<std::string, std::string> attrs,
CustomStatefulOp** op_inst) {
*op_inst = new MyStatefulReluGPU();
return MX_SUCCESS;
}

REGISTER_OP(my_state_relu)
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape)
.setCreateOpState(createOpStateCPU, "cpu")
.setCreateOpState(createOpStateGPU, "gpu");

MXReturnValue initialize(int version) {
if (version >= 10400) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
return MX_FAIL;
}
}
Loading

0 comments on commit a726c40

Please sign in to comment.