-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-779]Add DLPack Transformation API #12047
Changes from 15 commits
822706e
8aac3da
ab6fa85
8c6e9d2
9fdfa7d
1142787
16df8d5
bfcffa2
f5c2552
98b5d11
7bdde8f
f225d27
afc1518
8b397fd
d48074a
58c5d87
72edbf8
ef8ffcd
afa1898
a4d3aee
493deb0
adf36ef
26db4d0
dec838d
850c3dc
fc99323
ffe60c6
cbb17c3
e56be1f
b1204bc
fe1387f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,6 +93,10 @@ typedef void *CudaModuleHandle; | |
typedef void *CudaKernelHandle; | ||
/*! \brief handle to a Profile object (domain, duration, counter, etc.) */ | ||
typedef void *ProfileHandle; | ||
/*! \brief handle to DLManagedTensor*/ | ||
typedef void *DLManagedTensorHandle; | ||
/*! \brief handle to PyObject*/ | ||
typedef void *PyObjectHandle; | ||
|
||
typedef void (*ExecutorMonitorCallback)(const char*, | ||
NDArrayHandle, | ||
|
@@ -737,6 +741,57 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, | |
*/ | ||
MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle, | ||
void **out_pdata); | ||
/*! | ||
* \brief Create a reference view of NDArray that | ||
* represents as DLManagedTensor until | ||
* all the pending writes with respect NDArray are finished. | ||
* \param handle the handle to the ndarray | ||
* \param out_dlpack pointer holder to get pointer of DLManagedTensor | ||
* \return 0 when success, -1 when failure happens | ||
*/ | ||
MXNET_DLL int MXNDArrayToDLPackForRead(NDArrayHandle handle, | ||
DLManagedTensorHandle *out_dlpack); | ||
|
||
/*! | ||
* \brief Create a reference view of NDArray that | ||
* represents as DLManagedTensor until | ||
* all the pending reads/writes with respect NDArray are finished. | ||
* \param handle the handle to the ndarray | ||
* \param out_dlpack pointer holder to get pointer of DLManagedTensor | ||
* \return 0 when success, -1 when failure happens | ||
*/ | ||
MXNET_DLL int MXNDArrayToDLPackForWrite(NDArrayHandle handle, | ||
DLManagedTensorHandle *out_dlpack); | ||
|
||
/*! | ||
* \brief Create a NDArray backed by a dlpack tensor. | ||
* | ||
* This allows us to create a NDArray using the memory | ||
* allocated by an external deep learning framework | ||
* that is DLPack compatible. | ||
* | ||
* The memory is retained until the NDArray went out of scope. | ||
* | ||
* \param dlpack the pointer of the input DLManagedTensor | ||
* \param out_handle pointer holder to get pointer of NDArray | ||
* \return 0 when success, -1 when failure happens | ||
*/ | ||
MXNET_DLL int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack, | ||
NDArrayHandle *out_handle); | ||
/*! | ||
* \brief Delete a dlpack tensor | ||
* \param dlpack the pointer of the input DLManagedTensor | ||
* \return 0 when success, -1 when failure happens | ||
*/ | ||
MXNET_DLL int MXNDArrayCallDLPackDeleter(DLManagedTensorHandle dlpack); | ||
|
||
/*! | ||
* \brief Delete a dlpack tensor | ||
* \param dlpack_capsule the pointer of a PyCapsule storing DLManagedTensor | ||
* \return 0 when success, -1 when failure happens | ||
*/ | ||
MXNET_DLL void MXNDArrayCallDLPackCapsuleDeleter(PyObjectHandle dlpack_capsule); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am less certain why we need the deleter function here, can they be directly handled in the python/cython side? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried to implement a deleter function in python, however the deleter function may be released by Python GC before calling the deleter function. See the test Code. It will raise segmentation fault. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Take a look at destructor at apache/tvm#1573 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is some subtlty here but they can never-the-less be implemented There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I ever tried to write a python function as the destructor, but it can't pass CI. PyTorch implemented the destructor using Python API in C++, and CuPy implemented it by cython, namely the code will be built by C++. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. I knew the trick and tried it in my previous PR. But it failed in Windows Test. It seems that the CI of TVM doesn't have Windows Test so the CI is passed. In Linux, the destructor is called first, then the destructor is released. So it works. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is strange as destructor itself sits in the global scope and should be destructed after the dltensors(which have a local scope) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see two problems in your particular gist you paste.
cfunc = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def dfunc(dltensor):
pycaps = ctypes.cast(dltensor, ctypes.py_object)
pass
c_destructor = cfunc(dfunc)
c_str_dltensor = ctypes.c_char_p(b"dltensor")
def test():
a = ctypes.pythonapi.PyCapsule_New(1, c_str_dltensor, c_destructor)
test() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you! |
||
|
||
/*! | ||
* \brief get the type of the data in NDArray | ||
* \param handle the handle to the narray | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,6 +104,14 @@ class TBlob { | |
: dptr_(dptr), shape_(shape), type_flag_(type_flag) { | ||
SetDLTensor(dev_mask, dev_id); | ||
} | ||
/*! | ||
* \brief constructor that construct TBlob from DLTensor | ||
* \param DLTensor Object | ||
*/ | ||
explicit TBlob(const DLTensor &dltensor) : dptr_(dltensor.data), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to add compactness check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Specifically, TBlob only support compact tensors, need to check strides == null or the strides reflect a compact setting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. I will move the strides check from ndarray.cpp to tensor_blob.h. |
||
shape_(TShape(dltensor.shape, dltensor.shape + dltensor.ndim)), | ||
type_flag_(DLDataTypeTransform(dltensor.dtype)), dltensor_(dltensor) { | ||
} | ||
/*! | ||
* \brief constructor from tensor | ||
* \param src source tensor | ||
|
@@ -336,14 +344,51 @@ class TBlob { | |
} | ||
} | ||
} | ||
static int DLDataTypeTransform(DLDataType dldata_type) { | ||
if (dldata_type.lanes != 1) { | ||
LOG(FATAL) << "Unsupported DLDataType whose lanes != 1"; | ||
} | ||
switch (dldata_type.code) { | ||
case kDLFloat: | ||
switch (dldata_type.bits) { | ||
case 16: | ||
return mshadow::kFloat16; | ||
case 32: | ||
return mshadow::kFloat32; | ||
case 64: | ||
return mshadow::kFloat64; | ||
} | ||
break; | ||
case kDLUInt: | ||
switch (dldata_type.bits) { | ||
case 8: | ||
return mshadow::kUint8; | ||
} | ||
break; | ||
case kDLInt: | ||
switch (dldata_type.bits) { | ||
case 8: | ||
return mshadow::kInt8; | ||
case 32: | ||
return mshadow::kInt32; | ||
case 64: | ||
return mshadow::kInt64; | ||
} | ||
break; | ||
} | ||
LOG(FATAL) << "Unknown DLDataType{" << dldata_type.code | ||
<< ", " << dldata_type.bits | ||
<< ", " << dldata_type.lanes << "}"; | ||
return mshadow::kFloat32; | ||
} | ||
|
||
inline void SetDLTensor(int dev_mask, int dev_id) { | ||
dltensor_.data = dptr_; | ||
dltensor_.ctx = DLContext{static_cast<DLDeviceType>(dev_mask), dev_id}; | ||
dltensor_.ndim = shape_.ndim(); | ||
dltensor_.dtype = DTypeTransform(type_flag_); | ||
dltensor_.shape = shape_.data(); | ||
dltensor_.strides = NULL; | ||
dltensor_.strides = nullptr; | ||
dltensor_.byte_offset = 0; | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,21 +31,24 @@ | |
|
||
class NDArrayBase(object): | ||
"""Base data structure for ndarray""" | ||
__slots__ = ["handle", "writable"] | ||
__slots__ = ["handle", "writable", "dlpack"] | ||
# pylint: disable= no-member | ||
|
||
def __init__(self, handle, writable=True): | ||
def __init__(self, handle, writable=True, dlpack=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dlpack should not be part of the member, the PyCapsule manages itself There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The dlpack in NDArray is PyCapsule which is the return value of
NDArray doesn't have the deleter function, so I made dlpack as a member of NDArray. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A better way is to keep NDArray's shared_ptr inside the manager_ctx itself, you can take a look at TVM's NDArray to DLManagedTesnor impl There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NDArray in MXNet and TVM are different. NDArray in TVM has the function Setting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can create a new NDArray() that copies the original NDArray(which increases refcount) and put that as a context There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In your case, when a get deleted, b still holds a NDArrayDLManager, which is allocated by new, and that object still hold NDArray(which holds a shared_ptr), so the original resource won't be released There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need to be careful to use shape from the same NDArray in your NDArrayDLManager There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. In the other case, from torch.utils import dlpack
a = torch.array([1,2,3])
pack = dlpack.to_dlpack(a)
b = mx.nd.from_dlpack(pack)
del a, pack When In my PR, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you copy the NDArray, they hold the same shared_ptr to the data, note that shared_ptr can be copied, and its ref counter is automatically managed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I made a copy of NDArray as the member of NDArrayDLManager, and the copy increase the refcount. Which object will call the deleter function? In my case, when |
||
"""initialize a new NDArray | ||
|
||
Parameters | ||
---------- | ||
handle : NDArrayHandle | ||
NDArray handle of C API | ||
dlpack : PyCapsule (DLPack) | ||
DLPack Object | ||
""" | ||
if handle is not None: | ||
assert isinstance(handle, NDArrayHandle) | ||
self.handle = handle | ||
self.writable = writable | ||
self.dlpack = dlpack | ||
|
||
def __del__(self): | ||
check_call(_LIB.MXNDArrayFree(self.handle)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From API point of view, we can just expose ToDLPack, and in the python API, explicitly call wait_for_read and wait_for_write
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will modify it.