Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions benchmarks/benchmarks/cwt_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def setup(self, n, wavelet, max_scale, dtype, method):
except ImportError:
raise NotImplementedError("cwt not available")
self.data = np.ones(n, dtype=dtype)
self.batch_data = np.ones((5, n), dtype=dtype)
self.scales = np.arange(1, max_scale + 1)


Expand All @@ -33,3 +34,12 @@ def time_cwt(self, n, wavelet, max_scale, dtype, method):
raise NotImplementedError(
"fft-based convolution not available.")
pywt.cwt(self.data, self.scales, wavelet)

def time_cwt_batch(self, n, wavelet, max_scale, dtype, method):
try:
pywt.cwt(self.batch_data, self.scales, wavelet, method=method,
axis=-1)
except TypeError:
# older PyWavelets does not support the axis argument
raise NotImplementedError(
"axis argument not available.")
146 changes: 87 additions & 59 deletions pywt/_cwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def next_fast_len(n):
return 2**ceil(np.log2(n))


def cwt(data, scales, wavelet, sampling_period=1., method='conv'):
def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
"""
cwt(data, scales, wavelet)

Expand Down Expand Up @@ -66,12 +66,16 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv'):
The ``fft`` method is ``O(N * log2(N))`` with
``N = len(scale) + len(data) - 1``. It is well suited for large size
signals but slightly slower than ``conv`` on small ones.
axis: int, optional
Axis over which to compute the CWT. If not given, the last axis is
used.

Returns
-------
coefs : array_like
Continuous wavelet transform of the input signal for the given scales
and wavelet
and wavelet. The first axis of ``coefs`` corresponds to the scales.
The remaining axes match the shape of ``data``.
frequencies : array_like
If the unit of sampling period are seconds and given, than frequencies
are in hertz. Otherwise, a sampling period of 1 is assumed.
Expand Down Expand Up @@ -112,62 +116,86 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv'):
wavelet = DiscreteContinuousWavelet(wavelet)
if np.isscalar(scales):
scales = np.array([scales])
if data.ndim == 1:
dt_out = dt_cplx if wavelet.complex_cwt else dt
out = np.empty((np.size(scales), data.size), dtype=dt_out)
precision = 10
int_psi, x = integrate_wavelet(wavelet, precision=precision)

# convert int_psi, x to the same precision as the data
dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt
int_psi = np.asarray(int_psi, dtype=dt_psi)
x = np.asarray(x, dtype=data.real.dtype)

if method == 'fft':
size_scale0 = -1
fft_data = None
elif not method == 'conv':
raise ValueError("method must be 'conv' or 'fft'")

for i, scale in enumerate(scales):
step = x[1] - x[0]
j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)
j = j.astype(int) # floor
if j[-1] >= int_psi.size:
j = np.extract(j < int_psi.size, j)
int_psi_scale = int_psi[j][::-1]

if method == 'conv':
if not np.isscalar(axis):
raise ValueError("axis must be a scalar.")

dt_out = dt_cplx if wavelet.complex_cwt else dt
out = np.empty((np.size(scales),) + data.shape, dtype=dt_out)
precision = 10
int_psi, x = integrate_wavelet(wavelet, precision=precision)

# convert int_psi, x to the same precision as the data
dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt
int_psi = np.asarray(int_psi, dtype=dt_psi)
x = np.asarray(x, dtype=data.real.dtype)

if method == 'fft':
size_scale0 = -1
fft_data = None
elif not method == 'conv':
raise ValueError("method must be 'conv' or 'fft'")

if data.ndim > 1:
# move axis to be transformed last (so it is contiguous)
data = data.swapaxes(-1, axis)

# reshape to (n_batch, data.shape[-1])
data_shape_pre = data.shape
data = data.reshape((-1, data.shape[-1]))

for i, scale in enumerate(scales):
step = x[1] - x[0]
j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)
j = j.astype(int) # floor
if j[-1] >= int_psi.size:
j = np.extract(j < int_psi.size, j)
int_psi_scale = int_psi[j][::-1]

if method == 'conv':
if data.ndim == 1:
conv = np.convolve(data, int_psi_scale)
else:
# The padding is selected for:
# - optimal FFT complexity
# - to be larger than the two signals length to avoid circular
# convolution
size_scale = next_fast_len(data.size + int_psi_scale.size - 1)
if size_scale != size_scale0:
# Must recompute fft_data when the padding size changes.
fft_data = fftmodule.fft(data, size_scale)
size_scale0 = size_scale
fft_wav = fftmodule.fft(int_psi_scale, size_scale)
conv = fftmodule.ifft(fft_wav * fft_data)
conv = conv[:data.size + int_psi_scale.size - 1]

coef = - np.sqrt(scale) * np.diff(conv)
if out.dtype.kind != 'c':
coef = coef.real
d = (coef.size - data.size) / 2.
if d > 0:
out[i, :] = coef[floor(d):-ceil(d)]
elif d == 0.:
out[i, :] = coef
else:
raise ValueError(
"Selected scale of {} too small.".format(scale))
frequencies = scale2frequency(wavelet, scales, precision)
if np.isscalar(frequencies):
frequencies = np.array([frequencies])
frequencies /= sampling_period
return out, frequencies
else:
raise ValueError("Only dim == 1 supported")
# batch convolution via loop
conv_shape = list(data.shape)
conv_shape[-1] += int_psi_scale.size - 1
conv_shape = tuple(conv_shape)
conv = np.empty(conv_shape, dtype=dt_out)
for n in range(data.shape[0]):
conv[n, :] = np.convolve(data[n], int_psi_scale)
else:
# The padding is selected for:
# - optimal FFT complexity
# - to be larger than the two signals length to avoid circular
# convolution
size_scale = next_fast_len(
data.shape[-1] + int_psi_scale.size - 1
)
if size_scale != size_scale0:
# Must recompute fft_data when the padding size changes.
fft_data = fftmodule.fft(data, size_scale, axis=-1)
size_scale0 = size_scale
fft_wav = fftmodule.fft(int_psi_scale, size_scale, axis=-1)
conv = fftmodule.ifft(fft_wav * fft_data, axis=-1)
conv = conv[..., :data.shape[-1] + int_psi_scale.size - 1]

coef = - np.sqrt(scale) * np.diff(conv, axis=-1)
if out.dtype.kind != 'c':
coef = coef.real
# transform axis is always -1 due to the data reshape above
d = (coef.shape[-1] - data.shape[-1]) / 2.
if d > 0:
coef = coef[..., floor(d):-ceil(d)]
elif d < 0:
raise ValueError(
"Selected scale of {} too small.".format(scale))
if data.ndim > 1:
# restore original data shape and axis position
coef = coef.reshape(data_shape_pre)
coef = coef.swapaxes(axis, -1)
out[i, ...] = coef

frequencies = scale2frequency(wavelet, scales, precision)
if np.isscalar(frequencies):
frequencies = np.array([frequencies])
frequencies /= sampling_period
return out, frequencies
84 changes: 61 additions & 23 deletions pywt/tests/test_cwt_wavelets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
from itertools import product

from numpy.testing import (assert_allclose, assert_warns, assert_almost_equal,
assert_raises, assert_equal)
import pytest
import numpy as np
import pywt

Expand Down Expand Up @@ -344,29 +346,65 @@ def test_cwt_parameters_in_names():
assert_raises(ValueError, func, 'fbsp1-1-1-1')


def test_cwt_complex():
for dtype, tol in [(np.float32, 1e-5), (np.float64, 1e-13)]:
time, sst = pywt.data.nino()
sst = np.asarray(sst, dtype=dtype)
dt = time[1] - time[0]
wavelet = 'cmor1.5-1.0'
scales = np.arange(1, 32)

for method in ['conv', 'fft']:
# real-valued tranfsorm as a reference
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method)

# verify same precision
assert_equal(cfs.real.dtype, sst.dtype)

# complex-valued transform equals sum of the transforms of the real
# and imaginary components
sst_complex = sst + 1j*sst
[cfs_complex, f] = pywt.cwt(sst_complex, scales, wavelet, dt,
method=method)
assert_allclose(cfs + 1j*cfs, cfs_complex, atol=tol, rtol=tol)
# verify dtype is preserved
assert_equal(cfs_complex.dtype, sst_complex.dtype)
@pytest.mark.parametrize('dtype, tol, method',
[(np.float32, 1e-5, 'conv'),
(np.float32, 1e-5, 'fft'),
(np.float64, 1e-13, 'conv'),
(np.float64, 1e-13, 'fft')])
def test_cwt_complex(dtype, tol, method):
time, sst = pywt.data.nino()
sst = np.asarray(sst, dtype=dtype)
dt = time[1] - time[0]
wavelet = 'cmor1.5-1.0'
scales = np.arange(1, 32)

# real-valued tranfsorm as a reference
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method)

# verify same precision
assert_equal(cfs.real.dtype, sst.dtype)

# complex-valued transform equals sum of the transforms of the real
# and imaginary components
sst_complex = sst + 1j*sst
[cfs_complex, f] = pywt.cwt(sst_complex, scales, wavelet, dt,
method=method)
assert_allclose(cfs + 1j*cfs, cfs_complex, atol=tol, rtol=tol)
# verify dtype is preserved
assert_equal(cfs_complex.dtype, sst_complex.dtype)


@pytest.mark.parametrize('axis, method', product([0, 1], ['conv', 'fft']))
def test_cwt_batch(axis, method):
dtype = np.float64
time, sst = pywt.data.nino()
n_batch = 8
batch_axis = 1 - axis
sst1 = np.asarray(sst, dtype=dtype)
sst = np.stack((sst1, ) * n_batch, axis=batch_axis)
dt = time[1] - time[0]
wavelet = 'cmor1.5-1.0'
scales = np.arange(1, 32)

# non-batch transform as reference
[cfs1, f] = pywt.cwt(sst1, scales, wavelet, dt, method=method, axis=axis)

shape_in = sst.shape
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method, axis=axis)

# shape of input is not modified
assert_equal(shape_in, sst.shape)

# verify same precision
assert_equal(cfs.real.dtype, sst.dtype)

# verify expected shape
assert_equal(cfs.shape[0], len(scales))
assert_equal(cfs.shape[1 + batch_axis], n_batch)
assert_equal(cfs.shape[1 + axis], sst.shape[axis])

# batch result on stacked input is the same as stacked 1d result
assert_equal(cfs, np.stack((cfs1,) * n_batch, axis=batch_axis + 1))


def test_cwt_small_scales():
Expand Down