Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix load_from_checkpoint to return model on correct device #17308

Merged
merged 24 commits into from
Apr 15, 2023

Conversation

ryan597
Copy link
Contributor

@ryan597 ryan597 commented Apr 9, 2023

What does this PR do?

Fixes #17304

When using

model = BoringModel.load_from_checkpoint('path.ckpt', map_location='cuda')

the returned model is always on the CPU due to _load_from_checkpoint not using the map_location on the returned object.

https://github.com/Lightning-AI/lightning/blob/fd4697c62c059fc7b9946e84d91625ecb6efdbe5/src/lightning/pytorch/core/saving.py#L51-L92

This PR returns the created model with the correct mapped location.

Also create tests for checking that the model is loaded onto GPU after the checkpoint load. I have included tests for mapping to the CPU for completeness although these do not fail currently fail, only the GPU tests do.

Created tests failing on master

All tests checking for model to be on cuda will fail on current master branch.
All tests checking for model to be on cpu will pass on current branch, but included for completeness.

import pytest
import torch

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
from tests_pytorch.helpers.runif import RunIf


def create_boring_checkpoint(tmpdir, accelerator="gpu", testmodel=BoringModel):
    checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="checkpoint")
    model = testmodel()
    trainer = pl.Trainer(
        devices=1,
        accelerator=accelerator,
        max_epochs=1,
        enable_model_summary=False,
        enable_progress_bar=False,
        callbacks=[checkpoint_callback],
    )
    trainer.fit(model)


@pytest.mark.parametrize(
    "map_location", (None, "cpu", torch.device("cpu"), lambda storage, loc: storage, {"cpu": "cpu"})
)
def test_load_from_checkpoint_map_location_cpu(tmpdir, map_location, testmodel=BoringModel):
    create_boring_checkpoint(tmpdir, accelerator="cpu", testmodel=testmodel)
    model = testmodel.load_from_checkpoint(f"{tmpdir}/checkpoint.ckpt", map_location=map_location)
    assert model.device.type == "cpu"


@RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize(
    "map_location", (None, "cuda", torch.device("cuda"), lambda storage, loc: storage.cuda(), {"cpu": "cuda"})
)
def test_load_from_checkpoint_map_location_gpu(tmpdir, map_location, testmodel=BoringModel):
    create_boring_checkpoint(tmpdir, accelerator="gpu", testmodel=testmodel)
    model = testmodel.load_from_checkpoint(f"{tmpdir}/checkpoint.ckpt", map_location=map_location)
    assert model.device.type == "cuda"


@RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize("map_location", ("cpu", torch.device("cpu"), lambda storage, loc: storage, {"cuda": "cpu"}))
def test_load_from_checkpoint_map_location_gpu_to_cpu(tmpdir, map_location, testmodel=BoringModel):
    create_boring_checkpoint(tmpdir, accelerator="gpu", testmodel=testmodel)
    model = testmodel.load_from_checkpoint(f"{tmpdir}/checkpoint.ckpt", map_location=map_location)
    assert model.device.type == "cpu"


@RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize(
    "map_location", ("cuda", torch.device("cuda"), lambda storage, loc: storage.cuda(), {"cpu": "cuda"})
)
def test_load_from_checkpoint_map_location_cpu_to_gpu(tmpdir, map_location, testmodel=BoringModel):
    create_boring_checkpoint(tmpdir, accelerator="cpu", testmodel=testmodel)
    model = testmodel.load_from_checkpoint(f"{tmpdir}/checkpoint.ckpt", map_location=map_location)
    assert model.device.type == "cuda"

What Can be Improved

When the object is created on the GPU without setting map_location (i.e. when the model checkpoint is from the GPU, shouldn't it automatically load onto the GPU as when you use torch.load("boring.pth")?).

import torch
import torch.nn as nn

class TestModel(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.fc = nn.Linear(32, 2)

    def forward(self, x):
        return self.fc(x)

testmodel = TestModel().to("cuda")

torch.save(testmodel, "torch.pth")
testmodel = torch.load("torch.pth")
print(next(testmodel.parameters()).device)  # Prints cuda:0
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel

checkpoint_callback = ModelCheckpoint(dirpath='', filename="boring")

model = BoringModel()
trainer = pl.Trainer(
    max_epochs=1,
    enable_progress_bar=False,
    enable_model_summary=False,
    callbacks=[checkpoint_callback],
    accelerator="gpu",
    devices=1
)
trainer.fit(model)

model = BoringModel.load_from_checkpoint("boring.ckpt") 
print(model.device.type)  # Should print cuda but prints cpu

So for this the issue is when the object is created in _load_state() it is not created on or moved to the GPU, thus when map_location is set to lambda storage, loc: storage (as it is when map_location=None) it just returns the object on the CPU.

If I try to retrieve this information from the checkpoint['state_dict'], it requires some ugly work that I'm not sure will work in all cases as I have only tested with the BoringModel so far

...
  if issubclass(cls, pl.LightningModule):
      loaded_location = list(checkpoint["state_dict"].items())[0][1].device  # This could be done better?
      if map_location is None:
          map_location = loaded_location
  
      storage = _load_state(cls, checkpoint, strict=strict, **kwargs)
      restore_location = torch.serialization._get_restore_location(map_location)
      if isinstance(map_location, dict) and isinstance(storage, pl.LightningModule):
          return restore_location(storage, str(storage.device))
  
      return restore_location(storage, map_location)

This works (for the BoringModel anyways), but I'm sure it can be done much better if someone has better ideas.

Failing tests / Breaking changes

tests/tests_pytorch/strategies/test_fsdp.py contains tests which are failing. It loads using checkpoints with the changed function load_from_checkpoint, however in the assertion, it moves one of the tensors to the cpu resulting in the assertion error because the loaded checkpoint is now correctly on cuda.

https://github.com/Lightning-AI/lightning/blob/fd4697c62c059fc7b9946e84d91625ecb6efdbe5/tests/tests_pytorch/strategies/test_fsdp.py#L142-L152

In changing this assertion to keep both state_dicts on their devices all tests pass.

- assert torch.equal(ddp_param.float().cpu(), shard_param) 
+ assert torch.equal(ddp_param, shard_param) 

If there's other cases where users have similar conditions or rely on the current behavior this PR could break them, although it would only be unexpected for the case of map_location=None.

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Apr 9, 2023
@ryan597 ryan597 force-pushed the bug/17304_load_checkpoint_location branch from 24def53 to 0b56d6e Compare April 9, 2023 17:52
@ryan597 ryan597 marked this pull request as ready for review April 9, 2023 22:00
@carmocca carmocca added bug Something isn't working community This PR is from the community labels Apr 10, 2023
@carmocca carmocca added this to the v1.9.x milestone Apr 10, 2023
src/lightning/pytorch/core/saving.py Outdated Show resolved Hide resolved
src/lightning/pytorch/core/saving.py Outdated Show resolved Hide resolved
@ryan597 ryan597 force-pushed the bug/17304_load_checkpoint_location branch from e67b8c2 to 6f924c3 Compare April 11, 2023 21:13
@Borda Borda changed the title Fix load_from_checkpoint to return model on correct device Fix load_from_checkpoint to return model on correct device Apr 14, 2023
@mergify mergify bot removed the has conflicts label Apr 14, 2023
src/lightning/pytorch/core/saving.py Outdated Show resolved Hide resolved
@mergify mergify bot added the ready PRs ready to be merged label Apr 14, 2023
src/lightning/pytorch/core/saving.py Outdated Show resolved Hide resolved
@carmocca carmocca enabled auto-merge (squash) April 14, 2023 12:27
ryan597 and others added 9 commits April 14, 2023 13:56
auto-merge was automatically disabled April 14, 2023 12:56

Head branch was pushed to by a user without write access

@ryan597 ryan597 force-pushed the bug/17304_load_checkpoint_location branch from d2caff6 to 3cc815d Compare April 14, 2023 12:56
@Borda Borda enabled auto-merge (squash) April 14, 2023 21:54
@Borda Borda merged commit e1ce887 into Lightning-AI:master Apr 15, 2023
Borda pushed a commit that referenced this pull request Apr 24, 2023
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
(cherry picked from commit e1ce887)
lantiga pushed a commit that referenced this pull request Apr 24, 2023
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
(cherry picked from commit e1ce887)
Borda pushed a commit that referenced this pull request Apr 24, 2023
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
(cherry picked from commit e1ce887)
Borda pushed a commit that referenced this pull request Apr 24, 2023
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
(cherry picked from commit e1ce887)
lantiga pushed a commit that referenced this pull request Apr 26, 2023
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
(cherry picked from commit e1ce887)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working community This PR is from the community pl Generic label for PyTorch Lightning package ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

load_from_checkpoint does not work as expected
3 participants