Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix last checkpoint finding in filtered files with correct extension #17072

Merged
merged 14 commits into from
Nov 21, 2023
Prev Previous commit
Next Next commit
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
pre-commit-ci[bot] committed Nov 19, 2023
commit 739231cc5a8b85013105281e3b4aabf0c793c8ad
15 changes: 9 additions & 6 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1511,12 +1511,15 @@ def test_resume_and_old_checkpoint_files_remain(same_resume_folder, tmp_path):
assert set(os.listdir(second)) == {"epoch=0-step=6.ckpt", "epoch=0-step=8.ckpt"}


@pytest.mark.parametrize("folder_contents, expected", [
([], []),
(["last"], []),
(["last", "last.ckpt"], ["last.ckpt"]),
(["log.txt", "last-v0.ckpt", "last-v1.ckpt"], ["last-v0.ckpt", "last-v1.ckpt"]),
])
@pytest.mark.parametrize(
"folder_contents, expected",
[
([], []),
(["last"], []),
(["last", "last.ckpt"], ["last.ckpt"]),
(["log.txt", "last-v0.ckpt", "last-v1.ckpt"], ["last-v0.ckpt", "last-v1.ckpt"]),
],
)
def test_find_last_checkpoints(folder_contents, expected, tmp_path):
for file in folder_contents:
(tmp_path / file).touch()