Skip to content

Commit

Permalink
Begin add support for tapers in array objs
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Dec 9, 2024
1 parent 6a23556 commit a107991
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 7 deletions.
70 changes: 63 additions & 7 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,7 +1421,6 @@ def __setstate__(self, state):

defaults = dict(
method="unknown",
dims=("epoch", "channel", "freq", "time")[-state["data"].ndim :],
baseline=None,
decim=1,
data_type="TFR",
Expand All @@ -1445,7 +1444,7 @@ def __setstate__(self, state):
unknown_class = Epochs if "epoch" in self._dims else Evoked
inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Unknown=unknown_class)
self._inst_type = inst_types[defaults["inst_type_str"]]
# sanity check data/freqs/times/info agreement
# sanity check data/freqs/times/info/weights agreement
self._check_state()

def __repr__(self):
Expand Down Expand Up @@ -1498,14 +1497,26 @@ def _check_compatibility(self, other):
raise RuntimeError(msg.format(problem, extra))

def _check_state(self):
"""Check data/freqs/times/info agreement during __setstate__."""
"""Check data/freqs/times/info/weights agreement during __setstate__."""
msg = "{} axis of data ({}) doesn't match {} attribute ({})"
n_chan_info = len(self.info["chs"])
n_chan = self._data.shape[self._dims.index("channel")]
n_taper = (
self._data.shape[self._dims.index("taper")]
if "taper" in self._dims
else None
)
n_freq = self._data.shape[self._dims.index("freq")]
n_time = self._data.shape[self._dims.index("time")]
if n_chan_info != n_chan:
msg = msg.format("Channel", n_chan, "info", n_chan_info)
elif n_taper is not None:
if self._weights is None:
raise RuntimeError("Taper dimension in data, but no weights found.")
if n_taper != self._weights.shape[0]:
msg = msg.format("Taper", n_taper, "weights", self._weights.shape[0])
elif n_freq != self._weights.shape[1]:
msg = msg.format("Frequency", n_freq, "weights", self._weights.shape[1])
elif n_freq != len(self.freqs):
msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size)
elif n_time != len(self.times):
Expand Down Expand Up @@ -2788,6 +2799,7 @@ class AverageTFR(BaseTFR):
%(nave_tfr_attr)s
%(sfreq_tfr_attr)s
%(shape_tfr_attr)s
%(weights_tfr_attr)s
See Also
--------
Expand Down Expand Up @@ -2904,6 +2916,10 @@ def __getstate__(self):

def __setstate__(self, state):
"""Unpack AverageTFR from serialized format."""
if state["data"].ndim != 3:
raise ValueError(f"RawTFR data should be 3D, got {state['data'].ndim}.")
# Set dims now since optional tapers makes it difficult to disentangle later
state["dims"] = ("channel", "freq", "time")
super().__setstate__(state)
self._comment = state.get("comment", "")
self._nave = state.get("nave", 1)
Expand Down Expand Up @@ -3059,6 +3075,7 @@ class EpochsTFR(BaseTFR, GetEpochsMixin):
%(selection_attr)s
%(sfreq_tfr_attr)s
%(shape_tfr_attr)s
%(weights_tfr_attr)s
See Also
--------
Expand Down Expand Up @@ -3143,8 +3160,15 @@ def __getstate__(self):

def __setstate__(self, state):
"""Unpack EpochsTFR from serialized format."""
if state["data"].ndim != 4:
raise ValueError(f"EpochsTFR data should be 4D, got {state['data'].ndim}.")
if state["data"].ndim not in [4, 5]:
raise ValueError(
f"EpochsTFR data should be 4D or 5D, got {state['data'].ndim}."
)
# Set dims now since optional tapers makes it difficult to disentangle later
state["dims"] = ("epoch", "channel")
if state["data"].ndim == 5:
state["dims"] += ("taper",)
state["dims"] += ("freq", "time")
super().__setstate__(state)
self._metadata = state.get("metadata", None)
n_epochs = self.shape[0]
Expand Down Expand Up @@ -3248,7 +3272,16 @@ def average(self, method="mean", *, dim="epochs", copy=False):
See discussion here:
https://github.com/scipy/scipy/pull/12676#issuecomment-783370228
Averaging is not supported for data containing a taper dimension.
"""
if "taper" in self._dims:
raise NotImplementedError(
"Averaging multitaper tapers across epochs, frequencies, or times is "
"not supported. If averaging across epochs, consider averaging the "
"epochs before computing the complex/phase spectrum."
)

_check_option("dim", dim, ("epochs", "freqs", "times"))
axis = self._dims.index(dim[:-1]) # self._dims entries aren't plural

Expand Down Expand Up @@ -3620,6 +3653,7 @@ class EpochsTFRArray(EpochsTFR):
%(selection)s
%(drop_log)s
%(metadata_epochstfr)s
%(weights_tfr_array)s
Attributes
----------
Expand All @@ -3636,6 +3670,7 @@ class EpochsTFRArray(EpochsTFR):
%(selection_attr)s
%(sfreq_tfr_attr)s
%(shape_tfr_attr)s
%(weights_tfr_attr)s
See Also
--------
Expand All @@ -3658,6 +3693,7 @@ def __init__(
selection=None,
drop_log=None,
metadata=None,
weights=None,
):
state = dict(info=info, data=data, times=times, freqs=freqs)
optional = dict(
Expand All @@ -3668,6 +3704,7 @@ def __init__(
selection=selection,
drop_log=drop_log,
metadata=metadata,
weights=weights,
)
for name, value in optional.items():
if value is not None:
Expand Down Expand Up @@ -3710,6 +3747,7 @@ class RawTFR(BaseTFR):
method : str
The method used to compute the spectra (``'morlet'``, ``'multitaper'``
or ``'stockwell'``).
%(weights_tfr_attr)s
See Also
--------
Expand Down Expand Up @@ -3759,6 +3797,19 @@ def __init__(
**method_kw,
)

def __setstate__(self, state):
"""Unpack RawTFR from serialized format."""
if state["data"].ndim not in [3, 4]:
raise ValueError(
f"RawTFR data should be 3D or 4D, got {state['data'].ndim}."
)
# Set dims now since optional tapers makes it difficult to disentangle later
state["dims"] = ("channel",)
if state["data"].ndim == 4:
state["dims"] += ("taper",)
state["dims"] += ("freq", "time")
super().__setstate__(state)

def __getitem__(self, item):
"""Get RawTFR data.
Expand Down Expand Up @@ -3824,6 +3875,7 @@ class RawTFRArray(RawTFR):
%(times)s
%(freqs_tfr_array)s
%(method_tfr_array)s
%(weights_tfr_array)s
Attributes
----------
Expand All @@ -3834,6 +3886,7 @@ class RawTFRArray(RawTFR):
%(method_tfr_attr)s
%(sfreq_tfr_attr)s
%(shape_tfr_attr)s
%(weights_tfr_attr)s
See Also
--------
Expand All @@ -3851,10 +3904,13 @@ def __init__(
freqs,
*,
method=None,
weights=None,
):
state = dict(info=info, data=data, times=times, freqs=freqs)
if method is not None:
state["method"] = method
optional = dict(method=method, weights=weights)
for name, value in optional.items():
if value is not None:
state[name] = value
self.__setstate__(state)


Expand Down
10 changes: 10 additions & 0 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5008,6 +5008,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
solution.
"""

docdict["weight_tfr_array"] = """
weights : array of shape (n_tapers, n_freqs) | None
The weights for each taper. Must be provided if ``data`` has a taper dimension, such
as for complex or phase multitaper data.
"""
docdict["weight_tfr_attr"] = """
weights : array of shape (n_tapers, n_freqs) | None
The weights for each taper, if present in the data.
"""

docdict["window_psd"] = """\
window : str | float | tuple
Windowing function to use. See :func:`scipy.signal.get_window`.
Expand Down

0 comments on commit a107991

Please sign in to comment.