-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix load_from_checkpoint
to return model on correct device
#17308
Merged
Borda
merged 24 commits into
Lightning-AI:master
from
ryan597:bug/17304_load_checkpoint_location
Apr 15, 2023
Merged
Fix load_from_checkpoint
to return model on correct device
#17308
Borda
merged 24 commits into
Lightning-AI:master
from
ryan597:bug/17304_load_checkpoint_location
Apr 15, 2023
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ryan597
force-pushed
the
bug/17304_load_checkpoint_location
branch
from
April 9, 2023 17:52
24def53
to
0b56d6e
Compare
ryan597
requested review from
awaelchli,
carmocca,
justusschock and
williamFalcon
as code owners
April 9, 2023 22:00
ryan597
commented
Apr 9, 2023
carmocca
approved these changes
Apr 10, 2023
carmocca
added
bug
Something isn't working
community
This PR is from the community
labels
Apr 10, 2023
carmocca
reviewed
Apr 10, 2023
ryan597
force-pushed
the
bug/17304_load_checkpoint_location
branch
from
April 11, 2023 21:13
e67b8c2
to
6f924c3
Compare
Borda
changed the title
Fix load_from_checkpoint to return model on correct device
Fix Apr 14, 2023
load_from_checkpoint
to return model on correct device
Borda
approved these changes
Apr 14, 2023
carmocca
approved these changes
Apr 14, 2023
map_location was not used to move created LightningModules We now find the restore location and map the Module to it
for more information, see https://pre-commit.ci
code checker was failing as load_state can return LightningDataModule which doesn't have .device attribute
Additional tests for this too
model now loads on same device it was saved from no longer need to move to cpu to check params
auto-merge was automatically disabled
April 14, 2023 12:56
Head branch was pushed to by a user without write access
ryan597
force-pushed
the
bug/17304_load_checkpoint_location
branch
from
April 14, 2023 12:56
d2caff6
to
3cc815d
Compare
Borda
pushed a commit
that referenced
this pull request
Apr 24, 2023
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit e1ce887)
lantiga
pushed a commit
that referenced
this pull request
Apr 24, 2023
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit e1ce887)
Borda
pushed a commit
that referenced
this pull request
Apr 24, 2023
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit e1ce887)
Borda
pushed a commit
that referenced
this pull request
Apr 24, 2023
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit e1ce887)
lantiga
pushed a commit
that referenced
this pull request
Apr 26, 2023
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit e1ce887)
7 tasks
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes #17304
When using
the returned model is always on the CPU due to
_load_from_checkpoint
not using the map_location on the returned object.https://github.com/Lightning-AI/lightning/blob/fd4697c62c059fc7b9946e84d91625ecb6efdbe5/src/lightning/pytorch/core/saving.py#L51-L92
This PR returns the created model with the correct mapped location.
Also create tests for checking that the model is loaded onto GPU after the checkpoint load. I have included tests for mapping to the CPU for completeness although these do not fail currently fail, only the GPU tests do.
Created tests failing on master
All tests checking for model to be on
cuda
will fail on current master branch.All tests checking for model to be on
cpu
will pass on current branch, but included for completeness.What Can be Improved
When the object is created on the GPU without setting map_location (i.e. when the model checkpoint is from the GPU, shouldn't it automatically load onto the GPU as when you use
torch.load("boring.pth")
?).So for this the issue is when the object is created in
_load_state()
it is not created on or moved to the GPU, thus whenmap_location
is set tolambda storage, loc: storage
(as it is when map_location=None) it just returns the object on the CPU.If I try to retrieve this information from the
checkpoint['state_dict']
, it requires some ugly work that I'm not sure will work in all cases as I have only tested with the BoringModel so farThis works (for the BoringModel anyways), but I'm sure it can be done much better if someone has better ideas.
Failing tests / Breaking changes
tests/tests_pytorch/strategies/test_fsdp.py
contains tests which are failing. It loads using checkpoints with the changed functionload_from_checkpoint
, however in the assertion, it moves one of the tensors to thecpu
resulting in the assertion error because the loaded checkpoint is now correctly oncuda
.https://github.com/Lightning-AI/lightning/blob/fd4697c62c059fc7b9946e84d91625ecb6efdbe5/tests/tests_pytorch/strategies/test_fsdp.py#L142-L152
In changing this assertion to keep both state_dicts on their devices all tests pass.
If there's other cases where users have similar conditions or rely on the current behavior this PR could break them, although it would only be unexpected for the case of
map_location=None
.Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist