Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH, MRG] Allow epoch construction from annotations #12311

Merged
merged 34 commits into from
Dec 31, 2023
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4fa044e
[ENH] Allow epoch construction from annotations
alexrockhill Dec 19, 2023
eb73298
ruff
alexrockhill Dec 19, 2023
daee067
go through all examples using events_from_annotations
alexrockhill Dec 19, 2023
4447ed8
ruff
alexrockhill Dec 19, 2023
24270d1
Merge branch 'main' into epochs
alexrockhill Dec 19, 2023
7b5b9d8
fix ssvep tut
alexrockhill Dec 19, 2023
d4bc5c9
ruff
alexrockhill Dec 19, 2023
1378048
style
alexrockhill Dec 19, 2023
241e80b
oops don't modify raw annotations
alexrockhill Dec 19, 2023
3bc03fb
fix example
alexrockhill Dec 20, 2023
c2098a4
fix logic
alexrockhill Dec 20, 2023
47cc039
Richard review
alexrockhill Dec 20, 2023
5776741
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
alexrockhill Dec 20, 2023
3914d0b
add a version changed note
alexrockhill Dec 20, 2023
029d131
style
alexrockhill Dec 20, 2023
a44554a
fix test
alexrockhill Dec 20, 2023
7391b0f
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
alexrockhill Dec 20, 2023
f102534
merge
alexrockhill Dec 20, 2023
bf97c0a
fix versionchanged formatting
alexrockhill Dec 20, 2023
57477b5
add note to raw param
alexrockhill Dec 21, 2023
5b9e571
events->annotations
alexrockhill Dec 21, 2023
517c789
Merge branch 'main' into epochs
alexrockhill Dec 21, 2023
1290aac
Dan review
alexrockhill Dec 21, 2023
4c46e11
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2023
27615c8
style
alexrockhill Dec 21, 2023
7db0b59
Merge branch 'main' into epochs
alexrockhill Dec 21, 2023
fb87172
Merge branch 'epochs' of https://github.com/alexrockhill/mne-python i…
alexrockhill Dec 21, 2023
7ef087d
space
alexrockhill Dec 21, 2023
c3ae450
fix outdated example
alexrockhill Dec 21, 2023
f9401bd
style
alexrockhill Dec 21, 2023
cad31f4
review
alexrockhill Dec 22, 2023
2c33dab
Merge branch 'main' into epochs
hoechenberger Dec 27, 2023
52542ea
ignore warning
alexrockhill Dec 30, 2023
135f523
fix xvfb
alexrockhill Dec 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel/12311.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:class:`mne.Epochs` can now be constructed using :class:`mne.Annotations` stored in the ``raw`` object, by specifying ``events=None``. By `Alex Rockhill`_.
13 changes: 5 additions & 8 deletions examples/decoding/decoding_csp_eeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from sklearn.model_selection import ShuffleSplit, cross_val_score
from sklearn.pipeline import Pipeline

from mne import Epochs, events_from_annotations, pick_types
from mne import Epochs, pick_types
from mne.channels import make_standard_montage
from mne.datasets import eegbci
from mne.decoding import CSP
Expand All @@ -41,7 +41,6 @@
# avoid classification of evoked responses by using epochs that start 1s after
# cue onset.
tmin, tmax = -1.0, 4.0
event_id = dict(hands=2, feet=3)
subject = 1
runs = [6, 10, 14] # motor imagery: hands vs feet

Expand All @@ -50,22 +49,20 @@
eegbci.standardize(raw) # set channel names
montage = make_standard_montage("standard_1005")
raw.set_montage(montage)
raw.annotations.rename(dict(T1="hands", T2="feet"))

# Apply band-pass filter
raw.filter(7.0, 30.0, fir_design="firwin", skip_by_annotation="edge")

events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3))

picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads")

# Read epochs (train will be done only between 1 and 2s)
# Testing will be done with a running classifier
epochs = Epochs(
raw,
events,
event_id,
tmin,
tmax,
event_id=["hands", "feet"],
tmin=tmin,
tmax=tmax,
proj=True,
picks=picks,
baseline=None,
Expand Down
19 changes: 8 additions & 11 deletions examples/decoding/decoding_csp_timefreq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,22 @@
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import LabelEncoder

from mne import Epochs, create_info, events_from_annotations
from mne import Epochs, create_info
from mne.datasets import eegbci
from mne.decoding import CSP
from mne.io import concatenate_raws, read_raw_edf
from mne.time_frequency import AverageTFR

# %%
# Set parameters and read data
event_id = dict(hands=2, feet=3) # motor imagery: hands vs feet
subject = 1
runs = [6, 10, 14]
raw_fnames = eegbci.load_data(subject, runs)
raw = concatenate_raws([read_raw_edf(f) for f in raw_fnames])
raw.annotations.rename(dict(T1="hands", T2="feet"))

# Extract information from the raw file
sfreq = raw.info["sfreq"]
events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3))
raw.pick(picks="eeg", exclude="bads")
raw.load_data()

Expand Down Expand Up @@ -95,10 +94,9 @@
# Extract epochs from filtered data, padded by window size
epochs = Epochs(
raw_filter,
events,
event_id,
tmin - w_size,
tmax + w_size,
event_id=["hands", "feet"],
tmin=tmin - w_size,
tmax=tmax + w_size,
proj=False,
baseline=None,
preload=True,
Expand Down Expand Up @@ -148,10 +146,9 @@
# Extract epochs from filtered data, padded by window size
epochs = Epochs(
raw_filter,
events,
event_id,
tmin - w_size,
tmax + w_size,
event_id=["hands", "feet"],
tmin=tmin - w_size,
tmax=tmax + w_size,
proj=False,
baseline=None,
preload=True,
Expand Down
11 changes: 5 additions & 6 deletions examples/time_frequency/time_frequency_erds.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
raw = concatenate_raws([read_raw_edf(f, preload=True) for f in fnames])

raw.rename_channels(lambda x: x.strip(".")) # remove dots from channel names

events, _ = mne.events_from_annotations(raw, event_id=dict(T1=2, T2=3))
# rename descriptions to be more easily interpretable
raw.annotations.rename(dict(T1="hands", T2="feet"))

# %%
# Now we can create 5-second epochs around events of interest.
Expand All @@ -64,10 +64,9 @@

epochs = mne.Epochs(
raw,
events,
event_ids,
tmin - 0.5,
tmax + 0.5,
event_id=["hands", "feet"],
tmin=tmin - 0.5,
tmax=tmax + 0.5,
picks=("C3", "Cz", "C4"),
baseline=None,
preload=True,
Expand Down
6 changes: 1 addition & 5 deletions examples/visualization/eyetracking_plot_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,8 @@

mne.preprocessing.eyetracking.interpolate_blinks(raw, interpolate_gaze=True)
raw.annotations.rename({"dvns": "natural"}) # more intuitive
event_ids = {"natural": 1}
events, event_dict = mne.events_from_annotations(raw, event_id=event_ids)

epochs = mne.Epochs(
raw, events=events, event_id=event_dict, tmin=0, tmax=20, baseline=None
)
epochs = mne.Epochs(raw, event_id=["natural"], tmin=0, tmax=20, baseline=None)


# %%
Expand Down
47 changes: 45 additions & 2 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
EpochAnnotationsMixin,
_read_annotations_fif,
_write_annotations,
events_from_annotations,
)
from .baseline import _check_baseline, _log_rescale, rescale
from .bem import _check_origin
Expand Down Expand Up @@ -3111,7 +3112,16 @@ class Epochs(BaseEpochs):
Parameters
----------
%(raw_epochs)s

.. note::
If ``raw`` contains annotations, ``Epochs`` can be constructed around
``raw.annotations.onset``, but note that the durations of the annotations
are ignored in this case.
%(events_epochs)s

.. versionchanged:: 1.7
Allow ``events=None`` to use ``raw.annotations.onset`` as the source of
epoch times.
%(event_id)s
%(epochs_tmin_tmax)s
%(baseline_epochs)s
Expand Down Expand Up @@ -3212,7 +3222,7 @@ class Epochs(BaseEpochs):
def __init__(
self,
raw,
events,
events=None,
hoechenberger marked this conversation as resolved.
Show resolved Hide resolved
event_id=None,
tmin=-0.2,
tmax=0.5,
Expand Down Expand Up @@ -3240,6 +3250,7 @@ def __init__(
"instance of mne.io.BaseRaw"
)
info = deepcopy(raw.info)
annotations = raw.annotations.copy()

# proj is on when applied in Raw
proj = proj or raw.proj
Expand All @@ -3249,6 +3260,38 @@ def __init__(
# keep track of original sfreq (needed for annotations)
raw_sfreq = raw.info["sfreq"]

# get events from annotations if no events given
if events is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put this block in a private function to be clear about what it needs and returns

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, the __init__ is already long enough as it is and a separate function would help with that

events, event_id_tmp = events_from_annotations(raw)
if events.size == 0:
raise RuntimeError(
"No usable annotations found in the raw object. "
"Either `events` must be provided or the raw "
"object must have annotations to construct epochs"
)
if any(raw.annotations.duration > 0):
logger.info(
"Ignoring annotation durations and creating fixed-duration epochs "
"around annotation onsets."
)
if event_id is None:
event_id = event_id_tmp
# if event_id is the names of events, map to events integers
if isinstance(event_id, str):
event_id = [event_id]
if isinstance(event_id, (list, tuple, set)):
if set(event_id).issubset(set(event_id_tmp)):
event_id = {my_id: event_id_tmp[my_id] for my_id in event_id}
# remove any non-selected annotations
annotations.delete(
~np.isin(raw.annotations.description, list(event_id))
)
else:
raise RuntimeError(
f"event_id(s) {set(event_id) - set(event_id_tmp)} "
"not found in annotations"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here should probably mimic what we do already for missing event IDs, e.g.,. having Epochs(raw, event_id=["auditory", "visual"], on_missing="ignore") should be okay when there are only "auditory" events in the annotations. Maybe you can remove some of the logic here to allow the existing event_id+on_missing checking/logic to take care of some stuff? (And please add a test for this case!)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this comment was not addressed yet and was errantly resolved by GitHub because the code was moved

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh I see there is an _on_missing in there now, never mind! Maybe we could deduplicate more at some point (?) but seems okay for now


# call BaseEpochs constructor
super(Epochs, self).__init__(
info,
Expand All @@ -3273,7 +3316,7 @@ def __init__(
event_repeated=event_repeated,
verbose=verbose,
raw_sfreq=raw_sfreq,
annotations=raw.annotations,
annotations=annotations,
)

@verbose
Expand Down
16 changes: 16 additions & 0 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,22 @@ def test_filter(tmp_path):
assert_allclose(epochs.get_data(), data_filt, atol=1e-17)


def test_epochs_from_annotations():
"""Test epoch instantiation using annotations."""
raw, events = _get_data()[:2]
with pytest.raises(
RuntimeError, match="No usable annotations found in the raw object"
):
Epochs(raw)
raw.set_annotations(
mne.annotations_from_events(
events, raw.info["sfreq"], first_samp=raw.first_samp
)
)
with pytest.raises(RuntimeError, match="not found in annotations"):
Epochs(raw, event_id="foo")


def test_epochs_hash():
"""Test epoch hashing."""
raw, events = _get_data()[:2]
Expand Down
8 changes: 5 additions & 3 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,12 +1107,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
"""

docdict["event_id"] = """
event_id : int | list of int | dict | None
event_id : int | list of int | dict | str | list of str | None
The id of the :term:`events` to consider. If dict, the keys can later be
used to access associated :term:`events`. Example:
dict(auditory=1, visual=3). If int, a dict will be created with the id as
string. If a list, all :term:`events` with the IDs specified in the list
are used. If None, all :term:`events` will be used and a dict is created
string. If a list of int, all :term:`events` with the IDs specified in the list
are used. If a str or list of str, ``events`` must be ``None`` to use annotations
and then the IDs must be the name(s) of the annotations to use.
If None, all :term:`events` will be used and a dict is created
with string integer names corresponding to the event id integers."""

docdict["event_id_ecg"] = """
Expand Down
3 changes: 1 addition & 2 deletions tutorials/clinical/20_seeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@

raw = mne.io.read_raw(misc_path / "seeg" / "sample_seeg_ieeg.fif")

events, event_id = mne.events_from_annotations(raw)
epochs = mne.Epochs(raw, events, event_id, detrend=1, baseline=None)
epochs = mne.Epochs(raw, detrend=1, baseline=None)
epochs = epochs["Response"][0] # just process one epoch of data for speed

# %%
Expand Down
6 changes: 1 addition & 5 deletions tutorials/clinical/30_ecog.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,11 @@
# at the posterior commissure)
raw.set_montage(montage)

# Find the annotated events
events, event_id = mne.events_from_annotations(raw)

# Make a 25 second epoch that spans before and after the seizure onset
epoch_length = 25 # seconds
epochs = mne.Epochs(
raw,
events,
event_id=event_id["onset"],
event_id="onset",
tmin=13,
tmax=13 + epoch_length,
baseline=None,
Expand Down
14 changes: 6 additions & 8 deletions tutorials/time-freq/50_ssvep.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,12 @@
raw.filter(l_freq=0.1, h_freq=None, fir_design="firwin", verbose=False)

# Construct epochs
event_id = {"12hz": 255, "15hz": 155}
events, _ = mne.events_from_annotations(raw, verbose=False)
raw.annotations.rename({"Stimulus/S255": "12hz", "Stimulus/S155": "15hz"})
tmin, tmax = -1.0, 20.0 # in s
baseline = None
epochs = mne.Epochs(
raw,
events=events,
event_id=[event_id["12hz"], event_id["15hz"]],
event_id=["12hz", "15hz"],
tmin=tmin,
tmax=tmax,
baseline=baseline,
Expand Down Expand Up @@ -356,8 +354,8 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1):
# Get indices for the different trial types
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

i_trial_12hz = np.where(epochs.events[:, 2] == event_id["12hz"])[0]
i_trial_15hz = np.where(epochs.events[:, 2] == event_id["15hz"])[0]
i_trial_12hz = np.where(epochs.annotations.description == "12hz")[0]
i_trial_15hz = np.where(epochs.annotations.description == "15hz")[0]

# %%
# Get indices of EEG channels forming the ROI
Expand Down Expand Up @@ -604,7 +602,7 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1):
window_snrs = [[]] * len(window_lengths)
for i_win, win in enumerate(window_lengths):
# compute spectrogram
this_spectrum = epochs[str(event_id["12hz"])].compute_psd(
this_spectrum = epochs["12hz"].compute_psd(
"welch",
n_fft=int(sfreq * win),
n_overlap=0,
Expand Down Expand Up @@ -688,7 +686,7 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1):

for i_win, win in enumerate(window_starts):
# compute spectrogram
this_spectrum = epochs[str(event_id["12hz"])].compute_psd(
this_spectrum = epochs["12hz"].compute_psd(
"welch",
n_fft=int(sfreq * window_length) - 1,
n_overlap=0,
Expand Down