Skip to content

Commit ade2f10

Browse files
committed
fix
1 parent 9c507ff commit ade2f10

File tree

4 files changed

+9
-30
lines changed

4 files changed

+9
-30
lines changed

paddle/phi/kernels/impl/unfold_grad_kernel_impl.h

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include <vector>
1818

1919
#include "paddle/phi/core/dense_tensor.h"
20-
#include "paddle/phi/kernels/full_kernel.h"
2120
#include "paddle/phi/kernels/funcs/im2col.h"
2221
#include "paddle/phi/kernels/funcs/math_function.h"
2322
#include "paddle/phi/kernels/funcs/unfold_functor.h"
@@ -33,14 +32,8 @@ void UnfoldGradKernel(const Context& dev_ctx,
3332
const std::vector<int>& paddings,
3433
const std::vector<int>& dilations,
3534
DenseTensor* x_grad) {
36-
if (out_grad.numel() == 0) {
37-
if (x_grad) {
38-
dev_ctx.template Alloc<T>(x_grad);
39-
phi::Full<T, Context>(
40-
dev_ctx, phi::IntArray(common::vectorize(x.dims())), 0, x_grad);
41-
}
42-
return;
43-
}
35+
dev_ctx.template Alloc<T>(x_grad);
36+
if (x_grad->numel() == 0) return;
4437

4538
if (!x_grad) return;
4639
const auto& x_dims = x_grad->dims();

paddle/phi/kernels/impl/unfold_kernel_impl.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include <vector>
1818

1919
#include "paddle/phi/core/dense_tensor.h"
20-
#include "paddle/phi/kernels/full_kernel.h"
2120
#include "paddle/phi/kernels/funcs/im2col.h"
2221
#include "paddle/phi/kernels/funcs/math_function.h"
2322
#include "paddle/phi/kernels/funcs/unfold_functor.h"
@@ -33,11 +32,8 @@ void UnfoldKernel(const Context& dev_ctx,
3332
const std::vector<int>& dilations,
3433
DenseTensor* out) {
3534
const int batch_size = static_cast<int>(x.dims()[0]);
36-
if (x && x.numel() == 0) {
37-
phi::Full<T, Context>(
38-
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
39-
return;
40-
}
35+
dev_ctx.template Alloc<T>(out);
36+
if (out->numel() == 0) return;
4137
phi::funcs::Im2ColFunctor<phi::funcs::ColFormat::kCFO, Context, T> im2col;
4238
const auto& x_dims = x.dims();
4339

paddle/phi/kernels/xpu/unfold_grad_kernel.cc

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
#include "paddle/phi/backends/xpu/enforce_xpu.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19-
#include "paddle/phi/kernels/full_kernel.h"
2019
#include "paddle/phi/kernels/funcs/unfold_functor.h"
2120

2221
namespace phi {
@@ -31,14 +30,9 @@ void UnfoldGradKernel(const Context& dev_ctx,
3130
const std::vector<int>& dilations_,
3231
DenseTensor* x_grad) {
3332
using XPUType = typename XPUTypeTrait<T>::Type;
34-
if (out_grad.numel() == 0) {
35-
if (x_grad) {
36-
dev_ctx.template Alloc<T>(x_grad);
37-
phi::Full<T, Context>(
38-
dev_ctx, phi::IntArray(common::vectorize(x.dims())), 0, x_grad);
39-
}
40-
return;
41-
}
33+
dev_ctx.template Alloc<T>(x_grad);
34+
if (x_grad->numel() == 0) return;
35+
4236
const std::string data_format = common::DataLayoutToString(x.layout());
4337
bool is_nchw = data_format == "NCHW";
4438
PADDLE_ENFORCE_EQ(is_nchw,

paddle/phi/kernels/xpu/unfold_kernel.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
#include "paddle/phi/backends/xpu/enforce_xpu.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19-
#include "paddle/phi/kernels/full_kernel.h"
2019
#include "paddle/phi/kernels/funcs/unfold_functor.h"
2120

2221
namespace phi {
@@ -30,11 +29,8 @@ void UnfoldKernel(const Context& dev_ctx,
3029
const std::vector<int>& dilations_,
3130
DenseTensor* out) {
3231
using XPUType = typename XPUTypeTrait<T>::Type;
33-
if (x.numel() == 0) {
34-
phi::Full<T, Context>(
35-
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
36-
return;
37-
}
32+
dev_ctx.template Alloc<T>(out);
33+
if (out->numel() == 0) return;
3834
const std::string data_format = common::DataLayoutToString(x.layout());
3935
bool is_nchw = data_format == "NCHW";
4036
PADDLE_ENFORCE_EQ(is_nchw,

0 commit comments

Comments
 (0)