diff --git a/mne/preprocessing/nirs/__init__.py b/mne/preprocessing/nirs/__init__.py index 6734cf53573..f101867f789 100644 --- a/mne/preprocessing/nirs/__init__.py +++ b/mne/preprocessing/nirs/__init__.py @@ -8,7 +8,7 @@ from .nirs import (short_channels, source_detector_distances, _check_channels_ordered, _channel_frequencies, - _fnirs_check_bads, _fnirs_spread_bads) + _fnirs_check_bads, _fnirs_spread_bads, _channel_chromophore) from ._optical_density import optical_density from ._beer_lambert_law import beer_lambert_law from ._scalp_coupling_index import scalp_coupling_index diff --git a/mne/preprocessing/nirs/_optical_density.py b/mne/preprocessing/nirs/_optical_density.py index 1efa72377a5..feb4ebb6637 100644 --- a/mne/preprocessing/nirs/_optical_density.py +++ b/mne/preprocessing/nirs/_optical_density.py @@ -10,6 +10,7 @@ from ...io.constants import FIFF from ...utils import _validate_type, warn from ...io.pick import _picks_to_idx +from ..nirs import _channel_frequencies, _check_channels_ordered def optical_density(raw): @@ -27,6 +28,8 @@ def optical_density(raw): """ raw = raw.copy().load_data() _validate_type(raw, BaseRaw, 'raw') + _check_channels_ordered(raw, np.unique(_channel_frequencies(raw))) + picks = _picks_to_idx(raw.info, 'fnirs_cw_amplitude') data_means = np.mean(raw.get_data(), axis=1) diff --git a/mne/preprocessing/nirs/_tddr.py b/mne/preprocessing/nirs/_tddr.py index e4daa039188..935a0b98d5a 100644 --- a/mne/preprocessing/nirs/_tddr.py +++ b/mne/preprocessing/nirs/_tddr.py @@ -10,6 +10,7 @@ from ...io import BaseRaw from ...utils import _validate_type, verbose from ...io.pick import _picks_to_idx +from ..nirs import _channel_frequencies, _check_channels_ordered @verbose @@ -42,6 +43,7 @@ def temporal_derivative_distribution_repair(raw, *, verbose=None): """ raw = raw.copy().load_data() _validate_type(raw, BaseRaw, 'raw') + _check_channels_ordered(raw, np.unique(_channel_frequencies(raw))) if not len(pick_types(raw.info, fnirs='fnirs_od')): raise RuntimeError('TDDR should be run on optical density data.') diff --git a/mne/preprocessing/nirs/nirs.py b/mne/preprocessing/nirs/nirs.py index 28dabd1cfdb..eea96594fa8 100644 --- a/mne/preprocessing/nirs/nirs.py +++ b/mne/preprocessing/nirs/nirs.py @@ -66,8 +66,19 @@ def _channel_frequencies(raw): return freqs +def _channel_chromophore(raw): + """Return the chromophore of each channel.""" + # Only valid for fNIRS data after conversion to haemoglobin + picks = _picks_to_idx(raw.info, ['hbo', 'hbr'], + exclude=[], allow_empty=True) + chroma = [] + for ii in picks: + chroma.append(raw.ch_names[ii].split(" ")[1]) + return chroma + + def _check_channels_ordered(raw, freqs): - """Check channels followed expected fNIRS format.""" + """Check channels follow expected fNIRS format.""" # Every second channel should be same SD pair # and have the specified light frequencies. picks = _picks_to_idx(raw.info, ['fnirs_cw_amplitude', 'fnirs_od'], @@ -77,6 +88,13 @@ def _check_channels_ordered(raw, freqs): 'NIRS channels not ordered correctly. An even number of NIRS ' 'channels is required. %d channels were provided: %r' % (len(raw.ch_names), raw.ch_names)) + + all_freqs = [raw.info["chs"][ii]["loc"][9] for ii in picks] + if np.any(np.isnan(all_freqs)): + raise ValueError( + 'NIRS channels is missing wavelength information in the' + f'info["chs"] structure. The encoded wavelengths are {all_freqs}.') + for ii in picks[::2]: ch1_name_info = re.match(r'S(\d+)_D(\d+) (\d+)', raw.info['chs'][ii]['ch_name']) @@ -101,7 +119,8 @@ def _check_channels_ordered(raw, freqs): (int(ch2_name_info.groups()[2]) != freqs[1]): raise ValueError( 'NIRS channels not ordered correctly. Channels must be ordered' - ' as source detector pairs with frequencies: %d & %d' + ' as source detector pairs with alternating' + ' frequencies: %d & %d' % (freqs[0], freqs[1])) return picks diff --git a/mne/preprocessing/nirs/tests/test_beer_lambert_law.py b/mne/preprocessing/nirs/tests/test_beer_lambert_law.py index 2e52d80cf5f..aca6c3d218c 100644 --- a/mne/preprocessing/nirs/tests/test_beer_lambert_law.py +++ b/mne/preprocessing/nirs/tests/test_beer_lambert_law.py @@ -70,7 +70,7 @@ def test_beer_lambert_unordered_errors(): # Test that an error is thrown if inconsistent frequencies used in data raw_od.info['chs'][2]['loc'][9] = 770.0 - with pytest.raises(ValueError, match='pairs with frequencies'): + with pytest.raises(ValueError, match='with alternating frequencies'): beer_lambert_law(raw_od) diff --git a/mne/preprocessing/nirs/tests/test_nirs.py b/mne/preprocessing/nirs/tests/test_nirs.py index 3a85589ef1c..55845521f0b 100644 --- a/mne/preprocessing/nirs/tests/test_nirs.py +++ b/mne/preprocessing/nirs/tests/test_nirs.py @@ -10,12 +10,13 @@ import numpy as np from numpy.testing import assert_array_equal +from mne import create_info from mne.datasets.testing import data_path -from mne.io import read_raw_nirx +from mne.io import read_raw_nirx, RawArray from mne.preprocessing.nirs import (optical_density, beer_lambert_law, _fnirs_check_bads, _fnirs_spread_bads, _check_channels_ordered, - _channel_frequencies) + _channel_frequencies, _channel_chromophore) from mne.io.pick import _picks_to_idx from mne.datasets import testing @@ -158,6 +159,8 @@ def test_fnirs_channel_naming_and_order_readers(fname): raw = read_raw_nirx(fname) freqs = np.unique(_channel_frequencies(raw)) assert_array_equal(freqs, [760, 850]) + chroma = np.unique(_channel_chromophore(raw)) + assert len(chroma) == 0 picks = _check_channels_ordered(raw, freqs) assert len(picks) == len(raw.ch_names) # as all fNIRS only data @@ -182,9 +185,124 @@ def test_fnirs_channel_naming_and_order_readers(fname): raw = optical_density(raw) freqs = np.unique(_channel_frequencies(raw)) assert_array_equal(freqs, [760, 850]) + chroma = np.unique(_channel_chromophore(raw)) + assert len(chroma) == 0 picks = _check_channels_ordered(raw, freqs) assert len(picks) == len(raw.ch_names) # as all fNIRS only data + # Check on haemoglobin data raw = beer_lambert_law(raw) freqs = np.unique(_channel_frequencies(raw)) assert len(freqs) == 0 + assert len(_channel_chromophore(raw)) == len(raw.ch_names) + chroma = np.unique(_channel_chromophore(raw)) + assert_array_equal(chroma, ["hbo", "hbr"]) + + +def test_fnirs_channel_naming_and_order_custom_raw(): + """Ensure fNIRS channel checking on manually created data.""" + data = np.random.normal(size=(6, 10)) + + # Start with a correctly named raw intensity dataset + # These are the steps required to build an fNIRS Raw object from scratch + ch_names = ['S1_D1 760', 'S1_D1 850', 'S2_D1 760', 'S2_D1 850', + 'S3_D1 760', 'S3_D1 850'] + ch_types = np.repeat("fnirs_cw_amplitude", 6) + info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=1.0) + raw = RawArray(data, info, verbose=True) + freqs = np.tile([760, 850], 3) + for idx, f in enumerate(freqs): + raw.info["chs"][idx]["loc"][9] = f + + freqs = np.unique(_channel_frequencies(raw)) + picks = _check_channels_ordered(raw, freqs) + assert len(picks) == len(raw.ch_names) + assert len(picks) == 6 + + # Different systems use different frequencies, so ensure that works + ch_names = ['S1_D1 920', 'S1_D1 850', 'S2_D1 920', 'S2_D1 850', + 'S3_D1 920', 'S3_D1 850'] + ch_types = np.repeat("fnirs_cw_amplitude", 6) + info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=1.0) + raw = RawArray(data, info, verbose=True) + freqs = np.tile([920, 850], 3) + for idx, f in enumerate(freqs): + raw.info["chs"][idx]["loc"][9] = f + + picks = _check_channels_ordered(raw, [920, 850]) + assert len(picks) == len(raw.ch_names) + assert len(picks) == 6 + + # Catch expected errors + + # The frequencies named in the channel names must match the info loc field + ch_names = ['S1_D1 760', 'S1_D1 850', 'S2_D1 760', 'S2_D1 850', + 'S3_D1 760', 'S3_D1 850'] + ch_types = np.repeat("fnirs_cw_amplitude", 6) + info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=1.0) + raw = RawArray(data, info, verbose=True) + freqs = np.tile([920, 850], 3) + for idx, f in enumerate(freqs): + raw.info["chs"][idx]["loc"][9] = f + with pytest.raises(ValueError, match='name and NIRS frequency do not'): + _check_channels_ordered(raw, [920, 850]) + + # Catch if someone doesn't set the info field + ch_names = ['S1_D1 760', 'S1_D1 850', 'S2_D1 760', 'S2_D1 850', + 'S3_D1 760', 'S3_D1 850'] + ch_types = np.repeat("fnirs_cw_amplitude", 6) + info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=1.0) + raw = RawArray(data, info, verbose=True) + with pytest.raises(ValueError, match='missing wavelength information'): + _check_channels_ordered(raw, [920, 850]) + + # I have seen data encoded not in alternating frequency, but blocked. + ch_names = ['S1_D1 760', 'S2_D1 760', 'S3_D1 760', + 'S1_D1 850', 'S2_D1 850', 'S3_D1 850'] + ch_types = np.repeat("fnirs_cw_amplitude", 6) + info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=1.0) + raw = RawArray(data, info, verbose=True) + freqs = np.repeat([760, 850], 3) + for idx, f in enumerate(freqs): + raw.info["chs"][idx]["loc"][9] = f + with pytest.raises(ValueError, match='channels not ordered correctly'): + _check_channels_ordered(raw, [760, 850]) + # and this is how you would fix the ordering, then it should pass + raw.pick(picks=[0, 3, 1, 4, 2, 5]) + _check_channels_ordered(raw, [760, 850]) + + +def test_fnirs_channel_naming_and_order_custom_optical_density(): + """Ensure fNIRS channel checking on manually created data.""" + data = np.random.normal(size=(6, 10)) + + # Start with a correctly named raw intensity dataset + # These are the steps required to build an fNIRS Raw object from scratch + ch_names = ['S1_D1 760', 'S1_D1 850', 'S2_D1 760', 'S2_D1 850', + 'S3_D1 760', 'S3_D1 850'] + ch_types = np.repeat("fnirs_od", 6) + info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=1.0) + raw = RawArray(data, info, verbose=True) + freqs = np.tile([760, 850], 3) + for idx, f in enumerate(freqs): + raw.info["chs"][idx]["loc"][9] = f + + freqs = np.unique(_channel_frequencies(raw)) + picks = _check_channels_ordered(raw, freqs) + assert len(picks) == len(raw.ch_names) + assert len(picks) == 6 + + # Check block naming for optical density + ch_names = ['S1_D1 760', 'S2_D1 760', 'S3_D1 760', + 'S1_D1 850', 'S2_D1 850', 'S3_D1 850'] + ch_types = np.repeat("fnirs_od", 6) + info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=1.0) + raw = RawArray(data, info, verbose=True) + freqs = np.repeat([760, 850], 3) + for idx, f in enumerate(freqs): + raw.info["chs"][idx]["loc"][9] = f + with pytest.raises(ValueError, match='channels not ordered correctly'): + _check_channels_ordered(raw, [760, 850]) + # and this is how you would fix the ordering, then it should pass + raw.pick(picks=[0, 3, 1, 4, 2, 5]) + _check_channels_ordered(raw, [760, 850])