diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 73235eaa..0bf902e3 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -12,6 +12,12 @@ from mne.epochs import BaseEpochs from mne.parallel import parallel_func from mne.source_estimate import _BaseSourceEstimate +from mne.time_frequency import ( + EpochsSpectrum, + EpochsSpectrumArray, + EpochsTFR, + EpochsTFRArray, +) from mne.time_frequency.multitaper import ( _compute_mt_params, _csd_from_mt, @@ -19,12 +25,7 @@ _psd_from_mt, _psd_from_mt_adaptive, ) -from mne.time_frequency.spectrum import ( - BaseSpectrum, - EpochsSpectrum, - EpochsSpectrumArray, -) -from mne.time_frequency.tfr import cwt, morlet +from mne.time_frequency.tfr import _tfr_from_mt, cwt, morlet from mne.utils import _arange_div, _check_option, _time_mask, logger, verbose, warn from ..base import SpectralConnectivity, SpectroTemporalConnectivity @@ -161,17 +162,21 @@ def _prepare_connectivity( """Check and precompute dimensions of results data.""" first_epoch = epoch_block[0] - # Sort times and freqs - if spectrum_computed: + # Sort times + if spectrum_computed and times_in is None: # is a Spectrum object n_signals = first_epoch[0].shape[0] times = None - n_times = None - times_in = None - n_times_in = None + n_times = 0 + n_times_in = 0 tmin_idx = None tmax_idx = None warn_times = False - else: + else: # data has a time dimension (timeseries or TFR object) + if spectrum_computed: # is a TFR object + if mode == "cwt_morlet": + first_epoch = (first_epoch[0][:, 0],) # just take first freq + else: # multitaper + first_epoch = (first_epoch[0][:, 0, 0],) # take first taper and freq ( n_signals, times, @@ -184,6 +189,9 @@ def _prepare_connectivity( ) = _check_times( data=first_epoch, sfreq=sfreq, times=times_in, tmin=tmin, tmax=tmax ) + + # Sort freqs + if not spectrum_computed: # is an (ordinary) timeseries # check that fmin corresponds to at least 5 cycles fmin = _check_freqs(sfreq=sfreq, fmin=fmin, n_times=n_times) # compute frequencies to analyze based on number of samples, sampling rate, @@ -306,6 +314,7 @@ def _assemble_spectral_params( spectral_params = dict(eigvals=None, window_fun=None, wavelets=None, weights=None) n_tapers = None n_times_spectrum = 0 + is_tfr_con = False if mode == "multitaper": window_fun, eigvals, mt_adaptive = _compute_mt_params( n_times, sfreq, mt_bandwidth, mt_low_bias, mt_adaptive @@ -333,9 +342,10 @@ def _assemble_spectral_params( wavelets=morlet(sfreq, freqs, n_cycles=cwt_n_cycles, zero_mean=True) ) n_times_spectrum = n_times + is_tfr_con = True else: raise ValueError("mode has an invalid value") - return spectral_params, mt_adaptive, n_times_spectrum, n_tapers + return spectral_params, mt_adaptive, n_times_spectrum, n_tapers, is_tfr_con ######################################################################## @@ -434,6 +444,37 @@ def _compute_spectra( return x_t, this_psd, weights +def _tfr_csd_from_mt(x_mt, y_mt, weights_x, weights_y): + """Compute time-frequency CSD from tapered spectra. + + Parameters + ---------- + x_mt : array, shape (..., n_tapers, n_freqs, n_times) + The tapered time-frequency spectra for signals x. + y_mt : array, shape (..., n_tapers, n_freqs, n_times) + The tapered time-frequency spectra for signals y. + weights_x : array, shape (n_tapers, n_freqs) + Weights to use for combining the tapered spectra of x_mt. + weights_y : array, shape (n_tapers, n_freqs) + Weights to use for combining the tapered spectra of y_mt. + + Returns + ------- + csd : array, shape (..., n_freqs, n_times) + The CSD between x and y. + """ + # expand weights dims to match x_mt and y_mt + weights_x = np.expand_dims(weights_x, axis=(*np.arange(x_mt.ndim - 3), -1)) + weights_y = np.expand_dims(weights_y, axis=(*np.arange(y_mt.ndim - 3), -1)) + # compute CSD + csd = np.sum(weights_x * x_mt * (weights_y * y_mt).conj(), axis=-3) + denom = np.sqrt((weights_x * weights_x.conj()).real.sum(axis=-3)) * np.sqrt( + (weights_y * weights_y.conj()).real.sum(axis=-3) + ) + csd *= 2 / denom + return csd + + def _epoch_spectral_connectivity( data, sig_idx, @@ -461,6 +502,7 @@ def _epoch_spectral_connectivity( gc_n_lags, n_components, spectrum_computed, + is_tfr_con, accumulate_inplace=True, ): """Estimate connectivity for one epoch (see spectral_connectivity).""" @@ -511,14 +553,22 @@ def _epoch_spectral_connectivity( # compute tapered spectra if spectrum_computed: # use existing spectral info - # XXX: Will need to distinguish time-resolved spectra here if support added - # Select signals & freqs of interest (flexible indexing for optional tapers dim) - x_t = np.array(data)[:, sig_idx][..., freq_mask] # split dims to avoid np.ix_ - if weights is None: # also assumes no tapers dim - x_t = np.expand_dims(x_t, axis=2) # CSD construction expects a tapers dim - weights = np.ones((1, 1, 1)) + # Select entries of interest (flexible indexing for optional tapers dim) + if tmin_idx is not None and tmax_idx is not None: # TFR spectra + x_t = np.asarray(data)[:, sig_idx][..., freq_mask, tmin_idx:tmax_idx] + else: # normal spectra + x_t = np.asarray(data)[:, sig_idx][..., freq_mask] + if weights is None: # assumes no tapers dim, i.e., for Fourier/Welch mode + x_t = np.expand_dims(x_t, axis=2) # CSD construction expects tapers dim + weights = np.ones((1, 1, 1)) # assign dummy weights if accumulate_psd: - this_psd = _psd_from_mt(x_t, weights) + if weights is not None: # mode == 'multitaper' or 'fourier' + if not is_tfr_con: # normal spectra (multitaper or Fourier) + this_psd = _psd_from_mt(x_t, weights) + else: # TFR spectra (multitaper) + this_psd = np.array([_tfr_from_mt(epo_x, weights) for epo_x in x_t]) + else: # mode == 'cwt_morlet' + this_psd = (x_t * x_t.conj()).real else: # compute spectral info from scratch x_t, this_psd, weights = _compute_spectra( data=data, @@ -549,28 +599,29 @@ def _epoch_spectral_connectivity( psd = None # tell the methods that a new epoch starts - for method in con_methods: - method.start_epoch() + for this_method in con_methods: + this_method.start_epoch() # accumulate connectivity scores if mode in ["multitaper", "fourier"]: for i in range(0, n_con_signals, block_size): n_extra = max(0, i + block_size - n_con_signals) con_idx = slice(i, i + block_size - n_extra) + compute_csd = _csd_from_mt if not is_tfr_con else _tfr_csd_from_mt if mt_adaptive: - csd = _csd_from_mt( + csd = compute_csd( x_t[idx_map[0][con_idx]], x_t[idx_map[1][con_idx]], weights[idx_map[0][con_idx]], weights[idx_map[1][con_idx]], ) else: - csd = _csd_from_mt( + csd = compute_csd( x_t[idx_map[0][con_idx]], x_t[idx_map[1][con_idx]], weights, weights ) - for method in con_methods: - method.accumulate(con_idx, csd) + for this_method in con_methods: + this_method.accumulate(con_idx, csd) else: # mode == 'cwt_morlet' # reminder to add alternative TFR methods for i in range(0, n_con_signals, block_size): n_extra = max(0, i + block_size - n_con_signals) @@ -578,9 +629,9 @@ def _epoch_spectral_connectivity( # this codes can be very slow csd = x_t[idx_map[0][con_idx]] * x_t[idx_map[1][con_idx]].conjugate() - for method in con_methods: - method.accumulate(con_idx, csd) - # future estimator types need to be explicitly handled here + for this_method in con_methods: + this_method.accumulate(con_idx, csd) + # future estimator types need to be explicitly handled here return con_methods, psd @@ -727,13 +778,14 @@ def spectral_connectivity_epochs( Parameters ---------- - data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs | ~mne.time_frequency.EpochsSpectrum + data : array-like, shape=(n_epochs, n_signals, n_times) | ~mne.Epochs | ~mne.time_frequency.EpochsSpectrum | ~mne.time_frequency.EpochsTFR The data from which to compute connectivity. Can be epoched timeseries data as an :term:`array-like` or :class:`~mne.Epochs` object, or Fourier coefficients - for each epoch as an :class:`~mne.time_frequency.EpochsSpectrum` object. If - timeseries data, the spectral information will be computed according to the - spectral estimation mode (see the ``mode`` parameter). If an - :class:`~mne.time_frequency.EpochsSpectrum` object, this spectral information + for each epoch as an :class:`~mne.time_frequency.EpochsSpectrum` or + :class:`~mne.time_frequency.EpochsTFR` object. If timeseries data, the spectral + information will be computed according to the spectral estimation mode (see the + ``mode`` parameter). If an :class:`~mne.time_frequency.EpochsSpectrum` or + :class:`~mne.time_frequency.EpochsTFR` object, existing spectral information will be used and the ``mode`` parameter will be ignored. Note that it is also possible to combine multiple timeseries signals by @@ -748,8 +800,11 @@ def spectral_connectivity_epochs( .. versionchanged:: 0.8 Fourier coefficients stored in an :class:`~mne.time_frequency.EpochsSpectrum` - or :class:`~mne.time_frequency.EpochsSpectrumArray` object can also be passed - in as data. Storing Fourier coefficients requires ``mne >= 1.8``. + or :class:`~mne.time_frequency.EpochsTFR` object can also be passed in as + data. Storing Fourier coefficients in + :class:`~mne.time_frequency.EpochsSpectrum` objects requires ``mne >= 1.8``. + Storing multitaper weights in :class:`~mne.time_frequency.EpochsTFR` objects + requires ``mne >= 1.10``. %(names)s method : str | list of str Connectivity measure(s) to compute. These can be ``['coh', 'cohy', @@ -789,7 +844,8 @@ def spectral_connectivity_epochs( mode : str Spectrum estimation mode can be either: 'multitaper', 'fourier', or 'cwt_morlet'. Ignored if ``data`` is an - :class:`~mne.time_frequency.EpochsSpectrum` object. + :class:`~mne.time_frequency.EpochsSpectrum` or + :class:`~mne.time_frequency.EpochsTFR` object. fmin : float | tuple of float | None The lower frequency of interest. Multiple bands are defined using a tuple, e.g., (8., 20.) for two bands with 8 Hz and 20 Hz lower freq. @@ -821,24 +877,27 @@ def spectral_connectivity_epochs( mt_bandwidth : float | None The bandwidth of the multitaper windowing function in Hz. Only used in 'multitaper' mode. Ignored if ``data`` is an - :class:`~mne.time_frequency.EpochsSpectrum` object. + :class:`~mne.time_frequency.EpochsSpectrum` or + :class:`~mne.time_frequency.EpochsTFR` object. mt_adaptive : bool Use adaptive weights to combine the tapered spectra into PSD. Only used in 'multitaper' mode. Ignored if ``data`` is an - :class:`~mne.time_frequency.EpochsSpectrum` object. + :class:`~mne.time_frequency.EpochsSpectrum` or + :class:`~mne.time_frequency.EpochsTFR` object. mt_low_bias : bool Only use tapers with more than 90 percent spectral concentration within bandwidth. Only used in 'multitaper' mode. Ignored if ``data`` is an - :class:`~mne.time_frequency.EpochsSpectrum` object. + :class:`~mne.time_frequency.EpochsSpectrum` or + :class:`~mne.time_frequency.EpochsTFR` object. cwt_freqs : array Array of frequencies of interest. Only used in 'cwt_morlet' mode. Only the frequencies within the range specified by ``fmin`` and ``fmax`` are - used. Ignored if ``data`` is an - :class:`~mne.time_frequency.EpochsSpectrum` object. + used. Ignored if ``data`` is an :class:`~mne.time_frequency.EpochsSpectrum` or + :class:`~mne.time_frequency.EpochsTFR` object. cwt_n_cycles : float | array of float Number of cycles. Fixed number or one per frequency. Only used in 'cwt_morlet' - mode. Ignored if ``data`` is an :class:`~mne.time_frequency.EpochsSpectrum` - object. + mode. Ignored if ``data`` is an :class:`~mne.time_frequency.EpochsSpectrum` or + :class:`~mne.time_frequency.EpochsTFR` object. gc_n_lags : int Number of lags to use for the vector autoregressive model when computing Granger causality. Higher values increase computational cost, @@ -1110,7 +1169,8 @@ def spectral_connectivity_epochs( weights = None metadata = None spectrum_computed = False - if isinstance(data, BaseEpochs | EpochsSpectrum | EpochsSpectrumArray): + is_tfr_con = False + if isinstance(data, BaseEpochs | EpochsSpectrum | EpochsTFR): names = data.ch_names sfreq = data.info["sfreq"] @@ -1131,28 +1191,50 @@ def spectral_connectivity_epochs( data.add_annotations_to_metadata(overwrite=True) metadata = data.metadata - if isinstance(data, EpochsSpectrum | EpochsSpectrumArray): - # XXX: Will need to be updated if new Spectrum methods are added + if isinstance(data, EpochsSpectrum | EpochsTFR): + # XXX: Will need to be updated if new Spectrum/TFR methods are added if not np.iscomplexobj(data.get_data()): raise TypeError( - "if `data` is an EpochsSpectrum object, it must contain " - "complex-valued Fourier coefficients, such as that returned from " - "Epochs.compute_psd(output='complex')" + "if `data` is an EpochsSpectrum or EpochsTFR object, it must " + "contain complex-valued Fourier coefficients, such as that " + "returned from Epochs.compute_psd/tfr() with `output='complex'`" ) if "segment" in data._dims: raise ValueError( "`data` cannot contain Fourier coefficients for individual segments" ) - if isinstance(data, EpochsSpectrum): # mode can be read mode from Spectrum - mode = data.method - mode = "fourier" if mode == "welch" else mode - else: # spectral method is "unknown", so take mode from data dimensions - # Currently, actual mode doesn't matter as long as we handle tapers and - # their weights in the same way as for multitaper spectra - mode = "multitaper" if "taper" in data._dims else "fourier" + mode = data.method + if isinstance(data, EpochsSpectrum): + if isinstance(data, EpochsSpectrumArray): # infer mode from dimensions + # Currently, actual mode doesn't matter as long as we handle tapers + # and their weights in the same way as for multitaper spectra + mode = "multitaper" if "taper" in data._dims else "fourier" + else: # read mode from object + mode = "fourier" if mode == "welch" else mode + else: + if isinstance(data, EpochsTFRArray): # infer mode from dimensions + # Currently, actual mode doesn't matter as long as we handle tapers + # and their weights in the same way as for multitaper spectra + mode = "multitaper" if "taper" in data._dims else "morlet" + else: + mode = "cwt_morlet" if mode == "morlet" else mode + is_tfr_con = True + times_in = data.times spectrum_computed = True freqs = data.freqs - weights = data.weights + # Extract weights from the EpochsSpectrum/TFR object + if not hasattr(data, "weights") or ( + data.weights is None and mode == "multitaper" + ): + # XXX: Remove logic when support for mne<1.10 is dropped + raise AttributeError( + "weights are required for multitaper coefficients stored in " + "EpochsSpectrum (requires mne >= 1.8) and EpochsTFR (requires " + "mne >= 1.10) objects; objects saved from older versions of mne " + "will need to be recomputed." + ) + if hasattr(data, "weights"): + weights = data.weights else: times_in = data.times # input times for Epochs input type elif sfreq is None: @@ -1222,7 +1304,7 @@ def spectral_connectivity_epochs( # get the window function, wavelets, etc for different modes if not spectrum_computed: - spectral_params, mt_adaptive, n_times_spectrum, n_tapers = ( + spectral_params, mt_adaptive, n_times_spectrum, n_tapers, is_tfr_con = ( _assemble_spectral_params( mode=mode, n_times=n_times, @@ -1240,8 +1322,8 @@ def spectral_connectivity_epochs( spectral_params = dict( eigvals=None, window_fun=None, wavelets=None, weights=weights ) - n_times_spectrum = 0 - n_tapers = None if weights is None else weights.size + n_times_spectrum = n_times # 0 if no times + n_tapers = None if weights is None else weights.shape[0] # unique signals for which we actually need to compute PSD etc. if multivariate_con: @@ -1294,7 +1376,7 @@ def spectral_connectivity_epochs( logger.info(f" the following metrics will be computed: {metrics_str}") # check dimensions and time scale - if not spectrum_computed: # XXX: Can we assume upstream checks sufficient? + if not spectrum_computed: for this_epoch in epoch_block: _, _, _, warn_times = _get_and_verify_data_sizes( this_epoch, @@ -1327,6 +1409,7 @@ def spectral_connectivity_epochs( gc_n_lags=gc_n_lags, n_components=n_components, spectrum_computed=spectrum_computed, + is_tfr_con=is_tfr_con, accumulate_inplace=True if n_jobs == 1 else False, ) call_params.update(**spectral_params) @@ -1474,7 +1557,11 @@ def spectral_connectivity_epochs( freqs=freqs, method=_method, n_nodes=n_nodes, - spec_method=mode if not isinstance(data, BaseSpectrum) else data.method, + spec_method=( + mode + if not isinstance(data, EpochsSpectrum | EpochsTFR) + else data.method + ), indices=indices, n_epochs_used=n_epochs, freqs_used=freqs_used, @@ -1489,10 +1576,9 @@ def spectral_connectivity_epochs( if n_components and _method in _multicomp_methods: kwargs.update(components=np.arange(n_components) + 1) # create the connectivity container - if mode in ["multitaper", "fourier"]: + if not is_tfr_con: klass = SpectralConnectivity else: - assert mode == "cwt_morlet" klass = SpectroTemporalConnectivity kwargs.update(times=times) conn_list.append(klass(**kwargs)) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index 44e91213..4257bd40 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -15,8 +15,9 @@ import numpy as np from mne.epochs import BaseEpochs from mne.parallel import parallel_func -from mne.time_frequency import EpochsSpectrum, EpochsSpectrumArray +from mne.time_frequency import EpochsSpectrum, EpochsTFR from mne.time_frequency.multitaper import _psd_from_mt +from mne.time_frequency.tfr import _tfr_from_mt from mne.utils import ProgressBar, _validate_type, logger @@ -32,7 +33,7 @@ def _check_rank_input(rank, data, indices): if "copy" in inspect.getfullargspec(data.get_data).kwonlyargs: kwargs["copy"] = False data_arr = data.get_data(**kwargs) - elif isinstance(data, EpochsSpectrum | EpochsSpectrumArray): + elif isinstance(data, EpochsSpectrum): # Spectrum objs will drop bad channels, so specify picking all channels data_arr = data.get_data(picks=np.arange(data.info["nchan"])) # Convert to power (and aggregate over tapers) before computing rank @@ -40,6 +41,16 @@ def _check_rank_input(rank, data, indices): data_arr = _psd_from_mt(data_arr, data.weights) else: data_arr = (data_arr * data_arr.conj()).real + elif isinstance(data, EpochsTFR): + # TFR objs will drop bad channels, so specify picking all channels + data_arr = data.get_data(picks=np.arange(data.info["nchan"])) + # Convert to power and aggregate over time before computing rank + if "taper" in data._dims: + data_arr = np.sum( + [_tfr_from_mt(epoch, data.weights) for epoch in data_arr], axis=-1 + ) + else: + data_arr = np.sum((data_arr * data_arr.conj()).real, axis=-1) else: data_arr = data diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index d13cbea4..bb5abc26 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -12,6 +12,7 @@ from mne_connectivity import ( SpectralConnectivity, + SpectroTemporalConnectivity, make_signals_in_freq_bands, read_connectivity, seed_target_indices, @@ -299,9 +300,9 @@ def test_spectral_connectivity(method, mode): ), con.get_data()[1, 0, gidx[0] : gidx[1]].min() # we see something for zero-lag assert_array_less(con.get_data(output="dense")[1, 0, : bidx[0]], lower_t) - assert np.all( - con.get_data(output="dense")[1, 0, bidx[1] :] < lower_t - ), con.get_data()[1, 0, bidx[1:]].max() + assert np.all(con.get_data(output="dense")[1, 0, bidx[1] :] < lower_t), ( + con.get_data()[1, 0, bidx[1:]].max() + ) elif method == "cohy": # imaginary coh will be zero check = np.imag(con.get_data(output="dense")[1, 0, gidx[0] : gidx[1]]) @@ -323,9 +324,9 @@ def test_spectral_connectivity(method, mode): ) assert_array_less(con.get_data(output="dense")[1, 0, : bidx[0]], lower_t) assert_array_less(con.get_data(output="dense")[1, 0, bidx[1] :], lower_t) - assert np.all( - con.get_data(output="dense")[1, 0, bidx[1] :] < lower_t - ), con.get_data()[1, 0, bidx[1] :].max() + assert np.all(con.get_data(output="dense")[1, 0, bidx[1] :] < lower_t), ( + con.get_data()[1, 0, bidx[1] :].max() + ) # compute a subset of connections using indices and 2 jobs indices = (np.array([2, 1]), np.array([0, 0])) @@ -469,17 +470,28 @@ def test_spectral_connectivity(method, mode): assert out_lens[0] == 10 -# Fourier coeffs in Spectrum objects added in MNE v1.8.0 @pytest.mark.skipif( - not check_version("mne", "1.8"), reason="Requires MNE v1.8.0 or higher" -) + not check_version("mne", "1.10"), reason="Requires MNE v1.10.0 or higher" +) # Taper weights in TFR objects added in MNE v1.10.0 @pytest.mark.parametrize("method", ["coh", "cacoh"]) -@pytest.mark.parametrize("mode", ["multitaper", "fourier"]) -def test_spectral_connectivity_epochs_spectrum_input(method, mode): - """Test spec_conn_epochs works with EpochsSpectrum data as input. +@pytest.mark.parametrize( + "mode, spectra_as_tfr", + [ + ("multitaper", False), # test multitaper in normal... + ("multitaper", True), # ... and TFR mode + ("fourier", False), + ("cwt_morlet", True), + ], +) +def test_spectral_connectivity_epochs_spectrum_tfr_input(method, mode, spectra_as_tfr): + """Test spec_conn_epochs works with EpochsSpectrum/TFR data as input. Important to test both bivariate and multivariate methods, as the latter involves additional steps (e.g., rank computation). + + Since spec_conn_epochs doesn't have a way to compute multitaper TFR from timeseries + data, we can't compare the results, but we can check that the connectivity values + are in an expected range. """ # Simulation parameters & data generation sfreq = 100.0 # Hz @@ -489,7 +501,7 @@ def test_spectral_connectivity_epochs_spectrum_input(method, mode): n_epochs = 30 n_times = 200 # samples trans_bandwidth = 1.0 # Hz - delay = 10 # samples + delay = 5 # samples data = make_signals_in_freq_bands( n_seeds=n_seeds, @@ -499,7 +511,7 @@ def test_spectral_connectivity_epochs_spectrum_input(method, mode): n_times=n_times, sfreq=sfreq, trans_bandwidth=trans_bandwidth, - snr=0.5, + snr=0.7, connection_delay=delay, rng_seed=44, ) @@ -511,34 +523,56 @@ def test_spectral_connectivity_epochs_spectrum_input(method, mode): else: indices = ([np.arange(n_seeds)], [np.arange(n_targets) + n_seeds]) - # Compute Fourier coefficients + # Compute spectral coefficients + tfr_freqs = np.arange(10, 50) # similar to Fourier & multitaper modes kwargs = dict() if mode == "fourier": kwargs.update(window="hann") # default is Hamming, but we need Hanning - coeffs = data.compute_psd( - method="welch" if mode == "fourier" else mode, output="complex", **kwargs - ) + spec_mode = "welch" + elif mode == "cwt_morlet": + kwargs.update(freqs=tfr_freqs) + spec_mode = "morlet" + else: # multitaper + if spectra_as_tfr: + kwargs.update(freqs=tfr_freqs) + spec_mode = mode + compute_coeffs_method = data.compute_tfr if spectra_as_tfr else data.compute_psd + coeffs = compute_coeffs_method(method=spec_mode, output="complex", **kwargs) # Compute connectivity con = spectral_connectivity_epochs(data=coeffs, method=method, indices=indices) - # Check connectivity from Epochs and Spectrum are equivalent; - # Works for multitaper, but Welch of Spectrum and Fourier of spec_conn are slightly - # off (max. abs. diff. ~0.006) even when what should be identical settings are used - con_from_epochs = spectral_connectivity_epochs( - data=data, method=method, indices=indices, mode=mode - ) - if mode == "multitaper": - atol = 0 + # Check connectivity classes are correct and that freqs/times match input data + if spectra_as_tfr: + assert isinstance(con, SpectroTemporalConnectivity), "wrong class type" + assert np.all(con.times == coeffs.times), "times do not match input data" else: - atol = 7e-3 - # spec_conn_epochs excludes freqs without at least 5 cycles, but not Spectrum - fstart = con.freqs.index(con_from_epochs.freqs[0]) - assert_allclose( - np.abs(con.get_data()[:, fstart:]), - np.abs(con_from_epochs.get_data()), - atol=atol, - ) + assert isinstance(con, SpectralConnectivity), "wrong class type" + assert np.all(con.freqs == coeffs.freqs), "freqs do not match input data" + + # Check connectivity from Epochs and EpochsSpectrum/TFR are equivalent + if mode == "multitaper" and spectra_as_tfr: + pass # no multitaper TFR computation from timeseries in spec_conn_epochs + else: + con_from_epochs = spectral_connectivity_epochs( + data=data, method=method, indices=indices, mode=mode, cwt_freqs=tfr_freqs + ) + # Works for multitaper & Morlet, but Welch of Spectrum and Fourier of spec_conn + # are slightly off (max. abs. diff. ~0.006). This is due to the Spectrum object + # using scipy.signal.spectrogram to compute the coefficients, while spec_conn + # uses scipy.signal.rfft, which give slightly different outputs even with + # identical settings. + if mode == "fourier": + atol = 7e-3 + else: + atol = 0 + # spec_conn_epochs excludes freqs without at least 5 cycles, but not Spectrum + fstart = con.freqs.index(con_from_epochs.freqs[0]) + assert_allclose( + np.abs(con.get_data()[:, fstart:]), + np.abs(con_from_epochs.get_data()), + atol=atol, + ) # Check connectivity values are as expected freqs = np.array(con.freqs) @@ -546,25 +580,20 @@ def test_spectral_connectivity_epochs_spectrum_input(method, mode): freqs_noise = (freqs < fband[0] - trans_bandwidth * 2) | ( freqs > fband[1] + trans_bandwidth * 2 ) - - # nothing for CaCoh to optimise, so use same thresholds for CaCoh and Coh - if mode == "multitaper": # lower baseline for multitaper - con_thresh = (0.1, 0.3) - else: # higher baseline for Welch/Fourier - con_thresh = (0.2, 0.4) - + WEAK_CONN_OR_NOISE = 0.3 # conn values outside of simulated fband should be < this + STRONG_CONN = 0.6 # conn values inside simulated fband should be > this # check freqs of simulated interaction show strong connectivity - assert_array_less(con_thresh[1], np.abs(con.get_data()[:, freqs_con].mean())) + assert_array_less(STRONG_CONN, np.abs(con.get_data()[:, freqs_con].mean())) # check freqs of no simulated interaction (just noise) show weak connectivity - assert_array_less(np.abs(con.get_data()[:, freqs_noise].mean()), con_thresh[0]) + assert_array_less(np.abs(con.get_data()[:, freqs_noise].mean()), WEAK_CONN_OR_NOISE) # TODO: Add general test for error catching for spec_conn_epochs @pytest.mark.skipif( - not check_version("mne", "1.8"), reason="Requires MNE v1.8.0 or higher" -) -def test_spectral_connectivity_epochs_spectrum_input_error_catch(): - """Test spec_conn_epochs catches error with EpochsSpectrum data as input.""" + not check_version("mne", "1.10"), reason="Requires MNE v1.10.0 or higher" +) # Taper weights in TFR objects added in MNE v1.10.0 +def test_spectral_connectivity_epochs_spectrum_tfr_input_error_catch(): + """Test spec_conn_epochs catches errors with EpochsSpectrum/TFR data as input.""" # Generate data rng = np.random.default_rng(44) n_epochs, n_chans, n_times = (5, 2, 50) @@ -577,12 +606,25 @@ def test_spectral_connectivity_epochs_spectrum_input_error_catch(): with pytest.raises(TypeError, match="must contain complex-valued Fourier coeff"): spectrum = data.compute_psd(output="power") spectral_connectivity_epochs(data=spectrum) + with pytest.raises(TypeError, match="must contain complex-valued Fourier coeff"): + tfr = data.compute_tfr(method="morlet", freqs=np.arange(15, 20), output="power") + spectral_connectivity_epochs(data=tfr) # Test unaggregated segments caught with pytest.raises(ValueError, match=r"cannot contain Fourier coeff.*segments"): spectrum = data.compute_psd(method="welch", average=False, output="complex") spectral_connectivity_epochs(data=spectrum) + # Simulate missing weights attr in EpochsSpectrum/TFR object + spectrum = data.compute_psd(method="multitaper", output="complex") + with pytest.raises(AttributeError, match="weights are required for multitaper"): + spectrum_copy = spectrum.copy() + del spectrum_copy._weights + spectral_connectivity_epochs(data=spectrum_copy) + with pytest.raises(AttributeError, match="weights are required for multitaper"): + spectrum._weights = None + spectral_connectivity_epochs(data=spectrum) + _gc_marks = [] if platform.system() == "Darwin" and platform.processor() == "arm": @@ -1310,7 +1352,7 @@ def test_spectral_connectivity_time_delayed(): N.B.: the spectral_connectivity_time method seems to be more unstable than spectral_connectivity_epochs for GC estimation. Accordingly, we assess Granger scores only in the context of the noise-corrected TRGC metric, - where the true directionality of the connections seems to identified. + where the true directionality of the connections seems to be identified. """ mode = "multitaper" # stick with single mode in interest of time @@ -1684,6 +1726,89 @@ def test_multivar_spectral_connectivity_time_shapes( assert np.all(np.array(con.indices) == np.array(([[0, 1, 2]], [[3, 4, -1]]))) +@pytest.mark.skipif( + not check_version("mne", "1.10"), reason="Requires MNE v1.10.0 or higher" +) # Taper weights in TFR objects added in MNE v1.10.0 +@pytest.mark.parametrize("method", ["coh", "cacoh"]) +@pytest.mark.parametrize("mode", ["multitaper", "cwt_morlet"]) +def test_spectral_connectivity_time_tfr_input(method, mode): + """Test spec_conn_time works with EpochsTFR data as input. + + Important to test both bivariate and multivariate methods, as the latter involves + additional steps (e.g., rank computation). + """ + # Simulation parameters & data generation + n_seeds = 2 + n_targets = 2 + fband = (15, 20) # Hz + trans_bandwidth = 1.0 # Hz + + data = make_signals_in_freq_bands( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=fband, + n_epochs=30, + n_times=200, + sfreq=100, + trans_bandwidth=trans_bandwidth, + snr=0.7, + connection_delay=5, + rng_seed=44, + ) + + if method == "coh": + indices = seed_target_indices( + seeds=np.arange(n_seeds), targets=np.arange(n_targets) + n_seeds + ) + else: + indices = ([np.arange(n_seeds)], [np.arange(n_targets) + n_seeds]) + + # Compute TFR + freqs = np.arange(10, 50) + n_cycles = 5.0 # non-default value to avoid warning in spec_conn_time + mt_bandwidth = 4.0 + kwargs = dict() + if mode == "cwt_morlet": + kwargs.update(zero_mean=False) # default in spec_conn_time + spec_mode = "morlet" + else: + kwargs.update(time_bandwidth=mt_bandwidth) + spec_mode = mode + coeffs = data.compute_tfr( + method=spec_mode, freqs=freqs, n_cycles=n_cycles, output="complex", **kwargs + ) + + # Compute connectivity + con_kwargs = dict( + method=method, + indices=indices, + mode=mode, + freqs=freqs, + n_cycles=n_cycles, + mt_bandwidth=mt_bandwidth, + average=True, + ) + con = spectral_connectivity_time(data=coeffs, **con_kwargs) + + # Check connectivity from Epochs and EpochsTFR are equivalent (small but non-zero + # tolerance given due to some platform-dependent variation) + con_from_epochs = spectral_connectivity_time(data=data, **con_kwargs) + assert_allclose( + np.abs(con.get_data()), np.abs(con_from_epochs.get_data()), atol=1e-7 + ) + + # Check connectivity values are as expected + freqs_con = (freqs >= fband[0]) & (freqs <= fband[1]) + freqs_noise = (freqs < fband[0] - trans_bandwidth * 2) | ( + freqs > fband[1] + trans_bandwidth * 2 + ) + # check freqs of simulated interaction show strong connectivity + assert_array_less(0.6, np.abs(con.get_data()[:, freqs_con].mean())) + # check freqs of no simulated interaction (just noise) show weak connectivity + assert_array_less(np.abs(con.get_data()[:, freqs_noise].mean()), 0.3) + + +# TODO: Add general test for error catching for spec_conn_time @pytest.mark.parametrize("method", ["cacoh", "mic", "mim", _gc, _gc_tr]) @pytest.mark.parametrize("mode", ["multitaper", "cwt_morlet"]) def test_multivar_spectral_connectivity_time_error_catch(method, mode): @@ -1705,7 +1830,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): freqs = np.arange(10, 25 + 1) # test type-checking of data - with pytest.raises(TypeError, match="must be an instance of Epochs or a NumPy arr"): + with pytest.raises(TypeError, match="Epochs, EpochsTFR, or a NumPy arr"): spectral_connectivity_time(data="foo", freqs=freqs) # check bad indices without nested array caught @@ -1822,6 +1947,40 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): ) +@pytest.mark.skipif( + not check_version("mne", "1.10"), reason="Requires MNE v1.10.0 or higher" +) # Taper weights in TFR objects added in MNE v1.10.0 +def test_spectral_connectivity_time_tfr_input_error_catch(): + """Test spec_conn_time catches errors with EpochsTFR data as input.""" + # Generate data + rng = np.random.default_rng(44) + n_epochs, n_chans, n_times = (5, 2, 100) + sfreq = 50 + data = rng.random((n_epochs, n_chans, n_times)) + info = create_info(ch_names=n_chans, sfreq=sfreq, ch_types="eeg") + data = EpochsArray(data=data, info=info) + freqs = np.arange(10, 20) + + # Test not Fourier coefficients caught + with pytest.raises(TypeError, match="must contain complex-valued Fourier coeff"): + tfr = data.compute_tfr(method="morlet", freqs=freqs, output="power") + spectral_connectivity_time(data=tfr, freqs=freqs) + + # Simulate missing weights attr in EpochsTFR object + tfr = data.compute_tfr(method="multitaper", output="complex", freqs=freqs) + with pytest.raises(AttributeError, match="weights are required for multitaper"): + tfr_copy = tfr.copy() + del tfr_copy._weights + spectral_connectivity_time(data=tfr_copy) + with pytest.raises(AttributeError, match="weights are required for multitaper"): + tfr._weights = None + spectral_connectivity_time(data=tfr) + + # Test no freqs caught for non-TFR input + with pytest.raises(TypeError, match="`freqs` must be specified"): + spectral_connectivity_time(data=data) + + def test_save(tmp_path): """Test saving results of spectral connectivity.""" epochs = make_signals_in_freq_bands( diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index a8bcdfb6..1876fc83 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -10,8 +10,14 @@ import xarray as xr from mne.epochs import BaseEpochs from mne.parallel import parallel_func -from mne.time_frequency import dpss_windows, tfr_array_morlet, tfr_array_multitaper -from mne.utils import _validate_type, logger, verbose +from mne.time_frequency import ( + EpochsTFR, + EpochsTFRArray, + dpss_windows, + tfr_array_morlet, + tfr_array_multitaper, +) +from mne.utils import _check_option, _validate_type, logger, verbose from ..base import EpochSpectralConnectivity, SpectralConnectivity from ..utils import _check_multivariate_indices, check_indices, fill_doc @@ -32,7 +38,7 @@ @fill_doc def spectral_connectivity_time( data, - freqs, + freqs=None, method="coh", average=False, indices=None, @@ -66,12 +72,26 @@ def spectral_connectivity_time( Parameters ---------- - data : array_like, shape (n_epochs, n_signals, n_times) | Epochs - The data from which to compute connectivity. - freqs : array_like - Array of frequencies of interest for time-frequency decomposition. - Only the frequencies within the range specified by ``fmin`` and - ``fmax`` are used. + data : array_like, shape (n_epochs, n_signals, n_times) | ~mne.Epochs | ~mne.time_frequency.EpochsTFR + The data from which to compute connectivity. Can be epoched timeseries data as + an :term:`array-like` or :class:`~mne.Epochs` object, or Fourier coefficients + for each epoch as an :class:`~mne.time_frequency.EpochsTFR` object. If + timeseries data, the spectral information will be computed according to the + spectral estimation mode (see the ``mode`` parameter). If an + :class:`~mne.time_frequency.EpochsTFR` object, existing spectral information + will be used and the ``mode`` parameter will be ignored. + + .. versionchanged:: 0.8 + Fourier coefficients stored in an :class:`~mne.time_frequency.EpochsTFR` + object can also be passed in as data. Storing multitaper weights in + :class:`~mne.time_frequency.EpochsTFR` objects requires ``mne >= 1.10``. + freqs : array_like | None + Array of frequencies of interest for time-frequency decomposition. Only the + frequencies within the range specified by ``fmin`` and ``fmax`` are used. If + ``data`` is an :term:`array-like` or :class:`~mne.Epochs` object, the + frequencies must be specified. If ``data`` is an + :class:`~mne.time_frequency.EpochsTFR` object, ``data.freqs`` is used and this + parameter is ignored. method : str | list of str Connectivity measure(s) to compute. These can be ``['coh', 'cacoh', 'mic', 'mim', 'plv', 'ciplv', 'pli', 'wpli', 'gc', 'gc_tr']``. These @@ -104,8 +124,8 @@ def spectral_connectivity_time( connections between all channels are computed, unless a Granger causality method is called, in which case an error is raised. sfreq : float - The sampling frequency. Required if data is not - :class:`Epochs `. + The sampling frequency. Required if ``data`` is not an :class:`~mne.Epochs` or + :class:`~mne.time_frequency.EpochsTFR` object. fmin : float | tuple of float | None The lower frequency of interest. Multiple bands are defined using a tuple, e.g., ``(8., 20.)`` for two bands with 8 Hz and 20 Hz lower @@ -133,21 +153,24 @@ def spectral_connectivity_time( Amount of time to consider as padding at the beginning and end of each epoch in seconds. See Notes for more information. mode : str - Time-frequency decomposition method. Can be either: 'multitaper', or - 'cwt_morlet'. See :func:`mne.time_frequency.tfr_array_multitaper` and - :func:`mne.time_frequency.tfr_array_morlet` for reference. + Time-frequency decomposition method. Can be either: ``'multitaper'``, or + ``'cwt_morlet'``. See :func:`mne.time_frequency.tfr_array_multitaper` and + :func:`mne.time_frequency.tfr_array_morlet` for reference. Ignored if ``data`` + is an :class:`~mne.time_frequency.EpochsTFR` object. mt_bandwidth : float | None - Product between the temporal window length (in seconds) and the full - frequency bandwidth (in Hz). This product can be seen as the surface - of the window on the time/frequency plane and controls the frequency - bandwidth (thus the frequency resolution) and the number of good - tapers. See :func:`mne.time_frequency.tfr_array_multitaper` - documentation. + Product between the temporal window length (in seconds) and the full frequency + bandwidth (in Hz). This product can be seen as the surface of the window on the + time/frequency plane and controls the frequency bandwidth (thus the frequency + resolution) and the number of good tapers. See + :func:`mne.time_frequency.tfr_array_multitaper` documentation. Ignored if + ``data`` is an :class:`~mne.time_frequency.EpochsTFR` object. n_cycles : float | array_like of float - Number of cycles in the wavelet, either a fixed number or one per - frequency. The number of cycles ``n_cycles`` and the frequencies of - interest ``cwt_freqs`` define the temporal window length. For details, - see :func:`mne.time_frequency.tfr_array_morlet` documentation. + Number of cycles in the wavelet, either a fixed number or one per frequency. The + number of cycles ``n_cycles`` and the frequencies of interest ``freqs`` define + the temporal window length. For details, see + :func:`mne.time_frequency.tfr_array_multitaper` and + :func:`mne.time_frequency.tfr_array_morlet` documentation. Ignored if ``data`` + is an :class:`~mne.time_frequency.EpochsTFR` object. gc_n_lags : int Number of lags to use for the vector autoregressive model when computing Granger causality. Higher values increase computational cost, @@ -183,10 +206,10 @@ def spectral_connectivity_time( instances corresponding to connectivity measures if several connectivity measures are specified. The shape of each connectivity dataset is ([n_epochs,] n_cons, [n_comps,] n_freqs). ``n_comps`` is present for valid multivariate - methods if ``n_components > 1``.When "indices" is None and a bivariate method is - called, "n_cons = n_signals ** 2", or if a multivariate method is called "n_cons - = 1". When "indices" is specified, "n_con = len(indices[0])" for bivariate and - multivariate methods. + methods if ``n_components > 1``. When "indices" is None and a bivariate method + is called, "n_cons = n_signals ** 2", or if a multivariate method is called + "n_cons = 1". When "indices" is specified, "n_con = len(indices[0])" for + bivariate and multivariate methods. See Also -------- @@ -369,12 +392,23 @@ def spectral_connectivity_time( References ---------- .. footbibliography:: - """ + """ # noqa: E501 events = None event_id = None # extract data from Epochs object - _validate_type(data, (np.ndarray, BaseEpochs), "`data`", "Epochs or a NumPy array") - if isinstance(data, BaseEpochs): + _validate_type( + data, + (np.ndarray, BaseEpochs, EpochsTFR), + "`data`", + "Epochs, EpochsTFR, or a NumPy array", + ) + if not isinstance(data, EpochsTFR) and freqs is None: + raise TypeError( + "`freqs` must be specified when `data` is not an EpochsTFR object" + ) + weights = None + spectrum_computed = False + if isinstance(data, BaseEpochs | EpochsTFR): names = data.ch_names sfreq = data.info["sfreq"] events = data.events @@ -392,12 +426,42 @@ def spectral_connectivity_time( if hasattr(data, "annotations") and not annots_in_metadata: data.add_annotations_to_metadata(overwrite=True) metadata = data.metadata - # XXX: remove logic once support for mne<1.6 is dropped - kwargs = dict() - if "copy" in inspect.getfullargspec(data.get_data).kwonlyargs: - kwargs["copy"] = False - data = data.get_data(**kwargs) - n_epochs, n_signals, n_times = data.shape + if isinstance(data, BaseEpochs): + # XXX: remove logic once support for mne<1.6 is dropped + kwargs = dict() + if "copy" in inspect.getfullargspec(data.get_data).kwonlyargs: + kwargs["copy"] = False + data = data.get_data(**kwargs) + n_epochs, n_signals, n_times = data.shape + else: + freqs = data.freqs # use freqs from EpochsTFR object + if isinstance(data, EpochsTFRArray): # infer mode from dimensions + mode = "multitaper" if "taper" in data._dims else "cwt_morlet" + else: # read mode from object + mode = "cwt_morlet" if data.method == "morlet" else data.method + # Extract weights from the EpochsTFR object + if not hasattr(data, "weights") or ( + data.weights is None and mode == "multitaper" + ): + # XXX: Remove logic when support for mne<1.10 is dropped + raise AttributeError( + "weights are required for multitaper coefficients stored in " + "EpochsTFR objects (requires mne >= 1.10); objects saved from " + "older versions of mne will need to be recomputed." + ) + if hasattr(data, "weights"): + weights = data.weights + # TFR objs will drop bad channels, so specify picking all channels + data = data.get_data(picks=np.arange(data.info["nchan"])) + if not np.iscomplexobj(data): + raise TypeError( + "if `data` is an EpochsTFR object, it must contain complex-valued " + "Fourier coefficients, such as that returned from " + "Epochs.compute_tfr() with `output='complex'`" + ) + n_epochs, n_signals = data.shape[:2] + n_times = data.shape[-1] + spectrum_computed = True else: data = np.asarray(data) n_epochs, n_signals, n_times = data.shape @@ -520,8 +584,7 @@ def spectral_connectivity_time( n_components = _check_n_components_input(n_components, rank) if n_components == 1: # n_components=0 means space for a components dimension is not allocated in - # the results, similar to how n_times_spectrum=0 is used to indicate that - # time is not a dimension in the results + # the results n_components = 0 else: rank = None @@ -620,13 +683,15 @@ def spectral_connectivity_time( padding=padding, kw_cwt={}, kw_mt={}, + weights=weights, + multivariate_con=multivariate_con, + spectrum_computed=spectrum_computed, n_jobs=n_jobs, verbose=verbose, - multivariate_con=multivariate_con, ) for epoch_idx in np.arange(n_epochs): - logger.info(f" Processing epoch {epoch_idx+1} / {n_epochs} ...") + logger.info(f" Processing epoch {epoch_idx + 1} / {n_epochs} ...") scores, patterns = _spectral_connectivity(data[epoch_idx], **call_params) for m in method: conn[m][epoch_idx] = np.stack(scores[m], axis=0) @@ -713,16 +778,18 @@ def _spectral_connectivity( padding, kw_cwt, kw_mt, + weights, + multivariate_con, + spectrum_computed, n_jobs, verbose, - multivariate_con, ): """Estimate time-resolved connectivity for one epoch. Parameters ---------- - data : array_like, shape (n_channels, n_times) - Time-series data. + data : array_like, shape (channels, [freqs,] [tapers,] times) + Time-series data or time-frequency data. method : list of str List of connectivity metrics to compute. kernel : array_like, shape (n_sm_fres, n_sm_times) @@ -764,8 +831,12 @@ def _spectral_connectivity( padding : float Amount of time to consider as padding at the beginning and end of each epoch in seconds. + weights : array, shape (n_tapers, n_freqs) | None + Taper weights for multitaper spectral estimation. multivariate_con : bool Whether or not multivariate connectivity is to be computed. + spectrum_computed : bool + Whether or not the time-frequency decomposition has already been computed. Returns ------- @@ -774,7 +845,6 @@ def _spectral_connectivity( ``method``. Each element is an array of shape (n_cons, [n_comps], n_freqs) or (n_cons, [n_comps], n_fbands) if ``faverage`` is `True`. ``n_comps`` is present for valid multivariate methods if ``n_components > 0``. - patterns : dict Dictionary containing the connectivity patterns (for reconstructing the connectivity components in channel-space) corresponding to the metrics in @@ -784,51 +854,63 @@ def _spectral_connectivity( and target signals (respectively). ``n_comps`` is present for valid multivariate methods if ``n_components > 0``. """ - n_cons = len(source_idx) - data = np.expand_dims(data, axis=0) - kw_cwt.setdefault("zero_mean", False) # avoid FutureWarning - if mode == "cwt_morlet": - out = tfr_array_morlet( - data, - sfreq, - freqs, - n_cycles=n_cycles, - output="complex", - decim=decim, - n_jobs=n_jobs, - **kw_cwt, - ) - out = np.expand_dims(out, axis=2) # same dims with multitaper - weights = None - elif mode == "multitaper": - out = tfr_array_multitaper( - data, - sfreq, - freqs, - n_cycles=n_cycles, - time_bandwidth=mt_bandwidth, - output="complex", - decim=decim, - n_jobs=n_jobs, - **kw_mt, - ) - if isinstance(n_cycles, int | float): - n_cycles = [n_cycles] * len(freqs) - mt_bandwidth = mt_bandwidth if mt_bandwidth else 4 - n_tapers = int(np.floor(mt_bandwidth - 1)) - weights = np.zeros((n_tapers, len(freqs), out.shape[-1])) - for i, (f, n_c) in enumerate(zip(freqs, n_cycles)): - window_length = np.arange(0.0, n_c / float(f), 1.0 / sfreq).shape[0] - half_nbw = mt_bandwidth / 2.0 - n_tapers = int(np.floor(mt_bandwidth - 1)) - _, eigvals = dpss_windows(window_length, half_nbw, n_tapers, sym=False) - weights[:, i, :] = np.sqrt(eigvals[:, np.newaxis]) - # weights have shape (n_tapers, n_freqs, n_times) + # check that spectral mode is recognised + _check_option("mode", mode, ("cwt_morlet", "multitaper")) + + # compute time-frequency decomposition + mt_bandwidth = mt_bandwidth if mt_bandwidth else 4 + if not spectrum_computed: + data = np.expand_dims(data, axis=0) + kw_cwt.setdefault("zero_mean", False) # avoid FutureWarning + if mode == "cwt_morlet": + out = tfr_array_morlet( + data, + sfreq, + freqs, + n_cycles=n_cycles, + output="complex", + decim=decim, + n_jobs=n_jobs, + **kw_cwt, + ) + else: + out = tfr_array_multitaper( + data, + sfreq, + freqs, + n_cycles=n_cycles, + time_bandwidth=mt_bandwidth, + output="complex", + decim=decim, + n_jobs=n_jobs, + **kw_mt, + ) + out = np.squeeze(out, axis=0) else: - raise ValueError("Mode must be 'cwt_morlet' or 'multitaper'.") - - out = np.squeeze(out, axis=0) + out = data + # give tapers dim to cwt_morlet output + if mode == "cwt_morlet": + out = np.expand_dims(out, axis=1) + + # compute taper weights + if mode == "multitaper": + if not spectrum_computed: # compute from scratch + if isinstance(n_cycles, int | float): + n_cycles = [n_cycles] * len(freqs) + n_tapers = out.shape[-3] + n_times = out.shape[-1] + half_nbw = mt_bandwidth / 2.0 + weights = np.zeros((n_tapers, len(freqs), n_times)) + for i, (f, n_c) in enumerate(zip(freqs, n_cycles)): + window_length = np.arange(0.0, n_c / float(f), 1.0 / sfreq).shape[0] + _, eigvals = dpss_windows(window_length, half_nbw, n_tapers, sym=False) + weights[:, i, :] = np.sqrt(eigvals[:, np.newaxis]) + # weights have shape (n_tapers, n_freqs, n_times) + else: # add time dimension to existing weights + weights = np.repeat(weights[..., np.newaxis], out.shape[-1], axis=-1) + + # pad spectrum and weights if padding: if padding < 0: raise ValueError(f"Padding cannot be negative, got {padding}.") @@ -843,6 +925,7 @@ def _spectral_connectivity( # compute for each connectivity method scores = {} patterns = {} + n_cons = len(source_idx) conn = _parallel_con( out, method, @@ -929,7 +1012,7 @@ def _parallel_con( Number of pairs of signals. faverage : bool Average over frequency bands. - weights : array_like, shape (n_tapers, n_freqs, n_times) + weights : array_like, shape (n_tapers, n_freqs, n_times) | None Multitaper weights. multivariate_con : bool Whether or not multivariate connectivity is being computed. @@ -1112,7 +1195,6 @@ def _multivariate_con( connectivity method. Each element is an array with shape ([n_comps], n_freqs) or ([n_comps], n_fbands) depending on ``faverage``. ``n_comps`` is present for valid multivariate methods if ``n_components > 0``. - patterns : list List of connectivity patterns between seed and target signals for each connectivity method. Each element is an array of length 2 corresponding to the