diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index f79f224029b2..568f79c31586 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -818,10 +818,12 @@ MXNET_DLL int MXNDArrayToDLPack(NDArrayHandle handle, * The memory is retained until the NDArray went out of scope. * * \param dlpack the pointer of the input DLManagedTensor +* \param transient_handle whether the handle will be destructed before calling the deleter * \param out_handle pointer holder to get pointer of NDArray * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack, + const bool transient_handle, NDArrayHandle *out_handle); /*! diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 340c38005493..e694573ed8eb 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -587,7 +587,7 @@ class NDArray { * * \return The created NDArray view. */ - static NDArray FromDLPack(const DLManagedTensor* tensor); + static NDArray FromDLPack(const DLManagedTensor* tensor, bool transient_handle); /*! * \brief Update ndarray chunk storage handles using existing ndarray storage handles diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 1c182731c78e..759c0aadcfe4 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -4156,7 +4156,7 @@ def from_dlpack(dlpack): assert ctypes.pythonapi.PyCapsule_IsValid(dlpack, _c_str_dltensor), ValueError( 'Invalid DLPack Tensor. DLTensor capsules can be consumed only once.') dlpack_handle = ctypes.c_void_p(ctypes.pythonapi.PyCapsule_GetPointer(dlpack, _c_str_dltensor)) - check_call(_LIB.MXNDArrayFromDLPack(dlpack_handle, ctypes.byref(handle))) + check_call(_LIB.MXNDArrayFromDLPack(dlpack_handle, False, ctypes.byref(handle))) # Rename PyCapsule (DLPack) ctypes.pythonapi.PyCapsule_SetName(dlpack, _c_str_used_dltensor) # delete the deleter of the old dlpack @@ -4262,8 +4262,6 @@ def _make_dl_managed_tensor(array): if not ndarray.flags['C_CONTIGUOUS']: raise ValueError("Only c-contiguous arrays are supported for zero-copy") c_obj = _make_dl_managed_tensor(ndarray) - address = ctypes.addressof(c_obj) - address = ctypes.cast(address, ctypes.c_void_p) handle = NDArrayHandle() - check_call(_LIB.MXNDArrayFromDLPack(address, ctypes.byref(handle))) + check_call(_LIB.MXNDArrayFromDLPack(ctypes.byref(c_obj), True, ctypes.byref(handle))) return NDArray(handle=handle) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index f549ddd13994..536c53537038 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -562,10 +562,12 @@ int MXNDArrayToDLPack(NDArrayHandle handle, } int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack, + const bool transient_handle, NDArrayHandle *out_handle) { API_BEGIN(); *out_handle = new NDArray(NDArray::FromDLPack( - static_cast(dlpack))); + static_cast(dlpack), + transient_handle)); API_END(); } diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 0bfca8c10a1a..60de62dd32eb 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -355,14 +355,19 @@ DLManagedTensor* NDArray::ToDLPack() const { return &(dlmanager->tensor); } -NDArray NDArray::FromDLPack(const DLManagedTensor* tensor) { - DLManagedTensor tensor_copy = *tensor; - auto deleter = [tensor_copy](){ - if (tensor_copy.deleter != nullptr) { - tensor_copy.deleter(const_cast(&tensor_copy)); +NDArray NDArray::FromDLPack(const DLManagedTensor* tensor, bool transient_handle) { + DLManagedTensor *tensor_copy = transient_handle + ? new DLManagedTensor(*tensor) + : const_cast(tensor); + auto deleter = [tensor_copy, transient_handle](){ + if (tensor_copy->deleter != nullptr) { + tensor_copy->deleter(tensor_copy); + } + if (transient_handle) { + delete tensor_copy; } }; - return NDArray(TBlob(tensor_copy.dl_tensor), tensor_copy.dl_tensor.ctx.device_id, deleter); + return NDArray(TBlob(tensor_copy->dl_tensor), tensor_copy->dl_tensor.ctx.device_id, deleter); } bool NDArray::fresh_out_grad() const {