Skip to content

Commit

Permalink
Reset all results on epoch end (#14061)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Aug 9, 2022
1 parent 56abd60 commit d850854
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the `NeptuneLogger` dependency being unrecognized ([#13988](https://github.com/Lightning-AI/lightning/pull/13988))


- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061))


- Fixed an issue where users would be warned about unset `max_epochs` even when `fast_dev_run` was set ([#13262](https://github.com/Lightning-AI/lightning/pull/13262))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@ def update_train_epoch_metrics(self) -> None:
self.log_metrics(self.metrics["log"])

# reset result collection for next epoch
assert self.trainer._results is not None
self.trainer._results.reset(metrics=True)
self.reset_results()

"""
Utilities and properties
Expand Down
29 changes: 27 additions & 2 deletions tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,12 @@ def on_train_epoch_end(self, trainer, pl_module):
"accelerator",
[
pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)),
"cpu",
],
)
def test_metric_are_properly_reduced(tmpdir, accelerator):
class TestingModel(BoringModel):
def __init__(self, *args, **kwargs) -> None:
def __init__(self) -> None:
super().__init__()
self.val_acc = Accuracy()

Expand All @@ -592,7 +593,6 @@ def validation_step(self, batch, batch_idx):
return super().validation_step(batch, batch_idx)

early_stop = EarlyStopping(monitor="val_acc", mode="max")

checkpoint = ModelCheckpoint(monitor="val_acc", save_last=True, save_top_k=2, mode="max")

model = TestingModel()
Expand Down Expand Up @@ -812,3 +812,28 @@ def training_step(self, batch, batch_idx):
call(metrics={"foo_epoch": 0.0, "epoch": 1}, step=3),
]
)


@mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics")
def test_log_on_train_start(mock_log_metrics, tmpdir):
"""Tests that logged metrics on_train_start get reset after the first epoch."""

class MyModel(BoringModel):
def on_train_start(self):
self.log("foo", 123)

model = MyModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=0,
max_epochs=2,
log_every_n_steps=1,
enable_model_summary=False,
enable_checkpointing=False,
enable_progress_bar=False,
)
trainer.fit(model)

assert mock_log_metrics.mock_calls == [call(metrics={"foo": 123.0, "epoch": 0}, step=0)]
assert trainer.max_epochs > 1

0 comments on commit d850854

Please sign in to comment.