Skip to content

Commit 38bc55c

Browse files
committed
Fix
1 parent 22fd52c commit 38bc55c

File tree

3 files changed

+50
-0
lines changed

3 files changed

+50
-0
lines changed

paddle/phi/kernels/cpu/roll_grad_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ void RollGradKernel(const Context& dev_ctx,
2727
const IntArray& shifts,
2828
const std::vector<int64_t>& axis,
2929
DenseTensor* x_grad) {
30+
if (x_grad && x_grad->numel() == 0) {
31+
dev_ctx.template Alloc<T>(x_grad);
32+
return;
33+
}
3034
std::vector<T> out_vec;
3135
phi::TensorToVector(out_grad, dev_ctx, &out_vec);
3236

paddle/phi/kernels/gpu/roll_grad_kernel.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ void RollGradKernel(const Context& dev_ctx,
3030
const std::vector<int64_t>& axis,
3131
DenseTensor* x_grad) {
3232
auto* out_grad_data = out_grad.data<T>();
33+
if (x_grad && x_grad->numel() == 0) {
34+
dev_ctx.template Alloc<T>(x_grad);
35+
return;
36+
}
3337
T* x_grad_data = dev_ctx.template Alloc<T>(x_grad);
3438

3539
auto shifts_data = shifts.GetData();

test/fft/test_fft.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1934,5 +1934,47 @@ def test_ifftshift(self):
19341934
)
19351935

19361936

1937+
def create_test_class(op_type, dtype, shape, axis):
1938+
class Cls(unittest.TestCase):
1939+
def test_zero_size(self):
1940+
paddle.disable_static()
1941+
import scipy # noqa: F401
1942+
1943+
numpy_tensor_1 = np.random.rand(*shape).astype(dtype)
1944+
paddle_x = paddle.to_tensor(numpy_tensor_1)
1945+
paddle_x.stop_gradient = False
1946+
1947+
paddle_api = eval(f"paddle.fft.{op_type}")
1948+
paddle_out = paddle_api(
1949+
paddle_x, axes=axis
1950+
) # Here the parameter is axes
1951+
numpy_api = eval(f"scipy.fft.{op_type}")
1952+
numpy_out = numpy_api(numpy_tensor_1, axes=axis)
1953+
1954+
np.testing.assert_allclose(
1955+
paddle_out.numpy(),
1956+
numpy_out,
1957+
1e-2,
1958+
1e-2,
1959+
)
1960+
np.testing.assert_allclose(
1961+
paddle_out.shape,
1962+
numpy_out.shape,
1963+
)
1964+
1965+
cls_name = f"{op_type}{dtype}_0SizeTest"
1966+
Cls.__name__ = cls_name
1967+
globals()[cls_name] = Cls
1968+
1969+
1970+
op_list = [
1971+
("fftshift", ("complex64", "complex128"), (0, -1)),
1972+
("ifftshift", ("complex64", "complex128"), (0, -1)),
1973+
]
1974+
1975+
for op in op_list:
1976+
create_test_class(op[0], op[1][0], [3, 4, 0], op[2][0])
1977+
create_test_class(op[0], op[1][1], [3, 4, 0, 3, 4], op[2][1])
1978+
19371979
if __name__ == '__main__':
19381980
unittest.main()

0 commit comments

Comments
 (0)