Skip to content

Commit

Permalink
Merge pull request #9347 from PyTorchLightning/bugfix/timer-on-train-end
Browse files Browse the repository at this point in the history
fix signature in callbacks to prevent deprecation warning
  • Loading branch information
awaelchli authored Sep 7, 2021
1 parent 645eabe commit a61cc72
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 3 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [1.4.6] - unreleased

- Fixed signature of `Timer.on_train_epoch_end` and `StochasticWeightAveraging.on_train_epoch_end` to prevent unwanted deprecation warnings ([#9347](https://github.com/PyTorchLightning/pytorch-lightning/pull/9347))


## [1.4.5] - 2021-08-31

- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo

trainer.accumulate_grad_batches = trainer.num_training_batches

def on_train_epoch_end(self, trainer: "pl.Trainer", *args):
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", unused: Optional = None):
trainer.fit_loop._skip_backward = False

def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/callbacks/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def on_train_batch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None:
return
self._check_time_remaining(trainer)

def on_train_epoch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None:
def on_train_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", unused: Optional = None
) -> None:
if self._interval != Interval.epoch or self._duration is None:
return
self._check_time_remaining(trainer)
Expand Down
5 changes: 4 additions & 1 deletion tests/callbacks/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
from tests.helpers.utils import no_warning_call


def test_trainer_flag(caplog):
Expand Down Expand Up @@ -106,7 +107,9 @@ def test_timer_stops_training(tmpdir, caplog):
timer = Timer(duration=duration)

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1000, callbacks=[timer])
with caplog.at_level(logging.INFO):
with caplog.at_level(logging.INFO), no_warning_call(
DeprecationWarning, match="The signature of `Callback.on_train_epoch_end` has changed in v1.3"
):
trainer.fit(model)
assert trainer.global_step > 1
assert trainer.current_epoch < 999
Expand Down

0 comments on commit a61cc72

Please sign in to comment.