diff --git a/pywt/_dwt.py b/pywt/_dwt.py index 56114566a..ea2d05047 100644 --- a/pywt/_dwt.py +++ b/pywt/_dwt.py @@ -92,8 +92,8 @@ def dwt_coeff_len(data_len, filter_len, mode): Data length. filter_len : int Filter length. - mode : str, optional (default: 'symmetric') - Signal extension mode, see Modes + mode : str, optional + Signal extension mode, see :ref:`Modes `. Returns ------- @@ -130,7 +130,7 @@ def dwt(data, wavelet, mode='symmetric', axis=-1): wavelet : Wavelet object or name Wavelet to use mode : str, optional - Signal extension mode, see Modes + Signal extension mode, see :ref:`Modes `. axis: int, optional Axis over which to compute the DWT. If not given, the last axis is used. @@ -199,14 +199,14 @@ def idwt(cA, cD, wavelet, mode='symmetric', axis=-1): ---------- cA : array_like or None Approximation coefficients. If None, will be set to array of zeros - with same shape as `cD`. + with same shape as ``cD``. cD : array_like or None Detail coefficients. If None, will be set to array of zeros - with same shape as `cA`. + with same shape as ``cA``. wavelet : Wavelet object or name Wavelet to use mode : str, optional (default: 'symmetric') - Signal extension mode, see Modes + Signal extension mode, see :ref:`Modes `. axis: int, optional Axis over which to compute the inverse DWT. If not given, the last axis is used. @@ -224,7 +224,7 @@ def idwt(cA, cD, wavelet, mode='symmetric', axis=-1): >>> pywt.idwt(cA, cD, 'db2', 'smooth') array([ 1., 2., 3., 4., 5., 6.]) - One of the neat features of `idwt` is that one of the ``cA`` and ``cD`` + One of the neat features of ``idwt`` is that one of the ``cA`` and ``cD`` arguments can be set to None. In that situation the reconstruction will be performed using only the other one. Mathematically speaking, this is equivalent to passing a zero-filled array as one of the arguments. @@ -300,7 +300,7 @@ def downcoef(part, data, wavelet, mode='symmetric', level=1): Partial Discrete Wavelet Transform data decomposition. - Similar to `pywt.dwt`, but computes only one set of coefficients. + Similar to ``pywt.dwt``, but computes only one set of coefficients. Useful when you need only approximation or only details at the given level. Parameters @@ -316,7 +316,7 @@ def downcoef(part, data, wavelet, mode='symmetric', level=1): wavelet : Wavelet object or name Wavelet to use mode : str, optional - Signal extension mode, see `Modes`. Default is 'symmetric'. + Signal extension mode, see :ref:`Modes `. level : int, optional Decomposition level. Default is 1. diff --git a/pywt/_extensions/_swt.pyx b/pywt/_extensions/_swt.pyx index 8955a4409..0b3d82103 100644 --- a/pywt/_extensions/_swt.pyx +++ b/pywt/_extensions/_swt.pyx @@ -1,6 +1,7 @@ #cython: boundscheck=False, wraparound=False from . cimport common from . cimport c_wt +from cpython cimport bool import warnings import numpy as np @@ -9,6 +10,7 @@ cimport numpy as np from .common cimport pywt_index_t from ._pywt cimport c_wavelet_from_object, cdata_t, Wavelet, _check_dtype + include "config.pxi" def swt_max_level(size_t input_len): @@ -47,7 +49,8 @@ def swt_max_level(size_t input_len): return max_level -def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level): +def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level, + bool trim_approx=False): cdef cdata_t[::1] cA, cD cdef Wavelet w cdef int retval @@ -142,14 +145,20 @@ def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level): raise RuntimeError("C swt failed.") data = cA - ret.append((cA, cD)) + if not trim_approx: + ret.append((np.asarray(cA), np.asarray(cD))) + else: + ret.append(np.asarray(cD)) + if trim_approx: + ret.append(np.asarray(cA)) ret.reverse() return ret cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level, - size_t start_level, unsigned int axis=0): + size_t start_level, unsigned int axis=0, + bool trim_approx=False): # memory-views do not support n-dimensional arrays, use np.ndarray instead cdef common.ArrayInfo data_info, output_info cdef np.ndarray cD, cA @@ -289,7 +298,10 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level, if retval == -5: raise TypeError("Array must be floating point, not {}" .format(data.dtype)) - ret.append((cA, cD)) + if not trim_approx: + ret.append((cA, cD)) + else: + ret.append(cD) # previous approx coeffs are the data for the next level data = cA @@ -297,5 +309,8 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level, data_info.strides = data.strides data_info.shape = data.shape + if trim_approx: + ret.append(cA) + ret.reverse() return ret diff --git a/pywt/_multidim.py b/pywt/_multidim.py index 39d9dc2bf..3636d01c6 100644 --- a/pywt/_multidim.py +++ b/pywt/_multidim.py @@ -33,7 +33,7 @@ def dwt2(data, wavelet, mode='symmetric', axes=(-2, -1)): Wavelet to use. This can also be a tuple containing a wavelet to apply along each axis in ``axes``. mode : str or 2-tuple of strings, optional - Signal extension mode, see Modes (default: 'symmetric'). This can + Signal extension mode, see :ref:`Modes `. This can also be a tuple of modes specifying the mode to use on each axis in ``axes``. axes : 2-tuple of ints, optional @@ -84,13 +84,13 @@ def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)): ---------- coeffs : tuple (cA, (cH, cV, cD)) A tuple with approximation coefficients and three - details coefficients 2D arrays like from `dwt2`. If any of these + details coefficients 2D arrays like from ``dwt2``. If any of these components are set to ``None``, it will be treated as zeros. wavelet : Wavelet object or name string, or 2-tuple of wavelets Wavelet to use. This can also be a tuple containing a wavelet to apply along each axis in ``axes``. mode : str or 2-tuple of strings, optional - Signal extension mode, see Modes (default: 'symmetric'). This can + Signal extension mode, see :ref:`Modes `. This can also be a tuple of modes specifying the mode to use on each axis in ``axes``. axes : 2-tuple of ints, optional @@ -131,7 +131,7 @@ def dwtn(data, wavelet, mode='symmetric', axes=None): apply along each axis in ``axes``. mode : str or tuple of string, optional Signal extension mode used in the decomposition, - see Modes (default: 'symmetric'). This can also be a tuple of modes + see :ref:`Modes `. This can also be a tuple of modes specifying the mode to use on each axis in ``axes``. axes : sequence of ints, optional Axes over which to compute the DWT. Repeated elements mean the DWT will @@ -233,7 +233,7 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None): apply along each axis in ``axes``. mode : str or list of string, optional Signal extension mode used in the decomposition, - see Modes (default: 'symmetric'). This can also be a tuple of modes + see :ref:`Modes `. This can also be a tuple of modes specifying the mode to use on each axis in ``axes``. axes : sequence of ints, optional Axes over which to compute the IDWT. Repeated elements mean the IDWT diff --git a/pywt/_multilevel.py b/pywt/_multilevel.py index 1ec785eb1..8bc195579 100644 --- a/pywt/_multilevel.py +++ b/pywt/_multilevel.py @@ -57,10 +57,10 @@ def wavedec(data, wavelet, mode='symmetric', level=None, axis=-1): wavelet : Wavelet object or name string Wavelet to use mode : str, optional - Signal extension mode, see `Modes` (default: 'symmetric') + Signal extension mode, see :ref:`Modes `. level : int, optional Decomposition level (must be >= 0). If level is None (default) then it - will be calculated using the `dwt_max_level` function. + will be calculated using the ``dwt_max_level`` function. axis: int, optional Axis over which to compute the DWT. If not given, the last axis is used. @@ -69,9 +69,10 @@ def wavedec(data, wavelet, mode='symmetric', level=None, axis=-1): ------- [cA_n, cD_n, cD_n-1, ..., cD2, cD1] : list Ordered list of coefficients arrays - where `n` denotes the level of decomposition. The first element - (`cA_n`) of the result is approximation coefficients array and the - following elements (`cD_n` - `cD_1`) are details coefficients arrays. + where ``n`` denotes the level of decomposition. The first element + (``cA_n``) of the result is approximation coefficients array and the + following elements (``cD_n`` - ``cD_1``) are details coefficients + arrays. Examples -------- @@ -119,14 +120,14 @@ def waverec(coeffs, wavelet, mode='symmetric', axis=-1): wavelet : Wavelet object or name string Wavelet to use mode : str, optional - Signal extension mode, see `Modes` (default: 'symmetric') + Signal extension mode, see :ref:`Modes `. axis: int, optional Axis over which to compute the inverse DWT. If not given, the last axis is used. Notes ----- - It may sometimes be desired to run `waverec` with some sets of + It may sometimes be desired to run ``waverec`` with some sets of coefficients omitted. This can best be done by setting the corresponding arrays to zero arrays of matching shape and dtype. Explicitly removing list entries or setting them to None is not supported. @@ -183,25 +184,26 @@ def wavedec2(data, wavelet, mode='symmetric', level=None, axes=(-2, -1)): 2D input data wavelet : Wavelet object or name string, or 2-tuple of wavelets Wavelet to use. This can also be a tuple containing a wavelet to - apply along each axis in `axes`. + apply along each axis in ``axes``. mode : str or 2-tuple of str, optional - Signal extension mode, see `Modes` (default: 'symmetric'). This can - also be a tuple containing a mode to apply along each axis in `axes`. + Signal extension mode, see :ref:`Modes `. This can + also be a tuple containing a mode to apply along each axis in ``axes``. level : int, optional Decomposition level (must be >= 0). If level is None (default) then it - will be calculated using the `dwt_max_level` function. + will be calculated using the ``dwt_max_level`` function. axes : 2-tuple of ints, optional Axes over which to compute the DWT. Repeated elements are not allowed. Returns ------- [cAn, (cHn, cVn, cDn), ... (cH1, cV1, cD1)] : list - Coefficients list. For user-specified `axes`, `cH*` - corresponds to ``axes[0]`` while `cV*` corresponds to ``axes[1]``. + Coefficients list. For user-specified ``axes``, ``cH*`` + corresponds to ``axes[0]`` while ``cV*`` corresponds to ``axes[1]``. The first element returned is the approximation coefficients for the nth level of decomposition. Remaining elements are tuples of detail coefficients in descending order of decomposition level. - (i.e. `cH1` are the horizontal detail coefficients at the first level) + (i.e. ``cH1`` are the horizontal detail coefficients at the first + level) Examples -------- @@ -257,10 +259,10 @@ def waverec2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)): Coefficients list [cAn, (cHn, cVn, cDn), ... (cH1, cV1, cD1)] wavelet : Wavelet object or name string, or 2-tuple of wavelets Wavelet to use. This can also be a tuple containing a wavelet to - apply along each axis in `axes`. + apply along each axis in ``axes``. mode : str or 2-tuple of str, optional - Signal extension mode, see `Modes` (default: 'symmetric'). This can - also be a tuple containing a mode to apply along each axis in `axes`. + Signal extension mode, see :ref:`Modes `. This can + also be a tuple containing a mode to apply along each axis in ``axes``. axes : 2-tuple of ints, optional Axes over which to compute the IDWT. Repeated elements are not allowed. @@ -270,7 +272,7 @@ def waverec2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)): Notes ----- - It may sometimes be desired to run `waverec2` with some sets of + It may sometimes be desired to run ``waverec2`` with some sets of coefficients omitted. This can best be done by setting the corresponding arrays to zero arrays of matching shape and dtype. Explicitly removing list or tuple entries or setting them to None is not supported. @@ -357,13 +359,13 @@ def wavedecn(data, wavelet, mode='symmetric', level=None, axes=None): nD input data wavelet : Wavelet object or name string, or tuple of wavelets Wavelet to use. This can also be a tuple containing a wavelet to - apply along each axis in `axes`. + apply along each axis in ``axes``. mode : str or tuple of str, optional - Signal extension mode, see `Modes` (default: 'symmetric'). This can - also be a tuple containing a mode to apply along each axis in `axes`. + Signal extension mode, see :ref:`Modes `. This can + also be a tuple containing a mode to apply along each axis in ``axes``. level : int, optional Decomposition level (must be >= 0). If level is None (default) then it - will be calculated using the `dwt_max_level` function. + will be calculated using the ``dwt_max_level`` function. axes : sequence of ints, optional Axes over which to compute the DWT. Axes may not be repeated. The default is None, which means transform all axes @@ -373,16 +375,16 @@ def wavedecn(data, wavelet, mode='symmetric', level=None, axes=None): ------- [cAn, {details_level_n}, ... {details_level_1}] : list Coefficients list. Coefficients are listed in descending order of - decomposition level. `cAn` are the approximation coefficients at - level `n`. Each `details_level_i` element is a dictionary - containing detail coefficients at level `i` of the decomposition. As + decomposition level. ``cAn`` are the approximation coefficients at + level ``n``. Each ``details_level_i`` element is a dictionary + containing detail coefficients at level ``i`` of the decomposition. As a concrete example, a 3D decomposition would have the following set of - keys in each `details_level_i` dictionary:: + keys in each ``details_level_i`` dictionary:: {'aad', 'ada', 'daa', 'add', 'dad', 'dda', 'ddd'} where the order of the characters in each key map to the specified - `axes`. + ``axes``. Examples -------- @@ -456,10 +458,10 @@ def waverecn(coeffs, wavelet, mode='symmetric', axes=None): Coefficients list [cAn, {details_level_n}, ... {details_level_1}] wavelet : Wavelet object or name string, or tuple of wavelets Wavelet to use. This can also be a tuple containing a wavelet to - apply along each axis in `axes`. + apply along each axis in ``axes``. mode : str or tuple of str, optional - Signal extension mode, see `Modes` (default: 'symmetric'). This can - also be a tuple containing a mode to apply along each axis in `axes`. + Signal extension mode, see :ref:`Modes `. This can + also be a tuple containing a mode to apply along each axis in ``axes``. axes : sequence of ints, optional Axes over which to compute the IDWT. Axes may not be repeated. @@ -469,7 +471,7 @@ def waverecn(coeffs, wavelet, mode='symmetric', axes=None): Notes ----- - It may sometimes be desired to run `waverecn` with some sets of + It may sometimes be desired to run ``waverecn`` with some sets of coefficients omitted. This can best be done by setting the corresponding arrays to zero arrays of matching shape and dtype. Explicitly removing list or dictionary entries or setting them to None is not supported. @@ -655,7 +657,7 @@ def _prepare_coeffs_axes(coeffs, axes): def coeffs_to_array(coeffs, padding=0, axes=None): """ - Arrange a wavelet coefficient list from `wavedecn` into a single array. + Arrange a wavelet coefficient list from ``wavedecn`` into a single array. Parameters ---------- @@ -665,7 +667,7 @@ def coeffs_to_array(coeffs, padding=0, axes=None): padding : float or None, optional If None, raise an error if the coefficients cannot be tightly packed. axes : sequence of ints, optional - Axes over which the DWT that created `coeffs` was performed. The + Axes over which the DWT that created ``coeffs`` was performed. The default value of None corresponds to all axes. Returns @@ -674,8 +676,8 @@ def coeffs_to_array(coeffs, padding=0, axes=None): Wavelet transform coefficient array. coeff_slices : list List of slices corresponding to each coefficient. As a 2D example, - `coeff_arr[coeff_slices[1]['dd']]` would extract the first level detail - coefficients from `coeff_arr`. + ``coeff_arr[coeff_slices[1]['dd']]`` would extract the first level + detail coefficients from ``coeff_arr``. See Also -------- @@ -773,17 +775,17 @@ def coeffs_to_array(coeffs, padding=0, axes=None): def array_to_coeffs(arr, coeff_slices, output_format='wavedecn'): """ Convert a combined array of coefficients back to a list compatible with - `waverecn`. + ``waverecn``. Parameters ---------- arr : array-like An array containing all wavelet coefficients. This should have been - generated via `coeffs_to_array`. + generated via ``coeffs_to_array``. coeff_slices : list of tuples List of slices corresponding to each coefficient as obtained from - `array_to_coeffs`. + ``array_to_coeffs``. output_format : {'wavedec', 'wavedec2', 'wavedecn'} Make the form of the coefficients compatible with this type of multilevel transform. @@ -800,7 +802,7 @@ def array_to_coeffs(arr, coeff_slices, output_format='wavedecn'): Notes ----- A single large array containing all coefficients will have subsets stored, - into a `waverecn` list, c, as indicated below:: + into a ``waverecn`` list, c, as indicated below:: +---------------+---------------+-------------------------------+ | | | | @@ -867,13 +869,13 @@ def wavedecn_shapes(shape, wavelet, mode='symmetric', level=None, axes=None): The shape of the data to be transformed. wavelet : Wavelet object or name string, or tuple of wavelets Wavelet to use. This can also be a tuple containing a wavelet to - apply along each axis in `axes`. + apply along each axis in ``axes``. mode : str or tuple of str, optional - Signal extension mode, see Modes (default: 'symmetric'). This can - also be a tuple containing a mode to apply along each axis in `axes`. + Signal extension mode, see :ref:`Modes `. This can + also be a tuple containing a mode to apply along each axis in ``axes``. level : int, optional Decomposition level (must be >= 0). If level is None (default) then it - will be calculated using the `dwt_max_level` function. + will be calculated using the ``dwt_max_level`` function. axes : sequence of ints, optional Axes over which to compute the DWT. Axes may not be repeated. The default is None, which means transform all axes @@ -882,7 +884,7 @@ def wavedecn_shapes(shape, wavelet, mode='symmetric', level=None, axes=None): Returns ------- shapes : [cAn, {details_level_n}, ... {details_level_1}] : list - Coefficients shape list. Mirrors the output of `wavedecn`, except + Coefficients shape list. Mirrors the output of ``wavedecn``, except it contains only the shapes of the coefficient arrays rather than the arrays themselves. @@ -922,9 +924,9 @@ def wavedecn_size(shapes): Parameters ---------- shapes : list of coefficient shapes - A set of coefficient shapes as returned by `wavedecn_shapes`. + A set of coefficient shapes as returned by ``wavedecn_shapes``. Alternatively, the user can specify a set of coefficients as returned - by `wavedecn`. + by ``wavedecn``. Returns ------- @@ -944,7 +946,7 @@ def wavedecn_size(shapes): 3087 """ def _size(x): - """Size corresponding to `x` as either a shape tuple or an ndarray.""" + """Size corresponding to ``x`` as either a shape tuple or ndarray.""" if isinstance(x, np.ndarray): return x.size else: @@ -963,7 +965,7 @@ def dwtn_max_level(shape, wavelet, axes=None): """Compute the maximum level of decomposition for n-dimensional data. This returns the maximum number of levels of decomposition suitable for use - with `wavedec`, `wavedec2` or `wavedecn`. + with ``wavedec``, ``wavedec2`` or ``wavedecn``. Parameters ---------- @@ -971,7 +973,7 @@ def dwtn_max_level(shape, wavelet, axes=None): Input data shape. wavelet : Wavelet object or name string, or tuple of wavelets Wavelet to use. This can also be a tuple containing a wavelet to - apply along each axis in `axes`. + apply along each axis in ``axes``. axes : sequence of ints, optional Axes over which to compute the DWT. Axes may not be repeated. @@ -982,7 +984,7 @@ def dwtn_max_level(shape, wavelet, axes=None): Notes ----- - The level returned is the smallest `dwt_max_level` over all axes. + The level returned is the smallest ``dwt_max_level`` over all axes. Examples -------- @@ -1009,9 +1011,11 @@ def ravel_coeffs(coeffs, axes=None): ---------- coeffs : array-like A list of multilevel wavelet coefficients as returned by - `wavedec`, `wavedec2` or `wavedecn`. + ``wavedec``, ``wavedec2`` or ``wavedecn``. This function is also + compatible with the output of ``swt``, ``swt2`` and ``swtn`` if those + functions were called with ``trim_approx=True``. axes : sequence of ints, optional - Axes over which the DWT that created `coeffs` was performed. The + Axes over which the DWT that created ``coeffs`` was performed. The default value of None corresponds to all axes. Returns @@ -1022,7 +1026,7 @@ def ravel_coeffs(coeffs, axes=None): coeff_slices : list List of slices corresponding to each coefficient. As a 2D example, ``coeff_arr[coeff_slices[1]['dd']]`` would extract the first level - detail coefficients from `coeff_arr`. + detail coefficients from ``coeff_arr``. coeff_shapes : list List of shapes corresponding to each coefficient. For example, in 2D, ``coeff_shapes[1]['dd']`` would contain the original shape of the first @@ -1092,23 +1096,24 @@ def unravel_coeffs(arr, coeff_slices, coeff_shapes, output_format='wavedecn'): ---------- arr : array-like An array containing all wavelet coefficients. This should have been - generated by applying `ravel_coeffs` to the output of `wavedec`, - `wavedec2` or `wavedecn`. + generated by applying ``ravel_coeffs`` to the output of ``wavedec``, + ``wavedec2`` or ``wavedecn`` (or via ``swt``, ``swt2`` or ``swtn`` + with ``trim_approx=True``). coeff_slices : list of tuples List of slices corresponding to each coefficient as obtained from - `ravel_coeffs`. + ``ravel_coeffs``. coeff_shapes : list of tuples List of shapes corresponding to each coefficient as obtained from - `ravel_coeffs`. - output_format : {'wavedec', 'wavedec2', 'wavedecn'}, optional + ``ravel_coeffs``. + output_format : {'wavedec', 'wavedec2', 'wavedecn', 'swt', 'swt2', 'swtn'}, optional Make the form of the unraveled coefficients compatible with this type - of multilevel transform. The default is 'wavedecn'. + of multilevel transform. The default is ``'wavedecn'``. Returns ------- coeffs: list List of wavelet transform coefficients. The specific format of the list - elements is determined by `output_format`. + elements is determined by ``output_format``. See Also -------- @@ -1141,13 +1146,13 @@ def unravel_coeffs(arr, coeff_slices, coeff_shapes, output_format='wavedecn'): for n in range(1, len(coeff_slices)): slice_dict = coeff_slices[n] shape_dict = coeff_shapes[n] - if output_format == 'wavedec': + if output_format in ['wavedec', 'swt']: d = arr[slice_dict['d']].reshape(shape_dict['d']) - elif output_format == 'wavedec2': + elif output_format in ['wavedec2', 'swt2']: d = (arr[slice_dict['da']].reshape(shape_dict['da']), arr[slice_dict['ad']].reshape(shape_dict['ad']), arr[slice_dict['dd']].reshape(shape_dict['dd'])) - elif output_format == 'wavedecn': + elif output_format in ['wavedecn', 'swtn']: d = {} for k, v in coeff_slices[n].items(): d[k] = arr[v].reshape(shape_dict[k]) @@ -1362,12 +1367,12 @@ def fswavedecn(data, wavelet, mode='symmetric', levels=None, axes=None): Wavelet to use. This can also be a tuple containing a wavelet to apply along each axis in ``axes``. mode : str or tuple of str, optional - Signal extension mode, see `Modes` (default: 'symmetric'). This can + Signal extension mode, see :ref:`Modes `. This can also be a tuple containing a mode to apply along each axis in ``axes``. levels : int or sequence of ints, optional Decomposition levels along each axis (must be >= 0). If an integer is provided, the same number of levels are used for all axes. If - ``levels`` is None (default), `dwt_max_level` will be used to compute + ``levels`` is None (default), ``dwt_max_level`` will be used to compute the maximum number of levels possible for each axis. axes : sequence of ints, optional Axes over which to compute the transform. Axes may not be repeated. The @@ -1378,7 +1383,7 @@ def fswavedecn(data, wavelet, mode='symmetric', levels=None, axes=None): fswavedecn_result : FswavedecnResult object Contains the wavelet coefficients, slice objects to allow obtaining the coefficients per detail or approximation level, and more. - See `FswavedecnResult` for details. + See ``FswavedecnResult`` for details. Examples -------- diff --git a/pywt/_swt.py b/pywt/_swt.py index 472c2ec2d..2600b169c 100644 --- a/pywt/_swt.py +++ b/pywt/_swt.py @@ -6,7 +6,7 @@ from ._c99_config import _have_c99_complex from ._extensions._dwt import idwt_single from ._extensions._swt import swt_max_level, swt as _swt, swt_axis as _swt_axis -from ._extensions._pywt import Modes, _check_dtype +from ._extensions._pywt import Wavelet, Modes, _check_dtype from ._multidim import idwt2, idwtn from ._utils import _as_wavelet, _wavelets_per_axis @@ -14,7 +14,18 @@ __all__ = ["swt", "swt_max_level", 'iswt', 'swt2', 'iswt2', 'swtn', 'iswtn'] -def swt(data, wavelet, level=None, start_level=0, axis=-1): +def _rescale_wavelet_filterbank(wavelet, sf): + wav = Wavelet(wavelet.name + 'r', + [np.asarray(f) * sf for f in wavelet.filter_bank]) + + # copy attributes from the original wavelet + wav.orthogonal = wavelet.orthogonal + wav.biorthogonal = wavelet.biorthogonal + return wav + + +def swt(data, wavelet, level=None, start_level=0, axis=-1, + trim_approx=False, norm=False): """ Multilevel 1D stationary wavelet transform. @@ -33,6 +44,13 @@ def swt(data, wavelet, level=None, start_level=0, axis=-1): axis: int, optional Axis over which to compute the SWT. If not given, the last axis is used. + trim_approx : bool, optional + If True, approximation coefficients at the final level are retained. + norm : bool, optional + If True, transform is normalized so that the energy of the coefficients + will be equal to the energy of ``data``. In other words, + ``np.linalg.norm(data.ravel())`` will equal the norm of the + concatenated transform coefficients when ``trim_approx`` is True. Returns ------- @@ -49,20 +67,59 @@ def swt(data, wavelet, level=None, start_level=0, axis=-1): [(cAm+n, cDm+n), ..., (cAm+1, cDm+1), (cAm, cDm)] + If ``trim_approx`` is ``True``, then the output list is exactly as in + ``pywt.wavedec``, where the first coefficient in the list is the + approximation coefficient at the final level and the rest are the + detail coefficients:: + + [cAn, cDn, ..., cD2, cD1] + Notes ----- The implementation here follows the "algorithm a-trous" and requires that the signal length along the transformed axis be a multiple of ``2**level``. If this is not the case, the user should pad up to an appropriate size using a function such as ``numpy.pad``. + + A primary benefit of this transform in comparison to its decimated + counterpart (``pywt.wavedecn``), is that it is shift-invariant. This comes + at cost of redundancy in the transform (the size of the output coefficients + is larger than the input). + + When the following three conditions are true:: + + 1.) The wavelet is orthogonal + 2.) ``swt`` is called with ``norm=True`` + 3.) ``swt`` is called with ``trim_approx=True`` + + the transform has the following additional properties that may be + desirable in applications: + 1.) energy is conserved + 2.) variance is partitioned across scales + + When used with ``norm=True``, this transform is closely related to the + multiple-overlap DWT (MODWT) as popularized for time-series analysis, + although the underlying implementation is slightly different from the one + published in [1]_. Specifically, the implementation used here requires a + signal that is a multiple of ``2**level`` in length. + + References + ---------- + .. [1] DB Percival and AT Walden. Wavelet Methods for Time Series Analysis. + Cambridge University Press, 2000. """ + if not _have_c99_complex and np.iscomplexobj(data): data = np.asarray(data) - coeffs_real = swt(data.real, wavelet, level, start_level) - coeffs_imag = swt(data.imag, wavelet, level, start_level) - coeffs_cplx = [] - for (cA_r, cD_r), (cA_i, cD_i) in zip(coeffs_real, coeffs_imag): - coeffs_cplx.append((cA_r + 1j*cA_i, cD_r + 1j*cD_i)) + coeffs_real = swt(data.real, wavelet, level, start_level, trim_approx) + coeffs_imag = swt(data.imag, wavelet, level, start_level, trim_approx) + if not trim_approx: + coeffs_cplx = [] + for (cA_r, cD_r), (cA_i, cD_i) in zip(coeffs_real, coeffs_imag): + coeffs_cplx.append((cA_r + 1j*cA_i, cD_r + 1j*cD_i)) + else: + coeffs_cplx = [cr + 1j*ci + for (cr, ci) in zip(coeffs_real, coeffs_imag)] return coeffs_cplx # accept array_like input; make a copy to ensure a contiguous array @@ -70,6 +127,12 @@ def swt(data, wavelet, level=None, start_level=0, axis=-1): data = np.array(data, dtype=dt) wavelet = _as_wavelet(wavelet) + if norm: + if not wavelet.orthogonal: + warnings.warn( + "norm=True, but the wavelet is not orthogonal: \n" + "\tThe conditions for energy preservation are not satisfied.") + wavelet = _rescale_wavelet_filterbank(wavelet, 1/np.sqrt(2)) if axis < 0: axis = axis + data.ndim @@ -80,13 +143,13 @@ def swt(data, wavelet, level=None, start_level=0, axis=-1): level = swt_max_level(data.shape[axis]) if data.ndim == 1: - ret = _swt(data, wavelet, level, start_level) + ret = _swt(data, wavelet, level, start_level, trim_approx) else: - ret = _swt_axis(data, wavelet, level, start_level, axis) - return [(np.asarray(cA), np.asarray(cD)) for cA, cD in ret] + ret = _swt_axis(data, wavelet, level, start_level, axis, trim_approx) + return ret -def iswt(coeffs, wavelet): +def iswt(coeffs, wavelet, norm=False): """ Multilevel 1D inverse discrete stationary wavelet transform. @@ -101,6 +164,10 @@ def iswt(coeffs, wavelet): ``start_level`` from ``pywt.swt``. wavelet : Wavelet object or name string Wavelet to use + norm : bool, optional + Controls the normalization used by the inverse transform. This must + be set equal to the value that was used by ``pywt.swt`` to preserve the + energy of a round-trip transform. Returns ------- @@ -114,23 +181,41 @@ def iswt(coeffs, wavelet): array([ 1., 2., 3., 4., 5., 6., 7., 8.]) """ # copy to avoid modification of input data - dt = _check_dtype(coeffs[0][0]) - output = np.array(coeffs[0][0], dtype=dt, copy=True) + # If swt was called with trim_approx=False, first element is a tuple + trim_approx = not isinstance(coeffs[0], (tuple, list)) + + if trim_approx: + cA = coeffs[0] + coeffs = coeffs[1:] + else: + cA = coeffs[0][0] + + dt = _check_dtype(cA) + output = np.array(cA, dtype=dt, copy=True) if not _have_c99_complex and np.iscomplexobj(output): # compute real and imaginary separately then combine - coeffs_real = [(cA.real, cD.real) for (cA, cD) in coeffs] - coeffs_imag = [(cA.imag, cD.imag) for (cA, cD) in coeffs] + if trim_approx: + coeffs_real = [c.real for c in coeffs] + coeffs_imag = [c.imag for c in coeffs] + else: + coeffs_real = [(cA.real, cD.real) for (cA, cD) in coeffs] + coeffs_imag = [(cA.imag, cD.imag) for (cA, cD) in coeffs] return iswt(coeffs_real, wavelet) + 1j*iswt(coeffs_imag, wavelet) # num_levels, equivalent to the decomposition level, n num_levels = len(coeffs) wavelet = _as_wavelet(wavelet) + if norm: + wavelet = _rescale_wavelet_filterbank(wavelet, np.sqrt(2)) mode = Modes.from_object('periodization') for j in range(num_levels, 0, -1): step_size = int(pow(2, j-1)) last_index = step_size - _, cD = coeffs[num_levels - j] + if trim_approx: + cD = coeffs[-j] + else: + _, cD = coeffs[-j] cD = np.asarray(cD, dtype=_check_dtype(cD)) if cD.dtype != output.dtype: # upcast to a common dtype (float64 or complex128) @@ -170,7 +255,8 @@ def iswt(coeffs, wavelet): return output -def swt2(data, wavelet, level, start_level=0, axes=(-2, -1)): +def swt2(data, wavelet, level, start_level=0, axes=(-2, -1), + trim_approx=False, norm=False): """ Multilevel 2D stationary wavelet transform. @@ -187,11 +273,20 @@ def swt2(data, wavelet, level, start_level=0, axes=(-2, -1)): The level at which the decomposition will start (default: 0) axes : 2-tuple of ints, optional Axes over which to compute the SWT. Repeated elements are not allowed. + trim_approx : bool, optional + If True, approximation coefficients at the final level are retained. + norm : bool, optional + If True, transform is normalized so that the energy of the coefficients + will be equal to the energy of ``data``. In other words, + ``np.linalg.norm(data.ravel())`` will equal the norm of the + concatenated transform coefficients when ``trim_approx`` is True. Returns ------- coeffs : list - Approximation and details coefficients (for ``start_level = m``):: + Approximation and details coefficients (for ``start_level = m``). + If ``trim_approx`` is ``True``, approximation coefficients are + retained for all levels:: [ (cA_m+level, @@ -209,12 +304,40 @@ def swt2(data, wavelet, level, start_level=0, axes=(-2, -1)): where cA is approximation, cH is horizontal details, cV is vertical details, cD is diagonal details and m is ``start_level``. + If ``trim_approx`` is ``False``, approximation coefficients are only + retained at the final level of decomposition. This matches the format + used by ``pywt.wavedec2``:: + + [ + cA_m+level, + (cH_m+level, cV_m+level, cD_m+level), + ..., + (cH_m+1, cV_m+1, cD_m+1), + (cH_m, cV_m, cD_m), + ] + Notes ----- The implementation here follows the "algorithm a-trous" and requires that the signal length along the transformed axes be a multiple of ``2**level``. If this is not the case, the user should pad up to an appropriate size using a function such as ``numpy.pad``. + + A primary benefit of this transform in comparison to its decimated + counterpart (``pywt.wavedecn``), is that it is shift-invariant. This comes + at cost of redundancy in the transform (the size of the output coefficients + is larger than the input). + + When the following three conditions are true:: + + 1.) The wavelet is orthogonal + 2.) ``swt2`` is called with ``norm=True`` + 3.) ``swt2`` is called with ``trim_approx=True`` + + the transform has the following additional properties that may be + desirable in applications: + 1.) energy is conserved + 2.) variance is partitioned across scales """ axes = tuple(axes) data = np.asarray(data) @@ -226,15 +349,20 @@ def swt2(data, wavelet, level, start_level=0, axes=(-2, -1)): raise ValueError("Input array has fewer dimensions than the specified " "axes") - coefs = swtn(data, wavelet, level, start_level, axes) + coefs = swtn(data, wavelet, level, start_level, axes, trim_approx, norm) ret = [] + if trim_approx: + ret.append(coefs[0]) + coefs = coefs[1:] for c in coefs: - ret.append((c['aa'], (c['da'], c['ad'], c['dd']))) - + if trim_approx: + ret.append((c['da'], c['ad'], c['dd'])) + else: + ret.append((c['aa'], (c['da'], c['ad'], c['dd']))) return ret -def iswt2(coeffs, wavelet): +def iswt2(coeffs, wavelet, norm=False): """ Multilevel 2D inverse discrete stationary wavelet transform. @@ -262,6 +390,10 @@ def iswt2(coeffs, wavelet): wavelet : Wavelet object or name string, or 2-tuple of wavelets Wavelet to use. This can also be a 2-tuple of wavelets to apply per axis. + norm : bool, optional + Controls the normalization used by the inverse transform. This must + be set equal to the value that was used by ``pywt.swt2`` to preserve + the energy of a round-trip transform. Returns ------- @@ -281,9 +413,17 @@ def iswt2(coeffs, wavelet): """ + # If swt was called with trim_approx=False, first element is a tuple + trim_approx = not isinstance(coeffs[0], (tuple, list)) + if trim_approx: + cA = coeffs[0] + coeffs = coeffs[1:] + else: + cA = coeffs[0][0] + # copy to avoid modification of input data - dt = _check_dtype(coeffs[0][0]) - output = np.array(coeffs[0][0], dtype=dt, copy=True) + dt = _check_dtype(cA) + output = np.array(cA, dtype=dt, copy=True) if output.ndim != 2: raise ValueError( @@ -292,11 +432,17 @@ def iswt2(coeffs, wavelet): # num_levels, equivalent to the decomposition level, n num_levels = len(coeffs) wavelets = _wavelets_per_axis(wavelet, axes=(0, 1)) + if norm: + wavelets = [_rescale_wavelet_filterbank(wav, np.sqrt(2)) + for wav in wavelets] for j in range(num_levels): step_size = int(pow(2, num_levels-j-1)) last_index = step_size - _, (cH, cV, cD) = coeffs[j] + if trim_approx: + (cH, cV, cD) = coeffs[j] + else: + _, (cH, cV, cD) = coeffs[j] # We are going to assume cH, cV, and cD are of equal size if (cH.shape != cV.shape) or (cH.shape != cD.shape): raise RuntimeError( @@ -353,7 +499,8 @@ def iswt2(coeffs, wavelet): return output -def swtn(data, wavelet, level, start_level=0, axes=None): +def swtn(data, wavelet, level, start_level=0, axes=None, trim_approx=False, + norm=False): """ n-dimensional stationary wavelet transform. @@ -371,6 +518,13 @@ def swtn(data, wavelet, level, start_level=0, axes=None): axes : sequence of ints, optional Axes over which to compute the SWT. A value of ``None`` (the default) selects all axes. Axes may not be repeated. + trim_approx : bool, optional + If True, approximation coefficients at the final level are retained. + norm : bool, optional + If True, transform is normalized so that the energy of the coefficients + will be equal to the energy of ``data``. In other words, + ``np.linalg.norm(data.ravel())`` will equal the norm of the + concatenated transform coefficients when ``trim_approx`` is True. Returns ------- @@ -391,19 +545,45 @@ def swtn(data, wavelet, level, start_level=0, axes=None): For user-specified ``axes``, the order of the characters in the dictionary keys map to the specified ``axes``. + If ``trim_approx`` is ``True``, the first element of the list contains + the array of approximation coefficients from the final level of + decomposition, while the remaining coefficient dictionaries contain + only detail coefficients. This matches the behavior of `pywt.wavedecn`. + Notes ----- The implementation here follows the "algorithm a-trous" and requires that the signal length along the transformed axes be a multiple of ``2**level``. If this is not the case, the user should pad up to an appropriate size using a function such as ``numpy.pad``. + + A primary benefit of this transform in comparison to its decimated + counterpart (``pywt.wavedecn``), is that it is shift-invariant. This comes + at cost of redundancy in the transform (the size of the output coefficients + is larger than the input). + + When the following three conditions are true:: + + 1.) The wavelet is orthogonal + 2.) ``swtn`` is called with ``norm=True`` + 3.) ``swtn`` is called with ``trim_approx=True`` + + the transform has the following additional properties that may be + desirable in applications: + 1.) energy is conserved + 2.) variance is partitioned across scales """ data = np.asarray(data) if not _have_c99_complex and np.iscomplexobj(data): - real = swtn(data.real, wavelet, level, start_level, axes) - imag = swtn(data.imag, wavelet, level, start_level, axes) - cplx = [] - for rdict, idict in zip(real, imag): + real = swtn(data.real, wavelet, level, start_level, axes, trim_approx) + imag = swtn(data.imag, wavelet, level, start_level, axes, trim_approx) + if trim_approx: + cplx = [real[0] + 1j * imag[0]] + offset = 1 + else: + cplx = [] + offset = 0 + for rdict, idict in zip(real[offset:], imag[offset:]): cplx.append( dict((k, rdict[k] + 1j * idict[k]) for k in rdict.keys())) return cplx @@ -421,7 +601,13 @@ def swtn(data, wavelet, level, start_level=0, axes=None): num_axes = len(axes) wavelets = _wavelets_per_axis(wavelet, axes) - + if norm: + if not np.all([wav.orthogonal for wav in wavelets]): + warnings.warn( + "norm=True, but the wavelets used are not orthogonal: \n" + "\tThe conditions for energy preservation are not satisfied.") + wavelets = [_rescale_wavelet_filterbank(wav, 1/np.sqrt(2)) + for wav in wavelets] ret = [] for i in range(start_level, start_level + level): coeffs = [('', data)] @@ -439,12 +625,15 @@ def swtn(data, wavelet, level, start_level=0, axes=None): # data for the next level is the approximation coeffs from this level data = coeffs['a' * num_axes] - + if trim_approx: + coeffs.pop('a' * num_axes) + if trim_approx: + ret.append(data) ret.reverse() return ret -def iswtn(coeffs, wavelet, axes=None): +def iswtn(coeffs, wavelet, axes=None, norm=False): """ Multilevel nD inverse discrete stationary wavelet transform. @@ -459,6 +648,10 @@ def iswtn(coeffs, wavelet, axes=None): Axes over which to compute the inverse SWT. Axes may not be repeated. The default is ``None``, which means transform all axes (``axes = range(data.ndim)``). + norm : bool, optional + Controls the normalization used by the inverse transform. This must + be set equal to the value that was used by ``pywt.swtn`` to preserve + the energy of a round-trip transform. Returns ------- @@ -479,11 +672,18 @@ def iswtn(coeffs, wavelet, axes=None): """ # key length matches the number of axes transformed - ndim_transform = max(len(key) for key in coeffs[0].keys()) + ndim_transform = max(len(key) for key in coeffs[-1].keys()) + + trim_approx = not isinstance(coeffs[0], dict) + if trim_approx: + cA = coeffs[0] + coeffs = coeffs[1:] + else: + cA = coeffs[0]['a'*ndim_transform] # copy to avoid modification of input data - dt = _check_dtype(coeffs[0]['a'*ndim_transform]) - output = np.array(coeffs[0]['a'*ndim_transform], dtype=dt, copy=True) + dt = _check_dtype(cA) + output = np.array(cA, dtype=dt, copy=True) ndim = output.ndim if axes is None: @@ -498,6 +698,9 @@ def iswtn(coeffs, wavelet, axes=None): # num_levels, equivalent to the decomposition level, n num_levels = len(coeffs) wavelets = _wavelets_per_axis(wavelet, axes) + if norm: + wavelets = [_rescale_wavelet_filterbank(wav, np.sqrt(2)) + for wav in wavelets] # initialize various slice objects used in the loops below # these will remain slice(None) only on axes that aren't transformed @@ -509,7 +712,8 @@ def iswtn(coeffs, wavelet, axes=None): for j in range(num_levels): step_size = int(pow(2, num_levels-j-1)) last_index = step_size - a = coeffs[j].pop('a'*ndim_transform) # will restore later + if not trim_approx: + a = coeffs[j].pop('a'*ndim_transform) # will restore later details = coeffs[j] # make sure dtype matches the coarsest level approximation coefficients common_dtype = np.result_type(*( @@ -560,5 +764,6 @@ def iswtn(coeffs, wavelet, axes=None): output[tuple(indices)] += x ntransforms += 1 output[tuple(indices)] /= ntransforms # normalize - coeffs[j]['a'*ndim_transform] = a # restore approx coeffs to dict + if not trim_approx: + coeffs[j]['a'*ndim_transform] = a # restore approx coeffs to dict return output diff --git a/pywt/tests/test_swt.py b/pywt/tests/test_swt.py index 8c340c7d0..bb70d3949 100644 --- a/pywt/tests/test_swt.py +++ b/pywt/tests/test_swt.py @@ -153,9 +153,15 @@ def test_swt_iswt_integration(): current_wavelet.rec_len)))) input_length = 2**(input_length_power + max_level - 1) X = np.arange(input_length) - coeffs = pywt.swt(X, current_wavelet, max_level) - Y = pywt.iswt(coeffs, current_wavelet) - assert_allclose(Y, X, rtol=1e-5, atol=1e-7) + for norm in [True, False]: + if norm and not current_wavelet.orthogonal: + # non-orthogonal wavelets to avoid warnings when norm=True + continue + for trim_approx in [True, False]: + coeffs = pywt.swt(X, current_wavelet, max_level, + trim_approx=trim_approx, norm=norm) + Y = pywt.iswt(coeffs, current_wavelet, norm=norm) + assert_allclose(Y, X, rtol=1e-5, atol=1e-7) def test_swt_dtypes(): @@ -235,9 +241,15 @@ def test_swt2_iswt2_integration(wavelets=None): input_length = 2**(input_length_power + max_level - 1) X = np.arange(input_length**2).reshape(input_length, input_length) - coeffs = pywt.swt2(X, current_wavelet, max_level) - Y = pywt.iswt2(coeffs, current_wavelet) - assert_allclose(Y, X, rtol=1e-5, atol=1e-5) + for norm in [True, False]: + if norm and not current_wavelet.orthogonal: + # non-orthogonal wavelets to avoid warnings when norm=True + continue + for trim_approx in [True, False]: + coeffs = pywt.swt2(X, current_wavelet, max_level, + trim_approx=trim_approx, norm=norm) + Y = pywt.iswt2(coeffs, current_wavelet, norm=norm) + assert_allclose(Y, X, rtol=1e-5, atol=1e-5) def test_swt2_iswt2_quick(): @@ -355,10 +367,16 @@ def test_swtn_iswtn_integration(wavelets=None): N = 2**(input_length_power + max_level - 1) X = np.arange(N**ndim).reshape((N, )*ndim) - coeffs = pywt.swtn(X, wav, max_level, axes=axes) - coeffs_copy = deepcopy(coeffs) - Y = pywt.iswtn(coeffs, wav, axes=axes) - assert_allclose(Y, X, rtol=1e-5, atol=1e-5) + for norm in [True, False]: + if norm and not wav.orthogonal: + # non-orthogonal wavelets to avoid warnings + continue + for trim_approx in [True, False]: + coeffs = pywt.swtn(X, wav, max_level, axes=axes, + trim_approx=trim_approx, norm=norm) + coeffs_copy = deepcopy(coeffs) + Y = pywt.iswtn(coeffs, wav, axes=axes, norm=norm) + assert_allclose(Y, X, rtol=1e-5, atol=1e-5) # verify the inverse transform didn't modify any coeffs for c, c2 in zip(coeffs, coeffs_copy): @@ -530,8 +548,86 @@ def test_iswtn_mixed_dtypes(): def test_swt_zero_size_axes(): # raise on empty input array assert_raises(ValueError, pywt.swt, [], 'db2') - + # >1D case uses a different code path so check there as well x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis assert_raises(ValueError, pywt.swtn, x, 'db2', level=1, axes=(0,)) + +def test_swt_variance_and_energy_preservation(): + """Verify that the 1D SWT partitions variance among the coefficients.""" + # When norm is True and the wavelet is orthogonal, the sum of the + # variances of the coefficients should equal the variance of the signal. + wav = 'db2' + rstate = np.random.RandomState(5) + x = rstate.randn(256) + coeffs = pywt.swt(x, wav, trim_approx=True, norm=True) + variances = [np.var(c) for c in coeffs] + assert_allclose(np.sum(variances), np.var(x)) + + # also verify L2-norm energy preservation property + assert_allclose(np.linalg.norm(x), + np.linalg.norm(np.concatenate(coeffs))) + + # non-orthogonal wavelet with norm=True raises a warning + assert_warns(UserWarning, pywt.swt, x, 'bior2.2', norm=True) + + +def test_swt2_variance_and_energy_preservation(): + """Verify that the 2D SWT partitions variance among the coefficients.""" + # When norm is True and the wavelet is orthogonal, the sum of the + # variances of the coefficients should equal the variance of the signal. + wav = 'db2' + rstate = np.random.RandomState(5) + x = rstate.randn(64, 64) + coeffs = pywt.swt2(x, wav, level=4, trim_approx=True, norm=True) + coeff_list = [coeffs[0].ravel()] + for d in coeffs[1:]: + for v in d: + coeff_list.append(v.ravel()) + variances = [np.var(v) for v in coeff_list] + assert_allclose(np.sum(variances), np.var(x)) + + # also verify L2-norm energy preservation property + assert_allclose(np.linalg.norm(x), + np.linalg.norm(np.concatenate(coeff_list))) + + # non-orthogonal wavelet with norm=True raises a warning + assert_warns(UserWarning, pywt.swt2, x, 'bior2.2', level=4, norm=True) + + +def test_swtn_variance_and_energy_preservation(): + """Verify that the nD SWT partitions variance among the coefficients.""" + # When norm is True and the wavelet is orthogonal, the sum of the + # variances of the coefficients should equal the variance of the signal. + wav = 'db2' + rstate = np.random.RandomState(5) + x = rstate.randn(64, 64) + coeffs = pywt.swtn(x, wav, level=4, trim_approx=True, norm=True) + coeff_list = [coeffs[0].ravel()] + for d in coeffs[1:]: + for k, v in d.items(): + coeff_list.append(v.ravel()) + variances = [np.var(v) for v in coeff_list] + assert_allclose(np.sum(variances), np.var(x)) + + # also verify L2-norm energy preservation property + assert_allclose(np.linalg.norm(x), + np.linalg.norm(np.concatenate(coeff_list))) + + # non-orthogonal wavelet with norm=True raises a warning + assert_warns(UserWarning, pywt.swtn, x, 'bior2.2', level=4, norm=True) + + +def test_swt_ravel_and_unravel(): + # When trim_approx=True, all swt functions can user pywt.ravel_coeffs + for ndim, _swt, _iswt, ravel_type in [ + (1, pywt.swt, pywt.iswt, 'swt'), + (2, pywt.swt2, pywt.iswt2, 'swt2'), + (3, pywt.swtn, pywt.iswtn, 'swtn')]: + x = np.ones((16, ) * ndim) + c = _swt(x, 'sym2', level=3, trim_approx=True) + arr, slices, shapes = pywt.ravel_coeffs(c) + c = pywt.unravel_coeffs(arr, slices, shapes, output_format=ravel_type) + r = _iswt(c, 'sym2') + assert_allclose(x, r)