diff --git a/mne/__init__.pyi b/mne/__init__.pyi index d50b5209346..6560854402e 100644 --- a/mne/__init__.pyi +++ b/mne/__init__.pyi @@ -11,6 +11,7 @@ __all__ = [ "Evoked", "EvokedArray", "Forward", + "HEDAnnotations", "Info", "Label", "MixedSourceEstimate", @@ -260,6 +261,7 @@ from ._freesurfer import ( ) from .annotations import ( Annotations, + HEDAnnotations, annotations_from_events, count_annotations, events_from_annotations, diff --git a/mne/annotations.py b/mne/annotations.py index 629ee7b20cb..09d907637a7 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -52,6 +52,7 @@ verbose, warn, ) +from .utils.check import _soft_import # For testing windows_like_datetime, we monkeypatch "datetime" in this module. # Keep the true datetime object around for _validate_type use. @@ -151,6 +152,7 @@ class Annotations: -------- mne.annotations_from_events mne.events_from_annotations + mne.HEDAnnotations Notes ----- @@ -288,7 +290,7 @@ def orig_time(self): def __eq__(self, other): """Compare to another Annotations instance.""" - if not isinstance(other, Annotations): + if not isinstance(other, type(self)): return False return ( np.array_equal(self.onset, other.onset) @@ -567,6 +569,8 @@ def _sort(self): self.duration = self.duration[order] self.description = self.description[order] self.ch_names = self.ch_names[order] + if hasattr(self, "hed_strings"): + self.hed_strings = self.hed_strings[order] @verbose def crop( @@ -758,6 +762,163 @@ def rename(self, mapping, verbose=None): return self +class HEDAnnotations(Annotations): + """Annotations object for annotating segments of raw data with HED tags. + + Parameters + ---------- + onset : array of float, shape (n_annotations,) + The starting time of annotations in seconds after ``orig_time``. + duration : array of float, shape (n_annotations,) | float + Durations of the annotations in seconds. If a float, all the + annotations are given the same duration. + description : array of str, shape (n_annotations,) | str + Array of strings containing description for each annotation. If a + string, all the annotations are given the same description. To reject + epochs, use description starting with keyword 'bad'. See example above. + hed_strings : array of str, shape (n_annotations,) | str + Sequence of strings containing a HED tag (or comma-separated list of HED tags) + for each annotation. If a single string is provided, all annotations are + assigned the same HED string. + hed_version : str + The HED schema version against which to validate the HED strings. + orig_time : float | str | datetime | tuple of int | None + A POSIX Timestamp, datetime or a tuple containing the timestamp as the + first element and microseconds as the second element. Determines the + starting time of annotation acquisition. If None (default), + starting time is determined from beginning of raw data acquisition. + In general, ``raw.info['meas_date']`` (or None) can be used for syncing + the annotations with raw data if their acquisition is started at the + same time. If it is a string, it should conform to the ISO8601 format. + More precisely to this '%%Y-%%m-%%d %%H:%%M:%%S.%%f' particular case of + the ISO8601 format where the delimiter between date and time is ' '. + %(ch_names_annot)s + + See Also + -------- + mne.Annotations + + Notes + ----- + + .. versionadded:: 1.10 + """ + + def __init__( + self, + onset, + duration, + description, + hed_strings, + hed_version="8.3.0", # TODO @VisLab what is a sensible default here? + orig_time=None, + ch_names=None, + ): + self.hed = _soft_import("hed", "validation of HED tags in annotations") + + super().__init__( + onset=onset, + duration=duration, + description=description, + orig_time=orig_time, + ch_names=ch_names, + ) + self.hed_version = hed_version + self._update_hed_strings(hed_strings=hed_strings) + + def _update_hed_strings(self, hed_strings): + # NB: must import; calling self.hed.validator.HedValidator doesn't work + from hed.validator import HedValidator + + if len(hed_strings) != len(self): + raise ValueError( + f"Number of HED strings ({len(hed_strings)}) must match the number of " + f"annotations ({len(self)})." + ) + # validation of HED strings + schema = self.hed.load_schema_version(self.hed_version) + validator = HedValidator(schema) + error_handler = self.hed.errors.ErrorHandler(check_for_warnings=False) + error_strs = [ + self._validate_one_hed_string(hs, schema, validator, error_handler) + for hs in hed_strings + ] + if any(map(len, error_strs)): + raise ValueError( + "Some HED strings in your annotations failed to validate:\n - " + + "\n - ".join(error_strs) + ) + self.hed_strings = hed_strings + + def _validate_one_hed_string(self, hed_string, schema, validator, error_handler): + """Validate a user-provided HED string.""" + hs = self.hed.HedString(hed_string, schema) + issues = validator.validate( + hs, allow_placeholders=False, error_handler=error_handler + ) + return self.hed.get_printable_issue_string(issues) + + def __eq__(self, other): + """Compare to another HEDAnnotations instance.""" + return ( + super().__eq__(self, other) + and np.array_equal(self.hed_strings, other.hed_strings) + and self.hed_version == other.hed_version + ) + + def __repr__(self): + """Show a textual summary of the object.""" + counter = Counter(self.hed_strings) + kinds = ", ".join(["{} ({})".format(*k) for k in sorted(counter.items())]) + kinds = (": " if len(kinds) > 0 else "") + kinds + ch_specific = ", channel-specific" if self._any_ch_names() else "" + s = ( + f"HEDAnnotations | {len(self.onset)} segment" + f"{_pl(len(self.onset))}{ch_specific}{kinds}" + ) + return "<" + shorten(s, width=77, placeholder=" ...") + ">" + + def __getitem__(self, key, *, with_ch_names=None): + """Propagate indexing and slicing to the underlying numpy structure.""" + result = super().__getitem__(self, key, with_ch_names=with_ch_names) + if isinstance(result, OrderedDict): + result["hed_strings"] = self.hed_strings[key] + else: + key = list(key) if isinstance(key, tuple) else key + hed_strings = self.hed_strings[key] + return HEDAnnotations( + result.onset, + result.duration, + result.description, + hed_strings, + hed_version=self.hed_version, + orig_time=self.orig_time, + ch_names=result.ch_names, + ) + + def append(self, onset, duration, description, ch_names=None): + """TODO.""" + pass + + def count(self): + """TODO. Unlike Annotations.count, keys should be HED tags not descriptions.""" + pass + + def crop( + self, tmin=None, tmax=None, emit_warning=False, use_orig_time=True, verbose=None + ): + """TODO.""" + pass + + def delete(self, idx): + """TODO.""" + pass + + def to_data_frame(self, time_format="datetime"): + """TODO.""" + pass + + class EpochAnnotationsMixin: """Mixin class for Annotations in Epochs."""