diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index dcf028dfa78..b17a22cf4f0 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -57,6 +57,8 @@ Enhancements - :func:`mne.viz.plot_evoked_topomap` and :meth:`mne.Evoked.plot_topomap` now display the time range the map was averaged over if ``average`` was passed (:gh:`10606` by `Richard Höchenberger`_) +- :func:`mne.viz.plot_evoked_topomap` and :meth:`mne.Evoked.plot_topomap` can now average the topographic maps across different time periods for each time point. To do this, pass a list of periods via the ``average`` parameter (:gh:`10610` by `Richard Höchenberger`_) + Bugs ~~~~ - Make ``color`` parameter check in in :func:`mne.viz.plot_evoked_topo` consistent (:gh:`10217` by :newcontrib:`T. Wang` and `Stefan Appelhoff`_) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 0331d28e3d2..ba91eb0cc34 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -233,15 +233,6 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.13.0 """ -docdict['average_topomap'] = """ -average : float | None - The time window (in seconds) around a given time point to be used for - averaging. For example, 0.2 would translate into a time window that starts - 0.1 s before and ends 0.1 s after the given time point. If the time window - exceeds the duration of the data, it will be clipped. If ``None`` - (default), no averaging will take place. -""" - docdict['axes_psd_topo'] = """ axes : list of Axes | None List of axes to plot consecutive topographies to. If ``None`` the axes diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py index 46427934d33..42623b87f1b 100644 --- a/mne/viz/tests/test_topomap.py +++ b/mne/viz/tests/test_topomap.py @@ -297,9 +297,9 @@ def proc_names(x): plt_topomap(times, ch_type='grad', mask=mask, show_names=True, mask_params={'marker': 'x'}) plt.close('all') - with pytest.raises(ValueError, match='number of seconds; got -'): + with pytest.raises(ValueError, match='number of seconds.* got -'): plt_topomap(times, ch_type='eeg', average=-1e3) - with pytest.raises(TypeError, match='number of seconds; got type'): + with pytest.raises(TypeError, match='number of seconds.* got type'): plt_topomap(times, ch_type='eeg', average='x') p = plt_topomap(times, ch_type='grad', image_interp='bilinear', @@ -333,7 +333,7 @@ def get_texts(p): assert_equal(texts[0], 'Custom') plt.close('all') - # Test averaging + # Test averaging with a scalar input averaging_times = [ev_bad.times[0], times[0], ev_bad.times[-1]] p = plt_topomap(averaging_times, ch_type='eeg', average=0.01) @@ -345,6 +345,26 @@ def get_texts(p): for idx, expected_title in enumerate(expected_ax_titles): assert p.axes[idx].get_title() == expected_title + # Test averaging with an array-like input + averaging_durations = [0.01, 0.02, None] + p = plt_topomap( + averaging_times, ch_type='eeg', average=averaging_durations + ) + expected_ax_titles = ( + '-0.200 – -0.195 s', # clipped on the left + '0.090 – 0.110 s', # full range + '0.499 s' # No averaging + ) + for idx, expected_title in enumerate(expected_ax_titles): + assert p.axes[idx].get_title() == expected_title + + # Test averaging with array-like input, but n_times != n_average + averaging_durations = [0.01, 0.02] + with pytest.raises(ValueError, match='3 time points.*2 periods'): + plt_topomap( + averaging_times, ch_type='eeg', average=averaging_durations + ) + del averaging_times, expected_ax_titles, expected_title # delaunay triangulation warning diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index 5c8d33e103d..41319e84f36 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -1573,7 +1573,17 @@ def plot_evoked_topomap(evoked, times="auto", ch_type=None, %(outlines_topomap)s %(contours_topomap)s %(image_interp_topomap)s - %(average_topomap)s + average : float | array-like of float, shape (n_times,) | None + The time window (in seconds) around a given time point to be used for + averaging. For example, 0.2 would translate into a time window that + starts 0.1 s before and ends 0.1 s after the given time point. If the + time window exceeds the duration of the data, it will be clipped. + Different time windows (one per time point) can be provided by + passing an ``array-like`` object (e.g., ``[0.1, 0.2, 0.3]``). If + ``None`` (default), no averaging will take place. + + .. versionchanged:: 1.1 + Support for ``array-like`` input. %(axes_topomap)s %(extrapolate_topomap)s @@ -1707,25 +1717,55 @@ def plot_evoked_topomap(evoked, times="auto", ch_type=None, sfreq=evoked.info['sfreq']))[0][0] for t in times] # do averaging if requested - avg_err = '"average" must be `None` or a positive number of seconds' + avg_err = ('"average" must be `None`, a positive number of seconds, or ' + 'an array-like object of the previous') + averaged_times = [] if average is None: + average = np.array([None] * n_times) data = data[np.ix_(picks, time_idx)] - elif not _is_numeric(average): - raise TypeError(f'{avg_err}; got type {type(average)}.') - elif average <= 0: - raise ValueError(f'{avg_err}; got {average}.') else: + if _is_numeric(average): + average = np.array([average] * n_times) + elif np.array(average).ndim == 0: + # It should be an array-like object + raise TypeError(f'{avg_err}; got type: {type(average)}.') + else: + average = np.array(average) + + if len(average) != n_times: + raise ValueError( + f'You requested to plot topographic maps for {n_times} time ' + f'points, but provided {len(average)} periods for ' + f'averaging. The number of time points and averaging periods ' + f'must be equal.' + ) data_ = np.zeros((len(picks), len(time_idx))) - ave_time = average / 2. - iter_times = evoked.times[time_idx] - for ii, (idx, tmin_, tmax_) in enumerate(zip(time_idx, - iter_times - ave_time, - iter_times + ave_time)): - my_range = (tmin_ < evoked.times) & (evoked.times < tmax_) - data_[:, ii] = data[picks][:, my_range].mean(-1) - averaged_times.append(evoked.times[my_range]) + + for average_idx, (this_average, this_time, this_time_idx) in enumerate( + zip(average, evoked.times[time_idx], time_idx) + ): + if ( + (_is_numeric(this_average) and this_average <= 0) or + (not _is_numeric(this_average) and this_average is not None) + ): + if len(average) == 1: + msg = f'{avg_err}; got {this_average}' + else: + msg = f'{avg_err}; got {this_average} in {average}' + raise ValueError(msg) + + if this_average is None: + data_[:, average_idx] = data[picks][:, this_time_idx] + averaged_times.append([this_time]) + else: + tmin_ = this_time - this_average / 2 + tmax_ = this_time + this_average / 2 + time_mask = (tmin_ < evoked.times) & (evoked.times < tmax_) + data_[:, average_idx] = data[picks][:, time_mask].mean(-1) + averaged_times.append(evoked.times[time_mask]) data = data_ + # apply scalings and merge channels data *= scaling if merge_channels: @@ -1758,21 +1798,26 @@ def plot_evoked_topomap(evoked, times="auto", ch_type=None, ch_type=ch_type) images, contours_ = [], [] # loop over times - for idx, time in enumerate(times): - adjust_for_cbar = colorbar and ncols is not None and idx >= ncols - 1 - ax_idx = idx + 1 if adjust_for_cbar else idx + for average_idx, (time, this_average) in enumerate( + zip(times, average) + ): + adjust_for_cbar = (colorbar and + ncols is not None and + average_idx >= ncols - 1) + ax_idx = average_idx + 1 if adjust_for_cbar else average_idx tp, cn, interp = _plot_topomap( - data[:, idx], pos, axes=axes[ax_idx], - mask=mask_[:, idx] if mask is not None else None, **kwargs) + data[:, average_idx], pos, axes=axes[ax_idx], + mask=mask_[:, average_idx] if mask is not None else None, **kwargs) images.append(tp) if cn is not None: contours_.append(cn) if time_format != '': - if average is None: + if this_average is None: axes_title = time_format % (time * scaling_time) else: - tmin_, tmax_ = averaged_times[idx][0], averaged_times[idx][-1] + tmin_ = averaged_times[average_idx][0] + tmax_ = averaged_times[average_idx][-1] from_time = time_format % (tmin_ * scaling_time) from_time = from_time.split(' ')[0] # Remove unit to_time = time_format % (tmax_ * scaling_time) diff --git a/tutorials/evoked/20_visualize_evoked.py b/tutorials/evoked/20_visualize_evoked.py index 1db5cf0c250..2aa6c8ad00b 100644 --- a/tutorials/evoked/20_visualize_evoked.py +++ b/tutorials/evoked/20_visualize_evoked.py @@ -99,8 +99,17 @@ # %% -fig = evks['aud/left'].plot_topomap(ch_type='mag', times=0.09, average=0.1) -fig.text(0.5, 0.05, 'average from 40-140 ms', ha='center') +fig = evks['aud/left'].plot_topomap(ch_type='mag', times=times, average=0.1) + +# %% +# It is also possible to pass different time durations to average over for each +# time point. Passing a value of ``None`` will disable averaging for that +# time point: + +averaging_durations = [0.01, 0.02, 0.03, None, None] +fig = evks['aud/left'].plot_topomap( + ch_type='mag', times=times, average=averaging_durations +) # %% # Additional examples of plotting scalp topographies can be found in