Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH, MRG] Allow epoch construction from annotations #12311

Merged
merged 34 commits into from
Dec 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4fa044e
[ENH] Allow epoch construction from annotations
alexrockhill Dec 19, 2023
eb73298
ruff
alexrockhill Dec 19, 2023
daee067
go through all examples using events_from_annotations
alexrockhill Dec 19, 2023
4447ed8
ruff
alexrockhill Dec 19, 2023
24270d1
Merge branch 'main' into epochs
alexrockhill Dec 19, 2023
7b5b9d8
fix ssvep tut
alexrockhill Dec 19, 2023
d4bc5c9
ruff
alexrockhill Dec 19, 2023
1378048
style
alexrockhill Dec 19, 2023
241e80b
oops don't modify raw annotations
alexrockhill Dec 19, 2023
3bc03fb
fix example
alexrockhill Dec 20, 2023
c2098a4
fix logic
alexrockhill Dec 20, 2023
47cc039
Richard review
alexrockhill Dec 20, 2023
5776741
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
alexrockhill Dec 20, 2023
3914d0b
add a version changed note
alexrockhill Dec 20, 2023
029d131
style
alexrockhill Dec 20, 2023
a44554a
fix test
alexrockhill Dec 20, 2023
7391b0f
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
alexrockhill Dec 20, 2023
f102534
merge
alexrockhill Dec 20, 2023
bf97c0a
fix versionchanged formatting
alexrockhill Dec 20, 2023
57477b5
add note to raw param
alexrockhill Dec 21, 2023
5b9e571
events->annotations
alexrockhill Dec 21, 2023
517c789
Merge branch 'main' into epochs
alexrockhill Dec 21, 2023
1290aac
Dan review
alexrockhill Dec 21, 2023
4c46e11
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2023
27615c8
style
alexrockhill Dec 21, 2023
7db0b59
Merge branch 'main' into epochs
alexrockhill Dec 21, 2023
fb87172
Merge branch 'epochs' of https://github.com/alexrockhill/mne-python i…
alexrockhill Dec 21, 2023
7ef087d
space
alexrockhill Dec 21, 2023
c3ae450
fix outdated example
alexrockhill Dec 21, 2023
f9401bd
style
alexrockhill Dec 21, 2023
cad31f4
review
alexrockhill Dec 22, 2023
2c33dab
Merge branch 'main' into epochs
hoechenberger Dec 27, 2023
52542ea
ignore warning
alexrockhill Dec 30, 2023
135f523
fix xvfb
alexrockhill Dec 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel/12311.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:class:`mne.Epochs` can now be constructed using :class:`mne.Annotations` stored in the ``raw`` object, by specifying ``events=None``. By `Alex Rockhill`_.
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,7 @@ def reset_warnings(gallery_conf, fname):
for key in (
"invalid version and will not be supported", # pyxdf
"distutils Version classes are deprecated", # seaborn and neo
"is_categorical_dtype is deprecated", # seaborn
"`np.object` is a deprecated alias for the builtin `object`", # pyxdf
# nilearn, should be fixed in > 0.9.1
"In future, it will be an error for 'np.bool_' scalars to",
Expand Down
13 changes: 5 additions & 8 deletions examples/decoding/decoding_csp_eeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from sklearn.model_selection import ShuffleSplit, cross_val_score
from sklearn.pipeline import Pipeline

from mne import Epochs, events_from_annotations, pick_types
from mne import Epochs, pick_types
from mne.channels import make_standard_montage
from mne.datasets import eegbci
from mne.decoding import CSP
Expand All @@ -41,7 +41,6 @@
# avoid classification of evoked responses by using epochs that start 1s after
# cue onset.
tmin, tmax = -1.0, 4.0
event_id = dict(hands=2, feet=3)
subject = 1
runs = [6, 10, 14] # motor imagery: hands vs feet

Expand All @@ -50,22 +49,20 @@
eegbci.standardize(raw) # set channel names
montage = make_standard_montage("standard_1005")
raw.set_montage(montage)
raw.annotations.rename(dict(T1="hands", T2="feet"))

# Apply band-pass filter
raw.filter(7.0, 30.0, fir_design="firwin", skip_by_annotation="edge")

events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3))

picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads")

# Read epochs (train will be done only between 1 and 2s)
# Testing will be done with a running classifier
epochs = Epochs(
raw,
events,
event_id,
tmin,
tmax,
event_id=["hands", "feet"],
tmin=tmin,
tmax=tmax,
proj=True,
picks=picks,
baseline=None,
Expand Down
19 changes: 8 additions & 11 deletions examples/decoding/decoding_csp_timefreq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,22 @@
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import LabelEncoder

from mne import Epochs, create_info, events_from_annotations
from mne import Epochs, create_info
from mne.datasets import eegbci
from mne.decoding import CSP
from mne.io import concatenate_raws, read_raw_edf
from mne.time_frequency import AverageTFR

# %%
# Set parameters and read data
event_id = dict(hands=2, feet=3) # motor imagery: hands vs feet
subject = 1
runs = [6, 10, 14]
raw_fnames = eegbci.load_data(subject, runs)
raw = concatenate_raws([read_raw_edf(f) for f in raw_fnames])
raw.annotations.rename(dict(T1="hands", T2="feet"))

# Extract information from the raw file
sfreq = raw.info["sfreq"]
events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3))
raw.pick(picks="eeg", exclude="bads")
raw.load_data()

Expand Down Expand Up @@ -95,10 +94,9 @@
# Extract epochs from filtered data, padded by window size
epochs = Epochs(
raw_filter,
events,
event_id,
tmin - w_size,
tmax + w_size,
event_id=["hands", "feet"],
tmin=tmin - w_size,
tmax=tmax + w_size,
proj=False,
baseline=None,
preload=True,
Expand Down Expand Up @@ -148,10 +146,9 @@
# Extract epochs from filtered data, padded by window size
epochs = Epochs(
raw_filter,
events,
event_id,
tmin - w_size,
tmax + w_size,
event_id=["hands", "feet"],
tmin=tmin - w_size,
tmax=tmax + w_size,
proj=False,
baseline=None,
preload=True,
Expand Down
11 changes: 5 additions & 6 deletions examples/time_frequency/time_frequency_erds.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
raw = concatenate_raws([read_raw_edf(f, preload=True) for f in fnames])

raw.rename_channels(lambda x: x.strip(".")) # remove dots from channel names

events, _ = mne.events_from_annotations(raw, event_id=dict(T1=2, T2=3))
# rename descriptions to be more easily interpretable
raw.annotations.rename(dict(T1="hands", T2="feet"))

# %%
# Now we can create 5-second epochs around events of interest.
Expand All @@ -64,10 +64,9 @@

epochs = mne.Epochs(
raw,
events,
event_ids,
tmin - 0.5,
tmax + 0.5,
event_id=["hands", "feet"],
tmin=tmin - 0.5,
tmax=tmax + 0.5,
picks=("C3", "Cz", "C4"),
baseline=None,
preload=True,
Expand Down
6 changes: 1 addition & 5 deletions examples/visualization/eyetracking_plot_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,8 @@

mne.preprocessing.eyetracking.interpolate_blinks(raw, interpolate_gaze=True)
raw.annotations.rename({"dvns": "natural"}) # more intuitive
event_ids = {"natural": 1}
events, event_dict = mne.events_from_annotations(raw, event_id=event_ids)

epochs = mne.Epochs(
raw, events=events, event_id=event_dict, tmin=0, tmax=20, baseline=None
)
epochs = mne.Epochs(raw, event_id=["natural"], tmin=0, tmax=20, baseline=None)


# %%
Expand Down
60 changes: 54 additions & 6 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
EpochAnnotationsMixin,
_read_annotations_fif,
_write_annotations,
events_from_annotations,
)
from .baseline import _check_baseline, _log_rescale, rescale
from .bem import _check_origin
Expand Down Expand Up @@ -487,10 +488,7 @@ def __init__(
if events is not None: # RtEpochs can have events=None
for key, val in self.event_id.items():
if val not in events[:, 2]:
msg = "No matching events found for %s " "(event id %i)" % (
key,
val,
)
msg = f"No matching events found for {key} (event id {val})"
_on_missing(on_missing, msg)

# ensure metadata matches original events size
Expand Down Expand Up @@ -3104,14 +3102,57 @@ def _ensure_list(x):
return metadata, events, event_id


def _events_from_annotations(raw, events, event_id, annotations, on_missing):
"""Generate events and event_ids from annotations."""
events, event_id_tmp = events_from_annotations(raw)
if events.size == 0:
raise RuntimeError(
"No usable annotations found in the raw object. "
"Either `events` must be provided or the raw "
"object must have annotations to construct epochs"
)
if any(raw.annotations.duration > 0):
logger.info(
"Ignoring annotation durations and creating fixed-duration epochs "
"around annotation onsets."
)
if event_id is None:
event_id = event_id_tmp
# if event_id is the names of events, map to events integers
if isinstance(event_id, str):
event_id = [event_id]
if isinstance(event_id, (list, tuple, set)):
if not set(event_id).issubset(set(event_id_tmp)):
msg = (
"No matching annotations found for event_id(s) "
f"{set(event_id) - set(event_id_tmp)}"
)
_on_missing(on_missing, msg)
# remove extras if on_missing not error
event_id = set(event_id) & set(event_id_tmp)
event_id = {my_id: event_id_tmp[my_id] for my_id in event_id}
# remove any non-selected annotations
annotations.delete(~np.isin(raw.annotations.description, list(event_id)))
return events, event_id, annotations


@fill_doc
class Epochs(BaseEpochs):
"""Epochs extracted from a Raw instance.

Parameters
----------
%(raw_epochs)s

.. note::
If ``raw`` contains annotations, ``Epochs`` can be constructed around
``raw.annotations.onset``, but note that the durations of the annotations
are ignored in this case.
%(events_epochs)s

.. versionchanged:: 1.7
Allow ``events=None`` to use ``raw.annotations.onset`` as the source of
epoch times.
%(event_id)s
%(epochs_tmin_tmax)s
%(baseline_epochs)s
Expand Down Expand Up @@ -3212,7 +3253,7 @@ class Epochs(BaseEpochs):
def __init__(
self,
raw,
events,
events=None,
hoechenberger marked this conversation as resolved.
Show resolved Hide resolved
event_id=None,
tmin=-0.2,
tmax=0.5,
Expand Down Expand Up @@ -3240,6 +3281,7 @@ def __init__(
"instance of mne.io.BaseRaw"
)
info = deepcopy(raw.info)
annotations = raw.annotations.copy()

# proj is on when applied in Raw
proj = proj or raw.proj
Expand All @@ -3249,6 +3291,12 @@ def __init__(
# keep track of original sfreq (needed for annotations)
raw_sfreq = raw.info["sfreq"]

# get events from annotations if no events given
if events is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put this block in a private function to be clear about what it needs and returns

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, the __init__ is already long enough as it is and a separate function would help with that

events, event_id, annotations = _events_from_annotations(
raw, events, event_id, annotations, on_missing
)

# call BaseEpochs constructor
super(Epochs, self).__init__(
info,
Expand All @@ -3273,7 +3321,7 @@ def __init__(
event_repeated=event_repeated,
verbose=verbose,
raw_sfreq=raw_sfreq,
annotations=raw.annotations,
annotations=annotations,
)

@verbose
Expand Down
20 changes: 20 additions & 0 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,26 @@ def test_filter(tmp_path):
assert_allclose(epochs.get_data(), data_filt, atol=1e-17)


def test_epochs_from_annotations():
"""Test epoch instantiation using annotations."""
raw, events = _get_data()[:2]
with pytest.raises(
RuntimeError, match="No usable annotations found in the raw object"
):
Epochs(raw)
raw.set_annotations(
mne.annotations_from_events(
events, raw.info["sfreq"], first_samp=raw.first_samp
)
)
# test on_missing
with pytest.raises(ValueError, match="No matching annotations"):
Epochs(raw, event_id="foo")
# test on_missing warn
with pytest.warns(match="No matching annotations"):
Epochs(raw, event_id=["1", "foo"], on_missing="warn")


def test_epochs_hash():
"""Test epoch hashing."""
raw, events = _get_data()[:2]
Expand Down
8 changes: 5 additions & 3 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,12 +1107,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
"""

docdict["event_id"] = """
event_id : int | list of int | dict | None
event_id : int | list of int | dict | str | list of str | None
The id of the :term:`events` to consider. If dict, the keys can later be
used to access associated :term:`events`. Example:
dict(auditory=1, visual=3). If int, a dict will be created with the id as
string. If a list, all :term:`events` with the IDs specified in the list
are used. If None, all :term:`events` will be used and a dict is created
string. If a list of int, all :term:`events` with the IDs specified in the list
are used. If a str or list of str, ``events`` must be ``None`` to use annotations
and then the IDs must be the name(s) of the annotations to use.
If None, all :term:`events` will be used and a dict is created
with string integer names corresponding to the event id integers."""

docdict["event_id_ecg"] = """
Expand Down
2 changes: 1 addition & 1 deletion tools/setup_xvfb.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ done

# This also includes the libraries necessary for PyQt5/PyQt6
sudo apt update
sudo apt install -yqq xvfb libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-xfixes0 libopengl0 libegl1 libosmesa6 mesa-utils libxcb-shape0 libxcb-cursor0
sudo apt install -yqq xvfb libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-xfixes0 libopengl0 libegl1 libosmesa6 mesa-utils libxcb-shape0 libxcb-cursor0 libxml2
/sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1400x900x24 -ac +extension GLX +render -noreset
3 changes: 1 addition & 2 deletions tutorials/clinical/20_seeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@

raw = mne.io.read_raw(misc_path / "seeg" / "sample_seeg_ieeg.fif")

events, event_id = mne.events_from_annotations(raw)
epochs = mne.Epochs(raw, events, event_id, detrend=1, baseline=None)
epochs = mne.Epochs(raw, detrend=1, baseline=None)
epochs = epochs["Response"][0] # just process one epoch of data for speed

# %%
Expand Down
6 changes: 1 addition & 5 deletions tutorials/clinical/30_ecog.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,11 @@
# at the posterior commissure)
raw.set_montage(montage)

# Find the annotated events
events, event_id = mne.events_from_annotations(raw)

# Make a 25 second epoch that spans before and after the seizure onset
epoch_length = 25 # seconds
epochs = mne.Epochs(
raw,
events,
event_id=event_id["onset"],
event_id="onset",
tmin=13,
tmax=13 + epoch_length,
baseline=None,
Expand Down
14 changes: 6 additions & 8 deletions tutorials/time-freq/50_ssvep.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,12 @@
raw.filter(l_freq=0.1, h_freq=None, fir_design="firwin", verbose=False)

# Construct epochs
event_id = {"12hz": 255, "15hz": 155}
events, _ = mne.events_from_annotations(raw, verbose=False)
raw.annotations.rename({"Stimulus/S255": "12hz", "Stimulus/S155": "15hz"})
tmin, tmax = -1.0, 20.0 # in s
baseline = None
epochs = mne.Epochs(
raw,
events=events,
event_id=[event_id["12hz"], event_id["15hz"]],
event_id=["12hz", "15hz"],
tmin=tmin,
tmax=tmax,
baseline=baseline,
Expand Down Expand Up @@ -356,8 +354,8 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1):
# Get indices for the different trial types
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

i_trial_12hz = np.where(epochs.events[:, 2] == event_id["12hz"])[0]
i_trial_15hz = np.where(epochs.events[:, 2] == event_id["15hz"])[0]
i_trial_12hz = np.where(epochs.annotations.description == "12hz")[0]
i_trial_15hz = np.where(epochs.annotations.description == "15hz")[0]

# %%
# Get indices of EEG channels forming the ROI
Expand Down Expand Up @@ -604,7 +602,7 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1):
window_snrs = [[]] * len(window_lengths)
for i_win, win in enumerate(window_lengths):
# compute spectrogram
this_spectrum = epochs[str(event_id["12hz"])].compute_psd(
this_spectrum = epochs["12hz"].compute_psd(
"welch",
n_fft=int(sfreq * win),
n_overlap=0,
Expand Down Expand Up @@ -688,7 +686,7 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1):

for i_win, win in enumerate(window_starts):
# compute spectrogram
this_spectrum = epochs[str(event_id["12hz"])].compute_psd(
this_spectrum = epochs["12hz"].compute_psd(
"welch",
n_fft=int(sfreq * window_length) - 1,
n_overlap=0,
Expand Down