Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sparse_attention OP, test=develop #35676

Merged

Conversation

Liu-xiandong
Copy link
Member

@Liu-xiandong Liu-xiandong commented Sep 13, 2021

PR types

New features

PR changes

OPs

Describe

Add paddle._C_ops.sparse_attention OPs

Example

import paddle
import numpy as np

query_data = np.array([[[[0, 1,], [2, 3], [ 0, 1], [2, 3]]]]).astype("float32")
key_data = np.array([[[[0, 1,], [2, 3], [ 0, 1], [2, 3]]]]).astype("float32")
value_data = np.array([[[[0, 1,], [2, 3], [ 0, 1], [2, 3]]]]).astype("float32")
sparse_csr_offset_data = np.array([[[0, 2, 4, 6, 8]]]).astype("int32")
sparse_csr_columns_data = np.array([[[0, 1, 0, 1, 2, 3, 2, 3]]]).astype("int32")
print(query_data.shape)
# (1, 1, 4, 2)
print(sparse_csr_offset_data.shape)
# (1, 1, 5)
print(sparse_csr_columns_data.shape)
# (1, 1, 8)
paddle.disable_static()
query = paddle.to_tensor(query_data, stop_gradient=False, place=paddle.CUDAPlace(0))
key = paddle.to_tensor(key_data, stop_gradient=False, place=paddle.CUDAPlace(0))
value = paddle.to_tensor(value_data, stop_gradient=False, place=paddle.CUDAPlace(0))
offset = paddle.to_tensor(sparse_csr_offset_data, stop_gradient=False, place=paddle.CUDAPlace(0))
columns = paddle.to_tensor(sparse_csr_columns_data, stop_gradient=False, place=paddle.CUDAPlace(0))
output = paddle._C_ops.sparse_attention(query, key, value, offset, columns)
print(output)

# [[[[1.60885942, 2.60885954],
#       [1.99830270, 2.99830270],
#       [1.60885942, 2.60885954],
#       [1.99830270, 2.99830270]]]]

Precautions

  • The code of this PR can only support CUDA 11.2. Currently, CI does not have GPU with CUDA 11.2 , and all tests will be skipped automatically.

  • The new OP is paddle._C_ops.sparse_attention. Regarding the work of the python API, it will be resolved in a follow-up PR.

  • The code of this PR lacks tests on dynamic graphs and static graphs, and will be added in subsequent PRs.

Result

image
image

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@ZHUI
Copy link
Collaborator

ZHUI commented Sep 17, 2021

这个是为了支持哪一些模型呢?

@Liu-xiandong
Copy link
Member Author

这个是为了支持哪一些模型呢?

NLP 那边的 sparse transformer

@Liu-xiandong Liu-xiandong reopened this Sep 17, 2021
@@ -46,6 +46,7 @@
'cudnn_lstm', \
'rnn', \
'lgamma', \
'sparse_attention', \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

添加白名单的理由是否充分?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前主要是矩阵乘导致的精度误差,咨询了彭军才,同意添加

paddle/fluid/operators/sparse_attention_op.cc Show resolved Hide resolved
@lanxianghit
Copy link
Contributor

example里给出的API定义不符合2.0规范 paddle.fluid.core.sparse_attention

@Liu-xiandong
Copy link
Member Author

example里给出的API定义不符合2.0规范 paddle.fluid.core.sparse_attention

目前行数比较多,Python API的封装将在下一个PR中提交

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI没有CUDA11.2以上的环境,请在PR描述里面贴一下单测本地测试的结果。

paddle/fluid/operators/sparse_attention_op.cc Outdated Show resolved Hide resolved
paddle/fluid/operators/sparse_attention_op.cc Outdated Show resolved Hide resolved
paddle/fluid/operators/sparse_attention_op.cc Outdated Show resolved Hide resolved
paddle/fluid/operators/sparse_attention_op.cc Outdated Show resolved Hide resolved
paddle/fluid/operators/sparse_attention_op.cc Outdated Show resolved Hide resolved
paddle/fluid/operators/sparse_attention_op.cu Outdated Show resolved Hide resolved
paddle/fluid/operators/sparse_attention_op.cu Show resolved Hide resolved
paddle/fluid/operators/sparse_attention_op.cu Outdated Show resolved Hide resolved
&output_lists[i], M, N, false, false);
}
#else
PADDLE_THROW(platform::errors::InvalidArgument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

错误类型好像有Unsupported

return -1


def get_linux_platform():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

操作系统,可以在CMakelists.txt里面控制。

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CMake可以控制C++的相关参数,但是python端不能直接控制。需要通过类似于注册算子的方式实现,例如#26180

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没那么复杂,cmake中可以判断是否WINMACOS系统,这两个系统就不定义这个单测了。

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, Thanks

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. 这个PR可以先合进去,sparse handle绑定ctx中已有的stream,下个PR需要处理下。

const T* srcptr = src + layout_rowptr[cur_block_row];
T* attnptr = nullptr;
if (attn_mask != nullptr) {
const T* attnptr = attn_mask + cur_block_row * num_rows;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是想赋值给L84定义的attnptr吧,但是这样写重新定义了一个局部变量,所以赋值是无效的。

*/
template <typename DeviceContext, typename T>
void SparseSoftmaxForward(const platform::CUDADeviceContext& ctx,
const Tensor* offset, const Tensor* columns,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

输入最好用const Tensor &类型。

}
}

void CusparseDestroy(cusparseDnMatDescr_t* dn_mat_first,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种封装方式不太好,一般应该遵守谁创建、谁销毁的原则。

GetTransposeOperation(b_transpose), &alpha, mat_a, mat_b, &beta, mat_c,
gpu_type, CUSPARSE_SDDMM_ALG_DEFAULT, &buffer_size);
auto d_buffer_ptr = paddle::memory::Alloc(ctx, buffer_size);
void* d_buffer = static_cast<void*>(d_buffer_ptr->ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是workspace吗?

return -1


def get_linux_platform():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没那么复杂,cmake中可以判断是否WINMACOS系统,这两个系统就不定义这个单测了。

@Liu-xiandong Liu-xiandong changed the title Add sparse_attention api, test=develop Add sparse_attention OP, test=develop Sep 28, 2021
@lanxianghit lanxianghit merged commit 6b587e9 into PaddlePaddle:develop Sep 28, 2021
AnnaTrainingG pushed a commit to AnnaTrainingG/Paddle that referenced this pull request Sep 29, 2021
Add sparse_attention OPs, python api will be added in next pr
Liu-xiandong added a commit to Liu-xiandong/Paddle that referenced this pull request Oct 14, 2021
Add sparse_attention OPs, python api will be added in next pr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants