From 6580ad1611346fe40ff4fdac8d5791a0850f2cfd Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Fri, 22 Oct 2021 11:21:25 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Bug=20Fixes=E3=80=91Elementwise=5Fadd?= =?UTF-8?q?=20triple=20grad,=20fixed=20an=20input=20uninitialized=20proble?= =?UTF-8?q?m=20(#36618)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Support elementwise_add triple grad Kernel * Change code-format to follow CI std * Removed unreasonable code, and fixed an input uninitialized issue * Support elementwise_add triple grad Kernel * Change code-format to follow CI std * Removed unreasonable code, and fixed an input uninitialized issue --- paddle/fluid/operators/elementwise/elementwise_op.h | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 5703e904c240b..13e4624ef717f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -445,18 +445,7 @@ class ElementwiseOpTripleGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { framework::proto::VarType::Type input_data_type; - if (ctx.HasInput("DDX") == false) { - OP_INOUT_CHECK(ctx.HasInput("DDY"), "Input", "DDY", - "ElementwiseOpTripleGrad"); - input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDY"); - } else if (ctx.HasInput("DDY") == false) { - OP_INOUT_CHECK(ctx.HasInput("DDX"), "Input", "DDX", - "ElementwiseOpTripleGrad"); - input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); - } else { - input_data_type = - OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "DDX", "DDY"); - } + input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "D_DDOut"); #ifdef PADDLE_WITH_MKLDNN if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {