Skip to content

Commit

Permalink
dynamic custom operator support (apache#15921)
Browse files Browse the repository at this point in the history
* fixed example to use absolute path

* added example for custom ops, added support for custom op registration

* added fcompute registration for loaded operators
moved library import order to after ndarray/symbol

* changed dynamic ops to be contrib

* added num in/out

* removed contrib op registration
re-registered ops from mx.nd.op to mx.nd

* added support for infer shape, updated example to call operator

* fixed whitespace

* fixed whitespace

* fixed whitespace

* added temporary support for operator multi-registration

* insanity checked

* update docblocks

* small format fix

* fix unittest with correct library

* implement InferType

* initial support for resource manager, temp space

* fixed formatting

* changed decltype to typedef

* fixed whitespace

* Added windows declaration types, change APIs to return MXReturnValue instead of int

* added library version number, API to get, and check to validate

* Changed CMakeLists to build lib_ops instead of lib_api, updated lib_api example, fixed whitespace

* add prototype of subgraph op

* implement FMutateInput as optional attribute

* fix sanity check

* replace fcompute to fcomputeEx and implement simple finferstoragetype

* changed fcompute to forward

* initial commit with fgradient support

* enabled gradient registration

* fixed whitespace

* fixed example to use absolute path

* added example for custom ops, added support for custom op registration

* added fcompute registration for loaded operators
moved library import order to after ndarray/symbol

* changed dynamic ops to be contrib

* added num in/out

* removed contrib op registration
re-registered ops from mx.nd.op to mx.nd

* added support for infer shape, updated example to call operator

* fixed whitespace

* fixed whitespace

* fixed whitespace

* added temporary support for operator multi-registration

* insanity checked

* update docblocks

* small format fix

* fix unittest with correct library

* implement InferType

* initial support for resource manager, temp space

* fixed formatting

* changed decltype to typedef

* fixed whitespace

* Added windows declaration types, change APIs to return MXReturnValue instead of int

* added library version number, API to get, and check to validate

* Changed CMakeLists to build lib_ops instead of lib_api, updated lib_api example, fixed whitespace

* add prototype of subgraph op

* implement FMutateInput as optional attribute

* fix sanity check

* replace fcompute to fcomputeEx and implement simple finferstoragetype

* changed fcompute to forward

* initial commit with fgradient support

* enabled gradient registration

* fixed whitespace

* prototype of createopstate and fstatefulcompute

* make custom state op interface work

* subgraph forward

* refactor stateful forward and add op resource

* wip gemm backward

* stateful backward and subgraph test

* implement gemm and state gemm, refactor test files

* add body to pure virtual destructor

* subgraph passing from python to custom lib

* rm lib_api c++11 dep, rm warpctc, add rm flag

* fix conflict

* subgraph json parsing utility

* add data size and fix unsigned warnings

* use c++ struct and fix cpplint

* refactor op registry

* fix line length and win array of ci; condense lines

* add mxnet_extension dir

* fixed extension to be dll for windows

* updated examples to use the same format as the example in the top-level Makefile: "lib<name>.so"

* removed destructor for CustomStatefulOp

* fix error in gemm test and clear up subgraph test

* lib path fix

* add unittest for custom op

* update Makefile revolve merge

* fix test and rename folder

* fix makefile rename

* fix cmake rename

* add explicit cpu context

* wkcn feedback: change mxtensor func name. use c++11 flag

* add operator keyward test and refine info print

* using typedef in forward

* small refine of docblock

* change names

* add separate stateful compute and pass state_op ptr

* user example using opresource alloc

* added DLTensor into MXTensor

* fixed whitespace

* added error check when DLTensor does not support MXNet data type

* changed to throw runtime exception

* changed include to stdexcept

* retrigger CI

* empty commit

* empty commit

* remove merge conflict

* add setdltensor for easy use and add docs

* CI

* re-trigger CI

* ci

* ci
  • Loading branch information
samskalicky authored and wkcn committed Dec 6, 2019
1 parent 8dd7051 commit ae472c2
Show file tree
Hide file tree
Showing 17 changed files with 2,160 additions and 34 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ else()

endif()

add_library(sample_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/lib_api/mylib.cc)
add_library(sample_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/gemm_lib.cc)
target_include_directories(sample_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
set(MXNET_INSTALL_TARGETS mxnet)
if(UNIX)
Expand Down
10 changes: 6 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,10 @@ cpplint:
pylint:
python3 -m pylint --rcfile=$(ROOTDIR)/ci/other/pylintrc --ignore-patterns=".*\.so$$,.*\.dll$$,.*\.dylib$$" python/mxnet

# sample lib for MXNet extension dynamically loading custom operator
sample_lib:
$(CXX) -shared -fPIC -std=c++11 example/extensions/lib_custom_op/gemm_lib.cc -o libsample_lib.so -I include/mxnet

# Cython build
cython:
cd python; $(PYTHON) setup.py build_ext --inplace --with-cython
Expand Down Expand Up @@ -705,10 +709,6 @@ rpkgtest:
Rscript -e 'require(testthat);res<-test_dir("R-package/tests/testthat");if(!testthat:::all_passed(res)){stop("Test failures", call. = FALSE)}'
Rscript -e 'res<-covr:::package_coverage("R-package");fileConn<-file(paste("r-package_coverage_",toString(runif(1)),".json"));writeLines(covr:::to_codecov(res), fileConn);close(fileConn)'


sample_lib:
$(CXX) -shared -fPIC example/lib_api/mylib.cc -o libsample_lib.so -I include/mxnet

scalaclean:
(cd $(ROOTDIR)/scala-package && mvn clean)

Expand Down Expand Up @@ -760,6 +760,7 @@ clean: rclean cyclean $(EXTRA_PACKAGES_CLEAN)
cd $(NNVM_PATH); $(MAKE) clean; cd -
cd $(TVM_PATH); $(MAKE) clean; cd -
cd $(AMALGAMATION_PATH); $(MAKE) clean; cd -
$(RM) libsample_lib.so
$(RM) -r $(patsubst %, %/*.d, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.d, $(EXTRA_OPERATORS))
$(RM) -r $(patsubst %, %/*.o, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.o, $(EXTRA_OPERATORS))
else
Expand All @@ -771,6 +772,7 @@ clean: rclean mkldnn_clean cyclean testclean $(EXTRA_PACKAGES_CLEAN)
cd $(NNVM_PATH); $(MAKE) clean; cd -
cd $(TVM_PATH); $(MAKE) clean; cd -
cd $(AMALGAMATION_PATH); $(MAKE) clean; cd -
$(RM) libsample_lib.so
endif

clean_all: clean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@
# under the License.

all:
g++ -shared -fPIC mylib.cc -o mylib.so -I ../../include/mxnet
g++ -std=c++11 -shared -fPIC init_lib.cc -o libinit_lib.so -I ../../../include/mxnet

test:
g++ -std=c++11 -O3 -o libtest libtest.cc -ldl -I ../../include/mxnet
g++ -std=c++11 -O3 -o libtest libtest.cc -ldl -I ../../../include/mxnet

windows:
cl /LD mylib.cc
cl /LD init_lib.cc

win_test:
cl libtest.cc

clean:
rm -rf mylib.so libtest
rm -rf *.so libtest
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@

/*!
* Copyright (c) 2015 by Contributors
* \file mylib.cc
* \file init_lib.cc
* \brief Sample library file
*/

#include <iostream>
#include "lib_api.h"

int initialize(int version) {
MXReturnValue initialize(int version) {
if (version >= 10400) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return 1;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
return 0;
return MX_FAIL;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ int main(void) {
// Get a handle to the library.
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
HINSTANCE handle;
handle = LoadLibrary(TEXT("mylib.dll"));
handle = LoadLibrary(TEXT("libinit_lib.dll"));
#else
void *handle;
handle = dlopen("mylib.so", RTLD_LAZY);
handle = dlopen("libinit_lib.so", RTLD_LAZY);
#endif

if (!handle) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import os

if (os.name=='posix'):
mx.library.load('mylib.so')
path = os.path.abspath('libinit_lib.so')
mx.library.load(path)
elif (os.name=='nt'):
mx.library.load('mylib.dll')
path = os.path.abspath('libinit_lib.dll')
mx.library.load(path)
27 changes: 27 additions & 0 deletions example/extensions/lib_custom_op/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

all: subgraph_lib gemm_lib

gemm_lib:
g++ -shared -fPIC -std=c++11 gemm_lib.cc -o libgemm_lib.so -I ../../../include/mxnet

subgraph_lib:
g++ -shared -fPIC -std=c++11 subgraph_lib.cc -o libsubgraph_lib.so -I ../../../include/mxnet

clean:
rm -rf libsubgraph_lib.so libgemm_lib.so
235 changes: 235 additions & 0 deletions example/extensions/lib_custom_op/gemm_lib.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file gemm_lib.cc
* \brief Sample 2D gemm custom operator implementation library file
*/

#include <iostream>
#include "lib_api.h"

// main matrix multiplication routine
void gemm(const float* A, const float* B, float* C,
const unsigned n, const unsigned k, const unsigned m) {
unsigned i, j, kk;
for (i = 0; i < n; i++) {
for (j = 0; j < m; j++) {
C[i*m+j] = 0;
for (kk = 0; kk < k; kk++) {
C[i*m+j] += A[i*k+kk] * B[kk*m+j];
}
}
}
}

void transpose(const float* A, float* At, const unsigned n, const unsigned m) {
unsigned i, j;
for (i = 0; i < n; i++) {
for (j = 0; j < m; j++) {
At[i*m+j] = A[j*n+i];
}
}
}

/*
* Executes C = A * B
* inputs[0] = A; inputs[1] = B; outputs[0] = C
*/
MXReturnValue forward(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
// simple example of using runtime data type
if (inputs[0].dtype == kFloat32) {
typedef float DType;
// extract data pointers from tensors
// if using dltensor repr, below lines can be changed to something like
// DType* A = reinterpret_cast<DType*>(inputs[0].dltensor.data);
DType* A = inputs[0].data<DType>();
DType* B = inputs[1].data<DType>();
DType* C = outputs[0].data<DType>();
// set tensor shapes
unsigned n = inputs[0].shape[0];
unsigned k = inputs[0].shape[1];
unsigned m = inputs[1].shape[1];

gemm(A, B, C, n, k, m);
}
return MX_SUCCESS;
}

/*
* Executes dA = dC * B.T; Executes dB = A.T * dC
***** gradient inputs
* inputs[0] = dC
***** original inputs
* inputs[1] = A; inputs[2] = B
***** original outputs
* inputs[3] = C
***** gradient outputs
* outputs[0] = dA; outputs[1] = dB
*/
MXReturnValue backward(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
// extract data pointers from tensors
float* dC = inputs[0].data<float>();
float* A = inputs[1].data<float>();
float* B = inputs[2].data<float>();
float* dA = outputs[0].data<float>();
float* dB = outputs[1].data<float>();
// set tensor shapes
unsigned n = inputs[1].shape[0];
unsigned k = inputs[1].shape[1];
unsigned m = inputs[2].shape[1];
// allocate temporary workspace memory through resource manager
// for multiple arrays better to request a big memory pool
void *workspace = res.alloc((k*n + m*k) * sizeof(float));
float *At = static_cast<float*>(workspace);
float *Bt = static_cast<float*>(workspace) + (k*n);

transpose(A, At, k, n);
transpose(B, Bt, m, k);
gemm(dC, Bt, dA, n, m, k);
gemm(At, dC, dB, k, n, m);

return MX_SUCCESS;
}

MXReturnValue parseAttrs(std::map<std::string, std::string> attrs, int* num_in, int* num_out) {
*num_in = 2;
*num_out = 1;
return MX_SUCCESS;
}

MXReturnValue inferType(std::map<std::string, std::string> attrs,
std::vector<int> &intypes,
std::vector<int> &outtypes) {
// validate inputs
if (intypes.size() != 2) {
std::cout << "Expected 2 inputs to inferType" << std::endl;
return MX_FAIL;
}
for (unsigned i = 0; i < intypes.size(); i++) {
if (intypes[i] != kFloat32) {
std::cout << "Expected input " << i << " to have float32 type" << std::endl;
return MX_FAIL;
}
}

outtypes[0] = intypes[0];
return MX_SUCCESS;
}

MXReturnValue inferShape(std::map<std::string, std::string> attrs,
std::vector<std::vector<unsigned int>> &inshapes,
std::vector<std::vector<unsigned int>> &outshapes) {
// validate inputs
if (inshapes.size() != 2) {
std::cout << "Expected 2 inputs to inferShape" << std::endl;
return MX_FAIL;
}
if (inshapes[0].size() != 2 || inshapes[1].size() != 2) {
std::cout << "Expected 2D matrices for both inputs to inferShape" << std::endl;
return MX_FAIL;
}

unsigned n = inshapes[0][0];
unsigned k = inshapes[0][1];
unsigned kk = inshapes[1][0];
unsigned m = inshapes[1][1];
if (k != kk) {
std::cout << "Exected first input axis 1 equals to second input axis 0" << std::endl;
return MX_FAIL;
}

outshapes[0] = {n, m};
return MX_SUCCESS;
}

REGISTER_OP(my_gemm)
.setForward(forward)
.setBackward(backward)
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape);

/* ------------------------------------------------------------------------- */

class MyStatefulGemm : public CustomStatefulOp {
public:
explicit MyStatefulGemm(int count) : count(count) {}

MXReturnValue Forward(std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource op_res) {
++count;
std::cout << "Info: keyword + number of forward: " << count << std::endl;
std::map<std::string, std::string> attrs;
return forward(attrs, inputs, outputs, op_res);
}

MXReturnValue Backward(std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource op_res) {
std::map<std::string, std::string> attrs;
return backward(attrs, inputs, outputs, op_res);
}

~MyStatefulGemm() {}

private:
int count;
};

MXReturnValue createOpState(std::map<std::string, std::string> attrs,
CustomStatefulOp** op_inst) {
int count = 0;
if (attrs.count("test_kw") > 0)
count = std::stoi(attrs["test_kw"]);
*op_inst = new MyStatefulGemm(count);
std::cout << "Info: stateful operator created" << std::endl;
return MX_SUCCESS;
}

MXReturnValue mutateInputs(std::map<std::string, std::string> attrs,
std::vector<int> &input_indices) {
// input_indices.push_back(1); // mark mutate input
return MX_SUCCESS;
}

REGISTER_OP(state_gemm)
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape)
.setMutateInputs(mutateInputs)
.setCreateOpState(createOpState);

MXReturnValue initialize(int version) {
if (version >= 10400) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
return MX_FAIL;
}
}
Loading

0 comments on commit ae472c2

Please sign in to comment.