Skip to content

Commit

Permalink
Disallow aggregating tapers in combine_tfr
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Dec 10, 2024
1 parent aaef4b7 commit 999d122
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 2 deletions.
40 changes: 38 additions & 2 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def test_average_tfr_init(full_evoked):

@pytest.mark.parametrize("inst", ("raw_tfr", "epochs_tfr", "average_tfr"))
def test_tfr_init_errors(inst, request, average_tfr):
"""Test __init__ for Raw/Epochs/AverageTFR."""
"""Test __init__ for {Raw,Epochs,Average}TFR."""
# Load data
inst = _get_inst(inst, request, average_tfr=average_tfr)
state = inst.__getstate__()
Expand Down Expand Up @@ -1587,7 +1587,7 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy):

@pytest.mark.parametrize("inst", ("raw", "epochs", "evoked"))
def test_tfrarray_tapered_spectra(inst, evoked, request):
"""Test Raw/Epochs/AverageTFRArray instantiation with tapered spectra."""
"""Test {Raw,Epochs,Average}TFRArray instantiation with tapered spectra."""
# Load data object
inst = _get_inst(inst, request, evoked=evoked)
inst.pick("mag")
Expand Down Expand Up @@ -1802,3 +1802,39 @@ def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request):
assert re.match(
rf"Average over \d{{1,3}} {ch_type} channels\.", popup_fig.axes[0].get_title()
)


def test_combine_tfr_error_catch(request, average_tfr):
"""Test combine_tfr() catches errors."""
# check unrecognised weights string caught
with pytest.raises(ValueError, match='Weights must be .* "nave" or "equal"'):
combine_tfr([average_tfr, average_tfr], weights="foo")
# check bad weights size caught
with pytest.raises(ValueError, match="Weights must be the same size as all_tfr"):
combine_tfr([average_tfr, average_tfr], weights=[1, 1, 1])
# check different channel names caught
state = average_tfr.__getstate__()
new_info = average_tfr.info.copy()
average_tfr_bad = AverageTFR(
inst=state | dict(info=new_info.rename_channels({new_info.ch_names[0]: "foo"}))
)
with pytest.raises(AssertionError, match=".* do not contain the same channels"):
combine_tfr([average_tfr, average_tfr_bad])
# check different times caught
average_tfr_bad = AverageTFR(inst=state | dict(times=average_tfr.times + 1))
with pytest.raises(
AssertionError, match=".* do not contain the same time instants"
):
combine_tfr([average_tfr, average_tfr_bad])
# check taper dim caught
n_tapers = 3 # anything >= 1 should do
weights = np.ones((n_tapers, average_tfr.shape[1])) # tapers x freqs
state["data"] = np.repeat(np.expand_dims(average_tfr.data, 1), n_tapers, axis=1)
state["weights"] = weights
state["dims"] = ("channel", "taper", "freq", "time")
average_tfr_taper = AverageTFR(inst=state)
with pytest.raises(
NotImplementedError,
match="Aggregating multitaper tapers across TFR datasets is not supported.",
):
combine_tfr([average_tfr_taper, average_tfr_taper])
8 changes: 8 additions & 0 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3941,8 +3941,16 @@ def combine_tfr(all_tfr, weights="nave"):
Notes
-----
Aggregating multitaper TFR datasets with a taper dimension such as for complex or
phase data is not supported.
.. versionadded:: 0.11.0
"""
if any("taper" in tfr._dims for tfr in all_tfr):
raise NotImplementedError(
"Aggregating multitaper tapers across TFR datasets is not supported."
)

tfr = all_tfr[0].copy()
if isinstance(weights, str):
if weights not in ("nave", "equal"):
Expand Down
3 changes: 3 additions & 0 deletions mne/utils/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,9 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True):
Notes
-----
Aggregating multitaper TFR datasets with a taper dimension such as for complex or
phase data is not supported.
.. versionadded:: 0.11.0
"""
# check if all elements in the given list are evoked data
Expand Down

0 comments on commit 999d122

Please sign in to comment.