Skip to content

Commit

Permalink
Deprecate early_stop_callback Trainer argument (part 2) (#3845)
Browse files Browse the repository at this point in the history
* update tests with EarlyStopping default

* imports

* revert legacy tests

* fix test

* revert

* revert
  • Loading branch information
awaelchli authored Oct 4, 2020
1 parent 6723b92 commit cc9781a
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _run_early_stopping_check(self, trainer, pl_module):
current = logs.get(self.monitor)

# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(current)
trainer.dev_debugger.track_early_stopping_history(self, current)

if not isinstance(current, torch.Tensor):
current = torch.tensor(current, device=pl_module.device)
Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/utilities/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,14 @@ def track_pbar_metrics_history(self, metrics):
self.pbar_added_metrics.append(metrics)

@enabled_only
def track_early_stopping_history(self, current):
es = self.trainer.early_stop_callback
def track_early_stopping_history(self, callback, current):
debug_dict = {
'epoch': self.trainer.current_epoch,
'global_step': self.trainer.global_step,
'rank': self.trainer.global_rank,
'current': current,
'best': es.best_score,
'patience': es.wait_count
'best': callback.best_score,
'patience': callback.wait_count
}
self.early_stopping_history.append(debug_dict)

Expand Down
3 changes: 2 additions & 1 deletion tests/backends/test_ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
from pytorch_lightning.callbacks import EarlyStopping
from tests.base import EvalModelTemplate
from pytorch_lightning.core import memory
from pytorch_lightning.trainer import Trainer
Expand All @@ -15,7 +16,7 @@ def test_multi_gpu_early_stop_ddp_spawn(tmpdir):

trainer_options = dict(
default_root_dir=tmpdir,
early_stop_callback=True,
callbacks=[EarlyStopping()],
max_epochs=50,
limit_train_batches=10,
limit_val_batches=10,
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_early_stopping_no_extraneous_invocations(tmpdir):
expected_count = 4
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=True,
callbacks=[EarlyStopping()],
val_check_interval=1.0,
max_epochs=expected_count,
)
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core import memory
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -25,7 +26,7 @@ def test_multi_gpu_early_stop_dp(tmpdir):

trainer_options = dict(
default_root_dir=tmpdir,
early_stop_callback=True,
callbacks=[EarlyStopping()],
max_epochs=50,
limit_train_batches=10,
limit_val_batches=10,
Expand Down
5 changes: 3 additions & 2 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.accelerators import TPUBackend
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base.datasets import TrialMNIST
Expand Down Expand Up @@ -155,7 +156,7 @@ def test_model_tpu_early_stop(tmpdir):
"""Test if single TPU core training works"""
model = EvalModelTemplate()
trainer = Trainer(
early_stop_callback=True,
callbacks=[EarlyStopping()],
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=50,
Expand Down Expand Up @@ -261,7 +262,7 @@ def test_result_obj_on_tpu(tmpdir):
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=epochs,
early_stop_callback=True,
callbacks=[EarlyStopping()],
row_log_interval=2,
limit_train_batches=batches,
weights_summary=None,
Expand Down

0 comments on commit cc9781a

Please sign in to comment.