Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix attn_bias_add bug. #37147

Merged
merged 6 commits into from
Nov 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 53 additions & 19 deletions paddle/fluid/operators/fused/attn_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ limitations under the License. */

#pragma once

#include "paddle/fluid/operators/fused/attn_bias_add.cu.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/float16.h"

#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"

namespace paddle {
namespace operators {

Expand All @@ -36,8 +39,10 @@ class AttnMatMul {

~AttnMatMul() {}

void ComputeForward(const T* weight_data, const T* input_data,
const T* bias_data, T* output_data, T* bias_out_data) {
void ComputeForward(const framework::Tensor* weight,
const framework::Tensor* input,
const framework::Tensor* bias, framework::Tensor* output,
framework::Tensor* bias_out) {
// Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
// here: (transa, transb): nt, input * weight.
CBLAS_TRANSPOSE transA = CblasNoTrans;
Expand All @@ -54,16 +59,25 @@ class AttnMatMul {
// here: (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha,
input_data, weight_data, beta, output_data);
input->data<T>(), weight->data<T>(), beta, output->data<T>());
if (compute_bias_) {
// compute output + bias
LaunchBiasAddFwKernel(dev_ctx_, bsz_seq_, output_size_, output_data,
bias_data, bias_out_data);
std::vector<const Tensor*> ins;
std::vector<Tensor*> outs;
ins.emplace_back(output);
ins.emplace_back(bias);
outs.emplace_back(bias_out);
int elewise_add_axis = -1;
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor<T>());
}
}

void ComputeBackward(const T* input, const T* weight, const T* d_output,
T* d_input, T* d_weight, T* d_bias) {
void ComputeBackward(const framework::Tensor* input,
const framework::Tensor* weight,
const framework::Tensor* d_output,
framework::Tensor* d_input, framework::Tensor* d_weight,
framework::Tensor* d_bias) {
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
Expand All @@ -81,11 +95,11 @@ class AttnMatMul {

T* dB_input_1_ptr = nullptr;
T* dB_input_2_ptr = nullptr;
T* dB_output_ptr = d_weight;
T* dB_output_ptr = d_weight->data<T>();

T* dA_input_1_ptr = nullptr;
T* dA_input_2_ptr = nullptr;
T* dA_output_ptr = d_input;
T* dA_output_ptr = d_input->data<T>();

if (!transA_) {
// fw: gemm-nt
Expand All @@ -104,10 +118,10 @@ class AttnMatMul {
dA_n = input_size_;
dA_k = output_size_;

blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, d_output,
input, beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
weight, beta, dA_output_ptr);
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha,
d_output->data<T>(), input->data<T>(), beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha,
d_output->data<T>(), weight->data<T>(), beta, dA_output_ptr);
} else { // fw: gemm-nn
// bw: gemm-tn, dB = A^t * dC
dB_transA = CblasTrans;
Expand All @@ -123,10 +137,10 @@ class AttnMatMul {
dA_n = input_size_;
dA_k = output_size_;

blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, input,
d_output, beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
weight, beta, dA_output_ptr);
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha,
input->data<T>(), d_output->data<T>(), beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha,
d_output->data<T>(), weight->data<T>(), beta, dA_output_ptr);
}
} else if (transB_) {
PADDLE_THROW(platform::errors::InvalidArgument(
Expand All @@ -138,7 +152,27 @@ class AttnMatMul {
"parameters."));
}
if (compute_bias_) {
LaunchBiasAddBwKernel(dev_ctx_, bsz_seq_, output_size_, d_output, d_bias);
// reduce: {0, 1, 2, 3, 4} -> {2, 3, 4} or {0, 1, 2} -> {2}
const auto input_dims = d_output->dims();
const auto output_dims = d_bias->dims();
bool support_case_1 =
(input_dims.size() == 5 && output_dims.size() == 3 &&
(input_dims[2] == output_dims[0]) &&
(input_dims[3] == output_dims[1]) &&
(input_dims[4] == output_dims[2]));
bool support_case_2 =
(input_dims.size() == 3 && output_dims.size() == 1 &&
(input_dims[2] == output_dims[0]));
if (support_case_1 || support_case_2) {
gpuStream_t stream = dev_ctx_.stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*d_output, d_bias, {0, 1},
stream);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support reduce when the input dims are [0,1,2,3,4] and "
"output is [2,3,4]"
"or input is [0,1,2] and output is [2]."));
}
}
}

Expand Down
28 changes: 14 additions & 14 deletions paddle/fluid/operators/fused/fused_attention_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {

layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data,
ln_out_data, ln_mean_data, ln_var_data);
qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data,
qkv_out_data, qkv_bias_out_data);
qkv_compute.ComputeForward(qkv_weight, ln_out, qkv_bias, qkv_out,
qkv_bias_out);
} else {
qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data,
qkv_out_data, qkv_bias_out_data);
qkv_compute.ComputeForward(qkv_weight, input_x, qkv_bias, qkv_out,
qkv_bias_out);
}
fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out,
Expand All @@ -184,8 +184,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim]
// out_linear_out: [batch_size, seq_len, embed_dim]
out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data,
nullptr, out_linear_out_data, nullptr);
out_linear_compute.ComputeForward(out_linear_weight, fmha_out, nullptr,
out_linear_out, nullptr);
if (pre_layer_norm) {
// output = (residual + dropout(input + bias))
fused_dropout_layernorm_helper.ResidualDropoutBias(
Expand Down Expand Up @@ -401,9 +401,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_out_linear_out_data, d_out_linear_bias_data, d_residual_data);
}

out_linear_compute.ComputeBackward(fmha_out_data, out_linear_weight_data,
d_out_linear_out_data, d_fmha_out_data,
d_out_linear_weight_data, nullptr);
out_linear_compute.ComputeBackward(fmha_out, out_linear_weight,
d_out_linear_out, d_fmha_out,
d_out_linear_weight, nullptr);

fmha_ref_compute.ComputeBackward(
*transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
Expand Down Expand Up @@ -432,15 +433,14 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
(d_ln_bias == nullptr ? nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace()));

qkv_compute.ComputeBackward(ln_out_data, qkv_weight_data,
d_qkv_bias_out_data, d_ln_out_data,
d_qkv_weight_data, d_qkv_bias_data);
qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out, d_ln_out,
d_qkv_weight, d_qkv_bias);
layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data,
ln_mean_data, ln_var_data, d_x_data,
d_ln_scale_data, d_ln_bias_data);
} else {
qkv_compute.ComputeBackward(x_data, qkv_weight_data, d_qkv_bias_out_data,
d_x_data, d_qkv_weight_data, d_qkv_bias_data);
qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x,
d_qkv_weight, d_qkv_bias);
}
// gradient accumulation
std::vector<const Tensor *> ins;
Expand Down
19 changes: 13 additions & 6 deletions python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,27 +89,32 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
qkv_weight = qkv_weight.reshape(qkv_weight.shape[0], qkv_weight.shape[1] *
qkv_weight.shape[2] * qkv_weight.shape[3])

qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] *
qkv_bias.shape[2])
if (pre_layer_norm):
ln_out = ln_out.reshape(batch_size * seq_len, embed_dim)
qkv = fc(ln_out, qkv_weight)
qkv_bias_out = qkv + qkv_bias
ln_out = ln_out.reshape(batch_size, seq_len, embed_dim)
else:
query = query.reshape(batch_size * seq_len, embed_dim)
qkv = fc(query, qkv_weight)
qkv_bias_out = qkv + qkv_bias
query = query.reshape(batch_size, seq_len, embed_dim)

qkv = qkv.reshape(batch_size, seq_len, 3, num_head, head_dim)
qkv_bias_out = qkv_bias_out.reshape(batch_size, seq_len, 3, num_head,
head_dim)
# q*k^t
qkv = qkv.transpose(
qkv_bias_out = qkv_bias_out.transpose(
(2, 0, 1, 3, 4)) # 3, batch_size, seq_len, num_head, head_dim
qkv = qkv.transpose(
qkv_bias_out = qkv_bias_out.transpose(
(0, 1, 3, 2, 4)) # 3, batch_size, num_head, seq_len, head_dim

q = qkv[0:1, ::]
q = qkv_bias_out[0:1, ::]
q = q.reshape(batch_size, num_head, seq_len, head_dim)
k = qkv[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim]
k = qkv_bias_out[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim]
k = k.reshape(batch_size, num_head, seq_len, head_dim)
v = qkv[2::]
v = qkv_bias_out[2::]
v = v.reshape(batch_size, num_head, seq_len, head_dim)

k = k.transpose([0, 1, 3, 2]) #[batch_size, num_head, head_dim, seq_len]
Expand Down Expand Up @@ -200,6 +205,8 @@ def run_imperative(self):
self.embed_dim, self.num_heads, self.dropout_prob,
self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
self.need_weight, self.weight_attr, self.bias_attr)
qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype('float32')
fused_attn.qkv_bias.set_value(paddle.to_tensor(qkv_bias))
out = fused_attn(
paddle.to_tensor(self.query),
paddle.to_tensor(self.query),
Expand Down
33 changes: 28 additions & 5 deletions python/paddle/incubate/nn/functional/fused_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,18 @@ def fused_feedforward(x,
ln1_epsilon (float, optional): Small float of first layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.
ln2_epsilon (float, optional): Small float of second layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.
pre_layer_norm (bool, optional): add layer_norm in the pre-processing stage or post-processing state.
training (bool): A flag indicating whether it is in train phrase or not. Default True.
mode(str): ['upscale_in_train'(default) | 'downscale_in_infer'].
training (bool, optional): A flag indicating whether it is in train phrase or not. Default True.
mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']

1. upscale_in_train(default), upscale the output at training time

- train: out = input * mask / ( 1.0 - p )
- inference: out = input

2. downscale_in_infer, downscale the output at inference

- train: out = input * mask
- inference: out = input * (1.0 - p)
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Returns:
Expand Down Expand Up @@ -245,7 +255,10 @@ def fused_multi_head_attention(x,
out = out * v
out = transpose(out, perm=[0, 2, 1, 3])
out = out_linear(out)
out = layer_norm(x + dropout(linear_bias + out))
if pre_layer_norm:
out = x + dropout(linear_bias + out)
else:
out = layer_norm(x + dropout(linear_bias + out))

Parameters:
x (Tensor): The input tensor of fused_multi_head_attention. The shape is
Expand Down Expand Up @@ -278,8 +291,18 @@ def fused_multi_head_attention(x,
0 for no dropout. Default 0.5.
ln_epsilon (float, optional): Small float value added to denominator of layer_norm
to avoid dividing by zero. Default is 1e-5.
training (bool): A flag indicating whether it is in train phrase or not. Default True.
mode(str): ['upscale_in_train'(default) | 'downscale_in_infer'].
training (bool, optional): A flag indicating whether it is in train phrase or not. Default True.
mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']

1. upscale_in_train(default), upscale the output at training time

- train: out = input * mask / ( 1.0 - p )
- inference: out = input

2. downscale_in_infer, downscale the output at inference

- train: out = input * mask
- inference: out = input * (1.0 - p)
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Returns:
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/incubate/nn/layer/fused_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ class FusedMultiHeadAttention(Layer):
attn_dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout in attention.
0 for no dropout. Default 0.5.
epsilon (float, optional): he small value added to the variance to prevent
division by zero. Default: 1e-05.
kdim (int, optional): The feature size in key. If None, assumed equal to
`embed_dim`. Default None.
vdim (int, optional): The feature size in value. If None, assumed equal to
Expand All @@ -56,6 +54,8 @@ class FusedMultiHeadAttention(Layer):
Default: None, which means the default bias parameter property is used.
If it is set to False, this layer will not have trainable bias parameter.
See usage for details in :code:`ParamAttr`.
epsilon (float, optional): The small value added to the variance to prevent
division by zero. Default: 1e-05.

Examples:

Expand Down