Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSOC] Add EpochsTFR support to spectral connectivity functions #232

Merged
merged 25 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
a9dc973
Add TFR support spec_conn_epochs
tsbinns Sep 6, 2024
535db69
Update epochs docstring
tsbinns Sep 9, 2024
14428f6
Update rank check comments
tsbinns Sep 9, 2024
1b1456d
Add TFR support spec_conn_time
tsbinns Sep 10, 2024
6fd0861
Switch tests to custom MNE branch
tsbinns Sep 10, 2024
fc55d70
Merge branch 'main' into specconn_tfr_support
tsbinns Sep 10, 2024
a38f8fd
Fix failing tfr_error test
tsbinns Sep 10, 2024
9c04f81
Fix time_tfr tolerances
tsbinns Sep 10, 2024
c861449
Fix misleading error message
tsbinns Sep 10, 2024
7013e19
Fix spec_conn_time docstring error
tsbinns Sep 10, 2024
160bc64
Revert "Switch tests to custom MNE branch"
tsbinns Sep 17, 2024
23dc1f2
Merge branch 'main' into specconn_tfr_support
tsbinns Sep 17, 2024
8e79c9f
Apply suggestions from code review
tsbinns Sep 20, 2024
23fd89c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2024
641fa1d
Name expected conn values
tsbinns Sep 23, 2024
232c0e9
Update Welch-Fourier variation message
tsbinns Sep 24, 2024
4898673
Merge branch 'main' into specconn_tfr_support
tsbinns Sep 26, 2024
269d209
Merge branch 'main' into specconn_tfr_support
tsbinns Oct 1, 2024
a0f55d3
Merge branch 'main' into specconn_tfr_support
tsbinns Jan 15, 2025
5a2b811
Update TFR support in spec_conn_epochs
tsbinns Jan 16, 2025
c354385
Update TFR support in spec_conn_time
tsbinns Jan 16, 2025
73ce0a6
Clean up class checking
tsbinns Jan 16, 2025
4dba53e
Merge branch 'main' into specconn_tfr_support
tsbinns Jan 17, 2025
f4c087c
Update weights check logic
tsbinns Jan 17, 2025
cba4d6e
Merge branch 'main' into specconn_tfr_support
tsbinns Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 82 additions & 43 deletions mne_connectivity/spectral/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@
from mne.epochs import BaseEpochs
from mne.parallel import parallel_func
from mne.source_estimate import _BaseSourceEstimate
from mne.time_frequency import (
EpochsSpectrum,
EpochsSpectrumArray,
EpochsTFR,
EpochsTFRArray,
)
from mne.time_frequency.multitaper import (
_compute_mt_params,
_csd_from_mt,
_mt_spectra,
_psd_from_mt,
_psd_from_mt_adaptive,
)
from mne.time_frequency.spectrum import (
BaseSpectrum,
EpochsSpectrum,
EpochsSpectrumArray,
)
from mne.time_frequency.tfr import cwt, morlet
from mne.time_frequency.spectrum import BaseSpectrum
from mne.time_frequency.tfr import BaseTFR, cwt, morlet
from mne.utils import _arange_div, _check_option, _time_mask, logger, verbose, warn

from ..base import SpectralConnectivity, SpectroTemporalConnectivity
Expand Down Expand Up @@ -161,17 +163,19 @@ def _prepare_connectivity(
"""Check and precompute dimensions of results data."""
first_epoch = epoch_block[0]

# Sort times and freqs
if spectrum_computed:
# Sort times
if spectrum_computed and times_in is None: # if Spectrum object passed in as data
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
n_signals = first_epoch[0].shape[0]
times = None
n_times = None
n_times = 0
times_in = None
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
n_times_in = None
n_times_in = 0
tmin_idx = None
tmax_idx = None
warn_times = False
else:
else: # if data has a time dimension (timeseries or TFR object)
if spectrum_computed: # if TFR object passed in as data
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
first_epoch = (first_epoch[0][:, 0],) # just take first freq
(
n_signals,
times,
Expand All @@ -184,6 +188,9 @@ def _prepare_connectivity(
) = _check_times(
data=first_epoch, sfreq=sfreq, times=times_in, tmin=tmin, tmax=tmax
)

# Sort freqs
if not spectrum_computed: # if timeseries passed in as data
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
# check that fmin corresponds to at least 5 cycles
fmin = _check_freqs(sfreq=sfreq, fmin=fmin, n_times=n_times)
# compute frequencies to analyze based on number of samples, sampling rate,
Expand Down Expand Up @@ -511,14 +518,19 @@ def _epoch_spectral_connectivity(

# compute tapered spectra
if spectrum_computed: # use existing spectral info
# XXX: Will need to distinguish time-resolved spectra here if support added
# Select signals & freqs of interest (flexible indexing for optional tapers dim)
x_t = np.array(data)[:, sig_idx][..., freq_mask] # split dims to avoid np.ix_
if weights is None: # also assumes no tapers dim
x_t = np.expand_dims(x_t, axis=2) # CSD construction expects a tapers dim
weights = np.ones((1, 1, 1))
# Select entries of interest (flexible indexing for optional tapers dim)
if tmin_idx is not None and tmax_idx is not None:
x_t = np.asarray(data)[:, sig_idx][..., freq_mask, tmin_idx:tmax_idx]
else:
x_t = np.asarray(data)[:, sig_idx][..., freq_mask]
if weights is None: # assumes no tapers dim
x_t = np.expand_dims(x_t, axis=2) # CSD construction expects tapers dim
weights = np.ones((1, 1, 1))
if accumulate_psd:
this_psd = _psd_from_mt(x_t, weights)
if weights is not None: # only None if mode == 'cwt_morlet'
this_psd = _psd_from_mt(x_t, weights)
else:
this_psd = (x_t * x_t.conj()).real
else: # compute spectral info from scratch
x_t, this_psd, weights = _compute_spectra(
data=data,
Expand Down Expand Up @@ -727,14 +739,15 @@ def spectral_connectivity_epochs(

Parameters
----------
data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs | ~mne.time_frequency.EpochsSpectrum
data : array-like, shape=(n_epochs, n_signals, n_times) | ~mne.Epochs | ~mne.time_frequency.EpochsSpectrum | ~mne.time_frequency.EpochsTFR
The data from which to compute connectivity. Can be epoched timeseries data as
an :term:`array-like` or :class:`~mne.Epochs` object, or Fourier coefficients
for each epoch as an :class:`~mne.time_frequency.EpochsSpectrum` object. If
timeseries data, the spectral information will be computed according to the
spectral estimation mode (see the ``mode`` parameter). If an
:class:`~mne.time_frequency.EpochsSpectrum` object, this spectral information
will be used and the ``mode`` parameter will be ignored.
for each epoch as an :class:`~mne.time_frequency.EpochsSpectrum` or
:class:`~mne.time_frequency.EpochsTFR` object. If timeseries data, the spectral
information will be computed according to the spectral estimation mode (see the
``mode`` parameter). If an :class:`~mne.time_frequency.EpochsSpectrum` or
:class:`~mne.time_frequency.EpochsTFR` object, this spectral information will be
used and the ``mode`` parameter will be ignored.
tsbinns marked this conversation as resolved.
Show resolved Hide resolved

Note that it is also possible to combine multiple timeseries signals by
providing a list of tuples, e.g.: ::
Expand All @@ -748,8 +761,9 @@ def spectral_connectivity_epochs(

.. versionchanged:: 0.8
Fourier coefficients stored in an :class:`~mne.time_frequency.EpochsSpectrum`
or :class:`~mne.time_frequency.EpochsSpectrumArray` object can also be passed
in as data. Storing Fourier coefficients requires ``mne >= 1.8``.
or :class:`~mne.time_frequency.EpochsTFR` object can also be passed in as
data. Storing Fourier coefficients in
:class:`~mne.time_frequency.EpochsSpectrum` objects requires ``mne >= 1.8``.
%(names)s
method : str | list of str
Connectivity measure(s) to compute. These can be ``['coh', 'cohy',
Expand Down Expand Up @@ -789,7 +803,8 @@ def spectral_connectivity_epochs(
mode : str
Spectrum estimation mode can be either: 'multitaper', 'fourier', or
'cwt_morlet'. Ignored if ``data`` is an
:class:`~mne.time_frequency.EpochsSpectrum` object.
:class:`~mne.time_frequency.EpochsSpectrum` or
:class:`~mne.time_frequency.EpochsTFR` object.
fmin : float | tuple of float
The lower frequency of interest. Multiple bands are defined using
a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq.
Expand Down Expand Up @@ -1105,7 +1120,10 @@ def spectral_connectivity_epochs(
weights = None
metadata = None
spectrum_computed = False
if isinstance(data, BaseEpochs | EpochsSpectrum | EpochsSpectrumArray):
if isinstance(
data,
BaseEpochs | EpochsSpectrum | EpochsSpectrumArray | EpochsTFR | EpochsTFRArray,
):
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
names = data.ch_names
sfreq = data.info["sfreq"]

Expand All @@ -1126,28 +1144,47 @@ def spectral_connectivity_epochs(
data.add_annotations_to_metadata(overwrite=True)
metadata = data.metadata

if isinstance(data, EpochsSpectrum | EpochsSpectrumArray):
if isinstance(
data, EpochsSpectrum | EpochsSpectrumArray | EpochsTFR | EpochsTFRArray
):
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
# XXX: Will need to be updated if new Spectrum methods are added
if not np.iscomplexobj(data.get_data()):
raise TypeError(
"if `data` is an EpochsSpectrum object, it must contain "
"complex-valued Fourier coefficients, such as that returned from "
"Epochs.compute_psd(output='complex')"
"if `data` is an EpochsSpectrum or EpochsTFR object, it must "
"contain complex-valued Fourier coefficients, such as that "
"returned from Epochs.compute_psd/tfr() with `output='complex'`"
)
if "segment" in data._dims:
raise ValueError(
"`data` cannot contain Fourier coefficients for individual segments"
)
if isinstance(data, EpochsSpectrum): # mode can be read mode from Spectrum
mode = data.method
mode = "fourier" if mode == "welch" else mode
else: # spectral method is "unknown", so take mode from data dimensions
# Currently, actual mode doesn't matter as long as we handle tapers and
# their weights in the same way as for multitaper spectra
mode = "multitaper" if "taper" in data._dims else "fourier"
mode = data.method
if isinstance(data, EpochsSpectrum | EpochsSpectrumArray):
if isinstance(data, EpochsSpectrum): # read mode from object
mode = "fourier" if mode == "welch" else mode
else: # infer mode from dimensions
# Currently, actual mode doesn't matter as long as we handle tapers
# and their weights in the same way as for multitaper spectra
mode = "multitaper" if "taper" in data._dims else "fourier"
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
weights = data.weights
else:
if isinstance(data, EpochsTFR): # read mode from object
if mode != "morlet": # FIXME: Add support for other TFR methods
raise ValueError(
"if `data` is an EpochsTFR object, the spectral method "
"must be 'morlet'"
)
else:
if "taper" in data._dims: # FIXME: Add support for multitaper TFR
raise ValueError(
"if `data` is an EpochsTFRArray object, it cannot contain "
"Fourier coefficients for individual tapers"
)
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
mode = "cwt_morlet" # currently only supported mode here
times_in = data.times
weights = None # no weights stored in TFR objects
spectrum_computed = True
freqs = data.freqs
weights = data.weights
else:
times_in = data.times # input times for Epochs input type
elif sfreq is None:
Expand Down Expand Up @@ -1235,7 +1272,7 @@ def spectral_connectivity_epochs(
spectral_params = dict(
eigvals=None, window_fun=None, wavelets=None, weights=weights
)
n_times_spectrum = 0
n_times_spectrum = n_times # 0 if no times
n_tapers = None if weights is None else weights.size

# unique signals for which we actually need to compute PSD etc.
Expand Down Expand Up @@ -1289,7 +1326,7 @@ def spectral_connectivity_epochs(
logger.info(f" the following metrics will be computed: {metrics_str}")

# check dimensions and time scale
if not spectrum_computed: # XXX: Can we assume upstream checks sufficient?
if not spectrum_computed:
for this_epoch in epoch_block:
_, _, _, warn_times = _get_and_verify_data_sizes(
this_epoch,
Expand Down Expand Up @@ -1469,7 +1506,9 @@ def spectral_connectivity_epochs(
freqs=freqs,
method=_method,
n_nodes=n_nodes,
spec_method=mode if not isinstance(data, BaseSpectrum) else data.method,
spec_method=(
mode if not isinstance(data, BaseSpectrum | BaseTFR) else data.method
),
indices=indices,
n_epochs_used=n_epochs,
freqs_used=freqs_used,
Expand Down
13 changes: 12 additions & 1 deletion mne_connectivity/spectral/epochs_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
import numpy as np
from mne.epochs import BaseEpochs
from mne.parallel import parallel_func
from mne.time_frequency import EpochsSpectrum, EpochsSpectrumArray
from mne.time_frequency import (
EpochsSpectrum,
EpochsSpectrumArray,
EpochsTFR,
EpochsTFRArray,
)
from mne.time_frequency.multitaper import _psd_from_mt
from mne.utils import ProgressBar, _validate_type, logger

Expand All @@ -40,6 +45,12 @@ def _check_rank_input(rank, data, indices):
data_arr = _psd_from_mt(data_arr, data.weights)
else:
data_arr = (data_arr * data_arr.conj()).real
elif isinstance(data, EpochsTFR | EpochsTFRArray):
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
# TFR objs will drop bad channels, so specify picking all channels
data_arr = data.get_data(picks=np.arange(data.info["nchan"]))
# Convert to power and aggregate over time before computing rank
# XXX: need to change when other types of TFR are supported
data_arr = np.sum((data_arr * data_arr.conj()).real, axis=-1)
else:
data_arr = data

Expand Down
Loading