diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 08a79316d77a0..1e9dcf0f48a68 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -23,6 +23,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue where setting `CHECKPOINT_JOIN_CHAR` or `CHECKPOINT_EQUALS_CHAR` would only work on the `ModelCheckpoint` class but not on an instance ([#19054](https://github.com/Lightning-AI/lightning/pull/19054)) +- Fixed `ModelCheckpoint` not expanding the `dirpath` if it has the `~` (home) prefix ([#19058](https://github.com/Lightning-AI/lightning/pull/19058)) + + + ## [2.1.2] - 2023-11-15 ### Fixed diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 4bc258ac13238..6fee5abf3412f 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -458,7 +458,7 @@ def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> self._fs = get_filesystem(dirpath if dirpath else "") if dirpath and _is_local_file_protocol(dirpath if dirpath else ""): - dirpath = os.path.realpath(dirpath) + dirpath = os.path.realpath(os.path.expanduser(dirpath)) self.dirpath = dirpath self.filename = filename diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 81cc98cf504ef..66764c78303ff 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -1536,3 +1536,17 @@ def test_find_last_checkpoints(name, extension, folder_contents, expected, tmp_p callback.FILE_EXTENSION = extension files = callback._find_last_checkpoints(trainer) assert files == {str(tmp_path / p) for p in expected} + + +def test_expand_home(): + """Test that the dirpath gets expanded if it contains `~`.""" + home_root = Path.home() + + checkpoint = ModelCheckpoint(dirpath="~/checkpoints") + assert checkpoint.dirpath == str(home_root / "checkpoints") + checkpoint = ModelCheckpoint(dirpath=Path("~/checkpoints")) + assert checkpoint.dirpath == str(home_root / "checkpoints") + + # it is possible to have a folder with the name `~` + checkpoint = ModelCheckpoint(dirpath="./~/checkpoints") + assert checkpoint.dirpath == str(Path.cwd() / "~" / "checkpoints")