Skip to content

Commit be93b77

Browse files
FIX Unpickling without using torch.load (#1092)
Resolves #1090. PyTorch plans to make the switch to weights_only=True for torch.load. We already partly dealt with that in #1064 when it comes to save_params/load_params. However, we still had a gap. Namely, when using pickle directly, i.e. when going through __getstate__ and __setstate__, we are still using torch.load and torch.save without handling weights_only. This will cause trouble in the future when the default is switched. But it's also annoying right now, because users will get the FutureWarning about weights_only, even if they correctly pass torch_load_kwargs (see #1090). The reason why we use torch.save/torch.load for pickle is that those functions are basically _extended_ pickle functions that have the benefit of supporting the map_location argument to handle the device of torch tensors, which we don't have for pickle. The map_location argument is important, e.g. when saving a net that uses CUDA and loading it on a machine without CUDA, we would otherwise run into an error. However, with the move to weights_only=True, these torch.save/torch.load will become _reduced_ pickle functions, as they will only support a small subset of objects by default. Therefore, we wouldn't be able to rely on torch.save/torch.load for pickling the whole skorch object. In this PR, we thus move to using plain pickle for this. However, now we run into the issue of how to handle the map_location. The solution I ended up with is now to intercept torch's _load_from_bytes using a custom Unpickler, and to specifically use torch.load there. That way, we can pass the map_location and other torch_load_kwargs. The remaining unpickling process just works as normal. Yes, this is a private function, so we cannot be sure if it'll work indefinitely, If there is a better suggestion, I'm open to it. However, the function has existed for 7 years, so it's not very likely that it will change anytime soon: https://github.com/pytorch/pytorch/blame/0674ab7e33c3f627ca6781ce98468ec1dd4743a5/torch/storage.py#L525 A drawback of the solution is that we cannot just load old skorch nets that were saved with torch.save using pickle.load. This is because torch uses custom persistent_load functions. When trying to load with pickle, we thus get: _pickle.UnpicklingError: A load persistent id instruction was encountered, but no persistent_load function was specified. Therefore, I had to keep torch.load as a fallback to avoid backwards incompatibility. The bad news is that the initial problem persists, namely that even when passing torch_load_kwargs, users get the FutureWarning about weights_only. The good news is that users can just re-save their net with the new skorch version and from then on they won't see the warning again. Note that I didn't add a specific test for this problem of loading backwards nets from before the change, because test_pickle_load, which uses a checked in pickle file, already covers this. Other considered solutions: 1. Why not continue using torch.save/torch.load and just pass the torch_load_kwargs argument to it? This is unforunately not that easy. When switching to weights_only=True, torch will refuse to load any custom objects, e.g. class MyModule. There is a way to prevent that, namely via torch.serialization.add_safe_globals, but it is a ton of work to add all required objects there, as even builtin Python types are mostly not supported. 2. We cannot use with torch.device, as this is not honored during unpickling. 3. During __getstate__, we could recursively go through the state, pop all torch tensors, and replace them with, say, numpy arrays and additional meta data like the device, then use this info to restore those objects during __setstate__. Even though this looks like a cleaner solution, it is much more complex and therefore, I'd argue more error prone. 4. Don't do anything and just live with the warning: This will work -- until PyTorch switches the default. Therefore, we had to tackle this sooner or later. Notes While working on this, I thought that we could most likely remove the cuda_dependent_attributes_ (which contains the net.module_, net.optimizer_, etc.). Their purpose was to call torch.load on these attributes specifically, but with the new Unpickler, it should also work without this. However, I kept the attribute for now, mainly for these reasons: 1. I didn't want to change more than necessary, as these changes are delicate and I don't to break any existing skorch code or pickle files. 2. The attribute itself is public, so in theory, users may rely on its existence (not sure if in practice). We would thus have to keep most of the code related to this attribute. But LMK if you think we should deprecate and eventually remove this attribute.
1 parent bb1bac4 commit be93b77

File tree

4 files changed

+90
-3
lines changed

4 files changed

+90
-3
lines changed

Diff for: CHANGES.md

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## [Unreleased]
99
### Added
1010
### Changed
11+
12+
- Loading of skorch nets using pickle: When unpickling a skorch net, you may come across a PyTorch warning that goes: "FutureWarning: You are using torch.load with weights_only=False [...]"; to avoid this warning, pickle the net again and use the new pickle file (#1092)
13+
1114
### Fixed
1215

1316
## [1.1.0]

Diff for: skorch/net.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from collections import OrderedDict
1414
from contextlib import contextmanager
1515
import os
16+
import pickle
1617
import tempfile
1718
import warnings
1819

@@ -33,6 +34,7 @@
3334
from skorch.exceptions import SkorchTrainingImpossibleError
3435
from skorch.history import History
3536
from skorch.setter import optimizer_setter
37+
from skorch.utils import _TorchLoadUnpickler
3638
from skorch.utils import _identity
3739
from skorch.utils import _infer_predict_nonlinearity
3840
from skorch.utils import FirstStepAccumulator
@@ -2242,7 +2244,7 @@ def __getstate__(self):
22422244
state.pop(k)
22432245

22442246
with tempfile.SpooledTemporaryFile() as f:
2245-
torch.save(cuda_attrs, f)
2247+
pickle.dump(cuda_attrs, f)
22462248
f.seek(0)
22472249
state['__cuda_dependent_attributes__'] = f.read()
22482250

@@ -2254,11 +2256,26 @@ def __setstate__(self, state):
22542256
map_location = get_map_location(state['device'])
22552257
load_kwargs = {'map_location': map_location}
22562258
state['device'] = self._check_device(state['device'], map_location)
2259+
torch_load_kwargs = state.get('torch_load_kwargs') or get_default_torch_load_kwargs()
22572260

22582261
with tempfile.SpooledTemporaryFile() as f:
2262+
unpickler = _TorchLoadUnpickler(
2263+
f,
2264+
map_location=map_location,
2265+
torch_load_kwargs=torch_load_kwargs,
2266+
)
22592267
f.write(state['__cuda_dependent_attributes__'])
22602268
f.seek(0)
2261-
cuda_attrs = torch.load(f, **load_kwargs)
2269+
try:
2270+
cuda_attrs = unpickler.load()
2271+
except pickle.UnpicklingError:
2272+
# This object was saved using skorch from before switching to the
2273+
# custom unpickler, i.e. with torch.save. Fall back to the old loading
2274+
# code using torch.load. Unfortunately, this means that the user may
2275+
# get the FutureWarning about weights_only=False. They need to re-save
2276+
# the net to get rid of the warning
2277+
f.seek(0)
2278+
cuda_attrs = torch.load(f, **load_kwargs)
22622279

22632280
state.update(cuda_attrs)
22642281
state.pop('__cuda_dependent_attributes__')

Diff for: skorch/tests/test_net.py

+35
Original file line numberDiff line numberDiff line change
@@ -3081,6 +3081,7 @@ def test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6(
30813081
# See discussion in 1063.
30823082
from skorch._version import Version
30833083

3084+
# TODO remove once torch 2.5.0 is no longer supported
30843085
if Version(torch.__version__) >= Version('2.6.0'):
30853086
pytest.skip("Test only for torch < v2.6.0")
30863087

@@ -3097,6 +3098,40 @@ def test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6(
30973098
del call_kwargs['map_location'] # we're not interested in that
30983099
assert call_kwargs == expected_kwargs
30993100

3101+
def test_torch_load_kwargs_forwarded_to_torch_load_unpickle(
3102+
self, net_cls, module_cls, monkeypatch, tmp_path
3103+
):
3104+
# See discussion in 1090
3105+
# Here we check that custom set torch load args are forwarded to
3106+
# torch.load even when using pickle. This is the same test otherwise as
3107+
# test_torch_load_kwargs_forwarded_to_torch_load
3108+
expected_kwargs = {'weights_only': 123, 'foo': 'bar'}
3109+
net = net_cls(module_cls, torch_load_kwargs=expected_kwargs).initialize()
3110+
3111+
original_torch_load = torch.load
3112+
# call original torch.load without extra params to prevent error:
3113+
mock_torch_load = Mock(
3114+
side_effect=lambda *args, **kwargs: original_torch_load(*args)
3115+
)
3116+
monkeypatch.setattr(torch, "load", mock_torch_load)
3117+
dumped = pickle.dumps(net)
3118+
pickle.loads(dumped)
3119+
3120+
call_kwargs = mock_torch_load.call_args_list[0].kwargs
3121+
del call_kwargs['map_location'] # we're not interested in that
3122+
assert call_kwargs == expected_kwargs
3123+
3124+
def test_unpickle_no_pytorch_warning(self, net_cls, module_cls, recwarn):
3125+
# See discussion 1090
3126+
# When using pickle, i.e. when going through __setstate__, we don't want to get
3127+
# any warnings about the usage of weights_only.
3128+
net = net_cls(module_cls).initialize()
3129+
dumped = pickle.dumps(net)
3130+
pickle.loads(dumped)
3131+
3132+
msg_content = "weights_only"
3133+
assert not any(msg_content in str(w.message) for w in recwarn.list)
3134+
31003135
def test_custom_module_params_passed_to_optimizer(
31013136
self, net_custom_module_cls, module_cls):
31023137
# custom module parameters should automatically be passed to the optimizer

Diff for: skorch/utils.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
import io
1212
from itertools import tee
1313
import pathlib
14+
import pickle
1415
import warnings
1516

1617
import numpy as np
1718
from scipy import sparse
18-
import sklearn
1919
from sklearn.exceptions import NotFittedError
2020
from sklearn.utils import _safe_indexing as safe_indexing
2121
from sklearn.utils.validation import check_is_fitted as sk_check_is_fitted
@@ -784,3 +784,35 @@ def get_default_torch_load_kwargs():
784784
if version_torch >= version_default_switch:
785785
return {"weights_only": True}
786786
return {"weights_only": False}
787+
788+
789+
class _TorchLoadUnpickler(pickle.Unpickler):
790+
"""
791+
Subclass of pickle.Unpickler that intercepts 'torch.storage._load_from_bytes' calls
792+
and uses `torch.load(..., map_location=..., torch_load_kwargs=...)`.
793+
794+
This way, we can use normal pickle when unpickling a skorch net but still benefit
795+
from torch.load to handle the map_location. Note that `with torch.device(...)` does
796+
not work for unpickling.
797+
798+
"""
799+
800+
def __init__(self, *args, map_location, torch_load_kwargs, **kwargs):
801+
super().__init__(*args, **kwargs)
802+
self.map_location = map_location
803+
self.torch_load_kwargs = torch_load_kwargs
804+
805+
def find_class(self, module, name):
806+
# The actual serialized data for PyTorch tensors references
807+
# torch.storage._load_from_bytes internally. We intercept that call:
808+
if (module == 'torch.storage') and (name == '_load_from_bytes'):
809+
# Return a function that uses torch.load with our desired map_location
810+
def _load_from_bytes(b):
811+
return torch.load(
812+
io.BytesIO(b),
813+
map_location=self.map_location,
814+
**self.torch_load_kwargs
815+
)
816+
return _load_from_bytes
817+
818+
return super().find_class(module, name)

0 commit comments

Comments
 (0)