From 103e74dcac2488e82bafcf27dd09c001260c230d Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 20 May 2019 23:35:37 -0700 Subject: [PATCH 1/3] Fix --- include/mxnet/c_api.h | 2 ++ include/mxnet/ndarray.h | 2 +- python/mxnet/ndarray/ndarray.py | 6 ++---- src/c_api/c_api.cc | 4 +++- src/ndarray/ndarray.cc | 12 +++++++----- 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index f79f224029b2..568f79c31586 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -818,10 +818,12 @@ MXNET_DLL int MXNDArrayToDLPack(NDArrayHandle handle, * The memory is retained until the NDArray went out of scope. * * \param dlpack the pointer of the input DLManagedTensor +* \param transient_handle whether the handle will be destructed before calling the deleter * \param out_handle pointer holder to get pointer of NDArray * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack, + const bool transient_handle, NDArrayHandle *out_handle); /*! diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 340c38005493..e694573ed8eb 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -587,7 +587,7 @@ class NDArray { * * \return The created NDArray view. */ - static NDArray FromDLPack(const DLManagedTensor* tensor); + static NDArray FromDLPack(const DLManagedTensor* tensor, bool transient_handle); /*! * \brief Update ndarray chunk storage handles using existing ndarray storage handles diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 1c182731c78e..759c0aadcfe4 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -4156,7 +4156,7 @@ def from_dlpack(dlpack): assert ctypes.pythonapi.PyCapsule_IsValid(dlpack, _c_str_dltensor), ValueError( 'Invalid DLPack Tensor. DLTensor capsules can be consumed only once.') dlpack_handle = ctypes.c_void_p(ctypes.pythonapi.PyCapsule_GetPointer(dlpack, _c_str_dltensor)) - check_call(_LIB.MXNDArrayFromDLPack(dlpack_handle, ctypes.byref(handle))) + check_call(_LIB.MXNDArrayFromDLPack(dlpack_handle, False, ctypes.byref(handle))) # Rename PyCapsule (DLPack) ctypes.pythonapi.PyCapsule_SetName(dlpack, _c_str_used_dltensor) # delete the deleter of the old dlpack @@ -4262,8 +4262,6 @@ def _make_dl_managed_tensor(array): if not ndarray.flags['C_CONTIGUOUS']: raise ValueError("Only c-contiguous arrays are supported for zero-copy") c_obj = _make_dl_managed_tensor(ndarray) - address = ctypes.addressof(c_obj) - address = ctypes.cast(address, ctypes.c_void_p) handle = NDArrayHandle() - check_call(_LIB.MXNDArrayFromDLPack(address, ctypes.byref(handle))) + check_call(_LIB.MXNDArrayFromDLPack(ctypes.byref(c_obj), True, ctypes.byref(handle))) return NDArray(handle=handle) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index f549ddd13994..536c53537038 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -562,10 +562,12 @@ int MXNDArrayToDLPack(NDArrayHandle handle, } int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack, + const bool transient_handle, NDArrayHandle *out_handle) { API_BEGIN(); *out_handle = new NDArray(NDArray::FromDLPack( - static_cast(dlpack))); + static_cast(dlpack), + transient_handle)); API_END(); } diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 0bfca8c10a1a..ad1554ae597f 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -355,14 +355,16 @@ DLManagedTensor* NDArray::ToDLPack() const { return &(dlmanager->tensor); } -NDArray NDArray::FromDLPack(const DLManagedTensor* tensor) { - DLManagedTensor tensor_copy = *tensor; +NDArray NDArray::FromDLPack(const DLManagedTensor* tensor, bool transient_handle) { + DLManagedTensor *tensor_copy = transient_handle + ? new DLManagedTensor(*tensor) + : const_cast(tensor); auto deleter = [tensor_copy](){ - if (tensor_copy.deleter != nullptr) { - tensor_copy.deleter(const_cast(&tensor_copy)); + if (tensor_copy->deleter != nullptr) { + tensor_copy->deleter(tensor_copy); } }; - return NDArray(TBlob(tensor_copy.dl_tensor), tensor_copy.dl_tensor.ctx.device_id, deleter); + return NDArray(TBlob(tensor_copy->dl_tensor), tensor_copy->dl_tensor.ctx.device_id, deleter); } bool NDArray::fresh_out_grad() const { From 989bcc66cf791cd598db0ebcfdc8446da1bd6718 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 21 May 2019 00:10:06 -0700 Subject: [PATCH 2/3] Fix --- src/ndarray/ndarray.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index ad1554ae597f..60de62dd32eb 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -359,10 +359,13 @@ NDArray NDArray::FromDLPack(const DLManagedTensor* tensor, bool transient_handle DLManagedTensor *tensor_copy = transient_handle ? new DLManagedTensor(*tensor) : const_cast(tensor); - auto deleter = [tensor_copy](){ + auto deleter = [tensor_copy, transient_handle](){ if (tensor_copy->deleter != nullptr) { tensor_copy->deleter(tensor_copy); } + if (transient_handle) { + delete tensor_copy; + } }; return NDArray(TBlob(tensor_copy->dl_tensor), tensor_copy->dl_tensor.ctx.device_id, deleter); } From 414cca893017acb93237df790288ca9f0a0573dd Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 21 May 2019 13:04:03 -0700 Subject: [PATCH 3/3] Retrigger