Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
b114e78
[DLPack] Bump dlpack to v1.1
SigureMo Sep 10, 2025
a40fb4d
remove unused DLPACK_TAG
SigureMo Sep 10, 2025
a4a237f
add versioned policy
SigureMo Sep 11, 2025
ee91c40
add phi deps
SigureMo Sep 11, 2025
ad7a308
use template
SigureMo Sep 11, 2025
c65294c
Merge branch 'develop' into dlpack/bump-to-dlpack-v11
SigureMo Sep 11, 2025
c5ac55a
add ut and fix dtype conversion
SigureMo Sep 11, 2025
d33a140
ignore all dirty submodules
SigureMo Sep 11, 2025
e0ce30f
add missing `TEST_API`
SigureMo Sep 11, 2025
9136519
add include for paddle/utils/test_macros.h
SigureMo Sep 11, 2025
7331552
remove meaningless test
SigureMo Sep 11, 2025
4806a4d
update uts
SigureMo Sep 11, 2025
b40643c
use kw-only arg stream
SigureMo Sep 11, 2025
fd8fec3
remove unused GetDstPtrByDLDataType
SigureMo Sep 12, 2025
afedf43
add copy and dl_device for __dlpack__
SigureMo Sep 12, 2025
ffbf8bd
Merge branch 'develop' into dlpack/bump-to-dlpack-v11
SigureMo Sep 12, 2025
cfc2770
support copy and device in from_dlpack
SigureMo Sep 12, 2025
01243cc
fix ut, remove explict __dlpack__ call
SigureMo Sep 12, 2025
eb33095
set `DLPACK_FLAG_BITMASK_IS_COPIED` when copied
SigureMo Sep 12, 2025
de0d28b
add cuda from_dlpack(tensor) case
SigureMo Sep 12, 2025
a13cefe
add debug info
SigureMo Sep 14, 2025
b272cb7
empty commit
SigureMo Sep 14, 2025
855740b
bump xdoctest to 1.3.0
SigureMo Sep 14, 2025
90a08e1
-skip after example check
SigureMo Sep 14, 2025
c6b0682
empty commit
SigureMo Sep 14, 2025
8c39863
empty commit
SigureMo Sep 14, 2025
38c256c
try to install latest numpy in ci
SigureMo Sep 15, 2025
7a381f9
refactor to avoid missing std::move
SigureMo Sep 15, 2025
2537547
add log for deleter called
SigureMo Sep 15, 2025
df4b917
revert copilot change
SigureMo Sep 15, 2025
b8ab744
refactor code to get better readability
SigureMo Sep 15, 2025
81f6fde
update example code
SigureMo Sep 15, 2025
21032f1
use singleton
SigureMo Sep 16, 2025
e038f25
Merge branch 'develop' into dlpack/bump-to-dlpack-v11
SigureMo Sep 16, 2025
0cb5b7f
cleanup debug info
SigureMo Sep 16, 2025
e9d8bf6
adjust order
SigureMo Sep 16, 2025
9caa727
cleanup unused code;
SigureMo Sep 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cmake/external/dlpack.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
include(ExternalProject)

set(DLPACK_PREFIX_DIR ${THIRD_PARTY_PATH}/dlpack)
set(DLPACK_TAG v0.8)
set(DLPACK_INCLUDE_DIR ${THIRD_PARTY_PATH}/dlpack/src/extern_dlpack/include)
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/dlpack)
include_directories(${SOURCE_DIR}/include)
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ cc_library(
cc_library(
tensor
SRCS tensor_util.cc
DEPS data_type device_context phi common)
DEPS data_type dlpack_tensor device_context phi common)

cc_library(
lod_tensor
Expand Down Expand Up @@ -549,7 +549,7 @@ cc_library(
cc_library(
dlpack_tensor
SRCS dlpack_tensor.cc
DEPS tensor dlpack)
DEPS dlpack phi)

cc_library(
op_compatible_info
Expand Down
298 changes: 225 additions & 73 deletions paddle/fluid/framework/dlpack_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,82 @@ namespace paddle {
namespace framework {

namespace internal {
class PaddleDeleterManager {
public:
static PaddleDeleterManager &Instance() {
static PaddleDeleterManager instance;
return instance;
}

void AddDeleter(void *ptr, std::function<void(phi::Allocation *)> deleter) {
std::lock_guard<std::mutex> lock(mutex_);
ptr_to_deleter_[ptr] = deleter;
}

static void DeleterBridge(phi::Allocation *alloc) {
std::lock_guard<std::mutex> lock(PaddleDeleterManager::Instance().mutex_);
auto &ptr_to_deleter = PaddleDeleterManager::Instance().ptr_to_deleter_;
auto it = ptr_to_deleter.find(static_cast<void *>(alloc->ptr()));
if (it != ptr_to_deleter.end()) {
it->second(alloc); // call the deleter
ptr_to_deleter.erase(it); // remove the entry from the map safely
}
}

private:
std::unordered_map<void *, std::function<void(phi::Allocation *)>>
ptr_to_deleter_;
std::mutex mutex_;
};

template <typename T>
static ::DLDataType GetDLDataTypeCode() {
phi::DenseTensor from_blob(void *data,
T *src,
const phi::DDim &shape,
const phi::DDim &strides,
phi::DataType dtype,
const phi::Place &place,
const Deleter &deleter) {
auto meta = phi::DenseTensorMeta(dtype, shape, strides);

phi::Allocation::DeleterFnPtr f = nullptr;
if (deleter) {
auto g = [deleter, src](phi::Allocation *p) {
if (src->manager_ctx) {
deleter(src);
}
};

PaddleDeleterManager::Instance().AddDeleter(data, std::move(g));

f = PaddleDeleterManager::DeleterBridge;
}

// Calculate the number of elements of underlying storage
size_t size = 1;
for (auto i = 0; i < shape.size(); ++i) {
if (shape[i] == 0) {
size = 0;
break;
}
size += strides[i] * (shape[i] - 1);
}

auto alloc =
std::make_shared<phi::Allocation>(data, size * SizeOf(dtype), f, place);
return phi::DenseTensor(alloc, meta);
}

template <typename T>
::DLDataType GetDLDataTypeCode() {
::DLDataType dtype;
if (std::is_same<T, phi::dtype::complex<float>>::value ||
std::is_same<T, phi::dtype::complex<double>>::value) {
dtype.code = kDLComplex;
} else if (std::is_same<T, phi::dtype::float8_e4m3fn>::value) {
dtype.code = kDLFloat8_e4m3fn;
} else if (std::is_same<T, phi::dtype::float8_e5m2>::value) {
dtype.code = kDLFloat8_e5m2;
} else if (std::is_same<T, phi::dtype::bfloat16>::value) {
dtype.code = kDLBfloat;
} else if (std::is_same<T, phi::dtype::float16>::value ||
Expand Down Expand Up @@ -63,7 +133,7 @@ static std::unordered_map<int, ::DLDataType> CreateDLDataTypeMap() {
return result;
}

static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
static ::DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
static auto type_to_dtype_map = CreateDLDataTypeMap();
static auto type_to_dtype_map_end_it = type_to_dtype_map.end();
auto it = type_to_dtype_map.find(static_cast<int>(type));
Expand All @@ -72,7 +142,6 @@ static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
common::errors::InvalidArgument(
"Unsupported data type (%s).", DataTypeToString(type)));
return it->second;
#undef REG_DL_DATA_TYPE
}

struct DLDeviceVisitor {
Expand Down Expand Up @@ -138,26 +207,116 @@ struct DLDeviceVisitor {
};
} // namespace internal

phi::DataType DLDataTypeToPhiDataType(::DLDataType type) {
// vector types not currently supported
PADDLE_ENFORCE_LE(
type.lanes,
1,
common::errors::Unimplemented("Vector type is not supported currently."));

switch (type.bits) {
case 8:
if (type.code == kDLBool) return phi::DataType::BOOL;
if (type.code == kDLInt) return phi::DataType::INT8;
if (type.code == kDLUInt) return phi::DataType::UINT8;
if (type.code == kDLFloat8_e4m3fn) return phi::DataType::FLOAT8_E4M3FN;
if (type.code == kDLFloat8_e5m2) return phi::DataType::FLOAT8_E5M2;
PADDLE_THROW(common::errors::Unimplemented(
"DLDataType code <%d> is illegal when DLDataType.bits is <%d>.",
type.code,
type.bits));
case 16:
if (type.code == kDLInt) return phi::DataType::INT16;
if (type.code == kDLFloat) return phi::DataType::FLOAT16;
if (type.code == kDLBfloat) return phi::DataType::BFLOAT16;
PADDLE_THROW(common::errors::Unimplemented(
"DLDataType code <%d> is illegal when DLDataType.bits is <%d>.",
type.code,
type.bits));
case 32:
if (type.code == kDLInt) return phi::DataType::INT32;
if (type.code == kDLFloat) return phi::DataType::FLOAT32;
PADDLE_THROW(common::errors::Unimplemented(
"DLDataType code <%d> is illegal when DLDataType.bits is <%d>.",
type.code,
type.bits));
case 64:
if (type.code == kDLInt) return phi::DataType::INT64;
if (type.code == kDLFloat) return phi::DataType::FLOAT64;
if (type.code == kDLComplex) return phi::DataType::COMPLEX64;
PADDLE_THROW(common::errors::Unimplemented(
"DLDataType code <%d> is illegal when DLDataType.bits is <%d>.",
type.code,
type.bits));
case 128:
if (type.code == kDLComplex) return phi::DataType::COMPLEX128;
PADDLE_THROW(common::errors::Unimplemented(
"DLDataType code <%d> is illegal when DLDataType.bits is <%d>.",
type.code,
type.bits));
default:
PADDLE_THROW(common::errors::Unimplemented(
"Unsupported DLDataType.bits %d.", type.bits));
}
}

::DLDataType PhiDataTypeToDLDataType(phi::DataType dtype) {
return internal::GetDLDataTypeFromTypeIndex(
framework::TransToProtoVarType(dtype));
}

phi::Place DLDeviceToPlace(const DLDevice &dl_device) {
phi::Place place;
if (dl_device.device_type == kDLCPU) {
place = phi::CPUPlace();
} else if (dl_device.device_type == kDLCUDA) {
place = phi::GPUPlace(dl_device.device_id);
} else if (dl_device.device_type == kDLCUDAHost) {
place = phi::GPUPinnedPlace();
} else {
PADDLE_THROW(common::errors::Unimplemented("Given Place is not supported"));
}
return place;
}

DLDevice PlaceToDLDevice(const phi::Place &place) {
return phi::VisitPlace(place, internal::DLDeviceVisitor());
}

template <typename T>
struct PaddleDLMTensor {
phi::DenseTensor handle;
DLManagedTensor tensor;
T tensor;
};

static void deleter(DLManagedTensor *self) {
template <typename T>
static void deleter(T *self) {
if (self && self->manager_ctx) {
delete[] self->dl_tensor
.shape; // delete shape allocated in toDLPack manually
.shape; // delete shape allocated in ToDLPack manually
delete[] self->dl_tensor
.strides; // delete strides allocated in toDLPack manually
delete static_cast<PaddleDLMTensor *>(self->manager_ctx);
.strides; // delete strides allocated in ToDLPack manually
delete static_cast<PaddleDLMTensor<T> *>(self->manager_ctx);
}
}

DLManagedTensor *toDLPack(const phi::DenseTensor &src) {
PaddleDLMTensor *pdDLMTensor(new PaddleDLMTensor);
template <class T>
void FillVersionInfo(T *tensor, uint64_t flags) {}

template <>
void FillVersionInfo<DLManagedTensorVersioned>(DLManagedTensorVersioned *tensor,
uint64_t flags) {
tensor->flags = flags;
tensor->version.major = DLPACK_MAJOR_VERSION;
tensor->version.minor = DLPACK_MINOR_VERSION;
}

template <typename T>
T *ToDLPackImpl(const phi::DenseTensor &src, uint64_t flags) {
PaddleDLMTensor<T> *pdDLMTensor(new PaddleDLMTensor<T>);
pdDLMTensor->handle = const_cast<phi::DenseTensor &>(src);
pdDLMTensor->tensor.manager_ctx = pdDLMTensor;
pdDLMTensor->tensor.deleter = &deleter;
pdDLMTensor->tensor.deleter = &deleter<T>;

// init ndim
using DimType = decltype(pdDLMTensor->tensor.dl_tensor.ndim); // int32_t
Expand All @@ -181,81 +340,74 @@ DLManagedTensor *toDLPack(const phi::DenseTensor &src) {
strides[i] = 1;
}
}
pdDLMTensor->tensor.dl_tensor.strides = strides;

pdDLMTensor->tensor.dl_tensor.data = const_cast<void *>(src.data());
auto place = src.place();
pdDLMTensor->tensor.dl_tensor.device =
phi::VisitPlace(place, internal::DLDeviceVisitor());
pdDLMTensor->tensor.dl_tensor.dtype = internal::GetDLDataTypeFromTypeIndex(
framework::TransToProtoVarType(src.dtype()));
pdDLMTensor->tensor.dl_tensor.strides = strides;
pdDLMTensor->tensor.dl_tensor.device = PlaceToDLDevice(src.place());
pdDLMTensor->tensor.dl_tensor.dtype = PhiDataTypeToDLDataType(src.dtype());
pdDLMTensor->tensor.dl_tensor.byte_offset = 0;
FillVersionInfo(&(pdDLMTensor->tensor), flags);
return &(pdDLMTensor->tensor);
}

DLPackTensor::DLPackTensor(const phi::DenseTensor &tensor, LaneType lanes)
: t_{}, shape_{} {
// init data, data buffer
t_.data = const_cast<void *>(tensor.data());

// init device, DLDevice type with device_type and device_id
auto place = tensor.place();
t_.device = phi::VisitPlace(place, internal::DLDeviceVisitor());

// init dtype
t_.dtype = internal::GetDLDataTypeFromTypeIndex(
framework::TransToProtoVarType(tensor.dtype()));
t_.dtype.lanes = lanes;

// init ndim, tensor rank
auto &dims = tensor.dims();
using DimType = decltype(t_.ndim); // int
t_.ndim = static_cast<DimType>(dims.size());

// init shape, tensor dims
t_.shape = shape_;
for (DimType i = 0; i < t_.ndim; ++i) {
t_.shape[i] = dims[i];
}

// init strides, nullptr means the tensor is compact
t_.strides = nullptr;

// init byte_offset
t_.byte_offset = 0;
DLManagedTensor *ToDLPack(const phi::DenseTensor &src, uint64_t flags) {
return ToDLPackImpl<DLManagedTensor>(src, flags);
}

::DLManagedTensor *DLPackTensor::ToDLManagedTensor() {
// init shape
auto shape = new int64_t[t_.ndim];
using DimType = decltype(t_.ndim); // int
for (DimType i = 0; i < t_.ndim; ++i) {
shape[i] = t_.shape[i];
}
t_.shape = shape;
DLManagedTensorVersioned *ToDLPackVersioned(const phi::DenseTensor &src,
uint64_t flags) {
return ToDLPackImpl<DLManagedTensorVersioned>(src, flags);
}

// init strides
auto strides = new int64_t[t_.ndim];
for (DimType i = 0; i < t_.ndim; ++i) {
strides[i] = 1;
}
for (DimType i = t_.ndim - 2; i >= 0; --i) {
strides[i] = t_.shape[i + 1] * strides[i + 1];
template <typename T>
phi::DenseTensor FromDLPackImpl(T *src, Deleter deleter) {
std::vector<int64_t> shape_vec;
std::copy(src->dl_tensor.shape,
src->dl_tensor.shape + src->dl_tensor.ndim,
std::back_inserter(shape_vec));

phi::Place place = DLDeviceToPlace(src->dl_tensor.device);
phi::DataType dtype = DLDataTypeToPhiDataType(src->dl_tensor.dtype);

if (!src->dl_tensor.strides) {
return internal::from_blob(
src->dl_tensor.data,
src,
common::make_ddim(shape_vec),
phi::DenseTensorMeta::calc_strides(common::make_ddim(shape_vec)),
dtype,
place,
std::move(deleter));
} else {
std::vector<int64_t> strides_vec;
std::copy(src->dl_tensor.strides,
src->dl_tensor.strides + src->dl_tensor.ndim,
std::back_inserter(strides_vec));
return internal::from_blob(src->dl_tensor.data,
src,
common::make_ddim(shape_vec),
common::make_ddim(strides_vec),
dtype,
place,
deleter);
}
t_.strides = strides;

auto tensor = new DLManagedTensor;
tensor->dl_tensor = t_;
}

tensor->deleter = [](DLManagedTensor *arg) {
delete[] arg->dl_tensor.shape;
delete[] arg->dl_tensor.strides;
delete arg;
template <typename T>
phi::DenseTensor FromDLPackImpl(T *src) {
auto deleter = [src](void *self [[maybe_unused]]) {
if (src->deleter) {
src->deleter(src);
}
};
return FromDLPackImpl<T>(src, std::move(deleter));
}

tensor->manager_ctx = nullptr;
phi::DenseTensor FromDLPack(DLManagedTensor *src) {
return FromDLPackImpl<DLManagedTensor>(src);
}

return tensor;
phi::DenseTensor FromDLPackVersioned(DLManagedTensorVersioned *src) {
return FromDLPackImpl<DLManagedTensorVersioned>(src);
}

} // namespace framework
Expand Down
Loading