From a61cc72f82cefd45d2e7f4320a823f51e2ffd477 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 7 Sep 2021 18:16:24 +0200 Subject: [PATCH] Merge pull request #9347 from PyTorchLightning/bugfix/timer-on-train-end fix signature in callbacks to prevent deprecation warning --- CHANGELOG.md | 5 +++++ pytorch_lightning/callbacks/stochastic_weight_avg.py | 2 +- pytorch_lightning/callbacks/timer.py | 4 +++- tests/callbacks/test_timer.py | 5 ++++- 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 21c2c1e5e3cb0..493be1077398e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [1.4.6] - unreleased + +- Fixed signature of `Timer.on_train_epoch_end` and `StochasticWeightAveraging.on_train_epoch_end` to prevent unwanted deprecation warnings ([#9347](https://github.com/PyTorchLightning/pytorch-lightning/pull/9347)) + + ## [1.4.5] - 2021-08-31 - Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142)) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 28c19944ebd37..dbf173ccb6e51 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -216,7 +216,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo trainer.accumulate_grad_batches = trainer.num_training_batches - def on_train_epoch_end(self, trainer: "pl.Trainer", *args): + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", unused: Optional = None): trainer.fit_loop._skip_backward = False def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): diff --git a/pytorch_lightning/callbacks/timer.py b/pytorch_lightning/callbacks/timer.py index f68ddb8611264..aff6b917096b5 100644 --- a/pytorch_lightning/callbacks/timer.py +++ b/pytorch_lightning/callbacks/timer.py @@ -148,7 +148,9 @@ def on_train_batch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None: return self._check_time_remaining(trainer) - def on_train_epoch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None: + def on_train_epoch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", unused: Optional = None + ) -> None: if self._interval != Interval.epoch or self._duration is None: return self._check_time_remaining(trainer) diff --git a/tests/callbacks/test_timer.py b/tests/callbacks/test_timer.py index c7b636d3f843a..4d93733eea36e 100644 --- a/tests/callbacks/test_timer.py +++ b/tests/callbacks/test_timer.py @@ -24,6 +24,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.runif import RunIf +from tests.helpers.utils import no_warning_call def test_trainer_flag(caplog): @@ -106,7 +107,9 @@ def test_timer_stops_training(tmpdir, caplog): timer = Timer(duration=duration) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1000, callbacks=[timer]) - with caplog.at_level(logging.INFO): + with caplog.at_level(logging.INFO), no_warning_call( + DeprecationWarning, match="The signature of `Callback.on_train_epoch_end` has changed in v1.3" + ): trainer.fit(model) assert trainer.global_step > 1 assert trainer.current_epoch < 999