Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/phi/kernels/gpu/set_value_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ void SetTensorValueKernelV2(const Context& dev_ctx,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* out) {
if (in.numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}

auto in_dims = in.dims();
auto meta = in.meta();
std::vector<int64_t> starts_local = starts.GetData();
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/impl/set_value_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ void SetValueGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& none_axes,
DenseTensor* x_grad,
DenseTensor* value_grad) {
if (out_grad.numel() == 0) {
if (x_grad) dev_ctx.template Alloc<T>(x_grad);
if (value_grad) dev_ctx.template Alloc<T>(value_grad);
return;
}
const int rank = out_grad.dims().size();
std::vector<int64_t> starts_local = starts.GetData();
std::vector<int64_t> ends_local = ends.GetData();
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/impl/set_value_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ void SetTensorValueKernel(const Context& dev_ctx,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* out) {
if (x.numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}

const int rank = x.dims().size();

switch (rank) {
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/xpu/set_value_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,11 @@ void SetValueGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& none_axes,
DenseTensor* x_grad,
DenseTensor* value_grad) {
if (out_grad.numel() == 0) {
if (x_grad) dev_ctx.template Alloc<T>(x_grad);
if (value_grad) dev_ctx.template Alloc<T>(value_grad);
return;
}
const int rank = out_grad.dims().size();

switch (rank) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/xpu/set_value_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,10 @@ void SetTensorValueKernel(const Context& dev_ctx,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* out) {
if (x.numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
SetValueKernelImpl<T, Context>(dev_ctx,
x,
value.data<T>(),
Expand Down
55 changes: 55 additions & 0 deletions test/legacy_test/test_slice_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,5 +342,60 @@ def test_error_index(self):
)


class TestSliceScatterApi_ZeroSize(unittest.TestCase):
def setUp(self):
np.random.seed(2023)
self.init_shape()
self.place = get_places()

def init_np(self):
self.x_np = np.random.random(self.x_shape).astype(
'uint16' if self.dtype == 'bfloat16' else self.dtype
)
self.value_np = np.random.random(self.value_shape).astype(
'uint16' if self.dtype == 'bfloat16' else self.dtype
)

def init_dtype(self):
self.dtype = 'float64'

def init_shape(self):
self.x_shape = [0, 6]
self.value_shape = [0, 2]
self.axes = [1]
self.starts = [2]
self.ends = [6]
self.strides = [2]

def test_api_dygraph(self):
self.init_dtype()
self.init_np()
for place in self.place:
paddle.disable_static(place)
x_tensor = paddle.to_tensor(self.x_np)
x_tensor.stop_gradient = False
value_tensor = paddle.to_tensor(self.value_np)
out = paddle.slice_scatter(
x_tensor,
value_tensor,
axes=self.axes,
starts=self.starts,
ends=self.ends,
strides=self.strides,
)
out_ref = numpy_ref(
self.x_np,
self.value_np,
axes=self.axes,
starts=self.starts,
ends=self.ends,
strides=self.strides,
)
np.testing.assert_allclose(out.numpy(), out_ref)
out.sum().backward()
np.testing.assert_allclose(x_tensor.grad.numpy(), x_tensor.numpy())
paddle.enable_static()


if __name__ == '__main__':
unittest.main()