From 2cb8a3c57505e96c73f5bcb646c756736e1cb629 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 2 Sep 2024 17:42:28 +0800 Subject: [PATCH 01/21] support shared memory via dlpack tensor format --- paddle/fluid/framework/tensor_util.cc | 120 +++++++++++++++++++++----- paddle/fluid/framework/tensor_util.h | 2 +- 2 files changed, 100 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 9addd09249d045..ed865d0fe8cf33 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -794,6 +794,77 @@ void* GetDstPtrByDLDataType(DLDataType type, } } +phi::Place GetPlaceFromPtr(void* data) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 10000 + cudaPointerAttributes attr = {}; + cudaError_t status = cudaPointerGetAttributes(&attr, data); + if (status == cudaSuccess && attr.type == cudaMemoryTypeDevice) { + return phi::GPUPlace(attr.device); + } +#else + PADDLE_THROW( + phi::errors::Unimplemented("The GetPlaceFromPtr() method is only " + "supported when CUDA version >= 10.0.")); +#endif +#else + hipPointerAttribute_t attr = {}; + hipError_t status = hipPointerGetAttributes(&attr, data); + if (status == hipSuccess && attr.memoryType == hipMemoryTypeDevice) { + return phi::GPUPlace(attr.device); + } +#endif +#endif + return phi::CPUPlace(); +} + +using Deleter = std::function; + +phi::DenseTensor from_blob(void* data, + const phi::DDim& shape, + phi::DataType dtype, + phi::DataLayout layout, + const phi::Place& place, + const Deleter& deleter) { + PADDLE_ENFORCE_NOT_NULL( + data, phi::errors::InvalidArgument("data can not be nullptr.")); + + using AllocationDeleter = void (*)(phi::Allocation*); + phi::Place data_place; + if (place.GetType() == phi::AllocationType::UNDEFINED || + place.GetType() == phi::AllocationType::CPU || + place.GetType() == phi::AllocationType::GPU) { + data_place = GetPlaceFromPtr(data); + if (place.GetType() != phi::AllocationType::UNDEFINED) { + PADDLE_ENFORCE_EQ(data_place, + place, + phi::errors::InvalidArgument( + "Specified place does not match place of data. ", + "Specified: %s, Expected: %s.", + data_place.DebugString(), + place.DebugString())); + } + } else { + data_place = place; + } + + auto meta = phi::DenseTensorMeta(dtype, shape, layout); + + size_t size = SizeOf(dtype) * (meta.is_scalar ? 1 : product(meta.dims)); + + AllocationDeleter alloc_deleter = nullptr; + if (deleter) { + static thread_local Deleter g_deleter = deleter; + alloc_deleter = [](phi::Allocation* p) { g_deleter(p->ptr()); }; + } + + auto alloc = + std::make_shared(data, size, alloc_deleter, data_place); + + return std::move(phi::DenseTensor(alloc, meta)); +} + void TensorFromDLPack(const ::DLTensor& dl_tensor, phi::DenseTensor* dst) { phi::CPUPlace dst_place = phi::CPUPlace(); phi::CPUPlace src_place = phi::CPUPlace(); @@ -834,41 +905,48 @@ void TensorFromDLPack(const ::DLTensor& dl_tensor, phi::DenseTensor* dst) { #endif } -void TensorFromDLPack(const DLManagedTensor* src, phi::DenseTensor* dst) { +void TensorFromDLPack(DLManagedTensor* src, phi::DenseTensor* dst) { std::vector vec; std::copy(src->dl_tensor.shape, src->dl_tensor.shape + src->dl_tensor.ndim, std::back_inserter(vec)); phi::DDim vddim = common::make_ddim(vec); - dst->Resize(vddim); - ::DLDataType type = src->dl_tensor.dtype; - - auto src_ptr = static_cast(src->dl_tensor.data); - auto size = common::product(vddim) * type.bits / 8; - + auto src_ptr = static_cast(src->dl_tensor.data); if (src->dl_tensor.device.device_type == kDLCPU) { phi::CPUPlace dst_place = phi::CPUPlace(); - phi::CPUPlace src_place = phi::CPUPlace(); - void* dst_ptr = GetDstPtrByDLDataType(type, dst, dst_place); - memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); + + auto deleter = [src](void* self [[maybe_unused]]) { + if (src->deleter) { + src->deleter(src); + } + }; + + *dst = from_blob(src_ptr, + vddim, + dst->dtype(), + phi::DataLayout::NCHW, + dst_place, + deleter); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (src->dl_tensor.device.device_type == kDLGPU) { phi::GPUPlace dst_place = phi::GPUPlace(src->dl_tensor.device.device_id); - phi::GPUPlace src_place = phi::GPUPlace(src->dl_tensor.device.device_id); - void* dst_ptr = GetDstPtrByDLDataType(type, dst, dst_place); - auto* ctx = phi::DeviceContextPool::Instance().GetByPlace(dst_place); - // Fix copy by share allocation. - memory::Copy(dst_place, - dst_ptr, - src_place, - src_ptr, - size, - reinterpret_cast(*ctx).stream()); + + auto deleter = [src](void* self [[maybe_unused]]) { + if (src->deleter) { + src->deleter(src); + } + }; + + *dst = from_blob(src_ptr, + vddim, + dst->dtype(), + phi::DataLayout::NCHW, + dst_place, + deleter); } #endif - src->deleter(const_cast(src)); #ifdef PADDLE_WITH_XPU PADDLE_THROW(phi::errors::Unimplemented("XPUPlace is not supported")); #endif diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index 5c7ea2e7f315e0..1339950ed19651 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -109,7 +109,7 @@ void TensorToVector(const phi::DenseTensor& src, std::vector* dst); TEST_API void TensorFromDLPack(const ::DLTensor& dl_tensor, phi::DenseTensor* dst); -void TensorFromDLPack(const DLManagedTensor* src, phi::DenseTensor* dst); +void TensorFromDLPack(DLManagedTensor* src, phi::DenseTensor* dst); // // The implementation of template functions. From c0d6fd4a86f3d839ee65beb3aa8dd7e491a2cbf7 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 2 Sep 2024 20:21:57 +0800 Subject: [PATCH 02/21] remove deleter --- paddle/fluid/framework/tensor_util.cc | 30 +++++++++++++++------------ 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index ed865d0fe8cf33..cfaff0b550c49f 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -862,10 +862,11 @@ phi::DenseTensor from_blob(void* data, auto alloc = std::make_shared(data, size, alloc_deleter, data_place); - return std::move(phi::DenseTensor(alloc, meta)); + return phi::DenseTensor(alloc, meta); } void TensorFromDLPack(const ::DLTensor& dl_tensor, phi::DenseTensor* dst) { + std::cout << "call DLTensor version" << std::endl; phi::CPUPlace dst_place = phi::CPUPlace(); phi::CPUPlace src_place = phi::CPUPlace(); @@ -906,6 +907,7 @@ void TensorFromDLPack(const ::DLTensor& dl_tensor, phi::DenseTensor* dst) { } void TensorFromDLPack(DLManagedTensor* src, phi::DenseTensor* dst) { + std::cout << "call DLManagedTensor version" << std::endl; std::vector vec; std::copy(src->dl_tensor.shape, src->dl_tensor.shape + src->dl_tensor.ndim, @@ -916,35 +918,37 @@ void TensorFromDLPack(DLManagedTensor* src, phi::DenseTensor* dst) { if (src->dl_tensor.device.device_type == kDLCPU) { phi::CPUPlace dst_place = phi::CPUPlace(); - auto deleter = [src](void* self [[maybe_unused]]) { - if (src->deleter) { - src->deleter(src); - } - }; + // auto deleter = [src](void* self [[maybe_unused]]) { + // if (src->deleter) { + // std::cout<< "delete ptr 1" << std::endl; + // src->deleter(src); + // } + // }; *dst = from_blob(src_ptr, vddim, dst->dtype(), phi::DataLayout::NCHW, dst_place, - deleter); + nullptr); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (src->dl_tensor.device.device_type == kDLGPU) { phi::GPUPlace dst_place = phi::GPUPlace(src->dl_tensor.device.device_id); - auto deleter = [src](void* self [[maybe_unused]]) { - if (src->deleter) { - src->deleter(src); - } - }; + // auto deleter = [src](void* self [[maybe_unused]]) { + // if (src->deleter) { + // std::cout << "delete ptr 2" << std::endl; + // src->deleter(src); + // } + // }; *dst = from_blob(src_ptr, vddim, dst->dtype(), phi::DataLayout::NCHW, dst_place, - deleter); + nullptr); } #endif #ifdef PADDLE_WITH_XPU From e2b0d9ce59eb1bd252f0a7f30eb946c69a4ad6bc Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 9 Sep 2024 09:24:32 +0000 Subject: [PATCH 03/21] update todlpack code --- paddle/fluid/framework/dlpack_tensor.cc | 50 +++---- paddle/fluid/framework/tensor_util.cc | 179 +++++++++--------------- paddle/fluid/framework/tensor_util.h | 12 +- paddle/fluid/pybind/pybind.cc | 43 +++--- paddle/fluid/pybind/tensor.cc | 26 ++-- python/paddle/utils/dlpack.py | 56 +++++--- test/legacy_test/test_dlpack.py | 9 ++ 7 files changed, 178 insertions(+), 197 deletions(-) diff --git a/paddle/fluid/framework/dlpack_tensor.cc b/paddle/fluid/framework/dlpack_tensor.cc index 49896e0fbbef7b..a20a41747c8ce1 100644 --- a/paddle/fluid/framework/dlpack_tensor.cc +++ b/paddle/fluid/framework/dlpack_tensor.cc @@ -120,55 +120,39 @@ struct DLDeviceVisitor { }; } // namespace internal -struct PaddleDLMTensor { +struct PTenDLMTensor { phi::DenseTensor handle; DLManagedTensor tensor; - PaddleDLMTensor() : tensor() {} + // PTenDLMTensor() : tensor() {} }; -void deleter(DLManagedTensor *arg) { - delete[] arg->dl_tensor.shape; - delete[] arg->dl_tensor.strides; - delete static_cast(arg->manager_ctx); +static void deleter(DLManagedTensor *self) { + delete static_cast(self->manager_ctx); } DLManagedTensor *toDLPack(const phi::DenseTensor &src) { - PaddleDLMTensor *pdDLMTensor(new PaddleDLMTensor); + // init shape + auto shape = common::vectorize(src.dims()); + auto strides = common::vectorize(src.strides()); + for (int i = 0; i < src.dims().size(); i++) { + if (shape[i] < 2) { + strides[i] = 1; + } + } + + PTenDLMTensor *pdDLMTensor(new PTenDLMTensor); pdDLMTensor->handle = const_cast(src); pdDLMTensor->tensor.manager_ctx = pdDLMTensor; pdDLMTensor->tensor.deleter = &deleter; pdDLMTensor->tensor.dl_tensor.data = const_cast(src.data()); - - // init ndim - using DimType = decltype(pdDLMTensor->tensor.dl_tensor.ndim); // int - pdDLMTensor->tensor.dl_tensor.ndim = static_cast(src.dims().size()); - DimType ndim = pdDLMTensor->tensor.dl_tensor.ndim; - - // init shape - auto shape = new int64_t[ndim]; - for (DimType i = 0; i < ndim; ++i) { - shape[i] = src.dims()[i]; - } - pdDLMTensor->tensor.dl_tensor.shape = shape; - - // init stride - auto strides = new int64_t[ndim]; - for (DimType i = 0; i < ndim; ++i) { - strides[i] = 1; - } - for (DimType i = ndim - 2; i >= 0; --i) { - strides[i] = shape[i + 1] * strides[i + 1]; - } - pdDLMTensor->tensor.dl_tensor.strides = strides; - - // init device, DLDevice type with device_type and device_id auto place = src.place(); pdDLMTensor->tensor.dl_tensor.device = phi::VisitPlace(place, internal::DLDeviceVisitor()); - + pdDLMTensor->tensor.dl_tensor.ndim = static_cast(src.dims().size()); pdDLMTensor->tensor.dl_tensor.dtype = internal::GetDLDataTypeFromTypeIndex( framework::TransToProtoVarType(src.dtype())); - + pdDLMTensor->tensor.dl_tensor.shape = shape.data(); + pdDLMTensor->tensor.dl_tensor.strides = strides.data(); pdDLMTensor->tensor.dl_tensor.byte_offset = 0; return &(pdDLMTensor->tensor); } diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index cfaff0b550c49f..d915f9c49606dc 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -31,7 +31,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_DNNL #include "dnnl_debug.h" // NOLINT #endif - +#include namespace paddle { namespace framework { @@ -819,141 +819,92 @@ phi::Place GetPlaceFromPtr(void* data) { return phi::CPUPlace(); } +// NOTE: set 'g_deleter' and 'g_src' as static variable +// for being used in lambda function 'f' without specified in capture list using Deleter = std::function; +using AllocationDeleter = void (*)(phi::Allocation*); + +static Deleter g_deleter = nullptr; +static DLManagedTensor* g_src = nullptr; + +void set_global_deleter_and_src(const Deleter& deleter, DLManagedTensor* src) { + /*Update static variable g_deleter and g_src manually each time in a + * function*/ + g_deleter = deleter; + g_src = src; +} -phi::DenseTensor from_blob(void* data, - const phi::DDim& shape, - phi::DataType dtype, - phi::DataLayout layout, - const phi::Place& place, - const Deleter& deleter) { +inline phi::DenseTensor from_blob(void* data, + DLManagedTensor* src, + const phi::DDim& shape, + phi::DataType dtype, + phi::DataLayout layout, + const phi::Place& place, + const Deleter& deleter) { PADDLE_ENFORCE_NOT_NULL( data, phi::errors::InvalidArgument("data can not be nullptr.")); - using AllocationDeleter = void (*)(phi::Allocation*); - phi::Place data_place; - if (place.GetType() == phi::AllocationType::UNDEFINED || - place.GetType() == phi::AllocationType::CPU || - place.GetType() == phi::AllocationType::GPU) { - data_place = GetPlaceFromPtr(data); - if (place.GetType() != phi::AllocationType::UNDEFINED) { - PADDLE_ENFORCE_EQ(data_place, - place, - phi::errors::InvalidArgument( - "Specified place does not match place of data. ", - "Specified: %s, Expected: %s.", - data_place.DebugString(), - place.DebugString())); - } - } else { - data_place = place; - } - auto meta = phi::DenseTensorMeta(dtype, shape, layout); - size_t size = SizeOf(dtype) * (meta.is_scalar ? 1 : product(meta.dims)); - - AllocationDeleter alloc_deleter = nullptr; + AllocationDeleter f = nullptr; if (deleter) { - static thread_local Deleter g_deleter = deleter; - alloc_deleter = [](phi::Allocation* p) { g_deleter(p->ptr()); }; + set_global_deleter_and_src(deleter, src); + f = [](phi::Allocation* p) { + if (g_src) { + g_deleter(g_src); + } + }; } - - auto alloc = - std::make_shared(data, size, alloc_deleter, data_place); - + auto alloc = std::make_shared(data, size, f, place); return phi::DenseTensor(alloc, meta); } -void TensorFromDLPack(const ::DLTensor& dl_tensor, phi::DenseTensor* dst) { - std::cout << "call DLTensor version" << std::endl; - phi::CPUPlace dst_place = phi::CPUPlace(); - phi::CPUPlace src_place = phi::CPUPlace(); - - std::vector vec; - std::copy(dl_tensor.shape, - dl_tensor.shape + dl_tensor.ndim, - std::back_inserter(vec)); - - phi::DDim vddim = common::make_ddim(vec); - - dst->Resize(vddim); - ::DLDataType type = dl_tensor.dtype; - void* dst_ptr = GetDstPtrByDLDataType(type, dst, dst_place); - - auto src_ptr = static_cast(dl_tensor.data); - auto size = common::product(vddim) * type.bits / 8; - - if (dl_tensor.device.device_type == kDLCPU) { - memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); - } -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (dl_tensor.device.device_type == kDLGPU) { - phi::GPUPlace dst_place = phi::GPUPlace(dl_tensor.device.device_id); - phi::GPUPlace src_place = phi::GPUPlace(dl_tensor.device.device_id); - dst_ptr = GetDstPtrByDLDataType(type, dst, dst_place); - auto* ctx = phi::DeviceContextPool::Instance().GetByPlace(dst_place); - memory::Copy(dst_place, - dst_ptr, - src_place, - src_ptr, - size, - reinterpret_cast(*ctx).stream()); - } -#endif -#ifdef PADDLE_WITH_XPU - PADDLE_THROW(phi::errors::Unimplemented("XPUPlace is not supported")); -#endif -} - -void TensorFromDLPack(DLManagedTensor* src, phi::DenseTensor* dst) { - std::cout << "call DLManagedTensor version" << std::endl; +phi::DenseTensor TensorFromDLPack(DLManagedTensor* src, + std::function deleter) { std::vector vec; std::copy(src->dl_tensor.shape, src->dl_tensor.shape + src->dl_tensor.ndim, std::back_inserter(vec)); - phi::DDim vddim = common::make_ddim(vec); - auto src_ptr = static_cast(src->dl_tensor.data); + phi::Place place; if (src->dl_tensor.device.device_type == kDLCPU) { - phi::CPUPlace dst_place = phi::CPUPlace(); - - // auto deleter = [src](void* self [[maybe_unused]]) { - // if (src->deleter) { - // std::cout<< "delete ptr 1" << std::endl; - // src->deleter(src); - // } - // }; - - *dst = from_blob(src_ptr, - vddim, - dst->dtype(), - phi::DataLayout::NCHW, - dst_place, - nullptr); + place = phi::CPUPlace(); + } else if (src->dl_tensor.device.device_type == kDLGPU) { + place = phi::GPUPlace(); + } else { + PADDLE_THROW(phi::errors::Unimplemented("Given Place is not supported")); } -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (src->dl_tensor.device.device_type == kDLGPU) { - phi::GPUPlace dst_place = phi::GPUPlace(src->dl_tensor.device.device_id); - - // auto deleter = [src](void* self [[maybe_unused]]) { - // if (src->deleter) { - // std::cout << "delete ptr 2" << std::endl; - // src->deleter(src); - // } - // }; - *dst = from_blob(src_ptr, - vddim, - dst->dtype(), + if (!src->dl_tensor.strides) { + return from_blob(src->dl_tensor.data, + src, + common::make_ddim(vec), + phi::DataType::FLOAT32, + phi::DataLayout::NCHW, + place, + std::move(deleter)); + } else { + return from_blob(src->dl_tensor.data, + src, + common::make_ddim(vec), + phi::DataType::FLOAT32, phi::DataLayout::NCHW, - dst_place, - nullptr); + place, + std::move(deleter)); } -#endif -#ifdef PADDLE_WITH_XPU - PADDLE_THROW(phi::errors::Unimplemented("XPUPlace is not supported")); -#endif +} + +/* +refer: +https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/DLConvertor.cpp#L297 +*/ +phi::DenseTensor TensorFromDLPack(DLManagedTensor* src) { + auto deleter = [src](void* self [[maybe_unused]]) { + if (src->deleter) { + src->deleter(src); + } + }; + return TensorFromDLPack(src, std::move(deleter)); } template diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index 1339950ed19651..b396fd5d55e159 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -107,10 +107,14 @@ void TensorToVector(const phi::DenseTensor& src, std::vector* dst); // convert dlpack's DLTensor to tensor -TEST_API void TensorFromDLPack(const ::DLTensor& dl_tensor, - phi::DenseTensor* dst); -void TensorFromDLPack(DLManagedTensor* src, phi::DenseTensor* dst); - +// TEST_API void TensorFromDLPack(const ::DLTensor& dl_tensor, +// phi::DenseTensor* dst); +TEST_API phi::DenseTensor TensorFromDLPack(DLManagedTensor* src); +inline phi::DenseTensor fromDLPack(const DLManagedTensor* src) { + return fromDLPack(const_cast(src)); +} +phi::DenseTensor TensorFromDLPack(DLManagedTensor* src, + std::function deleter); // // The implementation of template functions. // diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 0cdf0e575eaa8c..1a6abe5aa8fd7f 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1259,29 +1259,34 @@ PYBIND11_MODULE(libpaddle, m) { phi::DeviceContextPool::Instance().Get(place)->Wait(); }); - m.def("from_dlpack", [](py::capsule *dltensor) { - DLManagedTensor *dmt = reinterpret_cast( - PyCapsule_GetPointer(dltensor->ptr(), "dltensor")); + m.def("from_dlpack", [](py::object data) { + DLManagedTensor *dlMTensor = reinterpret_cast( + PyCapsule_GetPointer(data.ptr(), "dltensor")); PADDLE_ENFORCE_NOT_NULL( - dmt, + dlMTensor, phi::errors::InvalidArgument( "from_dlpack received an invalid capsule. " - "Note that a DLPack tensor can be consumed only once.")); - - PyCapsule_SetName(dltensor->ptr(), "used_dltensor"); - DLTensor dl = dmt->dl_tensor; - phi::DenseTensor tensor; - - if (dl.device.device_type == kDLCPU) { - paddle::framework::TensorFromDLPack(dmt, &tensor); - } -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (dl.device.device_type == kDLGPU) { - paddle::framework::TensorFromDLPack(dmt, &tensor); - } -#endif - return tensor; + "Note that DLTensor capsules can be consumed only once, " + "so you might have already constructed a tensor from it once.")); + + // auto deleter_with_gil = [dlMTensor](void*) { + // if (dlMTensor->deleter) { + // pybind11::gil_scoped_acquire gil; + // printf("call gil deleter begin\n"); + // dlMTensor->deleter(dlMTensor); + // printf("call gil deleter end\n"); + // } else { + // printf("dlMTensor->deleter is nullptr\n"); + // } + // }; + + // auto ptensor = paddle::framework::TensorFromDLPack(dlMTensor, + // std::move(deleter_with_gil)); + auto ptensor = paddle::framework::TensorFromDLPack(dlMTensor); + + PyCapsule_SetName(data.ptr(), "used_dltensor"); + return ptensor; }); m.def("_create_loaded_parameter", diff --git a/paddle/fluid/pybind/tensor.cc b/paddle/fluid/pybind/tensor.cc index 2556a3d2081e45..5372961c1aa222 100644 --- a/paddle/fluid/pybind/tensor.cc +++ b/paddle/fluid/pybind/tensor.cc @@ -424,17 +424,21 @@ void BindTensor(pybind11::module &m) { // NOLINT )DOC") .def("_to_dlpack", [](phi::DenseTensor &self) { - DLManagedTensor *dmt = framework::toDLPack(self); - auto capsule = pybind11::capsule( - static_cast(dmt), "dltensor", [](PyObject *ptr) { - if (!PyCapsule_IsValid(ptr, "dltensor")) { - return; - } - DLManagedTensor *dmt = static_cast( - PyCapsule_GetPointer(ptr, "dltensor")); - dmt->deleter(dmt); - }); - return capsule; + DLManagedTensor *dlMTensor = framework::toDLPack(self); + + auto DLPack_Capsule_Destructor = [](PyObject *data) { + if (!PyCapsule_IsValid(data, "dltensor")) { + // early out, see DLPack spec: if a consuming library sets the + // capsule name to something else, they own it and we don't + // need to do anything + return; + } + DLManagedTensor *dlMTensor = reinterpret_cast( + PyCapsule_GetPointer(data, "dltensor")); + dlMTensor->deleter(dlMTensor); + }; + return PyCapsule_New( + dlMTensor, "dltensor", DLPack_Capsule_Destructor); }) .def("_set_float_element", TensorSetElement) .def("_get_float_element", TensorGetElement) diff --git a/python/paddle/utils/dlpack.py b/python/paddle/utils/dlpack.py index 77c077c87260e1..99d027eae23011 100644 --- a/python/paddle/utils/dlpack.py +++ b/python/paddle/utils/dlpack.py @@ -14,6 +14,7 @@ from __future__ import annotations +import enum from typing import TYPE_CHECKING import paddle @@ -33,6 +34,20 @@ ] +class DLDeviceType(enum.IntEnum): + # Enums as in DLPack specification (aten/src/ATen/dlpack.h) + kDLCPU = (1,) + kDLGPU = (2,) + kDLCPUPinned = (3,) + kDLOpenCL = (4,) + kDLVulkan = (7,) + kDLMetal = (8,) + kDLVPI = (9,) + kDLROCM = (10,) + kDLExtDev = (12,) + kDLOneAPI = (14,) + + def to_dlpack(x: Tensor) -> CapsuleType: """ Encodes a tensor to DLPack. @@ -71,7 +86,7 @@ def to_dlpack(x: Tensor) -> CapsuleType: return x._to_dlpack() -def from_dlpack(dlpack: CapsuleType) -> Tensor: +def from_dlpack(ext_tensor) -> Tensor: """ Decodes a DLPack to a tensor. @@ -98,18 +113,27 @@ def from_dlpack(dlpack: CapsuleType) -> Tensor: [0.10000000, 0.20000000, 0.60000002, 0.69999999]]) """ - t = type(dlpack) - dlpack_flag = t.__module__ == 'builtins' and t.__name__ == 'PyCapsule' - if not dlpack_flag: - raise TypeError( - "The type of 'dlpack' in from_dlpack must be PyCapsule object," - f" but received {type(dlpack)}." - ) - - if in_dygraph_mode(): - out = paddle.base.core.from_dlpack(dlpack) - out = paddle.to_tensor(out) - return out - - out = paddle.base.core.from_dlpack(dlpack) - return out + if hasattr(ext_tensor, '__dlpack__'): + device = ext_tensor.__dlpack_device__() + # device is either CUDA or ROCm, we need to pass the current + # stream + if device[0] in (DLDeviceType.kDLGPU, DLDeviceType.kDLROCM): + stream = paddle.device.cuda.current_stream(device[1]) + # cuda_stream is the pointer to the stream and it is a public + # attribute, but it is not documented + # The array API specify that the default legacy stream must be passed + # with a value of 1 for CUDA + # https://data-apis.org/array-api/latest/API_specification/array_object.html?dlpack-self-stream-none#dlpack-self-stream-none + is_gpu = device[0] == DLDeviceType.kDLGPU + # Since pytorch is not using PTDS by default, lets directly pass + # the legacy stream + stream_ptr = ( + 1 if is_gpu and stream.cuda_stream == 0 else stream.cuda_stream + ) + dlpack = ext_tensor.__dlpack__(stream=stream_ptr) + else: + dlpack = ext_tensor.__dlpack__() + else: + # Old versions just call the converter + dlpack = ext_tensor + return paddle.base.core.from_dlpack(dlpack) diff --git a/test/legacy_test/test_dlpack.py b/test/legacy_test/test_dlpack.py index dbf9505cebc3a0..5ac3e83c94d3bd 100644 --- a/test/legacy_test/test_dlpack.py +++ b/test/legacy_test/test_dlpack.py @@ -122,6 +122,15 @@ def test_to_dlpack_for_loop(self): x = paddle.rand([3, 5]) dlpack = paddle.utils.dlpack.to_dlpack(x) + def test_to_dlpack_modification(self): + # See Paddle issue 50120 + for i in range(10): + x = paddle.rand([3, 5]) + dlpack = paddle.utils.dlpack.to_dlpack(x) + y = paddle.utils.dlpack.from_dlpack(dlpack) + y[1:2, 2:5] = 2.0 + np.testing.assert_allclose(x.numpy(), y.numpy()) + class TestRaiseError(unittest.TestCase): def test_from_dlpack_raise_type_error(self): From 31f918c6eb1cb79f7bb971bb6ef0330b5da04525 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 10 Sep 2024 14:18:18 +0800 Subject: [PATCH 04/21] fix for other dtype --- paddle/fluid/framework/dlpack_tensor.cc | 42 +++++++---- paddle/fluid/framework/tensor_util.cc | 92 +++++++++---------------- paddle/fluid/pybind/tensor.cc | 34 +++++---- paddle/phi/core/allocator.h | 2 +- python/paddle/utils/dlpack.py | 8 ++- 5 files changed, 87 insertions(+), 91 deletions(-) diff --git a/paddle/fluid/framework/dlpack_tensor.cc b/paddle/fluid/framework/dlpack_tensor.cc index 6487c5b45b7734..bc0075cc8183b0 100644 --- a/paddle/fluid/framework/dlpack_tensor.cc +++ b/paddle/fluid/framework/dlpack_tensor.cc @@ -122,39 +122,57 @@ struct DLDeviceVisitor { }; } // namespace internal -struct PTenDLMTensor { +struct PaddleDLMTensor { phi::DenseTensor handle; DLManagedTensor tensor; - // PTenDLMTensor() : tensor() {} }; static void deleter(DLManagedTensor *self) { - delete static_cast(self->manager_ctx); + if (self && self->manager_ctx) { + delete[] self->dl_tensor + .shape; // delete shape allocated in toDLPack manually + delete[] self->dl_tensor + .strides; // delete strides allocated in toDLPack manually + delete static_cast(self->manager_ctx); + } } DLManagedTensor *toDLPack(const phi::DenseTensor &src) { + PaddleDLMTensor *pdDLMTensor(new PaddleDLMTensor); + pdDLMTensor->handle = const_cast(src); + pdDLMTensor->tensor.manager_ctx = pdDLMTensor; + pdDLMTensor->tensor.deleter = &deleter; + + // init ndim + using DimType = decltype(pdDLMTensor->tensor.dl_tensor.ndim); // int32_t + auto _shape = src.dims(); + pdDLMTensor->tensor.dl_tensor.ndim = static_cast(_shape.size()); + DimType ndim = pdDLMTensor->tensor.dl_tensor.ndim; + // init shape - auto shape = common::vectorize(src.dims()); - auto strides = common::vectorize(src.strides()); + int64_t *shape = new int64_t[ndim]; + for (DimType i = 0; i < ndim; ++i) { + shape[i] = _shape[i]; + } + pdDLMTensor->tensor.dl_tensor.shape = shape; + + // init strides + auto _strides = src.strides(); + int64_t *strides = new int64_t[ndim]; for (int i = 0; i < src.dims().size(); i++) { + strides[i] = _strides[i]; if (shape[i] < 2) { strides[i] = 1; } } + pdDLMTensor->tensor.dl_tensor.strides = strides; - PTenDLMTensor *pdDLMTensor(new PTenDLMTensor); - pdDLMTensor->handle = const_cast(src); - pdDLMTensor->tensor.manager_ctx = pdDLMTensor; - pdDLMTensor->tensor.deleter = &deleter; pdDLMTensor->tensor.dl_tensor.data = const_cast(src.data()); auto place = src.place(); pdDLMTensor->tensor.dl_tensor.device = phi::VisitPlace(place, internal::DLDeviceVisitor()); - pdDLMTensor->tensor.dl_tensor.ndim = static_cast(src.dims().size()); pdDLMTensor->tensor.dl_tensor.dtype = internal::GetDLDataTypeFromTypeIndex( framework::TransToProtoVarType(src.dtype())); - pdDLMTensor->tensor.dl_tensor.shape = shape.data(); - pdDLMTensor->tensor.dl_tensor.strides = strides.data(); pdDLMTensor->tensor.dl_tensor.byte_offset = 0; return &(pdDLMTensor->tensor); } diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 20913951066b79..64fafed81e5b7d 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -725,10 +725,8 @@ void TensorFromStream(std::istream& is, } } -// get tensor data point by DLDataType -void* GetDstPtrByDLDataType(DLDataType type, - phi::DenseTensor* dst, - const phi::Place& dst_place) { +// get Tensor data dtype from given DLDataType +phi::DataType GetDstPtrByDLDataType(DLDataType type) { // vector types not currently supported PADDLE_ENFORCE_LE( type.lanes, @@ -737,52 +735,37 @@ void* GetDstPtrByDLDataType(DLDataType type, switch (type.bits) { case 8: - if (type.code == kDLInt) - return static_cast(dst->mutable_data(dst_place)); - if (type.code == kDLUInt) - return static_cast(dst->mutable_data(dst_place)); + if (type.code == kDLInt) return phi::DataType::INT8; + if (type.code == kDLUInt) return phi::DataType::UINT8; PADDLE_THROW(common::errors::Unimplemented( "DLDataType code <%d> is illegal when DLDataType.bits is <%d>.", type.code, type.bits)); case 16: - if (type.code == kDLInt) - return static_cast(dst->mutable_data(dst_place)); - if (type.code == kDLFloat) - return static_cast( - dst->mutable_data(dst_place)); - if (type.code == kDLBfloat) - return static_cast( - dst->mutable_data(dst_place)); + if (type.code == kDLInt) return phi::DataType::INT16; + if (type.code == kDLFloat) return phi::DataType::FLOAT16; + if (type.code == kDLBfloat) return phi::DataType::BFLOAT16; PADDLE_THROW(common::errors::Unimplemented( "DLDataType code <%d> is illegal when DLDataType.bits is <%d>.", type.code, type.bits)); case 32: - if (type.code == kDLInt) - return static_cast(dst->mutable_data(dst_place)); - if (type.code == kDLFloat) - return static_cast(dst->mutable_data(dst_place)); + if (type.code == kDLInt) return phi::DataType::INT32; + if (type.code == kDLFloat) return phi::DataType::FLOAT32; PADDLE_THROW(common::errors::Unimplemented( "DLDataType code <%d> is illegal when DLDataType.bits is <%d>.", type.code, type.bits)); case 64: - if (type.code == kDLInt) - return static_cast(dst->mutable_data(dst_place)); - if (type.code == kDLFloat) - return static_cast(dst->mutable_data(dst_place)); - if (type.code == kDLComplex) - return static_cast( - dst->mutable_data>(dst_place)); + if (type.code == kDLInt) return phi::DataType::INT64; + if (type.code == kDLFloat) return phi::DataType::FLOAT64; + if (type.code == kDLComplex) return phi::DataType::COMPLEX64; PADDLE_THROW(common::errors::Unimplemented( "DLDataType code <%d> is illegal when DLDataType.bits is <%d>.", type.code, type.bits)); case 128: - if (type.code == kDLComplex) - return static_cast( - dst->mutable_data>(dst_place)); + if (type.code == kDLComplex) return phi::DataType::COMPLEX128; PADDLE_THROW(common::errors::Unimplemented( "DLDataType code <%d> is illegal when DLDataType.bits is <%d>.", type.code, @@ -821,25 +804,19 @@ phi::Place GetPlaceFromPtr(void* data) { // NOTE: set 'g_deleter' and 'g_src' as static variable // for being used in lambda function 'f' without specified in capture list using Deleter = std::function; -using AllocationDeleter = void (*)(phi::Allocation*); +using AllocationDeleter = std::function; -static Deleter g_deleter = nullptr; -static DLManagedTensor* g_src = nullptr; - -void set_global_deleter_and_src(const Deleter& deleter, DLManagedTensor* src) { - /*Update static variable g_deleter and g_src manually each time in a - * function*/ - g_deleter = deleter; - g_src = src; -} - -inline phi::DenseTensor from_blob(void* data, - DLManagedTensor* src, - const phi::DDim& shape, - phi::DataType dtype, - phi::DataLayout layout, - const phi::Place& place, - const Deleter& deleter) { +/* +code ref: +https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/DLConvertor.cpp +*/ +phi::DenseTensor from_blob(void* data, + DLManagedTensor* src, + const phi::DDim& shape, + phi::DataType dtype, + phi::DataLayout layout, + const phi::Place& place, + const Deleter& deleter) { PADDLE_ENFORCE_NOT_NULL( data, phi::errors::InvalidArgument("data can not be nullptr.")); @@ -847,10 +824,9 @@ inline phi::DenseTensor from_blob(void* data, size_t size = SizeOf(dtype) * (meta.is_scalar ? 1 : product(meta.dims)); AllocationDeleter f = nullptr; if (deleter) { - set_global_deleter_and_src(deleter, src); - f = [](phi::Allocation* p) { - if (g_src) { - g_deleter(g_src); + f = [deleter, src](phi::Allocation* p) { + if (src->manager_ctx) { + deleter(src); } }; } @@ -874,11 +850,13 @@ phi::DenseTensor TensorFromDLPack(DLManagedTensor* src, PADDLE_THROW(phi::errors::Unimplemented("Given Place is not supported")); } + ::DLDataType type = src->dl_tensor.dtype; + auto dtype = GetDstPtrByDLDataType(type); if (!src->dl_tensor.strides) { return from_blob(src->dl_tensor.data, src, common::make_ddim(vec), - phi::DataType::FLOAT32, + dtype, phi::DataLayout::NCHW, place, std::move(deleter)); @@ -886,17 +864,13 @@ phi::DenseTensor TensorFromDLPack(DLManagedTensor* src, return from_blob(src->dl_tensor.data, src, common::make_ddim(vec), - phi::DataType::FLOAT32, + dtype, phi::DataLayout::NCHW, place, - std::move(deleter)); + deleter); } } -/* -refer: -https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/DLConvertor.cpp#L297 -*/ phi::DenseTensor TensorFromDLPack(DLManagedTensor* src) { auto deleter = [src](void* self [[maybe_unused]]) { if (src->deleter) { diff --git a/paddle/fluid/pybind/tensor.cc b/paddle/fluid/pybind/tensor.cc index 3b57b8daa8a3d2..3228b7ac03953b 100644 --- a/paddle/fluid/pybind/tensor.cc +++ b/paddle/fluid/pybind/tensor.cc @@ -430,24 +430,22 @@ void BindTensor(pybind11::module &m) { // NOLINT >>> print(t.shape()) [5, 30] )DOC") - .def("_to_dlpack", - [](phi::DenseTensor &self) { - DLManagedTensor *dlMTensor = framework::toDLPack(self); - - auto DLPack_Capsule_Destructor = [](PyObject *data) { - if (!PyCapsule_IsValid(data, "dltensor")) { - // early out, see DLPack spec: if a consuming library sets the - // capsule name to something else, they own it and we don't - // need to do anything - return; - } - DLManagedTensor *dlMTensor = reinterpret_cast( - PyCapsule_GetPointer(data, "dltensor")); - dlMTensor->deleter(dlMTensor); - }; - return PyCapsule_New( - dlMTensor, "dltensor", DLPack_Capsule_Destructor); - }) + .def( + "_to_dlpack", + [](phi::DenseTensor &self) { + DLManagedTensor *dlMTensor = framework::toDLPack(self); + auto capsule = pybind11::capsule( + static_cast(dlMTensor), "dltensor", [](PyObject *data) { + if (!PyCapsule_IsValid(data, "dltensor")) { + return; + } + DLManagedTensor *dlMTensor = + reinterpret_cast( + PyCapsule_GetPointer(data, "dltensor")); + dlMTensor->deleter(dlMTensor); + }); + return capsule; + }) .def("_set_float_element", TensorSetElement) .def("_get_float_element", TensorGetElement) .def("_set_double_element", TensorSetElement) diff --git a/paddle/phi/core/allocator.h b/paddle/phi/core/allocator.h index 1d89fd1b4aa88b..a495860136b196 100644 --- a/paddle/phi/core/allocator.h +++ b/paddle/phi/core/allocator.h @@ -29,7 +29,7 @@ namespace phi { class Allocation { public: using Place = phi::Place; - using DeleterFnPtr = void (*)(Allocation*); + using DeleterFnPtr = std::function; Allocation() = default; diff --git a/python/paddle/utils/dlpack.py b/python/paddle/utils/dlpack.py index 99d027eae23011..8451b70a4e2495 100644 --- a/python/paddle/utils/dlpack.py +++ b/python/paddle/utils/dlpack.py @@ -136,4 +136,10 @@ def from_dlpack(ext_tensor) -> Tensor: else: # Old versions just call the converter dlpack = ext_tensor - return paddle.base.core.from_dlpack(dlpack) + + out: paddle.base.libpaddle.Tensor = paddle.base.core.from_dlpack(dlpack) + + if in_dygraph_mode(): + out: Tensor = paddle.to_tensor(out) + + return out From 09336d4ede9f13063c0a9aeb483ae59d0f71b23e Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 10 Sep 2024 14:19:46 +0800 Subject: [PATCH 05/21] add type check for dlpack --- python/paddle/utils/dlpack.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/python/paddle/utils/dlpack.py b/python/paddle/utils/dlpack.py index 8451b70a4e2495..04a644b7d62e2a 100644 --- a/python/paddle/utils/dlpack.py +++ b/python/paddle/utils/dlpack.py @@ -86,7 +86,7 @@ def to_dlpack(x: Tensor) -> CapsuleType: return x._to_dlpack() -def from_dlpack(ext_tensor) -> Tensor: +def from_dlpack(dlpack) -> Tensor: """ Decodes a DLPack to a tensor. @@ -113,8 +113,17 @@ def from_dlpack(ext_tensor) -> Tensor: [0.10000000, 0.20000000, 0.60000002, 0.69999999]]) """ - if hasattr(ext_tensor, '__dlpack__'): - device = ext_tensor.__dlpack_device__() + # Check the type of dlpack + t = type(dlpack) + dlpack_flag = t.__module__ == 'builtins' and t.__name__ == 'PyCapsule' + if not dlpack_flag: + raise TypeError( + "The type of 'dlpack' in from_dlpack must be PyCapsule object," + f" but received {type(dlpack)}." + ) + + if hasattr(dlpack, '__dlpack__'): + device = dlpack.__dlpack_device__() # device is either CUDA or ROCm, we need to pass the current # stream if device[0] in (DLDeviceType.kDLGPU, DLDeviceType.kDLROCM): @@ -130,14 +139,14 @@ def from_dlpack(ext_tensor) -> Tensor: stream_ptr = ( 1 if is_gpu and stream.cuda_stream == 0 else stream.cuda_stream ) - dlpack = ext_tensor.__dlpack__(stream=stream_ptr) + dlpack_ = dlpack.__dlpack__(stream=stream_ptr) else: - dlpack = ext_tensor.__dlpack__() + dlpack_ = dlpack.__dlpack__() else: # Old versions just call the converter - dlpack = ext_tensor + dlpack_ = dlpack - out: paddle.base.libpaddle.Tensor = paddle.base.core.from_dlpack(dlpack) + out: paddle.base.libpaddle.Tensor = paddle.base.core.from_dlpack(dlpack_) if in_dygraph_mode(): out: Tensor = paddle.to_tensor(out) From 10a34a04e70842330d7d210022137df9d333b73c Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 10 Sep 2024 14:25:31 +0800 Subject: [PATCH 06/21] add data_ptr consistency unitest --- test/legacy_test/test_dlpack.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/legacy_test/test_dlpack.py b/test/legacy_test/test_dlpack.py index 5ac3e83c94d3bd..5d991eee0bbf45 100644 --- a/test/legacy_test/test_dlpack.py +++ b/test/legacy_test/test_dlpack.py @@ -131,6 +131,14 @@ def test_to_dlpack_modification(self): y[1:2, 2:5] = 2.0 np.testing.assert_allclose(x.numpy(), y.numpy()) + def test_to_dlpack_data_ptr_consistency(self): + # See Paddle issue 50120 + for i in range(2): + x = paddle.rand([3, 5]) + dlpack = paddle.utils.dlpack.to_dlpack(x) + y = paddle.utils.dlpack.from_dlpack(dlpack) + self.assertEqual(x.data_ptr(), y.data_ptr()) + class TestRaiseError(unittest.TestCase): def test_from_dlpack_raise_type_error(self): From a18afea10abe8cbbf7177d39a98f7111f6163419 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 10 Sep 2024 16:14:22 +0800 Subject: [PATCH 07/21] remove cuda_runtime.h and remove redundant annotation --- paddle/fluid/framework/tensor_util.cc | 11 +++++------ paddle/fluid/pybind/pybind.cc | 15 ++------------- python/paddle/utils/dlpack.py | 1 - 3 files changed, 7 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 64fafed81e5b7d..ceb9c657ffae4b 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -31,7 +31,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_DNNL #include "dnnl_debug.h" // NOLINT #endif -#include + namespace paddle { namespace framework { @@ -801,15 +801,13 @@ phi::Place GetPlaceFromPtr(void* data) { return phi::CPUPlace(); } -// NOTE: set 'g_deleter' and 'g_src' as static variable -// for being used in lambda function 'f' without specified in capture list using Deleter = std::function; -using AllocationDeleter = std::function; - /* code ref: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/DLConvertor.cpp */ +using AllocationDeleter = void (*)(phi::Allocation*); + phi::DenseTensor from_blob(void* data, DLManagedTensor* src, const phi::DDim& shape, @@ -824,11 +822,12 @@ phi::DenseTensor from_blob(void* data, size_t size = SizeOf(dtype) * (meta.is_scalar ? 1 : product(meta.dims)); AllocationDeleter f = nullptr; if (deleter) { - f = [deleter, src](phi::Allocation* p) { + auto g = [deleter, src] { if (src->manager_ctx) { deleter(src); } }; + f = [](phi::Allocation* p) { g(); }; } auto alloc = std::make_shared(data, size, f, place); return phi::DenseTensor(alloc, meta); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 50bd4055517582..da2c78a1b28449 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1274,19 +1274,8 @@ PYBIND11_MODULE(libpaddle, m) { "Note that DLTensor capsules can be consumed only once, " "so you might have already constructed a tensor from it once.")); - // auto deleter_with_gil = [dlMTensor](void*) { - // if (dlMTensor->deleter) { - // pybind11::gil_scoped_acquire gil; - // printf("call gil deleter begin\n"); - // dlMTensor->deleter(dlMTensor); - // printf("call gil deleter end\n"); - // } else { - // printf("dlMTensor->deleter is nullptr\n"); - // } - // }; - - // auto ptensor = paddle::framework::TensorFromDLPack(dlMTensor, - // std::move(deleter_with_gil)); + // NOTE: Might meet bugged numpy version, see: + // https://github.com/pytorch/pytorch/blob/main/torch/csrc/utils/tensor_new.cpp#L1636-L1638 auto ptensor = paddle::framework::TensorFromDLPack(dlMTensor); PyCapsule_SetName(data.ptr(), "used_dltensor"); diff --git a/python/paddle/utils/dlpack.py b/python/paddle/utils/dlpack.py index 04a644b7d62e2a..8772b4b1537516 100644 --- a/python/paddle/utils/dlpack.py +++ b/python/paddle/utils/dlpack.py @@ -35,7 +35,6 @@ class DLDeviceType(enum.IntEnum): - # Enums as in DLPack specification (aten/src/ATen/dlpack.h) kDLCPU = (1,) kDLGPU = (2,) kDLCPUPinned = (3,) From b65b74c4bf41593f7f18d4255d074ac0e23ca888 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 10 Sep 2024 16:24:36 +0800 Subject: [PATCH 08/21] update ref doc --- paddle/fluid/framework/tensor_util.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index ceb9c657ffae4b..68a27a14a961b5 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -801,11 +801,12 @@ phi::Place GetPlaceFromPtr(void* data) { return phi::CPUPlace(); } -using Deleter = std::function; /* -code ref: +dlpack related code ref: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/DLConvertor.cpp +and paddle/phi/api/lib/tensor_utils.cc */ +using Deleter = std::function; using AllocationDeleter = void (*)(phi::Allocation*); phi::DenseTensor from_blob(void* data, From d8cadb438631f2e38dfd4bb3e676cad681847eb9 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 10 Sep 2024 16:46:47 +0800 Subject: [PATCH 09/21] restore --- paddle/fluid/framework/tensor_util.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 68a27a14a961b5..a7c05810406427 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -807,7 +807,7 @@ dlpack related code ref: and paddle/phi/api/lib/tensor_utils.cc */ using Deleter = std::function; -using AllocationDeleter = void (*)(phi::Allocation*); +using AllocationDeleter = std::function; phi::DenseTensor from_blob(void* data, DLManagedTensor* src, @@ -823,12 +823,11 @@ phi::DenseTensor from_blob(void* data, size_t size = SizeOf(dtype) * (meta.is_scalar ? 1 : product(meta.dims)); AllocationDeleter f = nullptr; if (deleter) { - auto g = [deleter, src] { + f = [deleter, src](phi::Allocation* p) { if (src->manager_ctx) { deleter(src); } }; - f = [](phi::Allocation* p) { g(); }; } auto alloc = std::make_shared(data, size, f, place); return phi::DenseTensor(alloc, meta); From 3db178ac15d2a707168cdb33d3428e6c46b46fea Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 11 Sep 2024 19:05:03 +0800 Subject: [PATCH 10/21] use map and mutex instead of std::function --- paddle/fluid/framework/tensor_util.cc | 27 ++++++++++++++++++++++----- paddle/phi/core/allocator.h | 2 +- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index a7c05810406427..5591c19ddb3f8c 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor_util.h" #include +#include #include #include #include @@ -807,7 +808,18 @@ dlpack related code ref: and paddle/phi/api/lib/tensor_utils.cc */ using Deleter = std::function; -using AllocationDeleter = std::function; + +std::unordered_map> ptr_to_deleter; +std::mutex ptr_to_deleter_mutex; + +void DeleterBridge(phi::Allocation* alloc) { + std::lock_guard lock(ptr_to_deleter_mutex); + auto it = ptr_to_deleter.find(static_cast(alloc->ptr())); + if (it != ptr_to_deleter.end()) { + it->second(alloc); // 调用实际的删除器 + ptr_to_deleter.erase(it); // 删除对应的条目 + } +} phi::DenseTensor from_blob(void* data, DLManagedTensor* src, @@ -821,20 +833,25 @@ phi::DenseTensor from_blob(void* data, auto meta = phi::DenseTensorMeta(dtype, shape, layout); size_t size = SizeOf(dtype) * (meta.is_scalar ? 1 : product(meta.dims)); - AllocationDeleter f = nullptr; + phi::Allocation::DeleterFnPtr f = nullptr; if (deleter) { - f = [deleter, src](phi::Allocation* p) { + auto g = [deleter, src](phi::Allocation* p) { if (src->manager_ctx) { deleter(src); } }; + + { + std::lock_guard lock(ptr_to_deleter_mutex); + ptr_to_deleter[data] = g; + } + f = DeleterBridge; } auto alloc = std::make_shared(data, size, f, place); return phi::DenseTensor(alloc, meta); } -phi::DenseTensor TensorFromDLPack(DLManagedTensor* src, - std::function deleter) { +phi::DenseTensor TensorFromDLPack(DLManagedTensor* src, Deleter deleter) { std::vector vec; std::copy(src->dl_tensor.shape, src->dl_tensor.shape + src->dl_tensor.ndim, diff --git a/paddle/phi/core/allocator.h b/paddle/phi/core/allocator.h index a495860136b196..1d89fd1b4aa88b 100644 --- a/paddle/phi/core/allocator.h +++ b/paddle/phi/core/allocator.h @@ -29,7 +29,7 @@ namespace phi { class Allocation { public: using Place = phi::Place; - using DeleterFnPtr = std::function; + using DeleterFnPtr = void (*)(Allocation*); Allocation() = default; From 79cf661c48c78748cb4aa357c0d5ae282fa95138 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 11 Sep 2024 19:09:28 +0800 Subject: [PATCH 11/21] polish annotation --- paddle/fluid/framework/tensor_util.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 5591c19ddb3f8c..dd27c29bcfa037 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -810,14 +810,14 @@ and paddle/phi/api/lib/tensor_utils.cc using Deleter = std::function; std::unordered_map> ptr_to_deleter; -std::mutex ptr_to_deleter_mutex; +std::mutex ptr_to_deleter_mutex; // use mutex to keep thread safe void DeleterBridge(phi::Allocation* alloc) { std::lock_guard lock(ptr_to_deleter_mutex); auto it = ptr_to_deleter.find(static_cast(alloc->ptr())); if (it != ptr_to_deleter.end()) { - it->second(alloc); // 调用实际的删除器 - ptr_to_deleter.erase(it); // 删除对应的条目 + it->second(alloc); // call the deleter + ptr_to_deleter.erase(it); // remove the entry from the map safely } } From 342768bb7b58b1168c741a127324d1b86f2c73a0 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 11 Sep 2024 20:27:35 +0800 Subject: [PATCH 12/21] fix for cpu tensor --- paddle/fluid/framework/tensor_util.cc | 3 +++ python/paddle/utils/dlpack.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index dd27c29bcfa037..890759185aa8ca 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -834,6 +834,7 @@ phi::DenseTensor from_blob(void* data, auto meta = phi::DenseTensorMeta(dtype, shape, layout); size_t size = SizeOf(dtype) * (meta.is_scalar ? 1 : product(meta.dims)); phi::Allocation::DeleterFnPtr f = nullptr; + if (deleter) { auto g = [deleter, src](phi::Allocation* p) { if (src->manager_ctx) { @@ -845,8 +846,10 @@ phi::DenseTensor from_blob(void* data, std::lock_guard lock(ptr_to_deleter_mutex); ptr_to_deleter[data] = g; } + f = DeleterBridge; } + auto alloc = std::make_shared(data, size, f, place); return phi::DenseTensor(alloc, meta); } diff --git a/python/paddle/utils/dlpack.py b/python/paddle/utils/dlpack.py index 8772b4b1537516..25d957a0b108d8 100644 --- a/python/paddle/utils/dlpack.py +++ b/python/paddle/utils/dlpack.py @@ -148,6 +148,6 @@ def from_dlpack(dlpack) -> Tensor: out: paddle.base.libpaddle.Tensor = paddle.base.core.from_dlpack(dlpack_) if in_dygraph_mode(): - out: Tensor = paddle.to_tensor(out) + out: Tensor = paddle.Tensor(out, place=out._place()) return out From b5d777b6a6390d8b6ce04175998e6d765cbf156d Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 11 Sep 2024 20:40:38 +0800 Subject: [PATCH 13/21] add data_ptr consistency unitest --- test/legacy_test/test_dlpack.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/test/legacy_test/test_dlpack.py b/test/legacy_test/test_dlpack.py index 5d991eee0bbf45..fd6e703edae0b4 100644 --- a/test/legacy_test/test_dlpack.py +++ b/test/legacy_test/test_dlpack.py @@ -133,11 +133,13 @@ def test_to_dlpack_modification(self): def test_to_dlpack_data_ptr_consistency(self): # See Paddle issue 50120 - for i in range(2): - x = paddle.rand([3, 5]) - dlpack = paddle.utils.dlpack.to_dlpack(x) - y = paddle.utils.dlpack.from_dlpack(dlpack) - self.assertEqual(x.data_ptr(), y.data_ptr()) + for place in [base.CPUPlace(), base.CUDAPlace(0)]: + for i in range(10): + x = paddle.rand([3, 5]).to(device=place) + dlpack = paddle.utils.dlpack.to_dlpack(x) + y = paddle.utils.dlpack.from_dlpack(dlpack) + + self.assertEqual(x.data_ptr(), y.data_ptr()) class TestRaiseError(unittest.TestCase): From 72e85f924c67e6f1be3c36c47528e82225eb4d85 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 12 Sep 2024 12:08:06 +0800 Subject: [PATCH 14/21] update thirdpary/dlpack to v0.8 to support Tensor of bool dtype --- cmake/external/dlpack.cmake | 2 +- paddle/fluid/framework/dlpack_tensor.cc | 6 +- paddle/fluid/framework/tensor_util.cc | 2 +- python/paddle/utils/dlpack.py | 8 +- .../cpp/fluid/framework/dlpack_tensor_test.cc | 8 +- test/legacy_test/test_dlpack.py | 208 ++++++++++-------- third_party/dlpack | 2 +- 7 files changed, 131 insertions(+), 105 deletions(-) diff --git a/cmake/external/dlpack.cmake b/cmake/external/dlpack.cmake index 60752d99e76fdc..4677c9001ff41e 100644 --- a/cmake/external/dlpack.cmake +++ b/cmake/external/dlpack.cmake @@ -15,7 +15,7 @@ include(ExternalProject) set(DLPACK_PREFIX_DIR ${THIRD_PARTY_PATH}/dlpack) -set(DLPACK_TAG v0.4) +set(DLPACK_TAG v0.8) set(DLPACK_INCLUDE_DIR ${THIRD_PARTY_PATH}/dlpack/src/extern_dlpack/include) set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/dlpack) include_directories(${SOURCE_DIR}/include) diff --git a/paddle/fluid/framework/dlpack_tensor.cc b/paddle/fluid/framework/dlpack_tensor.cc index bc0075cc8183b0..51fc97a3332b97 100644 --- a/paddle/fluid/framework/dlpack_tensor.cc +++ b/paddle/fluid/framework/dlpack_tensor.cc @@ -35,6 +35,8 @@ static ::DLDataType GetDLDataTypeCode() { dtype.code = kDLFloat; } else if (std::is_unsigned::value) { dtype.code = kDLUInt; + } else if (std::is_same::value) { + dtype.code = kDLBool; } else if (std::is_integral::value) { dtype.code = kDLInt; } else { @@ -99,7 +101,7 @@ struct DLDeviceVisitor { inline ::DLDevice operator()(const phi::GPUPlace &place) const { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) ::DLDevice device; - device.device_type = kDLGPU; + device.device_type = kDLCUDA; device.device_id = place.device; // NOLINT return device; #else @@ -111,7 +113,7 @@ struct DLDeviceVisitor { inline ::DLDevice operator()(const phi::GPUPinnedPlace &place) const { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) ::DLDevice device; - device.device_type = kDLCPUPinned; + device.device_type = kDLCUDAHost; device.device_id = 0; return device; #else diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 890759185aa8ca..fb0b753d3a13b7 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -863,7 +863,7 @@ phi::DenseTensor TensorFromDLPack(DLManagedTensor* src, Deleter deleter) { phi::Place place; if (src->dl_tensor.device.device_type == kDLCPU) { place = phi::CPUPlace(); - } else if (src->dl_tensor.device.device_type == kDLGPU) { + } else if (src->dl_tensor.device.device_type == kDLCUDA) { place = phi::GPUPlace(); } else { PADDLE_THROW(phi::errors::Unimplemented("Given Place is not supported")); diff --git a/python/paddle/utils/dlpack.py b/python/paddle/utils/dlpack.py index 25d957a0b108d8..6679413549997a 100644 --- a/python/paddle/utils/dlpack.py +++ b/python/paddle/utils/dlpack.py @@ -36,8 +36,8 @@ class DLDeviceType(enum.IntEnum): kDLCPU = (1,) - kDLGPU = (2,) - kDLCPUPinned = (3,) + kDLCUDA = (2,) + kDLCUDAHost = (3,) kDLOpenCL = (4,) kDLVulkan = (7,) kDLMetal = (8,) @@ -125,14 +125,14 @@ def from_dlpack(dlpack) -> Tensor: device = dlpack.__dlpack_device__() # device is either CUDA or ROCm, we need to pass the current # stream - if device[0] in (DLDeviceType.kDLGPU, DLDeviceType.kDLROCM): + if device[0] in (DLDeviceType.kDLCUDA, DLDeviceType.kDLROCM): stream = paddle.device.cuda.current_stream(device[1]) # cuda_stream is the pointer to the stream and it is a public # attribute, but it is not documented # The array API specify that the default legacy stream must be passed # with a value of 1 for CUDA # https://data-apis.org/array-api/latest/API_specification/array_object.html?dlpack-self-stream-none#dlpack-self-stream-none - is_gpu = device[0] == DLDeviceType.kDLGPU + is_gpu = device[0] == DLDeviceType.kDLCUDA # Since pytorch is not using PTDS by default, lets directly pass # the legacy stream stream_ptr = ( diff --git a/test/cpp/fluid/framework/dlpack_tensor_test.cc b/test/cpp/fluid/framework/dlpack_tensor_test.cc index 5fdccf424dc99e..6a71be8aab57eb 100644 --- a/test/cpp/fluid/framework/dlpack_tensor_test.cc +++ b/test/cpp/fluid/framework/dlpack_tensor_test.cc @@ -79,10 +79,10 @@ void TestMain(const phi::Place &place, uint16_t lanes) { "but got %d", dl_tensor.device.device_id)); } else if (phi::is_gpu_place(place)) { - PADDLE_ENFORCE_EQ(kDLGPU, + PADDLE_ENFORCE_EQ(kDLCUDA, dl_tensor.device.device_type, common::errors::InvalidArgument( - "Device type should be kDLGPU, but got %d", + "Device type should be kDLCUDA, but got %d", dl_tensor.device.device_type)); PADDLE_ENFORCE_EQ( place.device, @@ -93,9 +93,9 @@ void TestMain(const phi::Place &place, uint16_t lanes) { dl_tensor.device.device_id)); } else if (phi::is_cuda_pinned_place(place)) { PADDLE_ENFORCE_EQ( - kDLCPUPinned, + kDLCUDAHost, dl_tensor.device.device_type, - common::errors::InvalidArgument("Device type should be kDLCPUPinned, " + common::errors::InvalidArgument("Device type should be kDLCUDAHost, " "but got %d", dl_tensor.device.device_type)); PADDLE_ENFORCE_EQ( diff --git a/test/legacy_test/test_dlpack.py b/test/legacy_test/test_dlpack.py index fd6e703edae0b4..d16adc40ddc02f 100644 --- a/test/legacy_test/test_dlpack.py +++ b/test/legacy_test/test_dlpack.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from utils import dygraph_guard, static_guard import paddle from paddle import base @@ -23,123 +24,146 @@ class TestDLPack(unittest.TestCase): def test_dlpack_dygraph(self): - paddle.disable_static() - tensor = paddle.to_tensor(np.array([1, 2, 3, 4]).astype('int')) - dlpack = paddle.utils.dlpack.to_dlpack(tensor) - out_from_dlpack = paddle.utils.dlpack.from_dlpack(dlpack) - if paddle.in_dynamic_mode(): + with dygraph_guard(): + tensor = paddle.to_tensor(np.array([1, 2, 3, 4]).astype("int")) + dlpack = paddle.utils.dlpack.to_dlpack(tensor) + out_from_dlpack = paddle.utils.dlpack.from_dlpack(dlpack) self.assertTrue( isinstance(out_from_dlpack, paddle.base.core.eager.Tensor) ) - else: - self.assertTrue(isinstance(out_from_dlpack, paddle.Tensor)) - np.testing.assert_array_equal( - np.array(out_from_dlpack), np.array([1, 2, 3, 4]).astype('int') - ) + np.testing.assert_array_equal( + out_from_dlpack.numpy(), np.array([1, 2, 3, 4]).astype("int") + ) def test_dlpack_tensor_larger_than_2dim(self): - paddle.disable_static() - numpy_data = np.random.randn(4, 5, 6) - t = paddle.to_tensor(numpy_data) - # TODO: There may be a reference count problem of to_dlpack. - dlpack = paddle.utils.dlpack.to_dlpack(t) - out = paddle.utils.dlpack.from_dlpack(dlpack) - np.testing.assert_allclose(numpy_data, out.numpy(), rtol=1e-05) + with dygraph_guard(): + numpy_data = np.random.randn(4, 5, 6) + t = paddle.to_tensor(numpy_data) + dlpack = paddle.utils.dlpack.to_dlpack(t) + out = paddle.utils.dlpack.from_dlpack(dlpack) + np.testing.assert_allclose(numpy_data, out.numpy(), rtol=1e-05) def test_dlpack_static(self): - paddle.enable_static() - tensor = base.create_lod_tensor( - np.array([[1], [2], [3], [4]]).astype('int'), - [[1, 3]], - base.CPUPlace(), - ) - dlpack = paddle.utils.dlpack.to_dlpack(tensor) - out_from_dlpack = paddle.utils.dlpack.from_dlpack(dlpack) - self.assertTrue(isinstance(out_from_dlpack, base.core.Tensor)) - np.testing.assert_array_equal( - np.array(out_from_dlpack), - np.array([[1], [2], [3], [4]]).astype('int'), - ) - - # when build with cuda - if core.is_compiled_with_cuda(): - gtensor = base.create_lod_tensor( - np.array([[1], [2], [3], [4]]).astype('int'), + with static_guard(): + tensor = base.create_lod_tensor( + np.array([[1], [2], [3], [4]]).astype("int"), [[1, 3]], - base.CUDAPlace(0), + base.CPUPlace(), ) - gdlpack = paddle.utils.dlpack.to_dlpack(gtensor) - gout_from_dlpack = paddle.utils.dlpack.from_dlpack(gdlpack) - self.assertTrue(isinstance(gout_from_dlpack, base.core.Tensor)) + dlpack = paddle.utils.dlpack.to_dlpack(tensor) + out_from_dlpack = paddle.utils.dlpack.from_dlpack(dlpack) + self.assertTrue(isinstance(out_from_dlpack, base.core.Tensor)) np.testing.assert_array_equal( - np.array(gout_from_dlpack), - np.array([[1], [2], [3], [4]]).astype('int'), + np.array(out_from_dlpack), + np.array([[1], [2], [3], [4]]).astype("int"), ) + # when build with cuda + if core.is_compiled_with_cuda(): + gtensor = base.create_lod_tensor( + np.array([[1], [2], [3], [4]]).astype("int"), + [[1, 3]], + base.CUDAPlace(0), + ) + gdlpack = paddle.utils.dlpack.to_dlpack(gtensor) + gout_from_dlpack = paddle.utils.dlpack.from_dlpack(gdlpack) + self.assertTrue(isinstance(gout_from_dlpack, base.core.Tensor)) + np.testing.assert_array_equal( + np.array(gout_from_dlpack), + np.array([[1], [2], [3], [4]]).astype("int"), + ) + def test_dlpack_dtype_conversion(self): - paddle.disable_static() - # DLpack does not explicitly support bool data type. - dtypes = [ - "float16", - "float32", - "float64", - "int8", - "int16", - "int32", - "int64", - "uint8", - ] - data = np.ones((2, 3, 4)) - for dtype in dtypes: - x = paddle.to_tensor(data, dtype=dtype) - dlpack = paddle.utils.dlpack.to_dlpack(x) - o = paddle.utils.dlpack.from_dlpack(dlpack) - self.assertEqual(x.dtype, o.dtype) - np.testing.assert_allclose(x.numpy(), o.numpy(), rtol=1e-05) - - complex_dtypes = ["complex64", "complex128"] - for dtype in complex_dtypes: - x = paddle.to_tensor( - [[1 + 6j, 2 + 5j, 3 + 4j], [4 + 3j, 5 + 2j, 6 + 1j]], - dtype=dtype, - ) - dlpack = paddle.utils.dlpack.to_dlpack(x) - o = paddle.utils.dlpack.from_dlpack(dlpack) - self.assertEqual(x.dtype, o.dtype) - np.testing.assert_allclose(x.numpy(), o.numpy(), rtol=1e-05) + with dygraph_guard(): + dtypes = [ + "float16", + "float32", + "float64", + "int8", + "int16", + "int32", + "int64", + "uint8", + "bool", + ] + places = [base.CPUPlace()] + if paddle.is_compiled_with_cuda(): + places.append(base.CUDAPlace(0)) + + data = np.ones((2, 3, 4)) + for place in places: + for dtype in dtypes: + x = paddle.to_tensor(data, dtype=dtype, place=place) + dlpack = paddle.utils.dlpack.to_dlpack(x) + o = paddle.utils.dlpack.from_dlpack(dlpack) + self.assertEqual(x.dtype, o.dtype) + np.testing.assert_allclose(x.numpy(), o.numpy(), rtol=1e-05) + + complex_dtypes = ["complex64", "complex128"] + for place in places: + for dtype in complex_dtypes: + x = paddle.to_tensor( + [[1 + 6j, 2 + 5j, 3 + 4j], [4 + 3j, 5 + 2j, 6 + 1j]], + dtype=dtype, + place=place, + ) + dlpack = paddle.utils.dlpack.to_dlpack(x) + o = paddle.utils.dlpack.from_dlpack(dlpack) + self.assertEqual(x.dtype, o.dtype) + np.testing.assert_allclose(x.numpy(), o.numpy(), rtol=1e-05) def test_dlpack_deletion(self): # See Paddle issue 47171 - if paddle.is_compiled_with_cuda(): - for i in range(80): - a = paddle.rand(shape=[1024 * 128, 1024], dtype="float32") - dlpack = paddle.utils.dlpack.to_dlpack(a) - b = paddle.utils.dlpack.from_dlpack(dlpack) + with dygraph_guard(): + places = [base.CPUPlace()] + if paddle.is_compiled_with_cuda(): + places.append(base.CUDAPlace(0)) + for place in places: + for _ in range(20): + a = paddle.rand( + shape=[1024 * 128, 1024], dtype="float32" + ).to(device=place) + dlpack = paddle.utils.dlpack.to_dlpack(a) + b = paddle.utils.dlpack.from_dlpack(dlpack) def test_to_dlpack_for_loop(self): # See Paddle issue 50120 - for i in range(10): - x = paddle.rand([3, 5]) - dlpack = paddle.utils.dlpack.to_dlpack(x) + with dygraph_guard(): + places = [base.CPUPlace()] + if paddle.is_compiled_with_cuda(): + places.append(base.CUDAPlace(0)) + for place in places: + for _ in range(10): + x = paddle.rand([3, 5]).to(device=place) + dlpack = paddle.utils.dlpack.to_dlpack(x) def test_to_dlpack_modification(self): # See Paddle issue 50120 - for i in range(10): - x = paddle.rand([3, 5]) - dlpack = paddle.utils.dlpack.to_dlpack(x) - y = paddle.utils.dlpack.from_dlpack(dlpack) - y[1:2, 2:5] = 2.0 - np.testing.assert_allclose(x.numpy(), y.numpy()) + with dygraph_guard(): + places = [base.CPUPlace()] + if paddle.is_compiled_with_cuda(): + places.append(base.CUDAPlace(0)) + for place in places: + for _ in range(10): + x = paddle.rand([3, 5]).to(device=place) + dlpack = paddle.utils.dlpack.to_dlpack(x) + y = paddle.utils.dlpack.from_dlpack(dlpack) + y[1:2, 2:5] = 2.0 + np.testing.assert_allclose(x.numpy(), y.numpy()) def test_to_dlpack_data_ptr_consistency(self): # See Paddle issue 50120 - for place in [base.CPUPlace(), base.CUDAPlace(0)]: - for i in range(10): - x = paddle.rand([3, 5]).to(device=place) - dlpack = paddle.utils.dlpack.to_dlpack(x) - y = paddle.utils.dlpack.from_dlpack(dlpack) + with dygraph_guard(): + places = [base.CPUPlace()] + if paddle.is_compiled_with_cuda(): + places.append(base.CUDAPlace(0)) + for place in places: + for _ in range(10): + x = paddle.rand([3, 5]).to(device=place) + dlpack = paddle.utils.dlpack.to_dlpack(x) + y = paddle.utils.dlpack.from_dlpack(dlpack) - self.assertEqual(x.data_ptr(), y.data_ptr()) + self.assertEqual(x.data_ptr(), y.data_ptr()) class TestRaiseError(unittest.TestCase): @@ -152,5 +176,5 @@ def test_to_dlpack_raise_type_error(self): self.assertRaises(TypeError, paddle.utils.dlpack.to_dlpack, np.zeros(5)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/third_party/dlpack b/third_party/dlpack index a07f962d446b57..365b823cedb281 160000 --- a/third_party/dlpack +++ b/third_party/dlpack @@ -1 +1 @@ -Subproject commit a07f962d446b577adf4baef2b347a0f3a2a20617 +Subproject commit 365b823cedb281cd0240ca601aba9b78771f91a3 From f64e0ce539d91bf8a1c536df83820c6daad91101 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 12 Sep 2024 13:05:42 +0800 Subject: [PATCH 15/21] move std::is_same before std::is_unsigned --- paddle/fluid/framework/dlpack_tensor.cc | 6 ++++-- paddle/fluid/framework/tensor_util.cc | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/dlpack_tensor.cc b/paddle/fluid/framework/dlpack_tensor.cc index 51fc97a3332b97..99808691e108e1 100644 --- a/paddle/fluid/framework/dlpack_tensor.cc +++ b/paddle/fluid/framework/dlpack_tensor.cc @@ -33,10 +33,12 @@ static ::DLDataType GetDLDataTypeCode() { } else if (std::is_same::value || std::is_floating_point::value) { dtype.code = kDLFloat; - } else if (std::is_unsigned::value) { - dtype.code = kDLUInt; } else if (std::is_same::value) { + // Since std::is_unsigned::value is True, + // it is necessary to evaluate bool before std::is_unsigned. dtype.code = kDLBool; + } else if (std::is_unsigned::value) { + dtype.code = kDLUInt; } else if (std::is_integral::value) { dtype.code = kDLInt; } else { diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index fb0b753d3a13b7..4b6cd00f5d7e25 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -736,6 +736,7 @@ phi::DataType GetDstPtrByDLDataType(DLDataType type) { switch (type.bits) { case 8: + if (type.code == kDLBool) return phi::DataType::BOOL; if (type.code == kDLInt) return phi::DataType::INT8; if (type.code == kDLUInt) return phi::DataType::UINT8; PADDLE_THROW(common::errors::Unimplemented( From 5f8c202724df8e37d5045ceb745b7958f5c475a9 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 12 Sep 2024 13:53:13 +0800 Subject: [PATCH 16/21] restore overloaded version of TensorFromDLPack --- paddle/fluid/framework/tensor_util.cc | 41 +++++++++++++++++++++++++++ paddle/fluid/framework/tensor_util.h | 9 +++--- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 4b6cd00f5d7e25..89323bde04e27c 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -900,6 +900,47 @@ phi::DenseTensor TensorFromDLPack(DLManagedTensor* src) { return TensorFromDLPack(src, std::move(deleter)); } +// Keep the this overloaded version of the interface unchanged. +void TensorFromDLPack(const ::DLTensor& dl_tensor, phi::DenseTensor* dst) { + phi::CPUPlace dst_place = phi::CPUPlace(); + phi::CPUPlace src_place = phi::CPUPlace(); + + std::vector vec; + std::copy(dl_tensor.shape, + dl_tensor.shape + dl_tensor.ndim, + std::back_inserter(vec)); + + phi::DDim vddim = common::make_ddim(vec); + + dst->Resize(vddim); + ::DLDataType type = dl_tensor.dtype; + void* dst_ptr = GetDstPtrByDLDataType(type, dst, dst_place); + + auto src_ptr = static_cast(dl_tensor.data); + auto size = common::product(vddim) * type.bits / 8; + + if (dl_tensor.device.device_type == kDLCPU) { + memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (dl_tensor.device.device_type == kDLGPU) { + phi::GPUPlace dst_place = phi::GPUPlace(dl_tensor.device.device_id); + phi::GPUPlace src_place = phi::GPUPlace(dl_tensor.device.device_id); + dst_ptr = GetDstPtrByDLDataType(type, dst, dst_place); + auto* ctx = phi::DeviceContextPool::Instance().GetByPlace(dst_place); + memory::Copy(dst_place, + dst_ptr, + src_place, + src_ptr, + size, + reinterpret_cast(*ctx).stream()); + } +#endif +#ifdef PADDLE_WITH_XPU + PADDLE_THROW(common::errors::Unimplemented("XPUPlace is not supported")); +#endif +} + template std::string format_tensor(const phi::DenseTensor& tensor) { // TODO(zhiqiu): use the print option to format tensor. diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index da9c02d20f0c97..a55de60e2ba6da 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -106,13 +106,14 @@ template void TensorToVector(const phi::DenseTensor& src, std::vector* dst); // convert dlpack's DLTensor to tensor +TEST_API void TensorFromDLPack(const ::DLTensor& dl_tensor, + phi::DenseTensor* dst); -// TEST_API void TensorFromDLPack(const ::DLTensor& dl_tensor, -// phi::DenseTensor* dst); TEST_API phi::DenseTensor TensorFromDLPack(DLManagedTensor* src); -inline phi::DenseTensor fromDLPack(const DLManagedTensor* src) { - return fromDLPack(const_cast(src)); +inline phi::DenseTensor TensorFromDLPack(const DLManagedTensor* src) { + return TensorFromDLPack(const_cast(src)); } + phi::DenseTensor TensorFromDLPack(DLManagedTensor* src, std::function deleter); // From 31690509963101ab6565750f407f455246ea9850 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 12 Sep 2024 15:01:40 +0800 Subject: [PATCH 17/21] restore overload version of GetDstPtrByDLDataType and update dlpack docstring --- paddle/fluid/framework/tensor_util.cc | 70 ++++++++++++++++++++++++++- python/paddle/utils/dlpack.py | 26 +++++++--- 2 files changed, 88 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 89323bde04e27c..fc77293591a937 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -726,6 +726,74 @@ void TensorFromStream(std::istream& is, } } +// get tensor data point by DLDataType +void* GetDstPtrByDLDataType(DLDataType type, + phi::DenseTensor* dst, + const phi::Place& dst_place) { + // vector types not currently supported + PADDLE_ENFORCE_LE( + type.lanes, + 1, + common::errors::Unimplemented("Vector type is not supported currently.")); + + switch (type.bits) { + case 8: + if (type.code == kDLInt) + return static_cast(dst->mutable_data(dst_place)); + if (type.code == kDLUInt) + return static_cast(dst->mutable_data(dst_place)); + PADDLE_THROW(common::errors::Unimplemented( + "DLDataType code <%d> is illegal when DLDataType.bits is <%d>.", + type.code, + type.bits)); + case 16: + if (type.code == kDLInt) + return static_cast(dst->mutable_data(dst_place)); + if (type.code == kDLFloat) + return static_cast( + dst->mutable_data(dst_place)); + if (type.code == kDLBfloat) + return static_cast( + dst->mutable_data(dst_place)); + PADDLE_THROW(common::errors::Unimplemented( + "DLDataType code <%d> is illegal when DLDataType.bits is <%d>.", + type.code, + type.bits)); + case 32: + if (type.code == kDLInt) + return static_cast(dst->mutable_data(dst_place)); + if (type.code == kDLFloat) + return static_cast(dst->mutable_data(dst_place)); + PADDLE_THROW(common::errors::Unimplemented( + "DLDataType code <%d> is illegal when DLDataType.bits is <%d>.", + type.code, + type.bits)); + case 64: + if (type.code == kDLInt) + return static_cast(dst->mutable_data(dst_place)); + if (type.code == kDLFloat) + return static_cast(dst->mutable_data(dst_place)); + if (type.code == kDLComplex) + return static_cast( + dst->mutable_data>(dst_place)); + PADDLE_THROW(common::errors::Unimplemented( + "DLDataType code <%d> is illegal when DLDataType.bits is <%d>.", + type.code, + type.bits)); + case 128: + if (type.code == kDLComplex) + return static_cast( + dst->mutable_data>(dst_place)); + PADDLE_THROW(common::errors::Unimplemented( + "DLDataType code <%d> is illegal when DLDataType.bits is <%d>.", + type.code, + type.bits)); + default: + PADDLE_THROW(common::errors::Unimplemented( + "Unsupported DLDataType.bits %d.", type.bits)); + } +} + // get Tensor data dtype from given DLDataType phi::DataType GetDstPtrByDLDataType(DLDataType type) { // vector types not currently supported @@ -923,7 +991,7 @@ void TensorFromDLPack(const ::DLTensor& dl_tensor, phi::DenseTensor* dst) { memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (dl_tensor.device.device_type == kDLGPU) { + if (dl_tensor.device.device_type == kDLCUDA) { phi::GPUPlace dst_place = phi::GPUPlace(dl_tensor.device.device_id); phi::GPUPlace src_place = phi::GPUPlace(dl_tensor.device.device_id); dst_ptr = GetDstPtrByDLDataType(type, dst, dst_place); diff --git a/python/paddle/utils/dlpack.py b/python/paddle/utils/dlpack.py index 6679413549997a..c5801670d64083 100644 --- a/python/paddle/utils/dlpack.py +++ b/python/paddle/utils/dlpack.py @@ -70,6 +70,13 @@ def to_dlpack(x: Tensor) -> CapsuleType: >>> print(dlpack) >>> # doctest: +SKIP('the address will change in every run') + >>> #doctest: -SKIP + + >>> # dlpack capsule will be renamed to 'used_dltensor' after decoded + >>> y = paddle.utils.dlpack.from_dlpack(dlpack) + >>> print(dlpack) + >>> # doctest: +SKIP('the address will change in every run') + """ if in_dygraph_mode(): @@ -87,7 +94,8 @@ def to_dlpack(x: Tensor) -> CapsuleType: def from_dlpack(dlpack) -> Tensor: """ - Decodes a DLPack to a tensor. + Decodes a DLPack to a tensor. The returned Paddle tensor will share the memory with + the tensor from given dlpack. Args: dlpack (PyCapsule): a PyCapsule object with the dltensor. @@ -105,11 +113,17 @@ def from_dlpack(dlpack) -> Tensor: >>> x = paddle.to_tensor([[0.2, 0.3, 0.5, 0.9], ... [0.1, 0.2, 0.6, 0.7]]) >>> dlpack = paddle.utils.dlpack.to_dlpack(x) - >>> x = paddle.utils.dlpack.from_dlpack(dlpack) - >>> print(x) + >>> y = paddle.utils.dlpack.from_dlpack(dlpack) + >>> print(y) Tensor(shape=[2, 4], dtype=float32, place=Place(cpu), stop_gradient=True, [[0.20000000, 0.30000001, 0.50000000, 0.89999998], [0.10000000, 0.20000000, 0.60000002, 0.69999999]]) + >>> # data of tensor x is shared with tensor y + >>> y[0, 0] = 10.0 + >>> print(x) + Tensor(shape=[2, 4], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[10. , 0.30000001, 0.50000000, 0.89999998], + [0.10000000, 0.20000000, 0.60000002, 0.69999999]]) """ # Check the type of dlpack @@ -123,9 +137,9 @@ def from_dlpack(dlpack) -> Tensor: if hasattr(dlpack, '__dlpack__'): device = dlpack.__dlpack_device__() - # device is either CUDA or ROCm, we need to pass the current + # device is CUDA, we need to pass the current # stream - if device[0] in (DLDeviceType.kDLCUDA, DLDeviceType.kDLROCM): + if device[0] in (DLDeviceType.kDLCUDA,): stream = paddle.device.cuda.current_stream(device[1]) # cuda_stream is the pointer to the stream and it is a public # attribute, but it is not documented @@ -133,8 +147,6 @@ def from_dlpack(dlpack) -> Tensor: # with a value of 1 for CUDA # https://data-apis.org/array-api/latest/API_specification/array_object.html?dlpack-self-stream-none#dlpack-self-stream-none is_gpu = device[0] == DLDeviceType.kDLCUDA - # Since pytorch is not using PTDS by default, lets directly pass - # the legacy stream stream_ptr = ( 1 if is_gpu and stream.cuda_stream == 0 else stream.cuda_stream ) From 3783f7450236c354ad5fe0be4c1a571fe743f75c Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 13 Sep 2024 10:33:02 +0800 Subject: [PATCH 18/21] fix unitest --- test/cpp/fluid/framework/dlpack_tensor_test.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/cpp/fluid/framework/dlpack_tensor_test.cc b/test/cpp/fluid/framework/dlpack_tensor_test.cc index 6a71be8aab57eb..febbacd47fc9be 100644 --- a/test/cpp/fluid/framework/dlpack_tensor_test.cc +++ b/test/cpp/fluid/framework/dlpack_tensor_test.cc @@ -34,6 +34,9 @@ constexpr uint8_t GetDLDataTypeCode() { if (std::is_same::value) { return static_cast(kDLBfloat); } + if (std::is_same::value) { + return static_cast(kDLBool); + } return std::is_same::value || std::is_floating_point::value From 741f0784d8305d5e8d37c6f08d5eca90d14531f9 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 13 Sep 2024 14:41:18 +0800 Subject: [PATCH 19/21] reduce matrix size and loop time in test_dlpack --- test/legacy_test/test_dlpack.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/legacy_test/test_dlpack.py b/test/legacy_test/test_dlpack.py index d16adc40ddc02f..87827e2317f00b 100644 --- a/test/legacy_test/test_dlpack.py +++ b/test/legacy_test/test_dlpack.py @@ -119,10 +119,10 @@ def test_dlpack_deletion(self): if paddle.is_compiled_with_cuda(): places.append(base.CUDAPlace(0)) for place in places: - for _ in range(20): - a = paddle.rand( - shape=[1024 * 128, 1024], dtype="float32" - ).to(device=place) + for _ in range(4): + a = paddle.rand(shape=[3, 5], dtype="float32").to( + device=place + ) dlpack = paddle.utils.dlpack.to_dlpack(a) b = paddle.utils.dlpack.from_dlpack(dlpack) @@ -133,7 +133,7 @@ def test_to_dlpack_for_loop(self): if paddle.is_compiled_with_cuda(): places.append(base.CUDAPlace(0)) for place in places: - for _ in range(10): + for _ in range(4): x = paddle.rand([3, 5]).to(device=place) dlpack = paddle.utils.dlpack.to_dlpack(x) @@ -144,7 +144,7 @@ def test_to_dlpack_modification(self): if paddle.is_compiled_with_cuda(): places.append(base.CUDAPlace(0)) for place in places: - for _ in range(10): + for _ in range(4): x = paddle.rand([3, 5]).to(device=place) dlpack = paddle.utils.dlpack.to_dlpack(x) y = paddle.utils.dlpack.from_dlpack(dlpack) @@ -158,7 +158,7 @@ def test_to_dlpack_data_ptr_consistency(self): if paddle.is_compiled_with_cuda(): places.append(base.CUDAPlace(0)) for place in places: - for _ in range(10): + for _ in range(4): x = paddle.rand([3, 5]).to(device=place) dlpack = paddle.utils.dlpack.to_dlpack(x) y = paddle.utils.dlpack.from_dlpack(dlpack) From de95c8a503a0b0b3e6067e9b0b6aa51f9ccc2413 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sun, 15 Sep 2024 11:57:14 +0800 Subject: [PATCH 20/21] support directly converting from object that has '__dlpack__' attribute --- python/paddle/utils/dlpack.py | 46 +++++++++++++++++++++------------ test/legacy_test/test_dlpack.py | 20 ++++++++++---- 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/python/paddle/utils/dlpack.py b/python/paddle/utils/dlpack.py index c5801670d64083..5431b88f4a449a 100644 --- a/python/paddle/utils/dlpack.py +++ b/python/paddle/utils/dlpack.py @@ -15,6 +15,7 @@ from __future__ import annotations import enum +import warnings from typing import TYPE_CHECKING import paddle @@ -26,7 +27,7 @@ if TYPE_CHECKING: from typing_extensions import CapsuleType - from paddle import Tensor + from paddle import Any, Tensor __all__ = [ 'to_dlpack', @@ -88,17 +89,25 @@ def to_dlpack(x: Tensor) -> CapsuleType: return x.value().get_tensor()._to_dlpack() - check_type(x, 'x', (LoDTensor), 'to_dlpack') + check_type(x, "x", (LoDTensor), "to_dlpack") return x._to_dlpack() -def from_dlpack(dlpack) -> Tensor: +def from_dlpack(dlpack: Any) -> Tensor: """ Decodes a DLPack to a tensor. The returned Paddle tensor will share the memory with the tensor from given dlpack. Args: - dlpack (PyCapsule): a PyCapsule object with the dltensor. + dlpack (object with `__dlpack__` attribute, or a PyCapsule): + The tensor or DLPack capsule to convert. + + If `dlpack` is a tensor (or ndarray) object, it must support + the `__dlpack__` protocol (i.e., have a `dlpack.__dlpack__` + method). Otherwise `dlpack` may be a DLPack capsule, which is + an opaque `PyCapsule` instance, typically produced by a + `to_dlpack` function or method. + Returns: out (Tensor), a tensor decoded from DLPack. One thing to be noted, if we get @@ -109,7 +118,7 @@ def from_dlpack(dlpack) -> Tensor: .. code-block:: python >>> import paddle - >>> # x is a tensor with shape [2, 4] + >>> # From DLPack capsule >>> x = paddle.to_tensor([[0.2, 0.3, 0.5, 0.9], ... [0.1, 0.2, 0.6, 0.7]]) >>> dlpack = paddle.utils.dlpack.to_dlpack(x) @@ -124,23 +133,28 @@ def from_dlpack(dlpack) -> Tensor: Tensor(shape=[2, 4], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[10. , 0.30000001, 0.50000000, 0.89999998], [0.10000000, 0.20000000, 0.60000002, 0.69999999]]) - """ - # Check the type of dlpack - t = type(dlpack) - dlpack_flag = t.__module__ == 'builtins' and t.__name__ == 'PyCapsule' - if not dlpack_flag: - raise TypeError( - "The type of 'dlpack' in from_dlpack must be PyCapsule object," - f" but received {type(dlpack)}." - ) + >>> # Directly from external tensor that has '__dlpack__' attribute + >>> import numpy as np + >>> x = np.array([[0.2, 0.3, 0.5, 0.9], + ... [0.1, 0.2, 0.6, 0.7]]) + >>> y = paddle.utils.dlpack.from_dlpack(x) + >>> y[0, 0] = 10.0 + >>> # data of tensor x is shared with tensor y + >>> print(x) + [[10. 0.3 0.5 0.9] + [ 0.1 0.2 0.6 0.7]] + """ - if hasattr(dlpack, '__dlpack__'): + if hasattr(dlpack, "__dlpack__"): device = dlpack.__dlpack_device__() # device is CUDA, we need to pass the current # stream if device[0] in (DLDeviceType.kDLCUDA,): - stream = paddle.device.cuda.current_stream(device[1]) + with warnings.catch_warnings(): + # ignore deprecation warning + warnings.filterwarnings("ignore", category=UserWarning) + stream = paddle.device.cuda.current_stream(device[1]) # cuda_stream is the pointer to the stream and it is a public # attribute, but it is not documented # The array API specify that the default legacy stream must be passed diff --git a/test/legacy_test/test_dlpack.py b/test/legacy_test/test_dlpack.py index 87827e2317f00b..368a118c49b98f 100644 --- a/test/legacy_test/test_dlpack.py +++ b/test/legacy_test/test_dlpack.py @@ -165,13 +165,23 @@ def test_to_dlpack_data_ptr_consistency(self): self.assertEqual(x.data_ptr(), y.data_ptr()) + def test_to_dlpack_from_ext_tensor(self): + with dygraph_guard(): + places = [base.CPUPlace()] + if paddle.is_compiled_with_cuda(): + places.append(base.CUDAPlace(0)) + for place in places: + for _ in range(4): + x = np.random.randn(3, 5) + y = paddle.utils.dlpack.from_dlpack(x) -class TestRaiseError(unittest.TestCase): - def test_from_dlpack_raise_type_error(self): - self.assertRaises( - TypeError, paddle.utils.dlpack.from_dlpack, np.zeros(5) - ) + self.assertEqual( + x.__array_interface__['data'][0], y.data_ptr() + ) + np.testing.assert_allclose(x, y.numpy()) + +class TestRaiseError(unittest.TestCase): def test_to_dlpack_raise_type_error(self): self.assertRaises(TypeError, paddle.utils.dlpack.to_dlpack, np.zeros(5)) From 1163d33030631353ba00db3647939c5af6d9b201 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 18 Sep 2024 13:14:19 +0800 Subject: [PATCH 21/21] support CUDAPinnedPlace, add unitest for place consistency check, remove redundant function in tensor_util.cc --- paddle/fluid/framework/tensor_util.cc | 27 ++------------------------- test/legacy_test/test_dlpack.py | 6 +++++- 2 files changed, 7 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index fc77293591a937..b1eb1f70c3c9ad 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -846,31 +846,6 @@ phi::DataType GetDstPtrByDLDataType(DLDataType type) { } } -phi::Place GetPlaceFromPtr(void* data) { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -#ifdef PADDLE_WITH_CUDA -#if CUDA_VERSION >= 10000 - cudaPointerAttributes attr = {}; - cudaError_t status = cudaPointerGetAttributes(&attr, data); - if (status == cudaSuccess && attr.type == cudaMemoryTypeDevice) { - return phi::GPUPlace(attr.device); - } -#else - PADDLE_THROW( - phi::errors::Unimplemented("The GetPlaceFromPtr() method is only " - "supported when CUDA version >= 10.0.")); -#endif -#else - hipPointerAttribute_t attr = {}; - hipError_t status = hipPointerGetAttributes(&attr, data); - if (status == hipSuccess && attr.memoryType == hipMemoryTypeDevice) { - return phi::GPUPlace(attr.device); - } -#endif -#endif - return phi::CPUPlace(); -} - /* dlpack related code ref: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/DLConvertor.cpp @@ -934,6 +909,8 @@ phi::DenseTensor TensorFromDLPack(DLManagedTensor* src, Deleter deleter) { place = phi::CPUPlace(); } else if (src->dl_tensor.device.device_type == kDLCUDA) { place = phi::GPUPlace(); + } else if (src->dl_tensor.device.device_type == kDLCUDAHost) { + place = phi::GPUPinnedPlace(); } else { PADDLE_THROW(phi::errors::Unimplemented("Given Place is not supported")); } diff --git a/test/legacy_test/test_dlpack.py b/test/legacy_test/test_dlpack.py index bfdc37f3b22eb2..1319ab59aa8ac6 100644 --- a/test/legacy_test/test_dlpack.py +++ b/test/legacy_test/test_dlpack.py @@ -73,7 +73,7 @@ def test_dlpack_static(self): np.array([[1], [2], [3], [4]]).astype("int"), ) - def test_dlpack_dtype_conversion(self): + def test_dlpack_dtype_and_place_consistency(self): with dygraph_guard(): dtypes = [ "float16", @@ -89,6 +89,8 @@ def test_dlpack_dtype_conversion(self): places = [base.CPUPlace()] if paddle.is_compiled_with_cuda(): places.append(base.CUDAPlace(0)) + places.append(base.CUDAPinnedPlace()) + dtypes.append("bfloat16") data = np.ones((2, 3, 4)) for place in places: @@ -98,6 +100,7 @@ def test_dlpack_dtype_conversion(self): o = paddle.utils.dlpack.from_dlpack(dlpack) self.assertEqual(x.dtype, o.dtype) np.testing.assert_allclose(x.numpy(), o.numpy(), rtol=1e-05) + self.assertEqual(type(x.place), type(o.place)) complex_dtypes = ["complex64", "complex128"] for place in places: @@ -111,6 +114,7 @@ def test_dlpack_dtype_conversion(self): o = paddle.utils.dlpack.from_dlpack(dlpack) self.assertEqual(x.dtype, o.dtype) np.testing.assert_allclose(x.numpy(), o.numpy(), rtol=1e-05) + self.assertEqual(type(x.place), type(o.place)) def test_dlpack_deletion(self): # See Paddle issue 47171