From e750d7daa6c5d17158f75804032aa300598264a2 Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Thu, 25 Jul 2024 07:32:17 -0700 Subject: [PATCH 1/6] FIX: explicitly set weights_only to avoid FutureWarning - If I'm understanding the Torch 2.4 changelog correctly, you just need to explicitly pass a value to weights_only. Since the default is alraedy False, I am just explicitly setting it here so this should be backward compatible --- mne_icalabel/iclabel/network/torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_icalabel/iclabel/network/torch.py b/mne_icalabel/iclabel/network/torch.py index 1a9b493e5..53dbc8f1a 100644 --- a/mne_icalabel/iclabel/network/torch.py +++ b/mne_icalabel/iclabel/network/torch.py @@ -198,7 +198,7 @@ def _run_iclabel(images: ArrayLike, psds: ArrayLike, autocorr: ArrayLike) -> NDA # load weights network_file = files("mne_icalabel.iclabel.network") / "assets" / "ICLabelNet.pt" iclabel_net = ICLabelNet() - iclabel_net.load_state_dict(torch.load(network_file)) + iclabel_net.load_state_dict(torch.load(network_file, weights_only=False)) # format inputs and run forward pass labels = iclabel_net( *_format_input_for_torch(*_format_input(images, psds, autocorr)) From a54d9fb1dde8d68651545e1477b5b07273d90857 Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Thu, 25 Jul 2024 08:01:36 -0700 Subject: [PATCH 2/6] FIX: fix other uses of torch.load in codebase --- mne_icalabel/iclabel/network/tests/test_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne_icalabel/iclabel/network/tests/test_network.py b/mne_icalabel/iclabel/network/tests/test_network.py index f0626ec9e..ca6008d0f 100644 --- a/mne_icalabel/iclabel/network/tests/test_network.py +++ b/mne_icalabel/iclabel/network/tests/test_network.py @@ -36,8 +36,8 @@ def test_weights_pytorch(): """Compare the weights of pytorch model and matconvnet model.""" - network_python = torch.load(torch_iclabel_path) - network_matlab = loadmat(matconvnet_iclabel_path) + network_python = torch.load(torch_iclabel_path, weights_only=False) + network_matlab = loadmat(matconvnet_iclabel_path, weights_only=False) # load weights from matlab network weights_matlab = network_matlab["params"]["value"][0, :] From e92fbdaa9130583418ba97e515df3c754271b7f2 Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Thu, 25 Jul 2024 08:06:20 -0700 Subject: [PATCH 3/6] DOC: Update changelog --- doc/changes/authors.inc | 3 ++- doc/changes/latest.rst | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/changes/authors.inc b/doc/changes/authors.inc index 445bb7509..5c0eaacef 100644 --- a/doc/changes/authors.inc +++ b/doc/changes/authors.inc @@ -1,4 +1,5 @@ .. _Adam Li: https://github.com/adam2392 .. _Mathieu Scheltienne: https://github.com/mscheltienne .. _Jacob Feitelberg: https://github.com/jacobf18 -.. _Anand Saini: https://github.com/anandsaini024 \ No newline at end of file +.. _Anand Saini: https://github.com/anandsaini024 +.. _Scott Huberty: https://github.com/scott-huberty \ No newline at end of file diff --git a/doc/changes/latest.rst b/doc/changes/latest.rst index 374472273..9b9418073 100644 --- a/doc/changes/latest.rst +++ b/doc/changes/latest.rst @@ -18,3 +18,4 @@ Version 0.7 =========== - Raise helpful error message when montage is incomplete (:pr:`181` by `Mathieu Scheltienne`_) +- Explicitly pass ``weights_only=False`` in all instances of ``torch.load`` used by mne-icalabel, to suppress a warning in PyTorch 2.4 (:pr:`193` by `Scott Huberty`_) From 53fb65a0e3765780af288d6b403ba03b28fdf971 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Thu, 25 Jul 2024 17:30:57 +0200 Subject: [PATCH 4/6] Update mne_icalabel/iclabel/network/tests/test_network.py --- mne_icalabel/iclabel/network/tests/test_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_icalabel/iclabel/network/tests/test_network.py b/mne_icalabel/iclabel/network/tests/test_network.py index ca6008d0f..744805034 100644 --- a/mne_icalabel/iclabel/network/tests/test_network.py +++ b/mne_icalabel/iclabel/network/tests/test_network.py @@ -37,7 +37,7 @@ def test_weights_pytorch(): """Compare the weights of pytorch model and matconvnet model.""" network_python = torch.load(torch_iclabel_path, weights_only=False) - network_matlab = loadmat(matconvnet_iclabel_path, weights_only=False) + network_matlab = loadmat(matconvnet_iclabel_path) # load weights from matlab network weights_matlab = network_matlab["params"]["value"][0, :] From 2069de6d2edc6252d9cd5ab414a2f95b4497fa6f Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Thu, 25 Jul 2024 17:33:41 +0200 Subject: [PATCH 5/6] fix one more --- mne_icalabel/iclabel/network/tests/test_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_icalabel/iclabel/network/tests/test_network.py b/mne_icalabel/iclabel/network/tests/test_network.py index 744805034..d300d6f42 100644 --- a/mne_icalabel/iclabel/network/tests/test_network.py +++ b/mne_icalabel/iclabel/network/tests/test_network.py @@ -119,7 +119,7 @@ def test_network_outputs_pytorch(): # run the forward pass on pytorch iclabel_net = ICLabelNet() - iclabel_net.load_state_dict(torch.load(torch_iclabel_path)) + iclabel_net.load_state_dict(torch.load(torch_iclabel_path, weights_only=False)) torch_labels = iclabel_net(images, psd, autocorr) torch_labels = torch_labels.detach().numpy() # (30, 7) From 433f5a7129e030c47dd4ca8659123692d281fda9 Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Thu, 25 Jul 2024 10:52:49 -0700 Subject: [PATCH 6/6] Change weights_only to True --- doc/changes/latest.rst | 2 +- mne_icalabel/iclabel/network/tests/test_network.py | 4 ++-- mne_icalabel/iclabel/network/torch.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/changes/latest.rst b/doc/changes/latest.rst index 9b9418073..1e8f4ada3 100644 --- a/doc/changes/latest.rst +++ b/doc/changes/latest.rst @@ -18,4 +18,4 @@ Version 0.7 =========== - Raise helpful error message when montage is incomplete (:pr:`181` by `Mathieu Scheltienne`_) -- Explicitly pass ``weights_only=False`` in all instances of ``torch.load`` used by mne-icalabel, to suppress a warning in PyTorch 2.4 (:pr:`193` by `Scott Huberty`_) +- Explicitly pass ``weights_only=True`` in all instances of ``torch.load`` used by mne-icalabel, both to suppress a warning in PyTorch 2.4 and to follow best security practices (:pr:`193` by `Scott Huberty`_) diff --git a/mne_icalabel/iclabel/network/tests/test_network.py b/mne_icalabel/iclabel/network/tests/test_network.py index d300d6f42..99f61f506 100644 --- a/mne_icalabel/iclabel/network/tests/test_network.py +++ b/mne_icalabel/iclabel/network/tests/test_network.py @@ -36,7 +36,7 @@ def test_weights_pytorch(): """Compare the weights of pytorch model and matconvnet model.""" - network_python = torch.load(torch_iclabel_path, weights_only=False) + network_python = torch.load(torch_iclabel_path, weights_only=True) network_matlab = loadmat(matconvnet_iclabel_path) # load weights from matlab network @@ -119,7 +119,7 @@ def test_network_outputs_pytorch(): # run the forward pass on pytorch iclabel_net = ICLabelNet() - iclabel_net.load_state_dict(torch.load(torch_iclabel_path, weights_only=False)) + iclabel_net.load_state_dict(torch.load(torch_iclabel_path, weights_only=True)) torch_labels = iclabel_net(images, psd, autocorr) torch_labels = torch_labels.detach().numpy() # (30, 7) diff --git a/mne_icalabel/iclabel/network/torch.py b/mne_icalabel/iclabel/network/torch.py index 53dbc8f1a..0e16b3022 100644 --- a/mne_icalabel/iclabel/network/torch.py +++ b/mne_icalabel/iclabel/network/torch.py @@ -198,7 +198,7 @@ def _run_iclabel(images: ArrayLike, psds: ArrayLike, autocorr: ArrayLike) -> NDA # load weights network_file = files("mne_icalabel.iclabel.network") / "assets" / "ICLabelNet.pt" iclabel_net = ICLabelNet() - iclabel_net.load_state_dict(torch.load(network_file, weights_only=False)) + iclabel_net.load_state_dict(torch.load(network_file, weights_only=True)) # format inputs and run forward pass labels = iclabel_net( *_format_input_for_torch(*_format_input(images, psds, autocorr))