Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate early_stop_callback Trainer argument (part 2) #3845

Merged
merged 6 commits into from
Oct 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _run_early_stopping_check(self, trainer, pl_module):
current = logs.get(self.monitor)

# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(current)
trainer.dev_debugger.track_early_stopping_history(self, current)

if not isinstance(current, torch.Tensor):
current = torch.tensor(current, device=pl_module.device)
Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/utilities/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,14 @@ def track_pbar_metrics_history(self, metrics):
self.pbar_added_metrics.append(metrics)

@enabled_only
def track_early_stopping_history(self, current):
es = self.trainer.early_stop_callback
def track_early_stopping_history(self, callback, current):
debug_dict = {
'epoch': self.trainer.current_epoch,
'global_step': self.trainer.global_step,
'rank': self.trainer.global_rank,
'current': current,
'best': es.best_score,
'patience': es.wait_count
'best': callback.best_score,
'patience': callback.wait_count
}
self.early_stopping_history.append(debug_dict)

Expand Down
3 changes: 2 additions & 1 deletion tests/backends/test_ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
from pytorch_lightning.callbacks import EarlyStopping
from tests.base import EvalModelTemplate
from pytorch_lightning.core import memory
from pytorch_lightning.trainer import Trainer
Expand All @@ -15,7 +16,7 @@ def test_multi_gpu_early_stop_ddp_spawn(tmpdir):

trainer_options = dict(
default_root_dir=tmpdir,
early_stop_callback=True,
callbacks=[EarlyStopping()],
max_epochs=50,
limit_train_batches=10,
limit_val_batches=10,
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_early_stopping_no_extraneous_invocations(tmpdir):
expected_count = 4
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=True,
callbacks=[EarlyStopping()],
val_check_interval=1.0,
max_epochs=expected_count,
)
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core import memory
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -25,7 +26,7 @@ def test_multi_gpu_early_stop_dp(tmpdir):

trainer_options = dict(
default_root_dir=tmpdir,
early_stop_callback=True,
callbacks=[EarlyStopping()],
max_epochs=50,
limit_train_batches=10,
limit_val_batches=10,
Expand Down
5 changes: 3 additions & 2 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.accelerators import TPUBackend
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base.datasets import TrialMNIST
Expand Down Expand Up @@ -155,7 +156,7 @@ def test_model_tpu_early_stop(tmpdir):
"""Test if single TPU core training works"""
model = EvalModelTemplate()
trainer = Trainer(
early_stop_callback=True,
callbacks=[EarlyStopping()],
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=50,
Expand Down Expand Up @@ -261,7 +262,7 @@ def test_result_obj_on_tpu(tmpdir):
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=epochs,
early_stop_callback=True,
callbacks=[EarlyStopping()],
row_log_interval=2,
limit_train_batches=batches,
weights_summary=None,
Expand Down