Skip to content

Commit

Permalink
Fix bias_add bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
limin2021 committed Nov 11, 2021
1 parent 6c183a8 commit c9e3439
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 34 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/operators/fused/attn_bias_add.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ __global__ void BroadcastKernelBinary(
kernel_primitives::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(
result, arg0, arg1, func);
// store
kernel_primitives::WriteData<OutT, VecSize, 1, 1>(out + fix, result, num);
kernel_primitives::WriteData<OutT, VecSize, 1, 1, true>(out + fix, result,
num);
}

// bias add forward impl for "[m, n] + [n] = [m, n]"
Expand Down
75 changes: 56 additions & 19 deletions paddle/fluid/operators/fused/attn_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ limitations under the License. */

#pragma once

#include "paddle/fluid/operators/fused/attn_bias_add.cu.h"
// #include "paddle/fluid/operators/fused/attn_bias_add.cu.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/float16.h"

#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"

namespace paddle {
namespace operators {

Expand All @@ -36,8 +40,10 @@ class AttnMatMul {

~AttnMatMul() {}

void ComputeForward(const T* weight_data, const T* input_data,
const T* bias_data, T* output_data, T* bias_out_data) {
void ComputeForward(const framework::Tensor* weight,
const framework::Tensor* input,
const framework::Tensor* bias, framework::Tensor* output,
framework::Tensor* bias_out) {
// Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
// here: (transa, transb): nt, input * weight.
CBLAS_TRANSPOSE transA = CblasNoTrans;
Expand All @@ -54,16 +60,27 @@ class AttnMatMul {
// here: (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha,
input_data, weight_data, beta, output_data);
input->data<T>(), weight->data<T>(), beta, output->data<T>());
if (compute_bias_) {
// compute output + bias
LaunchBiasAddFwKernel(dev_ctx_, bsz_seq_, output_size_, output_data,
bias_data, bias_out_data);
std::vector<const Tensor*> ins;
std::vector<Tensor*> outs;
ins.emplace_back(output);
ins.emplace_back(bias);
outs.emplace_back(bias_out);
int elewise_add_axis = -1;
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor<T>());
}
}

void ComputeBackward(const T* input, const T* weight, const T* d_output,
T* d_input, T* d_weight, T* d_bias) {
// void ComputeBackward(const T* input, const T* weight, const T* d_output,
// T* d_input, T* d_weight, T* d_bias) {
void ComputeBackward(const framework::Tensor* input,
const framework::Tensor* weight,
const framework::Tensor* d_output,
framework::Tensor* d_input, framework::Tensor* d_weight,
framework::Tensor* d_bias) {
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
Expand All @@ -81,11 +98,11 @@ class AttnMatMul {

T* dB_input_1_ptr = nullptr;
T* dB_input_2_ptr = nullptr;
T* dB_output_ptr = d_weight;
T* dB_output_ptr = d_weight->data<T>();

T* dA_input_1_ptr = nullptr;
T* dA_input_2_ptr = nullptr;
T* dA_output_ptr = d_input;
T* dA_output_ptr = d_input->data<T>();

if (!transA_) {
// fw: gemm-nt
Expand All @@ -104,10 +121,10 @@ class AttnMatMul {
dA_n = input_size_;
dA_k = output_size_;

blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, d_output,
input, beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
weight, beta, dA_output_ptr);
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha,
d_output->data<T>(), input->data<T>(), beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha,
d_output->data<T>(), weight->data<T>(), beta, dA_output_ptr);
} else { // fw: gemm-nn
// bw: gemm-tn, dB = A^t * dC
dB_transA = CblasTrans;
Expand All @@ -123,10 +140,10 @@ class AttnMatMul {
dA_n = input_size_;
dA_k = output_size_;

blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, input,
d_output, beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
weight, beta, dA_output_ptr);
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha,
input->data<T>(), d_output->data<T>(), beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha,
d_output->data<T>(), weight->data<T>(), beta, dA_output_ptr);
}
} else if (transB_) {
PADDLE_THROW(platform::errors::InvalidArgument(
Expand All @@ -138,7 +155,27 @@ class AttnMatMul {
"parameters."));
}
if (compute_bias_) {
LaunchBiasAddBwKernel(dev_ctx_, bsz_seq_, output_size_, d_output, d_bias);
// reduce: {0, 1, 2, 3, 4} -> {2, 3, 4} or {0, 1, 2} -> {2}
const auto input_dims = d_output->dims();
const auto output_dims = d_bias->dims();
bool support_case_1 =
(input_dims.size() == 5 && output_dims.size() == 3 &&
(input_dims[2] == output_dims[0]) &&
(input_dims[3] == output_dims[1]) &&
(input_dims[4] == output_dims[2]));
bool support_case_2 =
(input_dims.size() == 3 && output_dims.size() == 1 &&
(input_dims[2] == output_dims[0]));
if (support_case_1 || support_case_2) {
gpuStream_t stream = dev_ctx_.stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*d_output, d_bias, {0, 1},
stream);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support reduce when the input dims are [0,1,2,3,4] and "
"output is [2,3,4]"
"or input is [0,1,2] and output is [2]."));
}
}
}

Expand Down
28 changes: 14 additions & 14 deletions paddle/fluid/operators/fused/fused_attention_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {

layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data,
ln_out_data, ln_mean_data, ln_var_data);
qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data,
qkv_out_data, qkv_bias_out_data);
qkv_compute.ComputeForward(qkv_weight, ln_out, qkv_bias, qkv_out,
qkv_bias_out);
} else {
qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data,
qkv_out_data, qkv_bias_out_data);
qkv_compute.ComputeForward(qkv_weight, input_x, qkv_bias, qkv_out,
qkv_bias_out);
}
fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out,
Expand All @@ -184,8 +184,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim]
// out_linear_out: [batch_size, seq_len, embed_dim]
out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data,
nullptr, out_linear_out_data, nullptr);
out_linear_compute.ComputeForward(out_linear_weight, fmha_out, nullptr,
out_linear_out, nullptr);
if (pre_layer_norm) {
// output = (residual + dropout(input + bias))
fused_dropout_layernorm_helper.ResidualDropoutBias(
Expand Down Expand Up @@ -401,9 +401,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_out_linear_out_data, d_out_linear_bias_data, d_residual_data);
}

out_linear_compute.ComputeBackward(fmha_out_data, out_linear_weight_data,
d_out_linear_out_data, d_fmha_out_data,
d_out_linear_weight_data, nullptr);
out_linear_compute.ComputeBackward(fmha_out, out_linear_weight,
d_out_linear_out, d_fmha_out,
d_out_linear_weight, nullptr);

fmha_ref_compute.ComputeBackward(
*transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
Expand Down Expand Up @@ -432,15 +433,14 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
(d_ln_bias == nullptr ? nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace()));

qkv_compute.ComputeBackward(ln_out_data, qkv_weight_data,
d_qkv_bias_out_data, d_ln_out_data,
d_qkv_weight_data, d_qkv_bias_data);
qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out, d_ln_out,
d_qkv_weight, d_qkv_bias);
layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data,
ln_mean_data, ln_var_data, d_x_data,
d_ln_scale_data, d_ln_bias_data);
} else {
qkv_compute.ComputeBackward(x_data, qkv_weight_data, d_qkv_bias_out_data,
d_x_data, d_qkv_weight_data, d_qkv_bias_data);
qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x,
d_qkv_weight, d_qkv_bias);
}
// gradient accumulation
std::vector<const Tensor *> ins;
Expand Down

0 comments on commit c9e3439

Please sign in to comment.