diff --git a/paddle/phi/kernels/gpu/masked_fill_grad_kernel.cu b/paddle/phi/kernels/gpu/masked_fill_grad_kernel.cu index e54d46e0115bb3..2034b339a0b775 100644 --- a/paddle/phi/kernels/gpu/masked_fill_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/masked_fill_grad_kernel.cu @@ -275,14 +275,24 @@ void GPUMaskedFillGrad(const phi::GPUContext& dev_ctx, config); if (value_grad) { DenseTensor zero_tensor; - FullLikeKernel( - dev_ctx, out_grad, Scalar(T(0.0)), out_grad.dtype(), &zero_tensor); + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(out_grad.dims())), + T(0.0), + &zero_tensor); DenseTensor value_grad_tensor; value_grad_tensor.set_meta(out_grad.meta()); WhereKernel( dev_ctx, mask, out_grad, zero_tensor, &value_grad_tensor); - SumKernel( - dev_ctx, value_grad_tensor, {1}, out_grad.dtype(), false, value_grad); + std::vector v_dims(value_grad_tensor.dims().size()); + std::iota(v_dims.begin(), v_dims.end(), 0); + IntArray v_axis(v_dims); + SumKernel(dev_ctx, + value_grad_tensor, + v_axis, + value_grad->dtype(), + false, + value_grad); } } else {