Skip to content

Commit

Permalink
Fix fused_attention_op and fused_feedforward_op bug when pre_layer_no…
Browse files Browse the repository at this point in the history
…rm is false. (#36793)

* Fix bug when pre_layer_norm is false.
  • Loading branch information
limin2021 committed Oct 28, 2021
1 parent 96edcea commit d344aeb
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 144 deletions.
101 changes: 64 additions & 37 deletions paddle/fluid/operators/fused/fused_attention_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
"FusedAttentionOp");

OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut",
"FusedAttentionOp");
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut",
"FusedAttentionOp");
}

// qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut",
"FusedAttentionOp");
Expand Down Expand Up @@ -101,9 +104,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"input qkv_weight = [%s]",
x_dim, y_dim));

ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnOut", ctx->GetInputDim("X"));
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnOut", ctx->GetInputDim("X"));
}
// [batch_size, seq_len, 3, num_head, head_size]
ctx->SetOutputDim("QKVOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
Expand Down Expand Up @@ -351,11 +356,11 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("Ln2Bias"));
}
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance",
"FusedAttentionGrad");
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut",
"FusedAttentionGrad");
}
Expand All @@ -370,13 +375,15 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
"FusedAttentionGrad");

if (ctx->HasOutput(framework::GradVarName("LnScale"))) {
ctx->SetOutputDim(framework::GradVarName("LnScale"),
ctx->GetInputDim("LnScale"));
}
if (ctx->HasOutput(framework::GradVarName("LnBias"))) {
ctx->SetOutputDim(framework::GradVarName("LnBias"),
ctx->GetInputDim("LnBias"));
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
if (ctx->HasOutput(framework::GradVarName("LnScale"))) {
ctx->SetOutputDim(framework::GradVarName("LnScale"),
ctx->GetInputDim("LnScale"));
}
if (ctx->HasOutput(framework::GradVarName("LnBias"))) {
ctx->SetOutputDim(framework::GradVarName("LnBias"),
ctx->GetInputDim("LnBias"));
}
}
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
Expand All @@ -390,8 +397,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("QKVBias"),
ctx->GetInputDim("QKVBias"));

ctx->SetOutputDim(framework::GradVarName("LnOut"),
ctx->GetInputDim("LnOut"));
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim(framework::GradVarName("LnOut"),
ctx->GetInputDim("LnOut"));
}
ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
ctx->GetInputDim("FMHAOut"));
ctx->SetOutputDim(framework::GradVarName("QKTVOut"),
Expand Down Expand Up @@ -442,16 +451,23 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("SrcMask", this->Input("SrcMask"));
op->SetInput("OutLinearW", this->Input("OutLinearW"));
op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
if (this->HasInput("LnScale")) {
op->SetInput("LnScale", this->Input("LnScale"));
op->SetOutput(framework::GradVarName("LnScale"),
this->InputGrad("LnScale"));
}
if (this->HasInput("LnBias")) {
op->SetInput("LnBias", this->Input("LnBias"));
op->SetOutput(framework::GradVarName("LnBias"),
this->InputGrad("LnBias"));

op->SetAttrMap(this->Attrs());
bool is_pre_layer_norm =
BOOST_GET_CONST(bool, op->GetAttr("pre_layer_norm"));
if (is_pre_layer_norm) {
if (this->HasInput("LnScale")) {
op->SetInput("LnScale", this->Input("LnScale"));
op->SetOutput(framework::GradVarName("LnScale"),
this->InputGrad("LnScale"));
}
if (this->HasInput("LnBias")) {
op->SetInput("LnBias", this->Input("LnBias"));
op->SetOutput(framework::GradVarName("LnBias"),
this->InputGrad("LnBias"));
}
}

if (this->HasInput("Ln2Scale")) {
op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetOutput(framework::GradVarName("Ln2Scale"),
Expand All @@ -473,9 +489,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this->InputGrad("OutLinearW"));

// use forward outputs as backward inputs.
op->SetInput("LnOut", this->Output("LnOut"));
op->SetInput("LnMean", this->Output("LnMean"));
op->SetInput("LnVariance", this->Output("LnVariance"));
if (is_pre_layer_norm) {
if (this->HasOutput("LnOut")) {
op->SetInput("LnOut", this->Output("LnOut"));
}
if (this->HasOutput("LnMean")) {
op->SetInput("LnMean", this->Output("LnMean"));
}
if (this->HasOutput("LnVariance")) {
op->SetInput("LnVariance", this->Output("LnVariance"));
}
}
op->SetInput("QKVOut", this->Output("QKVOut"));
op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
op->SetInput("TransposeOut2", this->Output("TransposeOut2"));
Expand All @@ -496,7 +520,12 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("QKVOut", this->Output("QKVOut"));

// backward outputs: dinput
op->SetOutput(framework::GradVarName("LnOut"), this->OutputGrad("LnOut"));
if (is_pre_layer_norm) {
if (this->HasOutput("LnOut")) {
op->SetOutput(framework::GradVarName("LnOut"),
this->OutputGrad("LnOut"));
}
}
op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut"));
op->SetOutput(framework::GradVarName("QKVBiasOut"),
this->OutputGrad("QKVBiasOut"));
Expand All @@ -517,8 +546,6 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this->OutputGrad("BiasDropoutResidualOut"));
op->SetOutput(framework::GradVarName("OutLinearOut"),
this->OutputGrad("OutLinearOut"));

op->SetAttrMap(this->Attrs());
}
};

Expand Down
44 changes: 25 additions & 19 deletions paddle/fluid/operators/fused/fused_attention_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,12 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto *x_data = input_x->data<T>();
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *ln_mean_data = ln_mean->mutable_data<U>(ctx.GetPlace());
auto *ln_var_data = ln_var->mutable_data<U>(ctx.GetPlace());
auto *ln_out_data = ln_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_data =
pre_layer_norm ? ln_mean->mutable_data<U>(ctx.GetPlace()) : nullptr;
auto *ln_var_data =
pre_layer_norm ? ln_var->mutable_data<U>(ctx.GetPlace()) : nullptr;
auto *ln_out_data =
pre_layer_norm ? ln_out->mutable_data<T>(ctx.GetPlace()) : nullptr;

auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>();
Expand Down Expand Up @@ -243,9 +246,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *out_linear_bias_data = out_linear_bias->data<T>();

// fw output
auto *ln_mean = ctx.Input<Tensor>("LnMean");
auto *ln_var = ctx.Input<Tensor>("LnVariance");
auto *ln_out = ctx.Input<Tensor>("LnOut");
auto *fmha_out = ctx.Input<Tensor>("FMHAOut");
auto *transpose_out_2 = ctx.Input<Tensor>("TransposeOut2");
auto *qk_out = ctx.Input<Tensor>("QKOut");
Expand All @@ -260,9 +260,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *dropout_mask_out = ctx.Input<Tensor>("DropoutMaskOut");
auto *bias_dropout_residual_out =
ctx.Input<Tensor>("BiasDropoutResidualOut");
auto *ln_mean_data = ln_mean->data<U>();
auto *ln_var_data = ln_var->data<U>();
auto *ln_out_data = ln_out->data<T>();
auto *fmha_out_data = fmha_out->data<T>();
auto *transpose_out_2_data = transpose_out_2->data<T>();
auto *qk_out_data = qk_out->data<T>();
Expand All @@ -277,7 +274,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {

// output's grad
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_ln_out = ctx.Output<Tensor>(framework::GradVarName("LnOut"));
auto *d_qkv_out = ctx.Output<Tensor>(framework::GradVarName("QKVOut"));
auto *d_qkv_bias_out =
ctx.Output<Tensor>(framework::GradVarName("QKVBiasOut"));
Expand All @@ -297,7 +293,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_bias_dropout_residual_out =
ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut"));
auto *d_x_data = d_x->mutable_data<T>(ctx.GetPlace());
auto *d_ln_out_data = d_ln_out->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_out_data = d_qkv_out->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_out_data = d_qkv_bias_out->mutable_data<T>(ctx.GetPlace());
auto *d_qktv_out_data = d_qktv_out->mutable_data<T>(ctx.GetPlace());
Expand All @@ -315,8 +310,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());

// parameter grad
auto *d_ln_scale = ctx.Output<Tensor>(framework::GradVarName("LnScale"));
auto *d_ln_bias = ctx.Output<Tensor>(framework::GradVarName("LnBias"));
auto *d_qkv_weight = ctx.Output<Tensor>(framework::GradVarName("QKVW"));
auto *d_qkv_bias = ctx.Output<Tensor>(framework::GradVarName("QKVBias"));
auto *d_out_linear_weight =
Expand All @@ -325,12 +318,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
ctx.Output<Tensor>(framework::GradVarName("OutLinearBias"));
auto *d_ln_2_scale = ctx.Output<Tensor>(framework::GradVarName("Ln2Scale"));
auto *d_ln_2_bias = ctx.Output<Tensor>(framework::GradVarName("Ln2Bias"));
auto *d_ln_scale_data =
(d_ln_scale == nullptr ? nullptr
: d_ln_scale->mutable_data<U>(ctx.GetPlace()));
auto *d_ln_bias_data =
(d_ln_bias == nullptr ? nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace()));

auto *d_qkv_weight_data = d_qkv_weight->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_data = d_qkv_bias->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_weight_data =
Expand Down Expand Up @@ -407,6 +395,24 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
cudaMemcpyDeviceToDevice);

if (pre_layer_norm) {
auto *ln_mean = ctx.Input<Tensor>("LnMean");
auto *ln_var = ctx.Input<Tensor>("LnVariance");
auto *ln_out = ctx.Input<Tensor>("LnOut");
auto *ln_mean_data = ln_mean->data<U>();
auto *ln_var_data = ln_var->data<U>();
auto *ln_out_data = ln_out->data<T>();

auto *d_ln_out = ctx.Output<Tensor>(framework::GradVarName("LnOut"));
auto *d_ln_scale = ctx.Output<Tensor>(framework::GradVarName("LnScale"));
auto *d_ln_bias = ctx.Output<Tensor>(framework::GradVarName("LnBias"));
auto *d_ln_out_data = d_ln_out->mutable_data<T>(ctx.GetPlace());
auto *d_ln_scale_data =
(d_ln_scale == nullptr ? nullptr
: d_ln_scale->mutable_data<U>(ctx.GetPlace()));
auto *d_ln_bias_data =
(d_ln_bias == nullptr ? nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace()));

qkv_compute.ComputeBackward(ln_out_data, qkv_weight_data,
d_qkv_bias_out_data, d_ln_out_data,
d_qkv_weight_data, d_qkv_bias_data);
Expand Down
Loading

1 comment on commit d344aeb

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.