Skip to content

Commit

Permalink
Make ModelCheckpoint._format_checkpoint_name an instance method (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Nov 23, 2023
1 parent dbea69b commit 9a26da8
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed automatic detection of 'last.ckpt' files to respect the extension when filtering ([#17072](https://github.com/Lightning-AI/lightning/pull/17072))


- Fixed an issue where setting `CHECKPOINT_JOIN_CHAR` or `CHECKPOINT_EQUALS_CHAR` would only work on the `ModelCheckpoint` class but not on an instance ([#19054](https://github.com/Lightning-AI/lightning/pull/19054))


## [2.1.2] - 2023-11-15

Expand Down
9 changes: 4 additions & 5 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,17 +524,16 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] =

return should_update_best_and_save

@classmethod
def _format_checkpoint_name(
cls,
self,
filename: Optional[str],
metrics: Dict[str, Tensor],
prefix: str = "",
auto_insert_metric_name: bool = True,
) -> str:
if not filename:
# filename is not set, use default name
filename = "{epoch}" + cls.CHECKPOINT_JOIN_CHAR + "{step}"
filename = "{epoch}" + self.CHECKPOINT_JOIN_CHAR + "{step}"

# check and parse user passed keys in the string
groups = re.findall(r"(\{.*?)[:\}]", filename)
Expand All @@ -547,7 +546,7 @@ def _format_checkpoint_name(
name = group[1:]

if auto_insert_metric_name:
filename = filename.replace(group, name + cls.CHECKPOINT_EQUALS_CHAR + "{" + name)
filename = filename.replace(group, name + self.CHECKPOINT_EQUALS_CHAR + "{" + name)

# support for dots: https://stackoverflow.com/a/7934969
filename = filename.replace(group, f"{{0[{name}]")
Expand All @@ -557,7 +556,7 @@ def _format_checkpoint_name(
filename = filename.format(metrics)

if prefix:
filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename])
filename = self.CHECKPOINT_JOIN_CHAR.join([prefix, filename])

return filename

Expand Down
24 changes: 13 additions & 11 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,34 +402,36 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):


def test_model_checkpoint_format_checkpoint_name(tmpdir, monkeypatch):
model_checkpoint = ModelCheckpoint(dirpath=tmpdir)

# empty filename:
ckpt_name = ModelCheckpoint._format_checkpoint_name("", {"epoch": 3, "step": 2})
ckpt_name = model_checkpoint._format_checkpoint_name("", {"epoch": 3, "step": 2})
assert ckpt_name == "epoch=3-step=2"

ckpt_name = ModelCheckpoint._format_checkpoint_name(None, {"epoch": 3, "step": 2}, prefix="test")
ckpt_name = model_checkpoint._format_checkpoint_name(None, {"epoch": 3, "step": 2}, prefix="test")
assert ckpt_name == "test-epoch=3-step=2"

# no groups case:
ckpt_name = ModelCheckpoint._format_checkpoint_name("ckpt", {}, prefix="test")
ckpt_name = model_checkpoint._format_checkpoint_name("ckpt", {}, prefix="test")
assert ckpt_name == "test-ckpt"

# no prefix
ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03})
ckpt_name = model_checkpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03})
assert ckpt_name == "epoch=003-acc=0.03"

# one metric name is substring of another
ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch:03d}-{epoch_test:03d}", {"epoch": 3, "epoch_test": 3})
ckpt_name = model_checkpoint._format_checkpoint_name("{epoch:03d}-{epoch_test:03d}", {"epoch": 3, "epoch_test": 3})
assert ckpt_name == "epoch=003-epoch_test=003"

# prefix
monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_JOIN_CHAR", "@")
ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch},{acc:.5f}", {"epoch": 3, "acc": 0.03}, prefix="test")
model_checkpoint.CHECKPOINT_JOIN_CHAR = "@"
ckpt_name = model_checkpoint._format_checkpoint_name("{epoch},{acc:.5f}", {"epoch": 3, "acc": 0.03}, prefix="test")
assert ckpt_name == "test@epoch=3,acc=0.03000"
monkeypatch.undo()

# non-default char for equals sign
monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_EQUALS_CHAR", ":")
ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03})
model_checkpoint.CHECKPOINT_EQUALS_CHAR = ":"
ckpt_name = model_checkpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03})
assert ckpt_name == "epoch:003-acc:0.03"
monkeypatch.undo()

Expand All @@ -454,13 +456,13 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir, monkeypatch):
assert ckpt_name == "epoch=4_val/loss=0.03000.ckpt"

# auto_insert_metric_name=False
ckpt_name = ModelCheckpoint._format_checkpoint_name(
ckpt_name = model_checkpoint._format_checkpoint_name(
"epoch={epoch:03d}-val_acc={val/acc}", {"epoch": 3, "val/acc": 0.03}, auto_insert_metric_name=False
)
assert ckpt_name == "epoch=003-val_acc=0.03"

# dots in the metric name
ckpt_name = ModelCheckpoint._format_checkpoint_name(
ckpt_name = model_checkpoint._format_checkpoint_name(
"[email protected]={val/[email protected]:.4f}", {"val/[email protected]": 0.2}, auto_insert_metric_name=False
)
assert ckpt_name == "[email protected]=0.2000"
Expand Down

0 comments on commit 9a26da8

Please sign in to comment.