Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nn.functional.sparse_attention and some test cases, test=develop … #36551

Merged

Conversation

Liu-xiandong
Copy link
Member

PR types

New features

PR changes

APIs

Describe

cherry-pick #PR35757
Add paddle.nn.functional.sparse_attention API

  • 本个PR主要将sparse_attention功能在python层进行了一层封装,OP的主体代码见:#PR35676

  • 此外,对于封装的python 接口,增加了相应的单测。

Example

import paddle
import numpy as np

query_data = np.array([[[[0, 1,], [2, 3], [ 0, 1], [2, 3]]]]).astype("float32")
key_data = np.array([[[[0, 1,], [2, 3], [ 0, 1], [2, 3]]]]).astype("float32")
value_data = np.array([[[[0, 1,], [2, 3], [ 0, 1], [2, 3]]]]).astype("float32")
sparse_csr_offset_data = np.array([[[0, 2, 4, 6, 8]]]).astype("int32")
sparse_csr_columns_data = np.array([[[0, 1, 0, 1, 2, 3, 2, 3]]]).astype("int32")
print(query_data.shape)
# (1, 1, 4, 2)
print(sparse_csr_offset_data.shape)
# (1, 1, 5)
print(sparse_csr_columns_data.shape)
# (1, 1, 8)
paddle.disable_static()
query = paddle.to_tensor(query_data, stop_gradient=False, place=paddle.CUDAPlace(0))
key = paddle.to_tensor(key_data, stop_gradient=False, place=paddle.CUDAPlace(0))
value = paddle.to_tensor(value_data, stop_gradient=False, place=paddle.CUDAPlace(0))
offset = paddle.to_tensor(sparse_csr_offset_data, stop_gradient=False, place=paddle.CUDAPlace(0))
columns = paddle.to_tensor(sparse_csr_columns_data, stop_gradient=False, place=paddle.CUDAPlace(0))
output = paddle.nn.functional.sparse_attention(query, key, value, offset, columns)
print(output)

# [[[[1.60885942, 2.60885954],
#       [1.99830270, 2.99830270],
#       [1.60885942, 2.60885954],
#       [1.99830270, 2.99830270]]]]

Result

由于目前CI平台没有CUDA11.2的机器资源,因而将本地计算结果粘贴如下:

  1. 本地单测结果
    image
  2. API example结果
    image

…addlePaddle#35757)

Add paddle.nn.functional.sparse_attention API

    本个PR主要将sparse_attention功能在python层进行了一层封装,OP的主体代码见:#PR35676

    此外,对于封装的python 接口,增加了相应的单测。
@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@lanxianghit lanxianghit merged commit c57d1e9 into PaddlePaddle:release/2.2 Oct 25, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants