Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-779]Add DLPack Transformation API #12047

Merged
merged 31 commits into from
Sep 22, 2018
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
822706e
add dlpack convertor api
wkcn Aug 3, 2018
8aac3da
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
wkcn Aug 3, 2018
ab6fa85
add to_dlpack and from_dlpack for NDArray
wkcn Aug 6, 2018
8c6e9d2
fix dlpack deleter and add unittest for dlpack
wkcn Aug 6, 2018
9fdfa7d
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
wkcn Aug 6, 2018
1142787
update 3rdparty
wkcn Aug 6, 2018
16df8d5
fix for cpplint
wkcn Aug 6, 2018
bfcffa2
fix pylint and add destructor for dlpack
wkcn Aug 6, 2018
f5c2552
fix pylint in base.py
wkcn Aug 6, 2018
98b5d11
fix lint in base.py
wkcn Aug 6, 2018
7bdde8f
add document for DLPack transformation API
wkcn Aug 6, 2018
f225d27
add to_dlpack_for_read and to_dlpack_for_write
wkcn Aug 7, 2018
afc1518
fix lint for ndarray.py and fix typo in c_api.h
wkcn Aug 7, 2018
8b397fd
fix function name error in c_api
wkcn Aug 7, 2018
d48074a
update code indent in tensor_blob.h ans c_api.cc, remove unused type …
wkcn Aug 7, 2018
58c5d87
use MXNDArrayToDLPack in c_api and add compactness check in TBlob
wkcn Aug 9, 2018
72edbf8
merge master and fix merge conflict
wkcn Aug 11, 2018
ef8ffcd
use python function as destructor of DLPack
wkcn Aug 11, 2018
afa1898
remove unused PyObjectHandle and update DLDataTypeTransform
wkcn Aug 11, 2018
a4d3aee
update from_dlpack code
wkcn Aug 11, 2018
493deb0
fix pylint in ndarray.py
wkcn Aug 11, 2018
adf36ef
rename dlpack after using it
wkcn Aug 12, 2018
26db4d0
merge master
wkcn Aug 13, 2018
dec838d
DLManagedTensor manages itself
wkcn Aug 22, 2018
850c3dc
add deleter for TBlob and Chunk in NDArray
wkcn Aug 22, 2018
fc99323
remove used code in python/mxnet/base.py
wkcn Aug 22, 2018
ffe60c6
retrigger CI
wkcn Aug 22, 2018
cbb17c3
add deleter for shared_ptr<Chunk>
wkcn Sep 10, 2018
e56be1f
Merge branch 'master' into DLPack-convertor-API
wkcn Sep 10, 2018
b1204bc
compilation okay
wkcn Sep 10, 2018
fe1387f
fix cpplint
wkcn Sep 10, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ typedef void *CudaModuleHandle;
typedef void *CudaKernelHandle;
/*! \brief handle to a Profile object (domain, duration, counter, etc.) */
typedef void *ProfileHandle;
/*! \brief handle to DLManagedTensor*/
typedef void *DLManagedTensorHandle;

typedef void (*ExecutorMonitorCallback)(const char*,
NDArrayHandle,
Expand Down Expand Up @@ -737,6 +739,36 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
*/
MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle,
void **out_pdata);
/*!
* \brief Create a reference view of NDArray that
* represents as DLManagedTensor.
* \param handle the handle to the ndarray
* \param out_dlpack pointer holder to get pointer of DLManagedTensor
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayToDLPack(NDArrayHandle handle,
DLManagedTensorHandle *out_dlpack);
/*!
* \brief Create a NDArray backed by a dlpack tensor.
*
* This allows us to create a NDArray using the memory
* allocated by an external deep learning framework
* that is DLPack compatible.
*
* The memory is retained until the NDArray went out of scope.
*
* \param dlpack the pointer of the input DLManagedTensor
* \param out_handle pointer holder to get pointer of NDArray
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack,
NDArrayHandle *out_handle);
/*!
* \brief Delete a dlpack tensor
* \param dlpack the pointer of the input DLManagedTensor
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayCallDLPackDeleter(DLManagedTensorHandle dlpack);
/*!
* \brief get the type of the data in NDArray
* \param handle the handle to the narray
Expand Down
20 changes: 20 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,26 @@ class NDArray {
return ret;
}

/*!
* \brief Create a reference view of NDArray that
* represents as DLManagedTensor.
* \return A DLManagedTensor
*/
DLManagedTensor* ToDLPack() const;

/*!
* \brief Create a NDArray backed by a dlpack tensor.
*
* This allows us to create a NDArray using the memory
* allocated by an external deep learning framework
* that is DLPack compatible.
*
* The memory is retained until the NDArray went out of scope.
*
* \return The created NDArray view.
*/
static NDArray FromDLPack(DLManagedTensor* tensor);

/*!
* \brief Update ndarray chunk storage handles using existing ndarray storage handles
* Also update the aux_handle, aux_shapes and aux_types.
Expand Down
47 changes: 46 additions & 1 deletion include/mxnet/tensor_blob.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ class TBlob {
: dptr_(dptr), shape_(shape), type_flag_(type_flag) {
SetDLTensor(dev_mask, dev_id);
}
/*!
* \brief constructor that construct TBlob from DLTensor
* \param DLTensor Object
*/
explicit TBlob(const DLTensor &dltensor) : dptr_(dltensor.data),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add compactness check

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically, TBlob only support compact tensors, need to check strides == null or the strides reflect a compact setting

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I will move the strides check from ndarray.cpp to tensor_blob.h.

shape_(TShape(dltensor.shape, dltensor.shape + dltensor.ndim)),
type_flag_(DLDataTypeTransform(dltensor.dtype)), dltensor_(dltensor) {
}
/*!
* \brief constructor from tensor
* \param src source tensor
Expand Down Expand Up @@ -336,14 +344,51 @@ class TBlob {
}
}
}
static int DLDataTypeTransform(DLDataType dldata_type) {
if (dldata_type.lanes != 1) {
LOG(FATAL) << "Unsupported DLDataType whose lanes != 1";
}
switch (dldata_type.code) {
case kDLFloat:
switch (dldata_type.bits) {
case 16:
return mshadow::kFloat16;
case 32:
return mshadow::kFloat32;
case 64:
return mshadow::kFloat64;
}
break;
case kDLUInt:
switch (dldata_type.bits) {
case 8:
return mshadow::kUint8;
}
break;
case kDLInt:
switch (dldata_type.bits) {
case 8:
return mshadow::kInt8;
case 32:
return mshadow::kInt32;
case 64:
return mshadow::kInt64;
}
break;
}
LOG(FATAL) << "Unknown DLDataType{" << dldata_type.code
<< ", " << dldata_type.bits
<< ", " << dldata_type.lanes << "}";
return mshadow::kFloat32;
}

inline void SetDLTensor(int dev_mask, int dev_id) {
dltensor_.data = dptr_;
dltensor_.ctx = DLContext{static_cast<DLDeviceType>(dev_mask), dev_id};
dltensor_.ndim = shape_.ndim();
dltensor_.dtype = DTypeTransform(type_flag_);
dltensor_.shape = shape_.data();
dltensor_.strides = NULL;
dltensor_.strides = nullptr;
dltensor_.byte_offset = 0;
}

Expand Down
7 changes: 5 additions & 2 deletions python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@

class NDArrayBase(object):
"""Base data structure for ndarray"""
__slots__ = ["handle", "writable"]
__slots__ = ["handle", "writable", "dlpack_handle"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not put dlpack_handle inside it, instead directly return a pycapsule that contains the handle

# pylint: disable= no-member

def __init__(self, handle, writable=True):
def __init__(self, handle, writable=True, dlpack_handle=None):
"""initialize a new NDArray

Parameters
Expand All @@ -46,9 +46,12 @@ def __init__(self, handle, writable=True):
assert isinstance(handle, NDArrayHandle)
self.handle = handle
self.writable = writable
self.dlpack_handle = dlpack_handle

def __del__(self):
check_call(_LIB.MXNDArrayFree(self.handle))
if self.dlpack_handle is not None:
check_call(_LIB.MXNDArrayCallDLPackDeleter(self.dlpack_handle))

def __reduce__(self):
return (_ndarray_cls, (None,), self.__getstate__())
Expand Down
10 changes: 10 additions & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def _load_lib():
CudaModuleHandle = ctypes.c_void_p
CudaKernelHandle = ctypes.c_void_p
ProfileHandle = ctypes.c_void_p
DLPackHandle = ctypes.c_void_p


#----------------------------
Expand Down Expand Up @@ -729,3 +730,12 @@ def write_all_str(module_file, module_all_list):
module_op_file.close()
write_all_str(module_internal_file, module_internal_all)
module_internal_file.close()

ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p,
ctypes.c_void_p]

ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p

ctypes.pythonapi.PyCapsule_SetName.restype = ctypes.c_int
ctypes.pythonapi.PyCapsule_SetName.argtypes = [ctypes.py_object, ctypes.c_char_p]
1 change: 1 addition & 0 deletions python/mxnet/cython/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ctypedef void* SymbolHandle
ctypedef void* NDArrayHandle
ctypedef void* OpHandle
ctypedef void* CachedOpHandle
ctypedef void* DLPackHandle
ctypedef unsigned nn_uint

cdef py_str(const char* x):
Expand Down
6 changes: 5 additions & 1 deletion python/mxnet/cython/ndarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ cdef class NDArrayBase:
# handle for symbolic operator.
cdef NDArrayHandle chandle
cdef int cwritable
cdef DLPackHandle cdlpack_handle

cdef _set_handle(self, handle):
cdef unsigned long long ptr
Expand All @@ -52,12 +53,15 @@ cdef class NDArrayBase:
def __get__(self):
return bool(self.cwritable)

def __init__(self, handle, writable=True):
def __init__(self, handle, writable=True, dlpack_handle=None):
self._set_handle(handle)
self.cwritable = writable
self.cdlpack_handle = dlpack_handle

def __dealloc__(self):
CALL(MXNDArrayFree(self.chandle))
if self.cdlpack_handle:
CALL(MXNDArrayCallDLPackDeleter(self.cdlpack_handle))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not put dlpack_handle inside the object


def __reduce__(self):
return (_ndarray_cls, (None,), self.__getstate__())
Expand Down
98 changes: 96 additions & 2 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import numpy as np
from ..base import _LIB, numeric_types, integer_types
from ..base import c_array, c_array_buf, c_handle_array, mx_real_t
from ..base import mx_uint, NDArrayHandle, check_call
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle
from ..base import ctypes2buffer
from ..context import Context, current_context
from . import _internal
Expand All @@ -46,7 +46,8 @@
"ones", "add", "arange", "eye", "divide", "equal", "full", "greater", "greater_equal",
"imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", "logical_xor",
"maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", "onehot_encode",
"power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram"]
"power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram",
"to_dlpack", "from_dlpack"]

_STORAGE_TYPE_UNDEFINED = -1
_STORAGE_TYPE_DEFAULT = 0
Expand Down Expand Up @@ -2205,6 +2206,22 @@ def tostype(self, stype):
"""
return op.cast_storage(self, stype=stype)

def asdlpack(self):
"""Returns a reference view of NDArray that represents as DLManagedTensor.

Returns
-------
PyCapsule (the pointer of DLManagedTensor)
a reference view of NDArray that represents as DLManagedTensor.

Examples
--------
>>> x = mx.nd.ones((2,3))
>>> y = x.asdlpack()
>>> type(y)
<class 'PyCapsule'>
"""
return to_dlpack(self)

def _get_indexing_dispatch_code(key):
"""Returns a dispatch code for calling basic or advanced indexing functions."""
Expand Down Expand Up @@ -3851,3 +3868,80 @@ def histogram(a, bins=10, range=None):
return _internal._histogram(data=a, bin_cnt=bins, range=range)
raise ValueError("bins argument should be either an integer or an NDArray")
# pylint: enable= no-member, protected-access, redefined-builtin

def pycapsule_dlpack_deleter(dlpack):
"""The deleter of DLPack Tensor

Parameters
----------
dlpack: void *
"""
ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
try:
dlpack_handle = ctypes.c_void_p(
ctypes.pythonapi.PyCapsule_GetPointer(
ctypes.c_void_p(dlpack), b'dltensor'))
check_call(_LIB.MXNDArrayCallDLPackDeleter(dlpack_handle))
except ValueError:
pass

def to_dlpack(data):
"""Returns a reference view of NDArray that represents as DLManagedTensor.

Parameters
----------
data: NDArray
input data.

Returns
-------
PyCapsule (the pointer of DLManagedTensor)
a reference view of NDArray that represents as DLManagedTensor.

Examples
--------
>>> x = mx.nd.ones((2,3))
>>> y = mx.nd.to_dlpack(x)
>>> type(y)
<class 'PyCapsule'>
"""
dlpack = DLPackHandle()
check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack)))
func_def = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
return ctypes.pythonapi.PyCapsule_New(dlpack, b'dltensor', func_def(pycapsule_dlpack_deleter))

def from_dlpack(dlpack):
"""Returns a NDArray backed by a dlpack tensor.

Parameters
----------
dlpack: PyCapsule (the pointer of DLManagedTensor)
input data

Returns
-------
NDArray
a NDArray backed by a dlpack tensor

Examples
--------
>>> x = mx.nd.ones((2,3))
>>> y = mx.nd.to_dlpack(x)
>>> type(y)
<class 'PyCapsule'>
>>> z = mx.nd.from_dlpack(y)
>>> type(z)
<class 'mxnet.ndarray.ndarray.NDArray'>
>>> z
[[ 1. 1. 1.]
[ 1. 1. 1.]]
<NDArray 2x3 @cpu(0)>
"""
handle = NDArrayHandle()
ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, ctypes.c_char_p]
dlpack_handle = ctypes.c_void_p(ctypes.pythonapi.PyCapsule_GetPointer(dlpack, b'dltensor'))
assert dlpack_handle.value != 0, ValueError(
'Invalid DLPack Tensor. DLTensor capsules can be consumed only once.')
check_call(_LIB.MXNDArrayFromDLPack(dlpack_handle, ctypes.byref(handle)))
ctypes.pythonapi.PyCapsule_SetName(dlpack, b'used_dltensor')
return NDArray(handle=handle, dlpack_handle=dlpack_handle)
26 changes: 26 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,32 @@ int MXNDArrayGetData(NDArrayHandle handle,
API_END();
}

int MXNDArrayToDLPack(NDArrayHandle handle,
DLManagedTensorHandle *out_dlpack) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
*out_dlpack = arr->ToDLPack();
API_END();
}

int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack,
NDArrayHandle *out_handle) {
API_BEGIN();
NDArray *pdata = new NDArray();
*pdata = NDArray::FromDLPack(
static_cast<DLManagedTensor*>(dlpack));
*out_handle = pdata;
API_END();
}

int MXNDArrayCallDLPackDeleter(DLManagedTensorHandle dlpack) {
API_BEGIN();
DLManagedTensor *p_dlpack = static_cast<DLManagedTensor*>(dlpack);
if (p_dlpack)
p_dlpack->deleter(p_dlpack);
API_END();
}

int MXNDArrayGetDType(NDArrayHandle handle,
int *out_dtype) {
API_BEGIN();
Expand Down
Loading