From 36edb0e1c40d039e4f09be73b72b1903b219492d Mon Sep 17 00:00:00 2001 From: Liu-xiandong <85323580+Liu-xiandong@users.noreply.github.com> Date: Tue, 19 Oct 2021 15:09:57 +0800 Subject: [PATCH] [cherry-pick]Add sparse attention cherrypick (#36447) The code of this PR can only support CUDA 11.2. Currently, CI does not have GPU with CUDA 11.2 , and all tests will be skipped automatically. The new OP is paddle._C_ops.sparse_attention. Regarding the work of the python API, it will be resolved in a follow-up PR. The code of this PR lacks tests on dynamic graphs and static graphs, and will be added in subsequent PRs. --- cmake/operators.cmake | 2 +- paddle/fluid/operators/CMakeLists.txt | 6 +- paddle/fluid/operators/sparse_attention_op.cc | 193 +++++++ paddle/fluid/operators/sparse_attention_op.cu | 537 ++++++++++++++++++ paddle/fluid/platform/dynload/cusparse.cc | 4 + paddle/fluid/platform/dynload/cusparse.h | 20 +- .../unittests/test_sparse_attention_op.py | 205 +++++++ .../white_list/op_threshold_white_list.py | 1 + 8 files changed, 960 insertions(+), 8 deletions(-) create mode 100644 paddle/fluid/operators/sparse_attention_op.cc create mode 100644 paddle/fluid/operators/sparse_attention_op.cu create mode 100644 python/paddle/fluid/tests/unittests/test_sparse_attention_op.py diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 2c010a1e6297f..7541b234ceaa6 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -214,7 +214,7 @@ function(op_library TARGET) foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "bitwise_op" "nccl_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" -"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" +"sync_batch_norm_op" "sparse_attention_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op" "fused_bn_add_activation_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 0d7d0a5e13bf3..c487313f91c58 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -78,7 +78,7 @@ if(WITH_UNITY_BUILD) include(unity_build_rule.cmake) endif() -register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op lstm_op run_program_op eye_op recurrent_op +register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op sparse_attention_op lstm_op run_program_op eye_op recurrent_op sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS}) @@ -94,6 +94,10 @@ if (WITH_GPU OR WITH_ROCM) endif() op_library(sync_batch_norm_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n") + if ((NOT WIN32) AND (NOT WITH_ROCM) AND (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 11.2) ) + op_library(sparse_attention_op) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sparse_attention);\n") + endif() else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() diff --git a/paddle/fluid/operators/sparse_attention_op.cc b/paddle/fluid/operators/sparse_attention_op.cc new file mode 100644 index 0000000000000..9b6bc1b629045 --- /dev/null +++ b/paddle/fluid/operators/sparse_attention_op.cc @@ -0,0 +1,193 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class SparseAttentionOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "Q", + "(Tensor), The input tensor of query in attention, " + "whose dimension : `[batch_size, num_heads, target_len, head_dim]`."); + AddInput( + "K", + "(Tensor), The input tensor of key in attention, " + "whose dimension : `[batch_size, num_heads, target_len, head_dim]`."); + AddInput( + "V", + "(Tensor), The input tensor of value in attention, " + "whose dimension : `[batch_size, num_heads, target_len, head_dim]`."); + AddInput("Offset", + "(Tensor, default: Tensor), The input tensor of offset in " + "CSR sparse format, " + "whose dimension : `[batch_size, num_heads, target_len + 1]`."); + AddInput("Columns", + "(Tensor, default: Tensor), The input tensor of columns in " + "CSR sparse format, " + "whose dimension : `[batch_size, num_heads, sparse_nnz_num]`."); + AddOutput( + "Out", + "(Tensor), The output tensor of result in attention, " + "whose dimension : `[batch_size, num_heads, target_len, head_dim]`."); + AddOutput("SparseDotSdd", + "(Tensor), The output tensor of result in SparseDotSdd step, " + "whose dimension : `[batch_size, num_heads, sparse_nnz_dim]`.") + .AsIntermediate(); + AddOutput("Softmax", + "(Tensor), The output tensor of result in Softmax step, " + "whose dimension : `[batch_size, num_heads, sparse_nnz_dim]`.") + .AsIntermediate(); + AddComment(R"DOC( + Compute the value of the sparse attention module. Its input value includes five tensors. + Q, K, and V represent query, key, and value in the Attention module, respectively. + The CSR format is used to represent the sparsity feature in the Attention module. + The CSR format contains two tensors, offset and columns. + )DOC"); + } +}; + +class SparseAttentionOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Q"), "Input", "Q", "sparse_attention"); + OP_INOUT_CHECK(ctx->HasInput("K"), "Input", "K", "sparse_attention"); + OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "sparse_attention"); + OP_INOUT_CHECK(ctx->HasInput("Offset"), "Input", "Offset", + "sparse_attention"); + OP_INOUT_CHECK(ctx->HasInput("Columns"), "Input", "Columns", + "sparse_attention"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "sparse_attention"); + OP_INOUT_CHECK(ctx->HasOutput("SparseDotSdd"), "Output", "SparseDotSdd", + "sparse_attention"); + OP_INOUT_CHECK(ctx->HasOutput("Softmax"), "Output", "Softmax", + "sparse_attention"); + + auto dims_q = ctx->GetInputDim("Q"); + auto dims_k = ctx->GetInputDim("K"); + auto dims_v = ctx->GetInputDim("V"); + auto dims_columns = ctx->GetInputDim("Columns"); + + PADDLE_ENFORCE_EQ(dims_q.size(), static_cast(4), + platform::errors::InvalidArgument( + "Dimension in query' shapes should be 4.")); + PADDLE_ENFORCE_EQ(dims_k.size(), static_cast(4), + platform::errors::InvalidArgument( + "Dimension in key' shapes should be 4.")); + PADDLE_ENFORCE_EQ(dims_v.size(), static_cast(4), + platform::errors::InvalidArgument( + "Dimension in value' shapes should be 4.")); + + auto batch_size = dims_q[0]; + auto num_heads = dims_q[1]; + auto M = dims_q[2]; + auto N = dims_q[3]; + auto sparse_nnz = dims_columns[2]; + ctx->SetOutputDim("Out", {batch_size, num_heads, M, N}); + ctx->SetOutputDim("SparseDotSdd", {batch_size, num_heads, sparse_nnz}); + ctx->SetOutputDim("Softmax", {batch_size, num_heads, sparse_nnz}); + ctx->ShareLoD("Q", "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "Q", "K"); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + +class SparseAttentionOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Q"), "Input", "Q", "sparse_attention_grad"); + OP_INOUT_CHECK(ctx->HasInput("K"), "Input", "K", "sparse_attention_grad"); + OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "sparse_attention_grad"); + OP_INOUT_CHECK(ctx->HasInput("Offset"), "Input", "Offset", + "sparse_attention_grad"); + OP_INOUT_CHECK(ctx->HasInput("Columns"), "Input", "Columns", + "sparse_attention_grad"); + OP_INOUT_CHECK(ctx->HasInput("SparseDotSdd"), "Input", "SparseDotSdd", + "sparse_attention_grad"); + OP_INOUT_CHECK(ctx->HasInput("Softmax"), "Input", "Softmax", + "sparse_attention_grad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "sparse_attention_grad"); + + auto x_grad_name = framework::GradVarName("Q"); + auto y_grad_name = framework::GradVarName("K"); + auto z_grad_name = framework::GradVarName("V"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("Q")); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("K")); + } + if (ctx->HasOutput(z_grad_name)) { + ctx->SetOutputDim(z_grad_name, ctx->GetInputDim("V")); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); + } +}; + +template +class SparseAttentionGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("sparse_attention_grad"); + op->SetInput("Q", this->Input("Q")); + op->SetInput("K", this->Input("K")); + op->SetInput("V", this->Input("V")); + op->SetInput("Offset", this->Input("Offset")); + op->SetInput("Columns", this->Input("Columns")); + op->SetInput("SparseDotSdd", this->Output("SparseDotSdd")); + op->SetInput("Softmax", this->Output("Softmax")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("Q"), this->InputGrad("Q")); + op->SetOutput(framework::GradVarName("K"), this->InputGrad("K")); + op->SetOutput(framework::GradVarName("V"), this->InputGrad("V")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(sparse_attention, ops::SparseAttentionOp, + ops::SparseAttentionOpMaker, + ops::SparseAttentionGradOpMaker, + ops::SparseAttentionGradOpMaker); + +REGISTER_OPERATOR(sparse_attention_grad, ops::SparseAttentionOpGrad); diff --git a/paddle/fluid/operators/sparse_attention_op.cu b/paddle/fluid/operators/sparse_attention_op.cu new file mode 100644 index 0000000000000..88ee8999c5f4a --- /dev/null +++ b/paddle/fluid/operators/sparse_attention_op.cu @@ -0,0 +1,537 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#if defined(PADDLE_WITH_CUDA) +#include "paddle/fluid/platform/dynload/cusparse.h" +#endif + +namespace ops = paddle::operators; +namespace plf = paddle::platform; + +namespace paddle { +namespace operators { + +template +__forceinline__ __device__ T CudaShuffleXorSync(unsigned mask, T val, + int width = warpSize) { + return __shfl_xor_sync(mask, val, width); +} + +template +__device__ __forceinline__ void WarpReduceSum(T* sum) { +#pragma unroll + for (int offset = warp_size / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < batch_size; ++i) { + T sum_val = CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); + sum[i] = sum[i] + sum_val; + } + } +} + +template +__device__ __forceinline__ void WarpReduceMax(T* sum) { +#pragma unroll + for (int offset = warp_size / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < batch_size; ++i) { + T max_val = CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); + sum[i] = max(sum[i], max_val); + } + } +} + +template +__global__ void BlockSparseSoftmaxForward(T* softmax, const T* src, T scale, + const T* kp_mask, const T* attn_mask, + const int* layout_rowptr, + const int* layout_colindex, + int num_rows) { + // current thread related info + const int WarpSize = 32; + const int cur_row = blockIdx.x * blockDim.y + threadIdx.y; + if (cur_row < num_rows) { + const int cur_block_row = cur_row / BlockSize; + const int cur_block_nnz = + layout_rowptr[cur_block_row + 1] - layout_rowptr[cur_block_row]; + + T srcdata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize]; + T attndata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize]; + + // read kp mask + T cur_kp_mask = (kp_mask == nullptr) ? 0 : kp_mask[cur_row]; + + // read tensor data, attn mask + const int iter = (cur_block_nnz + WarpSize - 1) / WarpSize; + const T* srcptr = src + layout_rowptr[cur_block_row]; + T* attnptr = nullptr; + if (attn_mask != nullptr) { + const T* attnptr = attn_mask + cur_block_row * num_rows; + } + const int* colindex = layout_colindex + layout_rowptr[cur_block_row]; + for (int j = 0; j < iter; j++) { + int cur_block_col = j * WarpSize + threadIdx.x; + int cur_reg_index = j; + if (cur_block_col < cur_block_nnz) { + if ((attnptr != nullptr) && + std::abs(attnptr[colindex[cur_block_col]]) < + std::numeric_limits::epsilon()) { + srcdata[cur_reg_index] = + -std::numeric_limits::infinity() * scale + cur_kp_mask; + } else { + srcdata[cur_reg_index] = scale * srcptr[cur_block_col] + cur_kp_mask; + } + } else { + srcdata[cur_reg_index] = -std::numeric_limits::infinity(); + } + } + + // max value + T max_value = srcdata[0]; + const int kIteration = + (cur_block_nnz * BlockSize + WarpSize - 1) / WarpSize; +#pragma unroll + for (int it = 1; it < kIteration; ++it) { + max_value = (max_value > srcdata[it]) ? max_value : srcdata[it]; + } + WarpReduceMax(&max_value); + + // exp sum + T sum = 0; +#pragma unroll + for (int it = 0; it < kIteration; ++it) { + srcdata[it] = std::exp(srcdata[it] - max_value); + sum += srcdata[it]; + } + WarpReduceSum(&sum); + + // compute softmax and write out + T* softmaxptr = softmax + layout_rowptr[cur_block_row]; + for (int j = 0; j < iter; j++) { + int cur_block_col = j * WarpSize + threadIdx.x; + int cur_reg_index = j; + if (cur_block_col < cur_block_nnz) { + softmaxptr[cur_block_col] = srcdata[cur_reg_index] / sum; + } + } + } +} + +template +__global__ void BlockSparseSoftmaxBackward(T* dst, const T* grad, const T* src, + T scale, const int* layout_rowptr, + const int* layout_colindex, + int num_rows) { + // current thread related info + const int WarpSize = 32; + const int cur_row = blockIdx.x * blockDim.y + threadIdx.y; + if (cur_row < num_rows) { + const int cur_block_row = cur_row / BlockSize; + const int cur_block_nnz = + layout_rowptr[cur_block_row + 1] - layout_rowptr[cur_block_row]; + + T srcdata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize]; + T graddata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize]; + + // read tensor data, attn mask + const int iter = (cur_block_nnz + WarpSize - 1) / WarpSize; + const T* srcptr = src + layout_rowptr[cur_block_row]; + const T* gradptr = grad + layout_rowptr[cur_block_row]; + for (int j = 0; j < iter; j++) { + int cur_block_col = j * WarpSize + threadIdx.x; + int cur_reg_index = j; + if (cur_block_col < cur_block_nnz) { + srcdata[cur_reg_index] = srcptr[cur_block_col]; + graddata[cur_reg_index] = gradptr[cur_block_col]; + } else { + srcdata[cur_reg_index] = 0; + graddata[cur_reg_index] = 0; + } + } + + T sum = 0; + const int kIteration = + (cur_block_nnz * BlockSize + WarpSize - 1) / WarpSize; +#pragma unroll + for (int it = 0; it < kIteration; ++it) { + sum += srcdata[it] * graddata[it]; + } + WarpReduceSum(&sum); + + // compute softmax and write out + T* dstptr = dst + layout_rowptr[cur_block_row]; + for (int j = 0; j < iter; j++) { + int cur_block_col = j * WarpSize + threadIdx.x; + int cur_reg_index = j; + if (cur_block_col < cur_block_nnz) { + dstptr[cur_block_col] = + scale * srcdata[cur_reg_index] * (graddata[cur_reg_index] - sum); + } + } + } +} + +using Tensor = framework::Tensor; +/* +input: sparse C in CSR format (num_rows,num_rows) +output: sparse C after softmax operation +*/ +template +void SparseSoftmaxForward(const platform::CUDADeviceContext& ctx, + const Tensor* offset, const Tensor* columns, + Tensor* input, Tensor* output, const int blocksize, + const int num_rows, const int num_cols) { + const int* offset_data = offset->data(); + const int* columns_data = columns->data(); + T* input_data = input->data(); + T* output_data = output->data(); + + const int block_size = 1; + dim3 blocks(32, 4, 1); + int grid = (num_rows * block_size + 3) / 4; + T scaling = static_cast(1.0) / sqrt(static_cast(num_cols)); + + const int block_nnz_max = 256; + BlockSparseSoftmaxForward<<>>( + output_data, input_data, scaling, nullptr, nullptr, offset_data, + columns_data, num_rows); +} + +template +void SparseSoftmaxBackward(const platform::CUDADeviceContext& ctx, + const Tensor* offset, const Tensor* columns, + Tensor* dx, const Tensor* dout, const Tensor* out, + const int blocksize, const int num_rows, + const int num_cols) { + const int* offset_data = offset->data(); + const int* columns_data = columns->data(); + T* dx_data = dx->data(); + const T* dout_data = dout->data(); + const T* out_data = out->data(); + + const int block_size = 1; + dim3 blocks(32, 4, 1); + int grid = (num_rows * block_size + 3) / 4; + T scaling = static_cast(1.0) / sqrt(static_cast(num_cols)); + + const int block_nnz_max = 256; + BlockSparseSoftmaxBackward<<>>( + dx_data, dout_data, out_data, scaling, offset_data, columns_data, + num_rows); +} + +using VarType = framework::proto::VarType; +inline cudaDataType_t GetGpuType(const VarType::Type data_type) { + if (data_type == VarType::FP32) { + return CUDA_R_32F; + } else if (data_type == VarType::FP64) { + return CUDA_R_64F; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Not support tensor type in sparse_attention OP: %s", + framework::DataTypeToString(data_type))); + } +} + +inline cusparseOperation_t GetTransposeOperation(const bool transpose) { + if (transpose) { + return CUSPARSE_OPERATION_TRANSPOSE; + } else { + return CUSPARSE_OPERATION_NON_TRANSPOSE; + } +} + +void CusparseDestroy(cusparseDnMatDescr_t* dn_mat_first, + cusparseDnMatDescr_t* dn_mat_second, + cusparseSpMatDescr_t* sp_mat) { + platform::dynload::cusparseDestroyDnMat(*dn_mat_first); + platform::dynload::cusparseDestroyDnMat(*dn_mat_second); + platform::dynload::cusparseDestroySpMat(*sp_mat); +} + +/* +input: dense A (num_rows,num_cols), dense B (num_rows,num_cols) +output: sparse C in CSR format (num_rows,num_rows) +*/ +template +void DotSdd(const platform::CUDADeviceContext& ctx, const Tensor* a, + const Tensor* b, const Tensor* c_offset, const Tensor* c_columns, + Tensor* c_value, const int num_rows, const int num_cols, + const bool a_transpose, const bool b_transpose) { + const T* a_data = a->data(); + const T* b_data = b->data(); + const int* c_offset_data = c_offset->data(); + const int* c_columns_data = c_columns->data(); + T* c_value_data = c_value->data(); + + cudaDataType_t gpu_type = GetGpuType(c_value->type()); + cusparseHandle_t handle = nullptr; + cusparseDnMatDescr_t mat_a, mat_b; + cusparseSpMatDescr_t mat_c; + platform::dynload::cusparseCreate(&handle); + + // Create dense matrix A + platform::dynload::cusparseCreateDnMat(&mat_a, num_rows, num_cols, num_cols, + const_cast(a_data), gpu_type, + CUSPARSE_ORDER_ROW); + // Create dense matrix B + platform::dynload::cusparseCreateDnMat(&mat_b, num_rows, num_cols, num_cols, + const_cast(b_data), gpu_type, + CUSPARSE_ORDER_ROW); + // Create sparse matrix C in CSR format + int c_nnz = c_columns->dims()[1]; + platform::dynload::cusparseCreateCsr( + &mat_c, num_rows, num_rows, c_nnz, const_cast(c_offset_data), + const_cast(c_columns_data), c_value_data, CUSPARSE_INDEX_32I, + CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, gpu_type); + + T alpha = 1; + T beta = 0; + + size_t buffer_size = 0; + platform::dynload::cusparseSDDMM_bufferSize( + handle, GetTransposeOperation(a_transpose), + GetTransposeOperation(b_transpose), &alpha, mat_a, mat_b, &beta, mat_c, + gpu_type, CUSPARSE_SDDMM_ALG_DEFAULT, &buffer_size); + auto d_buffer_ptr = paddle::memory::Alloc(ctx, buffer_size); + void* d_buffer = static_cast(d_buffer_ptr->ptr()); + + platform::dynload::cusparseSDDMM(handle, GetTransposeOperation(a_transpose), + GetTransposeOperation(b_transpose), &alpha, + mat_a, mat_b, &beta, mat_c, gpu_type, + CUSPARSE_SDDMM_ALG_DEFAULT, d_buffer); + + CusparseDestroy(&mat_a, &mat_b, &mat_c); + platform::dynload::cusparseDestroy(handle); +} + +/* +input: sparse A in CSR format (num_rows,num_rows), dense B (num_rows,num_cols) +output: dense C (num_rows,num_cols) +*/ +template +void DotDsd(const platform::CUDADeviceContext& ctx, const Tensor* a_offset, + const Tensor* a_columns, const Tensor* a_value, const Tensor* b, + Tensor* c, const int num_rows, const int num_cols, + const bool a_transpose, const bool b_transpose) { + const int* a_offset_data = a_offset->data(); + const int* a_columns_data = a_columns->data(); + const T* a_value_data = a_value->data(); + const T* b_data = b->data(); + T* c_data = c->data(); + + cudaDataType_t gpu_type = GetGpuType(c->type()); + cusparseHandle_t handle = nullptr; + cusparseSpMatDescr_t mat_a; + cusparseDnMatDescr_t mat_b, mat_c; + platform::dynload::cusparseCreate(&handle); + + // Create sparse matrix A in CSR format + int a_nnz = a_columns->dims()[1]; + platform::dynload::cusparseCreateCsr( + &mat_a, num_rows, num_rows, a_nnz, const_cast(a_offset_data), + const_cast(a_columns_data), const_cast(a_value_data), + CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, + gpu_type); + + // Create dense matrix B + platform::dynload::cusparseCreateDnMat(&mat_b, num_rows, num_cols, num_cols, + const_cast(b_data), gpu_type, + CUSPARSE_ORDER_ROW); + // Create dense matrix C + platform::dynload::cusparseCreateDnMat(&mat_c, num_rows, num_cols, num_cols, + c_data, gpu_type, CUSPARSE_ORDER_ROW); + + T alpha = 1; + T beta = 0; + + size_t buffer_size = 0; + // allocate an external buffer if needed + platform::dynload::cusparseSpMM_bufferSize( + handle, GetTransposeOperation(a_transpose), + GetTransposeOperation(b_transpose), &alpha, mat_a, mat_b, &beta, mat_c, + gpu_type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size); + auto d_buffer_ptr = paddle::memory::Alloc(ctx, buffer_size); + void* d_buffer = static_cast(d_buffer_ptr->ptr()); + + platform::dynload::cusparseSpMM(handle, GetTransposeOperation(a_transpose), + GetTransposeOperation(b_transpose), &alpha, + mat_a, mat_b, &beta, mat_c, gpu_type, + CUSPARSE_SPMM_ALG_DEFAULT, d_buffer); + + CusparseDestroy(&mat_b, &mat_c, &mat_a); + platform::dynload::cusparseDestroy(handle); +} + +std::vector GetSplitTensor(Tensor* input) { + auto dims = input->dims(); + int batch_size = dims[0]; + int num_heads = dims[1]; + std::vector new_dims(dims.size() - 1); + new_dims[0] = batch_size * num_heads; + for (int i = 1; i < new_dims.size(); i++) { + new_dims[i] = dims[i + 1]; + } + input->Resize(framework::make_ddim(new_dims)); + return input->Split(1, 0); +} + +template +class SparseAttentionCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto query = *ctx.Input("Q"); + auto key = *ctx.Input("K"); + auto value = *ctx.Input("V"); + auto offset = *ctx.Input("Offset"); + auto columns = *ctx.Input("Columns"); + auto output_ptr = ctx.Output("Out"); + output_ptr->mutable_data(ctx.GetPlace()); + auto sparse_dot_sdd_ptr = ctx.Output("SparseDotSdd"); + sparse_dot_sdd_ptr->mutable_data(ctx.GetPlace()); + auto softmax_ptr = ctx.Output("Softmax"); + softmax_ptr->mutable_data(ctx.GetPlace()); + + auto output = *output_ptr; + auto result_sdd = *sparse_dot_sdd_ptr; + auto result_softmax = *softmax_ptr; + + auto query_dims = query.dims(); + int batch_size = query_dims[0]; + int num_heads = query_dims[1]; + int M = query_dims[2]; + int N = query_dims[3]; + + std::vector query_lists = GetSplitTensor(&query); + std::vector key_lists = GetSplitTensor(&key); + std::vector value_lists = GetSplitTensor(&value); + std::vector offset_lists = GetSplitTensor(&offset); + std::vector columns_lists = GetSplitTensor(&columns); + std::vector result_sdd_lists = GetSplitTensor(&result_sdd); + std::vector result_softmax_lists = GetSplitTensor(&result_softmax); + std::vector output_lists = GetSplitTensor(&output); + + const auto& dev_ctx = ctx.cuda_device_context(); + const int iter_num = batch_size * num_heads; + for (int i = 0; i < iter_num; i++) { + DotSdd(dev_ctx, &query_lists[i], &key_lists[i], + &offset_lists[i], &columns_lists[i], + &result_sdd_lists[i], M, N, false, true); + + SparseSoftmaxForward( + dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i], + &result_softmax_lists[i], 1, M, N); + + DotDsd(dev_ctx, &offset_lists[i], &columns_lists[i], + &result_softmax_lists[i], &value_lists[i], + &output_lists[i], M, N, false, false); + } + } +}; + +template +class SparseAttentionGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto query = *ctx.Input("Q"); + auto key = *ctx.Input("K"); + auto value = *ctx.Input("V"); + auto offset = *ctx.Input("Offset"); + auto columns = *ctx.Input("Columns"); + auto sparse_dot_sdd = *ctx.Input("SparseDotSdd"); + auto softmax = *ctx.Input("Softmax"); + auto dout = *ctx.Input(framework::GradVarName("Out")); + auto* dquery_ptr = ctx.Output(framework::GradVarName("Q")); + auto* dkey_ptr = ctx.Output(framework::GradVarName("K")); + auto* dvalue_ptr = ctx.Output(framework::GradVarName("V")); + dquery_ptr->mutable_data(ctx.GetPlace()); + dkey_ptr->mutable_data(ctx.GetPlace()); + dvalue_ptr->mutable_data(ctx.GetPlace()); + auto dquery = *dquery_ptr; + auto dkey = *dkey_ptr; + auto dvalue = *dvalue_ptr; + + auto query_dims = query.dims(); + int batch_size = query_dims[0]; + int num_heads = query_dims[1]; + int M = query_dims[2]; + int N = query_dims[3]; + + std::vector query_lists = GetSplitTensor(&query); + std::vector key_lists = GetSplitTensor(&key); + std::vector value_lists = GetSplitTensor(&value); + std::vector offset_lists = GetSplitTensor(&offset); + std::vector columns_lists = GetSplitTensor(&columns); + std::vector sparse_dot_sdd_lists = GetSplitTensor(&sparse_dot_sdd); + std::vector softmax_lists = GetSplitTensor(&softmax); + std::vector dout_lists = GetSplitTensor(&dout); + std::vector dquery_lists = GetSplitTensor(&dquery); + std::vector dkey_lists = GetSplitTensor(&dkey); + std::vector dvalue_lists = GetSplitTensor(&dvalue); + + const int iter_num = batch_size * num_heads; + const auto& dev_ctx = ctx.cuda_device_context(); + for (int i = 0; i < iter_num; i++) { + // dValue = transpose(result_softmax) * dOut + DotDsd(dev_ctx, &offset_lists[i], &columns_lists[i], + &softmax_lists[i], &dout_lists[i], + &dvalue_lists[i], M, N, true, false); + + // dSoftmax = dOut * transpose(Value) + int nnz_num = columns.dims()[0]; + Tensor dsoftmax; + dsoftmax.Resize({nnz_num}); + dsoftmax.mutable_data(ctx.GetPlace()); + DotSdd(dev_ctx, &dout_lists[i], &value_lists[i], + &offset_lists[i], &columns_lists[i], &dsoftmax, + M, N, false, true); + + // dSparseDotSdd = dSoftmax * softmax'(SparseDotSdd) + Tensor dsparse_dot_sdd; + dsparse_dot_sdd.Resize({nnz_num}); + dsparse_dot_sdd.mutable_data(ctx.GetPlace()); + SparseSoftmaxBackward( + dev_ctx, &offset_lists[i], &columns_lists[i], &dsparse_dot_sdd, + &dsoftmax, &softmax_lists[i], 1, M, N); + + // dQuery = dSparseDotSdd * Key + DotDsd(dev_ctx, &offset_lists[i], &columns_lists[i], + &dsparse_dot_sdd, &key_lists[i], + &dquery_lists[i], M, N, false, false); + + // dKey = transpose(dSparseDotSdd) * Query + DotDsd(dev_ctx, &offset_lists[i], &columns_lists[i], + &dsparse_dot_sdd, &query_lists[i], + &dkey_lists[i], M, N, true, false); + } + } +}; + +} // namespace operators +} // namespace paddle +REGISTER_OP_CUDA_KERNEL( + sparse_attention, + ops::SparseAttentionCUDAKernel, + ops::SparseAttentionCUDAKernel); + +REGISTER_OP_CUDA_KERNEL( + sparse_attention_grad, + ops::SparseAttentionGradCUDAKernel, + ops::SparseAttentionGradCUDAKernel); diff --git a/paddle/fluid/platform/dynload/cusparse.cc b/paddle/fluid/platform/dynload/cusparse.cc index 2b41da541d9ae..2a1fe322dabcf 100644 --- a/paddle/fluid/platform/dynload/cusparse.cc +++ b/paddle/fluid/platform/dynload/cusparse.cc @@ -26,6 +26,10 @@ void *cusparse_dso_handle; #ifdef CUSPARSE_ROUTINE_EACH CUSPARSE_ROUTINE_EACH(DEFINE_WRAP); #endif + +#ifdef CUBLAS_BLAS_ROUTINE_EACH_R2 +CUSPARSE_ROUTINE_EACH_R2(DEFINE_WRAP); +#endif } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/cusparse.h b/paddle/fluid/platform/dynload/cusparse.h index 98841949676e4..e5be003fadf06 100644 --- a/paddle/fluid/platform/dynload/cusparse.h +++ b/paddle/fluid/platform/dynload/cusparse.h @@ -41,8 +41,9 @@ extern void *cusparse_dso_handle; }; \ extern DynLoad__##__name __name -#ifndef _WIN32 -#if CUDA_VERSION >= 11020 +#if !defined(PADDLE_WITH_ARM) && !defined(_WIN32) +// APIs available after CUDA 11.0 +#if CUDA_VERSION >= 11000 #define CUSPARSE_ROUTINE_EACH(__macro) \ __macro(cusparseCreate); \ __macro(cusparseCreateCsr); \ @@ -51,12 +52,19 @@ extern void *cusparse_dso_handle; __macro(cusparseSpMM); \ __macro(cusparseDestroySpMat); \ __macro(cusparseDestroyDnMat); \ - __macro(cusparseDestroy); \ - __macro(cusparseSDDMM_bufferSize); \ - __macro(cusparseSDDMM_preprocess); \ - __macro(cusparseSDDMM); + __macro(cusparseDestroy); CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP); + +// APIs available after CUDA 11.2 +#if CUDA_VERSION >= 11020 +#define CUSPARSE_ROUTINE_EACH_R2(__macro) \ + __macro(cusparseSDDMM_bufferSize); \ + __macro(cusparseSDDMM_preprocess); \ + __macro(cusparseSDDMM); + +CUSPARSE_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) +#endif #endif #endif diff --git a/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py b/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py new file mode 100644 index 0000000000000..48401fb55ef3f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py @@ -0,0 +1,205 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core +import paddle +import os +import re +import platform + + +def get_cuda_version(): + result = os.popen("nvcc --version").read() + regex = r'release (\S+),' + match = re.search(regex, result) + if match: + num = str(match.group(1)) + integer, decimal = num.split('.') + return int(integer) * 1000 + int(float(decimal) * 10) + else: + return -1 + + +def get_linux_platform(): + if platform.system().lower() == 'windows': + return 0 + elif platform.system().lower() == 'linux': + return 1 + else: + return -1 + + +def get_suitable_env(): + if get_cuda_version() >= 11020 and get_linux_platform() == 1: + return True + else: + return False + + +def softmax(x): + max = np.max(x, axis=1, keepdims=True) + e_x = np.exp(x - max) + sum = np.sum(e_x, axis=1, keepdims=True) + f_x = e_x / sum + return f_x + + +def get_csr_value(mat, layout, nnz): + row, col = mat.shape[0], mat.shape[1] + value = np.zeros(nnz) + ptr = 0 + for i in range(row): + for j in range(col): + if layout[i][j] == 1: + value[ptr] = mat[i][j] + ptr += 1 + return value + + +def ref_sparse_attention(q, k, v, offset, columns): + row, col, nnz = q.shape[0], q.shape[1], columns.shape[0] + mat = np.zeros((row, row)) + for cur_row in range(row): + start_ptr = int(offset[cur_row]) + end_ptr = int(offset[cur_row + 1]) + for ptr in range(start_ptr, end_ptr): + cur_col = int(columns[ptr]) + mat[cur_row][cur_col] = 1 + a = np.dot(q, k.T) * mat + a_value = get_csr_value(a, mat, nnz) + scaling = float(col)**-0.5 + a = scaling * a + for i in range(row): + for j in range(row): + if mat[i][j] == 0: + a[i][j] = float('-inf') + b = softmax(a) + b_value = get_csr_value(b, mat, nnz) + result = np.dot(b, v) + return result, a_value, b_value + + +def ref_batch_sparse_attention(q, k, v, offset, columns): + batch_size, num_heads, row, col = q.shape + nnz = columns.shape[2] + result = np.zeros((batch_size, num_heads, row, col)) + result_sdd = np.zeros((batch_size, num_heads, nnz)) + result_softmax = np.zeros((batch_size, num_heads, nnz)) + for i in range(batch_size): + for j in range(num_heads): + cur_q, cur_k, cur_v, = q[i][j], k[i][j], v[i][j] + cur_offset, cur_columns = offset[i][j], columns[i][j] + cur_result, cur_sdd, cur_softmax = ref_sparse_attention( + cur_q, cur_k, cur_v, cur_offset, cur_columns) + result[i][j] = cur_result + result_sdd[i][j], result_softmax[i][j] = cur_sdd, cur_softmax + return result, result_sdd, result_softmax + + +def init_csr_format(batch_size, num_heads, rows, blocksize): + block_num, block_last = rows / blocksize, rows % blocksize + nnz_num = block_num * blocksize * blocksize + block_last * block_last + offset = np.zeros(rows + 1) + columns = np.zeros(int(nnz_num)) + mat = np.zeros((rows, rows)) + for i in range(0, rows, blocksize): + for x in range(blocksize): + for y in range(blocksize): + p_x, p_y = i + x, i + y + if (p_x < rows) and (p_y < rows): + mat[p_x][p_y] = 1 + p_offset, p_column, count = 0, 0, 0 + for i in range(rows): + for j in range(rows): + if mat[i][j] != 0: + count += 1 + columns[p_column] = j + p_column += 1 + p_offset += 1 + offset[p_offset] = count + offset = np.expand_dims(np.expand_dims(offset, 0), 0) + offset = offset.repeat(num_heads, axis=1) + offset = offset.repeat(batch_size, axis=0) + columns = np.expand_dims(np.expand_dims(columns, 0), 0) + columns = columns.repeat(num_heads, axis=1) + columns = columns.repeat(batch_size, axis=0) + return offset, columns + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_suitable_env() == False, + "core is not compiled with CUDA and cuda version need >= 11.2 in windows") +class TestSparseAttentionOp(OpTest): + def config(self): + self.shape = (1, 1, 16, 8) + self.blocksize = 2 + self.dtype = "float64" + + def setUp(self): + paddle.enable_static() + self.config() + self.op_type = "sparse_attention" + self.place = paddle.CUDAPlace(0) + self.q = np.random.random(self.shape).astype(self.dtype) + self.k = np.random.random(self.shape).astype(self.dtype) + self.v = np.random.random(self.shape).astype(self.dtype) + offset, columns = init_csr_format(self.shape[0], self.shape[1], + self.shape[2], self.blocksize) + self.offset = offset.astype('int32') + self.columns = columns.astype('int32') + + result, result_sdd, result_softmax = ref_batch_sparse_attention( + self.q, self.k, self.v, self.offset, self.columns) + + self.inputs = { + 'Q': self.q, + 'K': self.k, + 'V': self.v, + 'Offset': self.offset, + 'Columns': self.columns + } + self.outputs = { + 'Out': result.astype(self.dtype), + 'SparseDotSdd': result_sdd.astype(self.dtype), + 'Softmax': result_softmax.astype(self.dtype) + } + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['Q'], 'Out') + self.check_grad_with_place(self.place, ['K'], 'Out') + self.check_grad_with_place(self.place, ['V'], 'Out') + + +class TestSparseAttentionOpFp32Test(TestSparseAttentionOp): + def config(self): + self.shape = (1, 1, 8, 16) + self.blocksize = 2 + self.dtype = "float32" + + +class TestSparseAttentionOpShapeTest(TestSparseAttentionOp): + def config(self): + self.shape = (2, 2, 32, 8) + self.blocksize = 8 + self.dtype = "float64" + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py index 26d63826cc87a..1c8c89d13abc7 100644 --- a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py @@ -46,6 +46,7 @@ 'cudnn_lstm', \ 'rnn', \ 'lgamma', \ + 'sparse_attention', \ 'svd', \ 'matrix_power', \ 'solve', \