Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Improve createOpState for custom ops (#19103)
Browse files Browse the repository at this point in the history
* initial commit

* fixed whitespace

* update test_gemm.py to gluon

* updated relu for gluon

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
3 people committed Sep 11, 2020
1 parent 95e1814 commit e5a7814
Show file tree
Hide file tree
Showing 10 changed files with 158 additions and 51 deletions.
3 changes: 3 additions & 0 deletions example/extensions/lib_custom_op/gemm_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ class MyStatefulGemm : public CustomStatefulOp {
};

MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
// testing passing of keyword arguments
int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
Expand Down
6 changes: 6 additions & 0 deletions example/extensions/lib_custom_op/relu_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,18 @@ MXReturnValue MyStatefulReluGPU::Backward(std::vector<MXTensor>* inputs,


MXReturnValue createOpStateCPU(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
*op_inst = new MyStatefulReluCPU(attrs);
return MX_SUCCESS;
}

MXReturnValue createOpStateGPU(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
*op_inst = new MyStatefulReluGPU(attrs);
return MX_SUCCESS;
Expand Down
79 changes: 51 additions & 28 deletions example/extensions/lib_custom_op/test_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,36 +47,59 @@
t = mx.sym.Variable('t')
c = mx.sym.my_gemm(s,t)
d = mx.sym.state_gemm(s,t,test_kw=200)
e = mx.sym.linalg.gemm2(s,t)

in_grad = [mx.nd.empty((2,3)),mx.nd.empty((3,1))]
in_grad2 = [mx.nd.empty((2,3)),mx.nd.empty((3,1))]

exe = c.bind(ctx=mx.cpu(),args={'s':a,'t':b},args_grad=in_grad)
exe2 = d.bind(ctx=mx.cpu(),args={'s':a,'t':b},args_grad=in_grad2)
out_grad = mx.nd.ones((2,1))

out = exe.forward()
print(out)
print("-------")
# stateless
block = mx.gluon.nn.SymbolBlock(c,[s,t])
with mx.autograd.record():
a_ = mx.nd.array([[1,2,3],[4,5,6]])
b_ = mx.nd.array([[7],[8],[9]])
a_.attach_grad()
b_.attach_grad()
# foward
out = block(a_,b_)
print(out)
print('+++++')
# backward
out.backward(out_grad)
print(a_.grad)
print(b_.grad)
print("-------")

out2 = exe2.forward()
out2 = exe2.forward()
# stateful
block2 = mx.gluon.nn.SymbolBlock(d,[s,t])
block2.hybridize(static_alloc=True, static_shape=True)
out2 = block2(a,b)
out2 = block2(a,b)
print(out2)
print("-------")
with mx.autograd.record():
a_ = mx.nd.array([[1,2,3],[4,5,6]])
b_ = mx.nd.array([[7],[8],[9]])
a_.attach_grad()
b_.attach_grad()
# forward
out2 = block2(a_,b_)
print('+++++')
# backward
out2.backward(out_grad)
print(a_.grad)
print(b_.grad)
print("-------")

# baseline forward
e = mx.sym.linalg.gemm2(s,t)
in_grad3 = [mx.nd.empty((2,3)),mx.nd.empty((3,1))]
exe3 = e.bind(ctx=mx.cpu(),args={'s':a,'t':b},args_grad=in_grad3)
out3 = exe3.forward()
print(out3)

print("--------start backward compute--------")
out_grad = mx.nd.ones((2,1))
exe.backward([out_grad])
print(in_grad)
print("-------")
exe2.backward([out_grad])
print(in_grad2)
print("-------")
exe3.backward([out_grad])
print(in_grad3)
# baseline
block3 = mx.gluon.nn.SymbolBlock(e,[s,t])
with mx.autograd.record():
a_ = mx.nd.array([[1,2,3],[4,5,6]])
b_ = mx.nd.array([[7],[8],[9]])
a_.attach_grad()
b_.attach_grad()
# forward
out3 = block3(a_,b_)
print(out3)
print('+++++')
# backward
out3.backward(out_grad)
print(a_.grad)
print(b_.grad)
46 changes: 32 additions & 14 deletions example/extensions/lib_custom_op/test_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,39 @@
d = mx.sym.Variable('d')
e = mx.sym.my_relu(c)
base = mx.sym.relu(d)
in_grad = [mx.nd.empty((2,2), ctx=mx.gpu())]
in_grad_base = [mx.nd.empty((2,2), ctx=mx.gpu())]
exe = e.bind(ctx=mx.gpu(), args={'c':b}, args_grad=in_grad)
exe_base = base.bind(ctx=mx.gpu(), args={'d':b}, args_grad=in_grad_base)
out = exe.forward()
out_base = exe_base.forward()
print(out)
print(out_base)

print("--------backward compute--------")
#in_grad = [mx.nd.empty((2,2), ctx=mx.gpu())]
#in_grad_base = [mx.nd.empty((2,2), ctx=mx.gpu())]
out_grad = mx.nd.ones((2,2), ctx=mx.gpu())
exe.backward([out_grad])
exe_base.backward([out_grad])
print(in_grad)
print(in_grad_base)
#exe = e.bind(ctx=mx.gpu(), args={'c':b}, args_grad=in_grad)
block = mx.gluon.nn.SymbolBlock(e,[c])
#exe_base = base.bind(ctx=mx.gpu(), args={'d':b}, args_grad=in_grad_base)
block_base = mx.gluon.nn.SymbolBlock(base,[d])

# base
with mx.autograd.record():
b_ = mx.nd.array([[-2,-1],[1,2]], ctx=mx.gpu())
b_.attach_grad()
# foward
out_base = block_base(b_)
print(out_base)
print('+++++')
# backward
out_base.backward(out_grad)
print(b_.grad)
print("-------")

# custom relu
with mx.autograd.record():
b_ = mx.nd.array([[-2,-1],[1,2]], ctx=mx.gpu())
b_.attach_grad()
# foward
out = block(b_)
print(out)
print('+++++')
# backward
out.backward(out_grad)
print(b_.grad)
print("-------")

print("--------test ndarray with size of 1 million---------")
b = mx.nd.uniform(shape=(100,100,100), ctx=mx.gpu())
Expand Down
3 changes: 3 additions & 0 deletions example/extensions/lib_custom_op/transposecsr_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ class MyStatefulTransposeCSR : public CustomStatefulOp {
};

MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
// testing passing of keyword arguments
int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
Expand Down
3 changes: 3 additions & 0 deletions example/extensions/lib_custom_op/transposerowsp_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ class MyStatefulTransposeRowSP : public CustomStatefulOp {
};

MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
// testing passing of keyword arguments
int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
Expand Down
3 changes: 3 additions & 0 deletions example/extensions/lib_subgraph/subgraph_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ class MyStatefulOp : public CustomStatefulOp {
};

MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
std::string serialized_subgraph = "[empty]";
// MXNet subgraph is stored as Symbol in operator node attrs subgraphs field
Expand Down
13 changes: 9 additions & 4 deletions include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,9 @@ typedef MXReturnValue (*mutateInputs_t)(const std::unordered_map<std::string,
std::vector<int>* input_indices);
typedef MXReturnValue (*createOpState_t)(const std::unordered_map<std::string,
std::string>& attributes,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp**);

/*!
Expand Down Expand Up @@ -1000,8 +1003,9 @@ typedef int (*opCallMutateInputs_t)(mutateInputs_t mutate, const char* const* ke

#define MXLIB_OPCALLCREATEOPSTATE_STR "_opCallCreateOpState"
typedef int (*opCallCreateOpState_t)(createOpState_t create_op, const char* const* keys,
const char* const* vals, int num,
void** state_op);
const char* const* vals, int num, const char* dev_type,
int dev_id, unsigned int** inshapes, int* indims,
int num_in, const int* intypes, void** state_op);

#define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute"
typedef int (*opCallFStatefulComp_t)(int is_forward, void* state_op,
Expand Down Expand Up @@ -1190,8 +1194,9 @@ extern "C" {

/*! \brief returns status of calling createStatefulOp function for operator from library */
MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char* const* keys,
const char* const* vals, int num,
void** state_op);
const char* const* vals, int num, const char* dev_type,
int dev_id, unsigned int** inshapes, int* indims,
int num_in, const int* intypes, void** state_op);

/*! \brief returns status of calling Stateful Forward/Backward for operator from library */
MX_INT_RET _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t** inshapes,
Expand Down
30 changes: 28 additions & 2 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,28 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
attr_vals.push_back(kv.second.c_str());
}

// string repr of supported context for custom library, currently only "cpu" and "gpu"
const char* ctx_str = ctx.dev_mask() == Context::kCPU ? "cpu" : "gpu";

std::vector<uint32_t*> inshapes(in_shapes.size());
std::vector<int> indims(in_shapes.size());

// determine amount of memory needed to store all the input shapes
size_t buff_size = 0;
for (size_t i = 0; i < in_shapes.size(); ++i)
buff_size += in_shapes[i].ndim();

// copy input shapes to raw memory layout
std::vector<uint32_t> inbuff(buff_size);
uint32_t *ptr = inbuff.data();
for (size_t i = 0; i < in_shapes.size(); ++i) {
inshapes[i] = ptr;
indims[i] = in_shapes[i].ndim();
for (int j = 0; j < in_shapes[i].ndim(); ++j, ++ptr) {
*ptr = static_cast<uint32_t>(in_shapes[i][j]);
}
}

// convert subgraph symbol from node attributes to char*
std::string subgraph_json;
if (!attrs.subgraphs.empty()) {
Expand All @@ -1111,15 +1133,19 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
CHECK(createop_map.count("cpu") > 0)
<< "CPU CreateOpState not implemented for '" << name_str << "'";
int retval = callCreateOpState(createop_map.at("cpu"), attr_keys.data(), attr_vals.data(),
attr_keys.size(), &state_op_inst);
attr_keys.size(), ctx_str, ctx.real_dev_id(),
inshapes.data(), indims.data(),
in_shapes.size(), in_types.data(), &state_op_inst);
std::string msgs = getExtensionMsgs(msgSize, msgGet);
CHECK(retval) << "Error calling CreateOpState CPU for custom operator '" << name_str << "'"
<< msgs;
} else if (ctx.dev_mask() == Context::kGPU) {
CHECK(createop_map.count("gpu") > 0)
<< "GPU CreateOpState not implemented for '" << name_str << "'";
int retval = callCreateOpState(createop_map.at("gpu"), attr_keys.data(), attr_vals.data(),
attr_keys.size(), &state_op_inst);
attr_keys.size(), ctx_str, ctx.real_dev_id(),
inshapes.data(), indims.data(),
in_shapes.size(), in_types.data(), &state_op_inst);
std::string msgs = getExtensionMsgs(msgSize, msgGet);
CHECK(retval) << "Error calling CreateOpState GPU for custom operator '" << name_str << "'"
<< msgs;
Expand Down
23 changes: 20 additions & 3 deletions src/lib_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1204,19 +1204,36 @@ MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char* co

/*! \brief returns status of calling createStatefulOp function for operator from library */
MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char* const* keys,
const char* const* vals, int num,
void** state_op) {
const char* const* vals, int num, const char* dev_type,
int dev_id, unsigned int** inshapes, int* indims,
int num_in, const int* intypes, void** state_op) {
// create map of attributes from list
std::unordered_map<std::string, std::string> attrs;
for (int i = 0; i < num; i++) {
attrs[std::string(keys[i])] = std::string(vals[i]);
}

mxnet::ext::MXContext ctx(dev_type, dev_id);

// create a vector of shapes for inputs
std::vector<std::vector<unsigned int> > in_shapes(num_in);
for (int i = 0; i < num_in; i++) {
for (int j = 0; j < indims[i]; j++) {
in_shapes[i].push_back(inshapes[i][j]);
}
}

// create a vector of types for inputs
std::vector<int> in_types(num_in);
for (int i = 0; i < num_in; i++) {
in_types[i] = intypes[i];
}

// void pointer to hold custom state op instance created in custom library
// eventually state_op pointer is populated by instance from custom library
mxnet::ext::CustomStatefulOp** op_ptr =
reinterpret_cast<mxnet::ext::CustomStatefulOp**>(state_op);
return create_op(attrs, op_ptr);
return create_op(attrs, ctx, in_shapes, in_types, op_ptr);
}

/*! \brief returns status of calling Stateful Forward/Backward for operator from library */
Expand Down

0 comments on commit e5a7814

Please sign in to comment.