Skip to content

Commit bccc063

Browse files
committed
Fix
1 parent 3efb8db commit bccc063

File tree

5 files changed

+19
-8
lines changed

5 files changed

+19
-8
lines changed

paddle/phi/kernels/cpu/p_norm_grad_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ void PNormGradKernel(const Context& dev_ctx,
5959
auto* in_norm_dy = &out_grad;
6060
auto* out_dx = x_grad;
6161
dev_ctx.template Alloc<T>(out_dx);
62+
if (x.numel() == 0) return;
6263

6364
T eps = static_cast<T>(epsilon);
6465
auto xdim = in_x->dims();

paddle/phi/kernels/cpu/p_norm_kernel.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,11 @@ void PNormKernel(const Context& dev_ctx,
6363
GetDims(xdim, axis, &pre, &n, &post, asvector);
6464

6565
if (x.numel() == 0) {
66-
if (out->numel() > 0) {
67-
std::vector<int64_t> vec_dims = common::vectorize(out->dims());
68-
phi::Full<T, Context>(
69-
dev_ctx, phi::IntArray(vec_dims), static_cast<T>(0), out);
66+
if (out) {
67+
phi::Full<T, Context>(dev_ctx,
68+
phi::IntArray(common::vectorize(out->dims())),
69+
static_cast<T>(0),
70+
out);
7071
}
7172
return;
7273
}

paddle/phi/kernels/gpu/p_norm_grad_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ void PNormGradKernel(const Context& dev_ctx,
8383
auto* in_norm_dy = &out_grad;
8484
auto* out_dx = x_grad;
8585
dev_ctx.template Alloc<T>(out_dx);
86+
if (x.numel() == 0) return;
8687

8788
auto xdim = in_x->dims();
8889
bool reduce_all = (in_norm->numel() == 1);

paddle/phi/kernels/gpu/p_norm_kernel.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,11 @@ void PNormKernel(const Context& dev_ctx,
9595
funcs::details::GetReduceDim(axis_dims, xdim.size(), asvector);
9696

9797
if (x.numel() == 0) {
98-
if (out->numel() > 0) {
99-
std::vector<int64_t> vec_dims = common::vectorize(out->dims());
100-
phi::Full<T, Context>(
101-
dev_ctx, phi::IntArray(vec_dims), static_cast<T>(0), out);
98+
if (out) {
99+
phi::Full<T, Context>(dev_ctx,
100+
phi::IntArray(common::vectorize(out->dims())),
101+
static_cast<T>(0),
102+
out);
102103
}
103104
return;
104105
}

test/legacy_test/test_norm_op.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,13 @@ def test_norm_x_type():
212212
self.assertRaises(TypeError, test_norm_x_type)
213213

214214

215+
class TestNormOp_ZeroSize(TestNormOp):
216+
def init_test_case(self):
217+
self.shape = [5, 3, 0, 7]
218+
self.axis = 0
219+
self.epsilon = 1e-8
220+
221+
215222
if __name__ == '__main__':
216223
paddle.enable_static()
217224
unittest.main()

0 commit comments

Comments
 (0)