diff --git a/doc/changes/devel/13101.bugfix.rst b/doc/changes/devel/13101.bugfix.rst new file mode 100644 index 00000000000..d24e55b5056 --- /dev/null +++ b/doc/changes/devel/13101.bugfix.rst @@ -0,0 +1 @@ +Take units (m or mm) into account when drawing :func:`~mne.viz.plot_evoked_field` on top of :class:`~mne.viz.Brain`, by `Marijn van Vliet`_. diff --git a/mne/dipole.py b/mne/dipole.py index a40e9708db2..33a99ca87c5 100644 --- a/mne/dipole.py +++ b/mne/dipole.py @@ -842,7 +842,7 @@ def _write_dipole_bdip(fname, dip): fid.write(np.array(has_errors, ">i4").tobytes()) # has_errors fid.write(np.zeros(1, ">f4").tobytes()) # noise level for key in _BDIP_ERROR_KEYS: - val = dip.conf[key][ti] if key in dip.conf else 0.0 + val = dip.conf[key][ti] if key in dip.conf else np.array(0.0) assert val.shape == () fid.write(np.array(val, ">f4").tobytes()) fid.write(np.zeros(25, ">f4").tobytes()) @@ -1503,7 +1503,7 @@ def fit_dipole( if not bem["is_sphere"]: # Find the best-fitting sphere inner_skull = _bem_find_surface(bem, "inner_skull") - inner_skull = inner_skull.copy() + inner_skull = deepcopy(inner_skull) R, r0 = _fit_sphere(inner_skull["rr"], disp=False) # r0 back to head frame for logging r0 = apply_trans(mri_head_t["trans"], r0[np.newaxis, :])[0] diff --git a/mne/gui/__init__.pyi b/mne/gui/__init__.pyi index 086c51a4904..8f66ad387eb 100644 --- a/mne/gui/__init__.pyi +++ b/mne/gui/__init__.pyi @@ -1,2 +1,3 @@ -__all__ = ["_GUIScraper", "coregistration"] +__all__ = ["_GUIScraper", "coregistration", "DipoleFitUI"] from ._gui import _GUIScraper, coregistration +from ._xfit import DipoleFitUI diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py new file mode 100644 index 00000000000..dabb4a020dc --- /dev/null +++ b/mne/gui/_xfit.py @@ -0,0 +1,825 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +from copy import deepcopy +from functools import partial +from pathlib import Path + +import numpy as np +import pyvista + +from .. import pick_types +from ..bem import ( + ConductorModel, + _ensure_bem_surfaces, + fit_sphere_to_headshape, + make_sphere_model, +) +from ..cov import make_ad_hoc_cov +from ..dipole import Dipole, fit_dipole +from ..forward import convert_forward_solution, make_field_map, make_forward_dipole +from ..minimum_norm import apply_inverse, make_inverse_operator +from ..surface import _normal_orth +from ..transforms import _get_trans, _get_transforms_to_coord_frame, apply_trans +from ..utils import _check_option, fill_doc, logger, verbose +from ..viz import EvokedField, create_3d_figure +from ..viz._3d import _plot_head_surface, _plot_sensors_3d +from ..viz.ui_events import link, subscribe +from ..viz.utils import _get_color_list + + +@fill_doc +@verbose +def dipolefit( + evoked, + cov=None, + bem=None, + initial_time=None, + trans=None, + rank=None, + show_density=True, + subject=None, + subjects_dir=None, + n_jobs=None, + verbose=None, +): + """GUI for interactive dipole fitting, inspired by MEGIN's XFit program. + + Parameters + ---------- + evoked : instance of Evoked + Evoked data to show fieldmap of and fit dipoles to. + cov : instance of Covariance | None + Noise covariance matrix. If ``None``, an ad-hoc covariance matrix is used. + bem : instance of ConductorModel | None + Boundary element model to use in forward calculations. If ``None``, a spherical + model is used. + initial_time : float | None + Initial time point to show. If ``None``, the time point of the maximum field + strength is used. + trans : instance of Transform | None + The transformation from head coordinates to MRI coordinates. If ``None``, the + identity matrix is used. + stc : instance of SourceEstimate | None + An optional distributed source estimate to show alongside the fieldmap. + %(rank)s + show_density : bool + Whether to show the density of the fieldmap. + subject : str | None + The subject name. If ``None``, no MRI data is shown. + %(subjects_dir)s + %(n_jobs)s + %(verbose)s + """ + return DipoleFitUI( + evoked=evoked, + cov=cov, + bem=bem, + initial_time=initial_time, + trans=trans, + stc=None, + rank=rank, + show_density=show_density, + subject=subject, + subjects_dir=subjects_dir, + n_jobs=n_jobs, + verbose=verbose, + ) + + +@fill_doc +class DipoleFitUI: + """GUI for interactive dipole fitting, inspired by MEGIN's XFit program. + + Parameters + ---------- + evoked : instance of Evoked + Evoked data to show fieldmap of and fit dipoles to. + cov : instance of Covariance | "baseline" | None + Noise covariance matrix. If ``None``, an ad-hoc covariance matrix is used with + default values for the diagonal elements (see Notes). If ``"baseline"``, the + diagonal elements is estimated from the baseline period of the evoked data. + bem : instance of ConductorModel | None + Boundary element model to use in forward calculations. If ``None``, a spherical + model is used. + initial_time : float | None + Initial time point to show. If ``None``, the time point of the maximum field + strength is used. + trans : instance of Transform | None + The transformation from head coordinates to MRI coordinates. If ``None``, + the identity matrix is used and everything will be done in head coordinates. + stc : instance of SourceEstimate | None + An optional distributed source estimate to show alongside the fieldmap. The time + samples need to match those of the evoked data. + subject : str | None + The subject name. If ``None``, no MRI data is shown. + %(subjects_dir)s + %(rank)s + show_density : bool + Whether to show the density of the fieldmap. + ch_type : "meg" | "eeg" | None + Type of channels to use for the dipole fitting. By default (``None``) both MEG + and EEG channels will be used. + %(n_jobs)s + %(verbose)s + + Notes + ----- + When using ``cov=None`` the default noise values are 5 fT/cm, 20 fT, and 0.2 µV for + gradiometers, magnetometers, and EEG channels respectively. + """ + + def __init__( + self, + evoked, + cov=None, + bem=None, + initial_time=None, + trans=None, + stc=None, + subject=None, + subjects_dir=None, + rank="info", + show_density=True, + ch_type=None, + n_jobs=None, + verbose=None, + ): + if cov is None: + cov = make_ad_hoc_cov(evoked.info) + elif cov == "baseline": + std = dict() + for typ in set(evoked.get_channel_types(only_data_chs=True)): + baseline = evoked.copy().pick(typ).crop(*evoked.baseline) + std[typ] = baseline.data.std(axis=1).mean() + cov = make_ad_hoc_cov(evoked.info, std) + if bem is None: + bem = make_sphere_model("auto", "auto", evoked.info) + bem = _ensure_bem_surfaces(bem, extra_allow=(ConductorModel, None)) + field_map = make_field_map( + evoked, + ch_type=ch_type, + trans=trans, + origin=bem["r0"] if bem["is_sphere"] else "auto", + subject=subject, + subjects_dir=subjects_dir, + n_jobs=n_jobs, + verbose=verbose, + ) + + if initial_time is None: + # Set initial time to moment of maximum field power. + data = evoked.copy().pick(field_map[0]["ch_names"]).data + initial_time = evoked.times[np.argmax(np.mean(data**2, axis=0))] + + if stc is not None: + if not np.allclose(stc.times, evoked.times): + raise ValueError( + "The time samples of the source estimate do not match those of the " + "evoked data." + ) + if trans is None: + raise ValueError( + "`trans` cannot be `None` when showing the fieldlines in " + "combination with a source estimate." + ) + + # Get transforms to convert all the various meshes to MRI space. + head_mri_t = _get_trans(trans, "head", "mri")[0] + to_cf_t = _get_transforms_to_coord_frame( + evoked.info, head_mri_t, coord_frame="mri" + ) + + # Initialize all the private attributes. + self._actors = dict() + self._bem = bem + self._ch_type = ch_type + self._cov = cov + self._current_time = initial_time + self._dipoles = dict() + self._evoked = evoked + self._field_map = field_map + self._fig_sensors = None + self._multi_dipole_method = "Multi dipole (MNE)" + self._show_density = show_density + self._stc = stc + self._subjects_dir = subjects_dir + self._subject = subject + self._time_line = None + self._head_mri_t = head_mri_t + self._to_cf_t = to_cf_t + self._rank = rank + self._verbose = verbose + + # Configure the GUI. + self._renderer = self._configure_main_display() + self._configure_dock() + + @property + def dipoles(self): + """A list of all the fitted dipoles that are enabled in the GUI.""" + return [d["dip"] for d in self._dipoles.values() if d["active"]] + + def _configure_main_display(self): + """Configure main 3D display of the GUI.""" + fig = create_3d_figure((1500, 1020), bgcolor="white", show=True) + + self._fig_stc = None + if self._stc is not None: + self._fig_stc = self._stc.plot( + subject=self._subject, + subjects_dir=self._subjects_dir, + surface="white", + hemi="both", + time_viewer=False, + initial_time=self._current_time, + brain_kwargs=dict(units="m"), + figure=fig, + ) + fig = self._fig_stc + self._actors["brain"] = fig._actors["data"] + + fig = EvokedField( + self._evoked, + self._field_map, + time=self._current_time, + interpolation="linear", + alpha=0, + show_density=self._show_density, + foreground="black", + background="white", + fig=fig, + ) + fig.separate_canvas = False # needed to plot the timeline later + fig.set_contour_line_width(2) + fig._renderer.set_camera( + focalpoint=fit_sphere_to_headshape(self._evoked.info)[1] + ) + + if self._stc is not None: + link(self._fig_stc, fig) + + for surf_map in fig._surf_maps: + if surf_map["map_kind"] == "meg": + helmet_mesh = surf_map["mesh"] + helmet_mesh._polydata.compute_normals() # needed later + helmet_mesh._actor.prop.culling = "back" + self._actors["helmet"] = helmet_mesh._actor + # For MEG fieldlines, we want to occlude the ones not facing us, + # otherwise it's hard to interpret them. Since the "contours" object + # does not support backface culling, we create an opaque mesh to put in + # front of the contour lines with frontface culling. + occl_surf = deepcopy(surf_map["surf"]) + occl_surf["rr"] -= 1e-3 * occl_surf["nn"] + occl_act, _ = fig._renderer.surface(occl_surf, color="white") + occl_act.prop.culling = "front" + occl_act.prop.lighting = False + self._actors["occlusion_surf"] = occl_act + elif surf_map["map_kind"] == "eeg": + head_mesh = surf_map["mesh"] + head_mesh._polydata.compute_normals() # needed later + head_mesh._actor.prop.culling = "back" + self._actors["head"] = head_mesh._actor + + show_meg = (self._ch_type is None or self._ch_type == "meg") and any( + [m["kind"] == "meg" for m in self._field_map] + ) + show_eeg = (self._ch_type is None or self._ch_type == "eeg") and any( + [m["kind"] == "eeg" for m in self._field_map] + ) + meg_picks = pick_types(self._evoked.info, meg=show_meg) + eeg_picks = pick_types(self._evoked.info, meg=False, eeg=show_eeg) + picks = np.concatenate((meg_picks, eeg_picks)) + self._ch_names = [self._evoked.ch_names[i] for i in picks] + + for m in self._field_map: + if m["kind"] == "eeg": + head_surf = m["surf"] + break + else: + self._actors["head"], _, head_surf = _plot_head_surface( + renderer=fig._renderer, + head="head", + subject=self._subject, + subjects_dir=self._subjects_dir, + bem=self._bem, + coord_frame="mri", + to_cf_t=self._to_cf_t, + alpha=0.2, + ) + self._actors["head"].prop.culling = "back" + + sensors = _plot_sensors_3d( + renderer=fig._renderer, + info=self._evoked.info, + to_cf_t=self._to_cf_t, + picks=picks, + meg=["sensors"] if show_meg else False, + eeg=["original"] if show_eeg else False, + fnirs=False, + warn_meg=False, + head_surf=head_surf, + units="m", + sensor_alpha=dict(meg=0.1, eeg=1.0), + orient_glyphs=False, + scale_by_distance=False, + project_points=False, + surf=None, + check_inside=None, + nearest=None, + sensor_colors=dict( + meg=["white" for _ in meg_picks], + eeg=["white" for _ in eeg_picks], + ), + ) + self._actors["sensors"] = list() + for s in sensors.values(): + self._actors["sensors"].extend(s) + + subscribe(fig, "time_change", self._on_time_change) + self._fig = fig + return fig._renderer + + def _configure_dock(self): + """Configure the left and right dock areas of the GUI.""" + r = self._renderer + + # Toggle buttons for various meshes + layout = r._dock_add_group_box("Meshes") + for actor_name in self._actors.keys(): + if actor_name == "occlusion_surf": + continue + r._dock_add_check_box( + name=actor_name, + value=True, + callback=partial(self.toggle_mesh, name=actor_name), + layout=layout, + ) + + # Right dock + r._dock_initialize(name="Dipole fitting", area="right") + r._dock_add_button("Sensor data", self._on_sensor_data) + r._dock_add_button("Fit dipole", self._on_fit_dipole) + methods = ["Multi dipole (MNE)", "Single dipole"] + r._dock_add_combo_box( + "Dipole model", + value="Multi dipole (MNE)", + rng=methods, + callback=self._on_select_method, + ) + self._dipole_box = r._dock_add_group_box(name="Dipoles") + r._dock_add_file_button( + name="save_dipoles", + desc="Save dipoles", + save=True, + func=self.save, + tooltip="Save the dipoles to disk", + filter_="Dipole files (*.bdip)", + initial_directory=".", + ) + r._dock_add_stretch() + + def toggle_mesh(self, name, show=None): + """Toggle a mesh on or off. + + Parameters + ---------- + name : str + Name of the mesh to toggle. + show : bool | None + Whether to show the mesh. If None, the visibility of the mesh is toggled. + """ + _check_option("name", name, self._actors.keys()) + actors = self._actors[name] + # self._actors[name] is sometimes a list and sometimes not. Make it + # always be a list to simplify the code. + if not isinstance(actors, list): + actors = [actors] + if show is None: + show = not actors[0].GetVisibility() + for act in actors: + act.SetVisibility(show) + self._renderer._update() + + def _on_time_change(self, event): + new_time = np.clip(event.time, self._evoked.times[0], self._evoked.times[-1]) + self._current_time = new_time + if self._time_line is not None: + self._time_line.set_xdata([new_time]) + self._renderer._mplcanvas.update_plot() + self._update_arrows() + + def _on_sensor_data(self): + """Show sensor data and allow sensor selection.""" + if self._fig_sensors is not None: + return + fig = self._evoked.plot_topo(select=True) + fig.canvas.mpl_connect("close_event", self._on_sensor_data_close) + subscribe(fig, "channels_select", self._on_channels_select) + self._fig_sensors = fig + + def _on_sensor_data_close(self, event): + """Handle closing of the sensor selection window.""" + self._fig_sensors = None + if "sensors" in self._actors: + for act in self._actors["sensors"]: + act.prop.SetColor(1, 1, 1) + self._renderer._update() + + def _on_channels_select(self, event): + """Color selected sensor meshes.""" + selected_channels = set(event.ch_names) + if "sensors" in self._actors: + for act, ch_name in zip(self._actors["sensors"], self._ch_names): + if ch_name in selected_channels: + act.prop.SetColor(0, 1, 0) + else: + act.prop.SetColor(1, 1, 1) + self._renderer._update() + + def _on_fit_dipole(self): + """Fit a single dipole.""" + evoked_picked = self._evoked.copy() + cov_picked = self._cov.copy().as_diag() # FIXME: as_diag necessary? + if self._fig_sensors is not None: + picks = self._fig_sensors.lasso.selection + if len(picks) > 0: + evoked_picked = evoked_picked.pick(picks) + evoked_picked.info.normalize_proj() + cov_picked = cov_picked.pick_channels(picks, ordered=False) + cov_picked["projs"] = evoked_picked.info["projs"] + evoked_picked.crop(self._current_time, self._current_time) + + dip = fit_dipole( + evoked_picked, + cov_picked, + self._bem, + trans=self._head_mri_t, + rank=self._rank, + verbose=False, + )[0] + + self.add_dipole(dip) + + def add_dipole(self, dipole): + """Add a dipole (or multiple dipoles) to the GUI. + + Parameters + ---------- + dipole : Dipole + The dipole to add. If the ``Dipole`` object defines multiple dipoles, they + will all be added. + """ + new_dipoles = list() + for dip_i in range(len(dipole)): + dip = dipole[dip_i] + + # Coordinates needed to draw the big arrow on the helmet. + helmet_coords, helmet_pos = self._get_helmet_coords(dip) + + # Collect all relevant information on the dipole in a dict. + colors = _get_color_list() + if len(self._dipoles) == 0: + dip_num = 0 + else: + dip_num = max(self._dipoles.keys()) + 1 + if dip.name is None: + dip.name = f"dip{dip_num}" + dip_color = colors[dip_num % len(colors)] + if helmet_coords is not None: + arrow_mesh = pyvista.PolyData(*_arrow_mesh()) + else: + arrow_mesh = None + dipole_dict = dict( + active=True, + brain_arrow_actor=None, + helmet_arrow_actor=None, + arrow_mesh=arrow_mesh, + color=dip_color, + dip=dip, + fix_ori=True, + fix_position=True, + helmet_coords=helmet_coords, + helmet_pos=helmet_pos, + num=dip_num, + fit_time=self._current_time, + ) + self._dipoles[dip_num] = dipole_dict + + # Add a row to the dipole list + r = self._renderer + hlayout = r._dock_add_layout(vertical=False) + widgets = [] + widgets.append( + r._dock_add_check_box( + name="", + value=True, + callback=partial(self._on_dipole_toggle, dip_num=dip_num), + layout=hlayout, + ) + ) + widgets.append( + r._dock_add_text( + name=dip.name, + value=dip.name, + placeholder="name", + callback=partial(self._on_dipole_set_name, dip_num=dip_num), + layout=hlayout, + ) + ) + widgets.append( + r._dock_add_check_box( + name="Fix ori", + value=True, + callback=partial( + self._on_dipole_toggle_fix_orientation, dip_num=dip_num + ), + layout=hlayout, + ) + ) + widgets.append( + r._dock_add_button( + name="", + icon="clear", + callback=partial(self._on_dipole_delete, dip_num=dip_num), + layout=hlayout, + ) + ) + dipole_dict["widgets"] = widgets + r._layout_add_widget(self._dipole_box, hlayout) + new_dipoles.append(dipole_dict) + + # Show the dipoles and arrows in the 3D view. Only do this after + # `_fit_timecourses` so that they have the correct size straight away. + self._fit_timecourses() + for dipole_dict in new_dipoles: + dip = dipole_dict["dip"] + dipole_dict["brain_arrow_actor"] = self._renderer.plotter.add_arrows( + apply_trans(self._head_mri_t, dip.pos[0]), + dip.ori[0], + color=dipole_dict["color"], + mag=0.05, + ) + if arrow_mesh is not None: + dipole_dict["helmet_arrow_actor"] = self._renderer.plotter.add_mesh( + arrow_mesh, + color=dipole_dict["color"], + culling="front", + ) + self._update_arrows() + + def _get_helmet_coords(self, dip): + """Compute the coordinate system used for drawing the big arrows on the helmet. + + In this coordinate system, Z is normal to the helmet surface, and XY + are tangential to the helmet surface. + """ + if "helmet" not in self._actors: + return None, None + + # Get the closest vertex (=point) of the helmet mesh + dip_pos = apply_trans(self._head_mri_t, dip.pos[0]) + helmet = self._actors["helmet"].GetMapper().GetInput() + distances = ((helmet.points - dip_pos) * helmet.point_normals).sum(axis=1) + closest_point = np.argmin(distances) + + # Compute the position of the projected dipole on the helmet + norm = helmet.point_normals[closest_point] + helmet_pos = dip_pos + (distances[closest_point] + 0.003) * norm + + # Create a coordinate system where X and Y are tangential to the helmet + helmet_coords = _normal_orth(norm) + + return helmet_coords, helmet_pos + + def _fit_timecourses(self): + """Compute (or re-compute) dipole timecourses. + + Called whenever a dipole is (de)-activated or the "Fix pos" box is toggled. + """ + active_dips = [d for d in self._dipoles.values() if d["active"]] + if len(active_dips) == 0: + return + + # Restrict the dipoles to only the time at which they were fitted. + for d in active_dips: + if len(d["dip"].times) > 1: + d["dip"] = d["dip"].crop(d["fit_time"], d["fit_time"]) + + if self._multi_dipole_method == "Multi dipole (MNE)": + fwd, _ = make_forward_dipole( + [d["dip"] for d in active_dips], + self._bem, + self._evoked.info, + trans=self._head_mri_t, + ) + fwd = convert_forward_solution(fwd, surf_ori=False) + + inv = make_inverse_operator( + self._evoked.info, + fwd, + self._cov, + fixed=False, + loose=1.0, + depth=0, + rank=self._rank, + ) + stc = apply_inverse( + self._evoked, inv, method="MNE", lambda2=1e-6, pick_ori="vector" + ) + + timecourses = stc.magnitude().data + orientations = (stc.data / timecourses[:, np.newaxis, :]).transpose(0, 2, 1) + fixed_timecourses = stc.project( + np.array([dip["dip"].ori[0] for dip in active_dips]) + )[0].data + + for i, dip in enumerate(active_dips): + if dip["fix_ori"]: + dip["timecourse"] = fixed_timecourses[i] + dip["orientation"] = dip["dip"].ori.repeat(len(stc.times), axis=0) + else: + dip["timecourse"] = timecourses[i] + dip["orientation"] = orientations[i] + elif self._multi_dipole_method == "Single dipole": + for dip in active_dips: + dip_with_timecourse, _ = fit_dipole( + self._evoked, + self._cov, + self._bem, + pos=dip["dip"].pos[0], # position is always fixed + ori=dip["dip"].ori[0] if dip["fix_ori"] else None, + trans=self._head_mri_t, + rank=self._rank, + verbose=False, + ) + if dip["fix_ori"]: + dip["timecourse"] = dip_with_timecourse.data[0] + dip["orientation"] = dip["dip"].ori.repeat( + len(dip_with_timecourse.times), axis=0 + ) + else: + dip["timecourse"] = dip_with_timecourse.amplitude + dip["orientation"] = dip_with_timecourse.ori + + # Update matplotlib canvas at the bottom of the window + canvas = self._setup_mplcanvas() + ymin, ymax = 0, 0 + for dip in active_dips: + if "line_artist" in dip: + dip["line_artist"].set_ydata(dip["timecourse"]) + else: + dip["line_artist"] = canvas.plot( + self._evoked.times, + dip["timecourse"], + label=dip["dip"].name, + color=dip["color"], + ) + ymin = min(ymin, 1.1 * dip["timecourse"].min()) + ymax = max(ymax, 1.1 * dip["timecourse"].max()) + canvas.axes.set_ylim(ymin, ymax) + canvas.update_plot() + self._update_arrows() + + @verbose + @fill_doc + def save(self, fname, verbose=None): + """Save the fitted dipoles to a file. + + Parameters + ---------- + fname : path-like + The name of the file. Should end in ``'.dip'`` to save in plain text format, + or in ``'.bdip'`` to save in binary format. + %(verbose)s + """ + logger.info("Saving dipoles as:") + fname = Path(fname) + + # Pack the dipoles into a single mne.Dipole object. + dip = Dipole( + times=np.array([d.times[0] for d in self.dipoles]), + pos=np.array([d.pos[0] for d in self.dipoles]), + amplitude=np.array([d.amplitude[0] for d in self.dipoles]), + ori=np.array([d.ori[0] for d in self.dipoles]), + gof=np.array([d.gof[0] for d in self.dipoles]), + khi2=np.array([d.khi2[0] for d in self.dipoles]), + nfree=np.array([d.nfree[0] for d in self.dipoles]), + conf={ + key: np.array([d.conf[key][0] for d in self.dipoles]) + for key in self.dipoles[0].conf.keys() + }, + name=",".join(d.name for d in self.dipoles), + ) + dip.save(fname, overwrite=True, verbose=verbose) + + def _update_arrows(self): + """Update the arrows to have the correct size and orientation.""" + active_dips = [d for d in self._dipoles.values() if d["active"]] + if len(active_dips) == 0: + return + orientations = [dip["orientation"] for dip in active_dips] + timecourses = [dip["timecourse"] for dip in active_dips] + arrow_scaling = 0.05 / np.max(np.abs(timecourses)) + for dip, ori, timecourse in zip(active_dips, orientations, timecourses): + helmet_coords = dip["helmet_coords"] + if helmet_coords is None: + continue + dip_ori = [ + np.interp(self._current_time, self._evoked.times, o) for o in ori.T + ] + dip_moment = np.interp(self._current_time, self._evoked.times, timecourse) + arrow_size = dip_moment * arrow_scaling + arrow_mesh = dip["arrow_mesh"] + + # Project the orientation of the dipole tangential to the helmet + dip_ori_tan = helmet_coords[:2] @ dip_ori @ helmet_coords[:2] + + # Rotate the coordinate system such that Y lies along the dipole + # orientation, now we have our desired coordinate system for the + # arrows. + arrow_coords = np.array( + [np.cross(dip_ori_tan, helmet_coords[2]), dip_ori_tan, helmet_coords[2]] + ) + arrow_coords /= np.linalg.norm(arrow_coords, axis=1, keepdims=True) + + # Update the arrow mesh to point in the right directions + arrow_mesh.points = (_arrow_mesh()[0] * arrow_size) @ arrow_coords + arrow_mesh.points += dip["helmet_pos"] + self._renderer._update() + + def _on_select_method(self, method): + """Select the method to use for multi-dipole timecourse fitting.""" + self._multi_dipole_method = method + self._fit_timecourses() + + def _on_dipole_toggle(self, active, dip_num): + """Toggle a dipole on or off.""" + dipole = self._dipoles[dip_num] + active = bool(active) + dipole["active"] = active + dipole["line_artist"].set_visible(active) + # Labels starting with "_" are hidden from the legend. + dipole["line_artist"].set_label(("" if active else "_") + dipole["dip"].name) + dipole["brain_arrow_actor"].visibility = active + dipole["helmet_arrow_actor"].visibility = active + self._fit_timecourses() + self._renderer._update() + self._renderer._mplcanvas.update_plot() + + def _on_dipole_set_name(self, name, dip_num): + """Set the name of a dipole.""" + self._dipoles[dip_num]["dip"].name = name + self._dipoles[dip_num]["line_artist"].set_label(name) + self._renderer._mplcanvas.update_plot() + + def _on_dipole_toggle_fix_orientation(self, fix, dip_num): + """Fix dipole orientation when fitting timecourse.""" + self._dipoles[dip_num]["fix_ori"] = bool(fix) + self._fit_timecourses() + + def _on_dipole_delete(self, dip_num): + """Delete previously fitted dipole.""" + dipole = self._dipoles[dip_num] + dipole["line_artist"].remove() + dipole["brain_arrow_actor"].visibility = False + dipole["helmet_arrow_actor"].visibility = False + for widget in dipole["widgets"]: + widget.hide() + del self._dipoles[dip_num] + self._fit_timecourses() + self._renderer._update() + self._renderer._mplcanvas.update_plot() + + def _setup_mplcanvas(self): + """Configure the matplotlib canvas at the bottom of the window.""" + if self._renderer._mplcanvas is None: + self._renderer._mplcanvas = self._renderer._window_get_mplcanvas( + self._fig, 0.3, False, False + ) + self._renderer._window_adjust_mplcanvas_layout() + if self._time_line is None: + self._time_line = self._renderer._mplcanvas.plot_time_line( + self._current_time, + label="time", + color="black", + ) + return self._renderer._mplcanvas + + +def _arrow_mesh(): + """Obtain a mesh of an arrow.""" + vertices = np.array( + [ + [0.0, 1.0, 0.0], + [0.3, 0.7, 0.0], + [0.1, 0.7, 0.0], + [0.1, -1.0, 0.0], + [-0.1, -1.0, 0.0], + [-0.1, 0.7, 0.0], + [-0.3, 0.7, 0.0], + ] + ) + faces = np.array([[7, 0, 1, 2, 3, 4, 5, 6]]) + return vertices, faces diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index 9c558d32a51..53ac0def755 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -83,13 +83,7 @@ ) from ._dipole import _check_concat_dipoles, _plot_dipole_3d, _plot_dipole_mri_outlines from .evoked_field import EvokedField -from .utils import ( - _check_time_unit, - _get_cmap, - _get_color_list, - figure_nobar, - plt_show, -) +from .utils import _check_time_unit, _get_cmap, _get_color_list, figure_nobar, plt_show verbose_dec = verbose FIDUCIAL_ORDER = (FIFF.FIFFV_POINT_LPA, FIFF.FIFFV_POINT_NASION, FIFF.FIFFV_POINT_RPA) @@ -864,6 +858,7 @@ def plot_alignment( renderer.set_interaction(interaction) # plot head + print(head, bem, subject) _, _, head_surf = _plot_head_surface( renderer, head, diff --git a/mne/viz/_3d_overlay.py b/mne/viz/_3d_overlay.py index 3ebc308c127..fd9bbf8e1bb 100644 --- a/mne/viz/_3d_overlay.py +++ b/mne/viz/_3d_overlay.py @@ -100,7 +100,7 @@ def _compute_over(self, B, A): C[:, :3] *= A_w C[:, :3] += B[:, :3] * B_w C[:, 3:] += B_w - C[:, :3] /= C[:, 3:] + C[:, :3] /= np.maximum(1e-20, C[:, 3:]) # avoid divide by zero return np.clip(C, 0, 1, out=C) def _compose_overlays(self): diff --git a/mne/viz/evoked_field.py b/mne/viz/evoked_field.py index cf5a9996216..4bfeec1e67e 100644 --- a/mne/viz/evoked_field.py +++ b/mne/viz/evoked_field.py @@ -7,6 +7,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +from copy import deepcopy from functools import partial import numpy as np @@ -49,8 +50,6 @@ class EvokedField: the average peak latency (across sensor types) is used. time_label : str | None How to print info about the time instant visualized. - %(n_jobs)s - fig : instance of Figure3D | None If None (default), a new figure will be created, otherwise it will plot into the given figure. @@ -68,6 +67,10 @@ class EvokedField: The number of contours. .. versionadded:: 0.21 + contour_line_width : float + The line_width of the contour lines. + + .. versionadded:: 1.6 show_density : bool Whether to draw the field density as an overlay on top of the helmet/head surface. Defaults to ``True``. @@ -90,6 +93,17 @@ class EvokedField: ``True`` if there is more than one time point and ``False`` otherwise. .. versionadded:: 1.6 + background : tuple(int, int, int) + The color definition of the background: (red, green, blue). + + .. versionadded:: 1.6 + foreground : matplotlib color + Color of the foreground (will be used for colorbars and text). + None (default) will use black or white depending on the value + of ``background``. + + .. versionadded:: 1.6 + %(n_jobs)s %(verbose)s Notes @@ -108,15 +122,18 @@ def __init__( *, time=None, time_label="t = %0.0f ms", - n_jobs=None, fig=None, vmax=None, n_contours=21, + contour_line_width=1.0, show_density=True, alpha=None, interpolation="nearest", interaction="terrain", time_viewer="auto", + background="black", + foreground=None, + n_jobs=None, verbose=None, ): from .backends.renderer import _get_3d_backend, _get_renderer @@ -133,6 +150,7 @@ def __init__( self._vmax = _validate_type(vmax, (None, "numeric", dict), "vmax") self._n_contours = _ensure_int(n_contours, "n_contours") + self._contour_line_width = contour_line_width self._time_interpolation = _check_option( "interpolation", interpolation, @@ -141,6 +159,10 @@ def __init__( self._interaction = _check_option( "interaction", interaction, ["trackball", "terrain"] ) + self._bg_color = _to_rgb(background, name="background") + if foreground is None: + foreground = "w" if sum(self._bg_color) < 2 else "k" + self._fg_color = _to_rgb(foreground, name="foreground") surf_map_kinds = [surf_map["kind"] for surf_map in surf_maps] if vmax is None: @@ -185,16 +207,16 @@ def __init__( if isinstance(fig, Brain): self._renderer = fig._renderer self._in_brain_figure = True + self._units = fig._units if _get_3d_backend() == "notebook": raise NotImplementedError( "Plotting on top of an existing Brain figure " "is currently not supported inside a notebook." ) else: - self._renderer = _get_renderer( - fig, bgcolor=(0.0, 0.0, 0.0), size=(600, 600) - ) + self._renderer = _get_renderer(fig, bgcolor=background, size=(600, 600)) self._in_brain_figure = False + self._units = "m" self.plotter = self._renderer.plotter self.interaction = interaction @@ -227,14 +249,17 @@ def current_time_func(): current_time_func=current_time_func, times=evoked.times, ) - if not self._in_brain_figure or "time_slider" not in fig.widgets: + if not self._in_brain_figure: # Draw the time label self._time_label = time_label if time_label is not None: if "%" in time_label: time_label = time_label % np.round(1e3 * time) self._time_label_actor = self._renderer.text2d( - x_window=0.01, y_window=0.01, text=time_label + x_window=0.01, + y_window=0.01, + text=time_label, + color=foreground, ) self._configure_dock() @@ -277,7 +302,8 @@ def _prepare_surf_map(self, surf_map, color, alpha): # Make a solid surface surf = surf_map["surf"] - if self._in_brain_figure: + if self._units == "mm": + surf = deepcopy(surf) surf["rr"] *= 1000 map_vmax = self._vmax.get(surf_map["kind"]) if map_vmax is None: @@ -355,6 +381,7 @@ def _update(self): vmin=-surf_map["map_vmax"], vmax=surf_map["map_vmax"], colormap=self._colormap_lines, + width=self._contour_line_width, ) if self._time_label is not None: if hasattr(self, "_time_label_actor"): @@ -365,7 +392,10 @@ def _update(self): if "%" in self._time_label: time_label = self._time_label % np.round(1e3 * self._current_time) self._time_label_actor = self._renderer.text2d( - x_window=0.01, y_window=0.01, text=time_label + x_window=0.01, + y_window=0.01, + text=time_label, + color=self._fg_color, ) self._renderer.plotter.update() @@ -434,6 +464,16 @@ def _callback(vmax, kind, scaling): callback=self.set_contours, layout=layout, ) + + self._widgets["contours_line_width"] = r._dock_add_slider( + name="Thickness", + value=1, + rng=[0, 10], + callback=self.set_contour_line_width, + double=True, + layout=layout, + ) + r._dock_finalize() def _on_time_change(self, event): @@ -495,9 +535,13 @@ def _on_contours(self, event): break surf_map["contours"] = event.contours self._n_contours = len(event.contours) + if event.line_width is not None: + self._contour_line_width = event.line_width with disable_ui_events(self): if "contours" in self._widgets: self._widgets["contours"].set_value(len(event.contours)) + if "contour_line_width" in self._widgets and event.line_width is not None: + self._widgets["contour_line_width"].set_value(event.line_width) self._update() def set_time(self, time): @@ -532,6 +576,7 @@ def set_contours(self, n_contours): contours=np.linspace( -surf_map["map_vmax"], surf_map["map_vmax"], n_contours ).tolist(), + line_width=self._contour_line_width, ), ) @@ -566,3 +611,14 @@ def _rescale(self): current_data = surf_map["data_interp"](self._current_time) vmax = float(np.max(current_data)) self.set_vmax(vmax, kind=surf_map["map_kind"]) + + def set_contour_line_width(self, line_width): + """Set the line_width of the contour lines. + + Parameters + ---------- + line_width : float + The desired line_width of the contour lines. + """ + self._contour_line_width = line_width + self.set_contours(self._n_contours) diff --git a/mne/viz/tests/test_3d.py b/mne/viz/tests/test_3d.py index 34022d59768..e3e4a2143d2 100644 --- a/mne/viz/tests/test_3d.py +++ b/mne/viz/tests/test_3d.py @@ -192,9 +192,18 @@ def test_plot_evoked_field(renderer): ) evoked.plot_field(maps, time=0.1, n_contours=n_contours) - # Test plotting inside an existing Brain figure. - brain = Brain("fsaverage", "lh", "inflated", subjects_dir=subjects_dir) - fig = evoked.plot_field(maps, time=0.1, fig=brain) + # Test plotting inside an existing Brain figure. Check that units are taken into + # account. + for units in ["mm", "m"]: + brain = Brain( + "fsaverage", "lh", "inflated", units=units, subjects_dir=subjects_dir + ) + fig = evoked.plot_field(maps, time=0.1, fig=brain) + assert brain._units == fig._units + scale = 1000 if units == "mm" else 1 + assert ( + fig._surf_maps[0]["surf"]["rr"][0, 0] == scale * maps[0]["surf"]["rr"][0, 0] + ) # Test some methods fig = evoked.plot_field(maps, time_viewer=True) diff --git a/mne/viz/ui_events.py b/mne/viz/ui_events.py index b8b3fe29a4d..a07514c0ebe 100644 --- a/mne/viz/ui_events.py +++ b/mne/viz/ui_events.py @@ -191,11 +191,13 @@ class Contours(UIEvent): Parameters ---------- kind : str - The kind of contours lines being changed. The Notes section of the drawing routine publishing this event should mention the possible kinds. contours : list of float The new values at which contour lines need to be drawn. + line_width : float | None + The line_width with which to draw the contour lines. Can be ``None`` to + indicate to keep using the current line_width. Attributes ---------- @@ -206,10 +208,14 @@ class Contours(UIEvent): kinds. contours : list of float The new values at which contour lines need to be drawn. + line_width : float | None + The line_width with which to draw the contour lines. Can be ``None`` to + indicate to keep using the current line_width. """ kind: str contours: list[str] + line_width: float | None @dataclass