Skip to content

[DLPACK] Support from_dlpack with shared memory#67927

Merged
HydrogenSulfate merged 40 commits intoPaddlePaddle:developfrom
HydrogenSulfate:support_dlpack
Sep 19, 2024
Merged

[DLPACK] Support from_dlpack with shared memory#67927
HydrogenSulfate merged 40 commits intoPaddlePaddle:developfrom
HydrogenSulfate:support_dlpack

Conversation

@HydrogenSulfate
Copy link
Contributor

@HydrogenSulfate HydrogenSulfate commented Sep 2, 2024

PR Category

User Experience

PR Types

Bug fixes

Description

Pcard-75624

修复前:

  1. from_dlpack通过显存拷贝的方式,实现转换来自其他深度学习框架的张量的功能,存在不必要的显存拷贝问题
  2. dlpack版本较老,不支持个别数据类型如bool

修复后:

  1. 通过from_blob函数直接持有生产者提供的张量数据,通过指针拷贝的方式,共享同一块数据区域,避免了数据拷贝问题,通过传递生产者的deleter避免显存泄露问题。
  2. 将 dlpack 从 v0.4 升级到 v0.8,与其它深度学习框架的行为对齐,包括支持了bool类型、CUDAPinned 设备,从而支持NVIDIA/warp工具组件,通过其单测

使用百万元素的张量进行测试,可支持与numpy、cupy、pytorch等高性能计算框架在CPU和GPU下的数据共享和张量互相转换:

(-表示不需要支持,表示支持)

from(row)/to(column) Paddle(CPU) Paddle(GPU) Pytorch(CPU) Pytorch(GPU) Cupy(GPU)
Paddle(CPU) - - -
Paddle(GPU) - -
Pytorch(CPU) - - -
Pytorch(GPU) - -
Cupy(GPU) - -

相关PR:NVIDIA/warp#313

import numpy as np


def dlpack_from_cupy_to_paddle():
    print("testing dlpack_from_cupy_to_paddle")
    import cupy as cp
    import paddle
    memory_pool = cp.get_default_memory_pool()
    for i in range(3):
        x = cp.zeros([100, 10000], dtype="float32")
        t = x.toDlpack()
        y = paddle.utils.dlpack.from_dlpack(t)

        # modify in both framework
        x[5, 5] = 2.0
        y[1, 0] = 1.0
        np.testing.assert_allclose(cp.asnumpy(x), y.numpy())

        used_bytes = memory_pool.used_bytes()
        used_megabytes = used_bytes / (1024 ** 2)
        print(f"{i} cupy mem: {used_megabytes:.2f} MB, paddle mem: {paddle.device.cuda.max_memory_allocated() / (1 << 20):.2f} MB")


def dlpack_from_paddle_to_cupy():
    print("testing dlpack_from_paddle_to_cupy")
    import cupy as cp
    import paddle
    memory_pool = cp.get_default_memory_pool()
    for i in range(3):
        x = paddle.randn([100, 10000], dtype="float32")
        t = paddle.utils.dlpack.to_dlpack(x)
        y = cp.from_dlpack(t)

        # modify in both framework
        x[5, 5] = 2.0
        y[1, 0] = 1.0
        np.testing.assert_allclose(cp.asnumpy(y), x.numpy())

        used_bytes = memory_pool.used_bytes()
        used_megabytes = used_bytes / (1024 ** 2)
        print(f"{i} cupy mem: {used_megabytes:.2f} MB, paddle mem: {paddle.device.cuda.max_memory_allocated() / (1 << 20):.2f} MB")


def dlpack_from_paddle_to_paddle():
    print("testing dlpack_from_paddle_to_paddle")
    import paddle
    from paddle.utils.dlpack import from_dlpack
    from paddle.utils.dlpack import to_dlpack
    for i in range(10):
        x = paddle.randn([100, 10000])
        t = to_dlpack(x)
        y = from_dlpack(t)

        assert x.data_ptr() == y.data_ptr()
        np.testing.assert_array_equal(x.numpy(), y.numpy())
        print(f"{i} paddle mem: {paddle.device.cuda.max_memory_allocated() / (1 << 20):.2f} MB")


def dlpack_from_torch_to_paddle():
    print("testing dlpack_from_torch_to_paddle")
    import paddle
    import torch
    for i in range(10):
        x = torch.randn(100, 10000, device='cuda:0')
        t = torch.utils.dlpack.to_dlpack(x)
        y = paddle.utils.dlpack.from_dlpack(t)

        # modify in both framework
        x[5, 5] = 2.0
        y[1, 0] = 1.0

        assert x.data_ptr() == y.data_ptr()
        np.testing.assert_array_equal(x.detach().cpu().numpy(), y.numpy())
        print(f"{i} paddle mem: {paddle.device.cuda.max_memory_allocated() / (1 << 20):.2f} MB, torch mem: {torch.cuda.max_memory_allocated() / (1 << 20):.2f} MB")


def dlpack_from_paddle_to_torch():
    print("testing dlpack_from_paddle_to_torch")
    import paddle
    import torch
    for i in range(10):
        x = paddle.randn([100, 10000])
        t = paddle.utils.dlpack.to_dlpack(x)
        y = torch.utils.dlpack.from_dlpack(t)

        # modify in both framework
        x[5, 5] = 2.0
        y[1, 0] = 1.0

        assert x.data_ptr() == y.data_ptr()
        np.testing.assert_array_equal(x.numpy(), y.detach().cpu().numpy())
        print(f"{i} paddle mem: {paddle.device.cuda.max_memory_allocated() / (1 << 20):.2f} MB, torch mem: {torch.cuda.max_memory_allocated() / (1 << 20):.2f} MB")


def dlpack_from_paddle_to_paddle_cpu():
    print("testing dlpack_from_paddle_to_paddle")
    import paddle
    from paddle.utils.dlpack import from_dlpack
    from paddle.utils.dlpack import to_dlpack
    for i in range(10):
        x = paddle.randn([100, 10000]).to("cpu")
        t = to_dlpack(x)
        y = from_dlpack(t)

        # modify in both framework
        x[5, 5] = 2.0
        y[1, 0] = 1.0

        assert ('cpu' in str(x.place) and 'cpu' in str(y.place)), f"{x.place}, {y.place}"
        assert x.data_ptr() == y.data_ptr()
        np.testing.assert_array_equal(x.numpy(), y.numpy())

        print(f"{i} paddle mem: {paddle.device.cuda.max_memory_allocated() / (1 << 20):.2f} MB")


def dlpack_from_torch_to_paddle_cpu():
    print("testing dlpack_from_torch_to_paddle")
    import paddle
    import torch
    for i in range(10):
        x = torch.randn(100, 10000, device='cpu')
        t = torch.utils.dlpack.to_dlpack(x)
        y = paddle.utils.dlpack.from_dlpack(t)
        assert 'cpu' in str(y.place), y.place

        # modify in both framework
        x[5, 5] = 2.0
        y[1, 0] = 1.0

        assert x.data_ptr() == y.data_ptr()
        np.testing.assert_array_equal(x.detach().numpy(), y.numpy())
        print(f"{i} paddle mem: {paddle.device.cuda.max_memory_allocated() / (1 << 20):.2f} MB, torch mem: {torch.cuda.max_memory_allocated() / (1 << 20):.2f} MB")


def dlpack_from_paddle_to_torch_cpu():
    print("testing dlpack_from_paddle_to_torch")
    import paddle
    import torch
    for i in range(10):
        x = paddle.randn([100, 10000]).to("cpu")
        t = paddle.utils.dlpack.to_dlpack(x)
        y = torch.utils.dlpack.from_dlpack(t)

        # modify in both framework
        x[5, 5] = 2.0
        y[1, 0] = 1.0

        assert 'cpu' in str(y.device), y.device
        assert x.data_ptr() == y.data_ptr()
        np.testing.assert_array_equal(x.numpy(), y.detach().numpy())
        print(f"{i} paddle mem: {paddle.device.cuda.max_memory_allocated() / (1 << 20):.2f} MB, torch mem: {torch.cuda.max_memory_allocated() / (1 << 20):.2f} MB")


if __name__ == "__main__":
    # paddle <-> paddle
    dlpack_from_paddle_to_paddle()

    # paddle <-> cupy
    dlpack_from_cupy_to_paddle()
    dlpack_from_paddle_to_cupy()

    # paddle <-> pytorch
    dlpack_from_torch_to_paddle()
    dlpack_from_paddle_to_torch()

    # paddle <-> paddle(cpu)
    dlpack_from_paddle_to_paddle_cpu()
    # paddle <-> pytorch(cpu)
    dlpack_from_torch_to_paddle_cpu()
    dlpack_from_paddle_to_torch_cpu()

输出:

testing dlpack_from_paddle_to_paddle
W0911 20:49:53.677103 20470 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 12.0, Runtime API Version: 11.6
W0911 20:49:53.713279 20470 gpu_resources.cc:164] device: 0, cuDNN Version: 8.4.
0 paddle mem: 3.81 MB
1 paddle mem: 7.63 MB
2 paddle mem: 7.63 MB
3 paddle mem: 7.63 MB
4 paddle mem: 7.63 MB
5 paddle mem: 7.63 MB
6 paddle mem: 7.63 MB
7 paddle mem: 7.63 MB
8 paddle mem: 7.63 MB
9 paddle mem: 7.63 MB
testing dlpack_from_cupy_to_paddle
0 cupy mem: 3.81 MB, paddle mem: 7.63 MB
1 cupy mem: 3.81 MB, paddle mem: 7.63 MB
2 cupy mem: 3.81 MB, paddle mem: 7.63 MB
testing dlpack_from_paddle_to_cupy
0 cupy mem: 0.00 MB, paddle mem: 7.63 MB
1 cupy mem: 0.00 MB, paddle mem: 7.63 MB
2 cupy mem: 0.00 MB, paddle mem: 7.63 MB
testing dlpack_from_torch_to_paddle
0 paddle mem: 7.63 MB, torch mem: 3.81 MB
1 paddle mem: 7.63 MB, torch mem: 7.63 MB
2 paddle mem: 7.63 MB, torch mem: 7.63 MB
3 paddle mem: 7.63 MB, torch mem: 7.63 MB
4 paddle mem: 7.63 MB, torch mem: 7.63 MB
5 paddle mem: 7.63 MB, torch mem: 7.63 MB
6 paddle mem: 7.63 MB, torch mem: 7.63 MB
7 paddle mem: 7.63 MB, torch mem: 7.63 MB
8 paddle mem: 7.63 MB, torch mem: 7.63 MB
9 paddle mem: 7.63 MB, torch mem: 7.63 MB
testing dlpack_from_paddle_to_torch
0 paddle mem: 7.63 MB, torch mem: 7.63 MB
1 paddle mem: 7.63 MB, torch mem: 7.63 MB
2 paddle mem: 7.63 MB, torch mem: 7.63 MB
3 paddle mem: 7.63 MB, torch mem: 7.63 MB
4 paddle mem: 7.63 MB, torch mem: 7.63 MB
5 paddle mem: 7.63 MB, torch mem: 7.63 MB
6 paddle mem: 7.63 MB, torch mem: 7.63 MB
7 paddle mem: 7.63 MB, torch mem: 7.63 MB
8 paddle mem: 7.63 MB, torch mem: 7.63 MB
9 paddle mem: 7.63 MB, torch mem: 7.63 MB
testing dlpack_from_paddle_to_paddle
0 paddle mem: 7.63 MB
1 paddle mem: 7.63 MB
2 paddle mem: 7.63 MB
3 paddle mem: 7.63 MB
4 paddle mem: 7.63 MB
5 paddle mem: 7.63 MB
6 paddle mem: 7.63 MB
7 paddle mem: 7.63 MB
8 paddle mem: 7.63 MB
9 paddle mem: 7.63 MB
testing dlpack_from_torch_to_paddle
0 paddle mem: 7.63 MB, torch mem: 7.63 MB
1 paddle mem: 7.63 MB, torch mem: 7.63 MB
2 paddle mem: 7.63 MB, torch mem: 7.63 MB
3 paddle mem: 7.63 MB, torch mem: 7.63 MB
4 paddle mem: 7.63 MB, torch mem: 7.63 MB
5 paddle mem: 7.63 MB, torch mem: 7.63 MB
6 paddle mem: 7.63 MB, torch mem: 7.63 MB
7 paddle mem: 7.63 MB, torch mem: 7.63 MB
8 paddle mem: 7.63 MB, torch mem: 7.63 MB
9 paddle mem: 7.63 MB, torch mem: 7.63 MB
testing dlpack_from_paddle_to_torch
0 paddle mem: 7.63 MB, torch mem: 7.63 MB
1 paddle mem: 7.63 MB, torch mem: 7.63 MB
2 paddle mem: 7.63 MB, torch mem: 7.63 MB
3 paddle mem: 7.63 MB, torch mem: 7.63 MB
4 paddle mem: 7.63 MB, torch mem: 7.63 MB
5 paddle mem: 7.63 MB, torch mem: 7.63 MB
6 paddle mem: 7.63 MB, torch mem: 7.63 MB
7 paddle mem: 7.63 MB, torch mem: 7.63 MB
8 paddle mem: 7.63 MB, torch mem: 7.63 MB
9 paddle mem: 7.63 MB, torch mem: 7.63 MB

[pull] develop from PaddlePaddle:develop
[pull] develop from PaddlePaddle:develop
@paddle-bot
Copy link

paddle-bot bot commented Sep 2, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@HydrogenSulfate HydrogenSulfate changed the title Support dlpack Support from_dlpack with shared memory Sep 2, 2024
@HydrogenSulfate HydrogenSulfate changed the title Support from_dlpack with shared memory [WIP] Support from_dlpack with shared memory Sep 2, 2024
@HydrogenSulfate HydrogenSulfate changed the title [WIP] Support from_dlpack with shared memory Support from_dlpack with shared memory Sep 10, 2024
@HydrogenSulfate HydrogenSulfate changed the title Support from_dlpack with shared memory [DLPACK] Support from_dlpack with shared memory Sep 10, 2024
@HydrogenSulfate HydrogenSulfate mentioned this pull request Sep 12, 2024
14 tasks
@DesmonDay DesmonDay self-requested a review September 13, 2024 06:49
Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾 for type annotations change



def from_dlpack(dlpack: CapsuleType) -> Tensor:
def from_dlpack(dlpack: Any) -> Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是易用性提升么?原来只支持 dlpack 的现在只要实现 __dlpack__ Protocol 就可以传了

从类型上比较建议使用 SupportDLPack | CapsuleType

class SupportDLPack(Protocol):
    def __dlpack__(self) -> CapsuleType: ...

不过目前这样也没啥大问题,可以下个 PR

Copy link
Contributor

@DesmonDay DesmonDay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

strides[i] = _strides[i];
if (shape[i] < 2) {
strides[i] = 1;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的改动是等价的么?

Copy link
Contributor Author

@HydrogenSulfate HydrogenSulfate Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的改动是等价的么?

原先的strides计算方法应该是有问题的,没有考虑x是non-contiguous的情况,而是直接根据shape算strides,这会导致转化后的dlpack张量一定是contiguous。参考pytorch的做法,应该直接使用原张量的strides即可https://github.com/pytorch/pytorch/blob/db80b98ec460ca5b2fd84c1dfb6426925f64c8cc/aten/src/ATen/DLConvertor.cpp#L267-L276

根据你说的我测了下这个PR转换前后的strides,好像from_dlpack对strides的处理还有点问题,我需要再修改一下,并再加一个strides单测。

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@HydrogenSulfate HydrogenSulfate merged commit c188b1d into PaddlePaddle:develop Sep 19, 2024
@HydrogenSulfate HydrogenSulfate deleted the support_dlpack branch September 19, 2024 11:58
@HydrogenSulfate HydrogenSulfate linked an issue Oct 20, 2024 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

关于from_dlpack

4 participants