Skip to content

Commit

Permalink
ref: remove _evaluate fx (#3197)
Browse files Browse the repository at this point in the history
* remove _evaluate

* remove _evaluate

* remove _evaluate

* remove _evaluate

* remove _evaluate

* remove _evaluate

* remove _evaluate

* remove _evaluate
  • Loading branch information
williamFalcon authored Aug 26, 2020
1 parent d9ea255 commit a170544
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 106 deletions.
6 changes: 6 additions & 0 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,15 @@ def load_state_dict(self, state_dict):
self.patience = state_dict['patience']

def on_validation_end(self, trainer, pl_module):
if trainer.running_sanity_check:
return

self._run_early_stopping_check(trainer, pl_module)

def on_validation_epoch_end(self, trainer, pl_module):
if trainer.running_sanity_check:
return

val_es_key = 'val_early_stop_on'
if trainer.callback_metrics.get(val_es_key) is not None:
self.monitor = val_es_key
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def format_checkpoint_name(self, epoch, metrics, ver=None):
return filepath

@rank_zero_only
def on_train_start(self, trainer, pl_module):
def on_pretrain_routine_start(self, trainer, pl_module):
"""
Determines model checkpoint save directory at runtime. References attributes from the
trainer's logger to determine where to save checkpoints.
Expand Down Expand Up @@ -330,6 +330,9 @@ def on_validation_end(self, trainer, pl_module):
if trainer.global_rank != 0:
return

if trainer.running_sanity_check:
return

# TODO: remove when dict results are deprecated
self.__warn_deprecated_monitor_key()

Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/trainer/evaluate_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, trainer):
self.predictions = None
self.max_batches = None

def get_evaluation_dataloaders(self):
def get_evaluation_dataloaders(self, max_batches):
# select dataloaders
model = self.trainer.get_model()

Expand All @@ -22,14 +22,17 @@ def get_evaluation_dataloaders(self):
self.trainer.reset_test_dataloader(model)

dataloaders = self.trainer.test_dataloaders
max_batches = self.trainer.num_test_batches
new_max_batches = self.trainer.num_test_batches
else:
# val
if self.trainer.val_dataloaders is None:
self.trainer.reset_val_dataloader(model)

dataloaders = self.trainer.val_dataloaders
max_batches = self.trainer.num_val_batches
new_max_batches = self.trainer.num_val_batches

if max_batches is None:
max_batches = new_max_batches

return dataloaders, max_batches

Expand Down
96 changes: 9 additions & 87 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,92 +229,10 @@ def reset_val_dataloader(self, *args):
def call_hook(self, hook_name, *args, **kwargs):
"""Warning: this is just empty shell for code implemented in other class."""

def _evaluate(
self,
model: LightningModule,
dataloaders: List[DataLoader],
max_batches: Union[int, List[int]],
test_mode: bool = False,
):
"""Run evaluation code.
Args:
model: The model to evaluate.
dataloaders: A list of PyTorch dataloaders.
max_batches: An integer or list of integers with length of the number of dataloaders. Each
entry is the number of batches to process in the corresponding dataloader.
test_mode:
"""

# enable eval mode + no grads
model.zero_grad()
model.eval()
torch.set_grad_enabled(False)

# set up the eval loop
self.evaluation_loop.setup(model, max_batches, dataloaders)

# hook
self.evaluation_loop.on_evaluation_epoch_start()

# run validation/testing
for dataloader_idx, dataloader in enumerate(dataloaders):
dl_outputs = []

# certain accelerators need to process the dataloader
dataloader = self.accelerator_backend.process_dataloader(dataloader)

# each dataloader has a max num batches
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]

for batch_idx, batch in enumerate(dataloader):
if batch is None:
continue

# stop short when running on limited batches
if batch_idx >= dl_max_batches:
break

# hook
self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)

# lightning module methods
output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
output = self.evaluation_loop.evaluation_step_end(output)

# hook
self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx)

# clean up
self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx)
self.evaluation_loop.log_step_metrics(output, batch_idx)

# track epoch level metrics
if output is not None:
dl_outputs.append(output)

self.evaluation_loop.outputs.append(dl_outputs)

# lightning module method
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))

# log epoch level metrics
self.evaluation_loop.log_epoch_metrics(eval_results)
self.evaluation_loop.predictions.to_disk()

# hook
self.evaluation_loop.on_evaluation_epoch_end()

# enable train mode again
model.train()
torch.set_grad_enabled(True)

return eval_results

def run_evaluation(self, test_mode: bool = False):
def run_evaluation(self, test_mode: bool = False, max_batches=None):
# bookkeeping
self.evaluation_loop.testing = test_mode
dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders()
dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches)
if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches):
return [], []

Expand All @@ -333,11 +251,12 @@ def run_evaluation(self, test_mode: bool = False):
# set up the eval loop
self.evaluation_loop.setup(model, max_batches, dataloaders)

# hook
# TODO: needs to move inside the loop but breaks early stopping
self.evaluation_loop.on_evaluation_epoch_start()

# run validation/testing
for dataloader_idx, dataloader in enumerate(dataloaders):
# hook
self.evaluation_loop.on_evaluation_epoch_start()

# bookkeeping
dl_outputs = []
dataloader = self.accelerator_backend.process_dataloader(dataloader)
Expand Down Expand Up @@ -402,6 +321,9 @@ def run_evaluation(self, test_mode: bool = False):
return eval_loop_results, eval_results

def __log_evaluation_epoch_metrics(self, eval_results, test_mode):
if self.running_sanity_check:
return

eval_loop_results = []
if eval_results is not None and len(eval_results) > 0:

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,8 @@ def _run_sanity_check(self, ref_model, model):
self.running_sanity_check = True
self.on_sanity_check_start()

eval_results = self._evaluate(model, self.val_dataloaders, self.num_sanity_val_batches, False)
# run eval step
_, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches)

# allow no returns from eval
if eval_results is not None and len(eval_results) > 0:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class InternalDebugger(object):

def __init__(self, trainer):

self.enabled = 'PL_DEV_DEBUG' in os.environ
self.enabled = os.environ.get('PL_DEV_DEBUG', '0') == '1'
self.trainer = trainer
self.logged_metrics = []
self.pbar_added_metrics = []
Expand Down
19 changes: 6 additions & 13 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from copy import deepcopy
import pickle

Expand Down Expand Up @@ -42,6 +43,7 @@ def on_train_start(self, trainer, pl_module):
default_root_dir=tmpdir,
checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stop_callback,
num_sanity_val_steps=0,
max_epochs=4,
)
trainer.fit(model)
Expand All @@ -67,29 +69,20 @@ def on_train_start(self, trainer, pl_module):

def test_early_stopping_no_extraneous_invocations(tmpdir):
"""Test to ensure that callback methods aren't being invoked outside of the callback handler."""
class EarlyStoppingTestInvocations(EarlyStopping):
def __init__(self, expected_count):
super().__init__()
self.count = 0
self.expected_count = expected_count

def on_validation_end(self, trainer, pl_module):
self.count += 1

def on_train_end(self, trainer, pl_module):
assert self.count == self.expected_count
os.environ['PL_DEV_DEBUG'] = '1'

model = EvalModelTemplate()
expected_count = 4
early_stop_callback = EarlyStoppingTestInvocations(expected_count)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=early_stop_callback,
early_stop_callback=True,
val_check_interval=1.0,
max_epochs=expected_count,
)
trainer.fit(model)

assert len(trainer.dev_debugger.early_stopping_history) == expected_count


@pytest.mark.parametrize('loss_values, patience, expected_stop_epoch', [
([6, 5, 5, 5, 5, 5], 3, 4),
Expand Down
2 changes: 2 additions & 0 deletions tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def _get_logger_args(logger_class, save_dir):
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_loggers_fit_test(wandb, tmpdir, monkeypatch, logger_class):
"""Verify that basic functionality of all loggers."""
os.environ['PL_DEV_DEBUG'] = '0'

if logger_class == CometLogger:
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
Expand Down

0 comments on commit a170544

Please sign in to comment.