@@ -1470,7 +1470,7 @@ void FFTC2CInferMeta(const MetaTensor& x,
14701470 if (config.is_runtime ) {
14711471 const phi::DDim x_dim = x.dims ();
14721472 for (auto axis : axes) {
1473- PADDLE_ENFORCE_GT (x_dim[axis],
1473+ PADDLE_ENFORCE_GE (x_dim[axis],
14741474 0 ,
14751475 common::errors::InvalidArgument (
14761476 " Invalid fft n-point (%d)." , x_dim[axis]));
@@ -1497,7 +1497,7 @@ void FFTC2RInferMeta(const MetaTensor& x,
14971497 if (config.is_runtime ) {
14981498 size_t signal_dims = axes.size ();
14991499 for (size_t i = 0 ; i < signal_dims - 1 ; i++) {
1500- PADDLE_ENFORCE_GT (x_dim[axes[i]],
1500+ PADDLE_ENFORCE_GE (x_dim[axes[i]],
15011501 0 ,
15021502 common::errors::InvalidArgument (
15031503 " Invalid fft n-point (%d)." , x_dim[axes[i]]));
@@ -1513,7 +1513,7 @@ void FFTC2RInferMeta(const MetaTensor& x,
15131513 } else if (config.is_runtime ) {
15141514 const int64_t input_last_dim_size = x_dim[last_fft_axis];
15151515 const int64_t fft_n_point = (input_last_dim_size - 1 ) * 2 ;
1516- PADDLE_ENFORCE_GT (fft_n_point,
1516+ PADDLE_ENFORCE_GE (fft_n_point,
15171517 0 ,
15181518 common::errors::InvalidArgument (
15191519 " Invalid fft n-point (%d)." , fft_n_point));
@@ -1542,7 +1542,7 @@ void FFTR2CInferMeta(const MetaTensor& x,
15421542 // they might be -1 to indicate unknown size ar compile time
15431543 if (config.is_runtime ) {
15441544 for (auto axis : axes) {
1545- PADDLE_ENFORCE_GT (x_dim[axis],
1545+ PADDLE_ENFORCE_GE (x_dim[axis],
15461546 0 ,
15471547 common::errors::InvalidArgument (
15481548 " Invalid fft n-point (%d)." , x_dim[axis]));
0 commit comments