Skip to content

Commit

Permalink
ref: final inner eval loop hooks (#3154)
Browse files Browse the repository at this point in the history
* final inner eval loop hooks

* final inner eval loop hooks
  • Loading branch information
williamFalcon authored Aug 25, 2020
1 parent 6f634a1 commit 22b9642
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 149 deletions.
116 changes: 107 additions & 9 deletions pytorch_lightning/trainer/evaluate_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.core.step_result import Result, EvalResult
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import flatten_dict


class EvaluationLoop(object):
Expand All @@ -18,16 +19,9 @@ def is_using_eval_results(self):
return using_eval_result

def setup(self, model, max_batches, dataloaders):
# enable eval mode
model.zero_grad()
model.eval()

# copy properties for forward overrides
self.trainer.copy_trainer_model_properties(model)

# disable gradients to save memory
torch.set_grad_enabled(False)

# bookkeeping
self.outputs = []
self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size)
Expand Down Expand Up @@ -85,6 +79,103 @@ def evaluation_step_end(self, *args, **kwargs):
output = self.trainer.call_hook('validation_step_end', *args, **kwargs)
return output

def evaluation_epoch_end(self, num_dataloaders):
using_eval_result = self.is_using_eval_results()

# call the model epoch end
eval_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result)
return eval_results

def log_epoch_metrics(self, eval_results):
using_eval_result = self.is_using_eval_results()
if using_eval_result:
if isinstance(eval_results, list):
for eval_result in eval_results:
self.trainer.callback_metrics = eval_result.callback_metrics
else:
self.trainer.callback_metrics = eval_results.callback_metrics
else:
if isinstance(eval_results, list):
for eval_result in eval_results:
# with a scalar return, auto set it to "val_loss" for callbacks
if isinstance(eval_result, torch.Tensor):
flat = {'val_loss': eval_result}
else:
flat = flatten_dict(eval_result)
self.trainer.callback_metrics.update(flat)
else:
# with a scalar return, auto set it to "val_loss" for callbacks
if isinstance(eval_results, torch.Tensor):
flat = {'val_loss': eval_results}
else:
flat = flatten_dict(eval_results)
self.trainer.callback_metrics.update(flat)

def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
model = self.trainer.get_model()

# with a single dataloader don't pass an array
outputs = self.outputs
eval_results = outputs
if num_dataloaders == 1:
eval_results = outputs[0]

user_reduced = False

if self.testing:
if self.trainer.is_overridden('test_epoch_end', model=model):
if using_eval_result:
eval_results = self.__gather_epoch_end_eval_results(outputs)

eval_results = model.test_epoch_end(eval_results)
user_reduced = True

else:
if self.trainer.is_overridden('validation_epoch_end', model=model):
if using_eval_result:
eval_results = self.__gather_epoch_end_eval_results(outputs)

eval_results = model.validation_epoch_end(eval_results)
user_reduced = True

if using_eval_result and not user_reduced:
eval_results = self.__auto_reduce_result_objs(outputs)

if not isinstance(eval_results, list):
eval_results = [eval_results]

return eval_results

def __gather_epoch_end_eval_results(self, outputs):
eval_results = []
for epoch_output in outputs:
result = epoch_output[0].__class__.gather(epoch_output)
if 'checkpoint_on' in result:
result.checkpoint_on = result.checkpoint_on.mean()
if 'early_stop_on' in result:
result.early_stop_on = result.early_stop_on.mean()

eval_results.append(result)

# with 1 dataloader don't pass in a list
if len(eval_results) == 1:
eval_results = eval_results[0]
return eval_results

def __auto_reduce_result_objs(self, outputs):
# outputs has a list of results per dataloader
eval_results = []
for dl_output in outputs:
result = dl_output[0]
result = result.__class__.reduce_on_epoch_end(dl_output)
if 'checkpoint_on' in result:
result.checkpoint_on = result.checkpoint_on.mean()
if 'early_stop_on' in result:
result.early_stop_on = result.early_stop_on.mean()
eval_results.append(result)

return eval_results

def on_evaluation_batch_start(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_batch_start', *args, **kwargs)
Expand All @@ -107,13 +198,20 @@ def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx):
# track debug metrics
self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output)

def on_evaluation_epoch_end(self, *args, **kwargs):
def on_evaluation_epoch_end(self, eval_results, *args, **kwargs):
# log epoch level metrics
self.log_epoch_metrics(eval_results)

# Write predictions to disk if they're available
self.predictions.to_disk()

# call the callback hook
if self.testing:
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)

def log_metrics(self, output, batch_idx):
def log_step_metrics(self, output, batch_idx):
if self.trainer.running_sanity_check:
return

Expand Down
158 changes: 18 additions & 140 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,17 +248,22 @@ def _evaluate(
# set up the loop for val/test
self.evaluation_loop.testing = test_mode

# enable eval mode + no grads
model.zero_grad()
model.eval()
torch.set_grad_enabled(False)

# set up the eval loop
self.evaluation_loop.setup(model, max_batches, dataloaders)

# hook
self.evaluation_loop.on_evaluation_epoch_start()

# run validation
# run validation/testing
for dataloader_idx, dataloader in enumerate(dataloaders):
dl_outputs = []

# on TPU we have to wrap it under the ParallelLoader
# certain accelerators need to process the dataloader
dataloader = self.accelerator_backend.process_dataloader(dataloader)

# each dataloader has a max num batches
Expand All @@ -272,163 +277,36 @@ def _evaluate(
if batch_idx >= dl_max_batches:
break

# val loop hooks
# hook
self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)

# lightning module methods
output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
output = self.evaluation_loop.evaluation_step_end(output)

# hook
self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx)

# clean up
self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx)
self.evaluation_loop.log_metrics(output, batch_idx)
self.evaluation_loop.log_step_metrics(output, batch_idx)

# track epoch level metrics
if output is not None:
dl_outputs.append(output)

self.evaluation_loop.outputs.append(dl_outputs)

# ---------------------
# EVAL_EPOCH_END
# ---------------------
using_eval_result = self.evaluation_loop.is_using_eval_results()
eval_results = self.__run_eval_epoch_end(
test_mode,
self.evaluation_loop.outputs,
dataloaders,
using_eval_result
)

# log callback metrics
self.__update_callback_metrics(eval_results, using_eval_result)
# lightning module method
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))

# Write predictions to disk if they're available.
self.evaluation_loop.predictions.to_disk()
# hook
self.evaluation_loop.on_evaluation_epoch_end(eval_results)

# enable train mode again
model.train()

# enable gradients to save memory
torch.set_grad_enabled(True)

# --------------------------
# ON_EVAL_EPOCH_END hook
# --------------------------
self.evaluation_loop.on_evaluation_epoch_end()

return eval_results

def __update_callback_metrics(self, eval_results, using_eval_result):
if using_eval_result:
if isinstance(eval_results, list):
for eval_result in eval_results:
self.callback_metrics = eval_result.callback_metrics
else:
self.callback_metrics = eval_results.callback_metrics
else:
if isinstance(eval_results, list):
for eval_result in eval_results:
# with a scalar return, auto set it to "val_loss" for callbacks
if isinstance(eval_result, torch.Tensor):
flat = {'val_loss': eval_result}
else:
flat = flatten_dict(eval_result)
self.callback_metrics.update(flat)
else:
# with a scalar return, auto set it to "val_loss" for callbacks
if isinstance(eval_results, torch.Tensor):
flat = {'val_loss': eval_results}
else:
flat = flatten_dict(eval_results)
self.callback_metrics.update(flat)

def __run_eval_epoch_end(self, test_mode, outputs, dataloaders, using_eval_result):
model = self.get_model()

# with a single dataloader don't pass an array
eval_results = outputs
if len(dataloaders) == 1:
eval_results = outputs[0]

user_reduced = False

if test_mode:
if self.is_overridden('test_end', model=model):
# TODO: remove in v1.0.0
if using_eval_result:
eval_results = self.__gather_epoch_end_eval_results(outputs)

eval_results = model.test_end(eval_results)
user_reduced = True
rank_zero_warn(
'Method `test_end` was deprecated in v0.7 and will be removed in v1.0.'
' Use `test_epoch_end` instead.',
DeprecationWarning,
)

elif self.is_overridden('test_epoch_end', model=model):
if using_eval_result:
eval_results = self.__gather_epoch_end_eval_results(outputs)

eval_results = model.test_epoch_end(eval_results)
user_reduced = True

else:
if self.is_overridden('validation_end', model=model):
# TODO: remove in v1.0.0
if using_eval_result:
eval_results = self.__gather_epoch_end_eval_results(outputs)

eval_results = model.validation_end(eval_results)
user_reduced = True
rank_zero_warn(
'Method `validation_end` was deprecated in v0.7 and will be removed in v1.0.'
' Use `validation_epoch_end` instead.',
DeprecationWarning,
)

elif self.is_overridden('validation_epoch_end', model=model):
if using_eval_result:
eval_results = self.__gather_epoch_end_eval_results(outputs)

eval_results = model.validation_epoch_end(eval_results)
user_reduced = True

if using_eval_result and not user_reduced:
eval_results = self.__auto_reduce_result_objs(outputs)

if not isinstance(eval_results, list):
eval_results = [eval_results]

return eval_results

def __gather_epoch_end_eval_results(self, outputs):
eval_results = []
for epoch_output in outputs:
result = epoch_output[0].__class__.gather(epoch_output)
if 'checkpoint_on' in result:
result.checkpoint_on = result.checkpoint_on.mean()
if 'early_stop_on' in result:
result.early_stop_on = result.early_stop_on.mean()

eval_results.append(result)

# with 1 dataloader don't pass in a list
if len(eval_results) == 1:
eval_results = eval_results[0]
return eval_results

def __auto_reduce_result_objs(self, outputs):
# outputs has a list of results per dataloader
eval_results = []
for dl_output in outputs:
result = dl_output[0]
result = result.__class__.reduce_on_epoch_end(dl_output)
if 'checkpoint_on' in result:
result.checkpoint_on = result.checkpoint_on.mean()
if 'early_stop_on' in result:
result.early_stop_on = result.early_stop_on.mean()
eval_results.append(result)

return eval_results

def run_evaluation(self, test_mode: bool = False):
Expand Down

0 comments on commit 22b9642

Please sign in to comment.