Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions python/paddle/utils/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import enum
import warnings
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Protocol

import paddle

Expand All @@ -27,14 +27,22 @@
if TYPE_CHECKING:
from typing_extensions import CapsuleType

from paddle import Any, Tensor
from paddle import Tensor

__all__ = [
'to_dlpack',
'from_dlpack',
]


class SupportDLPack(Protocol):
def __dlpack__(self) -> CapsuleType:
pass

def __dlpack_device__(self) -> tuple[enum.IntEnum, int]:
pass


class DLDeviceType(enum.IntEnum):
kDLCPU = (1,)
kDLCUDA = (2,)
Expand All @@ -53,15 +61,16 @@ def to_dlpack(x: Tensor) -> CapsuleType:
Encodes a tensor to DLPack.

Args:
x (Tensor): The input tensor, and the data type can be `bool`, `float16`, `float32`,
`float64`, `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`,
`complex128`.
x (Tensor): The input tensor, and the data type can be ``bool``, ``float16``, ``float32``,
``float64``, ``int8``, ``int16``, ``int32``, ``int64``, ``uint8``, ``complex64``,
``complex128``.

Returns:
dltensor, and the data type is PyCapsule.

Examples:
.. code-block:: python
:name: code-paddle-to-paddle

>>> import paddle
>>> # x is a tensor with shape [2, 4]
Expand All @@ -78,6 +87,20 @@ def to_dlpack(x: Tensor) -> CapsuleType:
>>> print(dlpack)
>>> # doctest: +SKIP('the address will change in every run')
<capsule object "used_dltensor" at 0x7f6103c681b0>

.. code-block:: python
:name: code-paddle-to-torch

>>> # doctest: +SKIP('torch will not be installed')
>>> # type: ignore
>>> # convert tensor from paddle to other framework using to_dlpack
>>> import torch

>>> x = paddle.randn([2, 4]).to(device="cpu")
>>> y = torch.from_dlpack(paddle.utils.dlpack.to_dlpack(x))
>>> print(y.shape)
torch.Size([2, 4])
>>> # doctest: -SKIP
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里要不要拆分成两段示例代码?看起来比较独立

拆分示例代码需要为每段添加 name(可以全局搜 :name: 参考)

然后中文文档需要分别引用两段代码,可参考

https://github.com/PaddlePaddle/docs/blob/3db4cd854485ca4b9c74dd3ed8364be25746a776/docs/api/paddle/diag_cn.rst?plain=1#L31-L39

"""

if in_dygraph_mode():
Expand All @@ -93,36 +116,42 @@ def to_dlpack(x: Tensor) -> CapsuleType:
return x._to_dlpack()


def from_dlpack(dlpack: Any) -> Tensor:
def from_dlpack(dlpack: SupportDLPack | CapsuleType) -> Tensor:
"""
Decodes a DLPack to a tensor. The returned Paddle tensor will share the memory with
the tensor from given dlpack.

Args:
dlpack (object with `__dlpack__` attribute, or a PyCapsule):
The tensor or DLPack capsule to convert.
dlpack (SupportDLPack | CapsuleType): A PyCapsule object with the dltensor,
or that implements '__dlpack__' and '__dlpack_device__' methods.

If `dlpack` is a tensor (or ndarray) object, it must support
the `__dlpack__` protocol (i.e., have a `dlpack.__dlpack__`
method). Otherwise `dlpack` may be a DLPack capsule, which is
an opaque `PyCapsule` instance, typically produced by a
`to_dlpack` function or method.


Returns:
out (Tensor), a tensor decoded from DLPack. One thing to be noted, if we get
an input dltensor with data type as `bool`, we return the decoded
tensor as `uint8`.
out (Tensor): A tensor decoded from DLPack. The data type of returned tensor
can be one of: ``int32``, ``int64``, ``float16``, ``float32`` and ``float64``.
The device of returned tensor can be one of: ``CPU``, ``CUDAPlace``, ``CUDAPinnedPlace``.

Examples:
.. code-block:: python
:name: code-paddle-from-paddle

>>> import paddle
>>> # From DLPack capsule
>>> x = paddle.to_tensor([[0.2, 0.3, 0.5, 0.9],
... [0.1, 0.2, 0.6, 0.7]])
... [0.1, 0.2, 0.6, 0.7]], place="cpu")
>>> dlpack = paddle.utils.dlpack.to_dlpack(x)

>>> y = paddle.utils.dlpack.from_dlpack(dlpack)
>>> # dlpack capsule will be renamed to 'used_dltensor' after decoded
>>> print(dlpack)
>>> # doctest: +SKIP('the address will change in every run')
<capsule object "used_dltensor" at 0x7f6103c681b0>

>>> print(y)
Tensor(shape=[2, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
[[0.20000000, 0.30000001, 0.50000000, 0.89999998],
Expand All @@ -134,7 +163,10 @@ def from_dlpack(dlpack: Any) -> Tensor:
[[10. , 0.30000001, 0.50000000, 0.89999998],
[0.10000000, 0.20000000, 0.60000002, 0.69999999]])

>>> # Directly from external tensor that has '__dlpack__' attribute
.. code-block:: python
:name: code-paddle-from-numpy

>>> # Directly from external tensor that implements '__dlpack__' and '__dlpack_device__' methods
>>> import numpy as np
>>> x = np.array([[0.2, 0.3, 0.5, 0.9],
... [0.1, 0.2, 0.6, 0.7]])
Expand Down