Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/phi/kernels/cpu/roll_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ void RollGradKernel(const Context& dev_ctx,
const IntArray& shifts,
const std::vector<int64_t>& axis,
DenseTensor* x_grad) {
if (x_grad && x_grad->numel() == 0) {
dev_ctx.template Alloc<T>(x_grad);
return;
}
std::vector<T> out_vec;
phi::TensorToVector(out_grad, dev_ctx, &out_vec);

Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/gpu/roll_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ void RollGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& axis,
DenseTensor* x_grad) {
auto* out_grad_data = out_grad.data<T>();
if (x_grad && x_grad->numel() == 0) {
dev_ctx.template Alloc<T>(x_grad);
return;
}
T* x_grad_data = dev_ctx.template Alloc<T>(x_grad);

auto shifts_data = shifts.GetData();
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/xpu/roll_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ void RollGradKernel(const Context& dev_ctx,
DenseTensor* x_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto shifts_data = shifts.GetData();
if (x_grad && x_grad->numel() == 0) {
dev_ctx.template Alloc<T>(x_grad);
return;
}
dev_ctx.template Alloc<T>(x_grad);
DDim input_dim = x.dims();
std::vector<int64_t> xshape;
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/xpu/roll_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ void RollKernel(const Context& dev_ctx,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto shifts_data = shifts.GetData();
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
dev_ctx.template Alloc<T>(out);
DDim input_dim = x.dims();
std::vector<int64_t> xshape;
Expand Down
54 changes: 54 additions & 0 deletions test/fft/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -1934,5 +1934,59 @@ def test_ifftshift(self):
)


@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'axes', 'dtype'),
[
('test_1d', np.random.randn(0), (0,), 'float64'),
(
'test_2d_odd_with_all_axes',
np.random.randn(5, 0) + 1j * np.random.randn(5, 0),
None,
'complex128',
),
],
)
class TestFftShift_ZeroSize(unittest.TestCase):
def test_fftshift(self):
"""Test fftshift with norm condition"""
with paddle.base.dygraph.guard(self.place):
np.testing.assert_allclose(
scipy.fft.fftshift(self.x, self.axes),
paddle.fft.fftshift(
paddle.to_tensor(self.x), self.axes
).numpy(),
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)),
)


@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'axes'),
[
('test_1d', np.random.randn(0), (0,), 'float64'),
(
'test_2d_odd_with_all_axes',
np.random.randn(5, 0) + 1j * np.random.randn(5, 0),
None,
'complex128',
),
],
)
class TestIfftShift_ZeroSize(unittest.TestCase):
def test_ifftshift(self):
"""Test ifftshift with norm condition"""
with paddle.base.dygraph.guard(self.place):
np.testing.assert_allclose(
scipy.fft.ifftshift(self.x, self.axes),
paddle.fft.ifftshift(
paddle.to_tensor(self.x), self.axes
).numpy(),
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)),
)


if __name__ == '__main__':
unittest.main()
Loading