Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Enhancements for MXTensor for custom operators #17204

Merged
merged 16 commits into from
Jan 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 45 additions & 45 deletions include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#include <utility>
#include <stdexcept>

#define MX_LIBRARY_VERSION 1
#define MX_LIBRARY_VERSION 2

/*
* Import from DLPack https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
Expand Down Expand Up @@ -198,6 +198,7 @@ enum MXDType {
kInt32 = 4,
kInt8 = 5,
kInt64 = 6,
kUNSET = 100,
};

enum MXReturnValue {
Expand All @@ -209,10 +210,22 @@ enum MXReturnValue {
* \brief Tensor data structure used by custom operator
*/
struct MXTensor {
MXTensor() : data_ptr(NULL) {}

MXTensor(void *data_ptr, const std::vector<int64_t> &shape, MXDType dtype)
: data_ptr(data_ptr), shape(shape), dtype(dtype) {}
MXTensor() : data_ptr(NULL), dtype(kUNSET), verID(0) {}

MXTensor(void *data_ptr, const std::vector<int64_t> &shape, MXDType dtype,
size_t vID)
: data_ptr(data_ptr), shape(shape), dtype(dtype), verID(vID) {}

/*! \brief populate internal tensor fields */
void setTensor(void *dptr, MXDType type, const int64_t* dims,
int ndims, size_t vID) {
data_ptr = dptr; dtype = type; verID = vID;
shape.clear();
for (int j = 0; j < ndims; j++) {
shape.push_back(dims[j]);
}
setDLTensor();
}

/*! \brief populate DLTensor fields */
void setDLTensor() {
Expand Down Expand Up @@ -277,6 +290,14 @@ struct MXTensor {
return size;
}

/*! \brief helper function to compare two MXTensors */
inline bool isSame(const MXTensor &oth) const {
return data_ptr == oth.data_ptr &&
dtype == oth.dtype &&
verID == oth.verID &&
shape == oth.shape;
}

// data is flatten 1D repr of tensor, elements are in continuous memory
// user can access each element using the shape of tensor
void *data_ptr;
Expand All @@ -287,6 +308,9 @@ struct MXTensor {
// type can only be MXDType enum types
MXDType dtype;

// version number updated if the tensor has changed since the last use by custom op
size_t verID;

// corresponding DLTensor repr of MXTensor
// easy way to reuse functions taking DLTensor
DLTensor dltensor;
Expand Down Expand Up @@ -684,15 +708,9 @@ typedef int (*opCallInferType_t)(inferType_t, const char* const*, const char* co

#define MXLIB_OPCALLFCOMP_STR "_opCallFCompute"
typedef int (*opCallFComp_t)(fcomp_t, const char* const*, const char* const*, int,
const int64_t**, int*, void**, int*, int,
const int64_t**, int*, void**, int*, int,
xpu_malloc_t, void*);

#define MXLIB_OPCALLBKWD_STR "_opCallBackward"
typedef int (*opCallBkwd_t)(fcomp_t, const char* const*, const char* const*, int,
const int64_t**, int*, void**, int*, int,
const int64_t**, int*, void**, int*, int,
xpu_malloc_t, void*);
const int64_t**, int*, void**, int*, size_t*, int,
const int64_t**, int*, void**, int*, size_t*, int,
xpu_malloc_t, void*);

#define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs"
typedef int (*opCallMutateInputs_t)(mutateInputs_t, const char* const*, const char* const*, int,
Expand All @@ -703,9 +721,9 @@ typedef int (*opCallCreateOpState_t)(createOpState_t, const char* const*, const
void**);

#define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute"
typedef int (*opCallFStatefulComp_t)(bool, void*, const int64_t**, int*, void**, int*, int,
const int64_t**, int*, void**, int*, int,
xpu_malloc_t, void*);
typedef int (*opCallFStatefulComp_t)(bool, void*, const int64_t**, int*, void**, int*, size_t*,
int, const int64_t**, int*, void**, int*, size_t*,
int, xpu_malloc_t, void*);

#define MXLIB_INITIALIZE_STR "initialize"
typedef int (*initialize_t)(int);
Expand Down Expand Up @@ -876,9 +894,9 @@ extern "C" {
_opCallFCompute(fcomp_t fcomp, const char* const* keys,
const char* const* vals, int num,
const int64_t** inshapes, int* indims,
void** indata, int* intypes, int num_in,
void** indata, int* intypes, size_t* inIDs, int num_in,
const int64_t** outshapes, int* outdims,
void** outdata, int* outtypes, int num_out,
void** outdata, int* outtypes, size_t* outIDs, int num_out,
xpu_malloc_t cpu_malloc, void* cpu_alloc) {
// create map of attributes from list
std::map<std::string, std::string> attrs;
Expand All @@ -889,23 +907,14 @@ extern "C" {
// create a vector of tensors for inputs
std::vector<MXTensor> inputs(num_in);
for (int i = 0; i < num_in; i++) {
inputs[i].data_ptr = indata[i];
inputs[i].dtype = (MXDType)intypes[i];
for (int j = 0; j < indims[i]; j++) {
inputs[i].shape.push_back(inshapes[i][j]);
}
inputs[i].setDLTensor();
inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i], inIDs[i]);
}

// create a vector of tensors for outputs
std::vector<MXTensor> outputs(num_out);
for (int i = 0; i < num_out; i++) {
outputs[i].data_ptr = outdata[i];
outputs[i].dtype = (MXDType) outtypes[i];
for (int j = 0; j < outdims[i]; j++) {
outputs[i].shape.push_back(outshapes[i][j]);
}
outputs[i].setDLTensor();
outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i],
outIDs[i]);
}

OpResource res(cpu_malloc, cpu_alloc);
Expand Down Expand Up @@ -973,30 +982,21 @@ extern "C" {
#endif
_opCallFStatefulCompute(bool is_forward, void* state_op,
const int64_t** inshapes, int* indims,
void** indata, int* intypes, int num_in,
void** indata, int* intypes, size_t* inIDs, int num_in,
const int64_t** outshapes, int* outdims,
void** outdata, int* outtypes, int num_out,
void** outdata, int* outtypes, size_t* outIDs, int num_out,
xpu_malloc_t cpu_malloc, void* cpu_alloc) {
// create a vector of tensors for inputs
std::vector<MXTensor> inputs(num_in);
for (int i = 0; i < num_in; i++) {
inputs[i].data_ptr = indata[i];
inputs[i].dtype = (MXDType)intypes[i];
for (int j = 0; j < indims[i]; j++) {
inputs[i].shape.push_back(inshapes[i][j]);
}
inputs[i].setDLTensor();
inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i], inIDs[i]);
}

// create a vector of tensors for outputs
std::vector<MXTensor> outputs(num_out);
for (int i = 0; i < num_out; i++) {
outputs[i].data_ptr = outdata[i];
outputs[i].dtype = (MXDType) outtypes[i];
for (int j = 0; j < outdims[i]; j++) {
outputs[i].shape.push_back(outshapes[i][j]);
}
outputs[i].setDLTensor();
outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i],
outIDs[i]);
}
OpResource res(cpu_malloc, cpu_alloc);
CustomStatefulOp* op_ptr = reinterpret_cast<CustomStatefulOp*>(state_op);
Expand Down
17 changes: 12 additions & 5 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,13 +395,15 @@ int MXLoadLib(const char *path) {
std::vector<const int64_t *> in_shapes, out_shapes;
std::vector<int> in_dims, out_dims;
std::vector<int> in_types, out_types;
std::vector<size_t> in_verIDs, out_verIDs;

// convert input tensors to constituent parts
for (size_t i = 0; i < inputs.size(); i++) {
in_data.push_back(inputs[i].data().dptr_);
in_shapes.push_back(inputs[i].shape().data());
in_dims.push_back(inputs[i].shape().ndim());
in_types.push_back(inputs[i].dtype());
in_verIDs.push_back(inputs[i].version());
}

// convert output tensors to constituent parts
Expand All @@ -410,6 +412,7 @@ int MXLoadLib(const char *path) {
out_shapes.push_back(outputs[i].shape().data());
out_dims.push_back(outputs[i].shape().ndim());
out_types.push_back(outputs[i].dtype());
out_verIDs.push_back(outputs[i].version());
}

// get memory resource
Expand Down Expand Up @@ -438,9 +441,10 @@ int MXLoadLib(const char *path) {
// call fcompute function
CHECK(callFComp(fcomp_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
in_shapes.data(), in_dims.data(), in_data.data(),
in_types.data(), in_data.size(),
in_types.data(), in_verIDs.data(), in_data.size(),
out_shapes.data(), out_dims.data(), out_data.data(),
out_types.data(), out_data.size(), cpu_malloc, &cpu_alloc))
out_types.data(), out_verIDs.data(), out_data.size(),
cpu_malloc, &cpu_alloc))
<< "Error calling FCompute for custom operator '" << name_str << "'";

// return type void
Expand Down Expand Up @@ -570,13 +574,15 @@ int MXLoadLib(const char *path) {
std::vector<const int64_t *> in_shapes, out_shapes;
std::vector<int> in_dims, out_dims;
std::vector<int> in_types, out_types;
std::vector<size_t> in_verIDs, out_verIDs;

// convert input tensors to constituent parts
for (size_t i = 0; i < inputs.size(); i++) {
in_data.push_back(inputs[i].data().dptr_);
in_shapes.push_back(inputs[i].shape().data());
in_dims.push_back(inputs[i].shape().ndim());
in_types.push_back(inputs[i].dtype());
in_verIDs.push_back(inputs[i].version());
}

// convert output tensors to constituent parts
Expand All @@ -585,6 +591,7 @@ int MXLoadLib(const char *path) {
out_shapes.push_back(outputs[i].shape().data());
out_dims.push_back(outputs[i].shape().ndim());
out_types.push_back(outputs[i].dtype());
out_verIDs.push_back(outputs[i].version());
}

// get memory resource
Expand Down Expand Up @@ -618,9 +625,9 @@ int MXLoadLib(const char *path) {

// call fcompute function
CHECK(callFStatefulComp(is_forward, state_op_inst, in_shapes.data(), in_dims.data(),
in_data.data(), in_types.data(), in_data.size(),
out_shapes.data(), out_dims.data(), out_data.data(),
out_types.data(), out_data.size(), cpu_malloc, &cpu_alloc))
in_data.data(), in_types.data(), in_verIDs.data(), in_data.size(),
out_shapes.data(), out_dims.data(), out_data.data(), out_types.data(),
out_verIDs.data(), out_data.size(), cpu_malloc, &cpu_alloc))
<< "Error calling FStatefulCompute for custom operator '" << name_str << "'";
};

Expand Down