Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/api/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ EEG referencing:
get_chpi_info
head_pos_to_trans_rot_t
read_head_pos
refit_hpi
write_head_pos

:py:mod:`mne.transforms`
Expand All @@ -235,6 +236,7 @@ EEG referencing:
:toctree: ../generated/

Transform
angle_distance_between_rigid
quat_to_rot
rot_to_quat
read_ras_mni_t
1 change: 1 addition & 0 deletions doc/changes/dev/13484.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add ability to refit HPI order and device-to-head transform via :func:`mne.chpi.refit_hpi` and compute distances between transforms with :func:`mne.transforms.angle_distance_between_rigid` by `Eric Larson`_.
4 changes: 4 additions & 0 deletions mne/_fiff/meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -1946,6 +1946,10 @@ def ch_names(self):

return ch_names

@property
def _cals(self):
return np.array([ch["range"] * ch["cal"] for ch in self["chs"]], float)

@repr_html
def _repr_html_(self):
"""Summarize info for HTML representation."""
Expand Down
349 changes: 333 additions & 16 deletions mne/chpi.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions mne/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
# update the checksum in the MNE_DATASETS dict below, and change version
# here: ↓↓↓↓↓↓↓↓
RELEASES = dict(
testing="0.168",
testing="0.169",
misc="0.27",
phantom_kit="0.2",
ucl_opm_auditory="0.2",
Expand Down Expand Up @@ -115,7 +115,7 @@
# Testing and misc are at the top as they're updated most often
MNE_DATASETS["testing"] = dict(
archive_name=f"{TESTING_VERSIONED}.tar.gz",
hash="md5:7782a64f170b9435b0fd126862b0cf63",
hash="md5:bb0524db8605e96fde6333893a969766",
url=(
"https://codeload.github.com/mne-tools/mne-testing-data/"
f"tar.gz/{RELEASES['testing']}"
Expand Down
4 changes: 2 additions & 2 deletions mne/io/artemis123/artemis123.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,13 +434,13 @@ def __init__(
)

# compute initial head to dev transform and hpi ordering
head_to_dev_t, order, trans_g = _fit_coil_order_dev_head_trans(
dev_head_t, order, trans_g = _fit_coil_order_dev_head_trans(
hpi_dev, hpi_head
)

# set the device to head transform
self.info["dev_head_t"] = Transform(
FIFF.FIFFV_COORD_DEVICE, FIFF.FIFFV_COORD_HEAD, head_to_dev_t
FIFF.FIFFV_COORD_DEVICE, FIFF.FIFFV_COORD_HEAD, dev_head_t
)

# add hpi_meg_dev to dig...
Expand Down
15 changes: 3 additions & 12 deletions mne/io/artemis123/tests/test_artemis123.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mne.io import read_raw_artemis123
from mne.io.artemis123.utils import _generate_mne_locs_file, _load_mne_locs
from mne.io.tests.test_raw import _test_raw_reader
from mne.transforms import _angle_between_quats, rot_to_quat
from mne.utils._testing import assert_trans_allclose

artemis123_dir = testing.data_path(download=False) / "ARTEMIS123"
short_HPI_dip_fname = (
Expand All @@ -28,17 +28,8 @@
# (old or new)
def _assert_trans(actual, desired, dist_tol=0.017, angle_tol=5.0):
__tracebackhide__ = True
trans_est = actual[0:3, 3]
quat_est = rot_to_quat(actual[0:3, 0:3])
trans = desired[0:3, 3]
quat = rot_to_quat(desired[0:3, 0:3])

angle = np.rad2deg(_angle_between_quats(quat_est, quat))
dist = np.linalg.norm(trans - trans_est)
assert dist <= dist_tol, (
f"{1000 * dist:0.3f} > {1000 * dist_tol:0.3f} mm translation"
)
assert angle <= angle_tol, f"{angle:0.3f} > {angle_tol:0.3f}° rotation"
# To minimize diff, keep this trivial wrapper
assert_trans_allclose(actual, desired, dist_tol=dist_tol, angle_tol=angle_tol)


@testing.requires_testing_data
Expand Down
6 changes: 1 addition & 5 deletions mne/io/fiff/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,11 +356,7 @@ def _read_raw_file(
raw.orig_format = orig_format

# Add the calibration factors
cals = np.zeros(info["nchan"])
for k in range(info["nchan"]):
cals[k] = info["chs"][k]["range"] * info["chs"][k]["cal"]

raw._cals = cals
raw._cals = info._cals
raw._raw_extras = raw_extras
logger.info(
" Range : %d ... %d = %9.3f ... %9.3f secs",
Expand Down
10 changes: 5 additions & 5 deletions mne/preprocessing/tests/test_artifact_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
compute_average_dev_head_t,
)
from mne.tests.test_annotations import _assert_annotations_equal
from mne.transforms import _angle_dist_between_rigid, quat_to_rot, rot_to_quat
from mne.transforms import angle_distance_between_rigid, quat_to_rot, rot_to_quat

data_path = testing.data_path(download=False)
sss_path = data_path / "SSS"
Expand Down Expand Up @@ -119,22 +119,22 @@ def test_movement_annotation_head_correction(meas_date):
"trans"
]
unit_kw = dict(distance_units="mm", angle_units="deg")
deg_annot_combo, mm_annot_combo = _angle_dist_between_rigid(
deg_annot_combo, mm_annot_combo = angle_distance_between_rigid(
dev_head_t,
dev_head_t_combo,
**unit_kw,
)
deg_unannot_combo, mm_unannot_combo = _angle_dist_between_rigid(
deg_unannot_combo, mm_unannot_combo = angle_distance_between_rigid(
dev_head_t_unannot,
dev_head_t_combo,
**unit_kw,
)
deg_annot_unannot, mm_annot_unannot = _angle_dist_between_rigid(
deg_annot_unannot, mm_annot_unannot = angle_distance_between_rigid(
dev_head_t,
dev_head_t_unannot,
**unit_kw,
)
deg_combo_naive, mm_combo_naive = _angle_dist_between_rigid(
deg_combo_naive, mm_combo_naive = angle_distance_between_rigid(
dev_head_t_combo,
dev_head_t_naive,
**unit_kw,
Expand Down
8 changes: 4 additions & 4 deletions mne/preprocessing/tests/test_fine_cal.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
write_fine_calibration,
)
from mne.preprocessing.tests.test_maxwell import _assert_shielding
from mne.transforms import _angle_dist_between_rigid
from mne.transforms import angle_distance_between_rigid
from mne.utils import catch_logging, object_diff

# Define fine calibration filepaths
Expand Down Expand Up @@ -127,19 +127,19 @@ def test_compute_fine_cal(kind):
orig_trans = _loc_to_coil_trans(orig_locs)
want_trans = _loc_to_coil_trans(want_locs)
got_trans = _loc_to_coil_trans(got_locs)
want_orig_angles, want_orig_dist = _angle_dist_between_rigid(
want_orig_angles, want_orig_dist = angle_distance_between_rigid(
want_trans,
orig_trans,
angle_units="deg",
distance_units="mm",
)
got_want_angles, got_want_dist = _angle_dist_between_rigid(
got_want_angles, got_want_dist = angle_distance_between_rigid(
got_trans,
want_trans,
angle_units="deg",
distance_units="mm",
)
got_orig_angles, got_orig_dist = _angle_dist_between_rigid(
got_orig_angles, got_orig_dist = angle_distance_between_rigid(
got_trans,
orig_trans,
angle_units="deg",
Expand Down
156 changes: 154 additions & 2 deletions mne/tests/test_chpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
get_chpi_info,
head_pos_to_trans_rot_t,
read_head_pos,
refit_hpi,
write_head_pos,
)
from mne.datasets import testing
Expand All @@ -41,14 +42,19 @@
read_raw_kit,
)
from mne.simulation import add_chpi
from mne.transforms import _angle_between_quats, rot_to_quat
from mne.transforms import (
_angle_between_quats,
angle_distance_between_rigid,
rot_to_quat,
)
from mne.utils import (
_record_warnings,
assert_meg_snr,
catch_logging,
object_diff,
verbose,
)
from mne.utils._testing import assert_trans_allclose
from mne.viz import plot_head_positions

base_dir = Path(__file__).parents[1] / "io" / "tests" / "data"
Expand All @@ -66,6 +72,7 @@
chpi5_pos_fname = data_path / "SSS" / "chpi5_raw_mc.pos"
ctf_chpi_fname = data_path / "CTF" / "testdata_ctf_mc.ds"
ctf_chpi_pos_fname = data_path / "CTF" / "testdata_ctf_mc.pos"
chpi_problem_fname = data_path / "SSS" / "chpi_problematic-info.fif"

art_fname = (
data_path
Expand Down Expand Up @@ -96,7 +103,7 @@ def test_chpi_adjust():
msg = [
"HPIFIT: 5 coils digitized in order 5 1 4 3 2",
"HPIFIT: 3 coils accepted: 1 2 4",
"Hpi coil moments (3, 5):",
"HPI coil moments (3, 5):",
"2.08542e-15 -1.52486e-15 -1.53484e-15",
"2.14516e-15 2.09608e-15 7.30303e-16",
"-3.2318e-16 -4.25666e-16 2.69997e-15",
Expand Down Expand Up @@ -859,3 +866,148 @@ def test_get_active_chpi_neuromag():
get_active_chpi(raw_no_chpi, on_missing="ignore"),
np.zeros_like(raw_no_chpi.times),
)


def assert_slopes_correlated(actual_meas, desired_meas, *, lim=(0.99, 1.0)):
"""Assert that slopes in two coil info dicts are all close."""
__tracebackhide__ = True
assert len(actual_meas["hpi_coils"]) == len(desired_meas["hpi_coils"])
for ci, (c1, c2) in enumerate(
zip(actual_meas["hpi_coils"], desired_meas["hpi_coils"])
):
corr = np.abs(np.corrcoef(c1["slopes"].ravel(), c2["slopes"].ravel())[0, 1])
assert lim[0] <= corr <= lim[1], f"meas['hpi_coils'][{ci}] corr: {corr}"


@pytest.mark.slowtest
@testing.requires_testing_data
def test_refit_hpi_locs_basic():
"""Test that HPI locations can be refit."""
raw = read_raw_fif(chpi_fif_fname, allow_maxshield="yes").crop(0, 2).load_data()
# These should be similar (and both should work)
locs = compute_chpi_amplitudes(raw, t_step_min=2, t_window=1)
locs_2 = compute_chpi_amplitudes(raw, t_step_min=2, t_window=1, ext_order=0)
corr = np.corrcoef(locs["slopes"].ravel(), locs_2["slopes"].ravel())[0, 1]
assert 0.999 < corr < 1.0
info = raw.info
del raw

# Refit on these data won't change much
info_new = info.copy()
assert len(info["hpi_results"][-1]["used"]) == 3
refit_hpi(info_new, amplitudes=False, locs=False, use=3)
assert len(info_new["hpi_results"]) == len(info["hpi_results"]) == 1
assert len(info_new["hpi_meas"]) == len(info["hpi_meas"]) == 1
assert_trans_allclose(
info_new["dev_head_t"],
info["dev_head_t"],
dist_tol=0.1e-3,
angle_tol=0.1,
)
# Refit with more coils than hpifit (our default is use=None)
refit_hpi(info_new, amplitudes=False, locs=False)
assert len(info_new["hpi_results"][-1]["used"]) == 5
assert_trans_allclose(
info_new["dev_head_t"],
info["dev_head_t"],
dist_tol=3e-3,
angle_tol=2,
)
# Refit locations
refit_hpi(info_new, amplitudes=False) # default: locs=True
assert_trans_allclose(
info_new["dev_head_t"],
info["dev_head_t"],
dist_tol=2e-3,
angle_tol=2,
)
assert_array_equal(
info_new["hpi_results"][-1]["order"], info["hpi_results"][-1]["order"]
)
assert_slopes_correlated(
info_new["hpi_meas"][-1],
info["hpi_meas"][-1],
lim=(0.999999, 1.0),
)
with pytest.raises(ValueError, match="must also be True"):
refit_hpi(info_new, locs=False)
# Refit locations and amplitudes (with ext_order=0 just to make sure it works)
refit_hpi(info_new, ext_order=0)
assert_trans_allclose(
info_new["dev_head_t"],
info["dev_head_t"],
dist_tol=2e-3,
angle_tol=2,
)
assert_array_equal(
info_new["hpi_results"][-1]["order"], info["hpi_results"][-1]["order"]
)
assert_slopes_correlated(
info_new["hpi_meas"][-1], info["hpi_meas"][-1], lim=(0.99, 0.999999)
)


@testing.requires_testing_data
def test_refit_hpi_locs_problematic():
"""Test that we can fix problematic HPI fits."""
info_bad = read_info(chpi_problem_fname)
ang, dist = angle_distance_between_rigid(
info_bad["dev_head_t"]["trans"], angle_units="deg", distance_units="mm"
)
assert_allclose(ang, 177, atol=1) # upside-down!
assert_allclose(dist, 61, atol=1)
orig_order = [4, 2, 1, 3, 5]
good_order = [1, 2, 4, 3, 5]
assert_array_equal(info_bad["hpi_results"][-1]["order"], orig_order)
orig_use = info_bad["hpi_results"][-1]["used"]
assert_array_equal(orig_use, [2, 3, 5])
with pytest.warns(RuntimeWarning, match="colinear"):
info_new = refit_hpi(
info_bad.copy(),
amplitudes=False,
locs=False,
order=False,
use=orig_use - 1,
dist_limit=np.inf,
linearity_limit=0.03,
)
assert_array_equal(info_new["hpi_results"][-1]["order"], orig_order)
assert_array_equal(info_new["hpi_results"][-1]["used"], orig_use)
assert_trans_allclose(
info_new["dev_head_t"],
info_bad["dev_head_t"],
dist_tol=1e-3,
angle_tol=1,
)
# Even just allowing our permutation checker to run helps
with pytest.raises(RuntimeError, match="need at least 3"):
refit_hpi(info_bad.copy(), amplitudes=False, locs=False, order=False)
info_new = refit_hpi(
info_bad.copy(),
amplitudes=False,
locs=False,
dist_limit=0.02,
linearity_limit=0.03,
)
assert_array_equal(info_new["hpi_results"][-1]["order"], good_order)
ang, dist = angle_distance_between_rigid(
info_new["dev_head_t"]["trans"],
angle_units="deg",
distance_units="mm",
)
assert 10 < ang < 15 # much more upright!
assert 75 < dist < 80
with pytest.warns(RuntimeWarning, match="Discrepancy"):
# We can run this with amplitudes=True, but it's much faster not to
# (and the result is very similar)
info_new = refit_hpi(
info_bad.copy(), amplitudes=False, dist_limit=0.01, linearity_limit=0.03
)
assert_array_equal(info_new["hpi_results"][-1]["order"], good_order)
ang, dist = angle_distance_between_rigid(
info_new["dev_head_t"]["trans"],
angle_units="deg",
distance_units="mm",
)
assert 3 < ang < 6
assert 82 < dist < 87
7 changes: 7 additions & 0 deletions mne/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_topo_to_sph,
_validate_pipeline,
_write_fs_xfm,
angle_distance_between_rigid,
apply_trans,
combine_transforms,
get_ras_to_neuromag_trans,
Expand Down Expand Up @@ -563,6 +564,12 @@ def test_fit_matched_points(quats, scaling, do_scale):
dist = np.linalg.norm(est[3:] - translation)
assert_array_less(dist_bounds[0], dist)
assert_array_less(dist, dist_bounds[1])
# check our public function as well
a = _quat_to_affine(est)
b = _quat_to_affine(np.r_[quat, translation])
angle_, dist_ = angle_distance_between_rigid(a, b, angle_units="deg")
assert_allclose(angle, angle_)
assert_allclose(dist, dist_)


def test_euler(quats):
Expand Down
Loading