diff --git a/doc/changes/authors.inc b/doc/changes/authors.inc index 445bb7509..5c0eaacef 100644 --- a/doc/changes/authors.inc +++ b/doc/changes/authors.inc @@ -1,4 +1,5 @@ .. _Adam Li: https://github.com/adam2392 .. _Mathieu Scheltienne: https://github.com/mscheltienne .. _Jacob Feitelberg: https://github.com/jacobf18 -.. _Anand Saini: https://github.com/anandsaini024 \ No newline at end of file +.. _Anand Saini: https://github.com/anandsaini024 +.. _Scott Huberty: https://github.com/scott-huberty \ No newline at end of file diff --git a/doc/changes/latest.rst b/doc/changes/latest.rst index 374472273..1e8f4ada3 100644 --- a/doc/changes/latest.rst +++ b/doc/changes/latest.rst @@ -18,3 +18,4 @@ Version 0.7 =========== - Raise helpful error message when montage is incomplete (:pr:`181` by `Mathieu Scheltienne`_) +- Explicitly pass ``weights_only=True`` in all instances of ``torch.load`` used by mne-icalabel, both to suppress a warning in PyTorch 2.4 and to follow best security practices (:pr:`193` by `Scott Huberty`_) diff --git a/mne_icalabel/iclabel/network/tests/test_network.py b/mne_icalabel/iclabel/network/tests/test_network.py index f0626ec9e..99f61f506 100644 --- a/mne_icalabel/iclabel/network/tests/test_network.py +++ b/mne_icalabel/iclabel/network/tests/test_network.py @@ -36,7 +36,7 @@ def test_weights_pytorch(): """Compare the weights of pytorch model and matconvnet model.""" - network_python = torch.load(torch_iclabel_path) + network_python = torch.load(torch_iclabel_path, weights_only=True) network_matlab = loadmat(matconvnet_iclabel_path) # load weights from matlab network @@ -119,7 +119,7 @@ def test_network_outputs_pytorch(): # run the forward pass on pytorch iclabel_net = ICLabelNet() - iclabel_net.load_state_dict(torch.load(torch_iclabel_path)) + iclabel_net.load_state_dict(torch.load(torch_iclabel_path, weights_only=True)) torch_labels = iclabel_net(images, psd, autocorr) torch_labels = torch_labels.detach().numpy() # (30, 7) diff --git a/mne_icalabel/iclabel/network/torch.py b/mne_icalabel/iclabel/network/torch.py index 1a9b493e5..0e16b3022 100644 --- a/mne_icalabel/iclabel/network/torch.py +++ b/mne_icalabel/iclabel/network/torch.py @@ -198,7 +198,7 @@ def _run_iclabel(images: ArrayLike, psds: ArrayLike, autocorr: ArrayLike) -> NDA # load weights network_file = files("mne_icalabel.iclabel.network") / "assets" / "ICLabelNet.pt" iclabel_net = ICLabelNet() - iclabel_net.load_state_dict(torch.load(network_file)) + iclabel_net.load_state_dict(torch.load(network_file, weights_only=True)) # format inputs and run forward pass labels = iclabel_net( *_format_input_for_torch(*_format_input(images, psds, autocorr))