diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index f9d67a42c24..9a81afdad08 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -298,8 +298,8 @@ def test_interpolation_nirs(): name in raw_od.ch_names])[0][0] bad_0_std_pre_interp = np.std(raw_od._data[bad_0]) bads_init = list(raw_od.info['bads']) - raw_od.interpolate_bads(exclude=bads_init[:1]) - assert raw_od.info['bads'] == bads_init[:1] + raw_od.interpolate_bads(exclude=bads_init[:2]) + assert raw_od.info['bads'] == bads_init[:2] raw_od.interpolate_bads() assert raw_od.info['bads'] == [] assert bad_0_std_pre_interp > np.std(raw_od._data[bad_0]) diff --git a/mne/preprocessing/nirs/nirs.py b/mne/preprocessing/nirs/nirs.py index b6e69d64fe8..093eb2954ca 100644 --- a/mne/preprocessing/nirs/nirs.py +++ b/mne/preprocessing/nirs/nirs.py @@ -169,6 +169,7 @@ def _check_channels_ordered(info, pair_vals): ' as source detector pairs with alternating' f' {error_word}: {pair_vals[0]} & {pair_vals[1]}') + _fnirs_check_bads(info) return picks_cw @@ -183,29 +184,29 @@ def _validate_nirs_info(info): return picks -def _fnirs_check_bads(raw): +def _fnirs_check_bads(info): """Check consistent labeling of bads across fnirs optodes.""" # For an optode pair, if one component (light frequency or chroma) is # marked as bad then they all should be. This function checks that all # optodes are marked bad consistently. - picks = _picks_to_idx(raw.info, 'fnirs', exclude=[]) + picks = _picks_to_idx(info, 'fnirs', exclude=[], allow_empty=True) for ii in picks[::2]: - bad_opto = set(raw.info['bads']).intersection(raw.ch_names[ii:ii + 2]) + bad_opto = set(info['bads']).intersection(info.ch_names[ii:ii + 2]) if len(bad_opto) == 1: raise RuntimeError('NIRS bad labelling is not consistent') -def _fnirs_spread_bads(raw): +def _fnirs_spread_bads(info): """Spread bad labeling across fnirs channels.""" - # For an optode if any component (light frequency or chroma) is marked + # For an optode pair if any component (light frequency or chroma) is marked # as bad, then they all should be. This function will find any pairs marked # as bad and spread the bad marking to all components of the optode pair. - picks = _picks_to_idx(raw.info, 'fnirs', exclude=[]) + picks = _picks_to_idx(info, 'fnirs', exclude=[], allow_empty=True) new_bads = list() for ii in picks[::2]: - bad_opto = set(raw.info['bads']).intersection(raw.ch_names[ii:ii + 2]) + bad_opto = set(info['bads']).intersection(info.ch_names[ii:ii + 2]) if len(bad_opto) > 0: - new_bads.extend(raw.ch_names[ii:ii + 2]) - raw.info['bads'] = new_bads + new_bads.extend(info.ch_names[ii:ii + 2]) + info['bads'] = new_bads - return raw + return info diff --git a/mne/preprocessing/nirs/tests/test_nirs.py b/mne/preprocessing/nirs/tests/test_nirs.py index 2db3abf9544..eed99dfa021 100644 --- a/mne/preprocessing/nirs/tests/test_nirs.py +++ b/mne/preprocessing/nirs/tests/test_nirs.py @@ -100,29 +100,31 @@ def test_fnirs_check_bads(fname): """Test checking of bad markings.""" # No bad channels, so these should all pass raw = read_raw_nirx(fname) - _fnirs_check_bads(raw) + _fnirs_check_bads(raw.info) raw = optical_density(raw) - _fnirs_check_bads(raw) + _fnirs_check_bads(raw.info) raw = beer_lambert_law(raw) - _fnirs_check_bads(raw) + _fnirs_check_bads(raw.info) # Mark pairs of bad channels, so these should all pass raw = read_raw_nirx(fname) raw.info['bads'] = raw.ch_names[0:2] - _fnirs_check_bads(raw) + _fnirs_check_bads(raw.info) raw = optical_density(raw) - _fnirs_check_bads(raw) + _fnirs_check_bads(raw.info) raw = beer_lambert_law(raw) - _fnirs_check_bads(raw) + _fnirs_check_bads(raw.info) # Mark single channel as bad, so these should all fail raw = read_raw_nirx(fname) raw.info['bads'] = raw.ch_names[0:1] - pytest.raises(RuntimeError, _fnirs_check_bads, raw) - raw = optical_density(raw) - pytest.raises(RuntimeError, _fnirs_check_bads, raw) - raw = beer_lambert_law(raw) - pytest.raises(RuntimeError, _fnirs_check_bads, raw) + pytest.raises(RuntimeError, _fnirs_check_bads, raw.info) + with pytest.raises(RuntimeError, match='bad labelling'): + raw = optical_density(raw) + pytest.raises(RuntimeError, _fnirs_check_bads, raw.info) + with pytest.raises(RuntimeError, match='bad labelling'): + raw = beer_lambert_law(raw) + pytest.raises(RuntimeError, _fnirs_check_bads, raw.info) @testing.requires_testing_data @@ -133,20 +135,20 @@ def test_fnirs_spread_bads(fname): # Test spreading upwards in frequency and on raw data raw = read_raw_nirx(fname) raw.info['bads'] = ['S1_D1 760'] - raw = _fnirs_spread_bads(raw) - assert raw.info['bads'] == ['S1_D1 760', 'S1_D1 850'] + info = _fnirs_spread_bads(raw.info) + assert info['bads'] == ['S1_D1 760', 'S1_D1 850'] # Test spreading downwards in frequency and on od data raw = optical_density(raw) raw.info['bads'] = raw.ch_names[5:6] - raw = _fnirs_spread_bads(raw) - assert raw.info['bads'] == raw.ch_names[4:6] + info = _fnirs_spread_bads(raw.info) + assert info['bads'] == raw.ch_names[4:6] # Test spreading multiple bads and on chroma data raw = beer_lambert_law(raw) raw.info['bads'] = [raw.ch_names[x] for x in [1, 8]] - raw = _fnirs_spread_bads(raw) - assert raw.info['bads'] == [raw.ch_names[x] for x in [0, 1, 8, 9]] + info = _fnirs_spread_bads(raw.info) + assert info['bads'] == [info.ch_names[x] for x in [0, 1, 8, 9]] @testing.requires_testing_data