Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13494.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug where :meth:`mne.channels.DigMontage.plot` would error when ``axes`` was passed by `Christian O'Reilly`_.
14 changes: 8 additions & 6 deletions mne/viz/montage.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,19 @@ def plot_montage(
)

if scale != 1.0:
axes = axes if axes else fig.axes[0]

# scale points
collection = fig.axes[0].collections[0]
collection = axes.collections[0]
collection.set_sizes([scale * 10])

# scale labels
labels = fig.findobj(match=plt.Text)
x_label, y_label = fig.axes[0].xaxis.label, fig.axes[0].yaxis.label
z_label = fig.axes[0].zaxis.label if kind == "3d" else None
tick_labels = fig.axes[0].get_xticklabels() + fig.axes[0].get_yticklabels()
labels = axes.findobj(match=plt.Text)
x_label, y_label = axes.xaxis.label, axes.yaxis.label
z_label = axes.zaxis.label if kind == "3d" else None
tick_labels = axes.get_xticklabels() + axes.get_yticklabels()
if kind == "3d":
tick_labels += fig.axes[0].get_zticklabels()
tick_labels += axes.get_zticklabels()
for label in labels:
if label not in [x_label, y_label, z_label] + tick_labels:
label.set_fontsize(label.get_fontsize() * scale)
Expand Down
14 changes: 14 additions & 0 deletions mne/viz/tests/test_montage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest

from mne.channels import make_dig_montage, make_standard_montage, read_dig_fif
from mne.io import RawArray, create_info

p_dir = Path(__file__).parents[2] / "io" / "kit" / "tests" / "data"
elp = p_dir / "test_elp.txt"
Expand Down Expand Up @@ -86,3 +87,16 @@ def test_plot_digmontage():
)
montage.plot()
plt.close("all")


def test_plot_montage_scale():
"""Test montage.plot with non-default scale using subplot axes."""
montage = make_standard_montage("GSN-HydroCel-129")
_, ax = plt.subplots(2, 1)[1][1]
picks = montage.ch_names
info = create_info(montage.ch_names, sfreq=256, ch_types="eeg")
raw = RawArray(
np.zeros((len(montage.ch_names), 1)), info, copy=None, verbose=False
).set_montage(montage)
# test for gh-13438
raw.pick(picks).get_montage().plot(axes=ax, show_names=False, scale=0.1)
Loading