Skip to content

Commit aaf077d

Browse files
authored
[0-size Tensor Job2 No.48] Add 0-size Tensor support for paddle.median (#73477)
* Fix * Fix
1 parent 993b1ed commit aaf077d

File tree

5 files changed

+91
-27
lines changed

5 files changed

+91
-27
lines changed

paddle/phi/kernels/cpu/top_k_kernel.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/phi/backends/cpu/cpu_context.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
1920
#include "paddle/phi/kernels/funcs/eigen/common.h"
2021
#include "paddle/phi/kernels/funcs/math_function.h"
2122

@@ -159,19 +160,26 @@ void TopkKernel(const Context& dev_ctx,
159160
}
160161

161162
int k = k_scalar.to<int>();
162-
PADDLE_ENFORCE_GE(
163-
x.numel(),
164-
k,
165-
errors::InvalidArgument(
166-
"x has only %d element, can not find %d top values.", x.numel(), k));
167-
163+
// out shape [-1]
168164
if (k_scalar.FromTensor()) {
169165
auto out_dims = out->dims();
170166
// according to axis to set K value in the dim
171167
out_dims[axis] = k;
172168
out->Resize(out_dims);
173169
indices->Resize(out_dims);
174170
}
171+
if (x.numel() == 0) {
172+
phi::Full<T, Context>(
173+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), NAN, out);
174+
phi::Full<int64_t, Context>(
175+
dev_ctx, phi::IntArray(common::vectorize(indices->dims())), 0, indices);
176+
return;
177+
}
178+
PADDLE_ENFORCE_GE(
179+
x.numel(),
180+
k,
181+
errors::InvalidArgument(
182+
"x has only %d element, can not find %d top values.", x.numel(), k));
175183

176184
T* out_data = dev_ctx.template Alloc<T>(out);
177185
int64_t* indices_data = dev_ctx.template Alloc<int64_t>(indices);

paddle/phi/kernels/gpu/top_k_kernel.cu

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
#include "paddle/phi/common/bfloat16.h"
2121
#include "paddle/phi/core/kernel_registry.h"
2222
#include "paddle/phi/core/tensor_utils.h"
23+
#include "paddle/phi/kernels/full_kernel.h"
2324
#include "paddle/phi/kernels/funcs/gather.cu.h"
2425
#include "paddle/phi/kernels/funcs/math_function.h"
2526
#include "paddle/phi/kernels/funcs/top_k_function_cuda.h"
26-
2727
namespace phi {
2828

2929
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
@@ -83,17 +83,25 @@ void TopkKernel(const Context& dev_ctx,
8383
if (axis < 0) axis += in_dims.size();
8484

8585
int k = k_scalar.to<int>();
86-
PADDLE_ENFORCE_GE(
87-
x.numel(),
88-
k,
89-
errors::InvalidArgument(
90-
"x has only %d element, can not find %d top values.", x.numel(), k));
86+
// out shape [-1]
9187
if (k_scalar.FromTensor()) {
9288
phi::DDim out_dims = out->dims();
9389
out_dims[axis] = k;
9490
out->Resize(out_dims);
9591
indices->Resize(out_dims);
9692
}
93+
if (x.numel() == 0) {
94+
phi::Full<T, Context>(
95+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), NAN, out);
96+
phi::Full<int64_t, Context>(
97+
dev_ctx, phi::IntArray(common::vectorize(indices->dims())), 0, indices);
98+
return;
99+
}
100+
PADDLE_ENFORCE_GE(
101+
x.numel(),
102+
k,
103+
errors::InvalidArgument(
104+
"x has only %d element, can not find %d top values.", x.numel(), k));
97105

98106
const auto& out_dims = out->dims();
99107

paddle/phi/kernels/xpu/top_k_kernel.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/phi/backends/xpu/enforce_xpu.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
1920
#include "paddle/phi/kernels/funcs/math_function.h"
2021
#include "paddle/phi/kernels/xpu/xpu_mem_util.h"
2122
namespace phi {
@@ -52,19 +53,26 @@ void TopkKernel(const Context& dev_ctx,
5253
}
5354

5455
int64_t k = k_scalar.to<int64_t>();
55-
PADDLE_ENFORCE_GE(
56-
x.numel(),
57-
k,
58-
errors::InvalidArgument(
59-
"x has only %d element, can not find %d top values.", x.numel(), k));
60-
56+
// out shape [-1]
6157
if (k_scalar.FromTensor()) {
6258
auto out_dims_ = out->dims();
6359
// according to axis to set K value in the dim
6460
out_dims_[axis] = k;
6561
out->Resize(out_dims_);
6662
indices->Resize(out_dims_);
6763
}
64+
if (x.numel() == 0) {
65+
phi::Full<T, Context>(
66+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), NAN, out);
67+
phi::Full<int64_t, Context>(
68+
dev_ctx, phi::IntArray(common::vectorize(indices->dims())), 0, indices);
69+
return;
70+
}
71+
PADDLE_ENFORCE_GE(
72+
x.numel(),
73+
k,
74+
errors::InvalidArgument(
75+
"x has only %d element, can not find %d top values.", x.numel(), k));
6876

6977
const T* in_data = x.data<T>();
7078
int64_t* indices_data = dev_ctx.template Alloc<int64_t>(indices);

python/paddle/tensor/stat.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -598,10 +598,6 @@ def median(
598598
if not isinstance(x, (Variable, paddle.pir.Value)):
599599
raise TypeError("In median, the input x should be a Tensor.")
600600

601-
if in_dynamic_mode() and x.size == 0:
602-
# TODO: Currently, `__eq__` don't support arguments (`pir.Value` & `int`)
603-
raise ValueError("In median, the size of input x should not be 0.")
604-
605601
is_flatten = False
606602
dims = len(x.shape)
607603
if dims == 0:
@@ -658,7 +654,7 @@ def median(
658654
keepdim=True,
659655
)
660656
else: # mode == 'min'
661-
if sz & 1 == 0:
657+
if sz & 1 == 0 and kth != 0:
662658
out_tensor = paddle.slice(
663659
tensor_topk, axes=[axis], starts=[kth - 1], ends=[kth]
664660
)

test/legacy_test/test_median.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def np_medain_min(data, keepdims=False):
5050
return np_res + np.sum(np.isnan(data).astype(data.dtype) * data)
5151

5252

53-
def np_medain_min_axis(data, axis=None, keepdims=False):
53+
def np_median_min_axis(data, axis=None, keepdims=False):
5454
data = copy.deepcopy(data)
5555
if axis is None:
5656
return np_medain_min(data, keepdims)
@@ -232,7 +232,7 @@ class TestMedianMin(unittest.TestCase):
232232
def static_single_test_median(self, lis_test):
233233
paddle.enable_static()
234234
x, axis, keepdims = lis_test
235-
res_np = np_medain_min_axis(x, axis=axis, keepdims=keepdims)
235+
res_np = np_median_min_axis(x, axis=axis, keepdims=keepdims)
236236
main_program = paddle.static.Program()
237237
startup_program = paddle.static.Program()
238238
exe = paddle.static.Executor()
@@ -245,7 +245,7 @@ def static_single_test_median(self, lis_test):
245245

246246
def dygraph_single_test_median(self, lis_test):
247247
x, axis, keepdims = lis_test
248-
res_np = np_medain_min_axis(x, axis=axis, keepdims=keepdims)
248+
res_np = np_median_min_axis(x, axis=axis, keepdims=keepdims)
249249
if axis is None:
250250
res_pd = paddle.median(
251251
paddle.to_tensor(x), axis, keepdims, mode='min'
@@ -335,7 +335,7 @@ def test_float16(self):
335335
for keepdims in [False, True]
336336
]
337337
for axis, keepdims in lis_tests:
338-
res_np = np_medain_min_axis(x, axis=axis, keepdims=keepdims)
338+
res_np = np_median_min_axis(x, axis=axis, keepdims=keepdims)
339339
if axis is None:
340340
res_pd = paddle.median(
341341
paddle.to_tensor(x), axis, keepdims, mode='min'
@@ -357,5 +357,49 @@ def test_output_dtype(self):
357357
np.testing.assert_equal(res.numpy().dtype, np.dtype(inp_dtype))
358358

359359

360+
class TestMedianAvg_ZeroSize(unittest.TestCase):
361+
def dygraph_single_test_median(self, lis_test):
362+
x, axis, keepdims = lis_test
363+
res_np = np.median(x, axis=axis, keepdims=keepdims)
364+
x_pd = paddle.to_tensor(x)
365+
x_pd.stop_gradient = False
366+
res_pd = paddle.median(x_pd, axis, keepdims)
367+
np.testing.assert_allclose(res_pd.numpy(), res_np)
368+
paddle.sum(res_pd).backward()
369+
np.testing.assert_allclose(x_pd.grad.shape, x_pd.shape)
370+
371+
def test_median_dygraph(self):
372+
paddle.disable_static()
373+
h = 0
374+
w = 4
375+
l = 2
376+
x = np.arange(h * w * l).reshape([h, w, l])
377+
self.dygraph_single_test_median([x, 1, False])
378+
379+
380+
class TestMedianMin_ZeroSize(unittest.TestCase):
381+
382+
def dygraph_single_test_median(self, lis_test):
383+
x, axis, keepdims = lis_test
384+
res_np = np_median_min_axis(x, axis=axis, keepdims=keepdims)
385+
x_pd = paddle.to_tensor(x)
386+
x_pd.stop_gradient = False
387+
if axis is None:
388+
res_pd = paddle.median(x_pd, axis, keepdims, mode='min')
389+
else:
390+
res_pd, _ = paddle.median(x_pd, axis, keepdims, mode='min')
391+
np.testing.assert_allclose(res_pd.numpy(), res_np)
392+
paddle.sum(res_pd).backward()
393+
np.testing.assert_allclose(x_pd.grad.shape, x_pd.shape)
394+
395+
def test_median_dygraph(self):
396+
paddle.disable_static()
397+
h = 0
398+
w = 4
399+
l = 2
400+
x = np.arange(h * w * l).reshape([h, w, l]).astype("float32")
401+
self.dygraph_single_test_median([x, 1, False])
402+
403+
360404
if __name__ == '__main__':
361405
unittest.main()

0 commit comments

Comments
 (0)