diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 00439962a944..a01cc6a77940 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -93,6 +93,8 @@ typedef void *CudaModuleHandle; typedef void *CudaKernelHandle; /*! \brief handle to a Profile object (domain, duration, counter, etc.) */ typedef void *ProfileHandle; +/*! \brief handle to DLManagedTensor*/ +typedef void *DLManagedTensorHandle; typedef void (*ExecutorMonitorCallback)(const char*, NDArrayHandle, @@ -746,6 +748,40 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, */ MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle, void **out_pdata); +/*! +* \brief Create a reference view of NDArray that +* represents as DLManagedTensor +* Notice: MXNet uses asynchronous execution. Please call MXNDArrayWaitToRead or +* MXNDArrayWaitToWrite before calling MXNDArrayToDLPack. +* \param handle the handle to the ndarray +* \param out_dlpack pointer holder to get pointer of DLManagedTensor +* \return 0 when success, -1 when failure happens +*/ +MXNET_DLL int MXNDArrayToDLPack(NDArrayHandle handle, + DLManagedTensorHandle *out_dlpack); + +/*! +* \brief Create a NDArray backed by a dlpack tensor. +* +* This allows us to create a NDArray using the memory +* allocated by an external deep learning framework +* that is DLPack compatible. +* +* The memory is retained until the NDArray went out of scope. +* +* \param dlpack the pointer of the input DLManagedTensor +* \param out_handle pointer holder to get pointer of NDArray +* \return 0 when success, -1 when failure happens +*/ +MXNET_DLL int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack, + NDArrayHandle *out_handle); +/*! + * \brief Delete a dlpack tensor + * \param dlpack the pointer of the input DLManagedTensor + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArrayCallDLPackDeleter(DLManagedTensorHandle dlpack); + /*! * \brief get the type of the data in NDArray * \param handle the handle to the narray diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 6141a4da78ef..afae5dcfcffe 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -116,6 +116,26 @@ class NDArray { dtype_(data.type_flag_), storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) { } + + /*! + * \brief constructing a static NDArray that shares data with TBlob which is with deleter + * Use with caution: allocate ONLY ONE NDArray for each TBlob, + * make sure the memory region is available through out the life of NDArray + * \param data the memory content of static data + * \param dev_id the device id this tensor sits at + * \param deleter the function pointer of custom deleter + */ + NDArray(const TBlob &data, int dev_id, const std::function& deleter) + : ptr_(new Chunk(data, dev_id), + [deleter](Chunk *p) { + deleter(); // call custom deleter + delete p; // delete Chunk object + }), + shape_(data.shape_), + dtype_(data.type_flag_), storage_type_(kDefaultStorage), + entry_({nullptr, 0, 0}) { + } + /*! \brief create ndarray from shared memory */ NDArray(int shared_pid, int shared_id, const TShape& shape, int dtype) : ptr_(std::make_shared(shared_pid, shared_id, shape, dtype)), shape_(shape), @@ -523,6 +543,26 @@ class NDArray { return ret; } + /*! + * \brief Create a reference view of NDArray that + * represents as DLManagedTensor. + * \return A DLManagedTensor + */ + DLManagedTensor* ToDLPack() const; + + /*! + * \brief Create a NDArray backed by a dlpack tensor. + * + * This allows us to create a NDArray using the memory + * allocated by an external deep learning framework + * that is DLPack compatible. + * + * The memory is retained until the NDArray went out of scope. + * + * \return The created NDArray view. + */ + static NDArray FromDLPack(const DLManagedTensor* tensor); + /*! * \brief Update ndarray chunk storage handles using existing ndarray storage handles * Also update the aux_handle, aux_shapes and aux_types. diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h index 6f604a5bb8d9..496e8c7cfced 100755 --- a/include/mxnet/tensor_blob.h +++ b/include/mxnet/tensor_blob.h @@ -104,6 +104,39 @@ class TBlob { : dptr_(dptr), shape_(shape), type_flag_(type_flag) { SetDLTensor(dev_mask, dev_id); } + /*! + * \brief constructor that construct TBlob from DLTensor + * \param DLTensor Object + */ + explicit TBlob(const DLTensor &dltensor) + : dptr_(dltensor.data), + shape_(TShape(dltensor.shape, dltensor.shape + dltensor.ndim)), + type_flag_(DLDataTypeTransform(dltensor.dtype)), + dltensor_(dltensor) { + // compactness check for DLTensor + if (dltensor.strides != nullptr) { + // check strides + const int &ndim = dltensor.ndim; + const int64_t *shape = dltensor.shape; + const int64_t *strides = dltensor.strides; + if (ndim >= 1) { + bool err = false; + if (strides[ndim - 1] != 1) { + err = true; + } else { + for (int i = ndim - 2; i >= 0; --i) { + if (strides[i] != shape[i + 1] * strides[i + 1]) { + err = true; + break; + } + } + } + if (err) { + LOG(FATAL) << "Unsupported DLPack because MXNet only support compact tensor now"; + } + } + } + } /*! * \brief constructor from tensor * \param src source tensor @@ -336,6 +369,36 @@ class TBlob { } } } + static int DLDataTypeTransform(DLDataType dldata_type) { + if (dldata_type.lanes != 1) { + LOG(FATAL) << "Unsupported DLDataType whose lanes != 1"; + } + switch (dldata_type.code) { + case kDLFloat: + switch (dldata_type.bits) { + case 16: return mshadow::kFloat16; + case 32: return mshadow::kFloat32; + case 64: return mshadow::kFloat64; + } + break; + case kDLUInt: + switch (dldata_type.bits) { + case 8: return mshadow::kUint8; + } + break; + case kDLInt: + switch (dldata_type.bits) { + case 8: return mshadow::kInt8; + case 32: return mshadow::kInt32; + case 64: return mshadow::kInt64; + } + break; + } + LOG(FATAL) << "Unknown DLDataType{" << dldata_type.code + << ", " << dldata_type.bits + << ", " << dldata_type.lanes << "}"; + return mshadow::kFloat32; + } inline void SetDLTensor(int dev_mask, int dev_id) { dltensor_.data = dptr_; @@ -343,7 +406,7 @@ class TBlob { dltensor_.ndim = shape_.ndim(); dltensor_.dtype = DTypeTransform(type_flag_); dltensor_.shape = shape_.data(); - dltensor_.strides = NULL; + dltensor_.strides = nullptr; dltensor_.byte_offset = 0; } diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 89e1c9e087b5..84b9e5831c69 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -232,6 +232,7 @@ def _load_lib(): CudaModuleHandle = ctypes.c_void_p CudaKernelHandle = ctypes.c_void_p ProfileHandle = ctypes.c_void_p +DLPackHandle = ctypes.c_void_p #---------------------------- @@ -726,3 +727,6 @@ def write_all_str(module_file, module_all_list): module_op_file.close() write_all_str(module_internal_file, module_internal_all) module_internal_file.close() + +ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object +ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index d6d619f30cab..fabf42e1dc63 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -34,8 +34,8 @@ from functools import reduce # pylint: disable=redefined-builtin import numpy as np from ..base import _LIB, numeric_types, integer_types -from ..base import c_array, c_array_buf, c_handle_array, mx_real_t -from ..base import mx_uint, NDArrayHandle, check_call +from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t +from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle from ..base import ctypes2buffer from ..context import Context, current_context from . import _internal @@ -46,7 +46,8 @@ "ones", "add", "arange", "eye", "divide", "equal", "full", "greater", "greater_equal", "imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", "logical_xor", "maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", "onehot_encode", - "power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram"] + "power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram", + "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack"] _STORAGE_TYPE_UNDEFINED = -1 _STORAGE_TYPE_DEFAULT = 0 @@ -178,7 +179,6 @@ class NDArray(NDArrayBase): # See C++ side of definition(kTVMNDArrayTypeCode) at include/mxmet/tensor_blob.h _tvm_tcode = 19 # pylint: disable= no-member, undefined-variable - @property def _tvm_handle(self): return self.handle.value @@ -2205,6 +2205,52 @@ def tostype(self, stype): """ return op.cast_storage(self, stype=stype) + def to_dlpack_for_read(self): + """Returns a reference view of NDArray that represents as DLManagedTensor until + all previous write operations on the current array are finished. + + Returns + ------- + PyCapsule (the pointer of DLManagedTensor) + a reference view of NDArray that represents as DLManagedTensor. + + Examples + -------- + >>> x = mx.nd.ones((2,3)) + >>> y = mx.nd.to_dlpack_for_read(x) + >>> type(y) + + >>> z = mx.nd.from_dlpack(y) + >>> z + [[1. 1. 1.] + [1. 1. 1.]] + + """ + return to_dlpack_for_read(self) + + def to_dlpack_for_write(self): + """Returns a reference view of NDArray that represents as DLManagedTensor until + all previous read/write operations on the current array are finished. + + Returns + ------- + PyCapsule (the pointer of DLManagedTensor) + a reference view of NDArray that represents as DLManagedTensor. + + Examples + -------- + >>> x = mx.nd.ones((2,3)) + >>> w = mx.nd.to_dlpack_for_write(x) + >>> type(w) + + >>> u = mx.nd.from_dlpack(w) + >>> u += 1 + >>> x + [[2. 2. 2.] + [2. 2. 2.]] + + """ + return to_dlpack_for_write(self) def _get_indexing_dispatch_code(key): """Returns a dispatch code for calling basic or advanced indexing functions.""" @@ -3851,3 +3897,128 @@ def histogram(a, bins=10, range=None): return _internal._histogram(data=a, bin_cnt=bins, range=range) raise ValueError("bins argument should be either an integer or an NDArray") # pylint: enable= no-member, protected-access, redefined-builtin + +PyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p) +_c_str_dltensor = c_str('dltensor') +_c_str_used_dltensor = c_str('used_dltensor') + +def _dlpack_deleter(pycapsule): + pycapsule = ctypes.c_void_p(pycapsule) + if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor): + ptr = ctypes.c_void_p( + ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor)) + check_call(_LIB.MXNDArrayCallDLPackDeleter(ptr)) + +_c_dlpack_deleter = PyCapsuleDestructor(_dlpack_deleter) + +def to_dlpack_for_read(data): + """Returns a reference view of NDArray that represents as DLManagedTensor until + all previous write operations on the current array are finished. + + Parameters + ---------- + data: NDArray + input data. + + Returns + ------- + PyCapsule (the pointer of DLManagedTensor) + a reference view of NDArray that represents as DLManagedTensor. + + Examples + -------- + >>> x = mx.nd.ones((2,3)) + >>> y = mx.nd.to_dlpack_for_read(x) + >>> type(y) + + >>> z = mx.nd.from_dlpack(y) + >>> z + [[1. 1. 1.] + [1. 1. 1.]] + + """ + data.wait_to_read() + dlpack = DLPackHandle() + check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack))) + return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, _c_dlpack_deleter) + +def to_dlpack_for_write(data): + """Returns a reference view of NDArray that represents as DLManagedTensor until + all previous read/write operations on the current array are finished. + + Parameters + ---------- + data: NDArray + input data. + + Returns + ------- + PyCapsule (the pointer of DLManagedTensor) + a reference view of NDArray that represents as DLManagedTensor. + + Examples + -------- + >>> x = mx.nd.ones((2,3)) + >>> w = mx.nd.to_dlpack_for_write(x) + >>> type(w) + + >>> u = mx.nd.from_dlpack(w) + >>> u += 1 + >>> x + [[2. 2. 2.] + [2. 2. 2.]] + + """ + check_call(_LIB.MXNDArrayWaitToWrite(data.handle)) + dlpack = DLPackHandle() + check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack))) + return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, _c_dlpack_deleter) + +def from_dlpack(dlpack): + """Returns a NDArray backed by a dlpack tensor. + + Parameters + ---------- + dlpack: PyCapsule (the pointer of DLManagedTensor) + input data + + Returns + ------- + NDArray + a NDArray backed by a dlpack tensor + + Examples + -------- + >>> x = mx.nd.ones((2,3)) + >>> y = mx.nd.to_dlpack_for_read(x) + >>> type(y) + + >>> z = mx.nd.from_dlpack(y) + >>> type(z) + + >>> z + [[ 1. 1. 1.] + [ 1. 1. 1.]] + + + >>> w = mx.nd.to_dlpack_for_write(x) + >>> type(w) + + >>> u = mx.nd.from_dlpack(w) + >>> u += 1 + >>> x + [[2. 2. 2.] + [2. 2. 2.]] + + """ + handle = NDArrayHandle() + dlpack = ctypes.py_object(dlpack) + assert ctypes.pythonapi.PyCapsule_IsValid(dlpack, _c_str_dltensor), ValueError( + 'Invalid DLPack Tensor. DLTensor capsules can be consumed only once.') + dlpack_handle = ctypes.c_void_p(ctypes.pythonapi.PyCapsule_GetPointer(dlpack, _c_str_dltensor)) + check_call(_LIB.MXNDArrayFromDLPack(dlpack_handle, ctypes.byref(handle))) + # Rename PyCapsule (DLPack) + ctypes.pythonapi.PyCapsule_SetName(dlpack, _c_str_used_dltensor) + # delete the deleter of the old dlpack + ctypes.pythonapi.PyCapsule_SetDestructor(dlpack, None) + return NDArray(handle=handle) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 1ef3f0fca9f3..56e318097a3c 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -500,6 +500,31 @@ int MXNDArrayGetData(NDArrayHandle handle, API_END(); } +int MXNDArrayToDLPack(NDArrayHandle handle, + DLManagedTensorHandle *out_dlpack) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + *out_dlpack = arr->ToDLPack(); + API_END(); +} + +int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack, + NDArrayHandle *out_handle) { + API_BEGIN(); + *out_handle = new NDArray(NDArray::FromDLPack( + static_cast(dlpack))); + API_END(); +} + +int MXNDArrayCallDLPackDeleter(DLManagedTensorHandle dlpack) { + API_BEGIN(); + if (dlpack != nullptr) { + DLManagedTensor *p_dlpack = static_cast(dlpack); + p_dlpack->deleter(p_dlpack); + } + API_END(); +} + int MXNDArrayGetDType(NDArrayHandle handle, int *out_dtype) { API_BEGIN(); diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 853838a87f4c..5bcb1c2bf485 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -312,6 +312,34 @@ NDArray NDArray::data_ndarray() const { return ret; } +struct NDArrayDLManager { + NDArray handle; // ref NDArray + DLManagedTensor tensor; +}; + +DLManagedTensor* NDArray::ToDLPack() const { + NDArrayDLManager* dlmanager(new NDArrayDLManager); + dlmanager->handle = *this; + if (!is_none()) { + dlmanager->tensor.dl_tensor = data().dltensor(); + } + dlmanager->tensor.manager_ctx = dlmanager; + dlmanager->tensor.deleter = [](DLManagedTensor* dlmanager){ + delete static_cast(dlmanager->manager_ctx); + }; + return &(dlmanager->tensor); +} + +NDArray NDArray::FromDLPack(const DLManagedTensor* tensor) { + const DLTensor &dl_tensor = tensor->dl_tensor; + auto deleter = [tensor](){ + if (tensor->deleter != nullptr) { + tensor->deleter(const_cast(tensor)); + } + }; + return NDArray(TBlob(dl_tensor), dl_tensor.ctx.device_id, deleter); +} + bool NDArray::fresh_out_grad() const { if (Imperative::AGInfo::IsNone(*this)) return false; Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node); diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index c48801ec1cec..e5fc19b190a8 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1439,6 +1439,37 @@ def test_ndarray_cpu_shared_ctx(): res = mx.nd.zeros((1, 2, 3), ctx=ctx) assert(res.context == ctx) +@with_seed() +def test_dlpack(): + for dtype in [np.float32, np.int32]: + for shape in [(3, 4, 5, 6), (2, 10), (15,)]: + a = mx.nd.random.uniform(shape = shape) + a_np = a.asnumpy() + + pack = a.to_dlpack_for_read() + b = mx.nd.from_dlpack(pack) + + a_copy = a.copy() + pack2 = a_copy.to_dlpack_for_write() + c = mx.nd.from_dlpack(pack2) + + pack3 = mx.nd.to_dlpack_for_read(a) + d = mx.nd.from_dlpack(pack3) + + a_copy = a.copy() + pack4 = mx.nd.to_dlpack_for_write(a_copy) + e = mx.nd.from_dlpack(pack4) + + del a, pack, pack2, pack3, pack4 + + b_np = b.asnumpy() + c_np = c.asnumpy() + d_np = d.asnumpy() + e_np = e.asnumpy() + mx.test_utils.assert_almost_equal(a_np, b_np) + mx.test_utils.assert_almost_equal(a_np, c_np) + mx.test_utils.assert_almost_equal(a_np, d_np) + mx.test_utils.assert_almost_equal(a_np, e_np) if __name__ == '__main__': import nose