Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ Enhancements
~~~~~~~~~~~~
.. - Add something cool (:gh:`9192` **by new contributor** |New Contributor|_)

- Add the "exclude" keyword to the psd_plot in :func:`mne.time_frequency.psd.py` (:gh:`9379` by `Eduard ort`_)

- New function :func:`mne.chpi.get_chpi_info` to retrieve basic information about the cHPI system used when recording MEG data (:gh:`9369` by `Richard Höchenberger`_)


Expand Down
13 changes: 7 additions & 6 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,16 +1116,17 @@ def plot(self, picks=None, scalings=None, n_epochs=20, n_channels=20,
@copy_function_doc_to_method_doc(plot_epochs_psd)
def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None,
proj=False, bandwidth=None, adaptive=False, low_bias=True,
normalization='length', picks=None, ax=None, color='black',
xscale='linear', area_mode='std', area_alpha=0.33,
dB=True, estimate='auto', show=True, n_jobs=1,
average=False, line_alpha=None, spatial_colors=True,
normalization='length', picks=None, exclude='bads', ax=None,
color='black', xscale='linear', area_mode='std',
area_alpha=0.33, dB=True, estimate='auto', show=True,
n_jobs=1, average=False, line_alpha=None, spatial_colors=True,
sphere=None, verbose=None):
return plot_epochs_psd(self, fmin=fmin, fmax=fmax, tmin=tmin,
tmax=tmax, proj=proj, bandwidth=bandwidth,
adaptive=adaptive, low_bias=low_bias,
normalization=normalization, picks=picks, ax=ax,
color=color, xscale=xscale, area_mode=area_mode,
normalization=normalization, picks=picks,
exclude=exclude, ax=ax, color=color,
xscale=xscale, area_mode=area_mode,
area_alpha=area_alpha, dB=dB, estimate=estimate,
show=show, n_jobs=n_jobs, average=average,
line_alpha=line_alpha,
Expand Down
19 changes: 10 additions & 9 deletions mne/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,18 +1540,19 @@ def plot(self, events=None, duration=10.0, start=0.0, n_channels=20,
@copy_function_doc_to_method_doc(plot_raw_psd)
def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None, proj=False,
n_fft=None, n_overlap=0, reject_by_annotation=True,
picks=None, ax=None, color='black', xscale='linear',
area_mode='std', area_alpha=0.33, dB=True, estimate='auto',
show=True, n_jobs=1, average=False, line_alpha=None,
spatial_colors=True, sphere=None, window='hamming',
verbose=None):
picks=None, exclude='bads', ax=None, color='black',
xscale='linear', area_mode='std', area_alpha=0.33, dB=True,
estimate='auto', show=True, n_jobs=1, average=False,
line_alpha=None, spatial_colors=True, sphere=None,
window='hamming', verbose=None):
return plot_raw_psd(self, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax,
proj=proj, n_fft=n_fft, n_overlap=n_overlap,
reject_by_annotation=reject_by_annotation,
picks=picks, ax=ax, color=color, xscale=xscale,
area_mode=area_mode, area_alpha=area_alpha,
dB=dB, estimate=estimate, show=show, n_jobs=n_jobs,
average=average, line_alpha=line_alpha,
picks=picks, exclude=exclude, ax=ax, color=color,
xscale=xscale, area_mode=area_mode,
area_alpha=area_alpha, dB=dB, estimate=estimate,
show=show, n_jobs=n_jobs, average=average,
line_alpha=line_alpha,
spatial_colors=spatial_colors, sphere=sphere,
window=window, verbose=verbose)

Expand Down
21 changes: 18 additions & 3 deletions mne/viz/_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
_get_color_list, logger, _validate_if_list_of_axes,
_plot_psd)
from ..defaults import _handle_default
from ..utils import set_config, _check_option, _check_sphere, Bunch
from ..utils import (set_config, _check_option, _check_sphere, _validate_type,
Bunch)
from ..annotations import _sync_onset
from ..time_frequency import psd_welch, psd_multitaper
from ..io.pick import (pick_types, _picks_to_idx, channel_indices_by_type,
Expand Down Expand Up @@ -2313,8 +2314,8 @@ def _line_figure(inst, axes=None, picks=None, **kwargs):
return fig, axes


def _psd_figure(inst, proj, picks, axes, area_mode, tmin, tmax, fmin, fmax,
n_jobs, color, area_alpha, dB, estimate, average,
def _psd_figure(inst, proj, picks, exclude, axes, area_mode, tmin, tmax, fmin,
fmax, n_jobs, color, area_alpha, dB, estimate, average,
spatial_colors, xscale, line_alpha, sphere, window, **kwargs):
"""Instantiate a new power spectral density figure."""
from .. import BaseEpochs
Expand Down Expand Up @@ -2351,6 +2352,20 @@ def _psd_figure(inst, proj, picks, axes, area_mode, tmin, tmax, fmin, fmax,
titles_list = list()
scalings_list = list()
psd_list = list()
# exclude channels
inst.info._check_consistency()
_validate_type(exclude, (list, tuple, str), 'exclude')
if isinstance(exclude, str):
_check_option('exclude', exclude, 'bads', extra='when str')
exclude = inst.info['bads']
for ei, ex in enumerate(exclude):
if not isinstance(ex, str) or ex not in inst.info['ch_names']:
raise ValueError(f'exclude[{ei}] ({repr(ex)}) '
'not found in info["ch_names"]')
if len(exclude) > 0:
picks = np.array([pick for pick in picks if
inst.info['ch_names'][pick] not in exclude], int)

# initialize figure
fig, axes = _line_figure(inst, axes, picks, **kwargs)
# split picks, units, etc, for each subplot
Expand Down
27 changes: 15 additions & 12 deletions mne/viz/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,11 +899,11 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, n_channels=20,
@verbose
def plot_epochs_psd(epochs, fmin=0, fmax=np.inf, tmin=None, tmax=None,
proj=False, bandwidth=None, adaptive=False, low_bias=True,
normalization='length', picks=None, ax=None, color='black',
xscale='linear', area_mode='std', area_alpha=0.33,
dB=True, estimate='auto', show=True, n_jobs=1,
average=False, line_alpha=None, spatial_colors=True,
sphere=None, verbose=None):
normalization='length', picks=None, exclude='bads',
ax=None, color='black', xscale='linear', area_mode='std',
area_alpha=0.33, dB=True, estimate='auto', show=True,
n_jobs=1, average=False, line_alpha=None,
spatial_colors=True, sphere=None, verbose=None):
"""%(plot_psd_doc)s.

Parameters
Expand Down Expand Up @@ -934,6 +934,9 @@ def plot_epochs_psd(epochs, fmin=0, fmax=np.inf, tmin=None, tmax=None,
be normalized by the sampling rate as well as the length of
the signal (as in nitime).
%(plot_psd_picks_good_data)s
exclude : list of str | 'bads'
Channels names to exclude from being shown. If 'bads', the
bad channels are excluded.
ax : instance of Axes | None
Axes to plot into. If None, axes will be created.
%(plot_psd_color)s
Expand Down Expand Up @@ -961,12 +964,12 @@ def plot_epochs_psd(epochs, fmin=0, fmax=np.inf, tmin=None, tmax=None,
# epochs always use multitaper, not Welch, so no need to allow "window"
# param above
fig = _psd_figure(
inst=epochs, proj=proj, picks=picks, axes=ax, tmin=tmin, tmax=tmax,
fmin=fmin, fmax=fmax, sphere=sphere, xscale=xscale, dB=dB,
average=average, estimate=estimate, area_mode=area_mode,
line_alpha=line_alpha, area_alpha=area_alpha, color=color,
spatial_colors=spatial_colors, n_jobs=n_jobs, bandwidth=bandwidth,
adaptive=adaptive, low_bias=low_bias, normalization=normalization,
window='hamming')
inst=epochs, proj=proj, picks=picks, exclude=exclude, axes=ax,
tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, sphere=sphere,
xscale=xscale, dB=dB, average=average, estimate=estimate,
area_mode=area_mode, line_alpha=line_alpha, area_alpha=area_alpha,
color=color, spatial_colors=spatial_colors, n_jobs=n_jobs,
bandwidth=bandwidth, adaptive=adaptive, low_bias=low_bias,
normalization=normalization, window='hamming')
plt_show(show)
return fig
17 changes: 10 additions & 7 deletions mne/viz/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,11 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20,
@verbose
def plot_raw_psd(raw, fmin=0, fmax=np.inf, tmin=None, tmax=None, proj=False,
n_fft=None, n_overlap=0, reject_by_annotation=True,
picks=None, ax=None, color='black', xscale='linear',
area_mode='std', area_alpha=0.33, dB=True, estimate='auto',
show=True, n_jobs=1, average=False, line_alpha=None,
spatial_colors=True, sphere=None, window='hamming',
verbose=None):
picks=None, exclude='bads', ax=None, color='black',
xscale='linear', area_mode='std', area_alpha=0.33, dB=True,
estimate='auto', show=True, n_jobs=1, average=False,
line_alpha=None, spatial_colors=True, sphere=None,
window='hamming', verbose=None):
"""%(plot_psd_doc)s.

Parameters
Expand All @@ -397,6 +397,9 @@ def plot_raw_psd(raw, fmin=0, fmax=np.inf, tmin=None, tmax=None, proj=False,
is 0 (no overlap).
%(reject_by_annotation_raw)s
%(plot_psd_picks_good_data)s
exclude : list of str | 'bads'
Channels names to exclude from being shown. If 'bads', the
bad channels are excluded.
ax : instance of Axes | None
Axes to plot into. If None, axes will be created.
%(plot_psd_color)s
Expand Down Expand Up @@ -430,8 +433,8 @@ def plot_raw_psd(raw, fmin=0, fmax=np.inf, tmin=None, tmax=None, proj=False,
n_fft = min(np.diff(raw.time_as_index([tmin, tmax]))[0] + 1, 2048)
# generate figure
fig = _psd_figure(
inst=raw, proj=proj, picks=picks, axes=ax, tmin=tmin, tmax=tmax,
fmin=fmin, fmax=fmax, sphere=sphere, xscale=xscale, dB=dB,
inst=raw, proj=proj, picks=picks, exclude=exclude, axes=ax, tmin=tmin,
tmax=tmax, fmin=fmin, fmax=fmax, sphere=sphere, xscale=xscale, dB=dB,
average=average, estimate=estimate, area_mode=area_mode,
line_alpha=line_alpha, area_alpha=area_alpha, color=color,
spatial_colors=spatial_colors, n_jobs=n_jobs, n_fft=n_fft,
Expand Down
2 changes: 2 additions & 0 deletions mne/viz/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ def test_plot_psd_epochs(epochs):
epochs.plot_psd(average=True, spatial_colors=False)
epochs.plot_psd(average=False, spatial_colors=True)
epochs.plot_psd(average=False, spatial_colors=False)
epochs.plot_psd(average=True, exclude='bads')

# test plot_psd_topomap errors
with pytest.raises(RuntimeError, match='No frequencies in band'):
epochs.plot_psd_topomap(bands=[(0, 0.01, 'foo')])
Expand Down
2 changes: 1 addition & 1 deletion mne/viz/tests/test_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ def test_browse_figure_constructor():
def test_psd_figure_constructor():
"""Test error handling in MNELineFigure constructor."""
with pytest.raises(TypeError, match='an instance of Raw or Epochs, got'):
_psd_figure('foo', *((None,) * 19))
_psd_figure('foo', *((None,) * 20))
5 changes: 5 additions & 0 deletions mne/viz/tests/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,8 @@ def test_plot_raw_psd(raw):
picks = pick_types(raw.info, meg='mag', eeg=False)[:4]
raw.plot_psd(tmax=None, picks=picks, area_mode='range', average=False,
spatial_colors=True)
raw.plot_psd(tmax=None, picks=picks, exclude='bads', area_mode='range',
average=False, spatial_colors=True)
raw.plot_psd(tmax=20., color='yellow', dB=False, line_alpha=0.4,
n_overlap=0.1, average=False)
plt.close('all')
Expand Down Expand Up @@ -663,6 +665,9 @@ def test_plot_raw_psd(raw):
verbose='error')
fig = raw.plot_psd()
assert len(fig.axes) == 10
# test excludeing one channel
fig = raw.plot_psd(exclude=['MEG 0113'])
assert len(fig.axes) == 9
plt.close('all')

# gh-7631
Expand Down