From 8c167168e3bfbfd2a1f2394831bcf9c4c214d6f5 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 22 Oct 2024 11:40:56 +0200 Subject: [PATCH] Revert "Update docstrings" This reverts commit 82fc2f7fe450dee8445cb9b48993944336e2aedc. --- mne/time_frequency/multitaper.py | 5 ++-- mne/time_frequency/tfr.py | 50 +++++++++++++------------------- 2 files changed, 22 insertions(+), 33 deletions(-) diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index fc926af4863..f5f6f79a0b3 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -505,8 +505,7 @@ def tfr_array_multitaper( coherence across trials. return_weights : bool, default False - If True, return the taper weights. Only applies if ``output='complex'`` or - ``'phase'``. + If True, return the taper weights. Only applies if ``output="complex"``. .. versionadded:: 1.9.0 @@ -529,7 +528,7 @@ def tfr_array_multitaper( contain the average power and the imaginary values contain the inter-trial coherence: :math:`out = power_{avg} + i * ITC`. weights : array of shape (n_tapers, n_freqs) - The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and + The taper weights. Only returned if ``output="complex"`` and ``return_weights=True``. See Also diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 91eaad159f5..908d25662e8 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -1214,6 +1214,9 @@ def __init__( f'{classname} got unsupported parameter value{_pl(problem)} ' f'{" and ".join(problem)}.' ) + # shim for tfr_array_morlet deprecation warning (TODO: remove after 1.7 release) + if method == "morlet": + method_kw.setdefault("zero_mean", True) # check method valid_methods = ["morlet", "multitaper"] if isinstance(inst, BaseEpochs): @@ -2693,12 +2696,9 @@ def to_data_frame( """ # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa - # triage for Epoch-derived or unaggregated spectra - from_epo = isinstance(self, EpochsTFR) - unagg_mt = "taper" in self._dims # arg checking valid_index_args = ["time", "freq"] - if from_epo: + if isinstance(self, EpochsTFR): valid_index_args.extend(["epoch", "condition"]) valid_time_formats = ["ms", "timedelta"] index = _check_pandas_index_arguments(index, valid_index_args) @@ -2706,42 +2706,32 @@ def to_data_frame( # get data picks = _picks_to_idx(self.info, picks, "all", exclude=()) data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True) - ch_axis = self._dims.index("channel") - if not from_epo: + axis = self._dims.index("channel") + if not isinstance(self, EpochsTFR): data = data[np.newaxis] # add singleton "epochs" axis - ch_axis += 1 - if not unagg_mt: - data = np.expand_dims(data, -3) # add singleton "tapers" axis - n_epochs, n_picks, n_tapers, n_freqs, n_times = data.shape - # reshape to (epochs*tapers*freqs*times) x signals - data = np.moveaxis(data, ch_axis, -1) - data = data.reshape(n_epochs * n_tapers * n_freqs * n_times, n_picks) + axis += 1 + n_epochs, n_picks, n_freqs, n_times = data.shape + # reshape to (epochs*freqs*times) x signals + data = np.moveaxis(data, axis, -1) + data = data.reshape(n_epochs * n_freqs * n_times, n_picks) # prepare extra columns / multiindex mindex = list() - default_index = list() times = _convert_times(times, time_format, self.info["meas_date"]) - times = np.tile(times, n_epochs * n_freqs * n_tapers) - freqs = np.tile(np.repeat(freqs, n_times * n_tapers), n_epochs) + times = np.tile(times, n_epochs * n_freqs) + freqs = np.tile(np.repeat(freqs, n_times), n_epochs) mindex.append(("time", times)) mindex.append(("freq", freqs)) - if from_epo: - mindex.append( - ("epoch", np.repeat(self.selection, n_times * n_freqs * n_tapers)) - ) + if isinstance(self, EpochsTFR): + mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs))) rev_event_id = {v: k for k, v in self.event_id.items()} conditions = [rev_event_id[k] for k in self.events[:, 2]] - mindex.append( - ("condition", np.repeat(conditions, n_times * n_freqs * n_tapers)) - ) - default_index.extend(["condition", "epoch"]) - default_index.extend(["freq", "time"]) - if unagg_mt: - name = "taper" - taper_nums = np.tile(np.arange(n_tapers), n_epochs * n_freqs * n_times) - mindex.append((name, taper_nums)) - default_index.append(name) + mindex.append(("condition", np.repeat(conditions, n_times * n_freqs))) assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:]) # build DataFrame + if isinstance(self, EpochsTFR): + default_index = ["condition", "epoch", "freq", "time"] + else: + default_index = ["freq", "time"] df = _build_data_frame( self, data, picks, long_format, mindex, index, default_index=default_index )