Skip to content

Commit cf56cbd

Browse files
authored
[0-size Tensor No.11、262、287] Add 0-size Tensor support for argsort/sort API. (#72872)
* fix * fix * fix * fix * fix * fix * fix * fix
1 parent c0032d7 commit cf56cbd

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

paddle/phi/kernels/xpu/argsort_kernel.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "paddle/phi/backends/xpu/enforce_xpu.h"
1818
#include "paddle/phi/backends/xpu/xpu_context.h"
1919
#include "paddle/phi/core/kernel_registry.h"
20+
#include "paddle/phi/kernels/full_kernel.h"
2021
#include "paddle/phi/kernels/funcs/math_function.h"
2122

2223
namespace phi {
@@ -99,6 +100,15 @@ void ArgsortKernel(const Context& dev_ctx,
99100
DenseTensor* indices) {
100101
auto in_dims = input.dims();
101102
auto rank = in_dims.size();
103+
104+
if (input.numel() == 0) {
105+
output->Resize(in_dims);
106+
indices->Resize(in_dims);
107+
dev_ctx.template Alloc<T>(output);
108+
dev_ctx.template Alloc<int64_t>(indices);
109+
return;
110+
}
111+
102112
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
103113
int64_t n = in_dims[axis];
104114

test/xpu/test_argsort_op_xpu.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,65 @@ def test_check_grad(self):
9696
self.check_grad_with_place(self.place, {'X'}, 'Out')
9797

9898

99+
class XPUTestArgsortOp_0D(XPUOpTestWrapper):
100+
def __init__(self):
101+
self.op_name = 'argsort'
102+
self.use_dynamic_create_class = False
103+
104+
class TestArgsortOpCase1(XPUOpTest):
105+
def setUp(self):
106+
self.set_xpu()
107+
self.op_type = "argsort"
108+
self.place = paddle.XPUPlace(0)
109+
self.dtype = self.in_type
110+
self.input_shape = 0
111+
self.axis = (
112+
-1 if not hasattr(self, 'init_axis') else self.init_axis
113+
)
114+
self.descending = (
115+
False
116+
if not hasattr(self, 'init_descending')
117+
else self.init_descending
118+
)
119+
120+
if self.dtype == np.float32:
121+
self.x = np.random.random(self.input_shape).astype(
122+
self.dtype
123+
)
124+
else:
125+
self.x = np.random.choice(
126+
low=-1000, high=1000, size=self.input_shape
127+
).astype(self.dtype)
128+
129+
self.inputs = {"X": self.x}
130+
self.attrs = {"axis": self.axis, "descending": self.descending}
131+
self.get_output()
132+
self.outputs = {"Out": self.sorted_x, "Indices": self.indices}
133+
134+
def get_output(self):
135+
if self.descending:
136+
self.indices = np.flip(
137+
np.argsort(self.x, kind='heapsort', axis=self.axis),
138+
self.axis,
139+
)
140+
self.sorted_x = np.flip(
141+
np.sort(self.x, kind='heapsort', axis=self.axis), self.axis
142+
)
143+
else:
144+
self.indices = np.argsort(self.x, kind='heapsort', axis=self.axis)
145+
self.sorted_x = np.sort(self.x, kind='heapsort', axis=self.axis)
146+
147+
def set_xpu(self):
148+
self.__class__.use_xpu = True
149+
self.__class__.no_need_check_grad = True
150+
151+
def test_check_output(self):
152+
self.check_output_with_place(self.place)
153+
154+
def test_check_grad(self):
155+
self.check_grad_with_place(self.place, {'X'}, 'Out')
156+
157+
99158
class XPUTestArgsortOp_LargeN(XPUOpTestWrapper):
100159
def __init__(self):
101160
self.op_name = 'argsort'

0 commit comments

Comments
 (0)