Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH, MRG] Allow epoch construction from annotations #12311

Merged
merged 34 commits into from
Dec 31, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4fa044e
[ENH] Allow epoch construction from annotations
alexrockhill Dec 19, 2023
eb73298
ruff
alexrockhill Dec 19, 2023
daee067
go through all examples using events_from_annotations
alexrockhill Dec 19, 2023
4447ed8
ruff
alexrockhill Dec 19, 2023
24270d1
Merge branch 'main' into epochs
alexrockhill Dec 19, 2023
7b5b9d8
fix ssvep tut
alexrockhill Dec 19, 2023
d4bc5c9
ruff
alexrockhill Dec 19, 2023
1378048
style
alexrockhill Dec 19, 2023
241e80b
oops don't modify raw annotations
alexrockhill Dec 19, 2023
3bc03fb
fix example
alexrockhill Dec 20, 2023
c2098a4
fix logic
alexrockhill Dec 20, 2023
47cc039
Richard review
alexrockhill Dec 20, 2023
5776741
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
alexrockhill Dec 20, 2023
3914d0b
add a version changed note
alexrockhill Dec 20, 2023
029d131
style
alexrockhill Dec 20, 2023
a44554a
fix test
alexrockhill Dec 20, 2023
7391b0f
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
alexrockhill Dec 20, 2023
f102534
merge
alexrockhill Dec 20, 2023
bf97c0a
fix versionchanged formatting
alexrockhill Dec 20, 2023
57477b5
add note to raw param
alexrockhill Dec 21, 2023
5b9e571
events->annotations
alexrockhill Dec 21, 2023
517c789
Merge branch 'main' into epochs
alexrockhill Dec 21, 2023
1290aac
Dan review
alexrockhill Dec 21, 2023
4c46e11
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2023
27615c8
style
alexrockhill Dec 21, 2023
7db0b59
Merge branch 'main' into epochs
alexrockhill Dec 21, 2023
fb87172
Merge branch 'epochs' of https://github.com/alexrockhill/mne-python i…
alexrockhill Dec 21, 2023
7ef087d
space
alexrockhill Dec 21, 2023
c3ae450
fix outdated example
alexrockhill Dec 21, 2023
f9401bd
style
alexrockhill Dec 21, 2023
cad31f4
review
alexrockhill Dec 22, 2023
2c33dab
Merge branch 'main' into epochs
hoechenberger Dec 27, 2023
52542ea
ignore warning
alexrockhill Dec 30, 2023
135f523
fix xvfb
alexrockhill Dec 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel/12311.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:class:`mne.Epochs` can now be constructed using :class:`mne.Annotations` stored in the ``raw`` object, allowing ``events=None``. By `Alex Rockhill`_.
alexrockhill marked this conversation as resolved.
Show resolved Hide resolved
11 changes: 4 additions & 7 deletions examples/decoding/decoding_csp_eeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from sklearn.model_selection import ShuffleSplit, cross_val_score
from sklearn.pipeline import Pipeline

from mne import Epochs, events_from_annotations, pick_types
from mne import Epochs, pick_types
from mne.channels import make_standard_montage
from mne.datasets import eegbci
from mne.decoding import CSP
Expand All @@ -54,18 +54,15 @@
# Apply band-pass filter
raw.filter(7.0, 30.0, fir_design="firwin", skip_by_annotation="edge")

events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3))

picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads")

# Read epochs (train will be done only between 1 and 2s)
# Testing will be done with a running classifier
epochs = Epochs(
raw,
events,
event_id,
tmin,
tmax,
event_id=dict(T1=2, T2=3),
tmin=tmin,
tmax=tmax,
proj=True,
picks=picks,
baseline=None,
Expand Down
19 changes: 8 additions & 11 deletions examples/decoding/decoding_csp_timefreq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,22 @@
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import LabelEncoder

from mne import Epochs, create_info, events_from_annotations
from mne import Epochs, create_info
from mne.datasets import eegbci
from mne.decoding import CSP
from mne.io import concatenate_raws, read_raw_edf
from mne.time_frequency import AverageTFR

# %%
# Set parameters and read data
event_id = dict(hands=2, feet=3) # motor imagery: hands vs feet
subject = 1
runs = [6, 10, 14]
raw_fnames = eegbci.load_data(subject, runs)
raw = concatenate_raws([read_raw_edf(f) for f in raw_fnames])
raw.annotations.rename(dict(T1="hands", T2="feet"))

# Extract information from the raw file
sfreq = raw.info["sfreq"]
events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3))
raw.pick(picks="eeg", exclude="bads")
raw.load_data()

Expand Down Expand Up @@ -95,10 +94,9 @@
# Extract epochs from filtered data, padded by window size
epochs = Epochs(
raw_filter,
events,
event_id,
tmin - w_size,
tmax + w_size,
event_id=["hands", "feet"],
tmin=tmin - w_size,
tmax=tmax + w_size,
proj=False,
baseline=None,
preload=True,
Expand Down Expand Up @@ -148,10 +146,9 @@
# Extract epochs from filtered data, padded by window size
epochs = Epochs(
raw_filter,
events,
event_id,
tmin - w_size,
tmax + w_size,
event_id=["hands", "feet"],
tmin=tmin - w_size,
tmax=tmax + w_size,
proj=False,
baseline=None,
preload=True,
Expand Down
11 changes: 5 additions & 6 deletions examples/time_frequency/time_frequency_erds.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
raw = concatenate_raws([read_raw_edf(f, preload=True) for f in fnames])

raw.rename_channels(lambda x: x.strip(".")) # remove dots from channel names

events, _ = mne.events_from_annotations(raw, event_id=dict(T1=2, T2=3))
# rename descriptions to be more easily interpretable
raw.annotations.rename(dict(T1="hands", T2="feet"))

# %%
# Now we can create 5-second epochs around events of interest.
Expand All @@ -64,10 +64,9 @@

epochs = mne.Epochs(
raw,
events,
event_ids,
tmin - 0.5,
tmax + 0.5,
event_id=["hands", "feet"],
tmin=tmin - 0.5,
tmax=tmax + 0.5,
picks=("C3", "Cz", "C4"),
baseline=None,
preload=True,
Expand Down
6 changes: 1 addition & 5 deletions examples/visualization/eyetracking_plot_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,8 @@

mne.preprocessing.eyetracking.interpolate_blinks(raw, interpolate_gaze=True)
raw.annotations.rename({"dvns": "natural"}) # more intuitive
event_ids = {"natural": 1}
events, event_dict = mne.events_from_annotations(raw, event_id=event_ids)

epochs = mne.Epochs(
raw, events=events, event_id=event_dict, tmin=0, tmax=20, baseline=None
)
epochs = mne.Epochs(raw, event_id=["natural"], tmin=0, tmax=20, baseline=None)


# %%
Expand Down
49 changes: 47 additions & 2 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
EpochAnnotationsMixin,
_read_annotations_fif,
_write_annotations,
events_from_annotations,
)
from .baseline import _check_baseline, _log_rescale, rescale
from .bem import _check_origin
Expand Down Expand Up @@ -3112,6 +3113,9 @@ class Epochs(BaseEpochs):
----------
%(raw_epochs)s
%(events_epochs)s

.. versionchanged:: 1.7
Allow ``events=None`` to use ``raw.annotations``.
alexrockhill marked this conversation as resolved.
Show resolved Hide resolved
%(event_id)s
%(epochs_tmin_tmax)s
%(baseline_epochs)s
Expand Down Expand Up @@ -3174,6 +3178,10 @@ class Epochs(BaseEpochs):

Notes
-----
When Epochs are constructed using only a raw object (from the annotations
stored in the raw object), the duration of the events are ignored since
Epochs must be the same time length by design.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this note should go into a .. note:: section for the raw parameter – whose docstring should be updated anyway to suggest that annotations will be used if no events are being passed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this note is also confusing, because it mixes annotations and events ("the duration of the events are ignored").

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment wasn't addressed so far :)


When accessing data, Epochs are detrended, baseline-corrected, and
decimated, then projectors are (optionally) applied.

Expand Down Expand Up @@ -3212,7 +3220,7 @@ class Epochs(BaseEpochs):
def __init__(
self,
raw,
events,
events=None,
hoechenberger marked this conversation as resolved.
Show resolved Hide resolved
event_id=None,
tmin=-0.2,
tmax=0.5,
Expand Down Expand Up @@ -3240,6 +3248,7 @@ def __init__(
"instance of mne.io.BaseRaw"
)
info = deepcopy(raw.info)
annotations = raw.annotations.copy()

# proj is on when applied in Raw
proj = proj or raw.proj
Expand All @@ -3249,6 +3258,42 @@ def __init__(
# keep track of original sfreq (needed for annotations)
raw_sfreq = raw.info["sfreq"]

# get events from annotations if no events given
if events is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put this block in a private function to be clear about what it needs and returns

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, the __init__ is already long enough as it is and a separate function would help with that

events, event_id_tmp = events_from_annotations(raw)
if events.size == 0:
raise RuntimeError(
"No usable annotations found in the raw object. "
"Either `events` must be provided or the raw "
"object must have annotations to construct epochs"
)
if any(raw.annotations.duration > 0):
logger.info(
"Ignoring annotation durations, only fixed "
"duration epochs are currently supported"
alexrockhill marked this conversation as resolved.
Show resolved Hide resolved
)
if event_id is None:
event_id = event_id_tmp
# if event_id is the names of events, map to events integers
elif isinstance(event_id, (str, list, tuple)):
if isinstance(event_id, str):
event_id = [event_id]
if all([my_id in event_id_tmp for my_id in event_id]):
alexrockhill marked this conversation as resolved.
Show resolved Hide resolved
event_id = {my_id: event_id_tmp[my_id] for my_id in event_id}
# remove any non-selected annotations
annotations.delete(
[
i
for i, desc in enumerate(raw.annotations.description)
if desc not in event_id
]
)
else:
raise RuntimeError(
f"event_id(s) {set(event_id).difference(set(event_id_tmp))} "
alexrockhill marked this conversation as resolved.
Show resolved Hide resolved
"not found in annotations"
)

# call BaseEpochs constructor
super(Epochs, self).__init__(
info,
Expand All @@ -3273,7 +3318,7 @@ def __init__(
event_repeated=event_repeated,
verbose=verbose,
raw_sfreq=raw_sfreq,
annotations=raw.annotations,
annotations=annotations,
)

@verbose
Expand Down
18 changes: 18 additions & 0 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,24 @@ def test_filter(tmp_path):
assert_allclose(epochs.get_data(), data_filt, atol=1e-17)


def test_epochs_from_annotations():
"""Test epoch instantiation using annotations."""
raw, events = _get_data()[:2]
with pytest.raises(
RuntimeError, match="No usable annotations found in the raw object"
):
Epochs(raw)
raw.set_annotations(
mne.annotations_from_events(
events, raw.info["sfreq"], first_samp=raw.first_samp
)
)
with pytest.raises(RuntimeError, match="not found in annotations"):
Epochs(raw, event_id="foo")
with pytest.raises(RuntimeError, match="not found in annotations"):
Epochs(raw, event_id=["foo"])
alexrockhill marked this conversation as resolved.
Show resolved Hide resolved


def test_epochs_hash():
"""Test epoch hashing."""
raw, events = _get_data()[:2]
Expand Down
8 changes: 5 additions & 3 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,12 +1107,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
"""

docdict["event_id"] = """
event_id : int | list of int | dict | None
event_id : int | list of int | dict | str | list of str | None
The id of the :term:`events` to consider. If dict, the keys can later be
used to access associated :term:`events`. Example:
dict(auditory=1, visual=3). If int, a dict will be created with the id as
string. If a list, all :term:`events` with the IDs specified in the list
are used. If None, all :term:`events` will be used and a dict is created
string. If a list of int, all :term:`events` with the IDs specified in the list
are used. If a str or list of str, ``events`` must be ``None`` to use annotations
and then the IDs must be the name(s) of the annotations to use.
If None, all :term:`events` will be used and a dict is created
with string integer names corresponding to the event id integers."""

docdict["event_id_ecg"] = """
Expand Down
3 changes: 1 addition & 2 deletions tutorials/clinical/20_seeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@

raw = mne.io.read_raw(misc_path / "seeg" / "sample_seeg_ieeg.fif")

events, event_id = mne.events_from_annotations(raw)
epochs = mne.Epochs(raw, events, event_id, detrend=1, baseline=None)
epochs = mne.Epochs(raw, detrend=1, baseline=None)
epochs = epochs["Response"][0] # just process one epoch of data for speed

# %%
Expand Down
6 changes: 1 addition & 5 deletions tutorials/clinical/30_ecog.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,11 @@
# at the posterior commissure)
raw.set_montage(montage)

# Find the annotated events
events, event_id = mne.events_from_annotations(raw)

# Make a 25 second epoch that spans before and after the seizure onset
epoch_length = 25 # seconds
epochs = mne.Epochs(
raw,
events,
event_id=event_id["onset"],
event_id="onset",
tmin=13,
tmax=13 + epoch_length,
baseline=None,
Expand Down
14 changes: 6 additions & 8 deletions tutorials/time-freq/50_ssvep.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,12 @@
raw.filter(l_freq=0.1, h_freq=None, fir_design="firwin", verbose=False)

# Construct epochs
event_id = {"12hz": 255, "15hz": 155}
events, _ = mne.events_from_annotations(raw, verbose=False)
raw.annotations.rename({"Stimulus/S255": "12hz", "Stimulus/S155": "15hz"})
tmin, tmax = -1.0, 20.0 # in s
baseline = None
epochs = mne.Epochs(
raw,
events=events,
event_id=[event_id["12hz"], event_id["15hz"]],
event_id=["12hz", "15hz"],
tmin=tmin,
tmax=tmax,
baseline=baseline,
Expand Down Expand Up @@ -356,8 +354,8 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1):
# Get indices for the different trial types
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

i_trial_12hz = np.where(epochs.events[:, 2] == event_id["12hz"])[0]
i_trial_15hz = np.where(epochs.events[:, 2] == event_id["15hz"])[0]
i_trial_12hz = np.where(epochs.annotations.description == "12hz")[0]
i_trial_15hz = np.where(epochs.annotations.description == "15hz")[0]

# %%
# Get indices of EEG channels forming the ROI
Expand Down Expand Up @@ -604,7 +602,7 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1):
window_snrs = [[]] * len(window_lengths)
for i_win, win in enumerate(window_lengths):
# compute spectrogram
this_spectrum = epochs[str(event_id["12hz"])].compute_psd(
this_spectrum = epochs["12hz"].compute_psd(
"welch",
n_fft=int(sfreq * win),
n_overlap=0,
Expand Down Expand Up @@ -688,7 +686,7 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1):

for i_win, win in enumerate(window_starts):
# compute spectrogram
this_spectrum = epochs[str(event_id["12hz"])].compute_psd(
this_spectrum = epochs["12hz"].compute_psd(
"welch",
n_fft=int(sfreq * window_length) - 1,
n_overlap=0,
Expand Down
Loading