Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Improve createOpState for custom ops #19103

Merged
merged 4 commits into from
Sep 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions example/extensions/lib_custom_op/gemm_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ class MyStatefulGemm : public CustomStatefulOp {
};

MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
// testing passing of keyword arguments
int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
Expand Down
6 changes: 6 additions & 0 deletions example/extensions/lib_custom_op/relu_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,18 @@ MXReturnValue MyStatefulReluGPU::Backward(std::vector<MXTensor>* inputs,


MXReturnValue createOpStateCPU(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
*op_inst = new MyStatefulReluCPU(attrs);
return MX_SUCCESS;
}

MXReturnValue createOpStateGPU(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
*op_inst = new MyStatefulReluGPU(attrs);
return MX_SUCCESS;
Expand Down
79 changes: 51 additions & 28 deletions example/extensions/lib_custom_op/test_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,36 +47,59 @@
t = mx.sym.Variable('t')
c = mx.sym.my_gemm(s,t)
d = mx.sym.state_gemm(s,t,test_kw=200)
e = mx.sym.linalg.gemm2(s,t)

in_grad = [mx.nd.empty((2,3)),mx.nd.empty((3,1))]
in_grad2 = [mx.nd.empty((2,3)),mx.nd.empty((3,1))]

exe = c.bind(ctx=mx.cpu(),args={'s':a,'t':b},args_grad=in_grad)
exe2 = d.bind(ctx=mx.cpu(),args={'s':a,'t':b},args_grad=in_grad2)
out_grad = mx.nd.ones((2,1))

out = exe.forward()
print(out)
print("-------")
# stateless
block = mx.gluon.nn.SymbolBlock(c,[s,t])
with mx.autograd.record():
a_ = mx.nd.array([[1,2,3],[4,5,6]])
b_ = mx.nd.array([[7],[8],[9]])
a_.attach_grad()
b_.attach_grad()
# foward
out = block(a_,b_)
print(out)
print('+++++')
# backward
out.backward(out_grad)
print(a_.grad)
print(b_.grad)
print("-------")

out2 = exe2.forward()
out2 = exe2.forward()
# stateful
block2 = mx.gluon.nn.SymbolBlock(d,[s,t])
block2.hybridize(static_alloc=True, static_shape=True)
out2 = block2(a,b)
out2 = block2(a,b)
print(out2)
print("-------")
with mx.autograd.record():
a_ = mx.nd.array([[1,2,3],[4,5,6]])
b_ = mx.nd.array([[7],[8],[9]])
a_.attach_grad()
b_.attach_grad()
# forward
out2 = block2(a_,b_)
print('+++++')
# backward
out2.backward(out_grad)
print(a_.grad)
print(b_.grad)
print("-------")

# baseline forward
e = mx.sym.linalg.gemm2(s,t)
in_grad3 = [mx.nd.empty((2,3)),mx.nd.empty((3,1))]
exe3 = e.bind(ctx=mx.cpu(),args={'s':a,'t':b},args_grad=in_grad3)
out3 = exe3.forward()
print(out3)

print("--------start backward compute--------")
out_grad = mx.nd.ones((2,1))
exe.backward([out_grad])
print(in_grad)
print("-------")
exe2.backward([out_grad])
print(in_grad2)
print("-------")
exe3.backward([out_grad])
print(in_grad3)
# baseline
block3 = mx.gluon.nn.SymbolBlock(e,[s,t])
with mx.autograd.record():
a_ = mx.nd.array([[1,2,3],[4,5,6]])
b_ = mx.nd.array([[7],[8],[9]])
a_.attach_grad()
b_.attach_grad()
# forward
out3 = block3(a_,b_)
print(out3)
print('+++++')
# backward
out3.backward(out_grad)
print(a_.grad)
print(b_.grad)
46 changes: 32 additions & 14 deletions example/extensions/lib_custom_op/test_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,39 @@
d = mx.sym.Variable('d')
e = mx.sym.my_relu(c)
base = mx.sym.relu(d)
in_grad = [mx.nd.empty((2,2), ctx=mx.gpu())]
in_grad_base = [mx.nd.empty((2,2), ctx=mx.gpu())]
exe = e.bind(ctx=mx.gpu(), args={'c':b}, args_grad=in_grad)
exe_base = base.bind(ctx=mx.gpu(), args={'d':b}, args_grad=in_grad_base)
out = exe.forward()
out_base = exe_base.forward()
print(out)
print(out_base)

print("--------backward compute--------")
#in_grad = [mx.nd.empty((2,2), ctx=mx.gpu())]
#in_grad_base = [mx.nd.empty((2,2), ctx=mx.gpu())]
out_grad = mx.nd.ones((2,2), ctx=mx.gpu())
exe.backward([out_grad])
exe_base.backward([out_grad])
print(in_grad)
print(in_grad_base)
#exe = e.bind(ctx=mx.gpu(), args={'c':b}, args_grad=in_grad)
block = mx.gluon.nn.SymbolBlock(e,[c])
#exe_base = base.bind(ctx=mx.gpu(), args={'d':b}, args_grad=in_grad_base)
block_base = mx.gluon.nn.SymbolBlock(base,[d])

# base
with mx.autograd.record():
b_ = mx.nd.array([[-2,-1],[1,2]], ctx=mx.gpu())
b_.attach_grad()
# foward
out_base = block_base(b_)
print(out_base)
print('+++++')
# backward
out_base.backward(out_grad)
print(b_.grad)
print("-------")

# custom relu
with mx.autograd.record():
b_ = mx.nd.array([[-2,-1],[1,2]], ctx=mx.gpu())
b_.attach_grad()
# foward
out = block(b_)
print(out)
print('+++++')
# backward
out.backward(out_grad)
print(b_.grad)
print("-------")

print("--------test ndarray with size of 1 million---------")
b = mx.nd.uniform(shape=(100,100,100), ctx=mx.gpu())
Expand Down
3 changes: 3 additions & 0 deletions example/extensions/lib_custom_op/transposecsr_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ class MyStatefulTransposeCSR : public CustomStatefulOp {
};

MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
// testing passing of keyword arguments
int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
Expand Down
3 changes: 3 additions & 0 deletions example/extensions/lib_custom_op/transposerowsp_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ class MyStatefulTransposeRowSP : public CustomStatefulOp {
};

MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
// testing passing of keyword arguments
int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
Expand Down
3 changes: 3 additions & 0 deletions example/extensions/lib_subgraph/subgraph_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ class MyStatefulOp : public CustomStatefulOp {
};

MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
std::string serialized_subgraph = "[empty]";
// MXNet subgraph is stored as Symbol in operator node attrs subgraphs field
Expand Down
13 changes: 9 additions & 4 deletions include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,9 @@ typedef MXReturnValue (*mutateInputs_t)(const std::unordered_map<std::string,
std::vector<int>* input_indices);
typedef MXReturnValue (*createOpState_t)(const std::unordered_map<std::string,
std::string>& attributes,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp**);

/*!
Expand Down Expand Up @@ -1000,8 +1003,9 @@ typedef int (*opCallMutateInputs_t)(mutateInputs_t mutate, const char* const* ke

#define MXLIB_OPCALLCREATEOPSTATE_STR "_opCallCreateOpState"
typedef int (*opCallCreateOpState_t)(createOpState_t create_op, const char* const* keys,
const char* const* vals, int num,
void** state_op);
const char* const* vals, int num, const char* dev_type,
int dev_id, unsigned int** inshapes, int* indims,
int num_in, const int* intypes, void** state_op);

#define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute"
typedef int (*opCallFStatefulComp_t)(int is_forward, void* state_op,
Expand Down Expand Up @@ -1190,8 +1194,9 @@ extern "C" {

/*! \brief returns status of calling createStatefulOp function for operator from library */
MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char* const* keys,
const char* const* vals, int num,
void** state_op);
const char* const* vals, int num, const char* dev_type,
int dev_id, unsigned int** inshapes, int* indims,
int num_in, const int* intypes, void** state_op);

/*! \brief returns status of calling Stateful Forward/Backward for operator from library */
MX_INT_RET _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t** inshapes,
Expand Down
30 changes: 28 additions & 2 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,28 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
attr_vals.push_back(kv.second.c_str());
}

// string repr of supported context for custom library, currently only "cpu" and "gpu"
const char* ctx_str = ctx.dev_mask() == Context::kCPU ? "cpu" : "gpu";

std::vector<uint32_t*> inshapes(in_shapes.size());
std::vector<int> indims(in_shapes.size());

// determine amount of memory needed to store all the input shapes
size_t buff_size = 0;
for (size_t i = 0; i < in_shapes.size(); ++i)
buff_size += in_shapes[i].ndim();

// copy input shapes to raw memory layout
std::vector<uint32_t> inbuff(buff_size);
uint32_t *ptr = inbuff.data();
for (size_t i = 0; i < in_shapes.size(); ++i) {
inshapes[i] = ptr;
indims[i] = in_shapes[i].ndim();
for (int j = 0; j < in_shapes[i].ndim(); ++j, ++ptr) {
*ptr = static_cast<uint32_t>(in_shapes[i][j]);
}
}

// convert subgraph symbol from node attributes to char*
std::string subgraph_json;
if (!attrs.subgraphs.empty()) {
Expand All @@ -1111,15 +1133,19 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
CHECK(createop_map.count("cpu") > 0)
<< "CPU CreateOpState not implemented for '" << name_str << "'";
int retval = callCreateOpState(createop_map.at("cpu"), attr_keys.data(), attr_vals.data(),
attr_keys.size(), &state_op_inst);
attr_keys.size(), ctx_str, ctx.real_dev_id(),
inshapes.data(), indims.data(),
in_shapes.size(), in_types.data(), &state_op_inst);
std::string msgs = getExtensionMsgs(msgSize, msgGet);
CHECK(retval) << "Error calling CreateOpState CPU for custom operator '" << name_str << "'"
<< msgs;
} else if (ctx.dev_mask() == Context::kGPU) {
CHECK(createop_map.count("gpu") > 0)
<< "GPU CreateOpState not implemented for '" << name_str << "'";
int retval = callCreateOpState(createop_map.at("gpu"), attr_keys.data(), attr_vals.data(),
attr_keys.size(), &state_op_inst);
attr_keys.size(), ctx_str, ctx.real_dev_id(),
inshapes.data(), indims.data(),
in_shapes.size(), in_types.data(), &state_op_inst);
std::string msgs = getExtensionMsgs(msgSize, msgGet);
CHECK(retval) << "Error calling CreateOpState GPU for custom operator '" << name_str << "'"
<< msgs;
Expand Down
23 changes: 20 additions & 3 deletions src/lib_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1204,19 +1204,36 @@ MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char* co

/*! \brief returns status of calling createStatefulOp function for operator from library */
MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char* const* keys,
const char* const* vals, int num,
void** state_op) {
const char* const* vals, int num, const char* dev_type,
int dev_id, unsigned int** inshapes, int* indims,
int num_in, const int* intypes, void** state_op) {
// create map of attributes from list
std::unordered_map<std::string, std::string> attrs;
for (int i = 0; i < num; i++) {
attrs[std::string(keys[i])] = std::string(vals[i]);
}

mxnet::ext::MXContext ctx(dev_type, dev_id);

// create a vector of shapes for inputs
std::vector<std::vector<unsigned int> > in_shapes(num_in);
for (int i = 0; i < num_in; i++) {
for (int j = 0; j < indims[i]; j++) {
in_shapes[i].push_back(inshapes[i][j]);
}
}

// create a vector of types for inputs
std::vector<int> in_types(num_in);
for (int i = 0; i < num_in; i++) {
in_types[i] = intypes[i];
}

// void pointer to hold custom state op instance created in custom library
// eventually state_op pointer is populated by instance from custom library
mxnet::ext::CustomStatefulOp** op_ptr =
reinterpret_cast<mxnet::ext::CustomStatefulOp**>(state_op);
return create_op(attrs, op_ptr);
return create_op(attrs, ctx, in_shapes, in_types, op_ptr);
}

/*! \brief returns status of calling Stateful Forward/Backward for operator from library */
Expand Down