diff --git a/doc/api/preprocessing.rst b/doc/api/preprocessing.rst index 07443e518aa..2a396684dec 100644 --- a/doc/api/preprocessing.rst +++ b/doc/api/preprocessing.rst @@ -221,6 +221,7 @@ EEG referencing: get_chpi_info head_pos_to_trans_rot_t read_head_pos + refit_hpi write_head_pos :py:mod:`mne.transforms` @@ -235,6 +236,7 @@ EEG referencing: :toctree: ../generated/ Transform + angle_distance_between_rigid quat_to_rot rot_to_quat read_ras_mni_t diff --git a/doc/changes/dev/13484.newfeature.rst b/doc/changes/dev/13484.newfeature.rst new file mode 100644 index 00000000000..d8c792eca72 --- /dev/null +++ b/doc/changes/dev/13484.newfeature.rst @@ -0,0 +1 @@ +Add ability to refit HPI order and device-to-head transform via :func:`mne.chpi.refit_hpi` and compute distances between transforms with :func:`mne.transforms.angle_distance_between_rigid` by `Eric Larson`_. diff --git a/mne/_fiff/meas_info.py b/mne/_fiff/meas_info.py index e6d7bd4690a..708b7135a9c 100644 --- a/mne/_fiff/meas_info.py +++ b/mne/_fiff/meas_info.py @@ -1999,6 +1999,10 @@ def ch_names(self): return ch_names + @property + def _cals(self): + return np.array([ch["range"] * ch["cal"] for ch in self["chs"]], float) + @repr_html def _repr_html_(self): """Summarize info for HTML representation.""" diff --git a/mne/chpi.py b/mne/chpi.py index 1cc168f3e20..711474338c9 100644 --- a/mne/chpi.py +++ b/mne/chpi.py @@ -45,7 +45,7 @@ from .event import find_events from .fixes import jit from .forward import _concatenate_coils, _create_meg_coils, _magnetic_dipole_field_vec -from .io import BaseRaw +from .io import BaseRaw, RawArray from .io.ctf.trans import _make_ctf_coord_trans_set from .io.kit.constants import KIT from .io.kit.kit import RawKIT as _RawKIT @@ -61,6 +61,7 @@ _fit_matched_points, _quat_to_affine, als_ras_trans, + angle_distance_between_rigid, apply_trans, invert_transform, quat_to_rot, @@ -420,13 +421,12 @@ def _get_hpi_initial_fit(info, adjust=False, verbose=None): raise RuntimeError("no initial cHPI head localization performed") hpi_result = info["hpi_results"][-1] - hpi_dig = sorted( - [d for d in info["dig"] if d["kind"] == FIFF.FIFFV_POINT_HPI], - key=lambda x: x["ident"], - ) # ascending (dig) order - if len(hpi_dig) == 0: # CTF data, probably + hpi_dig = _sorted_hpi_dig(info["dig"]) + CTF_KINDS = (FIFF.FIFFV_POINT_HPI, FIFF.FIFFV_POINT_CARDINAL) + if len(hpi_dig) == 0: msg = "HPIFIT: No HPI dig points, using hpifit result" - hpi_dig = sorted(hpi_result["dig_points"], key=lambda x: x["ident"]) + # For CTF data, these can get stored as cardinal points + hpi_dig = _sorted_hpi_dig(hpi_result["dig_points"], kinds=CTF_KINDS) if all( d["coord_frame"] in (FIFF.FIFFV_COORD_DEVICE, FIFF.FIFFV_COORD_UNKNOWN) for d in hpi_dig @@ -462,11 +462,12 @@ def _get_hpi_initial_fit(info, adjust=False, verbose=None): f"HPIFIT: {len(used)} coils accepted: {' '.join(str(h) for h in used)}" ) hpi_rrs = np.array([d["r"] for d in hpi_dig])[pos_order] - assert len(hpi_rrs) >= 3 + assert len(hpi_rrs) >= 3, len(hpi_rrs) # Fitting errors - hpi_rrs_fit = sorted( - [d for d in info["hpi_results"][-1]["dig_points"]], key=lambda x: x["ident"] + hpi_rrs_fit = _sorted_hpi_dig( + info["hpi_results"][-1]["dig_points"], + kinds=CTF_KINDS, ) hpi_rrs_fit = np.array([d["r"] for d in hpi_rrs_fit]) # hpi_result['dig_points'] are in FIFFV_COORD_UNKNOWN coords, but this @@ -475,7 +476,7 @@ def _get_hpi_initial_fit(info, adjust=False, verbose=None): assert hpi_result["coord_trans"]["to"] == FIFF.FIFFV_COORD_HEAD hpi_rrs_fit = apply_trans(hpi_result["coord_trans"]["trans"], hpi_rrs_fit) if "moments" in hpi_result: - logger.debug(f"Hpi coil moments {hpi_result['moments'].shape[::-1]}:") + logger.debug(f"HPI coil moments {hpi_result['moments'].shape[::-1]}:") for moment in hpi_result["moments"]: logger.debug(f"{moment[0]:g} {moment[1]:g} {moment[2]:g}") errors = np.linalg.norm(hpi_rrs - hpi_rrs_fit, axis=1) @@ -583,14 +584,14 @@ def _fit_chpi_quat(coil_dev_rrs, coil_head_rrs): denom = np.linalg.norm(coil_head_rrs - np.mean(coil_head_rrs, axis=0)) denom *= denom # We could try to solve it the analytic way: - # XXX someday we could choose to weight these points by their goodness - # of fit somehow. + # TODO someday we could choose to weight these points by their goodness + # of fit somehow, see also https://github.com/mne-tools/mne-python/issues/11330 quat = _fit_matched_points(coil_dev_rrs, coil_head_rrs)[0] gof = 1.0 - _chpi_objective(quat, coil_dev_rrs, coil_head_rrs) / denom return quat, gof -def _fit_coil_order_dev_head_trans(dev_pnts, head_pnts, bias=True): +def _fit_coil_order_dev_head_trans(dev_pnts, head_pnts, *, bias=True, prefix=""): """Compute Device to Head transform allowing for permutiatons of points.""" id_quat = np.zeros(6) best_order = None @@ -616,6 +617,13 @@ def _fit_coil_order_dev_head_trans(dev_pnts, head_pnts, bias=True): # Convert Quaterion to transform dev_head_t = _quat_to_affine(best_quat) + ang, dist = angle_distance_between_rigid( + dev_head_t, angle_units="deg", distance_units="mm" + ) + logger.info( + f"{prefix}Fitted dev_head_t {ang:0.1f}° and {dist:0.1f} mm " + f"from device origin (GOF: {out_g:.3f})" + ) return dev_head_t, best_order, out_g @@ -755,6 +763,9 @@ def _setup_ext_proj(info, ext_order): proj_op, _ = setup_proj( info, add_eeg_ref=False, activate=False, verbose=_verbose_safe_false() ) + # Can be None if ext_order = 0 + if proj_op is None: + proj_op = np.eye(len(meg_picks)) assert proj_op.shape == (len(meg_picks),) * 2 return proj, proj_op, meg_picks @@ -870,8 +881,7 @@ def _check_chpi_param(chpi_, name): want_keys = list(want_ndims.keys()) + extra_keys if set(want_keys).symmetric_difference(chpi_): raise ValueError( - f"{name} must be a dict with entries {want_keys}, got " - f"{sorted(chpi_.keys())}" + f"{name} must be a dict with entries {want_keys}, got {sorted(chpi_)}" ) n_times = None for key, want_ndim in want_ndims.items(): @@ -1615,3 +1625,311 @@ def get_active_chpi(raw, *, on_missing="raise", verbose=None): chpi_ts = raw[chpi_info[1]][0].astype(int) chpi_active = (chpi_ts & chpi_info[2][:, np.newaxis]).astype(bool) return chpi_active.sum(axis=0) + + +@verbose +def refit_hpi( + info, + *, + amplitudes=True, + locs=True, + order=True, + ext_order=1, + gof_limit=0.98, + dist_limit=0.005, + use=None, + colinearity_limit=0, + verbose=None, +): + """Refit HPI coil order. + + This operates inplace on ``info``, and will typically be called via + ``refit_hpi(raw.info)`` before further processing. + + Parameters + ---------- + info : instance of Info + The measurement info. + amplitudes : bool + Whether to recompute the HPI amplitudes (slopes) from the raw data obtained + during the original fit using :func:`~mne.chpi.compute_chpi_amplitudes`, + or used the already-computed ones stored in ``info['hpi_meas']``. + If this is True, ``locs`` and ``order`` must also be True. + locs : bool + Whether to recompute the HPI coil locations using + :func:`~mne.chpi.compute_chpi_locs`, or use the already-computed ones stored in + ``info['hpi_results']``. + If this is True, ``order`` must also be True. + order : bool + Whether to refit the coil order by testing all permutations for the best + goodness of fit between digitized coil locations and (rigid-transformed) + fitted coil locations. + %(ext_order_chpi)s + gof_limit : float + The goodness-of-fit limit to use when choosing which coils to use for refitting. + dist_limit : float + The distance limit (in meters) to use when choosing which coils to use for + refitting. + use : int | None + The maximum number of coils to use when testing different coil orderings. + The default for ``hpifit`` in MEGIN software is 3. Default (None) means to + use all coils above ``gof_limit``. Can also be an ndarray of coil indices + (0-indexed!) to use, e.g., ``[1, 2, 4]``. + colinearity_limit : float + The RMS limit (in meters) to use when checking for colinearity of coil + locations. If the RMS difference between the used points and a best-fit linear + approximation is less than this limit, a warning is emitted. + The default (``0``) avoids colinearity warnings altogether. + The appropriate value here is dataset dependent, but for one problematic + dataset the value of 0.03 worked well. + %(verbose)s + + Returns + ------- + info : instance of Info + The modified measurement info (same as input). + + Notes + ----- + This adds additional entries to ``info["hpi_meas"]`` and + ``info["hpi_results"]``, leaving the existing ones intact. + It will always modify ``info["dev_head_t"]`` inplace. + + The algorithm is as follows: + + 1. Optionally recompute HPI amplitudes (sinusoidal fit for each channel) using + :func:`~mne.chpi.compute_chpi_amplitudes`. + 2. Optionally use HPI amplitudes to fit HPI coil locations using + :func:`~mne.chpi.compute_chpi_locs`. + 3. Optionally determine coil digitization order by testing all permutations + for the best goodness of fit between digitized coil locations and + (rigid-transformed) fitted coil locations. + 4. Subselect coils to use for fitting ``dev_head_t`` based on ``gof_limit``, + ``dist_limit``, and ``use``. + 5. Update info inplace by modifying ``info["dev_head_t"]`` and appending new entries + to ``info["hpi_meas"]`` and ``info["hpi_results"]``. + + .. versionadded:: 1.11 + """ + _validate_type(info, Info, "info") + fit_info = pick_info(info, pick_types(info, meg=True, exclude=())) + # Set bads to empty list here. In theory flux jumps etc. or even flat channels + # shouldn't affect the fit much. At some point we could allow ignoring bads, + # but it would make the API more complex (KISS) and make the info accounting harder + # (e.g., slopes must always have shape[-1] == len(all_meg_chs)). + fit_info["bads"] = [] # for backward compat... maybe shouldn't do this + vf = _verbose_safe_false() + hpi = _setup_hpi_amplitude_fitting(fit_info, 1.0, ext_order=ext_order, verbose=vf) + n_coils = info["hpi_subsystem"]["ncoil"] + if use is not None: + use = np.array(use) + if use.dtype.kind != "i": + raise ValueError( + f"use must be an integer or array-like of integers, got {use.dtype}" + ) + if use.shape == (): + use = int(use.item()) + if use < 3: + raise ValueError(f"max_use must be at least 3, got {use}") + elif ( + use.ndim != 1 + or not np.array_equal(np.sort(use), np.unique(use)) + or not np.isin(use, np.arange(info["hpi_subsystem"]["ncoil"])).all() + ): + raise ValueError( + "use must be a 1D array of unique integers in the range [0, " + f"{n_coils - 1}]" + ) + assert use is None or isinstance(use, int | np.ndarray) # we have this now + _validate_type(amplitudes, bool, "amplitudes") + _validate_type(locs, bool, "locs") + _validate_type(order, bool, "order") + _validate_type(dist_limit, "numeric", "dist_limit") + _validate_type(colinearity_limit, "numeric", "colinearity_limit") + if amplitudes and not locs: + raise ValueError( + "If amplitudes is True, locs must also be True (otherwise " + "recomputing amplitudes has no effect)" + ) + if locs and not order: + raise ValueError( + "If locs is True, order must also be True (otherwise " + "recomputing locations has no effect)" + ) + logger.info(f"Refitting HPI coil order for {n_coils} coils ...") + old_meas = info["hpi_meas"][-1] + old_results = info["hpi_results"][-1] + slopes = np.array([[old_meas["hpi_coils"][ci]["slopes"] for ci in range(n_coils)]]) + + # 1. Compute HPI amplitudes + if amplitudes: + epoch = old_meas["hpi_coils"][0]["epoch"] + cals = info._cals[pick_types(info, meg=True, exclude=()), np.newaxis] + assert cals.shape[0] == epoch.shape[0], "Calibration shape mismatch" + data = epoch * cals + fit_raw = RawArray(data, fit_info) + stop = fit_raw.times[-1] + fit_amps = compute_chpi_amplitudes( + fit_raw, t_step_min=stop, t_window=stop + 1.0 / info["sfreq"] + ) + for ci, slope in enumerate(fit_amps["slopes"][0]): + old_slope = old_meas["hpi_coils"][ci]["slopes"] + corr = np.abs(np.corrcoef(slope, old_slope)[0, 1]) + logger.info(f" Coil {ci + 1}: slope correlation with old = {corr:.3f}") + else: + fit_amps = dict( + times=np.array([old_meas["first_samp"] / info["sfreq"]]), + slopes=slopes, + proj=hpi["proj"], + ) + del amplitudes + + # 2. Compute HPI locations + if locs: + fit_locs = compute_chpi_locs(fit_info, fit_amps) + for ci in range(n_coils): + dist = 1e3 * np.linalg.norm( + fit_locs["rrs"][0][ci] - old_results["dig_points"][ci]["r"] + ) + logger.info( + f" Coil {ci + 1}: location difference with old = {dist:.1f} mm" + ) + else: + fit_locs = dict( + rrs=[np.array([d["r"] for d in old_results["dig_points"]], float)], + gofs=[old_results["goodness"]], + moments=[old_results["moments"]], + ) + del fit_info, locs + + # 3. Determine coil order + hpi_dig = _sorted_hpi_dig(info["dig"]) + assert all(d["coord_frame"] == FIFF.FIFFV_COORD_HEAD for d in hpi_dig) # should be + hpi_head = np.array([d["r"] for d in hpi_dig]).astype(float) + del hpi_dig + hpi_dev = fit_locs["rrs"][0].astype(float) + hpi_gofs = fit_locs["gofs"][0] + gofs_str = " ".join(f"{g:.3f}" for g in hpi_gofs) + logger.info(f" Coil goodness-of-fits: {gofs_str}") + assert len(hpi_head) == len(hpi_dev) == n_coils + if order: + fit_dev_head_t, fit_order, _g = _fit_coil_order_dev_head_trans( + hpi_dev, + hpi_head, + prefix=" ", + ) + else: + fit_order = info["hpi_results"][-1]["order"] - 1 # make 0-indexed + fit_dev_head_t = info["dev_head_t"]["trans"] + + # 4. Subselect usable coils and determine final dev_head_t + if isinstance(use, int) or use is None: + used = np.where(hpi_gofs >= gof_limit)[0] + if len(used) < 3: + gofs = ", ".join(f"{g:.3f}" for g in hpi_gofs) + raise RuntimeError( + f"Only {len(used)} coil{_pl(used)} with goodness of fit >= {gof_limit}" + f", need at least 3 to refit HPI order (got {gofs})." + ) + quat, _g = _fit_chpi_quat(hpi_dev[used], hpi_head[fit_order][used]) + fit_dev_head_t = _quat_to_affine(quat) + hpi_head_got = apply_trans(fit_dev_head_t, hpi_dev) + dists = np.linalg.norm(hpi_head_got - hpi_head[fit_order], axis=1) + dist_str = " ".join(f"{dist * 1e3:.1f}" for dist in dists) + logger.info(f" Coil distances after initial fit: {dist_str} mm") + good_dists_idx = np.where(dists[used] <= dist_limit)[0] + if not len(good_dists_idx) >= 3: + raise RuntimeError( + f"Only {len(good_dists_idx)} coil{_pl(good_dists_idx)} have distance " + f"<= {dist_limit * 1e3:.1f} mm, need at least 3 to refit HPI order " + f"(got distances: {np.round(1e3 * dists, 1)})." + ) + used = used[good_dists_idx] + if use is not None: + used = np.sort(used[np.argsort(hpi_gofs[used])[-use:]]) + else: + used = use + del use + used_str = " ".join(str(u + 1) for u in used) + logger.info(f" Using coils {used_str} to compute final dev_head_t") + + # Sanity check linearity of points + # The threshold of 3cm was empirically determined by looking at a known problematic + # dataset and seeing when it was fixed. + limit_cm = colinearity_limit * 1e2 + min_kind, min_cm = "", np.inf + for kind, pts in [ + ("digitized", hpi_head[fit_order][used]), + ("fitted", hpi_dev[used]), + ]: + centered = pts - pts.mean(axis=0) + v = np.linalg.svd(centered, full_matrices=False)[2] + resid = pts - (pts @ v[0])[:, np.newaxis] * v[0] + # check if RMS error is less than 5 mm + rms_cm = 1e2 * np.sqrt(np.mean(resid * resid)) + if rms_cm < min_cm: + min_cm = rms_cm + min_kind = kind + if min_cm < limit_cm: + extra = "" + if len(used) < n_coils: + extra += ( + ", consider including more coils by adjusting the gof_limit, " + "dist_limit, amplitudes, locs, or manually setting order" + ) + warn( + f"The {len(used)} {min_kind} coil locations {used_str} are approximately " + f"colinear (RMS error {min_cm:.1f} cm from a linear fit). The fit may be " + f"unstable and be fit as an incorrect rotation about a line{extra}" + ) + quat, _g = _fit_chpi_quat(hpi_dev[used], hpi_head[fit_order][used]) + assert np.linalg.det(quat_to_rot(quat[:3])) > 0.9999 + fit_dev_head_t = _quat_to_affine(quat) + + # 5. Adjust metadata + info["dev_head_t"]["trans"][:] = fit_dev_head_t + results = copy.deepcopy(old_results) + results["coord_trans"]["trans"][:] = fit_dev_head_t + results["accept"] = 1 + results["used"] = used + 1 # make 1-indexed (can be different length from previous) + results["order"][:] = fit_order + 1 # make 1-indexed + results["goodness"][:] = hpi_gofs + results["good_limit"] = gof_limit + results["moments"][:] = fit_locs["moments"][0] # ndarray, shape (n_coils, 3) + results["dist_limit"] = dist_limit + del fit_locs + for ci, loc in enumerate(hpi_dev): + results["dig_points"][ci]["r"][:] = loc.astype(float) + del hpi_dev, fit_dev_head_t, fit_order + meas = copy.deepcopy(old_meas) + meas["used"] = np.arange(1, n_coils + 1) # we use all of them + for ci, slope in enumerate(fit_amps["slopes"][0]): + meas["hpi_coils"][ci]["slopes"][:] = slope + # print out some stats about the refit + to_print = dict(old=old_results, new=results) + for kind, result in to_print.items(): + this_order = result["order"] + msg = f" {kind.capitalize()} order {this_order} errors: " + # errors + this_dev_head_t = result["coord_trans"] + this_hpi_dev = np.array([d["r"] for d in result["dig_points"]]).astype(float) + diffs = apply_trans(this_dev_head_t, this_hpi_dev) - hpi_head[this_order - 1] + dists = 1e3 * np.linalg.norm(diffs, axis=1) + for dist in dists: + msg += f"{dist:5.1f} " + msg += "mm" + logger.info(msg) + # In theory these are lists, but it seems like maxfilter only likes having a single + # entry. At some point we should try recording data where multiple fits are stored + # (maybe there actually aren't any...) + info["hpi_meas"][-1] = meas + info["hpi_results"][-1] = result + return info + + +def _sorted_hpi_dig(dig, *, kinds=(FIFF.FIFFV_POINT_HPI,)): + return sorted( + # need .get here because the hpi_result["dig_points"] does not set it + (d for d in dig if d.get("kind", FIFF.FIFFV_POINT_HPI) in kinds), + key=lambda d: d["ident"], + ) diff --git a/mne/conftest.py b/mne/conftest.py index 57d14205e17..5a7fa4fed40 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -19,8 +19,12 @@ import numpy as np import pytest -from pytest import StashKey +from pytest import StashKey, register_assert_rewrite +# Any `assert` statements in our testing functions should be verbose versions +register_assert_rewrite("mne.utils._testing") + +# ruff: noqa: E402 import mne from mne import Epochs, pick_types, read_events from mne.channels import read_layout @@ -39,11 +43,10 @@ check_version, numerics, ) - -# data from sample dataset from mne.viz._figure import use_browser_backend from mne.viz.backends._utils import _init_mne_qtapp +# data from sample dataset test_path = testing.data_path(download=False) s_path = op.join(test_path, "MEG", "sample") fname_evoked = op.join(s_path, "sample_audvis_trunc-ave.fif") diff --git a/mne/datasets/config.py b/mne/datasets/config.py index 10d3dea7fa4..23c1cf9e78b 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -87,7 +87,7 @@ # update the checksum in the MNE_DATASETS dict below, and change version # here: ↓↓↓↓↓↓↓↓ RELEASES = dict( - testing="0.168", + testing="0.169", misc="0.27", phantom_kit="0.2", ucl_opm_auditory="0.2", @@ -115,7 +115,7 @@ # Testing and misc are at the top as they're updated most often MNE_DATASETS["testing"] = dict( archive_name=f"{TESTING_VERSIONED}.tar.gz", - hash="md5:7782a64f170b9435b0fd126862b0cf63", + hash="md5:bb0524db8605e96fde6333893a969766", url=( "https://codeload.github.com/mne-tools/mne-testing-data/" f"tar.gz/{RELEASES['testing']}" diff --git a/mne/io/artemis123/artemis123.py b/mne/io/artemis123/artemis123.py index 56534c57ca0..177b57c5db7 100644 --- a/mne/io/artemis123/artemis123.py +++ b/mne/io/artemis123/artemis123.py @@ -434,13 +434,13 @@ def __init__( ) # compute initial head to dev transform and hpi ordering - head_to_dev_t, order, trans_g = _fit_coil_order_dev_head_trans( + dev_head_t, order, trans_g = _fit_coil_order_dev_head_trans( hpi_dev, hpi_head ) # set the device to head transform self.info["dev_head_t"] = Transform( - FIFF.FIFFV_COORD_DEVICE, FIFF.FIFFV_COORD_HEAD, head_to_dev_t + FIFF.FIFFV_COORD_DEVICE, FIFF.FIFFV_COORD_HEAD, dev_head_t ) # add hpi_meg_dev to dig... diff --git a/mne/io/artemis123/tests/test_artemis123.py b/mne/io/artemis123/tests/test_artemis123.py index b39f35febb9..e40913d0ed9 100644 --- a/mne/io/artemis123/tests/test_artemis123.py +++ b/mne/io/artemis123/tests/test_artemis123.py @@ -12,7 +12,7 @@ from mne.io import read_raw_artemis123 from mne.io.artemis123.utils import _generate_mne_locs_file, _load_mne_locs from mne.io.tests.test_raw import _test_raw_reader -from mne.transforms import _angle_between_quats, rot_to_quat +from mne.utils._testing import assert_trans_allclose artemis123_dir = testing.data_path(download=False) / "ARTEMIS123" short_HPI_dip_fname = ( @@ -28,17 +28,8 @@ # (old or new) def _assert_trans(actual, desired, dist_tol=0.017, angle_tol=5.0): __tracebackhide__ = True - trans_est = actual[0:3, 3] - quat_est = rot_to_quat(actual[0:3, 0:3]) - trans = desired[0:3, 3] - quat = rot_to_quat(desired[0:3, 0:3]) - - angle = np.rad2deg(_angle_between_quats(quat_est, quat)) - dist = np.linalg.norm(trans - trans_est) - assert dist <= dist_tol, ( - f"{1000 * dist:0.3f} > {1000 * dist_tol:0.3f} mm translation" - ) - assert angle <= angle_tol, f"{angle:0.3f} > {angle_tol:0.3f}° rotation" + # To minimize diff, keep this trivial wrapper + assert_trans_allclose(actual, desired, dist_tol=dist_tol, angle_tol=angle_tol) @testing.requires_testing_data diff --git a/mne/io/fiff/raw.py b/mne/io/fiff/raw.py index f9e10a0039d..d4e1e40a7aa 100644 --- a/mne/io/fiff/raw.py +++ b/mne/io/fiff/raw.py @@ -356,11 +356,7 @@ def _read_raw_file( raw.orig_format = orig_format # Add the calibration factors - cals = np.zeros(info["nchan"]) - for k in range(info["nchan"]): - cals[k] = info["chs"][k]["range"] * info["chs"][k]["cal"] - - raw._cals = cals + raw._cals = info._cals raw._raw_extras = raw_extras logger.info( " Range : %d ... %d = %9.3f ... %9.3f secs", diff --git a/mne/preprocessing/tests/test_artifact_detection.py b/mne/preprocessing/tests/test_artifact_detection.py index 91ef4743f95..ccd8893ba11 100644 --- a/mne/preprocessing/tests/test_artifact_detection.py +++ b/mne/preprocessing/tests/test_artifact_detection.py @@ -17,7 +17,7 @@ compute_average_dev_head_t, ) from mne.tests.test_annotations import _assert_annotations_equal -from mne.transforms import _angle_dist_between_rigid, quat_to_rot, rot_to_quat +from mne.transforms import angle_distance_between_rigid, quat_to_rot, rot_to_quat data_path = testing.data_path(download=False) sss_path = data_path / "SSS" @@ -119,22 +119,22 @@ def test_movement_annotation_head_correction(meas_date): "trans" ] unit_kw = dict(distance_units="mm", angle_units="deg") - deg_annot_combo, mm_annot_combo = _angle_dist_between_rigid( + deg_annot_combo, mm_annot_combo = angle_distance_between_rigid( dev_head_t, dev_head_t_combo, **unit_kw, ) - deg_unannot_combo, mm_unannot_combo = _angle_dist_between_rigid( + deg_unannot_combo, mm_unannot_combo = angle_distance_between_rigid( dev_head_t_unannot, dev_head_t_combo, **unit_kw, ) - deg_annot_unannot, mm_annot_unannot = _angle_dist_between_rigid( + deg_annot_unannot, mm_annot_unannot = angle_distance_between_rigid( dev_head_t, dev_head_t_unannot, **unit_kw, ) - deg_combo_naive, mm_combo_naive = _angle_dist_between_rigid( + deg_combo_naive, mm_combo_naive = angle_distance_between_rigid( dev_head_t_combo, dev_head_t_naive, **unit_kw, diff --git a/mne/preprocessing/tests/test_fine_cal.py b/mne/preprocessing/tests/test_fine_cal.py index 26a51b0f472..449930bebc2 100644 --- a/mne/preprocessing/tests/test_fine_cal.py +++ b/mne/preprocessing/tests/test_fine_cal.py @@ -19,7 +19,7 @@ write_fine_calibration, ) from mne.preprocessing.tests.test_maxwell import _assert_shielding -from mne.transforms import _angle_dist_between_rigid +from mne.transforms import angle_distance_between_rigid from mne.utils import catch_logging, object_diff # Define fine calibration filepaths @@ -127,19 +127,19 @@ def test_compute_fine_cal(kind): orig_trans = _loc_to_coil_trans(orig_locs) want_trans = _loc_to_coil_trans(want_locs) got_trans = _loc_to_coil_trans(got_locs) - want_orig_angles, want_orig_dist = _angle_dist_between_rigid( + want_orig_angles, want_orig_dist = angle_distance_between_rigid( want_trans, orig_trans, angle_units="deg", distance_units="mm", ) - got_want_angles, got_want_dist = _angle_dist_between_rigid( + got_want_angles, got_want_dist = angle_distance_between_rigid( got_trans, want_trans, angle_units="deg", distance_units="mm", ) - got_orig_angles, got_orig_dist = _angle_dist_between_rigid( + got_orig_angles, got_orig_dist = angle_distance_between_rigid( got_trans, orig_trans, angle_units="deg", diff --git a/mne/tests/test_chpi.py b/mne/tests/test_chpi.py index 790e93992d7..0ba13f8c708 100644 --- a/mne/tests/test_chpi.py +++ b/mne/tests/test_chpi.py @@ -28,6 +28,7 @@ get_chpi_info, head_pos_to_trans_rot_t, read_head_pos, + refit_hpi, write_head_pos, ) from mne.datasets import testing @@ -41,7 +42,11 @@ read_raw_kit, ) from mne.simulation import add_chpi -from mne.transforms import _angle_between_quats, rot_to_quat +from mne.transforms import ( + _angle_between_quats, + angle_distance_between_rigid, + rot_to_quat, +) from mne.utils import ( _record_warnings, assert_meg_snr, @@ -49,6 +54,7 @@ object_diff, verbose, ) +from mne.utils._testing import assert_trans_allclose from mne.viz import plot_head_positions base_dir = Path(__file__).parents[1] / "io" / "tests" / "data" @@ -66,6 +72,7 @@ chpi5_pos_fname = data_path / "SSS" / "chpi5_raw_mc.pos" ctf_chpi_fname = data_path / "CTF" / "testdata_ctf_mc.ds" ctf_chpi_pos_fname = data_path / "CTF" / "testdata_ctf_mc.pos" +chpi_problem_fname = data_path / "SSS" / "chpi_problematic-info.fif" art_fname = ( data_path @@ -96,7 +103,7 @@ def test_chpi_adjust(): msg = [ "HPIFIT: 5 coils digitized in order 5 1 4 3 2", "HPIFIT: 3 coils accepted: 1 2 4", - "Hpi coil moments (3, 5):", + "HPI coil moments (3, 5):", "2.08542e-15 -1.52486e-15 -1.53484e-15", "2.14516e-15 2.09608e-15 7.30303e-16", "-3.2318e-16 -4.25666e-16 2.69997e-15", @@ -859,3 +866,148 @@ def test_get_active_chpi_neuromag(): get_active_chpi(raw_no_chpi, on_missing="ignore"), np.zeros_like(raw_no_chpi.times), ) + + +def assert_slopes_correlated(actual_meas, desired_meas, *, lim=(0.99, 1.0)): + """Assert that slopes in two coil info dicts are all close.""" + __tracebackhide__ = True + assert len(actual_meas["hpi_coils"]) == len(desired_meas["hpi_coils"]) + for ci, (c1, c2) in enumerate( + zip(actual_meas["hpi_coils"], desired_meas["hpi_coils"]) + ): + corr = np.abs(np.corrcoef(c1["slopes"].ravel(), c2["slopes"].ravel())[0, 1]) + assert lim[0] <= corr <= lim[1], f"meas['hpi_coils'][{ci}] corr: {corr}" + + +@pytest.mark.slowtest +@testing.requires_testing_data +def test_refit_hpi_locs_basic(): + """Test that HPI locations can be refit.""" + raw = read_raw_fif(chpi_fif_fname, allow_maxshield="yes").crop(0, 2).load_data() + # These should be similar (and both should work) + locs = compute_chpi_amplitudes(raw, t_step_min=2, t_window=1) + locs_2 = compute_chpi_amplitudes(raw, t_step_min=2, t_window=1, ext_order=0) + corr = np.corrcoef(locs["slopes"].ravel(), locs_2["slopes"].ravel())[0, 1] + assert 0.999 < corr < 1.0 + info = raw.info + del raw + + # Refit on these data won't change much + info_new = info.copy() + assert len(info["hpi_results"][-1]["used"]) == 3 + refit_hpi(info_new, amplitudes=False, locs=False, use=3) + assert len(info_new["hpi_results"]) == len(info["hpi_results"]) == 1 + assert len(info_new["hpi_meas"]) == len(info["hpi_meas"]) == 1 + assert_trans_allclose( + info_new["dev_head_t"], + info["dev_head_t"], + dist_tol=0.1e-3, + angle_tol=0.1, + ) + # Refit with more coils than hpifit (our default is use=None) + refit_hpi(info_new, amplitudes=False, locs=False) + assert len(info_new["hpi_results"][-1]["used"]) == 5 + assert_trans_allclose( + info_new["dev_head_t"], + info["dev_head_t"], + dist_tol=3e-3, + angle_tol=2, + ) + # Refit locations + refit_hpi(info_new, amplitudes=False) # default: locs=True + assert_trans_allclose( + info_new["dev_head_t"], + info["dev_head_t"], + dist_tol=2e-3, + angle_tol=2, + ) + assert_array_equal( + info_new["hpi_results"][-1]["order"], info["hpi_results"][-1]["order"] + ) + assert_slopes_correlated( + info_new["hpi_meas"][-1], + info["hpi_meas"][-1], + lim=(0.999999, 1.0), + ) + with pytest.raises(ValueError, match="must also be True"): + refit_hpi(info_new, locs=False) + # Refit locations and amplitudes (with ext_order=0 just to make sure it works) + refit_hpi(info_new, ext_order=0) + assert_trans_allclose( + info_new["dev_head_t"], + info["dev_head_t"], + dist_tol=2e-3, + angle_tol=2, + ) + assert_array_equal( + info_new["hpi_results"][-1]["order"], info["hpi_results"][-1]["order"] + ) + assert_slopes_correlated( + info_new["hpi_meas"][-1], info["hpi_meas"][-1], lim=(0.99, 0.999999) + ) + + +@testing.requires_testing_data +def test_refit_hpi_locs_problematic(): + """Test that we can fix problematic HPI fits.""" + info_bad = read_info(chpi_problem_fname) + ang, dist = angle_distance_between_rigid( + info_bad["dev_head_t"]["trans"], angle_units="deg", distance_units="mm" + ) + assert_allclose(ang, 177, atol=1) # upside-down! + assert_allclose(dist, 61, atol=1) + orig_order = [4, 2, 1, 3, 5] + good_order = [1, 2, 4, 3, 5] + assert_array_equal(info_bad["hpi_results"][-1]["order"], orig_order) + orig_use = info_bad["hpi_results"][-1]["used"] + assert_array_equal(orig_use, [2, 3, 5]) + with pytest.warns(RuntimeWarning, match="colinear"): + info_new = refit_hpi( + info_bad.copy(), + amplitudes=False, + locs=False, + order=False, + use=orig_use - 1, + dist_limit=np.inf, + colinearity_limit=0.03, + ) + assert_array_equal(info_new["hpi_results"][-1]["order"], orig_order) + assert_array_equal(info_new["hpi_results"][-1]["used"], orig_use) + assert_trans_allclose( + info_new["dev_head_t"], + info_bad["dev_head_t"], + dist_tol=1e-3, + angle_tol=1, + ) + # Even just allowing our permutation checker to run helps + with pytest.raises(RuntimeError, match="need at least 3"): + refit_hpi(info_bad.copy(), amplitudes=False, locs=False, order=False) + info_new = refit_hpi( + info_bad.copy(), + amplitudes=False, + locs=False, + dist_limit=0.02, + colinearity_limit=0.03, + ) + assert_array_equal(info_new["hpi_results"][-1]["order"], good_order) + ang, dist = angle_distance_between_rigid( + info_new["dev_head_t"]["trans"], + angle_units="deg", + distance_units="mm", + ) + assert 10 < ang < 15 # much more upright! + assert 75 < dist < 80 + with pytest.warns(RuntimeWarning, match="Discrepancy"): + # We can run this with amplitudes=True, but it's much faster not to + # (and the result is very similar) + info_new = refit_hpi( + info_bad.copy(), amplitudes=False, dist_limit=0.01, colinearity_limit=0.03 + ) + assert_array_equal(info_new["hpi_results"][-1]["order"], good_order) + ang, dist = angle_distance_between_rigid( + info_new["dev_head_t"]["trans"], + angle_units="deg", + distance_units="mm", + ) + assert 3 < ang < 6 + assert 82 < dist < 87 diff --git a/mne/tests/test_transforms.py b/mne/tests/test_transforms.py index 05cf3e6b90f..69f8160f3c1 100644 --- a/mne/tests/test_transforms.py +++ b/mne/tests/test_transforms.py @@ -41,6 +41,7 @@ _topo_to_sph, _validate_pipeline, _write_fs_xfm, + angle_distance_between_rigid, apply_trans, combine_transforms, get_ras_to_neuromag_trans, @@ -563,6 +564,12 @@ def test_fit_matched_points(quats, scaling, do_scale): dist = np.linalg.norm(est[3:] - translation) assert_array_less(dist_bounds[0], dist) assert_array_less(dist, dist_bounds[1]) + # check our public function as well + a = _quat_to_affine(est) + b = _quat_to_affine(np.r_[quat, translation]) + angle_, dist_ = angle_distance_between_rigid(a, b, angle_units="deg") + assert_allclose(angle, angle_) + assert_allclose(dist, dist_) def test_euler(quats): diff --git a/mne/transforms.py b/mne/transforms.py index 1c7f26ab1e0..76289504785 100644 --- a/mne/transforms.py +++ b/mne/transforms.py @@ -1354,24 +1354,50 @@ def rot_to_quat(rot): def _quat_to_affine(quat): - assert quat.shape == (6,) + assert quat.shape == (6,), quat.shape affine = np.eye(4) affine[:3, :3] = quat_to_rot(quat[:3]) affine[:3, 3] = quat[3:] return affine -def _affine_to_quat(affine): - assert affine.shape[-2:] == (4, 4) +def _affine_to_quat(affine, *, name="affine"): + _validate_type(affine, np.ndarray, name) + if affine.shape[-2:] != (4, 4): + raise ValueError(f"{name} must be of shape (..., 4, 4), got {affine.shape}") return np.concatenate( [rot_to_quat(affine[..., :3, :3]), affine[..., :3, 3]], axis=-1, ) -def _angle_dist_between_rigid(a, b=None, *, angle_units="rad", distance_units="m"): - a = _affine_to_quat(a) - b = np.zeros(6) if b is None else _affine_to_quat(b) +def angle_distance_between_rigid(a, b=None, *, angle_units="rad", distance_units="m"): + """Compute the angle and distance between two rigid transforms. + + Parameters + ---------- + a : array, shape (..., 4, 4) + First rigid transform. + b : array, shape (..., 4, 4) | None + Second rigid transform. If None, the identity transform is used. + angle_units : str + Units for the angle output, either "rad" or "deg". + distance_units : str + Units for the distance output, either "m" or "mm". + + Returns + ------- + angles : array, shape (...) + The angles between the two transforms. + distances : array, shape (...) + The distances between the two transforms. + + Notes + ----- + .. versionadded:: 1.11 + """ + a = _affine_to_quat(a, name="a") + b = np.zeros(6) if b is None else _affine_to_quat(b, name="b") ang = _angle_between_quats(a[..., :3], b[..., :3]) dist = np.linalg.norm(a[..., 3:] - b[..., 3:], axis=-1) assert isinstance(angle_units, str) and angle_units in ("rad", "deg") @@ -1877,7 +1903,9 @@ def _compute_volume_registration( # report some useful information if step in ("translation", "rigid"): - angle, dist = _angle_dist_between_rigid(reg_affine, angle_units="deg") + angle, dist = angle_distance_between_rigid( + reg_affine, angle_units="deg" + ) logger.info(f" Translation: {dist:6.1f} mm") if step == "rigid": logger.info(f" Rotation: {angle:6.1f}°") diff --git a/mne/utils/_testing.py b/mne/utils/_testing.py index 63e0d1036b9..189d6877134 100644 --- a/mne/utils/_testing.py +++ b/mne/utils/_testing.py @@ -169,18 +169,24 @@ def dec(*args, **kwargs): def assert_and_remove_boundary_annot(annotations, n=1): """Assert that there are boundary annotations and remove them.""" + __tracebackhide__ = True + from ..io import BaseRaw if isinstance(annotations, BaseRaw): # allow either input annotations = annotations.annotations for key in ("EDGE", "BAD"): idx = np.where(annotations.description == f"{key} boundary")[0] - assert len(idx) == n + assert len(idx) == n, ( + f"Got {len(idx)} '{key} boundary' annotations, expected {n}" + ) annotations.delete(idx) def assert_object_equal(a, b, *, err_msg="Object mismatch", allclose=False): """Assert two objects are equal.""" + __tracebackhide__ = True + d = object_diff(a, b, allclose=allclose) assert d == "", f"{err_msg}\n{d}" @@ -214,6 +220,8 @@ def _get_data(x, ch_idx): def _check_snr(actual, desired, picks, min_tol, med_tol, msg, kind="MEG"): """Check the SNR of a set of channels.""" + __tracebackhide__ = True + actual_data = _get_data(actual, picks) desired_data = _get_data(desired, picks) bench_rms = np.sqrt(np.mean(desired_data * desired_data, axis=1)) @@ -242,6 +250,8 @@ def assert_meg_snr( Mostly useful for operations like Maxwell filtering that modify MEG channels while leaving EEG and others intact. """ + __tracebackhide__ = True + from .._fiff.pick import pick_types picks = pick_types(desired.info, meg=True, exclude=[]) @@ -273,19 +283,31 @@ def assert_meg_snr( def assert_snr(actual, desired, tol): """Assert actual and desired arrays are within some SNR tolerance.""" + __tracebackhide__ = True + with np.errstate(divide="ignore"): # allow infinite snr = linalg.norm(desired, ord="fro") / linalg.norm(desired - actual, ord="fro") - assert snr >= tol, f"{snr} < {tol}" + assert snr >= tol, f"{snr=} < {tol=}" def assert_stcs_equal(stc1, stc2): """Check that two STC are equal.""" - assert_allclose(stc1.times, stc2.times) - assert_allclose(stc1.data, stc2.data) - assert_array_equal(stc1.vertices[0], stc2.vertices[0]) - assert_array_equal(stc1.vertices[1], stc2.vertices[1]) - assert_allclose(stc1.tmin, stc2.tmin) - assert_allclose(stc1.tstep, stc2.tstep) + __tracebackhide__ = True + + assert_allclose(stc1.times, stc2.times, err_msg="Times mismatch") + assert_allclose(stc1.data, stc2.data, err_msg="Data mismatch") + assert_array_equal( + stc1.vertices[0], + stc2.vertices[0], + err_msg="Left vertices mismatch", + ) + assert_array_equal( + stc1.vertices[1], + stc2.vertices[1], + err_msg="Right vertices mismatch", + ) + assert_allclose(stc1.tmin, stc2.tmin, err_msg="tmin mismatch") + assert_allclose(stc1.tstep, stc2.tstep, err_msg="tstep mismatch") def _dig_sort_key(dig): @@ -295,6 +317,8 @@ def _dig_sort_key(dig): def assert_dig_allclose(info_py, info_bin, limit=None): """Assert dig allclose.""" + __tracebackhide__ = True + from .._fiff.constants import FIFF from .._fiff.meas_info import Info from ..bem import fit_sphere_to_headshape @@ -303,21 +327,21 @@ def assert_dig_allclose(info_py, info_bin, limit=None): # test dig positions dig_py, dig_bin = info_py, info_bin if isinstance(dig_py, Info): - assert isinstance(dig_bin, Info) + assert isinstance(dig_bin, Info), "Both must be Info or DigMontage" dig_py, dig_bin = dig_py["dig"], dig_bin["dig"] else: - assert isinstance(dig_bin, DigMontage) - assert isinstance(dig_py, DigMontage) + assert isinstance(dig_bin, DigMontage), "Both must be Info or DigMontage" + assert isinstance(dig_py, DigMontage), "Both must be Info or DigMontage" dig_py, dig_bin = dig_py.dig, dig_bin.dig info_py = info_bin = None - assert isinstance(dig_py, list) - assert isinstance(dig_bin, list) + assert isinstance(dig_py, list), "dig_py must be a list" + assert isinstance(dig_bin, list), "dig_bin must be a list" dig_py = sorted(dig_py, key=_dig_sort_key) dig_bin = sorted(dig_bin, key=_dig_sort_key) - assert len(dig_py) == len(dig_bin) + assert len(dig_py) == len(dig_bin), "Different number of dig points" for ii, (d_py, d_bin) in enumerate(zip(dig_py[:limit], dig_bin[:limit])): for key in ("ident", "kind", "coord_frame"): - assert d_py[key] == d_bin[key], key + assert d_py[key] == d_bin[key], f"{key=} mismatch on point {ii}" assert_allclose( d_py["r"], d_bin["r"], @@ -332,9 +356,21 @@ def assert_dig_allclose(info_py, info_bin, limit=None): r_py, o_head_py, o_dev_py = fit_sphere_to_headshape( info_py, units="m", verbose="error" ) - assert_allclose(r_py, r_bin, atol=1e-6) - assert_allclose(o_dev_py, o_dev_bin, rtol=1e-5, atol=1e-6) - assert_allclose(o_head_py, o_head_bin, rtol=1e-5, atol=1e-6) + assert_allclose(r_py, r_bin, atol=1e-6, err_msg="Sphere radius mismatch") + assert_allclose( + o_dev_py, + o_dev_bin, + rtol=1e-5, + atol=1e-6, + err_msg="Sphere device origin mismatch", + ) + assert_allclose( + o_head_py, + o_head_bin, + rtol=1e-5, + atol=1e-6, + err_msg="Sphere origin mismatch", + ) def _click_ch_name(fig, ch_index=0, button=1): @@ -357,3 +393,27 @@ def _get_suptitle(fig): else: # unreliable hack; should work in most tests as we rarely use `sup_{x,y}label` return fig.texts[0].get_text() + + +def assert_trans_allclose(actual, desired, dist_tol=0.0, angle_tol=0.0): + __tracebackhide__ = True + + from ..transforms import Transform, angle_distance_between_rigid + + if isinstance(actual, Transform): + assert isinstance(desired, Transform), "Both must be Transform or ndarray" + assert actual["from"] == desired["from"], "'from' frame mismatch" + assert actual["to"] == desired["to"], "'to' frame mismatch" + actual = actual["trans"] + desired = desired["trans"] + assert isinstance(actual, np.ndarray), "actual should be ndarray" + assert isinstance(desired, np.ndarray), "desired should be ndarray" + assert actual.shape == (4, 4), "actual.shape should be (4, 4)" + assert desired.shape == (4, 4), "desired.shape should be (4, 4)" + angle, dist = angle_distance_between_rigid( + actual, desired, angle_units="deg", distance_units="m" + ) + assert dist <= dist_tol, ( + f"{1000 * dist:0.3f} > {1000 * dist_tol:0.3f} mm translation" + ) + assert angle <= angle_tol, f"{angle:0.3f} > {angle_tol:0.3f}° rotation" diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index f844d9b54e5..cc660fe4986 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -51,13 +51,13 @@ from ..transforms import ( Transform, _angle_between_quats, - _angle_dist_between_rigid, _ensure_trans, _find_trans, _frame_to_str, _get_trans, _get_transforms_to_coord_frame, _print_coord_trans, + angle_distance_between_rigid, apply_trans, combine_transforms, read_ras_mni_t, @@ -328,7 +328,7 @@ def plot_head_positions( for ax, val in zip(axes[:3].ravel(), vals): ax.axhline(val, color="r", ls=":", zorder=2, lw=1.0) if totals: - dest_ang, dest_dist = _angle_dist_between_rigid( + dest_ang, dest_dist = angle_distance_between_rigid( destination, angle_units="deg", distance_units="mm",