-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
[DO NOT MRG] plot_compare_evoked improved line colouring #4526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e48e050
d7556f9
1b32fc7
36d3617
6dc0493
81db3a0
3feb2fd
8b87314
9488c60
c2269b7
4c49a87
27f4371
7c8c36b
f4d29ac
5a01d28
63db06c
1b8c75b
a414256
51769c7
eb17427
5e79013
9eeb617
55c7298
21833fb
f957146
4374c42
db4d561
cfb5ce6
c1f8e34
e7524a7
5d6d436
f62ebba
41eaa40
66bd908
2b2dc53
fd23602
fb37f87
4c24b4d
06a285a
a5fc30a
b9afa02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| from numbers import Integral | ||
|
|
||
| import numpy as np | ||
| import matplotlib.lines as mlines | ||
|
|
||
| from ..io.pick import (channel_type, pick_types, _picks_by_type, | ||
| _pick_data_channels, _VALID_CHANNEL_TYPES) | ||
|
|
@@ -1440,11 +1441,26 @@ def _truncate_yaxis(axes, ymin, ymax, orig_ymin, orig_ymax, fraction, | |
| return ymin_bound, ymax_bound | ||
|
|
||
|
|
||
| def _check_loc_legal(loc, what='your choice'): | ||
| """Check if a loc is a legal for MPL.""" | ||
| true_default = {"show_legend": 3, "show_sensors": 4}.get(what, 1) | ||
| loc_dict = {'upper right': 1, 'upper left': 2, 'lower left': 3, | ||
| 'lower right': 4, 'right': 5, 'center left': 6, | ||
| 'center right': 7, 'lower center': 8, 'upper center': 9, | ||
| 'center': 10, True: true_default} | ||
| loc_ = loc_dict.get(loc, loc) | ||
| if loc_ not in range(11): | ||
| raise ValueError(str(loc) + " is not a legal MPL loc, please supply" | ||
| "another value for " + what + ".") | ||
| return loc_ | ||
|
|
||
|
|
||
| def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None, | ||
| linestyles=['-'], styles=None, vlines=list((0.,)), | ||
| ci=0.95, truncate_yaxis=False, truncate_xaxis=True, | ||
| ylim=dict(), invert_y=False, show_sensors=None, | ||
| show_legend=True, axes=None, title=None, show=True): | ||
| linestyles=['-'], styles=None, cmap=None, | ||
| vlines=list((0.,)), ci=0.95, truncate_yaxis=False, | ||
| truncate_xaxis=True, ylim=dict(), invert_y=False, | ||
| show_sensors=None, show_legend=True, | ||
| split_legend=False, axes=None, title=None, show=True): | ||
| """Plot evoked time courses for one or multiple channels and conditions. | ||
|
|
||
| This function is useful for comparing ER[P/F]s at a specific location. It | ||
|
|
@@ -1454,7 +1470,7 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None, | |
| Parameters | ||
| ---------- | ||
| evokeds : instance of mne.Evoked | list | dict | ||
| If a single evoked instance, it is plotted as a time series. | ||
| If a single Evoked instance, it is plotted as a time series. | ||
| If a dict whose values are Evoked objects, the contents are plotted as | ||
| single time series each and the keys are used as condition labels. | ||
| If a list of Evokeds, the contents are plotted with indices as labels. | ||
|
|
@@ -1483,6 +1499,9 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None, | |
| "Vis/L", "Vis/R", `colors` can be `dict(Aud='r', Vis='b')` to map both | ||
| Aud/L and Aud/R to the color red and both Visual conditions to blue. | ||
| If None (default), a sequence of desaturated colors is used. | ||
| If `cmap` is None, `colors` will indicate how each condition is | ||
| colored with reference to its position on the colormap - see `cmap` | ||
| below. | ||
| linestyles : list | dict | ||
| If a list, will be sequentially and repeatedly used for evoked plot | ||
| linestyles. | ||
|
|
@@ -1498,21 +1517,44 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None, | |
| parameters will be passed to the line plot call of the corresponding | ||
| condition, overriding defaults. | ||
| E.g., if evokeds is a dict with the keys "Aud/L", "Aud/R", | ||
| "Vis/L", "Vis/R", `styles` can be `{"Aud/L":{"linewidth":1}}` to set | ||
| "Vis/L", "Vis/R", `styles` can be `{"Aud/L": {"linewidth": 1}}` to set | ||
| the linewidth for "Aud/L" to 1. Note that HED ('/'-separated) tags are | ||
| not supported. | ||
| cmap : None | str | tuple | ||
| If not None, plot evoked activity with colors from a color gradient | ||
| (indicated by a str referencing a matplotlib colormap - e.g., "viridis" | ||
| or "Reds"). | ||
| If ``evokeds`` is a list and ``colors`` is `None`, the color will | ||
| depend on the list position. If ``colors`` is a list, it must contain | ||
| integers where the list positions correspond to ``evokeds``, and the | ||
| value corresponds to the position on the colorbar. | ||
| If ``evokeds`` is a dict, ``colors`` should be a dict mapping from | ||
| (potentially HED-style) condition tags to numbers corresponding to | ||
| rank order positions on the colorbar. E.g., :: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. aparat from the remark above - this is documentation for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see what you mean. However, here's my rationale.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I get that, it is just a bit confusing, no big deal. |
||
|
|
||
| evokeds={"cond1/A": ev1, "cond2/A": ev2, "cond3/A": ev3, "B": ev4}, | ||
| cmap='viridis', colors=dict(cond1=1 cond2=2, cond3=3), | ||
| linestyles={"A": "-", "B": ":"} | ||
|
|
||
| If ``cmap`` is a tuple of length 2, the first item must be | ||
| a string which will become the colorbar label, and the second one | ||
| must indicate a colormap, e.g. :: | ||
|
|
||
| cmap=('conds', 'viridis'), colors=dict(cond1=1 cond2=2, cond3=3), | ||
|
|
||
| vlines : list of int | ||
| A list of integers corresponding to the positions, in seconds, | ||
| at which to plot dashed vertical lines. | ||
| ci : float | callable | None | ||
| ci : float | callable | None | bool | ||
| If not None and ``evokeds`` is a [list/dict] of lists, a shaded | ||
| confidence interval is drawn around the individual time series. If | ||
| float, a percentile bootstrap method is used to estimate the confidence | ||
| interval and this value determines the CI width. E.g., if this value is | ||
| .95 (the default), the 95% confidence interval is drawn. If a callable, | ||
| it must take as its single argument an array (observations x times) and | ||
| return the upper and lower confidence bands. | ||
| If None, no confidence band is plotted. | ||
| If None or False, no confidence band is plotted. | ||
| If True, the 95% confidence interval is drawn. | ||
| truncate_yaxis : bool | str | ||
| If True, the left y axis spine is truncated to reduce visual clutter. | ||
| If 'max_ticks', the spine is truncated at the minimum and maximum | ||
|
|
@@ -1529,21 +1571,25 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None, | |
| invert_y : bool | ||
| If True, negative values are plotted up (as is sometimes done | ||
| for ERPs out of tradition). Defaults to False. | ||
| show_sensors: bool | int | None | ||
| show_sensors: bool | int | str | None | ||
| If not False, channel locations are plotted on a small head circle. | ||
| If an int, the position of the axes (forwarded to | ||
| If int or str, the position of the axes (forwarded to | ||
| ``mpl_toolkits.axes_grid1.inset_locator.inset_axes``). | ||
| If None, defaults to True if ``gfp`` is False, else to False. | ||
| show_legend : bool | int | ||
| If not False, show a legend. If int, the position of the axes | ||
| show_legend : bool | str | int | ||
| If not False, show a legend. If int or str, the position of the axes | ||
| (forwarded to ``mpl_toolkits.axes_grid1.inset_locator.inset_axes``). | ||
| split_legend : bool | ||
| If True, the legend shows color and linestyle separately; `colors` must | ||
| not be None. Defaults to True if ``cmap`` is not None, else defaults to | ||
| False. | ||
| axes : None | `matplotlib.axes.Axes` instance | list of `axes` | ||
| What axes to plot to. If None, a new axes is created. | ||
| When plotting multiple channel types, can also be a list of axes, one | ||
| per channel type. | ||
| title : None | str | ||
| If str, will be plotted as figure title. If None, the channel | ||
| names will be shown. | ||
| If str, will be plotted as figure title. If None, the channel names | ||
| will be shown. | ||
| show : bool | ||
| If True, show the figure. | ||
|
|
||
|
|
@@ -1558,7 +1604,10 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None, | |
| # set up labels and instances | ||
| if isinstance(evokeds, Evoked): | ||
| evokeds = dict(Evoked=evokeds) # title becomes 'Evoked' | ||
| elif not isinstance(evokeds, dict): | ||
| elif not isinstance(evokeds, dict): # it's assumed to be a list | ||
| if (cmap is not None) and (colors is None): | ||
| colors = dict( | ||
| (str(ii + 1), ii) for ii, evoked in enumerate(evokeds)) | ||
| evokeds = dict((str(ii + 1), evoked) | ||
| for ii, evoked in enumerate(evokeds)) | ||
| for cond in evokeds.keys(): | ||
|
|
@@ -1581,13 +1630,12 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None, | |
| if isinstance(picks, Integral): | ||
| picks = [picks] | ||
| elif len(picks) == 0: | ||
| warn("No picks, plotting the GFP ...") | ||
| gfp = True | ||
| picks = _pick_data_channels(example.info) | ||
|
|
||
| if len(picks) == 0: | ||
| raise ValueError("No valid channels were found to plot the GFP. " + | ||
| "Use 'picks' instead to select them manually.") | ||
| warn("No picks, plotting the GFP ...") | ||
|
|
||
| if ylim is None: | ||
| ylim = dict() | ||
|
|
@@ -1602,14 +1650,17 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None, | |
| else: | ||
| if not isinstance(picks[0], (int, np.integer)): | ||
| msg = "'picks' must be int or a list of int, not {0}." | ||
| raise ValueError(msg.format(type(picks))) | ||
| raise TypeError(msg.format(type(picks))) | ||
| show_sensors = True if show_sensors is None else show_sensors | ||
| ch_names = [example.ch_names[pick] for pick in picks] | ||
| ch_types = list(set(channel_type(example.info, pick_) | ||
| for pick_ in picks)) | ||
| # XXX: could possibly be refactored; plot_joint is doing a similar thing | ||
| if any([type_ not in _VALID_CHANNEL_TYPES for type_ in ch_types]): | ||
| raise ValueError("Non-data channel picked.") | ||
| non_data_channels = [str(pick) for pick, type_ in zip(picks, ch_types) | ||
| if type_ not in _VALID_CHANNEL_TYPES] | ||
| if len(non_data_channels) > 0: | ||
| msg = "Non-data channel(s) {0} were picked." | ||
| raise ValueError(msg.format(", ".join(non_data_channels))) | ||
| if len(ch_types) > 1: | ||
| warn("Multiple channel types selected, returning one figure per type.") | ||
| if axes is not None: | ||
|
|
@@ -1619,7 +1670,7 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None, | |
| for ii, t in enumerate(ch_types): | ||
| picks_ = [idx for idx in picks | ||
| if channel_type(example.info, idx) == t] | ||
| title_ = "GFP, " + t if not title and gfp is True else title | ||
| title_ = "GFP, " + t if (not title and (gfp is True)) else title | ||
| ax_ = axes[ii] if axes is not None else None | ||
| figs.append( | ||
| plot_compare_evokeds( | ||
|
|
@@ -1633,8 +1684,13 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None, | |
| ymin, ymax = ylim.get(ch_type, [None, None]) | ||
|
|
||
| # deal with dict/list of lists and the CI | ||
| if ci is not None and not (isinstance(ci, np.float) or callable(ci)): | ||
| raise TypeError('ci must be float or callable, got ' + str(type(ci))) | ||
| if ci is None: | ||
| ci = False | ||
| if ci is True: | ||
| ci = .95 | ||
| elif ci is not False and not (isinstance(ci, np.float) or callable(ci)): | ||
| raise TypeError('ci must be None, bool, float or callable, got ' + | ||
| str(type(ci))) | ||
|
|
||
| scaling = _handle_default("scalings")[ch_type] | ||
| unit = _handle_default("units")[ch_type] | ||
|
|
@@ -1650,8 +1706,6 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None, | |
| ymin = 0. # 'grad' and GFP are plotted as all-positive | ||
|
|
||
| # if we have a dict/list of lists, we compute the grand average and the CI | ||
| if ci is None: | ||
| ci = False | ||
| if not all([isinstance(evoked_, Evoked) for evoked_ in evokeds.values()]): | ||
| if ci is not False: | ||
| if callable(ci): | ||
|
|
@@ -1705,13 +1759,67 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None, | |
| "conditions. Condition " + style_ + | ||
| " was not found in the supplied data.") | ||
|
|
||
| # second, color | ||
| # third, color | ||
| # check: is color a list? | ||
| if (colors is not None and not isinstance(colors, string_types) and | ||
| not isinstance(colors, dict) and len(colors) > 1): | ||
| colors = dict((condition, color) for condition, color | ||
| in zip(conditions, colors)) | ||
|
|
||
| if cmap is not None: | ||
| if not isinstance(cmap, string_types) and len(cmap) == 2: | ||
| cmap_label, cmap = cmap | ||
| else: | ||
| cmap_label = "" | ||
|
|
||
| # dealing with a split legend | ||
| if split_legend is None: | ||
| split_legend = cmap is not None # default to True iff cmap is given | ||
| if split_legend is True: | ||
| if colors is None: | ||
| raise ValueError( | ||
| "If `split_legend` is True, `colors` must not be None.") | ||
| # mpl 1.3 requires us to split it like this. with recent mpl, | ||
| # we could use the label parameter of the Line2D | ||
| legend_lines, legend_labels = list(), list() | ||
| if cmap is None: # ... one set of lines for the colors | ||
| for color in sorted(colors.keys()): | ||
| line = mlines.Line2D([], [], linestyle="-", | ||
| color=colors[color]) | ||
| legend_lines.append(line) | ||
| legend_labels.append(color) | ||
| if len(list(linestyles)) > 1: # ... one set for the linestyle | ||
| for style, s in linestyles.items(): | ||
| line = mlines.Line2D([], [], color='k', linestyle=s) | ||
| legend_lines.append(line) | ||
| legend_labels.append(style) | ||
|
|
||
| # dealing with continuous colors | ||
| if cmap is not None: | ||
| for color_value in colors.values(): | ||
| try: | ||
| float(color_value) | ||
| except ValueError: | ||
| raise TypeError("If ``cmap`` is not None, the values of " | ||
| "``colors`` must be numeric. Got " + | ||
| str(type(color_value))) | ||
| cmapper = getattr(plt.cm, cmap, cmap) | ||
| color_conds = list(colors.keys()) | ||
| all_colors = [colors[cond] for cond in color_conds] | ||
| n_colors = len(all_colors) | ||
| color_order = np.array(all_colors).argsort() | ||
| color_indices = color_order.argsort() | ||
|
|
||
| the_colors = cmapper(np.linspace(0, 1, n_colors)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe it would be better to ask users to give colors as 0 - 1 floats, then non-linear mapping would be possible (for example only picking values [0.3, 0.5, 0.8])? But I am still not sure if this is better than getting the colors and passing them yourself. |
||
|
|
||
| colors_ = {cond: ind for cond, ind in zip(color_conds, color_indices)} | ||
| colors = dict() | ||
| for cond in evokeds.keys(): | ||
| for cond_number, color in colors_.items(): | ||
| if cond_number in cond: | ||
| colors[cond] = the_colors[color] | ||
| continue | ||
|
|
||
| if not isinstance(colors, dict): # default colors from M Waskom's Seaborn | ||
| # XXX should put a good list of default colors into defaults.py | ||
| colors_ = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just BTW - I am not sure about that, if one is using some matplotlib style maybe we should honor that?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, we need default colors listed explicitly somewhere anyways. That's the way the code works.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. technically you could probably extract the default
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note if we use the MPL defaults as a manual list, that still won't help people who use custom MPL colors.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I was thinking about the current color cycle, not the default one. But I like the default colors, so nevermind for now. :) |
||
|
|
@@ -1756,14 +1864,14 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None, | |
| d = np.sqrt((d * d).mean(axis=-1)) | ||
| else: | ||
| d = d.mean(-1) | ||
| axes.plot(times, d, zorder=1000, label=condition, **styles[condition]) | ||
| axes.plot(times, d, zorder=100, label=condition, **styles[condition]) | ||
| if any(d > 0) or all_positive: | ||
| any_positive = True | ||
| if np.any(d < 0): | ||
| any_negative = True | ||
|
|
||
| # plot the confidence interval | ||
| if ci and (gfp is not True): | ||
| if ci: | ||
| ci_ = ci_array[condition] | ||
| axes.fill_between(times, ci_[0].flatten(), ci_[1].flatten(), | ||
| zorder=9, color=styles[condition]['c'], alpha=.3) | ||
|
|
@@ -1808,6 +1916,8 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None, | |
| _setup_ax_spines(axes, vlines, tmin, tmax, invert_y, ymax_bound, unit, | ||
| truncate_xaxis) | ||
|
|
||
| # and now for 3 "legends" .. | ||
| # a head plot showing the sensors that are being plotted | ||
| if show_sensors: | ||
| try: | ||
| pos = _auto_topomap_coords( | ||
|
|
@@ -1818,19 +1928,36 @@ def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None, | |
| else: | ||
| head_pos = {'center': (0, 0), 'scale': (0.5, 0.5)} | ||
| pos, outlines = _check_outlines(pos, np.array([1, 1]), head_pos) | ||
| if not isinstance(show_sensors, (np.int, bool)): | ||
| raise TypeError("`show_sensors` must be numeric or bool, not" + | ||
| str(type(show_sensors))) | ||
| if show_sensors is True: | ||
| show_sensors = 2 | ||
| if not isinstance(show_sensors, (np.int, bool, str)): | ||
| raise TypeError("show_sensors must be numeric, str or bool, " | ||
| "not " + str(type(show_sensors))) | ||
| show_sensors = _check_loc_legal(show_sensors, "show_sensors") | ||
| _plot_legend(pos, ["k" for pick in picks], axes, list(), outlines, | ||
| show_sensors, size=20) | ||
|
|
||
| if show_legend and len(conditions) > 1: | ||
| if show_legend is True: | ||
| show_legend = 'best' | ||
| axes.legend(loc=show_legend, ncol=1 + (len(conditions) // 5), | ||
| frameon=True) | ||
| # the condition legend | ||
| if len(conditions) > 1 and show_legend is not False: | ||
| show_legend = _check_loc_legal(show_legend, "show_legend") | ||
| legend_params = dict(loc=show_legend, frameon=True) | ||
| if split_legend: | ||
| if len(legend_lines) > 1: | ||
| axes.legend(legend_lines, legend_labels, # see above: mpl 1.3 | ||
| ncol=1 + (len(legend_lines) // 4), **legend_params) | ||
| else: | ||
| axes.legend(ncol=1 + (len(conditions) // 5), **legend_params) | ||
|
|
||
| # the colormap, if `cmap` is provided | ||
| if split_legend and cmap is not None: | ||
| # plot the colorbar ... complicated cause we don't have a heatmap | ||
| from mpl_toolkits.axes_grid1 import make_axes_locatable | ||
| divider = make_axes_locatable(axes) | ||
| ax_cb = divider.append_axes("right", size="5%", pad=0.05) | ||
| ax_cb.imshow(the_colors[:, np.newaxis, :], interpolation='none') | ||
| ax_cb.set_yticks(np.arange(len(the_colors))) | ||
| ax_cb.set_yticklabels(np.array(color_conds)[color_order]) | ||
| ax_cb.yaxis.tick_right() | ||
| ax_cb.set_xticks(()) | ||
| ax_cb.set_ylabel(cmap_label) | ||
|
|
||
| plt_show(show) | ||
| return fig | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure this complexity is needed. If you want a different order of colors, you can reorder Evokeds, and when you need specific colors you can get these from the colormap yourself and pass to
colors.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reordering Evokeds won't work for the dict case (which is my preferred case), and won't work for the important case where multiple evokeds get the same color.
Taking colors from an MPL (or, easier, seaborn) colormap is nice and all, but getting everything. set up - particularly the split legends - is hell on the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, after looking at how easy and elegant the example some lines below is - I tend to agree.
However as I mentioned in another comment it would be more flexible if colors could get floats between 0 and 1 and the linear order was not assumed.