Skip to content

Commit aed002f

Browse files
committed
Fix
1 parent 16c8e62 commit aed002f

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

paddle/phi/infermeta/unary.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,7 +1470,7 @@ void FFTC2CInferMeta(const MetaTensor& x,
14701470
if (config.is_runtime) {
14711471
const phi::DDim x_dim = x.dims();
14721472
for (auto axis : axes) {
1473-
PADDLE_ENFORCE_GT(x_dim[axis],
1473+
PADDLE_ENFORCE_GE(x_dim[axis],
14741474
0,
14751475
common::errors::InvalidArgument(
14761476
"Invalid fft n-point (%d).", x_dim[axis]));
@@ -1497,7 +1497,7 @@ void FFTC2RInferMeta(const MetaTensor& x,
14971497
if (config.is_runtime) {
14981498
size_t signal_dims = axes.size();
14991499
for (size_t i = 0; i < signal_dims - 1; i++) {
1500-
PADDLE_ENFORCE_GT(x_dim[axes[i]],
1500+
PADDLE_ENFORCE_GE(x_dim[axes[i]],
15011501
0,
15021502
common::errors::InvalidArgument(
15031503
"Invalid fft n-point (%d).", x_dim[axes[i]]));
@@ -1513,7 +1513,7 @@ void FFTC2RInferMeta(const MetaTensor& x,
15131513
} else if (config.is_runtime) {
15141514
const int64_t input_last_dim_size = x_dim[last_fft_axis];
15151515
const int64_t fft_n_point = (input_last_dim_size - 1) * 2;
1516-
PADDLE_ENFORCE_GT(fft_n_point,
1516+
PADDLE_ENFORCE_GE(fft_n_point,
15171517
0,
15181518
common::errors::InvalidArgument(
15191519
"Invalid fft n-point (%d).", fft_n_point));
@@ -1542,7 +1542,7 @@ void FFTR2CInferMeta(const MetaTensor& x,
15421542
// they might be -1 to indicate unknown size ar compile time
15431543
if (config.is_runtime) {
15441544
for (auto axis : axes) {
1545-
PADDLE_ENFORCE_GT(x_dim[axis],
1545+
PADDLE_ENFORCE_GE(x_dim[axis],
15461546
0,
15471547
common::errors::InvalidArgument(
15481548
"Invalid fft n-point (%d).", x_dim[axis]));

test/legacy_test/test_fft_op_0size.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,21 @@ def test_zero_size(self):
6161
"rfft2",
6262
"rfftn",
6363
]
64-
for op_type in op_list:
65-
create_test_class(op_type, "complex64", [3, 4, 0], 0)
66-
create_test_class(op_type, "complex128", [3, 4, 0, 3, 4], -1)
64+
axes_list = [
65+
((-2, -1), None),
66+
(None, None),
67+
(0, -1),
68+
((-2, -1), None),
69+
(None, None),
70+
(0, -1),
71+
((-2, -1), None),
72+
(None, None),
73+
((-2, -1), None),
74+
(None, None),
75+
]
76+
for i, op_type in enumerate(op_list):
77+
create_test_class(op_type, "complex64", [3, 4, 0], axes_list[i][0])
78+
create_test_class(op_type, "complex128", [3, 4, 0, 3, 4], axes_list[i][1])
6779

6880
if __name__ == '__main__':
6981
unittest.main()

0 commit comments

Comments
 (0)