diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 6c4ac318264e8..f7c7129c7732b 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -37,12 +37,15 @@ class FusedAttentionOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean", - "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance", - "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut", - "FusedAttentionOp"); + if (ctx->Attrs().Get("pre_layer_norm") == true) { + OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut", + "FusedAttentionOp"); + } + // qkv_out: [batch_size, seq_len, 3, num_head, dim_head] OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut", "FusedAttentionOp"); @@ -101,9 +104,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "input qkv_weight = [%s]", x_dim, y_dim)); - ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); - ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); - ctx->SetOutputDim("LnOut", ctx->GetInputDim("X")); + if (ctx->Attrs().Get("pre_layer_norm") == true) { + ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("LnOut", ctx->GetInputDim("X")); + } // [batch_size, seq_len, 3, num_head, head_size] ctx->SetOutputDim("QKVOut", {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); @@ -351,11 +356,11 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ctx->GetInputDim("Ln2Bias")); } OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad"); - OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean", - "FusedAttentionGrad"); - OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance", - "FusedAttentionGrad"); if (ctx->Attrs().Get("pre_layer_norm") == true) { + OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean", + "FusedAttentionGrad"); + OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance", + "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut", "FusedAttentionGrad"); } @@ -370,13 +375,15 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", "FusedAttentionGrad"); - if (ctx->HasOutput(framework::GradVarName("LnScale"))) { - ctx->SetOutputDim(framework::GradVarName("LnScale"), - ctx->GetInputDim("LnScale")); - } - if (ctx->HasOutput(framework::GradVarName("LnBias"))) { - ctx->SetOutputDim(framework::GradVarName("LnBias"), - ctx->GetInputDim("LnBias")); + if (ctx->Attrs().Get("pre_layer_norm") == true) { + if (ctx->HasOutput(framework::GradVarName("LnScale"))) { + ctx->SetOutputDim(framework::GradVarName("LnScale"), + ctx->GetInputDim("LnScale")); + } + if (ctx->HasOutput(framework::GradVarName("LnBias"))) { + ctx->SetOutputDim(framework::GradVarName("LnBias"), + ctx->GetInputDim("LnBias")); + } } if (ctx->HasOutput(framework::GradVarName("X"))) { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); @@ -390,8 +397,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("QKVBias"), ctx->GetInputDim("QKVBias")); - ctx->SetOutputDim(framework::GradVarName("LnOut"), - ctx->GetInputDim("LnOut")); + if (ctx->Attrs().Get("pre_layer_norm") == true) { + ctx->SetOutputDim(framework::GradVarName("LnOut"), + ctx->GetInputDim("LnOut")); + } ctx->SetOutputDim(framework::GradVarName("FMHAOut"), ctx->GetInputDim("FMHAOut")); ctx->SetOutputDim(framework::GradVarName("QKTVOut"), @@ -442,16 +451,23 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("SrcMask", this->Input("SrcMask")); op->SetInput("OutLinearW", this->Input("OutLinearW")); op->SetInput("OutLinearBias", this->Input("OutLinearBias")); - if (this->HasInput("LnScale")) { - op->SetInput("LnScale", this->Input("LnScale")); - op->SetOutput(framework::GradVarName("LnScale"), - this->InputGrad("LnScale")); - } - if (this->HasInput("LnBias")) { - op->SetInput("LnBias", this->Input("LnBias")); - op->SetOutput(framework::GradVarName("LnBias"), - this->InputGrad("LnBias")); + + op->SetAttrMap(this->Attrs()); + bool is_pre_layer_norm = + BOOST_GET_CONST(bool, op->GetAttr("pre_layer_norm")); + if (is_pre_layer_norm) { + if (this->HasInput("LnScale")) { + op->SetInput("LnScale", this->Input("LnScale")); + op->SetOutput(framework::GradVarName("LnScale"), + this->InputGrad("LnScale")); + } + if (this->HasInput("LnBias")) { + op->SetInput("LnBias", this->Input("LnBias")); + op->SetOutput(framework::GradVarName("LnBias"), + this->InputGrad("LnBias")); + } } + if (this->HasInput("Ln2Scale")) { op->SetInput("Ln2Scale", this->Input("Ln2Scale")); op->SetOutput(framework::GradVarName("Ln2Scale"), @@ -473,9 +489,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { this->InputGrad("OutLinearW")); // use forward outputs as backward inputs. - op->SetInput("LnOut", this->Output("LnOut")); - op->SetInput("LnMean", this->Output("LnMean")); - op->SetInput("LnVariance", this->Output("LnVariance")); + if (is_pre_layer_norm) { + if (this->HasOutput("LnOut")) { + op->SetInput("LnOut", this->Output("LnOut")); + } + if (this->HasOutput("LnMean")) { + op->SetInput("LnMean", this->Output("LnMean")); + } + if (this->HasOutput("LnVariance")) { + op->SetInput("LnVariance", this->Output("LnVariance")); + } + } op->SetInput("QKVOut", this->Output("QKVOut")); op->SetInput("QKVBiasOut", this->Output("QKVBiasOut")); op->SetInput("TransposeOut2", this->Output("TransposeOut2")); @@ -496,7 +520,12 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("QKVOut", this->Output("QKVOut")); // backward outputs: dinput - op->SetOutput(framework::GradVarName("LnOut"), this->OutputGrad("LnOut")); + if (is_pre_layer_norm) { + if (this->HasOutput("LnOut")) { + op->SetOutput(framework::GradVarName("LnOut"), + this->OutputGrad("LnOut")); + } + } op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut")); op->SetOutput(framework::GradVarName("QKVBiasOut"), this->OutputGrad("QKVBiasOut")); @@ -517,8 +546,6 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { this->OutputGrad("BiasDropoutResidualOut")); op->SetOutput(framework::GradVarName("OutLinearOut"), this->OutputGrad("OutLinearOut")); - - op->SetAttrMap(this->Attrs()); } }; diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 95e690cb17ec1..01bc49bcf4079 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -97,9 +97,12 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *x_data = input_x->data(); auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data()); - auto *ln_mean_data = ln_mean->mutable_data(ctx.GetPlace()); - auto *ln_var_data = ln_var->mutable_data(ctx.GetPlace()); - auto *ln_out_data = ln_out->mutable_data(ctx.GetPlace()); + auto *ln_mean_data = + pre_layer_norm ? ln_mean->mutable_data(ctx.GetPlace()) : nullptr; + auto *ln_var_data = + pre_layer_norm ? ln_var->mutable_data(ctx.GetPlace()) : nullptr; + auto *ln_out_data = + pre_layer_norm ? ln_out->mutable_data(ctx.GetPlace()) : nullptr; auto *qkv_weight_data = qkv_weight->data(); auto *qkv_bias_data = qkv_bias->data(); @@ -243,9 +246,6 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *out_linear_bias_data = out_linear_bias->data(); // fw output - auto *ln_mean = ctx.Input("LnMean"); - auto *ln_var = ctx.Input("LnVariance"); - auto *ln_out = ctx.Input("LnOut"); auto *fmha_out = ctx.Input("FMHAOut"); auto *transpose_out_2 = ctx.Input("TransposeOut2"); auto *qk_out = ctx.Input("QKOut"); @@ -260,9 +260,6 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *dropout_mask_out = ctx.Input("DropoutMaskOut"); auto *bias_dropout_residual_out = ctx.Input("BiasDropoutResidualOut"); - auto *ln_mean_data = ln_mean->data(); - auto *ln_var_data = ln_var->data(); - auto *ln_out_data = ln_out->data(); auto *fmha_out_data = fmha_out->data(); auto *transpose_out_2_data = transpose_out_2->data(); auto *qk_out_data = qk_out->data(); @@ -277,7 +274,6 @@ class FusedAttentionGradKernel : public framework::OpKernel { // output's grad auto *d_x = ctx.Output(framework::GradVarName("X")); - auto *d_ln_out = ctx.Output(framework::GradVarName("LnOut")); auto *d_qkv_out = ctx.Output(framework::GradVarName("QKVOut")); auto *d_qkv_bias_out = ctx.Output(framework::GradVarName("QKVBiasOut")); @@ -297,7 +293,6 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *d_bias_dropout_residual_out = ctx.Output(framework::GradVarName("BiasDropoutResidualOut")); auto *d_x_data = d_x->mutable_data(ctx.GetPlace()); - auto *d_ln_out_data = d_ln_out->mutable_data(ctx.GetPlace()); auto *d_qkv_out_data = d_qkv_out->mutable_data(ctx.GetPlace()); auto *d_qkv_bias_out_data = d_qkv_bias_out->mutable_data(ctx.GetPlace()); auto *d_qktv_out_data = d_qktv_out->mutable_data(ctx.GetPlace()); @@ -315,8 +310,6 @@ class FusedAttentionGradKernel : public framework::OpKernel { d_bias_dropout_residual_out->mutable_data(ctx.GetPlace()); // parameter grad - auto *d_ln_scale = ctx.Output(framework::GradVarName("LnScale")); - auto *d_ln_bias = ctx.Output(framework::GradVarName("LnBias")); auto *d_qkv_weight = ctx.Output(framework::GradVarName("QKVW")); auto *d_qkv_bias = ctx.Output(framework::GradVarName("QKVBias")); auto *d_out_linear_weight = @@ -325,12 +318,7 @@ class FusedAttentionGradKernel : public framework::OpKernel { ctx.Output(framework::GradVarName("OutLinearBias")); auto *d_ln_2_scale = ctx.Output(framework::GradVarName("Ln2Scale")); auto *d_ln_2_bias = ctx.Output(framework::GradVarName("Ln2Bias")); - auto *d_ln_scale_data = - (d_ln_scale == nullptr ? nullptr - : d_ln_scale->mutable_data(ctx.GetPlace())); - auto *d_ln_bias_data = - (d_ln_bias == nullptr ? nullptr - : d_ln_bias->mutable_data(ctx.GetPlace())); + auto *d_qkv_weight_data = d_qkv_weight->mutable_data(ctx.GetPlace()); auto *d_qkv_bias_data = d_qkv_bias->mutable_data(ctx.GetPlace()); auto *d_out_linear_weight_data = @@ -407,6 +395,24 @@ class FusedAttentionGradKernel : public framework::OpKernel { cudaMemcpyDeviceToDevice); if (pre_layer_norm) { + auto *ln_mean = ctx.Input("LnMean"); + auto *ln_var = ctx.Input("LnVariance"); + auto *ln_out = ctx.Input("LnOut"); + auto *ln_mean_data = ln_mean->data(); + auto *ln_var_data = ln_var->data(); + auto *ln_out_data = ln_out->data(); + + auto *d_ln_out = ctx.Output(framework::GradVarName("LnOut")); + auto *d_ln_scale = ctx.Output(framework::GradVarName("LnScale")); + auto *d_ln_bias = ctx.Output(framework::GradVarName("LnBias")); + auto *d_ln_out_data = d_ln_out->mutable_data(ctx.GetPlace()); + auto *d_ln_scale_data = + (d_ln_scale == nullptr ? nullptr + : d_ln_scale->mutable_data(ctx.GetPlace())); + auto *d_ln_bias_data = + (d_ln_bias == nullptr ? nullptr + : d_ln_bias->mutable_data(ctx.GetPlace())); + qkv_compute.ComputeBackward(ln_out_data, qkv_weight_data, d_qkv_bias_out_data, d_ln_out_data, d_qkv_weight_data, d_qkv_bias_data); diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cc b/paddle/fluid/operators/fused/fused_feedforward_op.cc index 4e03c7369d10e..7da790fc5c6e2 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cc +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cc @@ -41,18 +41,8 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel { "fused_feedforward"); OP_INOUT_CHECK(context->HasOutput("Dropout2Mask"), "Output", "Dropout2Mask", "fused_feedforward"); - OP_INOUT_CHECK(context->HasOutput("Ln1Mean"), "Output", "Ln1Mean", - "fused_feedforward"); - OP_INOUT_CHECK(context->HasOutput("Ln1Variance"), "Output", "Ln1Variance", - "fused_feedforward"); - OP_INOUT_CHECK(context->HasOutput("Ln2Mean"), "Output", "Ln2Mean", - "fused_feedforward"); - OP_INOUT_CHECK(context->HasOutput("Ln2Variance"), "Output", "Ln2Variance", - "fused_feedforward"); OP_INOUT_CHECK(context->HasOutput("Linear1Out"), "Output", "Linear1Out", "fused_feedforward"); - OP_INOUT_CHECK(context->HasOutput("Ln1Out"), "Output", "Ln1Out", - "fused_feedforward"); OP_INOUT_CHECK(context->HasOutput("Dropout1Out"), "Output", "Dropout1Out", "fused_feedforward"); OP_INOUT_CHECK(context->HasOutput("Dropout2Out"), "Output", "Dropout2Out", @@ -76,7 +66,6 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel { } context->SetOutputDim("Dropout1Out", tmp_dim_x); context->SetOutputDim("Linear1Out", tmp_dim_x); - context->SetOutputDim("Ln1Out", dim_x); context->SetOutputDim("Dropout2Out", dim_x); if (context->Attrs().Get("dropout2_is_test") == false) { @@ -84,10 +73,25 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel { } framework::DDim mean_dim = framework::make_ddim({mat_dim_x.batch_size_ * mat_dim_x.height_}); - context->SetOutputDim("Ln1Mean", mean_dim); - context->SetOutputDim("Ln1Variance", mean_dim); - context->SetOutputDim("Ln2Mean", mean_dim); - context->SetOutputDim("Ln2Variance", mean_dim); + bool pre_layer_norm = context->Attrs().Get("pre_layer_norm"); + if (pre_layer_norm) { + OP_INOUT_CHECK(context->HasOutput("Ln1Mean"), "Output", "Ln1Mean", + "fused_feedforward"); + OP_INOUT_CHECK(context->HasOutput("Ln1Variance"), "Output", "Ln1Variance", + "fused_feedforward"); + OP_INOUT_CHECK(context->HasOutput("Ln1Out"), "Output", "Ln1Out", + "fused_feedforward"); + context->SetOutputDim("Ln1Out", dim_x); + context->SetOutputDim("Ln1Mean", mean_dim); + context->SetOutputDim("Ln1Variance", mean_dim); + } else { + OP_INOUT_CHECK(context->HasOutput("Ln2Mean"), "Output", "Ln2Mean", + "fused_feedforward"); + OP_INOUT_CHECK(context->HasOutput("Ln2Variance"), "Output", "Ln2Variance", + "fused_feedforward"); + context->SetOutputDim("Ln2Mean", mean_dim); + context->SetOutputDim("Ln2Variance", mean_dim); + } context->ShareLoD("X", "Out"); } @@ -218,14 +222,13 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx->Attrs().Get("dropout2_is_test"), false, platform::errors::InvalidArgument( "GradOp is only callable when is_test is false")); + bool pre_layer_norm = ctx->Attrs().Get("pre_layer_norm"); OP_INOUT_CHECK(ctx->HasInput("Dropout1Mask"), "Input", "Dropout1Mask", "FusedFeedForwardGrad"); OP_INOUT_CHECK(ctx->HasInput("Dropout2Mask"), "Input", "Dropout1Mask", "FusedFeedForwardGrad"); OP_INOUT_CHECK(ctx->HasInput("Linear1Out"), "Input", "Linear1Out", "FusedFeedForwardGrad"); - OP_INOUT_CHECK(ctx->HasInput("Ln1Out"), "Input", "Ln1Out", - "FusedFeedForwardGrad"); OP_INOUT_CHECK(ctx->HasInput("Dropout1Out"), "Input", "Dropout1Out", "FusedFeedForwardGrad"); OP_INOUT_CHECK(ctx->HasInput("Dropout2Out"), "Input", "Dropout2Out", @@ -234,14 +237,19 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel { "FusedFeedForwardGrad"); OP_INOUT_CHECK(ctx->HasInput("Linear2Weight"), "Input", "Linear2Weight", "FusedFeedForwardGrad"); - OP_INOUT_CHECK(ctx->HasInput("Ln1Mean"), "Input", "Ln1Mean", - "FusedFeedForwardGrad"); - OP_INOUT_CHECK(ctx->HasInput("Ln1Variance"), "Input", "Ln1Variance", - "FusedFeedForwardGrad"); - OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean", - "FusedFeedForwardGrad"); - OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance", - "FusedFeedForwardGrad"); + if (pre_layer_norm) { + OP_INOUT_CHECK(ctx->HasInput("Ln1Mean"), "Input", "Ln1Mean", + "FusedFeedForwardGrad"); + OP_INOUT_CHECK(ctx->HasInput("Ln1Variance"), "Input", "Ln1Variance", + "FusedFeedForwardGrad"); + OP_INOUT_CHECK(ctx->HasInput("Ln1Out"), "Input", "Ln1Out", + "FusedFeedForwardGrad"); + } else { + OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean", + "FusedFeedForwardGrad"); + OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance", + "FusedFeedForwardGrad"); + } OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", framework::GradVarName("Out"), "FusedFeedForwardGrad"); @@ -299,30 +307,36 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker { op->SetInput("Linear1Weight", this->Input("Linear1Weight")); op->SetInput("Linear1Bias", this->Input("Linear1Bias")); op->SetInput("Linear2Weight", this->Input("Linear2Weight")); - op->SetInput("Ln1Scale", this->Input("Ln1Scale")); - op->SetInput("Ln1Bias", this->Input("Ln1Bias")); - op->SetInput("Ln2Scale", this->Input("Ln2Scale")); - op->SetInput("Ln2Bias", this->Input("Ln2Bias")); op->SetInput("Dropout1Mask", this->Output("Dropout1Mask")); op->SetInput("Dropout2Mask", this->Output("Dropout2Mask")); op->SetInput("Linear1Out", this->Output("Linear1Out")); - op->SetInput("Ln1Out", this->Output("Ln1Out")); - op->SetInput("Ln1Mean", this->Output("Ln1Mean")); - op->SetInput("Ln1Variance", this->Output("Ln1Variance")); - op->SetInput("Ln2Mean", this->Output("Ln2Mean")); - op->SetInput("Ln2Variance", this->Output("Ln2Variance")); op->SetInput("Dropout1Out", this->Output("Dropout1Out")); op->SetInput("Dropout2Out", this->Output("Dropout2Out")); + op->SetAttrMap(this->Attrs()); + bool pre_layer_norm = BOOST_GET_CONST(bool, op->GetAttr("pre_layer_norm")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetOutput(framework::GradVarName("Ln1Scale"), - this->InputGrad("Ln1Scale")); - op->SetOutput(framework::GradVarName("Ln1Bias"), - this->InputGrad("Ln1Bias")); - op->SetOutput(framework::GradVarName("Ln2Scale"), - this->InputGrad("Ln2Scale")); - op->SetOutput(framework::GradVarName("Ln2Bias"), - this->InputGrad("Ln2Bias")); + if (pre_layer_norm) { + op->SetInput("Ln1Scale", this->Input("Ln1Scale")); + op->SetInput("Ln1Bias", this->Input("Ln1Bias")); + op->SetInput("Ln1Out", this->Output("Ln1Out")); + op->SetInput("Ln1Mean", this->Output("Ln1Mean")); + op->SetInput("Ln1Variance", this->Output("Ln1Variance")); + op->SetOutput(framework::GradVarName("Ln1Scale"), + this->InputGrad("Ln1Scale")); + op->SetOutput(framework::GradVarName("Ln1Bias"), + this->InputGrad("Ln1Bias")); + } else { + op->SetInput("Ln2Scale", this->Input("Ln2Scale")); + op->SetInput("Ln2Bias", this->Input("Ln2Bias")); + op->SetInput("Ln2Mean", this->Output("Ln2Mean")); + op->SetInput("Ln2Variance", this->Output("Ln2Variance")); + op->SetOutput(framework::GradVarName("Ln2Scale"), + this->InputGrad("Ln2Scale")); + op->SetOutput(framework::GradVarName("Ln2Bias"), + this->InputGrad("Ln2Bias")); + } op->SetOutput(framework::GradVarName("Linear1Weight"), this->InputGrad("Linear1Weight")); op->SetOutput(framework::GradVarName("Linear1Bias"), @@ -334,8 +348,6 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker { op->SetOutput(framework::GradVarName("Linear2Bias"), this->InputGrad("Linear2Bias")); } - - op->SetAttrMap(this->Attrs()); } }; diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cu b/paddle/fluid/operators/fused/fused_feedforward_op.cu index 61a8a9a82f2e0..3b47e65c4833d 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cu +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cu @@ -113,26 +113,40 @@ class FusedFeedForwardKernel : public framework::OpKernel { auto* linear1_bias = context.Input("Linear1Bias"); auto* linear2_weight = context.Input("Linear2Weight"); auto* linear2_bias = context.Input("Linear2Bias"); - auto* ln1_scale = context.Input("Ln1Scale"); - auto* ln1_bias = context.Input("Ln1Bias"); - auto* ln2_scale = context.Input("Ln2Scale"); - auto* ln2_bias = context.Input("Ln2Bias"); - - auto* ln1_mean = context.Output("Ln1Mean"); - auto* ln1_variance = context.Output("Ln1Variance"); - auto* ln2_mean = context.Output("Ln2Mean"); - auto* ln2_variance = context.Output("Ln2Variance"); + const bool pre_layer_norm = context.Attr("pre_layer_norm"); + + auto* ln1_scale = + pre_layer_norm ? context.Input("Ln1Scale") : nullptr; + auto* ln1_bias = + pre_layer_norm ? context.Input("Ln1Bias") : nullptr; + auto* ln2_scale = !pre_layer_norm + ? context.Input("Ln2Scale") + : nullptr; + auto* ln2_bias = + !pre_layer_norm ? context.Input("Ln2Bias") : nullptr; + + auto* ln1_mean = + pre_layer_norm ? context.Output("Ln1Mean") : nullptr; + auto* ln1_variance = pre_layer_norm + ? context.Output("Ln1Variance") + : nullptr; + auto* ln2_mean = !pre_layer_norm + ? context.Output("Ln2Mean") + : nullptr; + auto* ln2_variance = !pre_layer_norm + ? context.Output("Ln2Variance") + : nullptr; auto* out = context.Output("Out"); auto* dropout1_mask = context.Output("Dropout1Mask"); auto* dropout2_mask = context.Output("Dropout2Mask"); auto* linear1_out = context.Output("Linear1Out"); - auto* ln1_out = context.Output("Ln1Out"); + auto* ln1_out = + pre_layer_norm ? context.Output("Ln1Out") : nullptr; auto* dropout1_out = context.Output("Dropout1Out"); auto* dropout2_out = context.Output("Dropout2Out"); const std::string act_method = context.Attr("act_method"); - const bool pre_layer_norm = context.Attr("pre_layer_norm"); const float epsilon1 = context.Attr("ln1_epsilon"); const float epsilon2 = context.Attr("ln2_epsilon"); @@ -144,12 +158,16 @@ class FusedFeedForwardKernel : public framework::OpKernel { out->mutable_data(place); dropout1_mask->mutable_data(place); dropout2_mask->mutable_data(place); - ln1_mean->mutable_data(place); - ln1_variance->mutable_data(place); - ln2_mean->mutable_data(place); - ln2_variance->mutable_data(place); + if (pre_layer_norm) { + ln1_mean->mutable_data(place); + ln1_variance->mutable_data(place); + ln1_out->mutable_data(place); + } else { + ln2_mean->mutable_data(place); + ln2_variance->mutable_data(place); + } + linear1_out->mutable_data(place); - ln1_out->mutable_data(place); dropout1_out->mutable_data(place); dropout2_out->mutable_data(place); @@ -193,16 +211,16 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { const framework::Tensor& d_out, const framework::Tensor& x, const framework::Tensor& dropout1_mask, const framework::Tensor& dropout2_mask, - const framework::Tensor& linear1_out, const framework::Tensor& ln1_out, + const framework::Tensor& linear1_out, const framework::Tensor* ln1_out, const framework::Tensor& dropout1_out, const framework::Tensor& dropout2_out, const framework::Tensor& linear1_weight, const framework::Tensor* linear1_bias, const framework::Tensor& linear2_weight, const framework::Tensor* ln1_gamma, const framework::Tensor* ln1_beta, - const framework::Tensor& ln1_mean, const framework::Tensor& ln1_variance, + const framework::Tensor* ln1_mean, const framework::Tensor* ln1_variance, const framework::Tensor* ln2_gamma, const framework::Tensor* ln2_beta, - const framework::Tensor& ln2_mean, const framework::Tensor& ln2_variance, + const framework::Tensor* ln2_mean, const framework::Tensor* ln2_variance, framework::Tensor* d_x, framework::Tensor* d_linear1_weight, framework::Tensor* d_linear1_bias, framework::Tensor* d_linear2_weight, framework::Tensor* d_linear2_bias, framework::Tensor* d_ln1_gamma, @@ -252,8 +270,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { } else { fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( ctx, d_out.data(), dropout2_out.data(), - dropout2_mask.data(), ln2_gamma_ptr, ln2_mean.data(), - ln2_variance.data(), d_dropout2_out.data(), d_ln2_gamma_ptr, + dropout2_mask.data(), ln2_gamma_ptr, ln2_mean->data(), + ln2_variance->data(), d_dropout2_out.data(), d_ln2_gamma_ptr, d_ln2_beta_ptr, d_linear2_out.data(), d_linear2_bias_ptr, d_residual.data()); } @@ -273,13 +291,13 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { if (pre_layer_norm) { framework::Tensor d_ln1_out; d_ln1_out.mutable_data({bsz_seq, d_model}, place); - MatMulGrad(ctx, d_linear1_out, ln1_out, linear1_weight, &d_ln1_out, + MatMulGrad(ctx, d_linear1_out, *ln1_out, linear1_weight, &d_ln1_out, d_linear1_weight); - pre_layernorm_helper.LayerNormGrad(ctx, d_ln1_out.data(), x.data(), - ln1_gamma_ptr, ln1_mean.data(), - ln1_variance.data(), d_x->data(), - d_ln1_gamma_ptr, d_ln1_beta_ptr); + pre_layernorm_helper.LayerNormGrad( + ctx, d_ln1_out.data(), x.data(), ln1_gamma_ptr, + ln1_mean->data(), ln1_variance->data(), d_x->data(), + d_ln1_gamma_ptr, d_ln1_beta_ptr); } else { MatMulGrad(ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight); } @@ -290,33 +308,52 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { auto d_out = *context.Input(framework::GradVarName("Out")); auto x = *context.Input("X"); + const bool pre_layer_norm = context.Attr("pre_layer_norm"); auto dropout1_mask = *context.Input("Dropout1Mask"); auto dropout2_mask = *context.Input("Dropout2Mask"); auto linear1_out = *context.Input("Linear1Out"); - auto ln1_out = *context.Input("Ln1Out"); + auto* ln1_out = + pre_layer_norm ? context.Input("Ln1Out") : nullptr; auto dropout1_out = *context.Input("Dropout1Out"); auto dropout2_out = *context.Input("Dropout2Out"); auto linear1_weight = *context.Input("Linear1Weight"); auto* linear1_bias = context.Input("Linear1Bias"); auto linear2_weight = *context.Input("Linear2Weight"); - auto ln1_mean = *context.Input("Ln1Mean"); - auto ln1_variance = *context.Input("Ln1Variance"); - auto* ln1_scale = context.Input("Ln1Scale"); - auto* ln1_bias = context.Input("Ln1Bias"); - auto ln2_mean = *context.Input("Ln2Mean"); - auto ln2_variance = *context.Input("Ln2Variance"); - auto* ln2_scale = context.Input("Ln2Scale"); - auto* ln2_bias = context.Input("Ln2Bias"); + auto* ln1_mean = + pre_layer_norm ? context.Input("Ln1Mean") : nullptr; + auto* ln1_variance = pre_layer_norm + ? context.Input("Ln1Variance") + : nullptr; + auto* ln1_scale = + pre_layer_norm ? context.Input("Ln1Scale") : nullptr; + auto* ln1_bias = + pre_layer_norm ? context.Input("Ln1Bias") : nullptr; + auto* ln2_mean = + !pre_layer_norm ? context.Input("Ln2Mean") : nullptr; + auto* ln2_variance = !pre_layer_norm + ? context.Input("Ln2Variance") + : nullptr; + auto* ln2_scale = !pre_layer_norm + ? context.Input("Ln2Scale") + : nullptr; + auto* ln2_bias = + !pre_layer_norm ? context.Input("Ln2Bias") : nullptr; auto* d_x = context.Output(framework::GradVarName("X")); - auto* d_ln1_scale = - context.Output(framework::GradVarName("Ln1Scale")); - auto* d_ln1_bias = - context.Output(framework::GradVarName("Ln1Bias")); + auto* d_ln1_scale = pre_layer_norm + ? context.Output( + framework::GradVarName("Ln1Scale")) + : nullptr; + auto* d_ln1_bias = pre_layer_norm + ? context.Output( + framework::GradVarName("Ln1Bias")) + : nullptr; auto* d_ln2_scale = - context.Output(framework::GradVarName("Ln2Scale")); + pre_layer_norm ? nullptr : context.Output( + framework::GradVarName("Ln2Scale")); auto* d_ln2_bias = - context.Output(framework::GradVarName("Ln2Bias")); + pre_layer_norm ? nullptr : context.Output( + framework::GradVarName("Ln2Bias")); auto* d_linear1_weight = context.Output( framework::GradVarName("Linear1Weight")); auto* d_linear1_bias = context.Output( @@ -328,7 +365,6 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { const float epsilon1 = context.Attr("ln1_epsilon"); const float epsilon2 = context.Attr("ln2_epsilon"); - const bool pre_layer_norm = context.Attr("pre_layer_norm"); const std::string act_method = context.Attr("act_method"); DropoutParam dropout_param1(context, 1); DropoutParam dropout_param2(context, 2); diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index 7359adff62021..c33e1f53dfdb6 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -65,7 +65,7 @@ def setUp(self): def config(self): self.x_type = np.float32 self.attn_mask_type = np.float64 - self.pre_layer_norm = True + self.pre_layer_norm = False self.training = True self.batch_size = 8 @@ -213,11 +213,40 @@ def test_fused_attention_op(self): x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-5) +class TestFusedAttentionOpPreLn(TestFusedAttentionOp): + def config(self): + self.x_type = np.float32 + self.attn_mask_type = np.float64 + self.pre_layer_norm = True + self.training = True + + self.batch_size = 8 + self.query_length = 128 + self.head_dim = 64 + self.num_heads = 16 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.weight_attr = None + self.bias_attr = None + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length + + def test_fused_attention_op(self): + final_out_ref, x_grad_ref = self.GetBaselineOut() + final_out, x_grad = self.GetFusedAttentionOut() + np.testing.assert_allclose( + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1) + np.testing.assert_allclose( + x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1) + + class TestFusedAttentionOpFp16(TestFusedAttentionOp): def config(self): self.x_type = np.float16 self.attn_mask_type = np.float64 - self.pre_layer_norm = True + self.pre_layer_norm = False self.training = True self.batch_size = 8