Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
e48e050
init
jona-sassenhagen Aug 31, 2017
d7556f9
make it work
jona-sassenhagen Sep 1, 2017
1b32fc7
fx
jona-sassenhagen Sep 1, 2017
36d3617
tests
jona-sassenhagen Sep 1, 2017
6dc0493
fix for mpl=1.3
jona-sassenhagen Sep 2, 2017
81db3a0
just pep8
jona-sassenhagen Sep 2, 2017
3feb2fd
...
jona-sassenhagen Sep 3, 2017
8b87314
tutorial
jona-sassenhagen Oct 19, 2017
9488c60
change cmap
jona-sassenhagen Oct 20, 2017
c2269b7
fixes
jona-sassenhagen Oct 20, 2017
4c49a87
fix docstring
jona-sassenhagen Oct 20, 2017
27f4371
pep8
jona-sassenhagen Oct 20, 2017
7c8c36b
pep8
jona-sassenhagen Oct 20, 2017
f4d29ac
heatmap for the cbar
jona-sassenhagen Oct 20, 2017
5a01d28
trigger circle
jona-sassenhagen Oct 20, 2017
63db06c
pep8
jona-sassenhagen Oct 21, 2017
1b8c75b
tiny fixes
jona-sassenhagen Oct 21, 2017
a414256
coverage
jona-sassenhagen Oct 21, 2017
51769c7
a few comments and docs
jona-sassenhagen Oct 21, 2017
eb17427
fix
jona-sassenhagen Oct 22, 2017
5e79013
tiny stuff
jona-sassenhagen Oct 23, 2017
9eeb617
doc
jona-sassenhagen Oct 24, 2017
55c7298
only my pep8
jona-sassenhagen Oct 24, 2017
21833fb
rather important fix
jona-sassenhagen Oct 27, 2017
f957146
cbar label
jona-sassenhagen Oct 27, 2017
4374c42
another fix
jona-sassenhagen Oct 27, 2017
db4d561
another fix
jona-sassenhagen Oct 27, 2017
cfb5ce6
cosmits
agramfort Nov 3, 2017
c1f8e34
address mikolajs comments
jona-sassenhagen Nov 8, 2017
e7524a7
strings for mpl locs
jona-sassenhagen Nov 8, 2017
5d6d436
pep8
jona-sassenhagen Nov 8, 2017
f62ebba
more pep8
jona-sassenhagen Nov 9, 2017
41eaa40
bool ci
jona-sassenhagen Nov 13, 2017
66bd908
speed up tests a bit
jona-sassenhagen Nov 13, 2017
2b2dc53
fix bad rebase
jona-sassenhagen Nov 13, 2017
fd23602
fix legend default
jona-sassenhagen Nov 13, 2017
fb37f87
*curses*
jona-sassenhagen Nov 13, 2017
4c24b4d
fix plot_epochs_image
jona-sassenhagen Nov 13, 2017
06a285a
pep
jona-sassenhagen Nov 13, 2017
a5fc30a
fixes
jona-sassenhagen Nov 14, 2017
b9afa02
comment
jona-sassenhagen Nov 14, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mne/viz/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def plot_epochs_image(epochs, picks=None, sigma=0., vmin=None,
group_by = "type"
combine = "gfp"

if combine is not None:
ts_args["show_sensors"] = False

if picks is None:
picks = pick_types(epochs.info, meg=True, eeg=True, ref_meg=False,
exclude='bads')
Expand Down
203 changes: 165 additions & 38 deletions mne/viz/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from numbers import Integral

import numpy as np
import matplotlib.lines as mlines

from ..io.pick import (channel_type, pick_types, _picks_by_type,
_pick_data_channels, _VALID_CHANNEL_TYPES)
Expand Down Expand Up @@ -1440,11 +1441,26 @@ def _truncate_yaxis(axes, ymin, ymax, orig_ymin, orig_ymax, fraction,
return ymin_bound, ymax_bound


def _check_loc_legal(loc, what='your choice'):
"""Check if a loc is a legal for MPL."""
true_default = {"show_legend": 3, "show_sensors": 4}.get(what, 1)
loc_dict = {'upper right': 1, 'upper left': 2, 'lower left': 3,
'lower right': 4, 'right': 5, 'center left': 6,
'center right': 7, 'lower center': 8, 'upper center': 9,
'center': 10, True: true_default}
loc_ = loc_dict.get(loc, loc)
if loc_ not in range(11):
raise ValueError(str(loc) + " is not a legal MPL loc, please supply"
"another value for " + what + ".")
return loc_


def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None,
linestyles=['-'], styles=None, vlines=list((0.,)),
ci=0.95, truncate_yaxis=False, truncate_xaxis=True,
ylim=dict(), invert_y=False, show_sensors=None,
show_legend=True, axes=None, title=None, show=True):
linestyles=['-'], styles=None, cmap=None,
vlines=list((0.,)), ci=0.95, truncate_yaxis=False,
truncate_xaxis=True, ylim=dict(), invert_y=False,
show_sensors=None, show_legend=True,
split_legend=False, axes=None, title=None, show=True):
"""Plot evoked time courses for one or multiple channels and conditions.

This function is useful for comparing ER[P/F]s at a specific location. It
Expand All @@ -1454,7 +1470,7 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None,
Parameters
----------
evokeds : instance of mne.Evoked | list | dict
If a single evoked instance, it is plotted as a time series.
If a single Evoked instance, it is plotted as a time series.
If a dict whose values are Evoked objects, the contents are plotted as
single time series each and the keys are used as condition labels.
If a list of Evokeds, the contents are plotted with indices as labels.
Expand Down Expand Up @@ -1483,6 +1499,9 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None,
"Vis/L", "Vis/R", `colors` can be `dict(Aud='r', Vis='b')` to map both
Aud/L and Aud/R to the color red and both Visual conditions to blue.
If None (default), a sequence of desaturated colors is used.
If `cmap` is None, `colors` will indicate how each condition is
colored with reference to its position on the colormap - see `cmap`
below.
linestyles : list | dict
If a list, will be sequentially and repeatedly used for evoked plot
linestyles.
Expand All @@ -1498,21 +1517,44 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None,
parameters will be passed to the line plot call of the corresponding
condition, overriding defaults.
E.g., if evokeds is a dict with the keys "Aud/L", "Aud/R",
"Vis/L", "Vis/R", `styles` can be `{"Aud/L":{"linewidth":1}}` to set
"Vis/L", "Vis/R", `styles` can be `{"Aud/L": {"linewidth": 1}}` to set
the linewidth for "Aud/L" to 1. Note that HED ('/'-separated) tags are
not supported.
cmap : None | str | tuple
If not None, plot evoked activity with colors from a color gradient
(indicated by a str referencing a matplotlib colormap - e.g., "viridis"
or "Reds").
If ``evokeds`` is a list and ``colors`` is `None`, the color will
depend on the list position. If ``colors`` is a list, it must contain
integers where the list positions correspond to ``evokeds``, and the
value corresponds to the position on the colorbar.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure this complexity is needed. If you want a different order of colors, you can reorder Evokeds, and when you need specific colors you can get these from the colormap yourself and pass to colors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reordering Evokeds won't work for the dict case (which is my preferred case), and won't work for the important case where multiple evokeds get the same color.

Taking colors from an MPL (or, easier, seaborn) colormap is nice and all, but getting everything. set up - particularly the split legends - is hell on the user.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, after looking at how easy and elegant the example some lines below is - I tend to agree.
However as I mentioned in another comment it would be more flexible if colors could get floats between 0 and 1 and the linear order was not assumed.

If ``evokeds`` is a dict, ``colors`` should be a dict mapping from
(potentially HED-style) condition tags to numbers corresponding to
rank order positions on the colorbar. E.g., ::
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aparat from the remark above - this is documentation for cmap, but it is mostly devoted to colors kwarg - maybe that should be there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean. However, here's my rationale. colors has essentially two modes: a categorical mode, and a gradient mode. The gradient mode depends on camp.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I get that, it is just a bit confusing, no big deal.


evokeds={"cond1/A": ev1, "cond2/A": ev2, "cond3/A": ev3, "B": ev4},
cmap='viridis', colors=dict(cond1=1 cond2=2, cond3=3),
linestyles={"A": "-", "B": ":"}

If ``cmap`` is a tuple of length 2, the first item must be
a string which will become the colorbar label, and the second one
must indicate a colormap, e.g. ::

cmap=('conds', 'viridis'), colors=dict(cond1=1 cond2=2, cond3=3),

vlines : list of int
A list of integers corresponding to the positions, in seconds,
at which to plot dashed vertical lines.
ci : float | callable | None
ci : float | callable | None | bool
If not None and ``evokeds`` is a [list/dict] of lists, a shaded
confidence interval is drawn around the individual time series. If
float, a percentile bootstrap method is used to estimate the confidence
interval and this value determines the CI width. E.g., if this value is
.95 (the default), the 95% confidence interval is drawn. If a callable,
it must take as its single argument an array (observations x times) and
return the upper and lower confidence bands.
If None, no confidence band is plotted.
If None or False, no confidence band is plotted.
If True, the 95% confidence interval is drawn.
truncate_yaxis : bool | str
If True, the left y axis spine is truncated to reduce visual clutter.
If 'max_ticks', the spine is truncated at the minimum and maximum
Expand All @@ -1529,21 +1571,25 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None,
invert_y : bool
If True, negative values are plotted up (as is sometimes done
for ERPs out of tradition). Defaults to False.
show_sensors: bool | int | None
show_sensors: bool | int | str | None
If not False, channel locations are plotted on a small head circle.
If an int, the position of the axes (forwarded to
If int or str, the position of the axes (forwarded to
``mpl_toolkits.axes_grid1.inset_locator.inset_axes``).
If None, defaults to True if ``gfp`` is False, else to False.
show_legend : bool | int
If not False, show a legend. If int, the position of the axes
show_legend : bool | str | int
If not False, show a legend. If int or str, the position of the axes
(forwarded to ``mpl_toolkits.axes_grid1.inset_locator.inset_axes``).
split_legend : bool
If True, the legend shows color and linestyle separately; `colors` must
not be None. Defaults to True if ``cmap`` is not None, else defaults to
False.
axes : None | `matplotlib.axes.Axes` instance | list of `axes`
What axes to plot to. If None, a new axes is created.
When plotting multiple channel types, can also be a list of axes, one
per channel type.
title : None | str
If str, will be plotted as figure title. If None, the channel
names will be shown.
If str, will be plotted as figure title. If None, the channel names
will be shown.
show : bool
If True, show the figure.

Expand All @@ -1558,7 +1604,10 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None,
# set up labels and instances
if isinstance(evokeds, Evoked):
evokeds = dict(Evoked=evokeds) # title becomes 'Evoked'
elif not isinstance(evokeds, dict):
elif not isinstance(evokeds, dict): # it's assumed to be a list
if (cmap is not None) and (colors is None):
colors = dict(
(str(ii + 1), ii) for ii, evoked in enumerate(evokeds))
evokeds = dict((str(ii + 1), evoked)
for ii, evoked in enumerate(evokeds))
for cond in evokeds.keys():
Expand All @@ -1581,13 +1630,12 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None,
if isinstance(picks, Integral):
picks = [picks]
elif len(picks) == 0:
warn("No picks, plotting the GFP ...")
gfp = True
picks = _pick_data_channels(example.info)

if len(picks) == 0:
raise ValueError("No valid channels were found to plot the GFP. " +
"Use 'picks' instead to select them manually.")
warn("No picks, plotting the GFP ...")

if ylim is None:
ylim = dict()
Expand All @@ -1602,14 +1650,17 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None,
else:
if not isinstance(picks[0], (int, np.integer)):
msg = "'picks' must be int or a list of int, not {0}."
raise ValueError(msg.format(type(picks)))
raise TypeError(msg.format(type(picks)))
show_sensors = True if show_sensors is None else show_sensors
ch_names = [example.ch_names[pick] for pick in picks]
ch_types = list(set(channel_type(example.info, pick_)
for pick_ in picks))
# XXX: could possibly be refactored; plot_joint is doing a similar thing
if any([type_ not in _VALID_CHANNEL_TYPES for type_ in ch_types]):
raise ValueError("Non-data channel picked.")
non_data_channels = [str(pick) for pick, type_ in zip(picks, ch_types)
if type_ not in _VALID_CHANNEL_TYPES]
if len(non_data_channels) > 0:
msg = "Non-data channel(s) {0} were picked."
raise ValueError(msg.format(", ".join(non_data_channels)))
if len(ch_types) > 1:
warn("Multiple channel types selected, returning one figure per type.")
if axes is not None:
Expand All @@ -1619,7 +1670,7 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None,
for ii, t in enumerate(ch_types):
picks_ = [idx for idx in picks
if channel_type(example.info, idx) == t]
title_ = "GFP, " + t if not title and gfp is True else title
title_ = "GFP, " + t if (not title and (gfp is True)) else title
ax_ = axes[ii] if axes is not None else None
figs.append(
plot_compare_evokeds(
Expand All @@ -1633,8 +1684,13 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None,
ymin, ymax = ylim.get(ch_type, [None, None])

# deal with dict/list of lists and the CI
if ci is not None and not (isinstance(ci, np.float) or callable(ci)):
raise TypeError('ci must be float or callable, got ' + str(type(ci)))
if ci is None:
ci = False
if ci is True:
ci = .95
elif ci is not False and not (isinstance(ci, np.float) or callable(ci)):
raise TypeError('ci must be None, bool, float or callable, got ' +
str(type(ci)))

scaling = _handle_default("scalings")[ch_type]
unit = _handle_default("units")[ch_type]
Expand All @@ -1650,8 +1706,6 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None,
ymin = 0. # 'grad' and GFP are plotted as all-positive

# if we have a dict/list of lists, we compute the grand average and the CI
if ci is None:
ci = False
if not all([isinstance(evoked_, Evoked) for evoked_ in evokeds.values()]):
if ci is not False:
if callable(ci):
Expand Down Expand Up @@ -1705,13 +1759,67 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None,
"conditions. Condition " + style_ +
" was not found in the supplied data.")

# second, color
# third, color
# check: is color a list?
if (colors is not None and not isinstance(colors, string_types) and
not isinstance(colors, dict) and len(colors) > 1):
colors = dict((condition, color) for condition, color
in zip(conditions, colors))

if cmap is not None:
if not isinstance(cmap, string_types) and len(cmap) == 2:
cmap_label, cmap = cmap
else:
cmap_label = ""

# dealing with a split legend
if split_legend is None:
split_legend = cmap is not None # default to True iff cmap is given
if split_legend is True:
if colors is None:
raise ValueError(
"If `split_legend` is True, `colors` must not be None.")
# mpl 1.3 requires us to split it like this. with recent mpl,
# we could use the label parameter of the Line2D
legend_lines, legend_labels = list(), list()
if cmap is None: # ... one set of lines for the colors
for color in sorted(colors.keys()):
line = mlines.Line2D([], [], linestyle="-",
color=colors[color])
legend_lines.append(line)
legend_labels.append(color)
if len(list(linestyles)) > 1: # ... one set for the linestyle
for style, s in linestyles.items():
line = mlines.Line2D([], [], color='k', linestyle=s)
legend_lines.append(line)
legend_labels.append(style)

# dealing with continuous colors
if cmap is not None:
for color_value in colors.values():
try:
float(color_value)
except ValueError:
raise TypeError("If ``cmap`` is not None, the values of "
"``colors`` must be numeric. Got " +
str(type(color_value)))
cmapper = getattr(plt.cm, cmap, cmap)
color_conds = list(colors.keys())
all_colors = [colors[cond] for cond in color_conds]
n_colors = len(all_colors)
color_order = np.array(all_colors).argsort()
color_indices = color_order.argsort()

the_colors = cmapper(np.linspace(0, 1, n_colors))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it would be better to ask users to give colors as 0 - 1 floats, then non-linear mapping would be possible (for example only picking values [0.3, 0.5, 0.8])? But I am still not sure if this is better than getting the colors and passing them yourself.


colors_ = {cond: ind for cond, ind in zip(color_conds, color_indices)}
colors = dict()
for cond in evokeds.keys():
for cond_number, color in colors_.items():
if cond_number in cond:
colors[cond] = the_colors[color]
continue

if not isinstance(colors, dict): # default colors from M Waskom's Seaborn
# XXX should put a good list of default colors into defaults.py
colors_ = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just BTW - I am not sure about that, if one is using some matplotlib style maybe we should honor that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we need default colors listed explicitly somewhere anyways. That's the way the code works.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically you could probably extract the default matplotlib color cycle, but I don't mind customizing it here (we do it elsewhere, too, I think)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note if we use the MPL defaults as a manual list, that still won't help people who use custom MPL colors.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was thinking about the current color cycle, not the default one. But I like the default colors, so nevermind for now. :)

Expand Down Expand Up @@ -1756,14 +1864,14 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None,
d = np.sqrt((d * d).mean(axis=-1))
else:
d = d.mean(-1)
axes.plot(times, d, zorder=1000, label=condition, **styles[condition])
axes.plot(times, d, zorder=100, label=condition, **styles[condition])
if any(d > 0) or all_positive:
any_positive = True
if np.any(d < 0):
any_negative = True

# plot the confidence interval
if ci and (gfp is not True):
if ci:
ci_ = ci_array[condition]
axes.fill_between(times, ci_[0].flatten(), ci_[1].flatten(),
zorder=9, color=styles[condition]['c'], alpha=.3)
Expand Down Expand Up @@ -1808,6 +1916,8 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None,
_setup_ax_spines(axes, vlines, tmin, tmax, invert_y, ymax_bound, unit,
truncate_xaxis)

# and now for 3 "legends" ..
# a head plot showing the sensors that are being plotted
if show_sensors:
try:
pos = _auto_topomap_coords(
Expand All @@ -1818,19 +1928,36 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None,
else:
head_pos = {'center': (0, 0), 'scale': (0.5, 0.5)}
pos, outlines = _check_outlines(pos, np.array([1, 1]), head_pos)
if not isinstance(show_sensors, (np.int, bool)):
raise TypeError("`show_sensors` must be numeric or bool, not" +
str(type(show_sensors)))
if show_sensors is True:
show_sensors = 2
if not isinstance(show_sensors, (np.int, bool, str)):
raise TypeError("show_sensors must be numeric, str or bool, "
"not " + str(type(show_sensors)))
show_sensors = _check_loc_legal(show_sensors, "show_sensors")
_plot_legend(pos, ["k" for pick in picks], axes, list(), outlines,
show_sensors, size=20)

if show_legend and len(conditions) > 1:
if show_legend is True:
show_legend = 'best'
axes.legend(loc=show_legend, ncol=1 + (len(conditions) // 5),
frameon=True)
# the condition legend
if len(conditions) > 1 and show_legend is not False:
show_legend = _check_loc_legal(show_legend, "show_legend")
legend_params = dict(loc=show_legend, frameon=True)
if split_legend:
if len(legend_lines) > 1:
axes.legend(legend_lines, legend_labels, # see above: mpl 1.3
ncol=1 + (len(legend_lines) // 4), **legend_params)
else:
axes.legend(ncol=1 + (len(conditions) // 5), **legend_params)

# the colormap, if `cmap` is provided
if split_legend and cmap is not None:
# plot the colorbar ... complicated cause we don't have a heatmap
from mpl_toolkits.axes_grid1 import make_axes_locatable
divider = make_axes_locatable(axes)
ax_cb = divider.append_axes("right", size="5%", pad=0.05)
ax_cb.imshow(the_colors[:, np.newaxis, :], interpolation='none')
ax_cb.set_yticks(np.arange(len(the_colors)))
ax_cb.set_yticklabels(np.array(color_conds)[color_order])
ax_cb.yaxis.tick_right()
ax_cb.set_xticks(())
ax_cb.set_ylabel(cmap_label)

plt_show(show)
return fig
Loading