Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mne/channels/tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ def test_interpolation_nirs():
name in raw_od.ch_names])[0][0]
bad_0_std_pre_interp = np.std(raw_od._data[bad_0])
bads_init = list(raw_od.info['bads'])
raw_od.interpolate_bads(exclude=bads_init[:1])
assert raw_od.info['bads'] == bads_init[:1]
raw_od.interpolate_bads(exclude=bads_init[:2])
assert raw_od.info['bads'] == bads_init[:2]
raw_od.interpolate_bads()
assert raw_od.info['bads'] == []
assert bad_0_std_pre_interp > np.std(raw_od._data[bad_0])
21 changes: 11 additions & 10 deletions mne/preprocessing/nirs/nirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def _check_channels_ordered(info, pair_vals):
' as source detector pairs with alternating'
f' {error_word}: {pair_vals[0]} & {pair_vals[1]}')

_fnirs_check_bads(info)
return picks_cw


Expand All @@ -183,29 +184,29 @@ def _validate_nirs_info(info):
return picks


def _fnirs_check_bads(raw):
def _fnirs_check_bads(info):
"""Check consistent labeling of bads across fnirs optodes."""
# For an optode pair, if one component (light frequency or chroma) is
# marked as bad then they all should be. This function checks that all
# optodes are marked bad consistently.
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[])
picks = _picks_to_idx(info, 'fnirs', exclude=[], allow_empty=True)
for ii in picks[::2]:
bad_opto = set(raw.info['bads']).intersection(raw.ch_names[ii:ii + 2])
bad_opto = set(info['bads']).intersection(info.ch_names[ii:ii + 2])
if len(bad_opto) == 1:
raise RuntimeError('NIRS bad labelling is not consistent')


def _fnirs_spread_bads(raw):
def _fnirs_spread_bads(info):
"""Spread bad labeling across fnirs channels."""
# For an optode if any component (light frequency or chroma) is marked
# For an optode pair if any component (light frequency or chroma) is marked
# as bad, then they all should be. This function will find any pairs marked
# as bad and spread the bad marking to all components of the optode pair.
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[])
picks = _picks_to_idx(info, 'fnirs', exclude=[], allow_empty=True)
new_bads = list()
for ii in picks[::2]:
bad_opto = set(raw.info['bads']).intersection(raw.ch_names[ii:ii + 2])
bad_opto = set(info['bads']).intersection(info.ch_names[ii:ii + 2])
if len(bad_opto) > 0:
new_bads.extend(raw.ch_names[ii:ii + 2])
raw.info['bads'] = new_bads
new_bads.extend(info.ch_names[ii:ii + 2])
info['bads'] = new_bads

return raw
return info
36 changes: 19 additions & 17 deletions mne/preprocessing/nirs/tests/test_nirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,29 +100,31 @@ def test_fnirs_check_bads(fname):
"""Test checking of bad markings."""
# No bad channels, so these should all pass
raw = read_raw_nirx(fname)
_fnirs_check_bads(raw)
_fnirs_check_bads(raw.info)
raw = optical_density(raw)
_fnirs_check_bads(raw)
_fnirs_check_bads(raw.info)
raw = beer_lambert_law(raw)
_fnirs_check_bads(raw)
_fnirs_check_bads(raw.info)

# Mark pairs of bad channels, so these should all pass
raw = read_raw_nirx(fname)
raw.info['bads'] = raw.ch_names[0:2]
_fnirs_check_bads(raw)
_fnirs_check_bads(raw.info)
raw = optical_density(raw)
_fnirs_check_bads(raw)
_fnirs_check_bads(raw.info)
raw = beer_lambert_law(raw)
_fnirs_check_bads(raw)
_fnirs_check_bads(raw.info)

# Mark single channel as bad, so these should all fail
raw = read_raw_nirx(fname)
raw.info['bads'] = raw.ch_names[0:1]
pytest.raises(RuntimeError, _fnirs_check_bads, raw)
raw = optical_density(raw)
pytest.raises(RuntimeError, _fnirs_check_bads, raw)
raw = beer_lambert_law(raw)
pytest.raises(RuntimeError, _fnirs_check_bads, raw)
pytest.raises(RuntimeError, _fnirs_check_bads, raw.info)
with pytest.raises(RuntimeError, match='bad labelling'):
raw = optical_density(raw)
pytest.raises(RuntimeError, _fnirs_check_bads, raw.info)
with pytest.raises(RuntimeError, match='bad labelling'):
raw = beer_lambert_law(raw)
pytest.raises(RuntimeError, _fnirs_check_bads, raw.info)


@testing.requires_testing_data
Expand All @@ -133,20 +135,20 @@ def test_fnirs_spread_bads(fname):
# Test spreading upwards in frequency and on raw data
raw = read_raw_nirx(fname)
raw.info['bads'] = ['S1_D1 760']
raw = _fnirs_spread_bads(raw)
assert raw.info['bads'] == ['S1_D1 760', 'S1_D1 850']
info = _fnirs_spread_bads(raw.info)
assert info['bads'] == ['S1_D1 760', 'S1_D1 850']

# Test spreading downwards in frequency and on od data
raw = optical_density(raw)
raw.info['bads'] = raw.ch_names[5:6]
raw = _fnirs_spread_bads(raw)
assert raw.info['bads'] == raw.ch_names[4:6]
info = _fnirs_spread_bads(raw.info)
assert info['bads'] == raw.ch_names[4:6]

# Test spreading multiple bads and on chroma data
raw = beer_lambert_law(raw)
raw.info['bads'] = [raw.ch_names[x] for x in [1, 8]]
raw = _fnirs_spread_bads(raw)
assert raw.info['bads'] == [raw.ch_names[x] for x in [0, 1, 8, 9]]
info = _fnirs_spread_bads(raw.info)
assert info['bads'] == [info.ch_names[x] for x in [0, 1, 8, 9]]


@testing.requires_testing_data
Expand Down