diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 1be5491f71a..0edc421c8c1 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -61,6 +61,8 @@ Enhancements - :func:`mne.viz.plot_evoked` and :meth:`mne.Evoked.plot` gained a new parameter, ``highlight``, to visually highlight time periods of interest (:gh:`10614` by `Richard Höchenberger`_) +- Added fNIRS support to :func:`mne.Info.get_montage` (:gh:`10611` by `Robert Luke`_) + Bugs ~~~~ - Make ``color`` parameter check in in :func:`mne.viz.plot_evoked_topo` consistent (:gh:`10217` by :newcontrib:`T. Wang` and `Stefan Appelhoff`_) diff --git a/mne/channels/tests/test_montage.py b/mne/channels/tests/test_montage.py index 3f3840a4348..a2b7f4310b0 100644 --- a/mne/channels/tests/test_montage.py +++ b/mne/channels/tests/test_montage.py @@ -41,7 +41,7 @@ from mne.io.kit import read_mrk from mne.io import (read_raw_brainvision, read_raw_egi, read_raw_fif, - read_fiducials, __file__ as _MNE_IO_FILE) + read_fiducials, __file__ as _MNE_IO_FILE, read_raw_nirx) from mne.io import RawArray from mne.datasets import testing @@ -63,6 +63,8 @@ bdf_fname2 = op.join(data_path, 'BDF', 'test_bdf_stim_channel.bdf') egi_fname1 = op.join(data_path, 'EGI', 'test_egi.mff') cnt_fname = op.join(data_path, 'CNT', 'scan41_short.cnt') +fnirs_dname = op.join(data_path, 'NIRx', 'nirscout', + 'nirx_15_2_recording_w_short') subjects_dir = op.join(data_path, 'subjects') io_dir = op.dirname(_MNE_IO_FILE) @@ -1663,3 +1665,30 @@ def test_make_wrong_dig_montage(): make_dig_montage(ch_pos={'A1': ['a', 'b', 'c']}) with pytest.raises(TypeError, match="instance of ndarray, list, or tuple"): make_dig_montage(ch_pos={'A1': 5}) + + +@testing.requires_testing_data +def test_fnirs_montage(): + """Ensure fNIRS montages can be get and set.""" + raw = read_raw_nirx(fnirs_dname) + info_orig = raw.copy().info + mtg = raw.get_montage() + + num_sources = np.sum(["S" in optode for optode in mtg.ch_names]) + num_detectors = np.sum(["D" in optode for optode in mtg.ch_names]) + assert num_sources == 5 + assert num_detectors == 13 + + # Make a change to the montage before setting + raw.info['chs'][2]['loc'][:3] = [1., 2, 3] + # Set montage back to original + raw.set_montage(mtg) + + for ch in range(len(raw.ch_names)): + assert_array_equal(info_orig['chs'][ch]['loc'], + raw.info['chs'][ch]['loc']) + + # Mixed channel types not supported yet + raw.set_channel_types({ch_name: 'eeg' for ch_name in raw.ch_names[-2:]}) + with pytest.raises(ValueError, match='mix of fNIRS'): + raw.get_montage() diff --git a/mne/io/meas_info.py b/mne/io/meas_info.py index d9aa8662ceb..4eeb8bda4d8 100644 --- a/mne/io/meas_info.py +++ b/mne/io/meas_info.py @@ -175,6 +175,17 @@ def get_montage(self): # so use loc[:3] instead ch_pos = {ch_names[ii]: chs[ii]['loc'][:3] for ii in picks} + # fNIRS uses multiple channels for the same sensors, we use + # a private function to format these for dig montage. + fnirs_picks = pick_types(info, fnirs=True, exclude=[]) + if len(ch_pos) == len(fnirs_picks): + ch_pos = _get_fnirs_ch_pos(info) + elif len(fnirs_picks) > 0: + raise ValueError("MNE does not support getting the montage " + "for a mix of fNIRS and other data types. " + "Please raise a GitHub issue if you " + "require this feature.") + # create montage montage = make_dig_montage( ch_pos=ch_pos, @@ -2903,3 +2914,19 @@ def _ensure_infos_match(info1, info2, name, *, on_mismatch='raise'): f"runs to a common head position.") _on_missing(on_missing=on_mismatch, msg=msg, name='on_mismatch') + + +def _get_fnirs_ch_pos(info): + """Return positions of each fNIRS optode. + + fNIRS uses two types of optodes, sources and detectors. + There can be multiple connections between each source + and detector at different wavelengths. This function + returns the location of each source and detector. + """ + from ..preprocessing.nirs import _fnirs_optode_names, _optode_position + srcs, dets = _fnirs_optode_names(info) + ch_pos = {} + for optode in [*srcs, *dets]: + ch_pos[optode] = _optode_position(info, optode) + return ch_pos diff --git a/mne/io/nirx/tests/test_nirx.py b/mne/io/nirx/tests/test_nirx.py index 73172bd4184..7d5b7edb153 100644 --- a/mne/io/nirx/tests/test_nirx.py +++ b/mne/io/nirx/tests/test_nirx.py @@ -151,7 +151,7 @@ def test_nirsport_v2(): np.diff(raw.annotations.onset), [2.3, 3.1], atol=0.1) mon = raw.get_montage() - assert len(mon.dig) == 43 + assert len(mon.dig) == 27 @requires_testing_data diff --git a/mne/io/snirf/tests/test_snirf.py b/mne/io/snirf/tests/test_snirf.py index 9709976f198..938247c3ea2 100644 --- a/mne/io/snirf/tests/test_snirf.py +++ b/mne/io/snirf/tests/test_snirf.py @@ -304,7 +304,7 @@ def test_snirf_nirsport2_w_positions(): mni_locs[34], [0.0828, -0.046, 0.0285], atol=allowed_dist_error) mon = raw.get_montage() - assert len(mon.dig) == 43 + assert len(mon.dig) == 27 @requires_testing_data diff --git a/mne/minimum_norm/tests/test_inverse.py b/mne/minimum_norm/tests/test_inverse.py index ae1dcfa1f75..22049b1fb6e 100644 --- a/mne/minimum_norm/tests/test_inverse.py +++ b/mne/minimum_norm/tests/test_inverse.py @@ -332,7 +332,7 @@ def test_localization_bias_fixed(bias_params_fixed, method, lower, upper, ('MNE', 89, 92, dict(limit_depth_chs='whiten'), 0.2), # sparse default ('dSPM', 85, 87, 0.8, 0.2), ('sLORETA', 100, 100, 0.8, 0.2), - ('eLORETA', 99, 100, None, 0.2), + pytest.param('eLORETA', 99, 100, None, 0.2, marks=pytest.mark.slowtest), pytest.param('eLORETA', 99, 100, 0.8, 0.2, marks=pytest.mark.slowtest), pytest.param('eLORETA', 99, 100, 0.8, 0.001, marks=pytest.mark.slowtest), ]) @@ -831,6 +831,7 @@ def test_inverse_operator_volume(evoked, tmp_path): apply_inverse(evoked, inv_vol, pick_ori='normal') +@pytest.mark.slowtest def test_inverse_operator_discrete(evoked, tmp_path): """Test MNE inverse computation on discrete source space.""" # Make discrete source space diff --git a/mne/preprocessing/nirs/__init__.py b/mne/preprocessing/nirs/__init__.py index 2af2e3832cc..c9cc6d11374 100644 --- a/mne/preprocessing/nirs/__init__.py +++ b/mne/preprocessing/nirs/__init__.py @@ -9,7 +9,7 @@ from .nirs import (short_channels, source_detector_distances, _check_channels_ordered, _channel_frequencies, _fnirs_check_bads, _fnirs_spread_bads, _channel_chromophore, - _validate_nirs_info) + _validate_nirs_info, _fnirs_optode_names, _optode_position) from ._optical_density import optical_density from ._beer_lambert_law import beer_lambert_law from ._scalp_coupling_index import scalp_coupling_index diff --git a/mne/preprocessing/nirs/nirs.py b/mne/preprocessing/nirs/nirs.py index f29b6fa019f..982bf41347f 100644 --- a/mne/preprocessing/nirs/nirs.py +++ b/mne/preprocessing/nirs/nirs.py @@ -222,3 +222,40 @@ def _fnirs_spread_bads(info): info['bads'] = new_bads return info + + +def _fnirs_optode_names(info): + """Return list of unique optode names.""" + picks_wave = _picks_to_idx(info, ['fnirs_cw_amplitude', 'fnirs_od'], + exclude=[], allow_empty=True) + picks_chroma = _picks_to_idx(info, ['hbo', 'hbr'], + exclude=[], allow_empty=True) + + if len(picks_wave) > 0: + regex = _S_D_F_RE + elif len(picks_chroma) > 0: + regex = _S_D_H_RE + else: + return [], [] + + sources = np.unique([int(regex.match(ch).groups()[0]) + for ch in info.ch_names]) + detectors = np.unique([int(regex.match(ch).groups()[1]) + for ch in info.ch_names]) + + src_names = [f"S{s}" for s in sources] + det_names = [f"D{d}" for d in detectors] + + return src_names, det_names + + +def _optode_position(info, optode): + """Find the position of an optode.""" + idx = [optode in a for a in info.ch_names].index(True) + + if "S" in optode: + loc_idx = range(3, 6) + elif "D" in optode: + loc_idx = range(6, 9) + + return info["chs"][idx]["loc"][loc_idx] diff --git a/mne/preprocessing/nirs/tests/test_nirs.py b/mne/preprocessing/nirs/tests/test_nirs.py index 703c9ebd5a3..9646fe13065 100644 --- a/mne/preprocessing/nirs/tests/test_nirs.py +++ b/mne/preprocessing/nirs/tests/test_nirs.py @@ -8,7 +8,7 @@ import pytest import numpy as np -from numpy.testing import assert_array_equal +from numpy.testing import assert_array_equal, assert_array_almost_equal from mne import create_info from mne.datasets.testing import data_path @@ -16,7 +16,8 @@ from mne.preprocessing.nirs import (optical_density, beer_lambert_law, _fnirs_check_bads, _fnirs_spread_bads, _check_channels_ordered, - _channel_frequencies, _channel_chromophore) + _channel_frequencies, _channel_chromophore, + _fnirs_optode_names, _optode_position) from mne.io.pick import _picks_to_idx from mne.datasets import testing @@ -374,3 +375,30 @@ def test_fnirs_channel_naming_and_order_custom_chroma(): raw = RawArray(data, info, verbose=True) with pytest.raises(ValueError, match='can not be parsed'): _check_channels_ordered(raw.info, ["hbo", "hbr"]) + + +def test_optode_names(): + """Ensure optode name extraction is correct.""" + ch_names = ['S11_D2 760', 'S11_D2 850', 'S3_D1 760', + 'S3_D1 850', 'S2_D13 760', 'S2_D13 850'] + ch_types = np.repeat("fnirs_od", 6) + info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=1.0) + src_names, det_names = _fnirs_optode_names(info) + assert_array_equal(src_names, [f"S{n}" for n in ["2", "3", "11"]]) + assert_array_equal(det_names, [f"D{n}" for n in ["1", "2", "13"]]) + + ch_names = ['S1_D11 hbo', 'S1_D11 hbr', 'S2_D17 hbo', 'S2_D17 hbr', + 'S3_D1 hbo', 'S3_D1 hbr'] + ch_types = np.tile(["hbo", "hbr"], 3) + info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=1.0) + src_names, det_names = _fnirs_optode_names(info) + assert_array_equal(src_names, [f"S{n}" for n in range(1, 4)]) + assert_array_equal(det_names, [f"D{n}" for n in ["1", "11", "17"]]) + + +@testing.requires_testing_data +def test_optode_loc(): + """Ensure optode location extraction is correct.""" + raw = read_raw_nirx(fname_nirx_15_2_short) + loc = _optode_position(raw.info, "D3") + assert_array_almost_equal(loc, [0.082804, 0.01573, 0.024852]) diff --git a/mne/preprocessing/tests/test_maxwell.py b/mne/preprocessing/tests/test_maxwell.py index 3382a777438..4ede4916ed6 100644 --- a/mne/preprocessing/tests/test_maxwell.py +++ b/mne/preprocessing/tests/test_maxwell.py @@ -538,6 +538,7 @@ def test_bads_reconstruction(): assert_meg_snr(raw_sss, read_crop(sss_bad_recon_fname), 300.) +@pytest.mark.slowtest @buggy_mkl_svd @testing.requires_testing_data def test_spatiotemporal(): diff --git a/mne/tests/test_cov.py b/mne/tests/test_cov.py index 805459437cc..17c04d44b54 100644 --- a/mne/tests/test_cov.py +++ b/mne/tests/test_cov.py @@ -257,7 +257,11 @@ def test_io_cov(tmp_path): read_cov(cov_badname) -@pytest.mark.parametrize('method', (None, 'empirical', 'shrunk')) +@pytest.mark.parametrize('method', [ + None, + 'empirical', + pytest.param('shrunk', marks=pytest.mark.slowtest), +]) def test_cov_estimation_on_raw(method, tmp_path): """Test estimation from raw (typically empty room).""" if method == 'shrunk':