Skip to content
18 changes: 9 additions & 9 deletions pywt/_dwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def dwt_coeff_len(data_len, filter_len, mode):
Data length.
filter_len : int
Filter length.
mode : str, optional (default: 'symmetric')
Signal extension mode, see Modes
mode : str, optional
Signal extension mode, see :ref:`Modes <ref-modes>`.

Returns
-------
Expand Down Expand Up @@ -130,7 +130,7 @@ def dwt(data, wavelet, mode='symmetric', axis=-1):
wavelet : Wavelet object or name
Wavelet to use
mode : str, optional
Signal extension mode, see Modes
Signal extension mode, see :ref:`Modes <ref-modes>`.
axis: int, optional
Axis over which to compute the DWT. If not given, the
last axis is used.
Expand Down Expand Up @@ -199,14 +199,14 @@ def idwt(cA, cD, wavelet, mode='symmetric', axis=-1):
----------
cA : array_like or None
Approximation coefficients. If None, will be set to array of zeros
with same shape as `cD`.
with same shape as ``cD``.
cD : array_like or None
Detail coefficients. If None, will be set to array of zeros
with same shape as `cA`.
with same shape as ``cA``.
wavelet : Wavelet object or name
Wavelet to use
mode : str, optional (default: 'symmetric')
Signal extension mode, see Modes
Signal extension mode, see :ref:`Modes <ref-modes>`.
axis: int, optional
Axis over which to compute the inverse DWT. If not given, the
last axis is used.
Expand All @@ -224,7 +224,7 @@ def idwt(cA, cD, wavelet, mode='symmetric', axis=-1):
>>> pywt.idwt(cA, cD, 'db2', 'smooth')
array([ 1., 2., 3., 4., 5., 6.])

One of the neat features of `idwt` is that one of the ``cA`` and ``cD``
One of the neat features of ``idwt`` is that one of the ``cA`` and ``cD``
arguments can be set to None. In that situation the reconstruction will be
performed using only the other one. Mathematically speaking, this is
equivalent to passing a zero-filled array as one of the arguments.
Expand Down Expand Up @@ -300,7 +300,7 @@ def downcoef(part, data, wavelet, mode='symmetric', level=1):

Partial Discrete Wavelet Transform data decomposition.

Similar to `pywt.dwt`, but computes only one set of coefficients.
Similar to ``pywt.dwt``, but computes only one set of coefficients.
Useful when you need only approximation or only details at the given level.

Parameters
Expand All @@ -316,7 +316,7 @@ def downcoef(part, data, wavelet, mode='symmetric', level=1):
wavelet : Wavelet object or name
Wavelet to use
mode : str, optional
Signal extension mode, see `Modes`. Default is 'symmetric'.
Signal extension mode, see :ref:`Modes <ref-modes>`.
level : int, optional
Decomposition level. Default is 1.

Expand Down
23 changes: 19 additions & 4 deletions pywt/_extensions/_swt.pyx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#cython: boundscheck=False, wraparound=False
from . cimport common
from . cimport c_wt
from cpython cimport bool

import warnings
import numpy as np
Expand All @@ -9,6 +10,7 @@ cimport numpy as np
from .common cimport pywt_index_t
from ._pywt cimport c_wavelet_from_object, cdata_t, Wavelet, _check_dtype


include "config.pxi"

def swt_max_level(size_t input_len):
Expand Down Expand Up @@ -47,7 +49,8 @@ def swt_max_level(size_t input_len):
return max_level


def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level):
def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level,
bool trim_approx=False):
cdef cdata_t[::1] cA, cD
cdef Wavelet w
cdef int retval
Expand Down Expand Up @@ -142,14 +145,20 @@ def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level):
raise RuntimeError("C swt failed.")

data = cA
ret.append((cA, cD))
if not trim_approx:
ret.append((np.asarray(cA), np.asarray(cD)))
else:
ret.append(np.asarray(cD))

if trim_approx:
ret.append(np.asarray(cA))
ret.reverse()
return ret


cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
size_t start_level, unsigned int axis=0):
size_t start_level, unsigned int axis=0,
bool trim_approx=False):
# memory-views do not support n-dimensional arrays, use np.ndarray instead
cdef common.ArrayInfo data_info, output_info
cdef np.ndarray cD, cA
Expand Down Expand Up @@ -289,13 +298,19 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
if retval == -5:
raise TypeError("Array must be floating point, not {}"
.format(data.dtype))
ret.append((cA, cD))
if not trim_approx:
ret.append((cA, cD))
else:
ret.append(cD)

# previous approx coeffs are the data for the next level
data = cA
# update data_info to match the new data array
data_info.strides = <pywt_index_t *> data.strides
data_info.shape = <size_t *> data.shape

if trim_approx:
ret.append(cA)

ret.reverse()
return ret
10 changes: 5 additions & 5 deletions pywt/_multidim.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def dwt2(data, wavelet, mode='symmetric', axes=(-2, -1)):
Wavelet to use. This can also be a tuple containing a wavelet to
apply along each axis in ``axes``.
mode : str or 2-tuple of strings, optional
Signal extension mode, see Modes (default: 'symmetric'). This can
Signal extension mode, see :ref:`Modes <ref-modes>`. This can
also be a tuple of modes specifying the mode to use on each axis in
``axes``.
axes : 2-tuple of ints, optional
Expand Down Expand Up @@ -84,13 +84,13 @@ def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)):
----------
coeffs : tuple
(cA, (cH, cV, cD)) A tuple with approximation coefficients and three
details coefficients 2D arrays like from `dwt2`. If any of these
details coefficients 2D arrays like from ``dwt2``. If any of these
components are set to ``None``, it will be treated as zeros.
wavelet : Wavelet object or name string, or 2-tuple of wavelets
Wavelet to use. This can also be a tuple containing a wavelet to
apply along each axis in ``axes``.
mode : str or 2-tuple of strings, optional
Signal extension mode, see Modes (default: 'symmetric'). This can
Signal extension mode, see :ref:`Modes <ref-modes>`. This can
also be a tuple of modes specifying the mode to use on each axis in
``axes``.
axes : 2-tuple of ints, optional
Expand Down Expand Up @@ -131,7 +131,7 @@ def dwtn(data, wavelet, mode='symmetric', axes=None):
apply along each axis in ``axes``.
mode : str or tuple of string, optional
Signal extension mode used in the decomposition,
see Modes (default: 'symmetric'). This can also be a tuple of modes
see :ref:`Modes <ref-modes>`. This can also be a tuple of modes
specifying the mode to use on each axis in ``axes``.
axes : sequence of ints, optional
Axes over which to compute the DWT. Repeated elements mean the DWT will
Expand Down Expand Up @@ -233,7 +233,7 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
apply along each axis in ``axes``.
mode : str or list of string, optional
Signal extension mode used in the decomposition,
see Modes (default: 'symmetric'). This can also be a tuple of modes
see :ref:`Modes <ref-modes>`. This can also be a tuple of modes
specifying the mode to use on each axis in ``axes``.
axes : sequence of ints, optional
Axes over which to compute the IDWT. Repeated elements mean the IDWT
Expand Down
Loading