Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DLPack Conversion API #1573

Merged
merged 9 commits into from
Aug 10, 2018
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,32 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVMArrayHandle to,
TVMStreamHandle stream);

/*!
* \brief Produce an array from the DLManagedTensor that shares data memory
* with the DLManagedTensor.
* \param from The source DLManagedTensor.
* \param out The output array handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from,
TVMArrayHandle* out);

/*!
* \brief Produce a DLMangedTensor from the array that shares data memory with
* the array.
* \param from The source array.
* \param out The DLManagedTensor handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayToDLPack(TVMArrayHandle from,
DLManagedTensor** out);

/*!
* \brief Delete (free) a DLManagedTensor's data.
* \param dltensor Pointer to the DLManagedTensor.
*/
TVM_DLL void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor);

/*!
* \brief Create a new runtime stream.
*
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class NDArray {
* that is DLPack compatible.
*
* The memory is retained until the NDArray went out of scope.
*
* \param tensor The DLPack tensor to copy from.
* \return The created NDArray view.
*/
TVM_DLL static NDArray FromDLPack(DLManagedTensor* tensor);
Expand Down
62 changes: 61 additions & 1 deletion python/tvm/_ffi/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import ctypes
import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle
from .runtime_ctypes import TypeCode, tvm_shape_index_t

Expand All @@ -28,6 +28,17 @@
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase


TVMPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str('dltensor')


# used for PyCapsule manipulation
if hasattr(ctypes, 'pythonapi'):
ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object


def context(dev_type, dev_id=0):
"""Construct a TVM context with given device type and id.

Expand Down Expand Up @@ -62,6 +73,7 @@ def context(dev_type, dev_id=0):
dev_type = TVMContext.STR2MASK[dev_type]
return TVMContext(dev_type, dev_id)


def numpyasarray(np_data):
"""Return a TVMArray representation of a numpy array.
"""
Expand Down Expand Up @@ -112,6 +124,42 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
ctypes.byref(handle)))
return _make_array(handle, False)


def from_dlpack(dltensor):
"""Produce an array from a DLPack tensor without memory copy.
Retreives the underlying DLPack tensor's pointer to create an array from the
data. Removes the original DLPack tensor's destructor as now the array is
responsible for destruction.

Parameters
----------
dltensor : DLPack tensor

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add description

Returns
-------
arr: tvm.nd.NDArray
The array view of the tensor data.
"""
dltensor = ctypes.py_object(dltensor)
name = ctypes.pythonapi.PyCapsule_GetName(dltensor)
ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, name)
handle = TVMArrayHandle()
check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle)))
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, None)
return _make_array(handle, False)


def _dlpack_deleter(pycapsule):
pycapsule = ctypes.py_object(pycapsule)
if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor):
ptr = ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor)
_LIB.TVMDLManagedTensorCallDeleter(ptr)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))


_c_dlpack_deleter = TVMPyCapsuleDestructor(_dlpack_deleter)


class NDArrayBase(_NDArrayBase):
"""A simple Device/CPU Array object in runtime."""
@property
Expand Down Expand Up @@ -260,6 +308,18 @@ def copyto(self, target):
raise ValueError("Unsupported target type %s" % str(type(target)))
return target

def to_dlpack(self):
"""Produce an array from a DLPack Tensor without copying memory

Returns
-------
dlpack : DLPack tensor view of the array data
"""
handle = ctypes.c_void_p()
check_call(_LIB.TVMArrayToDLPack(self.handle, ctypes.byref(handle)))
return ctypes.pythonapi.PyCapsule_New(handle, _c_str_dltensor, _c_dlpack_deleter)


def free_extension_handle(handle, type_code):
"""Free c++ extension type handle

Expand Down
43 changes: 43 additions & 0 deletions python/tvm/contrib/dlpack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Wrapping functions to bridge frameworks with DLPack support to TVM"""
from .. import ndarray

def convert_func(tvm_func, tensor_type, to_dlpack_func):
"""Convert a tvm function into one that accepts a tensor from another
framework, provided the other framework supports DLPACK

Parameters
----------
tvm_func: Function
Built tvm function operating on arrays

tensor_type: Type
Type of the tensors of the target framework

to_dlpack_func: Function
Function to convert the source tensors to DLPACK
"""
assert(callable(tvm_func))

def _wrapper(*args):
args = tuple(ndarray.from_dlpack(to_dlpack_func(arg))\
if isinstance(arg, tensor_type) else arg for arg in args)
return tvm_func(*args)

return _wrapper

def to_pytorch_func(tvm_func):
"""Convert a tvm function into one that accepts PyTorch tensors

Parameters
----------
tvm_func: Function
Built tvm function operating on arrays

Returns
-------
wrapped_func: Function
Wrapped tvm function that operates on PyTorch tensors
"""
import torch
import torch.utils.dlpack
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a check to see if tvm_func is callable and raise error if it is not

return convert_func(tvm_func, torch.Tensor, torch.utils.dlpack.to_dlpack)
2 changes: 1 addition & 1 deletion python/tvm/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as _np

from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty
from ._ffi.ndarray import context, empty, from_dlpack
from ._ffi.ndarray import _set_class_ndarray
from ._ffi.ndarray import register_extension, free_extension_handle

Expand Down
36 changes: 29 additions & 7 deletions src/runtime/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ struct NDArray::Internal {
arr.data_ = nullptr;
return tensor;
}
// Container to DLManagedTensor
static DLManagedTensor* ToDLPack(NDArray::Container* from) {
CHECK(from != nullptr);
DLManagedTensor* ret = new DLManagedTensor();
ret->dl_tensor = from->dl_tensor;
ret->manager_ctx = from;
from->IncRef();
ret->deleter = NDArrayDLPackDeleter;
return ret;
}
};

NDArray NDArray::CreateView(std::vector<int64_t> shape,
Expand All @@ -115,13 +125,7 @@ NDArray NDArray::CreateView(std::vector<int64_t> shape,
}

DLManagedTensor* NDArray::ToDLPack() const {
CHECK(data_ != nullptr);
DLManagedTensor* ret = new DLManagedTensor();
ret->dl_tensor = data_->dl_tensor;
ret->manager_ctx = const_cast<NDArray*>(this);
data_->IncRef();
ret->deleter = NDArrayDLPackDeleter;
return ret;
return Internal::ToDLPack(data_);
}

NDArray NDArray::Empty(std::vector<int64_t> shape,
Expand Down Expand Up @@ -213,6 +217,24 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
API_END();
}

int TVMArrayFromDLPack(DLManagedTensor* from,
TVMArrayHandle* out) {
API_BEGIN();
*out = NDArray::Internal::MoveAsDLTensor(NDArray::FromDLPack(from));
API_END();
}

int TVMArrayToDLPack(TVMArrayHandle from,
DLManagedTensor** out) {
API_BEGIN();
*out = NDArray::Internal::ToDLPack(reinterpret_cast<NDArray::Container*>(from));
API_END();
}

void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) {
(*(dltensor->deleter))(dltensor);
}

int TVMArrayCopyFromBytes(TVMArrayHandle handle,
void* data,
size_t nbytes) {
Expand Down
44 changes: 44 additions & 0 deletions tests/python/contrib/test_dlpack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import tvm
import numpy as np
from tvm.contrib.dlpack import to_pytorch_func

def test():
a = np.random.randn(1337)
tvm_a = tvm.nd.array(a)
np.testing.assert_equal(tvm.nd.from_dlpack(tvm_a.to_dlpack()).asnumpy(), a)

try:
import torch
import torch.utils.dlpack

x = torch.rand(56, 56)
tvm_x = tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(x))
np.testing.assert_equal(x.numpy(), tvm_x.asnumpy())
y = tvm.nd.from_dlpack(tvm_x.to_dlpack())
np.testing.assert_equal(y.asnumpy(), tvm_x.asnumpy())
np.testing.assert_equal(torch.utils.dlpack.from_dlpack(y.to_dlpack()).numpy(), tvm_x.asnumpy())

n = tvm.convert(137)
xx = torch.rand(137,137)
yy = torch.rand(137,137)
zz2 = torch.empty(137,137)
zz = xx.mm(yy)
XX = tvm.placeholder((n,n), name='X')
YY = tvm.placeholder((n,n), name='Y')

k = tvm.reduce_axis((0, n), name='k')
ZZ = tvm.compute((n,n), lambda i,j : tvm.sum(XX[i,k]*YY[k,j], axis=k))
s = tvm.create_schedule(ZZ.op)
f = tvm.build(s, [XX, YY, ZZ], target_host='llvm', name='f')

f_pytorch = to_pytorch_func(f)
zz2 = torch.empty(137,137)
f_pytorch(xx, yy, zz2)
np.testing.assert_allclose(zz.numpy(), zz2.numpy(), rtol=1e-6)

except ImportError:
pass


if __name__ == '__main__':
test()