-
Notifications
You must be signed in to change notification settings - Fork 394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX Unpickling without using torch.load #1092
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Resolves #1090. This one got more complicated than I initially thought. So here it goes: PyTorch plans to make the switch to weights_only=True for torch.load. We already partly dealt with that in #1064 when it comes to save_params/load_params. However, we still had a gap. Namely, when using pickle directly, i.e. when going through __getstate__ and __setstate__, we are still using torch.load and torch.save without handling weights_only. This will cause trouble in the future when the default is switched. But it's also annoying right now, because users will get the FutureWarning about weights_only, even if they correctly pass torch_load_kwargs (see #1090). The reason why we use torch.save/torch.load for pickle is that those functions are basically _extended_ pickle functions that have the benefit of supporting the map_location argument to handle the device of torch tensors, which we don't have for pickle. The map_location argument is important, e.g. when saving a net that uses CUDA and loading it on a machine without CUDA, we would otherwise run into an error. However, with the move to weights_only=True, these torch.save/torch.load will become _reduced_ pickle functions, as they will only support a small subset of objects by default. Therefore, we wouldn't be able to rely on torch.save/torch.load for pickling the whole skorch object. In this PR, we thus move to using plain pickle for this. However, now we run into the issue of how to handle the map_location. The solution I ended up with is now to intercept torch's _load_from_bytes using a custom Unpickler, and to specifically use torch.load there. That way, we can pass the map_location and other torch_load_kwargs. The remaining unpickling process just works as normal. Yes, this is a private function, so we cannot be sure if it'll work indefinitely, If there is a better suggestion, I'm open to it. However, the function has existed for 7 years, so it's not very likely that it will change anytime soon: https://github.com/pytorch/pytorch/blame/0674ab7e33c3f627ca6781ce98468ec1dd4743a5/torch/storage.py#L525 A drawback of the solution is that we cannot just load old skorch nets that were saved with torch.save using pickle.load. This is because torch uses custom persistent_load functions. When trying to load with pickle, we thus get: _pickle.UnpicklingError: A load persistent id instruction was encountered, but no persistent_load function was specified. Therefore, I had to keep torch.load as a fallback to avoid backwards incompatibility. The bad news is that the initial problem persists, namely that even when passing torch_load_kwargs, users get the FutureWarning about weights_only. The good news is that users can just re-save their net with the new skorch version and from then on they won't see the warning again. Note that I didn't add a specific test for this problem of loading backwards nets from before the change, because test_pickle_load, which uses a checked in pickle file, already covers this. Other considered solutions: 1. Why not continue using torch.save/torch.load and just pass the torch_load_kwargs argument to it? This is unforunately not that easy. When switching to weights_only=True, torch will refuse to load any custom objects, e.g. class MyModule. There is a way to prevent that, namely via torch.serialization.add_safe_globals, but it is a ton of work to add all required objects there, as even builtin Python types are mostly not supported. 2. We cannot use with torch.device, as this is not honored during unpickling. 3. During __getstate__, we could recursively go through the state, pop all torch tensors, and replace them with, say, numpy arrays and additional meta data like the device, then use this info to restore those objects during __setstate__. Even though this looks like a cleaner solution, it is much more complex and therefore, I'd argue more error prone. Notes While working on this, I thought that we could most likely remove the cuda_dependent_attributes_ (which contains the net.module_, net.optimizer_, etc.). Their purpose was to call torch.load on these attributes specifically, but with the new Unpickler, it should also work without this. However, I kept the attribute for now, mainly for these reasons: 1. I didn't want to change more than necessary, as these changes are delicate and I don't to break any existing skorch code or pickle files. 2. The attribute itself is public, so in theory, users may rely on its existence (not sure if in practice). We would thus have to keep most of the code related to this attribute. But LMK if you think we should deprecate and eventually remove this attribute.
githubnemo
approved these changes
Jan 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, seems to be a good complexity/impact trade-off.
githubnemo
approved these changes
Jan 27, 2025
BenjaminBossan
added a commit
that referenced
this pull request
Jan 31, 2025
- Add test for new default of weights_only - Update pickle file test artifact (explained in #1092) - Update some comments
BenjaminBossan
added a commit
that referenced
this pull request
Feb 4, 2025
- Add torch 2.6.0 to CI - Remove torch 2.2.2 - Update torch install instructions, as they no longer provide conda packages - Add test for new default of weights_only - Update pickle file test artifact (explained in #1092) - Update some comments - Conditionally install triton 3.1 for torch < 2.6
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Resolves #1090, #1091.
This one got more complicated than I initially thought. So here it goes:
Description
PyTorch plans to make the switch to
weights_only=True
for torch.load. We already partly dealt with that in #1064 when it comes tosave_params/load_params
. However, we still had a gap. Namely, when using pickle directly, i.e. when going through__getstate__
and__setstate__
, we are still usingtorch.load
andtorch.save
without handlingweights_only
. This will cause trouble in the future when the default is switched. But it's also annoying right now, because users will get theFutureWarning
aboutweights_only
, even if they correctly passtorch_load_kwargs
(see #1090).The reason why we use
torch.save/torch.load
for pickle is that those functions are basically extended pickle functions that have the benefit of supporting themap_location
argument to handle the device of torch tensors, which we don't have for pickle. Themap_location
argument is important, e.g. when saving a net that uses CUDA and loading it on a machine without CUDA, we would otherwise run into an error.However, with the move to
weights_only=True
, thesetorch.save/torch.load
will become reduced pickle functions, as they will only support a small subset of objects by default. Therefore, we wouldn't be able to rely ontorch.save/torch.load
for pickling the whole skorch object.Solution
(thanks ChatGPT o1 for helping with this)
In this PR, we move to using plain pickle for this. However, now we run into the issue of how to handle the
map_location
. The solution I ended up with is now to intercept torch's_load_from_bytes
using a customUnpickler
, and to specifically usetorch.load
there. That way, we can pass themap_location
and othertorch_load_kwargs
. The remaining unpickling process just works as normal.Yes, this is a private function, so we cannot be sure if it'll work indefinitely, If there is a better suggestion, I'm open to it. However, the function has existed for 7 years, so it's not very likely that it will change anytime soon:
https://github.com/pytorch/pytorch/blame/0674ab7e33c3f627ca6781ce98468ec1dd4743a5/torch/storage.py#L525
A drawback of the solution is that we cannot just load old skorch nets that were saved with
torch.save
usingpickle.load
. This is because torch uses custompersistent_load
functions. When trying to load with pickle, we thus get:Therefore, I had to keep
torch.load
as a fallback to avoid backwards incompatibility. The bad news is that the initial problem persists, namely that even when passingtorch_load_kwargs
, users get theFutureWarning
aboutweights_only
. The good news is that users can just re-save their net with the new skorch version and from then on they won't see the warning again.Note that I didn't add a specific test for this problem of loading nets from before the change, because
test_pickle_load
, which uses a checked in pickled net, already covers this.Other considered solutions
torch.save/torch.load
and just pass thetorch_load_kwargs
argument to it? This is unfortunately not that easy. When switching toweights_only=True
, torch will refuse to load any custom objects, e.g.class MyModule
. There is a way to prevent that, namely viatorch.serialization.add_safe_globals
, but it is a ton of work to add all required objects there, as even builtin Python types are mostly not supported.with torch.device(map_location):
, as this is not honored during unpickling.__getstate__
, we could recursively go through the state, pop all torch tensors, and replace them with, say, numpy arrays and additional meta data like the device, then use this info to restore those objects during__setstate__
. Even though this looks like a cleaner solution, it is much more complex and therefore, I'd argue, more error prone.Notes
While working on this, I thought that we could most likely remove the
cuda_dependent_attributes_
(which contains thenet.module_
,net.optimizer_
, etc.). Their purpose was to calltorch.load
on these attributes specifically, but with the newUnpickler
, it should also work without this. However, I kept the attribute for now, mainly for these reasons:But LMK if you think we should deprecate and eventually remove this attribute.