Skip to content

Commit

Permalink
Fix fused_attention_op scope. (#37065)
Browse files Browse the repository at this point in the history
att, bug fix
  • Loading branch information
limin2021 authored Nov 10, 2021
1 parent 48d53cf commit ad44a40
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 108 deletions.
128 changes: 70 additions & 58 deletions paddle/fluid/operators/fused/fused_attention_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut",
"FusedAttentionOp");
} else {
OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output",
"BiasDropoutResidualOut", "FusedAttentionOp");
}

// qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
Expand Down Expand Up @@ -70,12 +77,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"), "Output", "OutLinearOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output",
"BiasDropoutResidualOut", "FusedAttentionOp");

OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), "Output", "DropoutMaskOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp");
Expand Down Expand Up @@ -109,6 +111,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnOut", ctx->GetInputDim("X"));
} else {
ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X"));
}
// [batch_size, seq_len, 3, num_head, head_size]
ctx->SetOutputDim("QKVOut",
Expand Down Expand Up @@ -138,12 +144,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]});
ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X"));

ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]});
if (ctx->Attrs().Get<bool>("dropout_is_test") == false) {
ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X"));
}
ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X"));

ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
}

Expand Down Expand Up @@ -314,25 +318,28 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
});

AddComment(R"DOC(
Add fused attention op whose logic is as follows:
// @input: [batch_size, seq_len, 3, num_head, head_dim]
// @final_out: [batch_size, seq_len, num_heads, head_dim]
if (pre_layernorm)
out = layer_norm(input);
Add fused attention op whose logic is as follows:
// @input: [batch_size, seq_len, 3, num_head, head_dim]
// @final_out: [batch_size, seq_len, num_heads, head_dim]
if (pre_layernorm)
out = layer_norm(input);
out = compute_qkv(out) + bias;
// fmha module
{
out = transpose(out, perm=[2, 0, 3, 1, 4]);
out = q * k^t;
out = attn_mask + out;
out = softmax(out);
out = dropout(out);
out = out * v;
out = transpose(out, perm=[0, 2, 1, 3]);
{
out = transpose(out, perm=[2, 0, 3, 1, 4]);
out = q * k^t;
out = attn_mask + out;
out = softmax(out);
out = dropout(out);
out = out * v;
out = transpose(out, perm=[0, 2, 1, 3]);
}
}
out = out_linear(out);
final_out = layer_norm(residual + dropout(bias + out));
if (pre_layernorm)
final_out = residual + dropout(bias + out);
else
final_out = layer_norm(residual + dropout(bias + out));
)DOC");
}
};
Expand All @@ -347,27 +354,29 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"GradOp is only callable when attn_dropout_is_test is false"));

OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance",
"FusedAttentionGrad");
if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Ln2Scale"),
ctx->GetInputDim("Ln2Scale"));
}
if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Ln2Bias"),
ctx->GetInputDim("Ln2Bias"));
}
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
if (ctx->Attrs().Get<bool>("pre_layer_norm") == false) {
OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance",
"FusedAttentionGrad");
if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Ln2Scale"),
ctx->GetInputDim("Ln2Scale"));
}
if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Ln2Bias"),
ctx->GetInputDim("Ln2Bias"));
}
} else {
OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut",
"FusedAttentionGrad");
}

OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
Expand Down Expand Up @@ -402,6 +411,9 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim(framework::GradVarName("LnOut"),
ctx->GetInputDim("LnOut"));
} else {
ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"),
ctx->GetInputDim("BiasDropoutResidualOut"));
}
ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
ctx->GetInputDim("FMHAOut"));
Expand All @@ -426,8 +438,6 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("QKVBiasOut"));
ctx->SetOutputDim(framework::GradVarName("OutLinearOut"),
ctx->GetInputDim("OutLinearOut"));
ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"),
ctx->GetInputDim("BiasDropoutResidualOut"));
}

protected:
Expand Down Expand Up @@ -478,17 +488,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("LnBias"),
this->InputGrad("LnBias"));
}
}

if (this->HasInput("Ln2Scale")) {
op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetOutput(framework::GradVarName("Ln2Scale"),
this->InputGrad("Ln2Scale"));
}
if (this->HasInput("Ln2Bias")) {
op->SetInput("Ln2Bias", this->Input("Ln2Bias"));
op->SetOutput(framework::GradVarName("Ln2Bias"),
this->InputGrad("Ln2Bias"));
} else {
if (this->HasInput("Ln2Scale")) {
op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetOutput(framework::GradVarName("Ln2Scale"),
this->InputGrad("Ln2Scale"));
}
if (this->HasInput("Ln2Bias")) {
op->SetInput("Ln2Bias", this->Input("Ln2Bias"));
op->SetOutput(framework::GradVarName("Ln2Bias"),
this->InputGrad("Ln2Bias"));
}
}

op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
Expand All @@ -511,6 +521,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
if (this->HasOutput("LnVariance")) {
op->SetInput("LnVariance", this->Output("LnVariance"));
}
} else {
op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetInput("BiasDropoutResidualOut",
this->Output("BiasDropoutResidualOut"));
}
op->SetInput("QKVOut", this->Output("QKVOut"));
op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
Expand All @@ -523,12 +538,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {

op->SetInput("FMHAOut", this->Output("FMHAOut"));
op->SetInput("OutLinearOut", this->Output("OutLinearOut"));

op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut"));
op->SetInput("BiasDropoutResidualOut",
this->Output("BiasDropoutResidualOut"));
op->SetInput("QKVOut", this->Output("QKVOut"));

// backward outputs: dinput
Expand All @@ -537,7 +547,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("LnOut"),
this->OutputGrad("LnOut"));
}
} else {
op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
this->OutputGrad("BiasDropoutResidualOut"));
}

op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut"));
op->SetOutput(framework::GradVarName("QKVBiasOut"),
this->OutputGrad("QKVBiasOut"));
Expand All @@ -553,8 +567,6 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {

op->SetOutput(framework::GradVarName("FMHAOut"),
this->OutputGrad("FMHAOut"));
op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
this->OutputGrad("BiasDropoutResidualOut"));
op->SetOutput(framework::GradVarName("OutLinearOut"),
this->OutputGrad("OutLinearOut"));
}
Expand Down
91 changes: 52 additions & 39 deletions paddle/fluid/operators/fused/fused_attention_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,6 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
const auto qkv_w_dims = qkv_weight->dims();

auto *x_data = input_x->data<T>();
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *ln_mean_data =
pre_layer_norm ? ln_mean->mutable_data<U>(ctx.GetPlace()) : nullptr;
auto *ln_var_data =
pre_layer_norm ? ln_var->mutable_data<U>(ctx.GetPlace()) : nullptr;
auto *ln_out_data =
pre_layer_norm ? ln_out->mutable_data<T>(ctx.GetPlace()) : nullptr;

auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>();
auto *qkv_out_data = qkv_out->mutable_data<T>(ctx.GetPlace());
Expand All @@ -130,16 +121,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto *out_linear_out_data = out_linear_out->mutable_data<T>(ctx.GetPlace());

// get data ptr for bias+dropout+residual+layernorm
auto *ln_scale_2_data =
(ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>());
auto *ln_bias_2_data =
(ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>());
auto *dropout_mask_out_data =
dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace());
auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace());
auto *final_out_data = out->mutable_data<T>(ctx.GetPlace());

int batch_size = input_x_dims[0];
Expand Down Expand Up @@ -178,6 +161,13 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
ln_epsilon);

if (pre_layer_norm) {
auto *ln_scale_data =
(ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *ln_mean_data = ln_mean->mutable_data<U>(ctx.GetPlace());
auto *ln_var_data = ln_var->mutable_data<U>(ctx.GetPlace());
auto *ln_out_data = ln_out->mutable_data<T>(ctx.GetPlace());

layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data,
ln_out_data, ln_mean_data, ln_var_data);
qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data,
Expand All @@ -196,12 +186,27 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// out_linear_out: [batch_size, seq_len, embed_dim]
out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data,
nullptr, out_linear_out_data, nullptr);
// output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data,
out_linear_bias_data, ln_scale_2_data, ln_bias_2_data,
bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data,
ln_mean_2_data, ln_var_2_data);
if (pre_layer_norm) {
// output = (residual + dropout(input + bias))
fused_dropout_layernorm_helper.ResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data,
out_linear_bias_data, final_out_data, dropout_mask_out_data);
} else {
auto *ln_scale_2_data =
(ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>());
auto *ln_bias_2_data =
(ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>());
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace());
auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace());
// output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data,
out_linear_bias_data, ln_scale_2_data, ln_bias_2_data,
bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data,
ln_mean_2_data, ln_var_2_data);
}
}
};

Expand Down Expand Up @@ -271,10 +276,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
auto *out_linear_out_data = out_linear_out->data<T>();
auto *ln_2_mean_data = ln_2_mean->data<U>();
auto *ln_2_var_data = ln_2_var->data<U>();
auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>();
auto *bias_dropout_residual_out_data = bias_dropout_residual_out->data<T>();

// output's grad
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
Expand Down Expand Up @@ -312,8 +314,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_fmha_out_data = d_fmha_out->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_out_data =
d_out_linear_out->mutable_data<T>(ctx.GetPlace());
auto *d_bias_dropout_residual_out_data =
d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());

// parameter grad
auto *d_qkv_weight = ctx.Output<Tensor>(framework::GradVarName("QKVW"));
Expand All @@ -331,12 +331,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_out_linear_weight->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_bias_data =
d_out_linear_bias->mutable_data<T>(ctx.GetPlace());
auto *d_ln_2_scale_data =
(d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data<U>(
ctx.GetPlace()));
auto *d_ln_2_bias_data =
(d_ln_2_bias == nullptr ? nullptr
: d_ln_2_bias->mutable_data<U>(ctx.GetPlace()));

const auto input_x_dims = input_x->dims();
const auto qkv_w_dims = qkv_weight->dims();
Expand Down Expand Up @@ -382,11 +376,30 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2,
ln2epsilon);

fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data,
dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data,
d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data,
d_out_linear_out_data, d_out_linear_bias_data, d_residual_data);
if (pre_layer_norm) {
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
ctx.cuda_device_context(), d_y_data, dropout_mask_out_data,
d_out_linear_out_data, d_residual_data, d_out_linear_bias_data);
} else {
auto *ln_2_mean_data = ln_2_mean->data<U>();
auto *ln_2_var_data = ln_2_var->data<U>();
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->data<T>();
auto *d_ln_2_scale_data =
(d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data<U>(
ctx.GetPlace()));
auto *d_ln_2_bias_data =
(d_ln_2_bias == nullptr ? nullptr : d_ln_2_bias->mutable_data<U>(
ctx.GetPlace()));
auto *d_bias_dropout_residual_out_data =
d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());

fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data,
dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data,
d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data,
d_out_linear_out_data, d_out_linear_bias_data, d_residual_data);
}

out_linear_compute.ComputeBackward(fmha_out_data, out_linear_weight_data,
d_out_linear_out_data, d_fmha_out_data,
Expand Down
Loading

0 comments on commit ad44a40

Please sign in to comment.