Skip to content

Commit

Permalink
First working version of lasso select in plot_evoked_topo
Browse files Browse the repository at this point in the history
  • Loading branch information
wmvanvliet committed Oct 4, 2023
1 parent 905c12c commit 4fe2135
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 66 deletions.
33 changes: 33 additions & 0 deletions mne/viz/evoked_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ class EvokedField:
The number of contours.
.. versionadded:: 0.21
contour_line_width : float
The line_width of the contour lines.
.. versionadded:: 1.6
show_density : bool
Whether to draw the field density as an overlay on top of the helmet/head
surface. Defaults to ``True``.
Expand Down Expand Up @@ -111,6 +115,7 @@ def __init__(
fig=None,
vmax=None,
n_contours=21,
contour_line_width=1.0,
show_density=True,
alpha=None,
interpolation="nearest",
Expand All @@ -132,6 +137,7 @@ def __init__(

self._vmax = _validate_type(vmax, (None, "numeric", dict), "vmax")
self._n_contours = _ensure_int(n_contours, "n_contours")
self._contour_line_width = contour_line_width
self._time_interpolation = _check_option(
"interpolation",
interpolation,
Expand Down Expand Up @@ -354,6 +360,7 @@ def _update(self):
vmin=-surf_map["map_vmax"],
vmax=surf_map["map_vmax"],
colormap=self._colormap_lines,
width=self._contour_line_width,
)
if self._time_label is not None:
if hasattr(self, "_time_label_actor"):
Expand Down Expand Up @@ -442,6 +449,16 @@ def _callback(vmax, type, scaling):
callback=self.set_contours,
layout=layout,
)

self._widgets["contours_line_width"] = r._dock_add_slider(
name="Thickness",
value=1,
rng=[0, 10],
callback=self.set_contour_line_width,
double=True,
layout=layout,
)

r._dock_finalize()

def _on_time_change(self, event):
Expand Down Expand Up @@ -503,9 +520,13 @@ def _on_contours(self, event):
break
surf_map["contours"] = event.contours
self._n_contours = len(event.contours)
if event.line_width is not None:
self._contour_line_width = event.line_width
with disable_ui_events(self):
if "contours" in self._widgets:
self._widgets["contours"].set_value(len(event.contours))
if "contour_line_width" in self._widgets and event.line_width is not None:
self._widgets["contour_line_width"].set_value(event.line_width)
self._update()

def set_time(self, time):
Expand Down Expand Up @@ -540,6 +561,7 @@ def set_contours(self, n_contours):
contours=np.linspace(
-surf_map["map_vmax"], surf_map["map_vmax"], n_contours
).tolist(),
line_width=self._contour_line_width,
),
)

Expand Down Expand Up @@ -574,3 +596,14 @@ def _rescale(self):
current_data = surf_map["data_interp"](self._current_time)
vmax = float(np.max(current_data))
self.set_vmax(vmax, type=surf_map["map_kind"])

def set_contour_line_width(self, line_width):
"""Set the line_width of the contour lines.
Parameters
----------
line_width : float
The desired line_width of the contour lines.
"""
self._contour_line_width = line_width
self.set_contours(self._n_contours)
32 changes: 23 additions & 9 deletions mne/viz/topo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_setup_ax_spines,
_check_cov,
_plot_masked_image,
SelectFromCollection,
)


Expand Down Expand Up @@ -195,8 +196,11 @@ def format_coord_multiaxis(x, y, ch_name=None):
under_ax.set(xlim=[0, 1], ylim=[0, 1])

axs = list()

shown_ch_names = []
for idx, name in iter_ch:
ch_idx = ch_names.index(name)
shown_ch_names.append(name)
if not unified: # old, slow way
ax = plt.axes(pos[idx])
ax.patch.set_facecolor(axis_facecolor)
Expand Down Expand Up @@ -237,15 +241,22 @@ def format_coord_multiaxis(x, y, ch_name=None):
],
[1, 0, 2],
)
if not img:
under_ax.add_collection(
collections.PolyCollection(
verts,
facecolor=axis_facecolor,
edgecolor=axis_spinecolor,
linewidth=1.0,
)
) # Not needed for image plots.
if not img: # Not needed for image plots.
collection = collections.PolyCollection(
verts,
facecolor=axis_facecolor,
edgecolor=axis_spinecolor,
)
under_ax.add_collection(collection)
fig.lasso = SelectFromCollection(
ax=under_ax,
collection=collection,
names=shown_ch_names,
alpha_nonselected=0,
alpha_selected=1,
linewidth_nonselected=0,
linewidth_selected=0.7,
)
for ax in axs:
yield ax, ax._mne_ch_idx

Expand Down Expand Up @@ -344,6 +355,9 @@ def _plot_topo_onpick(event, show_func):
"""Onpick callback that shows a single channel in a new figure."""
# make sure that the swipe gesture in OS-X doesn't open many figures
orig_ax = event.inaxes
if orig_ax.figure.canvas._key in ["shift", "alt"]:
return

import matplotlib.pyplot as plt

try:
Expand Down
27 changes: 27 additions & 0 deletions mne/viz/ui_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ class Contours(UIEvent):
kinds.
contours : list of float
The new values at which contour lines need to be drawn.
line_width : float | None
The line_width with which to draw the contour lines. Can be ``None`` to
indicate to keep using the current line_width.
Attributes
----------
Expand All @@ -186,10 +189,34 @@ class Contours(UIEvent):
kinds.
contours : list of float
The new values at which contour lines need to be drawn.
line_width : float | None
The line_width with which to draw the contour lines. Can be ``None`` to
indicate to keep using the current line_width.
"""

kind: str
contours: List[str]
line_width: Optional[float]


@dataclass
@fill_doc
class ChannelsSelect(UIEvent):
"""Indicates that the user has selected one or more channels.
Parameters
----------
ch_names : list of str
The names of the channels that were selected.
Attributes
----------
%(ui_event_name_source)s
ch_names : list of str
The names of the channels that were selected.
"""

ch_names: List[str]


def _get_event_channel(fig):
Expand Down
Loading

0 comments on commit 4fe2135

Please sign in to comment.