Skip to content

Commit d64d3c3

Browse files
authored
[0-size Tensor Job2 No.37] Add 0-size Tensor support for index_put (#73513)
* Fix * Fix
1 parent aaf077d commit d64d3c3

File tree

7 files changed

+107
-0
lines changed

7 files changed

+107
-0
lines changed

paddle/phi/kernels/cpu/index_put_grad_kernel.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,18 @@ void IndexPutGradKernel(const Context& dev_ctx,
182182
bool accumulate,
183183
DenseTensor* x_grad,
184184
DenseTensor* value_grad) {
185+
if (out_grad.numel() == 0) {
186+
dev_ctx.template Alloc<T>(x_grad);
187+
// Fill value_grad with 0.
188+
if (value_grad) {
189+
phi::Full<T, Context>(
190+
dev_ctx,
191+
phi::IntArray(common::vectorize(value_grad->dims())),
192+
0,
193+
value_grad);
194+
}
195+
return;
196+
}
185197
PADDLE_ENFORCE_EQ(
186198
x.dtype(),
187199
value.dtype(),

paddle/phi/kernels/cpu/index_put_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ void IndexPutKernel(const Context& dev_ctx,
105105
const DenseTensor& value,
106106
bool accumulate,
107107
DenseTensor* out) {
108+
if (out && out->numel() == 0) {
109+
dev_ctx.template Alloc<T>(out);
110+
return;
111+
}
108112
PADDLE_ENFORCE_EQ(
109113
x.dtype(),
110114
value.dtype(),

paddle/phi/kernels/gpu/index_put_grad_kernel.cu

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,19 @@ void IndexPutGradKernel(const Context& dev_ctx,
231231
bool accumulate,
232232
DenseTensor* x_grad,
233233
DenseTensor* value_grad) {
234+
if (out_grad.numel() == 0) {
235+
dev_ctx.template Alloc<T>(x_grad);
236+
// Fill value_grad with 0.
237+
if (value_grad) {
238+
phi::Full<T, Context>(
239+
dev_ctx,
240+
phi::IntArray(common::vectorize(value_grad->dims())),
241+
0,
242+
value_grad);
243+
}
244+
return;
245+
}
246+
234247
PADDLE_ENFORCE_EQ(
235248
x.dtype(),
236249
value.dtype(),

paddle/phi/kernels/gpu/index_put_kernel.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ void IndexPutKernel(const Context& dev_ctx,
116116
const DenseTensor& value,
117117
bool accumulate,
118118
DenseTensor* out) {
119+
if (out && out->numel() == 0) {
120+
dev_ctx.template Alloc<T>(out);
121+
return;
122+
}
119123
PADDLE_ENFORCE_EQ(
120124
x.dtype(),
121125
value.dtype(),

paddle/phi/kernels/xpu/index_put_grad_kernel.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ void IndexPutGradKernel(const Context& dev_ctx,
3131
bool accumulate,
3232
DenseTensor* x_grad,
3333
DenseTensor* value_grad) {
34+
if (out_grad.numel() == 0) {
35+
dev_ctx.template Alloc<T>(x_grad);
36+
// Fill value_grad with 0.
37+
if (value_grad) {
38+
phi::Full<T, Context>(
39+
dev_ctx,
40+
phi::IntArray(common::vectorize(value_grad->dims())),
41+
0,
42+
value_grad);
43+
}
44+
return;
45+
}
3446
PADDLE_ENFORCE_EQ(
3547
x.dtype(),
3648
value.dtype(),

paddle/phi/kernels/xpu/index_put_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ void IndexPutKernel(const Context& dev_ctx,
2828
const DenseTensor& value,
2929
bool accumulate,
3030
DenseTensor* out) {
31+
if (out && out->numel() == 0) {
32+
dev_ctx.template Alloc<T>(out);
33+
return;
34+
}
3135
PADDLE_ENFORCE_EQ(
3236
x.dtype(),
3337
value.dtype(),

test/legacy_test/test_index_put_op.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,5 +1028,63 @@ def init_dtype_type(self):
10281028
self.index_type_pd1 = "bool"
10291029

10301030

1031+
class TestIndexPutAPI_ZeroSize(unittest.TestCase):
1032+
def setUp(self):
1033+
self.init_dtype_type()
1034+
self.setPlace()
1035+
1036+
def init_dtype_type(self):
1037+
self.dtype_np = np.float32
1038+
self.index_type_np = np.int64
1039+
self.x_shape = (10, 0)
1040+
self.indices_shapes = [[10]]
1041+
self.value_shape = [1, 1]
1042+
self.dtype_pd = paddle.float32
1043+
self.index_type_pd = paddle.int64
1044+
1045+
def setPlace(self):
1046+
self.place = []
1047+
if (
1048+
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
1049+
in ['1', 'true', 'on']
1050+
or not paddle.is_compiled_with_cuda()
1051+
):
1052+
self.place.append('cpu')
1053+
if self.dtype_np is np.float16:
1054+
self.place = []
1055+
if paddle.is_compiled_with_cuda():
1056+
self.place.append('gpu')
1057+
1058+
def test_dygraph_forward(self):
1059+
paddle.disable_static()
1060+
for place in self.place:
1061+
paddle.device.set_device(place)
1062+
x_pd = paddle.randn(self.x_shape, dtype=self.dtype_pd)
1063+
x_np = x_pd.numpy()
1064+
value_pd = paddle.randn(self.value_shape, dtype=self.dtype_pd)
1065+
value_np = value_pd.numpy()
1066+
x_pd.stop_gradient = False
1067+
value_pd.stop_gradient = False
1068+
indices_pd = [
1069+
paddle.randn(indices_shape).astype(dtype=self.index_type_pd)
1070+
for indices_shape in self.indices_shapes
1071+
]
1072+
indices_np = [item.numpy() for item in indices_pd]
1073+
indices_pd = tuple(indices_pd)
1074+
accumulate = False
1075+
ref_res = compute_index_put_ref(
1076+
x_np, indices_np, value_np, accumulate
1077+
)
1078+
pd_res = paddle.index_put(x_pd, indices_pd, value_pd, accumulate)
1079+
np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7)
1080+
1081+
# check grad
1082+
pd_res.sum().backward()
1083+
np.testing.assert_allclose(x_pd.grad.shape, x_pd.shape)
1084+
np.testing.assert_allclose(
1085+
value_pd.grad.numpy(), np.zeros(value_pd.shape)
1086+
)
1087+
1088+
10311089
if __name__ == '__main__':
10321090
unittest.main()

0 commit comments

Comments
 (0)