Skip to content
Merged
21 changes: 13 additions & 8 deletions mne/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,18 @@ def grand_average(all_evoked, interpolate_bads=True):
return grand_average


def _check_evokeds_ch_names_times(all_evoked):
evoked = all_evoked[0]
ch_names = evoked.ch_names
for ev in all_evoked[1:]:
if not ev.ch_names == ch_names:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more easily expressed as ev.ch_names != ch_names but you can keep it this way

raise ValueError(
"%s and %s do not contain the same channels" % (evoked, ev))
if not np.max(np.abs(ev.times - evoked.times)) < 1e-7:
raise ValueError("%s and %s do not contain the same time instants"
% (evoked, ev))


def combine_evoked(all_evoked, weights):
"""Merge evoked data by weighted addition or subtraction.

Expand Down Expand Up @@ -892,14 +904,7 @@ def combine_evoked(all_evoked, weights):
if weights.ndim != 1 or weights.size != len(all_evoked):
raise ValueError('weights must be the same size as all_evoked')

ch_names = evoked.ch_names
for e in all_evoked[1:]:
assert e.ch_names == ch_names, ValueError("%s and %s do not contain "
"the same channels"
% (evoked, e))
assert np.max(np.abs(e.times - evoked.times)) < 1e-7, \
ValueError("%s and %s do not contain the same time instants"
% (evoked, e))
_check_evokeds_ch_names_times(all_evoked)

# use union of bad channels
bads = list(set(evoked.info['bads']).union(*(ev.info['bads']
Expand Down
38 changes: 27 additions & 11 deletions mne/io/pick.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,17 +583,33 @@ def pick_types_forward(orig, meg=True, eeg=False, ref_meg=True, seeg=False,
return pick_channels_forward(orig, include_ch_names)


def channel_indices_by_type(info):
"""Get indices of channels by type."""
idx = dict((key, list()) for key in _PICK_TYPES_KEYS if
key not in ('meg', 'fnirs'))
idx.update(mag=list(), grad=list(), hbo=list(), hbr=list())
for k, ch in enumerate(info['chs']):
for key in idx.keys():
if channel_type(info, k) == key:
idx[key].append(k)

return idx
def channel_indices_by_type(info, picks=None):
"""Get indices of channels by type.

Parameters
----------
info : instance of mne.measuerment_info.Info
The info.
picks : None | list of int
The indices of channels from which to get the type

Returns
-------
idx_by_type : dict
The dictionary that maps each channel type to the list of
channel indidces.
"""
idx_by_type = dict((key, list()) for key in _PICK_TYPES_KEYS if
key not in ('meg', 'fnirs'))
idx_by_type.update(mag=list(), grad=list(), hbo=list(), hbr=list())
if picks is None:
picks = range(len(info["chs"]))
for k in picks:
ch_type = channel_type(info, k)
for key in idx_by_type.keys():
if ch_type == key:
idx_by_type[key].append(k)
return idx_by_type


def pick_channels_cov(orig, include=[], exclude='bads'):
Expand Down
3 changes: 3 additions & 0 deletions mne/viz/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def plot_epochs_image(epochs, picks=None, sigma=0., vmin=None,
group_by = "type"
combine = "gfp"

if combine is not None:
ts_args["show_sensors"] = False

if picks is None:
picks = pick_types(epochs.info, meg=True, eeg=True, ref_meg=False,
exclude='bads')
Expand Down
Loading