Skip to content

Commit

Permalink
Add possibility to show an stc with the fieldmap
Browse files Browse the repository at this point in the history
  • Loading branch information
wmvanvliet committed Feb 4, 2025
1 parent d878f60 commit b1ee63a
Showing 1 changed file with 47 additions and 26 deletions.
73 changes: 47 additions & 26 deletions mne/gui/_xfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,11 @@
from ..forward import convert_forward_solution, make_field_map, make_forward_dipole
from ..minimum_norm import apply_inverse, make_inverse_operator
from ..surface import _normal_orth
from ..transforms import (
_get_trans,
_get_transforms_to_coord_frame,
transform_surface_to,
)
from ..transforms import _get_trans, _get_transforms_to_coord_frame, apply_trans
from ..utils import _check_option, fill_doc, logger, verbose
from ..viz import EvokedField, create_3d_figure
from ..viz._3d import _plot_head_surface, _plot_sensors_3d
from ..viz.ui_events import subscribe
from ..viz.ui_events import link, subscribe
from ..viz.utils import _get_color_list


Expand All @@ -41,7 +37,7 @@ def dipolefit(
bem=None,
initial_time=None,
trans=None,
rank="info",
rank=None,
show_density=True,
subject=None,
subjects_dir=None,
Expand All @@ -60,11 +56,13 @@ def dipolefit(
Boundary element model to use in forward calculations. If ``None``, a spherical
model is used.
initial_time : float | None
Initial time point to show. If None, the time point of the maximum field
Initial time point to show. If ``None``, the time point of the maximum field
strength is used.
trans : instance of Transform | None
The transformation from head coordinates to MRI coordinates. If ``None``,
the identity matrix is used.
The transformation from head coordinates to MRI coordinates. If ``None``, the
identity matrix is used.
stc : instance of SourceEstimate | None
An optional distributed source estimate to show alongside the fieldmap.
%(rank)s
show_density : bool
Whether to show the density of the fieldmap.
Expand All @@ -80,6 +78,7 @@ def dipolefit(
bem=bem,
initial_time=initial_time,
trans=trans,
stc=None,
rank=rank,
show_density=show_density,
subject=subject,
Expand Down Expand Up @@ -108,6 +107,8 @@ class DipoleFitUI:
trans : instance of Transform | None
The transformation from head coordinates to MRI coordinates. If ``None``,
the identity matrix is used.
stc : instance of SourceEstimate | None
An optional distributed source estimate to show alongside the fieldmap.
%(rank)s
show_density : bool
Whether to show the density of the fieldmap.
Expand All @@ -125,6 +126,7 @@ def __init__(
bem=None,
initial_time=None,
trans=None,
stc=None,
rank="info",
show_density=True,
subject=None,
Expand Down Expand Up @@ -154,19 +156,12 @@ def __init__(
data = evoked.copy().pick(field_map[0]["ch_names"]).data
initial_time = evoked.times[np.argmax(np.mean(data**2, axis=0))]

# Get transforms to convert all the various meshes to head space.
# Get transforms to convert all the various meshes to MRI space.
head_mri_t = _get_trans(trans, "head", "mri")[0]
to_cf_t = _get_transforms_to_coord_frame(
evoked.info, head_mri_t, coord_frame="head"
evoked.info, head_mri_t, coord_frame="mri"
)

# Transform the fieldmap surfaces to head space if needed.
if trans is not None:
for fm in field_map:
fm["surf"] = transform_surface_to(
fm["surf"], "head", [to_cf_t["mri"], to_cf_t["head"]], copy=False
)

# Initialize all the private attributes.
self._actors = dict()
self._bem = bem
Expand All @@ -179,11 +174,12 @@ def __init__(
self._fig_sensors = None
self._multi_dipole_method = "Multi dipole (MNE)"
self._show_density = show_density
self._stc = stc
self._subjects_dir = subjects_dir
self._subject = subject
self._time_line = None
self._head_mri_t = head_mri_t
self._to_cf_t = to_cf_t
self._trans = trans
self._rank = rank
self._verbose = verbose

Expand All @@ -199,6 +195,22 @@ def dipoles(self):
def _configure_main_display(self):
"""Configure main 3D display of the GUI."""
fig = create_3d_figure((1500, 1020), bgcolor="white", show=True)

self._fig_stc = None
if self._stc is not None:
self._fig_stc = self._stc.plot(
subject=self._subject,
subjects_dir=self._subjects_dir,
surface="white",
hemi="both",
time_viewer=False,
initial_time=self._current_time,
brain_kwargs=dict(units="m"),
figure=fig,
)
fig = self._fig_stc
self._actors["brain"] = fig._actors["data"]

fig = EvokedField(
self._evoked,
self._field_map,
Expand All @@ -216,6 +228,9 @@ def _configure_main_display(self):
focalpoint=fit_sphere_to_headshape(self._evoked.info)[1]
)

if self._stc is not None:
link(self._fig_stc, fig)

for surf_map in fig._surf_maps:
if surf_map["map_kind"] == "meg":
helmet_mesh = surf_map["mesh"]
Expand Down Expand Up @@ -260,7 +275,7 @@ def _configure_main_display(self):
subject=self._subject,
subjects_dir=self._subjects_dir,
bem=self._bem,
coord_frame="head",
coord_frame="mri",
to_cf_t=self._to_cf_t,
alpha=0.2,
)
Expand Down Expand Up @@ -405,13 +420,16 @@ def _on_fit_dipole(self):
evoked_picked.info.normalize_proj()
cov_picked = cov_picked.pick_channels(picks, ordered=False)
cov_picked["projs"] = evoked_picked.info["projs"]
# Do we need to set the rank?
# for k, v in self._rank.items():
# self._rank[k] = min(v, len(cov_picked.ch_names))
evoked_picked.crop(self._current_time, self._current_time)

dip = fit_dipole(
evoked_picked,
cov_picked,
self._bem,
trans=self._trans,
trans=self._head_mri_t,
rank=self._rank,
verbose=False,
)[0]
Expand Down Expand Up @@ -512,7 +530,10 @@ def add_dipole(self, dipole):
for dipole_dict in new_dipoles:
dip = dipole_dict["dip"]
dipole_dict["brain_arrow_actor"] = self._renderer.plotter.add_arrows(
dip.pos[0], dip.ori[0], color=dipole_dict["color"], mag=0.05
apply_trans(self._head_mri_t, dip.pos[0]),
dip.ori[0],
color=dipole_dict["color"],
mag=0.05,
)
if arrow_mesh is not None:
dipole_dict["helmet_arrow_actor"] = self._renderer.plotter.add_mesh(
Expand All @@ -532,7 +553,7 @@ def _get_helmet_coords(self, dip):
return None, None

# Get the closest vertex (=point) of the helmet mesh
dip_pos = dip.pos[0]
dip_pos = apply_trans(self._head_mri_t, dip.pos[0])
helmet = self._actors["helmet"].GetMapper().GetInput()
distances = ((helmet.points - dip_pos) * helmet.point_normals).sum(axis=1)
closest_point = np.argmin(distances)
Expand Down Expand Up @@ -565,7 +586,7 @@ def _fit_timecourses(self):
[d["dip"] for d in active_dips],
self._bem,
self._evoked.info,
trans=self._trans,
trans=self._head_mri_t,
)
fwd = convert_forward_solution(fwd, surf_ori=False)

Expand Down Expand Up @@ -603,7 +624,7 @@ def _fit_timecourses(self):
self._bem,
pos=dip["dip"].pos[0], # position is always fixed
ori=dip["dip"].ori[0] if dip["fix_ori"] else None,
trans=self._trans,
trans=self._head_mri_t,
rank=self._rank,
verbose=False,
)
Expand Down

0 comments on commit b1ee63a

Please sign in to comment.