Skip to content

Commit

Permalink
fix argument checking (PaddlePaddle#28)
Browse files Browse the repository at this point in the history
* fix argument checking
  • Loading branch information
Feiyu Chan authored Sep 9, 2021
1 parent d21249e commit 3e7792e
Showing 1 changed file with 127 additions and 9 deletions.
136 changes: 127 additions & 9 deletions python/paddle/tensor/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Sequence
import numpy as np
import paddle
from .attribute import is_complex, is_floating_point, is_interger, _real_to_complex_dtype, _complex_to_real_dtype
Expand All @@ -28,28 +29,57 @@ def _check_normalization(norm):
format(norm))


def _check_fft_n(n):
if not isinstance(n, int):
raise ValueError(
"Invalid FFT argument n({}), it shoule be an integer.".format(n))
if n <= 0:
raise ValueError(
"Invalid FFT argument n({}), it should be positive.".format(n))


def _check_fft_shape(x, s):
ndim = x.ndim
if not isinstance(s, Sequence):
raise ValueError(
"Invaid FFT argument s({}), it should be a sequence of integers.")

if len(s) > ndim:
raise ValueError(
"Length of fft sizes should not be larger than the rank of input. "
"Received, len of s: {}, rank of x: {}".format(len(s), ndim))
"Length of FFT argument s should not be larger than the rank of input. "
"Received s: {}, rank of x: {}".format(s, ndim))
for size in s:
if not isinstance(size, int) or size <= 0:
raise ValueError("FFT sizes {} contains invalid value ({})".format(
s, size))


def _check_fft_axis(x, axis):
ndim = x.ndim
if not isinstance(axis, int):
raise ValueError(
"Invalid FFT axis ({}), it shoule be an integer.".format(axis))
if axis < -ndim or axis >= ndim:
raise ValueError(
"Invalid FFT axis ({}), it should be in range [-{}, {})".format(
axis, ndim, ndim))


def _check_fft_axes(x, axes):
ndim = x.ndim
if not isinstance(axes, Sequence):
raise ValueError(
"Invalid FFT axes ({}), it should be a sequence of integers.".
format(axes))
if len(axes) > ndim:
raise ValueError(
"Length of fft axes should not be larger than the rank of input. "
"Received, len of axes: {}, rank of x: {}".format(len(axes), ndim))
for axis in axes:
if not isinstance(axis, int) or axis < -ndim or axis >= ndim:
raise ValueError("FFT axes {} contains invalid value ({})".format(
axes, axis))
raise ValueError(
"FFT axes {} contains invalid value ({}), it should be in range [-{}, {})".
format(axes, axis, ndim, ndim))


def _resize_fft_input(x, s, axes):
Expand Down Expand Up @@ -89,6 +119,12 @@ def _normalize_axes(x, axes):
return [item if item >= 0 else (item + ndim) for item in axes]


def _check_at_least_ndim(x, rank):
if x.ndim < rank:
raise ValueError("The rank of the input ({}) should >= {}".format(
x.ndim, rank))


# public APIs 1d
def fft(x, n=None, axis=-1, norm="backward", name=None):
if not is_complex(x):
Expand Down Expand Up @@ -157,26 +193,92 @@ def ihfftn(x, s=None, axes=None, norm="backward", name=None):

## public APIs 2d
def fft2(x, s=None, axes=(-2, -1), norm="backward", name=None):
_check_at_least_ndim(x, 2)
if s is not None:
if not isinstance(s, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument s ({}), it should be a sequence of 2 integers.".
format(s))
if axes is not None:
if not isinstance(axes, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".
format(axes))
return fftn(x, s, axes, norm, name)


def ifft2(x, s=None, axes=(-2, -1), norm="backward", name=None):
_check_at_least_ndim(x, 2)
if s is not None:
if not isinstance(s, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument s ({}), it should be a sequence of 2 integers.".
format(s))
if axes is not None:
if not isinstance(axes, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".
format(axes))
return ifftn(x, s, axes, norm, name)


def rfft2(x, s=None, axes=(-2, -1), norm="backward", name=None):
_check_at_least_ndim(x, 2)
if s is not None:
if not isinstance(s, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument s ({}), it should be a sequence of 2 integers.".
format(s))
if axes is not None:
if not isinstance(axes, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".
format(axes))
return rfftn(x, s, axes, norm, name)


def irfft2(x, s=None, axes=(-2, -1), norm="backward", name=None):
_check_at_least_ndim(x, 2)
if s is not None:
if not isinstance(s, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument s ({}), it should be a sequence of 2 integers.".
format(s))
if axes is not None:
if not isinstance(axes, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".
format(axes))
return irfftn(x, s, axes, norm, name)


def hfft2(x, s=None, axes=(-2, -1), norm="backward", name=None):
_check_at_least_ndim(x, 2)
if s is not None:
if not isinstance(s, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument s ({}), it should be a sequence of 2 integers.".
format(s))
if axes is not None:
if not isinstance(axes, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".
format(axes))
return hfftn(x, s, axes, norm, name)


def ihfft2(x, s=None, axes=(-2, -1), norm="backward", name=None):
_check_at_least_ndim(x, 2)
if s is not None:
if not isinstance(s, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument s ({}), it should be a sequence of 2 integers.".
format(s))
if axes is not None:
if not isinstance(axes, Sequence) or len(s) != 2:
raise ValueError(
"Invalid FFT argument axes ({}), it should be a sequence of 2 integers.".
format(axes))
return ihfftn(x, s, axes, norm, name)


Expand Down Expand Up @@ -232,12 +334,14 @@ def fft_c2c(x, n, axis, norm, forward, name):
if is_interger(x):
x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype()))
_check_normalization(norm)

axis = axis or -1
_check_fft_axis(x, axis)
axes = [axis]
axes = _normalize_axes(x, axes)
if n is not None:
_check_fft_n(n)
s = [n]
_check_fft_shape(x, s)
x = _resize_fft_input(x, s, axes)
op_type = 'fft_c2c'

Expand All @@ -262,11 +366,12 @@ def fft_r2c(x, n, axis, norm, forward, onesided, name):
x = paddle.cast(x, paddle.get_default_dtype())
_check_normalization(norm)
axis = axis or -1
_check_fft_axis(x, axis)
axes = [axis]
axes = _normalize_axes(x, axes)
if n is not None:
_check_fft_n(n)
s = [n]
_check_fft_shape(x, s)
x = _resize_fft_input(x, s, axes)
op_type = 'fft_r2c'
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], op_type)
Expand Down Expand Up @@ -298,11 +403,12 @@ def fft_c2r(x, n, axis, norm, forward, name):
x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype()))
_check_normalization(norm)
axis = axis or -1
_check_fft_axis(x, axis)
axes = [axis]
axes = _normalize_axes(x, axes)
if n is not None:
_check_fft_n(n)
s = [n // 2 + 1]
_check_fft_shape(x, s)
x = _resize_fft_input(x, s, axes)
op_type = 'fft_c2r'
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type)
Expand Down Expand Up @@ -349,6 +455,10 @@ def fftn_c2c(x, s, axes, norm, forward, name):
axes_argsoft = np.argsort(axes).tolist()
axes = [axes[i] for i in axes_argsoft]
if s is not None:
if len(s) != len(axes):
raise ValueError(
"Length of s ({}) and length of axes ({}) does not match.".
format(len(s), len(axes)))
s = [s[i] for i in axes_argsoft]

if s is not None:
Expand Down Expand Up @@ -391,7 +501,11 @@ def fftn_r2c(x, s, axes, norm, forward, onesided, name):
axes_argsoft = np.argsort(axes[:-1]).tolist()
axes = [axes[i] for i in axes_argsoft] + [axes[-1]]
if s is not None:
s = [s[i] for i in axes_argsoft] + s[-1]
if len(s) != len(axes):
raise ValueError(
"Length of s ({}) and length of axes ({}) does not match.".
format(len(s), len(axes)))
s = [s[i] for i in axes_argsoft] + [s[-1]]

if s is not None:
x = _resize_fft_input(x, s, axes)
Expand Down Expand Up @@ -442,7 +556,11 @@ def fftn_c2r(x, s, axes, norm, forward, name):
axes_argsoft = np.argsort(axes[:-1]).tolist()
axes = [axes[i] for i in axes_argsoft] + [axes[-1]]
if s is not None:
s = [s[i] for i in axes_argsoft] + s[-1]
if len(s) != len(axes):
raise ValueError(
"Length of s ({}) and length of axes ({}) does not match.".
format(len(s), len(axes)))
s = [s[i] for i in axes_argsoft] + [s[-1]]

if s is not None:
fft_input_shape = list(s)
Expand Down

0 comments on commit 3e7792e

Please sign in to comment.