Skip to content

Commit

Permalink
Fix ModelCheckpoint dirpath expanding home prefix (#19058)
Browse files Browse the repository at this point in the history
(cherry picked from commit 58c905b)
  • Loading branch information
awaelchli authored and lantiga committed Dec 20, 2023
1 parent d381b85 commit 1679458
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue where setting `CHECKPOINT_JOIN_CHAR` or `CHECKPOINT_EQUALS_CHAR` would only work on the `ModelCheckpoint` class but not on an instance ([#19054](https://github.com/Lightning-AI/lightning/pull/19054))


- Fixed `ModelCheckpoint` not expanding the `dirpath` if it has the `~` (home) prefix ([#19058](https://github.com/Lightning-AI/lightning/pull/19058))



## [2.1.2] - 2023-11-15

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) ->
self._fs = get_filesystem(dirpath if dirpath else "")

if dirpath and _is_local_file_protocol(dirpath if dirpath else ""):
dirpath = os.path.realpath(dirpath)
dirpath = os.path.realpath(os.path.expanduser(dirpath))

self.dirpath = dirpath
self.filename = filename
Expand Down
14 changes: 14 additions & 0 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,3 +1536,17 @@ def test_find_last_checkpoints(name, extension, folder_contents, expected, tmp_p
callback.FILE_EXTENSION = extension
files = callback._find_last_checkpoints(trainer)
assert files == {str(tmp_path / p) for p in expected}


def test_expand_home():
"""Test that the dirpath gets expanded if it contains `~`."""
home_root = Path.home()

checkpoint = ModelCheckpoint(dirpath="~/checkpoints")
assert checkpoint.dirpath == str(home_root / "checkpoints")
checkpoint = ModelCheckpoint(dirpath=Path("~/checkpoints"))
assert checkpoint.dirpath == str(home_root / "checkpoints")

# it is possible to have a folder with the name `~`
checkpoint = ModelCheckpoint(dirpath="./~/checkpoints")
assert checkpoint.dirpath == str(Path.cwd() / "~" / "checkpoints")

0 comments on commit 1679458

Please sign in to comment.