diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 065b29c75da37..3a94a531a6500 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -11,16 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os +import tempfile import collections import copy import inspect -import os import re -import tempfile from abc import ABC from argparse import Namespace -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from pytorch_lightning import _logger as log @@ -28,16 +27,17 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO +from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities import rank_zero_warn, AMPType from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities.parsing import ( AttributeDict, collect_init_args, get_init_args, ) +from pytorch_lightning.callbacks import Callback from torch import ScriptModule, Tensor from torch.nn import Module from torch.optim.optimizer import Optimizer @@ -111,6 +111,8 @@ def __init__(self, *args, **kwargs): self._datamodule = None self._results: Optional[Result] = None self._current_fx_name = '' + self._current_hook_fx_name = '' + self._current_dataloader_idx = None def optimizers(self): opts = self.trainer.optimizers @@ -244,6 +246,17 @@ def log( on_step = self.__auto_choose_log_on_step(on_step) on_epoch = self.__auto_choose_log_on_epoch(on_epoch) + if self._current_hook_fx_name != '': + self.trainer.logger_connector.callback_logging_validator\ + .validate_callback_logging_arguments(self._current_hook_fx_name, + on_step=on_step, + on_epoch=on_epoch) + + # make sure user doesn't introduce logic for multi-dataloaders + if "/dataloader_idx_" in name: + raise MisconfigurationException( + f"Logged key: {name} should not contain information about dataloader_idx.") + self._results.log( name, value, @@ -257,7 +270,8 @@ def log( enable_graph, sync_dist, sync_dist_op, - sync_dist_group + sync_dist_group, + self._current_dataloader_idx, ) def log_dict( @@ -950,7 +964,8 @@ def configure_optimizers( - Single optimizer. - List or Tuple - List of optimizers. - Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict). - - Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler' key which value is a single LR scheduler or lr_dict. + - Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler' key which value is a single LR + scheduler or lr_dict. - Tuple of dictionaries as described, with an optional 'frequency' key. - None - Fit will run without any optimizer. @@ -1278,11 +1293,11 @@ def tbptt_split_batch(self, batch, split_size): batch_split = [] for i, x in enumerate(batch): if isinstance(x, torch.Tensor): - split_x = x[:, t : t + split_size] + split_x = x[:, t: t + split_size] elif isinstance(x, collections.Sequence): split_x = [None] * len(x) for batch_idx in range(len(x)): - split_x[batch_idx] = x[batch_idx][t : t + split_size] + split_x[batch_idx] = x[batch_idx][t: t + split_size] batch_split.append(split_x) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 650c1876d0cd0..fe42a2d6013d1 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -124,6 +124,7 @@ def log( sync_dist: bool = False, sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, + dataloader_idx: Optional[int] = None, ): # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): @@ -144,6 +145,7 @@ def log( # set step version step_name = f'{name}_step' + self.__set_meta( step_name, value, @@ -154,12 +156,15 @@ def log( reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, - forked=False + forked=False, + dataloader_idx=dataloader_idx, ) + self.__setitem__(step_name, value) # set epoch version epoch_name = f'{name}_epoch' + self.__set_meta( epoch_name, value, @@ -170,7 +175,8 @@ def log( reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, - forked=False + forked=False, + dataloader_idx=dataloader_idx, ) self.__setitem__(epoch_name, value) @@ -185,7 +191,8 @@ def log( reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, - forked=was_forked + forked=was_forked, + dataloader_idx=dataloader_idx, ) # set the value @@ -202,7 +209,8 @@ def __set_meta( reduce_fx: Callable, tbptt_pad_token: int, tbptt_reduce_fx: Callable, - forked: bool + forked: bool, + dataloader_idx: Union[int, None] ): # set the meta for the item meta_value = value @@ -215,7 +223,8 @@ def __set_meta( value=meta_value, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, - forked=forked + forked=forked, + dataloader_idx=dataloader_idx, ) self['meta'][name] = meta @@ -242,7 +251,13 @@ def get_callback_metrics(self) -> dict: return result - def get_batch_log_metrics(self, include_forked_originals=True) -> dict: + def _add_dataloader_idx(self, k: str, dataloader_idx: Union[int, None], add_dataloader_idx: bool) -> str: + if dataloader_idx is not None and add_dataloader_idx: + return f"{k}/dataloader_idx_{dataloader_idx}" + else: + return k + + def get_batch_log_metrics(self, include_forked_originals=True, add_dataloader_idx=False) -> dict: """ Gets the metrics to log at the end of the batch step @@ -257,15 +272,17 @@ def get_batch_log_metrics(self, include_forked_originals=True) -> dict: if options['forked'] and not include_forked_originals: continue + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + if options['logger'] and options['on_step']: if isinstance(self[k], Metric): - result[k] = self[k]._forward_cache + result[dl_key] = self[k]._forward_cache else: - result[k] = self[k] + result[dl_key] = self[k] return result - def get_epoch_log_metrics(self) -> dict: + def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: """ Gets the metrics to log at the end of epoch """ @@ -279,11 +296,13 @@ def get_epoch_log_metrics(self) -> dict: if options['forked']: continue + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + if options['logger'] and options['on_epoch']: if isinstance(self[k], Metric): - result[k] = self[k].compute() + result[dl_key] = self[k].compute() else: - result[k] = self[k] + result[dl_key] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): # compute metric on epoch anyway so state does not accumulate @@ -291,7 +310,7 @@ def get_epoch_log_metrics(self) -> dict: return result - def get_epoch_pbar_metrics(self): + def get_epoch_pbar_metrics(self, add_dataloader_idx=False): """ Gets the metrics to log at the end of epoch """ @@ -305,11 +324,13 @@ def get_epoch_pbar_metrics(self): if options['forked']: continue + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + if options['prog_bar'] and options['on_epoch']: if isinstance(self[k], Metric): - result[k] = self[k].compute() + result[dl_key] = self[k].compute() else: - result[k] = self[k] + result[dl_key] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): # compute metric on epoch anyway so state does not accumulate @@ -317,7 +338,7 @@ def get_epoch_pbar_metrics(self): return result - def get_forked_metrics(self): + def get_forked_metrics(self, add_dataloader_idx=False): """ Gets the metrics to log at the end of epoch """ @@ -328,12 +349,14 @@ def get_forked_metrics(self): if k == '_internal': continue + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + if options['forked']: - result[k] = self[k] + result[dl_key] = self[k] return result - def get_batch_pbar_metrics(self, include_forked_originals=True): + def get_batch_pbar_metrics(self, include_forked_originals=True, add_dataloader_idx=False): """ Gets the metrics to log at the end of the batch step """ @@ -347,11 +370,13 @@ def get_batch_pbar_metrics(self, include_forked_originals=True): if options['forked'] and not include_forked_originals: continue + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + if options['prog_bar'] and options['on_step']: if isinstance(self[k], Metric): - result[k] = self[k]._forward_cache + result[dl_key] = self[k]._forward_cache else: - result[k] = self[k] + result[dl_key] = self[k] return result diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 187ff237056a2..b805c887ec7ad 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from abc import ABC from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -61,6 +62,8 @@ def init_default_checkpoint_callback(self, checkpoint_callback): checkpoint_callback = ModelCheckpoint(dirpath=None, filename=None) elif checkpoint_callback is False: checkpoint_callback = None + if checkpoint_callback: + checkpoint_callback.save_function = self.trainer.save_checkpoint return checkpoint_callback @@ -81,5 +84,4 @@ def configure_progress_bar(self, refresh_rate=1, process_position=0): self.trainer.callbacks.append(progress_bar_callback) else: progress_bar_callback = None - return progress_bar_callback diff --git a/pytorch_lightning/trainer/connectors/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector.py index 893eab5a16a3d..a65975fddeedd 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector.py @@ -19,20 +19,40 @@ from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.trainer.connectors.logger_connector_utils import LoggingCallbackValidator, CacheInternalMetrics from pprint import pprint -from typing import Iterable +from typing import Iterable, Union from copy import deepcopy -from collections import ChainMap +from collections import defaultdict, ChainMap class LoggerConnector: + __stages = ["train", "val", "test"] + __lookup_stages = {"0": "test", "1": "val", "True": "test", "False": "val"} + def __init__(self, trainer): self.trainer = trainer self.callback_metrics = {} self.logged_metrics = {} self.progress_bar_metrics = {} self.eval_loop_results = [] + self.callback_logging_validator = LoggingCallbackValidator() + self._cache_internal_metrics = {stage: CacheInternalMetrics() for stage in self.__stages} + + def cached_metrics(self, stage_or_testing: Union[str, bool]) -> Union[CacheInternalMetrics, None]: + stage_or_testing = str(stage_or_testing) + stages = self.__stages + if stage_or_testing in self.__stages: + return self._cache_internal_metrics[stage_or_testing] + if stage_or_testing in self.__lookup_stages: + # Acces using trainer.testing + stage = self.__lookup_stages[stage_or_testing] + return self._cache_internal_metrics[stage] + raise MisconfigurationException( + f"Provide stage_or_testing {stage_or_testing} doesn't belong either to {self.__stages}" + f" or {self.__lookup_stages.keys()}" + ) def on_trainer_init(self, logger, flush_logs_every_n_steps, log_every_n_steps): # logging @@ -76,7 +96,8 @@ def log_metrics(self, metrics, grad_norm_dic, step=None): metrics.update(mem_map) # add norms - metrics.update(grad_norm_dic) + if grad_norm_dic is not None: + metrics.update(grad_norm_dic) # turn all tensors to scalars scalar_metrics = self.trainer.metrics_to_scalars(metrics) @@ -108,17 +129,15 @@ def add_progress_bar_metrics(self, metrics): self.trainer.dev_debugger.track_pbar_metrics_history(metrics) - def on_evaluation_epoch_end(self, deprecated_eval_results, epoch_logs, using_eval_result, test_mode): + def track_metrics_evaluation_epoch_end(self, deprecated_eval_results, epoch_logs, using_eval_result, test_mode): self._track_callback_metrics(deprecated_eval_results, using_eval_result) - self._log_on_evaluation_epoch_end_metrics(epoch_logs) + metrics_to_log = self.cached_metrics(self.trainer.testing)\ + .get_as_list("before_on_batch_start", "epoch_log_metrics") + self._track_callback_metrics_1_0(epoch_logs, metrics_to_log, reduce_on_epoch=True) # TODO: deprecate parts of this for 1.0 (when removing results) self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results, test_mode) - # get the final loop results - eval_loop_results = self._get_evaluate_epoch_results(test_mode) - return eval_loop_results - def _get_evaluate_epoch_results(self, test_mode): # log results of test if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test: @@ -128,13 +147,25 @@ def _get_evaluate_epoch_results(self, test_mode): pprint(results) print('-' * 80) - results = self.eval_loop_results + if self.trainer.testing: + callback_metrics = deepcopy(self.callback_metrics) + if self.trainer.dev_debugger.enabled: + callback_metrics.pop("debug_epoch") + self.eval_loop_results.append(callback_metrics) + results = [dict(ChainMap(*self.eval_loop_results))] + else: + results = self.eval_loop_results # clear mem self.eval_loop_results = [] return results - def _log_on_evaluation_epoch_end_metrics(self, epoch_logs): + def track_metrics_on_evaluation_epoch_start(self, logs, metrics_to_log=[]): + batch_logger_metrics = logs.get_batch_log_metrics() + if len(batch_logger_metrics) > 0: + metrics_to_log.append(batch_logger_metrics) + + def _track_callback_metrics_1_0(self, logs, metrics_to_log=[], reduce_on_epoch=False): step_metrics = self.trainer.evaluation_loop.step_metrics num_loaders = len(step_metrics) @@ -145,16 +176,13 @@ def _log_on_evaluation_epoch_end_metrics(self, epoch_logs): if self.trainer.running_sanity_check: return - # track all metrics we want to log - metrics_to_log = [] - # --------------------------- # UPDATE EPOCH LOGGED METRICS # --------------------------- # (ie: in methods at the val_epoch_end level) # union the epoch logs with whatever was returned from loaders and reduced - epoch_logger_metrics = epoch_logs.get_epoch_log_metrics() - epoch_pbar_metrics = epoch_logs.get_epoch_pbar_metrics() + epoch_logger_metrics = logs.get_epoch_log_metrics(add_dataloader_idx=True) + epoch_pbar_metrics = logs.get_epoch_pbar_metrics(add_dataloader_idx=True) self.logged_metrics.update(epoch_logger_metrics) self.add_progress_bar_metrics(epoch_pbar_metrics) @@ -171,50 +199,49 @@ def _log_on_evaluation_epoch_end_metrics(self, epoch_logs): # -------------------------------- # each dataloader aggregated metrics # now we log all of them - for dl_idx, dl_metrics in enumerate(step_metrics): - if len(dl_metrics) == 0: - # Ensure custom logged metrics are included if not included with step metrics - if len(epoch_logger_metrics) > 0: - self.eval_loop_results.append(epoch_logger_metrics) - continue - - reduced_epoch_metrics = dl_metrics[0].__class__.reduce_on_epoch_end(dl_metrics) - # make the keys 'k/dl' - reduced_epoch_metrics = self.__rename_keys_by_dataloader_idx(reduced_epoch_metrics, dl_idx, num_loaders) - - # track the metrics - logger_metrics = reduced_epoch_metrics.get_epoch_log_metrics() - pbar_metrics = reduced_epoch_metrics.get_epoch_pbar_metrics() - self.logged_metrics.update(logger_metrics) - self.add_progress_bar_metrics(pbar_metrics) - - # enable the metrics to be monitored - self.callback_metrics.update(logger_metrics) - self.callback_metrics.update(pbar_metrics) - - # forked metrics were dropped, enable them for callbacks - forked_metrics = reduced_epoch_metrics.get_forked_metrics() - self.callback_metrics.update(forked_metrics) - - # track the final results for the dataloader - self.eval_loop_results.append(deepcopy(self.callback_metrics)) - - # actually log - if len(logger_metrics) > 0: - metrics_to_log.append(logger_metrics) - - # log all the metrics as a s single dict + if reduce_on_epoch: + for dl_idx, dl_metrics in enumerate(step_metrics): + if len(dl_metrics) == 0: + continue + + reduced_epoch_metrics = dl_metrics[0].__class__.reduce_on_epoch_end(dl_metrics) + logger_metrics = reduced_epoch_metrics.get_epoch_log_metrics(add_dataloader_idx=True) + pbar_metrics = reduced_epoch_metrics.get_epoch_pbar_metrics(add_dataloader_idx=True) + forked_metrics = reduced_epoch_metrics.get_forked_metrics(add_dataloader_idx=True) + + # track the metrics + self.logged_metrics.update(logger_metrics) + self.add_progress_bar_metrics(pbar_metrics) + + # enable the metrics to be monitored + self.callback_metrics.update(logger_metrics) + self.callback_metrics.update(pbar_metrics) + + # forked metrics were dropped, enable them for callbacks + self.callback_metrics.update(forked_metrics) + + # track the final results for the dataloader + self.add_to_eval_loop_results(dl_idx) + + # actually log + if len(logger_metrics) > 0: + metrics_to_log.append(logger_metrics) + + def add_to_eval_loop_results(self, dl_idx): + callback_metrics = deepcopy(self.callback_metrics) + for key in list(callback_metrics.keys()): + if "/dataloader_idx_" in key: + dl_idx_in_key = int(key.split("_")[-1]) + # remove dl_idx from self.callback_metrics not belonging to this dataset. + if dl_idx_in_key != dl_idx: + del callback_metrics[key] + self.eval_loop_results.append(callback_metrics) + + def log_epoch_metrics_on_evaluation_end(self, metrics_to_log): metrics_to_log = dict(ChainMap(*metrics_to_log)) if len(metrics_to_log) > 0: self.log_metrics(metrics_to_log, {}) - def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders): - if num_loaders == 1: - return metrics - - result = {f'{k}/dataloader_idx_{dataloader_idx}': v for k, v in metrics.items()} - return result - def _track_callback_metrics(self, eval_results, using_eval_result): if ( len(eval_results) > 0 and @@ -347,6 +374,13 @@ def log_train_epoch_end_metrics(self, epoch_log_metrics.update(epoch_end_log_result.get_epoch_log_metrics()) epoch_progress_bar_metrics.update(epoch_end_log_result.get_epoch_pbar_metrics()) + cache_internal_epoch_log_metrics = self.cached_metrics("train")\ + .get_as_dict("after_on_batch_end", "epoch_log_metrics") + epoch_log_metrics.update(cache_internal_epoch_log_metrics) + + cache_internal_epoch_pbar_metrics = self.cached_metrics("train")\ + .get_as_dict("after_on_batch_end", "epoch_pbar_metrics") + epoch_progress_bar_metrics.update(cache_internal_epoch_pbar_metrics) # TODO: deprecate 1.0 else: out = self.__run_legacy_training_epoch_end( @@ -532,6 +566,10 @@ def log_train_step_metrics(self, batch_output): # logs user requested information to logger metrics = batch_output.batch_log_metrics grad_norm_dic = batch_output.grad_norm_dic + if metrics is None: + metrics = {} + if grad_norm_dic is None: + grad_norm_dic = {} if len(metrics) > 0 or len(grad_norm_dic) > 0: self.log_metrics(metrics, grad_norm_dic) self.callback_metrics.update(metrics) diff --git a/pytorch_lightning/trainer/connectors/logger_connector_utils/__init__.py b/pytorch_lightning/trainer/connectors/logger_connector_utils/__init__.py new file mode 100644 index 0000000000000..2870560e5a7b3 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector_utils/__init__.py @@ -0,0 +1,2 @@ +from pytorch_lightning.trainer.connectors.logger_connector_utils.cache_metrics import CacheInternalMetrics +from pytorch_lightning.trainer.connectors.logger_connector_utils.callback_logging_validator import LoggingCallbackValidator diff --git a/pytorch_lightning/trainer/connectors/logger_connector_utils/cache_metrics.py b/pytorch_lightning/trainer/connectors/logger_connector_utils/cache_metrics.py new file mode 100644 index 0000000000000..99b95634b3a47 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector_utils/cache_metrics.py @@ -0,0 +1,53 @@ +from collections import defaultdict, ChainMap + + +class CacheInternalMetrics: + """ + This class is an helper to cache model._results logged values before / after entering batch loop. + As on every `run_training_batch`, we apply model._results = Result() + and therefore delete any previously logged values + + before_on_batch_start is responsible to catch logging values from `on_start` to `on_batch_start` + after_on_batch_end is responsible to catch logging values from `on_batch_end` to `on_epoch_end` + """ + + stages = ["before_on_batch_start", "after_on_batch_end"] + + def __init__(self): + self.reset() + + def append(self, stage: str, key: str, value) -> None: + assert stage in self.stages, f"Provided stage {stage} should be within {self.stages}" + self._internal_dict[stage][key].append(value) + + def get_as_dict(self, stage, key): + _internal_metrics = self.get_as_list(stage, key) + return dict(ChainMap(*_internal_metrics)) + + def get_as_list(self, stage, key): + assert stage in self.stages, f"Provided stage {stage} should be within {self.stages}" + return self._internal_dict[stage][key] + + def __repr__(self): + return self._internal_dict.__repr__() + + def update(self, trainer, stage: str) -> None: + """ + This function is used to cache any logged information + between "on_train_start" to "on_train_epoch_start" callback hooks + """ + assert stage in self.stages, f"Provided stage {stage} should be within {self.stages}" + if not trainer.running_sanity_check: + model_ref = trainer.get_model() + + # save epoch metrics + self.append(stage, "epoch_log_metrics", model_ref._results.get_epoch_log_metrics(add_dataloader_idx=True)) + self.append(stage, "epoch_pbar_metrics", model_ref._results.get_epoch_pbar_metrics(add_dataloader_idx=True)) + + # save step/batch metrics + self.append(stage, "batch_log_metrics", model_ref._results.get_batch_log_metrics(add_dataloader_idx=True)) + self.append(stage, "batch_pbar_metrics", model_ref._results.get_batch_pbar_metrics(include_forked_originals=False, + add_dataloader_idx=True)) + + def reset(self): + self._internal_dict = {stage: defaultdict(list) for stage in self.stages} \ No newline at end of file diff --git a/pytorch_lightning/trainer/connectors/logger_connector_utils/callback_logging_validator.py b/pytorch_lightning/trainer/connectors/logger_connector_utils/callback_logging_validator.py new file mode 100644 index 0000000000000..263c4b890968a --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector_utils/callback_logging_validator.py @@ -0,0 +1,198 @@ +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class LoggingCallbackValidator: + + @staticmethod + def validate_callback_logging_arguments(current_hook_fx_name: str = None, on_step: bool = None, + on_epoch: bool = None) -> None: + current_callback_hook_auth_args = getattr(LoggingCallbackValidator, f"_{current_hook_fx_name}_log")() + + if current_callback_hook_auth_args is not None: + m = "{} function supports only {} in {}. Provided {}" + if on_step not in current_callback_hook_auth_args["on_step"]: + msg = m.format(current_hook_fx_name, "on_step", current_callback_hook_auth_args["on_step"], on_step) + raise MisconfigurationException(msg) + + if on_epoch not in current_callback_hook_auth_args["on_epoch"]: + msg = m.format(current_hook_fx_name, "on_epoch", current_callback_hook_auth_args["on_epoch"], on_epoch) + raise MisconfigurationException(msg) + else: + raise MisconfigurationException( + f"{current_hook_fx_name} function doesn't support logging using self.log() yet." + ) + + @staticmethod + def _setup_log(): + """Called when fit or test begins""" + return None + + @staticmethod + def _teardown_log(): + """Called at the end of fit and test""" + return None + + @staticmethod + def _on_init_start_log(): + """Called when the trainer initialization begins, model has not yet been set.""" + return None + + @staticmethod + def _on_init_end_log(): + """Called when the trainer initialization ends, model has not yet been set.""" + return None + + @staticmethod + def _on_fit_start_log(): + """Called when the trainer initialization begins, model has not yet been set.""" + return None + + @staticmethod + def _on_fit_end_log(): + """Called when the trainer initialization begins, model has not yet been set.""" + return None + + @staticmethod + def _on_sanity_check_start_log(): + """Called when the validation sanity check starts.""" + return None + + @staticmethod + def _on_sanity_check_end_log(): + """Called when the validation sanity check ends.""" + return None + + @staticmethod + def _on_train_epoch_start_log(): + """Called when the epoch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_train_epoch_end_log(): + """Called when the epoch ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_epoch_start_log(): + """Called when the epoch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_epoch_end_log(): + """Called when the epoch ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_test_epoch_start_log(): + """Called when the epoch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_test_epoch_end_log(): + """Called when the epoch ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_epoch_start_log(): + """Called when the epoch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_epoch_end_log(): + """Called when the epoch ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_train_start_log(): + """Called when the train begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_train_end_log(): + """Called when the train ends.""" + return None + + @staticmethod + def _on_pretrain_routine_start_log(): + """Called when the train begins.""" + return None + + @staticmethod + def _on_pretrain_routine_end_log(): + """Called when the train ends.""" + return None + + @staticmethod + def _on_batch_start_log(): + """Called when the training batch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_batch_end_log(): + """Called when the training batch ends.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_train_batch_start_log(): + """Called when the training batch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_train_batch_end_log(): + """Called when the training batch ends.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_batch_start_log(): + """Called when the validation batch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_batch_end_log(): + """Called when the validation batch ends.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_test_batch_start_log(): + """Called when the test batch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_test_batch_end_log(): + """Called when the test batch ends.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_start_log(): + """Called when the validation loop begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_end_log(): + """Called when the validation loop ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_test_start_log(): + """Called when the test begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_test_end_log(): + """Called when the test ends.""" + return None + + @staticmethod + def _on_keyboard_interrupt_log(): + """Called when the training is interrupted by KeyboardInterrupt.""" + return None + + @staticmethod + def _on_save_checkpoint_log(): + """Called when saving a model checkpoint.""" + return None + + @staticmethod + def _on_load_checkpoint_log(): + """Called when loading a model checkpoint.""" + return None diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 9dab036583dd8..47be3a3892a0a 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -29,6 +29,7 @@ def __init__(self, trainer): self.predictions = None self.max_batches = None self.warning_cache = WarningCache() + self.num_dataloaders = None def on_trainer_init(self): self.trainer.num_val_batches = [] @@ -83,6 +84,7 @@ def should_skip_evaluation(self, dataloaders, max_batches): return False def on_evaluation_start(self, *args, **kwargs): + self.trainer.logger_connector.cached_metrics(self.testing).reset() if self.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) else: @@ -102,12 +104,22 @@ def on_evaluation_model_train(self, *args, **kwargs): else: model_ref.on_validation_model_train() + def _update_metrics_to_log_after_evaluation_epoch_end(self): + if not self.trainer.running_sanity_check: + self.trainer.logger_connector.cached_metrics(self.testing).update(self.trainer, "after_on_batch_end") + metrics_to_log = self.trainer.logger_connector.cached_metrics(self.testing)\ + .get_as_list("after_on_batch_end", "epoch_log_metrics") + self.trainer.logger_connector._track_callback_metrics_1_0(self.trainer.get_model()._results, + metrics_to_log=metrics_to_log) + def on_evaluation_end(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_end', *args, **kwargs) + self._update_metrics_to_log_after_evaluation_epoch_end() + def reload_evaluation_dataloaders(self): model = self.trainer.get_model() if self.testing: @@ -133,6 +145,34 @@ def setup(self, model, max_batches, dataloaders): max_batches = [max_batches] * len(dataloaders) self.max_batches = max_batches + self.num_dataloaders = len(dataloaders) + + def _update_logger_connector_metrics(self): + model = self.trainer.get_model() + + # set batch_pbar_metrics cached from "on_train_start" to "on_train_epoch_start" + cache_internal_batch_pbar_metrics = self.trainer.logger_connector.cached_metrics(self.testing).get_as_dict( + "before_on_batch_start", "batch_pbar_metrics") + if len(cache_internal_batch_pbar_metrics) > 0: + self.trainer.logger_connector.add_progress_bar_metrics(cache_internal_batch_pbar_metrics) + + # set epoch_pbar_metrics cached from "on_train_start" to "on_train_epoch_start" + cache_internal_batch_log_metrics = self.trainer.logger_connector.cached_metrics(self.testing).get_as_dict( + "before_on_batch_start", "batch_log_metrics") + if len(cache_internal_batch_log_metrics) > 0: + self.trainer.logger_connector.callback_metrics.update(cache_internal_batch_log_metrics) + + # set batch_pbar_metrics cached from "on_train_start" to "on_train_epoch_start" + cache_internal_epoch_pbar_metrics = self.trainer.logger_connector.cached_metrics(self.testing).get_as_dict( + "before_on_batch_start", "epoch_pbar_metrics") + if len(cache_internal_epoch_pbar_metrics) > 0: + self.trainer.logger_connector.add_progress_bar_metrics(cache_internal_epoch_pbar_metrics) + + # set epoch_pbar_metrics cached from "on_train_start" to "on_train_epoch_start" + cache_internal_epoch_log_metrics = self.trainer.logger_connector.cached_metrics(self.testing).get_as_dict( + "before_on_batch_start", "epoch_log_metrics") + if len(cache_internal_epoch_log_metrics) > 0: + self.trainer.logger_connector.callback_metrics.update(cache_internal_epoch_log_metrics) def on_evaluation_epoch_start(self, *args, **kwargs): if self.testing: @@ -140,6 +180,10 @@ def on_evaluation_epoch_start(self, *args, **kwargs): else: self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) + # cache model._results logged values as it will be reset in next hook + self.trainer.logger_connector.cached_metrics(self.testing).update(self.trainer, "before_on_batch_start") + self._update_logger_connector_metrics() + def build_args(self, test_mode, batch, batch_idx, dataloader_idx): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] @@ -190,6 +234,7 @@ def evaluation_step_end(self, *args, **kwargs): return output def evaluation_epoch_end(self, num_dataloaders): + self._unset_dataloader_idx() using_eval_result = self.is_using_eval_results() # call the model epoch end @@ -203,16 +248,29 @@ def evaluation_epoch_end(self, num_dataloaders): if not isinstance(r, (dict, Result, torch.Tensor)): deprecated_results[i] = [] + # track and reduced metrics + self.track_metrics_evaluation_epoch_end( + deprecated_results, epoch_logs, self.testing) + return deprecated_results, epoch_logs - def log_epoch_metrics(self, deprecated_eval_results, epoch_logs, test_mode): + def track_metrics_evaluation_epoch_end(self, deprecated_eval_results, epoch_logs, test_mode): using_eval_result = self.is_using_eval_results() - eval_loop_results = self.trainer.logger_connector.on_evaluation_epoch_end( + self.trainer.logger_connector.track_metrics_evaluation_epoch_end( deprecated_eval_results, epoch_logs, using_eval_result, test_mode ) + + def log_epoch_metrics_on_evaluation_end(self): + metrics_to_log = self.trainer.logger_connector.cached_metrics(self.testing)\ + .get_as_list("before_on_batch_start", "epoch_log_metrics") + metrics_to_log += self.trainer.logger_connector.cached_metrics(self.testing)\ + .get_as_list("after_on_batch_end", "epoch_log_metrics") + self.trainer.logger_connector.log_epoch_metrics_on_evaluation_end(metrics_to_log) + # get the final loop results + eval_loop_results = self.trainer.logger_connector._get_evaluate_epoch_results(self.testing) return eval_loop_results def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): @@ -292,6 +350,16 @@ def __auto_reduce_result_objs(self, outputs): return eval_results + def _unset_dataloader_idx(self): + # reset the result of the PL module + model = self.trainer.get_model() + model._current_dataloader_idx = None + + def set_dataloader_idx(self, dl_idx): + # reset the result of the PL module + model = self.trainer.get_model() + model._current_dataloader_idx = dl_idx if self.num_dataloaders > 1 else None + def on_evaluation_batch_start(self, *args, **kwargs): # reset the result of the PL module model = self.trainer.get_model() @@ -320,12 +388,17 @@ def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx): self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output) def on_evaluation_epoch_end(self, *args, **kwargs): + # reset model result + model_ref = self.trainer.get_model() + model_ref._results = Result() # call the callback hook if self.testing: self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) + self._update_metrics_to_log_after_evaluation_epoch_end() + def log_evaluation_step_metrics(self, batch, batch_idx): results = self.trainer.get_model()._results if len(results) == 1: @@ -346,7 +419,10 @@ def log_evaluation_step_metrics_legacy(self, output, batch_idx): def __log_result_step_metrics(self, output, batch_idx): step_log_metrics = output.get_batch_log_metrics(include_forked_originals=False) - step_pbar_metrics = output.get_batch_pbar_metrics(include_forked_originals=False) + step_pbar_metrics = output.get_batch_pbar_metrics(include_forked_originals=False, add_dataloader_idx=True) + + batch_pbar_metrics = self.trainer.get_model()._results.batch_pbar_metrics + step_pbar_metrics.update(batch_pbar_metrics) if len(step_log_metrics) > 0: # make the metrics appear as a different line in the same graph diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 44250ae905aba..b563004714cdc 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -22,7 +22,8 @@ from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.core.step_result import EvalResult +from pytorch_lightning.core.memory import ModelSummary +from pytorch_lightning.core.step_result import Result, EvalResult from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.profiler import BaseProfiler from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin @@ -230,6 +231,8 @@ def __init__( num_nodes: number of GPU nodes for distributed training. + num_processes: number of processes for distributed training with distributed_backend="ddp_cpu" + num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. Set it to `-1` to run all batches in all validation dataloaders. Default: 2 @@ -529,8 +532,13 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches): return [], [] - # enable eval mode + no grads + # Load model and reset Result model = self.get_model() + + # reset result + model._results = Result() + + # enable eval mode + no grads self.evaluation_loop.on_evaluation_model_eval() model.zero_grad() @@ -554,6 +562,9 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): dataloader = self.accelerator_backend.process_dataloader(dataloader) dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] + # set dataloader idx in pl_model, so we can handle multi dataloaders logging. + self.evaluation_loop.set_dataloader_idx(dataloader_idx) + for batch_idx, batch in enumerate(dataloader): if batch is None: continue @@ -596,20 +607,20 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): num_dataloaders=len(dataloaders) ) - # bookkeeping - eval_loop_results = self.evaluation_loop.log_epoch_metrics(deprecated_eval_results, epoch_logs, test_mode) - self.evaluation_loop.predictions.to_disk() - # hook self.evaluation_loop.on_evaluation_epoch_end() + # hook + self.evaluation_loop.on_evaluation_end() + + # bookkeeping and logging + eval_loop_results = self.evaluation_loop.log_epoch_metrics_on_evaluation_end() + self.evaluation_loop.predictions.to_disk() + # enable train mode again self.evaluation_loop.on_evaluation_model_train() torch.set_grad_enabled(True) - # hook - self.evaluation_loop.on_evaluation_end() - return eval_loop_results, deprecated_eval_results def run_test(self): @@ -820,8 +831,15 @@ def call_hook(self, hook_name, *args, **kwargs): # first call trainer hook if hasattr(self, hook_name): + model_ref = self.get_model() + if model_ref is not None: + # used to track current hook name called + model_ref._current_hook_fx_name = hook_name trainer_hook = getattr(self, hook_name) trainer_hook(*args, **kwargs) + if model_ref is not None: + # set back current_hook_fx_name to its default value + model_ref._current_hook_fx_name = '' # next call hook in lightningModule output = None diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d32f47dbbd485..dc52fd2726923 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import subprocess from copy import copy, deepcopy import numpy as np @@ -25,6 +24,7 @@ from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator +from pytorch_lightning.trainer.connectors.logger_connector import CacheInternalMetrics from pytorch_lightning.utilities import parsing, AMPType from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -86,7 +86,10 @@ def on_train_start(self): torch.cuda.empty_cache() # hook - self.trainer.call_hook("on_train_start") + model_ref = self.trainer.get_model() + model_ref._results = Result() + self.trainer.logger_connector.cached_metrics("train").reset() + self.trainer.call_hook('on_train_start') def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): # bind logger and other properties @@ -238,16 +241,23 @@ def on_train_epoch_start(self, epoch): self.checkpoint_accumulator = Accumulator() # hook - self.trainer.call_hook("on_epoch_start") - self.trainer.call_hook("on_train_epoch_start") + self.trainer.call_hook('on_epoch_start') + self.trainer.call_hook('on_train_epoch_start') + self.trainer.logger_connector.cached_metrics("train").update(self.trainer, "before_on_batch_start") + + def on_train_batch_end(self, batch_output, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx): + # hook + model_ref = self.trainer.get_model() + model_ref._results = Result() + self.trainer.call_hook('on_batch_end') + self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch, batch_idx, dataloader_idx) - def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx): # figure out what to track for epoch end self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs) - # hook - self.trainer.call_hook("on_batch_end") - self.trainer.call_hook("on_train_batch_end", epoch_end_outputs, batch, batch_idx, dataloader_idx) + batch_pbar_metrics = model_ref._results.get_batch_pbar_metrics() + if len(batch_pbar_metrics) > 0: + self.trainer.logger_connector.add_progress_bar_metrics(batch_pbar_metrics) def reset_train_val_dataloaders(self, model): if not self.trainer.reload_dataloaders_every_epoch: @@ -256,12 +266,30 @@ def reset_train_val_dataloaders(self, model): if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_val_dataloader(model) + def _extend_epoch_end_outputs(self, opt_outputs): + """ + This function extend `opt_outputs` from `epoch_end_outputs` with any extra `epoch_log_metrics` + from 'on_batch_end' or 'on_train_batch_end' hooks. + """ + model_ref = self.trainer.get_model() + valid_keys = model_ref._results.epoch_log_metrics + for opt_output in opt_outputs: + if isinstance(opt_output, dict): + opt_output.update(valid_keys) + + _internal = {k: v for k, v in model_ref._results["meta"]["_internal"].items() if k in valid_keys} + meta = {k: v for k, v in model_ref._results["meta"].items() if k in valid_keys} + + opt_output["meta"]["_internal"].update(_internal) + opt_output["meta"].update(meta) + def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs): # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(epoch_end_outputs): # with 1 step (no tbptt) don't use a sequence at epoch end if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result): opt_outputs = opt_outputs[0] + self._extend_epoch_end_outputs(opt_outputs) epoch_output[opt_idx].append(opt_outputs) def get_optimizers_iterable(self): @@ -297,8 +325,7 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss): def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # give the PL module a result for logging model = self.trainer.get_model() - model._results = Result() - model._current_fx_name = "training_step" + model._current_fx_name = 'training_step' with self.trainer.profiler.profile("model_forward"): args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) @@ -385,7 +412,8 @@ def _process_training_step_output(self, training_step_output, split_batch): return training_step_output_for_epoch_end, training_step_output def _process_training_step_output_1_0(self, training_step_output, split_batch): - result = self.trainer.get_model()._results + model_ref = self.trainer.get_model() + result = model_ref._results loss = None hiddens = None @@ -498,6 +526,13 @@ def log_training_step_metrics(self, opt_closure_result, batch_callback_metrics, # track batch log metrics batch_log_metrics.append(metrics_to_log) + # add initially computed step metrics. + cache_internal_batch_pbar_metrics = self.trainer.logger_connector.cached_metrics("train").get_as_dict( + "before_on_batch_start", "batch_pbar_metrics") + if len(cache_internal_batch_pbar_metrics) > 0: + self.trainer.logger_connector.add_progress_bar_metrics(cache_internal_batch_pbar_metrics) + self.trainer.logger_connector.callback_metrics.update(cache_internal_batch_pbar_metrics) + # track progress bar metrics if len(step_pbar_metrics) > 0: self.trainer.logger_connector.add_progress_bar_metrics(step_pbar_metrics) @@ -557,7 +592,7 @@ def run_training_epoch(self): # hook # TODO: add outputs to batches - self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx) + self.on_train_batch_end(batch_output, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx) # ----------------------------------------- # SAVE METRICS TO LOGGERS @@ -605,34 +640,59 @@ def run_training_epoch(self): self.trainer.checkpoint_connector.has_trained = True - # log epoch metrics - self.trainer.logger_connector.log_train_epoch_end_metrics( - epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers - ) - # hook self.trainer.logger_connector.on_train_epoch_end(epoch_output) - # when no val loop is present or fast-dev-run still need to call checkpoints - self.check_checkpoint_callback(not (should_check_val or is_overridden("validation_step", model))) - # epoch end hook self.run_on_epoch_end_hook(epoch_output) + # log epoch metrics + self.trainer.logger_connector.log_train_epoch_end_metrics( + epoch_output, + self.checkpoint_accumulator, + self.early_stopping_accumulator, + self.num_optimizers + ) + + # when no val loop is present or fast-dev-run still need to call checkpoints + self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model))) + # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() + def _update_logger_connector_progress_bar_metrics(self): + model = self.trainer.get_model() + + # set batch_pbar_metrics cached from "on_train_start" to "on_train_epoch_start" + cache_internal_batch_pbar_metrics = self.trainer.logger_connector.cached_metrics("train").get_as_dict( + "before_on_batch_start", "batch_pbar_metrics") + if len(cache_internal_batch_pbar_metrics) > 0: + self.trainer.logger_connector.add_progress_bar_metrics(cache_internal_batch_pbar_metrics) + + # set epoch_pbar_metrics cached from "on_train_start" to "on_train_epoch_start" + cache_internal_epoch_pbar_metrics = self.trainer.logger_connector.cached_metrics("train").get_as_dict( + "before_on_batch_start", "epoch_pbar_metrics") + if len(cache_internal_epoch_pbar_metrics) > 0: + self.trainer.logger_connector.add_progress_bar_metrics(cache_internal_epoch_pbar_metrics) + + # set batch_pbar_metrics cached from "on_batch_start" to "on_train_batch_start" + batch_pbar_metrics = model._results.batch_pbar_metrics + if len(batch_pbar_metrics) > 0: + self.trainer.logger_connector.add_progress_bar_metrics(batch_pbar_metrics) + def run_training_batch(self, batch, batch_idx, dataloader_idx): + + # reset results + model = self.trainer.get_model() + model._results = Result() + # track grad norms grad_norm_dic = {} # track all metrics for callbacks batch_callback_metrics = [] - # track metrics to log - batch_log_metrics = [] - # bookkeeping using_results_obj = False self.trainer.hiddens = None @@ -653,6 +713,13 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) + # update progress bar metrics with pbar called within callback + self._update_logger_connector_progress_bar_metrics() + + # track metrics to log + batch_log_metrics = {} + batch_log_metrics.update(model._results.batch_log_metrics) + batch_log_metrics = [batch_log_metrics] # checks if backward or backward + optimizer step (via closure) accumulation_done = self._accumulated_batches_reached() is_final_batch = self._num_training_batches_reached() @@ -713,11 +780,13 @@ def train_step_and_backward_closure(): self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) else: - self._curr_step_result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) + self._curr_step_result = self.training_step(split_batch, batch_idx, opt_idx, + self.trainer.hiddens) if self._curr_step_result is None: - # user decided to skip optimization - continue + results = self.trainer.get_model()._results + batch_log_metrics.append(results.get_batch_log_metrics(include_forked_originals=False)) + batch_log_metrics.append(self.trainer.metrics_to_scalars(results.epoch_log_metrics)) batch_outputs = self._process_closure_result( batch_callback_metrics=batch_callback_metrics, @@ -836,8 +905,12 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None): self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics) def run_on_epoch_end_hook(self, epoch_output): - self.trainer.call_hook("on_epoch_end") - self.trainer.call_hook("on_train_epoch_end", epoch_output) + # reset result + internal metris to catch epoch end logging + model_ref = self.trainer.get_model() + model_ref._results = Result() + self.trainer.call_hook('on_epoch_end') + self.trainer.call_hook('on_train_epoch_end', epoch_output) + self.trainer.logger_connector.cached_metrics("train").update(self.trainer, "after_on_batch_end") def increment_accumulated_grad_global_step(self): num_accumulated_batches_reached = self._accumulated_batches_reached() diff --git a/tests/base/__init__.py b/tests/base/__init__.py index a337d443b4384..faefa623dfee7 100644 --- a/tests/base/__init__.py +++ b/tests/base/__init__.py @@ -3,4 +3,4 @@ from tests.base.datasets import TrialMNIST from tests.base.model_template import EvalModelTemplate, GenericEvalModelTemplate from tests.base.simple_model import SimpleModule -from tests.base.boring_model import BoringModel +from tests.base.boring_model import BoringModel, RandomDataset diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 976a91f551e0a..0e344051e1fb6 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -424,7 +424,7 @@ def test_default_checkpoint_behavior(tmpdir): trainer.fit(model) results = trainer.test() - assert len(results) == 1 + assert len(results) == 1, results assert results[0]['test_acc'] >= 0.80 assert len(trainer.dev_debugger.checkpoint_callback_history) == 3 diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 886e0db4e7854..93d796cd45721 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -307,7 +307,7 @@ def on_test_model_train(self): trainer.fit(model) - assert model.called == [ + excepted = [ 'on_fit_start', 'on_pretrain_routine_start', 'on_pretrain_routine_end', @@ -333,18 +333,20 @@ def on_test_model_train(self): 'on_validation_batch_start', 'on_validation_batch_end', 'on_validation_epoch_end', - 'on_validation_model_train', 'on_save_checkpoint', + 'on_validation_model_train', 'on_epoch_end', 'on_train_epoch_end', 'on_train_end', 'on_fit_end', ] + assert model.called == excepted + model2 = HookedModel() trainer.test(model2) - assert model2.called == [ + expected = [ 'on_fit_start', 'on_pretrain_routine_start', 'on_pretrain_routine_end', @@ -356,3 +358,5 @@ def on_test_model_train(self): 'on_test_model_train', 'on_fit_end', ] + + assert model2.called == expected diff --git a/tests/trainer/logging/test_eval_loop_logging_1_0.py b/tests/trainer/logging/test_eval_loop_logging_1_0.py index bce4a23dda157..39d87b5e19d48 100644 --- a/tests/trainer/logging/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging/test_eval_loop_logging_1_0.py @@ -15,13 +15,19 @@ Tests to ensure that the training loop works with a dict (1.0) """ from pytorch_lightning.core.lightning import LightningModule +import os +import collections +import itertools +import pytest +import numpy as np +import torch + +import pytorch_lightning as pl from pytorch_lightning import Trainer from pytorch_lightning import callbacks, seed_everything + from tests.base.deterministic_model import DeterministicModel -from tests.base import SimpleModule, BoringModel -import os -import torch -import pytest +from tests.base import SimpleModule, BoringModel, RandomDataset def test__validation_step__log(tmpdir): @@ -320,7 +326,6 @@ def test_eval_epoch_only_logging(tmpdir, batches, log_interval, max_epochs): """ Tests that only test_epoch_end can be used to log, and we return them in the results. """ - os.environ['PL_DEV_DEBUG'] = '1' class TestModel(BoringModel): def test_epoch_end(self, outputs): @@ -358,3 +363,352 @@ def test_monitor_val_epoch_end(tmpdir): checkpoint_callback=checkpoint_callback, ) trainer.fit(model) + + +def test_log_works_in_val_callback(tmpdir): + """ + Tests that log can be called within callback + """ + os.environ['PL_DEV_DEBUG'] = '1' + + class TestCallback(callbacks.Callback): + + # helpers + count = 1 + choices = [False, True] + # used to compute expected values + callback_funcs_called = collections.defaultdict(list) + funcs_called_count = collections.defaultdict(int) + funcs_attr = {} + + def make_logging(self, pl_module: pl.LightningModule, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[]): + """ + This function is used to log metrics and make sure everything is properly logged. + """ + + self.funcs_called_count[func_name] += 1 + for idx, t in enumerate(list(itertools.product(*[on_steps, on_epochs, prob_bars]))): + # run logging + on_step, on_epoch, prog_bar = t + custom_func_name = f"{func_idx}_{idx}_{func_name}" + pl_module.log(custom_func_name, self.count * func_idx, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) + # catch information for verification + self.callback_funcs_called[func_name].append([self.count * func_idx]) + self.funcs_attr[custom_func_name] = { + "on_step": on_step, + "on_epoch": on_epoch, + "prog_bar": prog_bar, + "func_name": func_name} + + if on_step and on_epoch: + self.funcs_attr[f"{custom_func_name}_step"] = { + "on_step": True, + "on_epoch": False, + "prog_bar": prog_bar, + "func_name": func_name} + + self.funcs_attr[f"{custom_func_name}_epoch"] = { + "on_step": False, + "on_epoch": True, + "prog_bar": prog_bar, + "func_name": func_name} + + def on_validation_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_validation_start', 1, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_epoch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_epoch_start', 2, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_validation_epoch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_validation_epoch_start', 3, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_batch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_batch_start', 4, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + self.make_logging(pl_module, 'on_validation_batch_start', 5, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_batch_end(self, trainer, pl_module): + self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self.make_logging(pl_module, 'on_validation_batch_end', 7, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + # used to make sure aggregation works fine. + # we should obtain func[value * c for c in range(1, max_epochs * limit_validation_batches)]) + # with func = np.mean if on_epoch else func = np.max + self.count += 1 + + def on_epoch_end(self, trainer, pl_module): + self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False], + on_epochs=self.choices, prob_bars=self.choices) + + def on_validation_epoch_end(self, trainer, pl_module): + self.make_logging(pl_module, 'on_validation_epoch_end', 9, on_steps=[False], + on_epochs=self.choices, prob_bars=self.choices) + + class TestModel(BoringModel): + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('val_loss', loss) + + max_epochs = 1 + model = TestModel() + model.validation_epoch_end = None + test_callback = TestCallback() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=0, + limit_val_batches=4, + limit_test_batches=0, + val_check_interval=0., + num_sanity_val_steps=0, + max_epochs=max_epochs, + callbacks=[test_callback], + ) + trainer.fit(model) + trainer.test() + + assert test_callback.funcs_called_count["on_epoch_start"] == 1 + assert test_callback.funcs_called_count["on_batch_start"] == 1 + assert test_callback.funcs_called_count["on_batch_end"] == 1 + assert test_callback.funcs_called_count["on_validation_start"] == 1 + assert test_callback.funcs_called_count["on_validation_epoch_start"] == 1 + assert test_callback.funcs_called_count["on_validation_batch_start"] == 4 + assert test_callback.funcs_called_count["on_validation_batch_end"] == 4 + assert test_callback.funcs_called_count["on_validation_epoch_end"] == 1 + assert test_callback.funcs_called_count["on_epoch_end"] == 1 + + # Make sure the func_name exists within callback_metrics. If not, we missed some + callback_metrics_keys = [*trainer.callback_metrics.keys()] + for func_name in test_callback.callback_funcs_called.keys(): + is_in = False + for callback_metrics_key in callback_metrics_keys: + if func_name in callback_metrics_key: + is_in = True + assert is_in, (func_name, callback_metrics_keys) + + # function used to describe expected return logic + def get_expected_output(func_attr, original_values): + + if func_attr["on_epoch"] and not func_attr["on_step"]: + # Apply mean on values + expected_output = np.mean(original_values) + else: + # Keep the latest value + expected_output = np.max(original_values) + return expected_output + + # Make sure the func_name output equals the average from all logged values when on_epoch true + # pop extra keys + trainer.callback_metrics.pop("debug_epoch") + trainer.callback_metrics.pop("val_loss") + for func_name, output_value in trainer.callback_metrics.items(): + + if torch.is_tensor(output_value): + output_value = output_value.item() + # get creation attr + func_attr = test_callback.funcs_attr[func_name] + + # retrived orginal logged values + original_values = test_callback.callback_funcs_called[func_attr["func_name"]] + + # compute expected output and compare to actual one + expected_output = get_expected_output(func_attr, original_values) + assert float(output_value) == float(expected_output) + + for func_name, func_attr in test_callback.funcs_attr.items(): + if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]): + assert func_name in trainer.logger_connector.progress_bar_metrics + else: + assert func_name not in trainer.logger_connector.progress_bar_metrics + + +def test_log_works_in_test_callback(tmpdir): + """ + Tests that log can be called within callback + """ + os.environ['PL_DEV_DEBUG'] = '1' + + class TestCallback(callbacks.Callback): + + # helpers + count = 1 + choices = [False, True] + + # used to compute expected values + callback_funcs_called = collections.defaultdict(list) + funcs_called_count = collections.defaultdict(int) + funcs_attr = {} + + def make_logging(self, pl_module: pl.LightningModule, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[]): + original_func_name = func_name[:] + self.funcs_called_count[original_func_name] += 1 + for idx, t in enumerate(list(itertools.product(*[on_steps, on_epochs, prob_bars]))): + # run logging + func_name = original_func_name[:] + on_step, on_epoch, prog_bar = t + custom_func_name = f"{func_idx}_{idx}_{func_name}" + + pl_module.log(custom_func_name, self.count * func_idx, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) + + num_dl_ext = '' + if pl_module._current_dataloader_idx is not None: + dl_idx = pl_module._current_dataloader_idx + num_dl_ext = f"/dataloader_idx_{dl_idx}" + func_name += num_dl_ext + + # catch information for verification + self.callback_funcs_called[func_name].append([self.count * func_idx]) + self.funcs_attr[custom_func_name + num_dl_ext] = { + "on_step": on_step, + "on_epoch": on_epoch, + "prog_bar": prog_bar, + "func_name": func_name, + "forked": on_step and on_epoch} + if on_step and on_epoch: + self.funcs_attr[f"{custom_func_name}_step" + num_dl_ext] = { + "on_step": True, + "on_epoch": False, + "prog_bar": prog_bar, + "func_name": func_name, + "forked": False} + + self.funcs_attr[f"{custom_func_name}_epoch" + num_dl_ext] = { + "on_step": False, + "on_epoch": True, + "prog_bar": prog_bar, + "func_name": func_name, + "forked": False} + + def on_test_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_test_start', 1, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_epoch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_epoch_start', 2, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_test_epoch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_test_epoch_start', 3, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + self.make_logging(pl_module, 'on_test_batch_start', 4, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self.make_logging(pl_module, 'on_test_batch_end', 5, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + # used to make sure aggregation works fine. + # we should obtain func[value * c for c in range(1, max_epochs * limit_test_batches)]) + # with func = np.mean if on_epoch else func = np.max + self.count += 1 + + def on_epoch_end(self, trainer, pl_module): + self.make_logging(pl_module, 'on_epoch_end', 6, on_steps=[False], + on_epochs=self.choices, prob_bars=self.choices) + + def on_test_epoch_end(self, trainer, pl_module): + self.make_logging(pl_module, 'on_test_epoch_end', 7, on_steps=[False], + on_epochs=self.choices, prob_bars=self.choices) + + max_epochs = 1 + num_dataloaders = 2 + + class TestModel(BoringModel): + + manual_mean = collections.defaultdict(list) + + def test_step(self, batch, batch_idx, dataloader_idx=None): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('test_loss', loss) + self.manual_mean[str(dataloader_idx)].append(loss) + + def test_dataloader(self): + return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)] + + model = TestModel() + model.test_epoch_end = None + test_callback = TestCallback() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=0, + limit_val_batches=0, + limit_test_batches=2, + val_check_interval=0., + num_sanity_val_steps=0, + max_epochs=max_epochs, + callbacks=[test_callback], + ) + trainer.fit(model) + trainer.test() + + assert test_callback.funcs_called_count["on_epoch_start"] == 1 + assert test_callback.funcs_called_count["on_test_start"] == 1 + assert test_callback.funcs_called_count["on_test_epoch_start"] == 1 + assert test_callback.funcs_called_count["on_test_batch_start"] == 4 + assert test_callback.funcs_called_count["on_test_batch_end"] == 4 + assert test_callback.funcs_called_count["on_test_epoch_end"] == 1 + assert test_callback.funcs_called_count["on_epoch_end"] == 1 + + # Make sure the func_name exists within callback_metrics. If not, we missed some + callback_metrics_keys = [*trainer.callback_metrics.keys()] + + for func_name in test_callback.callback_funcs_called.keys(): + is_in = False + for callback_metrics_key in callback_metrics_keys: + if func_name in callback_metrics_key: + is_in = True + assert is_in, (func_name, callback_metrics_keys) + + # function used to describe expected return logic + def get_expected_output(func_attr, original_values): + # Apply mean on values + if func_attr["on_epoch"] and not func_attr["on_step"]: + expected_output = np.mean(original_values) + else: + expected_output = np.max(original_values) + return expected_output + + # Make sure the func_name output equals the average from all logged values when on_epoch true + # pop extra keys + assert "debug_epoch" in trainer.callback_metrics + trainer.callback_metrics.pop("debug_epoch") + for dl_idx in range(num_dataloaders): + key = f"test_loss/dataloader_idx_{dl_idx}" + assert key in trainer.callback_metrics + assert torch.stack(model.manual_mean[str(dl_idx)]).mean() == trainer.callback_metrics[key] + trainer.callback_metrics.pop(key) + + for func_name, output_value in trainer.callback_metrics.items(): + if torch.is_tensor(output_value): + output_value = output_value.item() + + # get func attr + func_attr = test_callback.funcs_attr[func_name] + + # retrived orginal logged values + original_values = test_callback.callback_funcs_called[func_attr["func_name"]] + + # compute expected output and compare to actual one + expected_output = get_expected_output(func_attr, original_values) + assert float(output_value) == float(expected_output) + + for func_name, func_attr in test_callback.funcs_attr.items(): + if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: + assert func_name in trainer.logger_connector.progress_bar_metrics + else: + assert func_name not in trainer.logger_connector.progress_bar_metrics \ No newline at end of file diff --git a/tests/trainer/logging/test_train_loop_logging_1_0.py b/tests/trainer/logging/test_train_loop_logging_1_0.py index 414264894e639..a6ac27e50b14a 100644 --- a/tests/trainer/logging/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging/test_train_loop_logging_1_0.py @@ -17,9 +17,13 @@ from pytorch_lightning.core.lightning import LightningModule from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset import os +import numpy as np +import collections import torch import pytest - +import itertools +import pytorch_lightning as pl +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning import Trainer, callbacks from tests.base.deterministic_model import DeterministicModel from torch.utils.data import Dataset @@ -489,3 +493,169 @@ def validation_step(self, batch, batch_idx): weights_summary=None, ) trainer.fit(model, train_data, val_data) + + +def test_log_works_in_train_callback(tmpdir): + """ + Tests that log can be called within callback + """ + + os.environ['PL_DEV_DEBUG'] = '1' + + class TestCallback(callbacks.Callback): + + # helpers + count = 1 + choices = [False, True] + # used to compute expected values + callback_funcs_called = collections.defaultdict(list) + funcs_called_count = collections.defaultdict(int) + funcs_attr = {} + + def make_logging(self, pl_module: pl.LightningModule, func_name, func_idx, + on_steps=[], on_epochs=[], prob_bars=[]): + self.funcs_called_count[func_name] += 1 + for idx, t in enumerate(list(itertools.product(*[on_steps, on_epochs, prob_bars]))): + # run logging + on_step, on_epoch, prog_bar = t + custom_func_name = f"{func_idx}_{idx}_{func_name}" + pl_module.log(custom_func_name, self.count * func_idx, on_step=on_step, + on_epoch=on_epoch, prog_bar=prog_bar) + # catch information for verification + self.callback_funcs_called[func_name].append([self.count * func_idx]) + self.funcs_attr[custom_func_name] = { + "on_step": on_step, + "on_epoch": on_epoch, + "prog_bar": prog_bar, + "func_name": func_name} + + if on_step and on_epoch: + self.funcs_attr[f"{custom_func_name}_step"] = { + "on_step": True, + "on_epoch": False, + "prog_bar": prog_bar, + "func_name": func_name} + + self.funcs_attr[f"{custom_func_name}_epoch"] = { + "on_step": False, + "on_epoch": True, + "prog_bar": prog_bar, + "func_name": func_name} + + def on_train_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_train_start', 1, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_epoch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_epoch_start', 2, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_train_epoch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_train_epoch_start', 3, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_batch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_batch_start', 4, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + self.make_logging(pl_module, 'on_train_batch_start', 5, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_batch_end(self, trainer, pl_module): + self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self.make_logging(pl_module, 'on_train_batch_end', 7, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + # used to make sure aggregation works fine. + # we should obtain func[value * c for c in range(1, max_epochs * limit_train_batches)]) + # with func = np.mean if on_epoch else func = np.max + self.count += 1 + + def on_epoch_end(self, trainer, pl_module): + self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False], + on_epochs=self.choices, prob_bars=self.choices) + + def on_train_epoch_end(self, trainer, pl_module, outputs): + self.make_logging(pl_module, 'on_train_epoch_end', 9, on_steps=[False], + on_epochs=self.choices, prob_bars=self.choices) + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('train_loss', loss) + return {"loss": loss} + + max_epochs = 1 + limit_train_batches = 2 + model = TestModel() + test_callback = TestCallback() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=0, + limit_test_batches=0, + val_check_interval=0., + num_sanity_val_steps=0, + max_epochs=max_epochs, + callbacks=[test_callback] + ) + trainer.fit(model) + + assert test_callback.funcs_called_count["on_train_start"] == 1 + assert test_callback.funcs_called_count["on_epoch_start"] == 1 + assert test_callback.funcs_called_count["on_train_epoch_start"] == 1 + assert test_callback.funcs_called_count["on_batch_start"] == 2 + assert test_callback.funcs_called_count["on_train_batch_start"] == 2 + assert test_callback.funcs_called_count["on_batch_end"] == 2 + assert test_callback.funcs_called_count["on_train_batch_end"] == 2 + assert test_callback.funcs_called_count["on_epoch_end"] == 1 + assert test_callback.funcs_called_count["on_train_epoch_end"] == 1 + + # Make sure the func_name exists within callback_metrics. If not, we missed some + callback_metrics_keys = [*trainer.callback_metrics.keys()] + for func_name in test_callback.callback_funcs_called.keys(): + is_in = False + for callback_metrics_key in callback_metrics_keys: + if func_name in callback_metrics_key: + is_in = True + assert is_in, (func_name, callback_metrics_keys) + + # function used to describe expected return logic + def get_expected_output(func_attr, original_values): + if func_attr["on_epoch"] and not func_attr["on_step"]: + # Apply mean on values + expected_output = np.mean(original_values) + else: + # Keep the latest value + expected_output = np.max(original_values) + return expected_output + + # Make sure the func_name output equals the average from all logged values when on_epoch true + # pop extra keys + trainer.callback_metrics.pop("debug_epoch") + trainer.callback_metrics.pop("train_loss") + + for func_name, output_value in trainer.callback_metrics.items(): + if torch.is_tensor(output_value): + output_value = output_value.item() + # get creation attr + func_attr = test_callback.funcs_attr[func_name] + + # retrived orginal logged values + original_values = test_callback.callback_funcs_called[func_attr["func_name"]] + + # compute expected output and compare to actual one + expected_output = get_expected_output(func_attr, original_values) + assert float(output_value) == float(expected_output) + + for func_name, func_attr in test_callback.funcs_attr.items(): + if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]): + assert func_name in trainer.logger_connector.progress_bar_metrics + else: + assert func_name not in trainer.logger_connector.progress_bar_metrics diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 5f279c0b0a4db..8f8e1d421292d 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -117,8 +117,6 @@ def training_step(self, batch, batch_idx, optimizer_idx): opt_b.zero_grad() assert torch.all(self.layer.weight.grad == 0) - return {'something': 'else'} - def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer assert len(outputs) == 2 @@ -183,8 +181,6 @@ def training_step(self, batch, batch_idx, optimizer_idx): opt_b.zero_grad() assert torch.all(self.layer.weight.grad == 0) - return {'something': 'else'} - def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer assert len(outputs) == 2