diff --git a/CHANGELOG.md b/CHANGELOG.md index 45e62382da2a3..77eca7f2daacc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -111,6 +111,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * `trainer.{logged,progress_bar,callback}_metrics` are now updated on-demand ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882)) * Completely overhaul the `Result` object in favor of `ResultMetric` ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882)) * Improve epoch-level reduction time and overall memory usage ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882)) + * Allow passing `self.log(batch_size=...)` ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891)) + * Each of the training loops now keeps its own results collection ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891)) + - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) @@ -164,6 +167,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `num_nodes` and `sync_batchnorm` arguments in `DDPPlugin` and `DDPSpawnPlugin` ([#7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026)) +- Deprecated `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)`. ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891)) + + ### Removed - Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654)) diff --git a/docs/source/extensions/logging.rst b/docs/source/extensions/logging.rst index 107eca2dd9d74..12760f0ee6898 100644 --- a/docs/source/extensions/logging.rst +++ b/docs/source/extensions/logging.rst @@ -68,6 +68,10 @@ except functions with `batch_start` in their names. def training_step(self, batch, batch_idx): self.log('my_metric', x) + # or a dict + def training_step(self, batch, batch_idx): + self.log('performance', {'acc': acc, 'recall': recall}) + Depending on where log is called from, Lightning auto-determines the correct logging mode for you. \ But of course you can override the default behavior by manually setting the :func:`~~pytorch_lightning.core.lightning.LightningModule.log` parameters. diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index bbba327e7856c..02633d3df16fa 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -24,9 +24,8 @@ import uuid from abc import ABC from argparse import Namespace -from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import torch from torch import ScriptModule, Tensor @@ -43,16 +42,13 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin -from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed +from pytorch_lightning.utilities.distributed import sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache -if TYPE_CHECKING: - from pytorch_lightning.trainer.connectors.logger_connector.result import Result - warning_cache = WarningCache() log = logging.getLogger(__name__) @@ -109,7 +105,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # optionally can be set by user self._example_input_array = None self._datamodule = None - self._results: Optional['Result'] = None self._current_fx_name: Optional[str] = None self._running_manual_backward: bool = False self._current_dataloader_idx: Optional[int] = None @@ -267,14 +262,15 @@ def log( logger: bool = True, on_step: Optional[bool] = None, on_epoch: Optional[bool] = None, - reduce_fx: Callable = torch.mean, + reduce_fx: Union[str, Callable] = 'default', # TODO: change to 'mean' when `sync_dist_op` is removed in 1.6 tbptt_reduce_fx: Optional = None, # noqa: Remove in 1.6 tbptt_pad_token: Optional = None, # noqa: Remove in 1.6 enable_graph: bool = False, sync_dist: bool = False, - sync_dist_op: Union[Any, str] = 'mean', + sync_dist_op: Optional = None, # noqa: Remove in 1.6 sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, + batch_size: Optional[int] = None, ) -> None: """ Log a key, value @@ -298,7 +294,7 @@ def log( Args: name: key to log - value: value to log + value: value to log. Can be a ``float``, ``Tensor``, ``Metric``, or a dictionary of the former. prog_bar: if True logs to the progress bar logger: if True logs to the logger on_step: if True logs at this step. None auto-logs at the training_step but not validation/test_step @@ -306,11 +302,12 @@ def log( reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default. enable_graph: if True, will not auto detach the graph sync_dist: if True, reduces the metric across GPUs/TPUs - sync_dist_op: the op to sync across GPUs/TPUs sync_dist_group: the ddp group to sync across add_dataloader_idx: if True, appends the index of the current dataloader to the name (when using multiple). If False, user needs to give unique names for each dataloader to not mix values + batch_size: Current batch_size. This will be directly inferred from the loaded batch, + but some data structures might need to explicitly provide it. """ if tbptt_reduce_fx is not None: rank_zero_deprecation( @@ -324,6 +321,15 @@ def log( ' Please, open a discussion explaining your use-case in' ' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`' ) + if sync_dist_op is not None: + rank_zero_deprecation( + f"`self.log(sync_dist_op='{sync_dist_op}')` is deprecated and will be removed in v.1.6." + f" Use `self.log(reduce_fx={sync_dist_op})` instead." + ) + if reduce_fx == 'default': + reduce_fx = sync_dist_op + elif reduce_fx == 'default': + reduce_fx = 'mean' # check for invalid values apply_to_collection(value, dict, self.__check_not_nested, name) @@ -335,8 +341,10 @@ def log( on_step = self.__auto_choose_log_on_step(on_step) on_epoch = self.__auto_choose_log_on_epoch(on_epoch) + results = self.trainer._results + assert results is not None assert self._current_fx_name is not None - self.trainer.logger_connector.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) + results.fx_validator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) # make sure user doesn't introduce logic for multi-dataloaders if "/dataloader_idx_" in name: @@ -345,18 +353,15 @@ def log( " but it should not contain information about `dataloader_idx`" ) - sync_fn = partial( - self.__sync, - sync_fn=self.trainer.training_type_plugin.reduce, - sync_dist=sync_dist, - sync_dist_op=sync_dist_op, - sync_dist_group=sync_dist_group, - device=self.device, - ) - value = apply_to_collection(value, (torch.Tensor, numbers.Number), sync_fn) + value = apply_to_collection(value, numbers.Number, self.__to_tensor) - assert self._results is not None - self._results.log( + if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name): + # if we started a new epoch (running it's first batch) the hook name has changed + # reset any tensors for the new hook name + results.reset(metrics=False, fx=self._current_fx_name) + + results.log( + self._current_fx_name, name, value, prog_bar=prog_bar, @@ -366,8 +371,14 @@ def log( reduce_fx=reduce_fx, enable_graph=enable_graph, dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), + batch_size=batch_size, + sync_dist=sync_dist, + sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp_if_available, + sync_dist_group=sync_dist_group, ) + self.trainer.logger_connector._current_fx = self._current_fx_name + def log_dict( self, dictionary: Mapping[str, _METRIC_COLLECTION], @@ -375,12 +386,12 @@ def log_dict( logger: bool = True, on_step: Optional[bool] = None, on_epoch: Optional[bool] = None, - reduce_fx: Callable = torch.mean, + reduce_fx: Union[str, Callable] = 'default', # TODO: change to 'mean' when `sync_dist_op` is removed in 1.6 tbptt_reduce_fx: Optional = None, # noqa: Remove in 1.6 tbptt_pad_token: Optional = None, # noqa: Remove in 1.6 enable_graph: bool = False, sync_dist: bool = False, - sync_dist_op: Union[Any, str] = 'mean', + sync_dist_op: Optional = None, # noqa: Remove in 1.6 sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, ) -> None: @@ -393,7 +404,8 @@ def log_dict( self.log_dict(values) Args: - dictionary: key value pairs (str, tensors) + dictionary: key value pairs. + The values can be a ``float``, ``Tensor``, ``Metric``, or a dictionary of the former. prog_bar: if True logs to the progress base logger: if True logs to the logger on_step: if True logs at this step. None auto-logs for training_step but not validation/test_step @@ -401,7 +413,6 @@ def log_dict( reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default. enable_graph: if True, will not auto detach the graph sync_dist: if True, reduces the metric across GPUs/TPUs - sync_dist_op: the op to sync across GPUs/TPUs sync_dist_group: the ddp group sync across add_dataloader_idx: if True, appends the index of the current dataloader to the name (when using multiple). If False, user needs to give unique names for @@ -426,25 +437,7 @@ def log_dict( ) @staticmethod - def __sync( - value: Union[torch.Tensor, numbers.Number], - sync_fn: Optional[Callable] = None, - sync_dist: bool = False, - sync_dist_op: Union[Any, str] = 'mean', - sync_dist_group: Optional[Any] = None, - device: torch.device = None, - ) -> torch.Tensor: - """Sync across workers when using distributed training""" - if isinstance(value, numbers.Number): - value = torch.tensor(value, device=device, dtype=torch.float) - sync_fn = sync_fn or sync_ddp_if_available - dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed() - if not sync_dist or not dist_available: - return value - return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) - - @staticmethod - def __check_not_nested(value: dict, name: str) -> None: + def __check_not_nested(value: dict, name: str) -> dict: # self-imposed restriction. for simplicity if any(isinstance(v, dict) for v in value.values()): raise ValueError(f'`self.log({name}, {value})` was called, but nested dictionaries cannot be logged') @@ -454,6 +447,9 @@ def __check_not_nested(value: dict, name: str) -> None: def __check_allowed(v: Any, name: str, value: Any) -> None: raise ValueError(f'`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged') + def __to_tensor(self, value: numbers.Number) -> torch.Tensor: + return torch.tensor(value, device=self.device) + def log_grad_norm(self, grad_norm_dict: Dict[str, torch.Tensor]) -> None: """Override this method to change the default behaviour of ``log_grad_norm``. diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py index 7268074ffe169..058c7575cb3fd 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py @@ -141,7 +141,7 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None # track batch_size - self.trainer.results.extract_batch_size(batch) + self.trainer._results.extract_batch_size(batch) self._batch_idx = batch_idx def update_eval_step_metrics(self) -> None: @@ -210,7 +210,7 @@ def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT: """ def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None: - self.trainer.results.extract_batch_size(split_batch) + self.trainer._results.extract_batch_size(split_batch) self._batch_idx = batch_idx self._split_idx = split_idx @@ -232,7 +232,7 @@ def update_train_epoch_metrics(self) -> None: self.log_metrics(metrics) # reset result collection for next epoch - self.trainer.results.reset(metrics=True) + self.trainer._results.reset(metrics=True) """ Utilities and properties @@ -273,7 +273,7 @@ def should_reset_tensors(self, fx: str) -> bool: return is_different_fx and is_first_batch def reset(self, metrics: Optional[bool] = None) -> None: - self.trainer.results.reset(metrics=metrics) + self.trainer._results.reset(metrics=metrics) self._batch_idx = None self._split_idx = None self._current_fx = None @@ -282,25 +282,25 @@ def reset(self, metrics: Optional[bool] = None) -> None: def metrics(self) -> Dict[MetricSource, Dict[str, _METRIC]]: """This function returns either batch or epoch metrics depending on ``_epoch_end_reached``.""" on_step = not self._epoch_end_reached - return self.trainer.results.metrics(on_step) + return self.trainer._results.metrics(on_step) @property def callback_metrics(self) -> Dict[str, _METRIC]: - if self.trainer.results: + if self.trainer._results: metrics = self.metrics[MetricSource.CALLBACK] self._callback_metrics.update(metrics) return self._callback_metrics @property def logged_metrics(self) -> Dict[str, _METRIC]: - if self.trainer.results: + if self.trainer._results: metrics = self.metrics[MetricSource.LOG] self._logged_metrics.update(metrics) return self._logged_metrics @property def progress_bar_metrics(self) -> Dict[str, float]: - if self.trainer.results: + if self.trainer._results: metrics = self.metrics[MetricSource.PBAR] self._progress_bar_metrics.update(metrics) return self._progress_bar_metrics diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result_new.py b/pytorch_lightning/trainer/connectors/logger_connector/result_new.py index 9141cb24bff36..bf155ad4210e5 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result_new.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result_new.py @@ -65,7 +65,6 @@ class _Metadata: reduce_fx: Union[str, Callable] = torch.mean enable_graph: bool = False dataloader_idx: Optional[int] = None - metric_attribute: Optional[str] = None sync: _Sync = field(default_factory=_Sync) def __post_init__(self) -> None: @@ -225,7 +224,7 @@ class ResultCollection(dict): Example: # `device` needs to be provided before logging - result = ResultCollection(True, torch.device("cpu")) + result = ResultCollection(training=True, torch.device("cpu")) # you can log to a specific collection. # arguments: fx, key, value, metadata @@ -303,7 +302,6 @@ def log( sync_dist_group: Optional[Any] = None, dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, - metric_attribute: Optional[str] = None, ) -> None: """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs @@ -331,7 +329,6 @@ def log( reduce_fx=reduce_fx, enable_graph=enable_graph, dataloader_idx=dataloader_idx, - metric_attribute=metric_attribute, sync=_Sync( should=sync_dist, fn=sync_dist_fn, diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index dff3165e62853..85f937700a87f 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -17,7 +17,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.trainer.connectors.logger_connector.result import Result +from pytorch_lightning.trainer.connectors.logger_connector.result_new import ResultCollection from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.model_helpers import is_overridden @@ -34,6 +34,16 @@ def __init__(self, trainer: 'pl.Trainer'): self.max_batches: Optional[List[Union[int, float]]] = None self.warning_cache = WarningCache() self.num_dataloaders: Optional[int] = None + self._val_results = ResultCollection(training=False) + self._test_results = ResultCollection(training=False) + + @property + def results(self) -> Optional[ResultCollection]: + if self.trainer.validating or self.trainer.sanity_checking: + return self._val_results + elif self.trainer.testing: + return self._test_results + return None def on_trainer_init(self) -> None: self.trainer.num_sanity_val_batches = [] @@ -77,6 +87,10 @@ def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool: def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end() + + assert self.results is not None + self.results.to(device=self.trainer.lightning_module.device) + if self.trainer.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) else: @@ -106,6 +120,9 @@ def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: # summarize profile results self.trainer.profiler.describe() + # reset any `torchmetrics.Metric` and the logger connector state + self.trainer.logger_connector.reset(metrics=True) + def reload_evaluation_dataloaders(self) -> None: model = self.trainer.lightning_module if self.trainer.testing: @@ -126,6 +143,7 @@ def setup(self, max_batches: List[Union[int, float]], dataloaders: List[DataLoad self.num_dataloaders = self._get_num_dataloaders(dataloaders) def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: + self.trainer.logger_connector.on_epoch_start() self.trainer.call_hook('on_epoch_start', *args, **kwargs) if self.trainer.testing: @@ -162,24 +180,15 @@ def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Op # configure step_kwargs step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) - model_ref = self.trainer.lightning_module - model_ref._results = Result() - if self.trainer.testing: - model_ref._current_fx_name = "test_step" + self.trainer.lightning_module._current_fx_name = "test_step" with self.trainer.profiler.profile("test_step"): output = self.trainer.accelerator.test_step(step_kwargs) else: - model_ref._current_fx_name = "validation_step" + self.trainer.lightning_module._current_fx_name = "validation_step" with self.trainer.profiler.profile("validation_step"): output = self.trainer.accelerator.validation_step(step_kwargs) - # capture any logged information - self.trainer.logger_connector.cache_logged_metrics() - # track batch size for weighted average - if isinstance(output, Result): - output.track_batch_size(batch) - return output def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: @@ -197,12 +206,15 @@ def _should_track_batch_outputs_for_epoch_end(self) -> bool: return is_overridden('validation_epoch_end', model=model) def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: - # unset dataloder_idx in model - self.trainer.logger_connector.evaluation_epoch_end() + # inform logger the batch loop has finished + self.trainer.logger_connector.epoch_end_reached() # call the model epoch end model = self.trainer.lightning_module + # unset dataloader_idx in model + model._current_dataloader_idx = None + if self.trainer.testing: if is_overridden('test_epoch_end', model=model): model._current_fx_name = 'test_epoch_end' @@ -213,13 +225,12 @@ def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: model._current_fx_name = 'validation_epoch_end' model.validation_epoch_end(outputs) - # capture logging - self.trainer.logger_connector.cache_logged_metrics() - def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + self.trainer.logger_connector.on_batch_start() + # set dataloader_idx to model and track batch_size assert self.num_dataloaders is not None - self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self.num_dataloaders) + self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self.num_dataloaders) if self.trainer.testing: self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) @@ -238,13 +249,15 @@ def on_evaluation_batch_end( else: self.trainer.call_hook('on_validation_batch_end', output, batch, batch_idx, dataloader_idx) + self.trainer.logger_connector.on_batch_end() + # store predicitons if do_write_predictions and track eval loss history self.store_predictions(output, batch_idx, dataloader_idx) def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, dataloader_idx: int) -> None: # Add step predictions to prediction collection to write later if output is not None and self.predictions is not None: - if isinstance(output, Result) and self.trainer.testing: + if isinstance(output, ResultCollection) and self.trainer.testing: self.predictions.add(output.pop('predictions', None)) # track debug metrics @@ -254,3 +267,4 @@ def on_evaluation_epoch_end(self) -> None: hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" self.trainer.call_hook(hook_name) self.trainer.call_hook('on_epoch_end') + self.trainer.logger_connector.on_epoch_end() diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index aa659ac766e85..b007af60393ae 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -32,7 +32,9 @@ from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector -from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector +from pytorch_lightning.trainer.connectors.logger_connector.logger_connector_new import LoggerConnectorNew +from pytorch_lightning.trainer.connectors.logger_connector.result_new import ResultCollection +from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.states import RunningStage, TrainerState, TrainerStatus from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn @@ -58,9 +60,10 @@ class TrainerProperties(ABC): checkpoint_connector: CheckpointConnector limit_val_batches: int logger: LightningLoggerBase - logger_connector: LoggerConnector + logger_connector: LoggerConnectorNew state: TrainerState train_loop: TrainLoop + evaluation_loop: EvaluationLoop """ Accelerator properties """ @@ -504,6 +507,13 @@ def max_steps(self) -> Optional[int]: def min_steps(self) -> Optional[int]: return self.train_loop.min_steps + @property + def _active_loop(self) -> Optional[Union[TrainLoop, EvaluationLoop]]: + if self.training: + return self.train_loop + elif self.sanity_checking or self.evaluating: + return self.evaluation_loop + """ Logging properties """ @@ -512,25 +522,19 @@ def min_steps(self) -> Optional[int]: def callback_metrics(self) -> dict: return self.logger_connector.callback_metrics - @callback_metrics.setter - def callback_metrics(self, x: dict) -> None: - self.logger_connector.callback_metrics = x - @property def logged_metrics(self) -> dict: return self.logger_connector.logged_metrics - @logged_metrics.setter - def logged_metrics(self, x: dict) -> None: - self.logger_connector.logged_metrics = x - @property def progress_bar_metrics(self) -> dict: return self.logger_connector.progress_bar_metrics - @progress_bar_metrics.setter - def progress_bar_metrics(self, x: dict) -> None: - self.logger_connector.progress_bar_metrics = x + @property + def _results(self) -> Optional[ResultCollection]: + active_loop = self._active_loop + if active_loop is not None: + return active_loop.results """ Other diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 66047cd1110ff..5a2ce5a6f0821 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -46,8 +46,8 @@ from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars -from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector -from pytorch_lightning.trainer.connectors.logger_connector.result import Result +from pytorch_lightning.trainer.connectors.logger_connector.logger_connector_new import LoggerConnectorNew +from pytorch_lightning.trainer.connectors.logger_connector.result_new import ResultCollection from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector @@ -326,7 +326,7 @@ def __init__( num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, sync_batchnorm, benchmark, replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins ) - self.logger_connector = LoggerConnector(self, log_gpu_memory) + self.logger_connector = LoggerConnectorNew(self, log_gpu_memory) self.model_connector = ModelConnector(self) self.callback_connector = CallbackConnector(self) self.debugging_connector = DebuggingConnector(self) @@ -806,6 +806,7 @@ def _pre_dispatch(self): def _post_dispatch(self): self.accelerator.post_dispatch(self) self.accelerator.teardown() + self.logger_connector.teardown() def _dispatch(self): if self.evaluating: @@ -985,7 +986,7 @@ def _run_evaluation(self) -> _EVALUATE_OUTPUT: self.evaluation_loop.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) # log batch metrics - self.logger_connector.log_evaluation_step_metrics() + self.logger_connector.update_eval_step_metrics() # track epoch level outputs dl_outputs = self._track_output_for_epoch_end(dl_outputs, output) @@ -1010,7 +1011,7 @@ def _run_evaluation(self) -> _EVALUATE_OUTPUT: self.evaluation_loop.on_evaluation_epoch_end() # log epoch metrics - eval_loop_results = self.logger_connector.get_evaluate_epoch_results() + eval_loop_results = self.logger_connector.update_eval_epoch_metrics() # hook self.evaluation_loop.on_evaluation_end() @@ -1021,16 +1022,13 @@ def _run_evaluation(self) -> _EVALUATE_OUTPUT: # enable train mode again self.evaluation_loop.on_evaluation_model_train() - # reset cached results - self.logger_connector.reset() - torch.set_grad_enabled(True) return eval_loop_results def _track_output_for_epoch_end(self, outputs, output): if output is not None: - if isinstance(output, Result): + if isinstance(output, ResultCollection): output = output.detach() if self.move_metrics_to_cpu: output = output.cpu() @@ -1115,12 +1113,16 @@ def _run_sanity_check(self, ref_model): self.on_sanity_check_end() - self.state.stage = stage + # reset validation metrics + self.logger_connector.reset() # reset the seed to what it was before sanity check # prevents sanity check to affect random sampling in training reset_seed() + # restore the previous stage when the sanity check if finished + self.state.stage = stage + def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]: if ckpt_path is None: return @@ -1194,32 +1196,14 @@ def _call_teardown_hook(self, model: LightningModule) -> None: model._current_fx_name = None model._current_dataloader_idx = None - def _reset_result_and_set_fx_name(self, hook_name: str) -> bool: - # on_before_zero_grad is called within training_step - # TODO(@carmocca): Result should handle this logic - if "batch_start" in hook_name or hook_name in ("on_before_zero_grad", "on_after_backward"): - return True - model_ref = self.lightning_module - if model_ref is not None: - # used to track current hook name called - model_ref._results = Result() - model_ref._current_fx_name = hook_name - return False - - def _cache_logged_metrics(self): - model_ref = self.lightning_module - if model_ref is not None: - # capture logging for this hook - self.logger_connector.cache_logged_metrics() - def call_hook(self, hook_name: str, *args, **kwargs) -> Any: # Note this implementation is copy/pasted into the TrainLoop class in TrainLoop._on_train_epoch_end_hook # This was done to manage the deprecation of the `outputs` argument to on_train_epoch_end # If making changes to this function, ensure that those changes are also made to # TrainLoop._on_train_epoch_end_hook - - # set hook_name to model + reset Result obj - skip = self._reset_result_and_set_fx_name(hook_name) + if self.lightning_module: + prev_fx_name = self.lightning_module._current_fx_name + self.lightning_module._current_fx_name = hook_name # always profile hooks with self.profiler.profile(hook_name): @@ -1245,8 +1229,10 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: # todo: move this data parallel logic into the data parallel plugin output = accelerator_output if output is None else output - if not skip: - self._cache_logged_metrics() + if self.lightning_module: + # restore current_fx when nested context + self.lightning_module._current_fx_name = prev_fx_name + return output @staticmethod diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 9b3dfed2c840b..2ad88c38a99ea 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -14,7 +14,6 @@ from collections import OrderedDict from contextlib import contextmanager, suppress -from copy import copy from functools import partial, update_wrapper from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union @@ -24,7 +23,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins import ParallelPlugin -from pytorch_lightning.trainer.connectors.logger_connector.result import Result +from pytorch_lightning.trainer.connectors.logger_connector.result_new import ResultCollection from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType from pytorch_lightning.utilities.distributed import rank_zero_info @@ -83,6 +82,8 @@ def __init__( else: self.trainer.num_sanity_val_steps = num_sanity_val_steps + self.results = ResultCollection(training=True) + @property def num_active_optimizers(self) -> int: return len(self.get_active_optimizers()) @@ -99,7 +100,8 @@ def should_skip_training(self) -> bool: return should_by_max_steps or should_by_epoch or self.trainer.num_training_batches == 0 def on_train_start(self): - # hook + self.results.to(device=self.trainer.lightning_module.device) + self.trainer.call_hook("on_train_start") def on_train_end(self): @@ -167,6 +169,7 @@ def on_train_epoch_start(self, epoch): self.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches) # hook + self.trainer.logger_connector.on_epoch_start() self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") @@ -178,13 +181,11 @@ def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, # hook self.trainer.call_hook('on_train_batch_end', processed_batch_end_outputs, batch, batch_idx, dataloader_idx) self.trainer.call_hook('on_batch_end') + self.trainer.logger_connector.on_batch_end() # figure out what to track for epoch end self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs) - # reset batch logger internals - self.trainer.logger_connector.on_train_batch_end() - def reset_train_val_dataloaders(self, model) -> None: """ Resets train and val dataloaders if none are attached to the trainer. @@ -199,22 +200,17 @@ def reset_train_val_dataloaders(self, model) -> None: self.trainer.reset_val_dataloader(model) def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): - hook_overridden = self._should_add_batch_output_to_epoch_output() + if not hook_overridden: + return # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(batch_end_outputs): - sample_output = opt_outputs[-1] - - # decide if we need to reduce at the end of the epoch automatically - auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end - - # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end - if not (hook_overridden or auto_reduce_tng_result): - continue - # with 1 step (no tbptt) don't use a sequence at epoch end - if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result): + if ( + isinstance(opt_outputs, list) and len(opt_outputs) == 1 + and not isinstance(opt_outputs[0], ResultCollection) + ): opt_outputs = opt_outputs[0] epoch_output[opt_idx].append(opt_outputs) @@ -256,8 +252,6 @@ def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[i return [(opt_idx, self.trainer.optimizers[opt_idx])] def on_after_backward(self, training_step_output, batch_idx, untouched_loss): - training_step_output.detach() - # insert after step hook self.trainer.call_hook("on_after_backward") @@ -289,55 +283,35 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # manually capture logged metrics model_ref._current_fx_name = 'training_step' - model_ref._results = Result() with self.trainer.profiler.profile("training_step"): training_step_output = self.trainer.accelerator.training_step(step_kwargs) self.trainer.accelerator.post_training_step() - self.trainer.logger_connector.cache_logged_metrics() - training_step_output = self.trainer.call_hook("training_step_end", training_step_output) self._check_training_step_output(training_step_output) - training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( - training_step_output, split_batch - ) - if training_step_output_for_epoch_end is None: + training_step_output = self._process_training_step_output(training_step_output) + if training_step_output is None: return - # enable empty loss when using manual opt closure_loss = None - untouched_loss = None - + loss = None if self.trainer.lightning_module.automatic_optimization: # accumulate loss. if accumulate_grad_batches==1, no effect closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches - # the loss will get scaled for amp. avoid any modifications to it - untouched_loss = closure_loss.detach().clone() - - # result - result = AttributeDict( - closure_loss=closure_loss, - loss=untouched_loss, - training_step_output=training_step_output, - training_step_output_for_epoch_end=training_step_output_for_epoch_end, - ) - return result + loss = closure_loss.detach().clone() + return AttributeDict(closure_loss=closure_loss, loss=loss, training_step_output=training_step_output) - def _process_training_step_output(self, training_step_output, split_batch): - training_step_output_for_epoch_end = training_step_output - - # enable validation_step return None - if training_step_output_for_epoch_end is None: - return None, None - - result = self.trainer.lightning_module._results + def _process_training_step_output(self, training_step_output): + if training_step_output is None: + return None + results = self.results loss = None hiddens = None - result["extra"] = {} + results.extra = {} # handle dict return if isinstance(training_step_output, dict): @@ -345,44 +319,37 @@ def _process_training_step_output(self, training_step_output, split_batch): hiddens = training_step_output.pop("hiddens", None) if hiddens is not None: hiddens = hiddens.detach() - result["extra"] = training_step_output + results.extra = training_step_output # handle scalar return elif isinstance(training_step_output, torch.Tensor): loss = training_step_output # map to results under the hood - result.minimize = loss + results.minimize = loss self._hiddens = hiddens - # track batch for manual reduction with result - result.track_batch_size(len(split_batch)) - - # track metrics without grads for epoch reduction - training_step_output_for_epoch_end = copy(result) - training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() if self.trainer.move_metrics_to_cpu: - training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu() - - return training_step_output_for_epoch_end, result + results.cpu() + return results @staticmethod def _prepare_outputs( - outputs: List[List[List[Result]]], + outputs: List[List[List['ResultCollection']]], batch_mode: bool, ) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]: """ Extract required information from batch or epoch end results. Args: - outputs: A 3-dimensional list of ``Result`` objects with dimensions: - [optimizer outs][batch outs][tbptt steps]. + outputs: A 3-dimensional list of ``ResultCollection`` objects with dimensions: + ``[optimizer outs][batch outs][tbptt steps]``. batch_mode: If True, ignore the batch output dimension. Returns: - The cleaned outputs with ``Result`` objects converted to dictionaries. All list dimensions of size one will - be collapsed. + The cleaned outputs with ``ResultCollection`` objects converted to dictionaries. + All list dimensions of size one will be collapsed. """ processed_outputs = [] for opt_outputs in outputs: @@ -398,6 +365,9 @@ def _prepare_outputs( for batch_outputs in opt_outputs: processed_tbptt_outputs = [] + if isinstance(batch_outputs, ResultCollection): + batch_outputs = [batch_outputs] + for tbptt_output in batch_outputs: out = tbptt_output.extra if tbptt_output.minimize is not None: @@ -499,19 +469,18 @@ def run_training_epoch(self): break # hook - # TODO: add outputs to batches self.on_train_batch_end( epoch_output, - batch_output.training_step_output_for_epoch_end, + batch_output.training_step_output, batch, batch_idx, dataloader_idx, ) # ----------------------------------------- - # SAVE METRICS TO LOGGERS + # SAVE METRICS TO LOGGERS AND PROGRESS_BAR # ----------------------------------------- - self.trainer.logger_connector.log_train_step_metrics() + self.trainer.logger_connector.update_train_step_metrics() # ----------------------------------------- # VALIDATE IF NEEDED @@ -553,7 +522,7 @@ def run_training_epoch(self): # TODO(@carmocca): deprecate and rename so users don't get confused self.global_step -= 1 # log epoch metrics - self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output) + self.trainer.logger_connector.update_train_epoch_metrics() self.global_step += 1 self.update_lr_schedulers('epoch') @@ -566,9 +535,9 @@ def run_training_epoch(self): self.check_checkpoint_callback(True) self.global_step += 1 - def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None: + def on_train_epoch_end(self, epoch_output: List[List[List['ResultCollection']]]) -> None: # inform logger the batch loop has finished - self.trainer.logger_connector.on_train_epoch_end() + self.trainer.logger_connector.epoch_end_reached() # prepare epoch output processed_epoch_output = TrainLoop._prepare_outputs(epoch_output, batch_mode=False) @@ -588,12 +557,10 @@ def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None: 'HINT: remove the return statement in training_epoch_end' ) - # capture logging - self.trainer.logger_connector.cache_logged_metrics() - # call train epoch end hooks self._on_train_epoch_end_hook(processed_epoch_output) self.trainer.call_hook('on_epoch_end') + self.trainer.logger_connector.on_epoch_end() def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: # We cannot rely on Trainer.call_hook because the signatures might be different across @@ -602,9 +569,8 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: # This implementation is copied from Trainer.call_hook hook_name = "on_train_epoch_end" - - # set hook_name to model + reset Result obj - skip = self.trainer._reset_result_and_set_fx_name(hook_name) + prev_fx_name = self.trainer.lightning_module._current_fx_name + self.trainer.lightning_module._current_fx_name = hook_name # always profile hooks with self.trainer.profiler.profile(hook_name): @@ -633,8 +599,8 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: accelerator_hook = getattr(self.trainer.accelerator, hook_name) accelerator_hook() - if not skip: - self.trainer._cache_logged_metrics() + # restore current_fx when nested context + self.trainer.lightning_module._current_fx_name = prev_fx_name def run_training_batch(self, batch, batch_idx, dataloader_idx): # bookkeeping @@ -647,9 +613,10 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): if batch is None: self.warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") - return AttributeDict(signal=0, training_step_output_for_epoch_end=batch_outputs) + return AttributeDict(signal=0, training_step_output=batch_outputs) # hook + self.trainer.logger_connector.on_batch_start() response = self.trainer.call_hook("on_batch_start") if response == -1: return AttributeDict(signal=-1) @@ -665,25 +632,28 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): for split_idx, split_batch in enumerate(splits): self.split_idx = split_idx + # let logger connector extract batch size + self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch) + if self.trainer.lightning_module.automatic_optimization: for opt_idx, optimizer in self.get_active_optimizers(batch_idx): - result = self._run_optimization(batch_idx, split_idx, split_batch, opt_idx, optimizer) + result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer) if result: - batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end) + batch_outputs[opt_idx].append(result.training_step_output) else: # in manual optimization, there is no looping over optimizers - result = self._run_optimization(batch_idx, split_idx, split_batch) + result = self._run_optimization(batch_idx, split_batch) if result: - batch_outputs[0].append(result.training_step_output_for_epoch_end) + batch_outputs[0].append(result.training_step_output) - return AttributeDict(signal=0, training_step_output_for_epoch_end=batch_outputs) + return AttributeDict(signal=0, training_step_output=batch_outputs) - def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimizer=None): + def _run_optimization(self, batch_idx, split_batch, opt_idx=0, optimizer=None): # TODO: In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change # opt_idx=0 to opt_idx=None in the signature here - # toggle model params + set info to logger_connector - self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) + # toggle model params + self.run_optimization_start(opt_idx, optimizer) result = AttributeDict() closure = self.make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result) @@ -770,9 +740,6 @@ def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) - if not opt_closure_result: return - # cache metrics - self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) - # check if loss or model weights are nan if self.trainer.terminate_on_nan: self._check_finite(opt_closure_result.loss) @@ -939,16 +906,13 @@ def save_loggers_on_train_batch_end(self): if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() - def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): + def run_optimization_start(self, opt_idx, optimizer): # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1: model = self.trainer.lightning_module model.toggle_optimizer(optimizer, opt_idx) - # use to track metrics internally - self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch) - def update_running_loss(self, current_loss: torch.Tensor) -> None: if self.trainer.lightning_module.automatic_optimization: # track total loss for logging (avoid mem leaks) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 9a05b73cb97b3..4766f12bdd154 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -386,8 +386,9 @@ def test_tensor_to_float_conversion(tmpdir): class TestModel(BoringModel): def training_step(self, batch, batch_idx): - self.log('foo', torch.tensor(0.123), prog_bar=True) - self.log('bar', {"baz": torch.tensor([1])}, prog_bar=True) + self.log('a', torch.tensor(0.123), prog_bar=True, on_epoch=False) + self.log('b', {"b1": torch.tensor([1])}, prog_bar=True, on_epoch=False) + self.log('c', {"c1": 2}, prog_bar=True, on_epoch=False) return super().training_step(batch, batch_idx) trainer = Trainer( @@ -399,9 +400,12 @@ def training_step(self, batch, batch_idx): ) trainer.fit(TestModel()) + torch.testing.assert_allclose(trainer.progress_bar_metrics['a'], 0.123) + assert trainer.progress_bar_metrics['b'] == {'b1': 1.0} + assert trainer.progress_bar_metrics['c'] == {'c1': 2.0} pbar = trainer.progress_bar_callback.main_progress_bar actual = str(pbar.postfix) - assert actual.endswith("foo=0.123, bar={'baz': tensor([1])}") + assert actual.endswith("a=0.123, b={'b1': 1.0}, c={'c1': 2.0}"), actual @pytest.mark.parametrize( diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 8a636a0b15dd1..37dfc5adc07ef 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -66,9 +66,9 @@ def _ddp_test_fn(rank, worldsize): cumulative_sum += i - result.log('h', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a") - result.log('h', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") - result.log('h', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") + result.log('h', 'a', metric_a, on_step=True, on_epoch=True) + result.log('h', 'b', metric_b, on_step=False, on_epoch=True) + result.log('h', 'c', metric_c, on_step=True, on_epoch=False) batch_log = result.metrics(True)[MetricSource.LOG] assert batch_log == {"a_step": i, "c": i} @@ -109,9 +109,9 @@ def test_result_metric_integration(): cumulative_sum += i - result.log('h', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a") - result.log('h', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") - result.log('h', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") + result.log('h', 'a', metric_a, on_step=True, on_epoch=True) + result.log('h', 'b', metric_b, on_step=False, on_epoch=True) + result.log('h', 'c', metric_c, on_step=True, on_epoch=False) batch_log = result.metrics(True)[MetricSource.LOG] assert batch_log == {"a_step": i, "c": i} diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 2851a81e968d2..1b4f6cacfef70 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -88,6 +88,19 @@ def training_step(self, *args): trainer.fit(TestModel()) +def test_v1_6_0_sync_dist_op(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, *args): + self.log("foo", 1, sync_dist_op='sum') + return super().training_step(*args) + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.deprecated_call(match=r"`self.log\(sync_dist_op='sum'\)` is deprecated"): + trainer.fit(TestModel()) + + def test_v1_6_0_datamodule_lifecycle_properties(tmpdir): dm = BoringDataModule() with pytest.deprecated_call(match=r"DataModule property `has_prepared_data` was deprecated in v1.4"): diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py index 7f3a969df5232..384e643e184fe 100644 --- a/tests/models/test_grad_norm.py +++ b/tests/models/test_grad_norm.py @@ -87,7 +87,6 @@ def on_train_batch_end(self, *_) -> None: assert np.allclose(mod[k], log[k], rtol=rtol), k -@pytest.mark.skip("TODO: remove skip with #7631") @pytest.mark.parametrize("log_every_n_steps", [1, 2, 3]) def test_grad_tracking_interval(tmpdir, log_every_n_steps): """ Test that gradient norms get tracked in the right interval and that everytime the same keys get logged. """ @@ -109,5 +108,9 @@ def test_grad_tracking_interval(tmpdir, log_every_n_steps): if grad_norm_dict: grad_norm_dicts.append(grad_norm_dict) - assert len(grad_norm_dicts) == expected - assert all(grad_norm_dicts[0].keys() == g.keys() for g in grad_norm_dicts) + # logging on n steps + 1 epochs + assert len(grad_norm_dicts) == expected + 1 + # check all metrics derived from steps have the same keys + assert all(grad_norm_dicts[0].keys() == g.keys() for g in grad_norm_dicts[:-1]) + epoch_end_keys = [k.replace("step", "epoch") for k in grad_norm_dicts[0]] + assert epoch_end_keys == list(grad_norm_dicts[-1]) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 10f96845a7a48..a37935cb1b5de 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -296,7 +296,7 @@ def training_step(self, batch, batch_idx): self.training_step_called = True tensor = torch.tensor([1.0]) - self.log("test_tensor", tensor, sync_dist=True, sync_dist_op='sum', on_step=True, on_epoch=True) + self.log("test_tensor", tensor, sync_dist=True, reduce_fx='sum', on_step=True, on_epoch=True) res = self._results diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index ee288c2ebe078..f81481e1ee695 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -60,7 +60,6 @@ def validation_step(self, batch, batch_idx): ) trainer.fit(model) - # make sure all the metrics are available for callbacks assert set(trainer.logged_metrics) == { 'a2', 'a_step', @@ -224,7 +223,9 @@ def validation_epoch_end(self, outputs) -> None: # make sure values are correct assert trainer.logged_metrics['val_loss_epoch'] == model.manual_epoch_end_mean - assert trainer.callback_metrics['val_loss'] == trainer.logged_metrics['val_loss_step'] + assert trainer.callback_metrics['val_loss_epoch'] == model.manual_epoch_end_mean + assert trainer.callback_metrics['val_loss'] == model.manual_epoch_end_mean + assert trainer.logged_metrics["val_loss_step"] == model.val_losses[-1] @pytest.mark.parametrize(['batches', 'log_interval', 'max_epochs'], [(1, 1, 1), (64, 32, 2)]) @@ -299,176 +300,103 @@ def test_log_works_in_val_callback(tmpdir): class TestCallback(callbacks.Callback): - # helpers - count = 1 + count = 0 choices = [False, True] + # used to compute expected values - callback_funcs_called = collections.defaultdict(list) - funcs_called_count = collections.defaultdict(int) - funcs_attr = {} + logged_values = collections.defaultdict(list) + call_counter = collections.Counter() + logged_arguments = {} - def make_logging(self, pl_module, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[]): - self.funcs_called_count[func_name] += 1 - product = [on_steps, on_epochs, prob_bars] - for idx, (on_step, on_epoch, prog_bar) in enumerate(list(itertools.product(*product))): - # run logging - custom_func_name = f"{func_idx}_{idx}_{func_name}" - pl_module.log( - custom_func_name, self.count * func_idx, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar - ) - # catch information for verification - self.callback_funcs_called[func_name].append([self.count * func_idx]) - self.funcs_attr[custom_func_name] = { - "on_step": on_step, - "on_epoch": on_epoch, - "prog_bar": prog_bar, - "forked": on_step and on_epoch, - "func_name": func_name - } + def make_logging(self, pl_module, func_name, on_steps, on_epochs, prob_bars): + self.call_counter.update([func_name]) - if on_step and on_epoch: - self.funcs_attr[f"{custom_func_name}_step"] = { - "on_step": True, - "on_epoch": False, - "prog_bar": prog_bar, - "forked": False, - "func_name": func_name - } + for idx, (on_step, on_epoch, prog_bar) in enumerate(itertools.product(on_steps, on_epochs, prob_bars)): + fx = f"{func_name}_{idx}" + pl_module.log(fx, self.count, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) + self.logged_values[fx].append(self.count) + self.logged_arguments[fx] = {"on_step": on_step, "on_epoch": on_epoch, "prog_bar": prog_bar} + self.count += 1 - self.funcs_attr[f"{custom_func_name}_epoch"] = { - "on_step": False, - "on_epoch": True, - "prog_bar": prog_bar, - "forked": False, - "func_name": func_name - } - - def on_validation_start(self, trainer, pl_module): + def on_validation_start(self, _, pl_module): self.make_logging( - pl_module, 'on_validation_start', 1, on_steps=[False], on_epochs=[True], prob_bars=self.choices + pl_module, 'on_validation_start', on_steps=[False], on_epochs=[True], prob_bars=self.choices ) def on_epoch_start(self, trainer, pl_module): if trainer.validating: self.make_logging( - pl_module, 'on_epoch_start', 2, on_steps=[False], on_epochs=[True], prob_bars=self.choices + pl_module, 'on_epoch_start', on_steps=[False], on_epochs=[True], prob_bars=self.choices ) - def on_validation_epoch_start(self, trainer, pl_module): - self.make_logging( - pl_module, 'on_validation_epoch_start', 3, on_steps=[False], on_epochs=[True], prob_bars=self.choices - ) - - def on_batch_end(self, trainer, pl_module): + def on_validation_epoch_start(self, _, pl_module): self.make_logging( - pl_module, 'on_batch_end', 6, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + pl_module, 'on_validation_epoch_start', on_steps=[False], on_epochs=[True], prob_bars=self.choices ) - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_validation_batch_end(self, _, pl_module, *__): self.make_logging( pl_module, 'on_validation_batch_end', - 7, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices ) - # used to make sure aggregation works fine. - # we should obtain func[value * c for c in range(1, max_epochs * limit_validation_batches)]) - # with func = np.mean if on_epoch else func = np.max - self.count += 1 def on_epoch_end(self, trainer, pl_module): if trainer.validating: - self.make_logging( - pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=[True], prob_bars=self.choices - ) + self.make_logging(pl_module, 'on_epoch_end', on_steps=[False], on_epochs=[True], prob_bars=self.choices) - def on_validation_epoch_end(self, trainer, pl_module): + def on_validation_epoch_end(self, _, pl_module): self.make_logging( - pl_module, 'on_validation_epoch_end', 9, on_steps=[False], on_epochs=[True], prob_bars=self.choices + pl_module, 'on_validation_epoch_end', on_steps=[False], on_epochs=[True], prob_bars=self.choices ) class TestModel(BoringModel): def validation_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) + loss = super().validation_step(batch, batch_idx)['x'] self.log('val_loss', loss) - max_epochs = 1 model = TestModel() model.validation_epoch_end = None - test_callback = TestCallback() - + cb = TestCallback() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=4, - limit_test_batches=0, - val_check_interval=0., num_sanity_val_steps=0, - max_epochs=max_epochs, - callbacks=[test_callback], + max_epochs=1, + callbacks=[cb], ) trainer.fit(model) - assert test_callback.funcs_called_count["on_epoch_start"] == 1 - # assert test_callback.funcs_called_count["on_batch_start"] == 1 - assert test_callback.funcs_called_count["on_batch_end"] == 1 - assert test_callback.funcs_called_count["on_validation_start"] == 1 - assert test_callback.funcs_called_count["on_validation_epoch_start"] == 1 - # assert test_callback.funcs_called_count["on_validation_batch_start"] == 4 - assert test_callback.funcs_called_count["on_epoch_end"] == 1 - assert test_callback.funcs_called_count["on_validation_batch_end"] == 4 - assert test_callback.funcs_called_count["on_validation_epoch_end"] == 1 - - # Make sure the func_name exists within callback_metrics. If not, we missed some - callback_metrics_keys = [*trainer.callback_metrics.keys()] - for func_name in test_callback.callback_funcs_called.keys(): - is_in = False - for callback_metrics_key in callback_metrics_keys: - if func_name in callback_metrics_key: - is_in = True - assert is_in, (func_name, callback_metrics_keys) + assert cb.call_counter == { + 'on_validation_batch_end': 4, + 'on_validation_start': 1, + 'on_epoch_start': 1, + 'on_validation_epoch_start': 1, + 'on_validation_epoch_end': 1, + 'on_epoch_end': 1 + } - # function used to describe expected return logic - def get_expected_output(func_attr, original_values): + def get_expected(on_epoch, values): + reduction = np.mean if on_epoch else np.max + return reduction(values) - if func_attr["on_epoch"] and not func_attr["on_step"]: - # Apply mean on values - expected_output = np.mean(original_values) - else: - # Keep the latest value - expected_output = np.max(original_values) - return expected_output - - # Make sure the func_name output equals the average from all logged values when on_epoch true - # pop extra keys - trainer.callback_metrics.pop("val_loss") - for func_name, output_value in trainer.callback_metrics.items(): - # not sure how to handle this now - if "epoch_0" in func_name: - func_name = '/'.join(func_name.split('/')[:-1]) + for fx, value in trainer.callback_metrics.items(): + actual = value.item() + if fx not in cb.logged_arguments: continue + on_epoch = cb.logged_arguments[fx]['on_epoch'] + values = cb.logged_values[fx] + expected = get_expected(on_epoch, values) + assert actual == expected - if torch.is_tensor(output_value): - output_value = output_value.item() - # get creation attr - func_attr = test_callback.funcs_attr[func_name] - - # retrived orginal logged values - original_values = test_callback.callback_funcs_called[func_attr["func_name"]] - - # compute expected output and compare to actual one - expected_output = get_expected_output(func_attr, original_values) - assert float(output_value) == float(expected_output) - - for func_name, func_attr in test_callback.funcs_attr.items(): - if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: - assert func_name in trainer.logger_connector.progress_bar_metrics - else: - assert func_name not in trainer.logger_connector.progress_bar_metrics + for fx, attrs in cb.logged_arguments.items(): + should_include = attrs["prog_bar"] and attrs["on_step"] ^ attrs["on_epoch"] + is_included = fx in trainer.logger_connector.progress_bar_metrics + assert is_included if should_include else not is_included def test_log_works_in_test_callback(tmpdir): @@ -479,7 +407,7 @@ def test_log_works_in_test_callback(tmpdir): class TestCallback(callbacks.Callback): # helpers - count = 1 + count = 0 choices = [False, True] # used to compute expected values @@ -487,19 +415,15 @@ class TestCallback(callbacks.Callback): funcs_called_count = collections.defaultdict(int) funcs_attr = {} - def make_logging(self, pl_module, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[]): + def make_logging(self, pl_module, func_name, on_steps, on_epochs, prob_bars): original_func_name = func_name[:] self.funcs_called_count[original_func_name] += 1 - product = [on_steps, on_epochs, prob_bars] - for idx, t in enumerate(list(itertools.product(*product))): - # run logging + + for idx, (on_step, on_epoch, prog_bar) in enumerate(itertools.product(on_steps, on_epochs, prob_bars)): func_name = original_func_name[:] - on_step, on_epoch, prog_bar = t - custom_func_name = f"{func_idx}_{idx}_{func_name}" + custom_func_name = f"{idx}_{func_name}" - pl_module.log( - custom_func_name, self.count * func_idx, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar - ) + pl_module.log(custom_func_name, self.count, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) num_dl_ext = '' if pl_module._current_dataloader_idx is not None: @@ -508,12 +432,11 @@ def make_logging(self, pl_module, func_name, func_idx, on_steps=[], on_epochs=[] func_name += num_dl_ext # catch information for verification - self.callback_funcs_called[func_name].append([self.count * func_idx]) + self.callback_funcs_called[func_name].append([self.count]) self.funcs_attr[custom_func_name + num_dl_ext] = { "on_step": on_step, "on_epoch": on_epoch, "prog_bar": prog_bar, - "forked": on_step and on_epoch, "func_name": func_name } if on_step and on_epoch: @@ -521,7 +444,6 @@ def make_logging(self, pl_module, func_name, func_idx, on_steps=[], on_epochs=[] "on_step": True, "on_epoch": False, "prog_bar": prog_bar, - "forked": False, "func_name": func_name } @@ -529,125 +451,86 @@ def make_logging(self, pl_module, func_name, func_idx, on_steps=[], on_epochs=[] "on_step": False, "on_epoch": True, "prog_bar": prog_bar, - "forked": False, "func_name": func_name } - def on_test_start(self, trainer, pl_module): - self.make_logging(pl_module, 'on_test_start', 1, on_steps=[False], on_epochs=[True], prob_bars=self.choices) + def on_test_start(self, _, pl_module): + self.make_logging(pl_module, 'on_test_start', on_steps=[False], on_epochs=[True], prob_bars=self.choices) - def on_test_epoch_start(self, trainer, pl_module): + def on_test_epoch_start(self, _, pl_module): self.make_logging( - pl_module, 'on_test_epoch_start', 3, on_steps=[False], on_epochs=[True], prob_bars=self.choices + pl_module, 'on_test_epoch_start', on_steps=[False], on_epochs=[True], prob_bars=self.choices ) - def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_test_batch_end(self, _, pl_module, *__): self.make_logging( - pl_module, - 'on_test_batch_end', - 5, - on_steps=self.choices, - on_epochs=self.choices, - prob_bars=self.choices + pl_module, 'on_test_batch_end', on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices ) - # used to make sure aggregation works fine. - # we should obtain func[value * c for c in range(1, max_epochs * limit_test_batches)]) - # with func = np.mean if on_epoch else func = np.max - self.count += 1 - - def on_test_epoch_end(self, trainer, pl_module): + def on_test_epoch_end(self, _, pl_module): self.make_logging( - pl_module, 'on_test_epoch_end', 7, on_steps=[False], on_epochs=[True], prob_bars=self.choices + pl_module, 'on_test_epoch_end', on_steps=[False], on_epochs=[True], prob_bars=self.choices ) - max_epochs = 2 num_dataloaders = 2 class TestModel(BoringModel): - - manual_mean = collections.defaultdict(list) + seen_losses = {i: [] for i in range(num_dataloaders)} def test_step(self, batch, batch_idx, dataloader_idx=None): - output = self.layer(batch) - loss = self.loss(batch, output) + loss = super().test_step(batch, batch_idx)['y'] self.log('test_loss', loss) - self.manual_mean[str(dataloader_idx)].append(loss) + self.seen_losses[dataloader_idx].append(loss) def test_dataloader(self): return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)] model = TestModel() model.test_epoch_end = None - test_callback = TestCallback() - + cb = TestCallback() trainer = Trainer( default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=0, limit_test_batches=2, - val_check_interval=0., num_sanity_val_steps=0, - max_epochs=max_epochs, - callbacks=[test_callback], + max_epochs=2, + callbacks=[cb], ) trainer.test(model) - assert test_callback.funcs_called_count["on_test_start"] == 1 - assert test_callback.funcs_called_count["on_test_epoch_start"] == 1 - assert test_callback.funcs_called_count["on_test_batch_end"] == 4 - assert test_callback.funcs_called_count["on_test_epoch_end"] == 1 + assert cb.funcs_called_count["on_test_start"] == 1 + assert cb.funcs_called_count["on_test_epoch_start"] == 1 + assert cb.funcs_called_count["on_test_batch_end"] == 4 + assert cb.funcs_called_count["on_test_epoch_end"] == 1 - # Make sure the func_name exists within callback_metrics. If not, we missed some - callback_metrics_keys = [*trainer.callback_metrics.keys()] - - for func_name in test_callback.callback_funcs_called.keys(): + callback_metrics_keys = list(trainer.callback_metrics) + for func_name in cb.callback_funcs_called.keys(): is_in = False for callback_metrics_key in callback_metrics_keys: if func_name in callback_metrics_key: is_in = True assert is_in, (func_name, callback_metrics_keys) - # function used to describe expected return logic - def get_expected_output(func_attr, original_values): - # Apply mean on values - if func_attr["on_epoch"] and not func_attr["on_step"]: - expected_output = np.mean(original_values) - else: - expected_output = np.max(original_values) - return expected_output + def get_expected(on_epoch, values): + reduction = np.mean if on_epoch else np.max + return reduction(values) # Make sure the func_name output equals the average from all logged values when on_epoch true for dl_idx in range(num_dataloaders): key = f"test_loss/dataloader_idx_{dl_idx}" assert key in trainer.callback_metrics - assert torch.stack(model.manual_mean[str(dl_idx)]).mean() == trainer.callback_metrics[key] - trainer.callback_metrics.pop(key) + assert torch.stack(model.seen_losses[dl_idx]).mean() == trainer.callback_metrics.pop(key) for func_name, output_value in trainer.callback_metrics.items(): - # not sure how to handle this now - if "epoch_1" in func_name: - func_name = '/'.join(func_name.split('/')[:-1]) - continue - - if torch.is_tensor(output_value): - output_value = output_value.item() - - # get func attr - func_attr = test_callback.funcs_attr[func_name] - - # retrived orginal logged values - original_values = test_callback.callback_funcs_called[func_attr["func_name"]] - - # compute expected output and compare to actual one - expected_output = get_expected_output(func_attr, original_values) - assert float(output_value) == float(expected_output) - - for func_name, func_attr in test_callback.funcs_attr.items(): - if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: - assert func_name in trainer.logger_connector.progress_bar_metrics - else: - assert func_name not in trainer.logger_connector.progress_bar_metrics + output_value = output_value.item() + func_attr = cb.funcs_attr[func_name] + original_values = cb.callback_funcs_called[func_attr["func_name"]] + expected_output = get_expected(func_attr['on_epoch'], original_values) + assert output_value == expected_output + + for fx, attrs in cb.funcs_attr.items(): + should_include = attrs["prog_bar"] and attrs["on_step"] ^ attrs["on_epoch"] + is_included = fx in trainer.logger_connector.progress_bar_metrics + assert is_included if should_include else not is_included @mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics") diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index b7ee13fed6d8e..c791b0a5a83fa 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -11,11 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Tests to ensure that the training loop works with a dict (1.0) -""" -from copy import deepcopy -from typing import Any, Callable from unittest import mock import pytest @@ -27,250 +22,13 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator -from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder -from pytorch_lightning.trainer.connectors.logger_connector.result import Result from pytorch_lightning.trainer.connectors.logger_connector.result_new import MetricSource, ResultCollection +from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf -def decorator_with_arguments(fx_name: str = '', hook_fx_name: str = None) -> Callable: - - def decorator(func: Callable) -> Callable: - - def wrapper(self, *args, **kwargs) -> Any: - # Set information - self._current_fx_name = fx_name - self._current_hook_fx_name = hook_fx_name - self._results = Result() - - result = func(self, *args, **kwargs) - - # cache metrics - self.trainer.logger_connector.cache_logged_metrics() - return result - - return wrapper - - return decorator - - -def test__logger_connector__epoch_result_store__train(tmpdir): - """ - Tests that LoggerConnector will properly capture logged information - and reduce them - """ - - class TestModel(BoringModel): - - train_losses = [] - - @decorator_with_arguments(fx_name="training_step") - def training_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - - self.train_losses.append(loss) - - self.log("train_loss", loss, on_step=True, on_epoch=True) - - return {"loss": loss} - - def training_step_end(self, *_): - self.train_results = deepcopy(self.trainer.logger_connector.cached_results) - - model = TestModel() - model.training_epoch_end = None - model.val_dataloader = None - - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=4, - max_epochs=1, - log_every_n_steps=1, - weights_summary=None, - ) - trainer.fit(model) - - train_results = model.train_results - - assert len(train_results(fx_name="training_step", dl_idx=0, opt_idx=0)) == 2 - generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, batch_idx=0, split_idx=0)["train_loss"] - assert generated == model.train_losses[0] - generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, batch_idx=1, split_idx=0)["train_loss"] - assert generated == model.train_losses[1] - - assert train_results.has_reduced is not True - - train_results.has_batch_loop_finished = True - - assert train_results.has_reduced is True - - generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, reduced=True)['train_loss_epoch'].item() - expected = torch.stack(model.train_losses).mean().item() - assert generated == expected - - -def test__logger_connector__epoch_result_store__train__tbptt(tmpdir): - """ - Tests that LoggerConnector will properly capture logged information with ttbt - and reduce them - """ - truncated_bptt_steps = 2 - sequence_size = 30 - batch_size = 30 - - x_seq = torch.rand(batch_size, sequence_size, 1) - y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() - - class MockSeq2SeqDataset(torch.utils.data.Dataset): - - def __getitem__(self, i): - return x_seq, y_seq_list - - def __len__(self): - return 1 - - class TestModel(BoringModel): - - train_losses = [] - - def __init__(self): - super().__init__() - self.test_hidden = None - self.layer = torch.nn.Linear(2, 2) - - @decorator_with_arguments(fx_name="training_step") - def training_step(self, batch, batch_idx, hiddens): - assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" - self.test_hidden = torch.rand(1) - - x_tensor, y_list = batch - assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" - - y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) - assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" - - pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) - loss = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps)) - - self.train_losses.append(loss) - - self.log('a', loss, on_epoch=True) - - return {'loss': loss, 'hiddens': self.test_hidden} - - def on_train_epoch_start(self) -> None: - self.test_hidden = None - - def train_dataloader(self): - return torch.utils.data.DataLoader( - dataset=MockSeq2SeqDataset(), - batch_size=batch_size, - shuffle=False, - sampler=None, - ) - - def training_step_end(self, training_step_output): - self.train_results = deepcopy(self.trainer.logger_connector.cached_results) - # must return - return training_step_output - - model = TestModel() - model.training_epoch_end = None - model.example_input_array = torch.randn(5, truncated_bptt_steps) - - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=10, - limit_val_batches=0, - truncated_bptt_steps=truncated_bptt_steps, - max_epochs=1, - log_every_n_steps=1, - weights_summary=None, - ) - trainer.fit(model) - - train_results = model.train_results - - generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, batch_idx=0) - assert len(generated) == len(model.train_losses) - - # assert reduction didn't happen yet - assert train_results.has_reduced is False - - # Launch reduction - train_results.has_batch_loop_finished = True - - # assert reduction did happen - assert train_results.has_reduced is True - - generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, reduced=True)['a_epoch'].item() - assert generated == torch.stack(model.train_losses).mean().item() - - -@pytest.mark.parametrize('num_dataloaders', [1, 2]) -def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, num_dataloaders): - """ - Tests that LoggerConnector will properly capture logged information in multi dataloaders scenario - """ - - class TestModel(BoringModel): - test_losses = {dl_idx: [] for dl_idx in range(num_dataloaders)} - - @decorator_with_arguments(fx_name="test_step") - def test_step(self, batch, batch_idx, dl_idx=0): - output = self.layer(batch) - loss = self.loss(batch, output) - self.test_losses[dl_idx].append(loss) - self.log("test_loss", loss, on_step=True, on_epoch=True) - return {"test_loss": loss} - - def on_test_batch_end(self, *args, **kwargs): - # save objects as it will be reset at the end of epoch. - self.batch_results = deepcopy(self.trainer.logger_connector.cached_results) - - def on_test_epoch_end(self): - # save objects as it will be reset at the end of epoch. - self.reduce_results = deepcopy(self.trainer.logger_connector.cached_results) - - def test_dataloader(self): - return [super().test_dataloader()] * num_dataloaders - - model = TestModel() - model.test_epoch_end = None - limit_test_batches = 4 - - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=0, - limit_val_batches=0, - limit_test_batches=limit_test_batches, - max_epochs=1, - log_every_n_steps=1, - weights_summary=None, - ) - trainer.test(model) - - test_results = model.batch_results - - generated = test_results(fx_name="test_step") - assert len(generated) == num_dataloaders - - for dl_idx in range(num_dataloaders): - generated = test_results(fx_name="test_step", dl_idx=dl_idx) - assert len(generated) == limit_test_batches - - test_results = model.reduce_results - - for dl_idx in range(num_dataloaders): - expected = torch.stack(model.test_losses[dl_idx]).mean() - generated = test_results(fx_name="test_step", dl_idx=dl_idx, reduced=True)["test_loss_epoch"] - torch.testing.assert_allclose(generated, expected) - - def test_fx_validator(tmpdir): funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')]) @@ -445,56 +203,6 @@ def test_dataloader(self): trainer.test(model, ckpt_path=None) -@pytest.mark.parametrize('to_float', [False, True]) -def test_metrics_holder(to_float, tmpdir): - - device = "cuda" if torch.cuda.is_available() else "cpu" - preds = torch.tensor([[0.9, 0.1]], device=device) - - def is_float(value: Any) -> bool: - return isinstance(value, float) - - expected_function = is_float if to_float else torch.is_tensor - targets = torch.tensor([1], device=device) - acc = Accuracy().to(device) - metric_holder = MetricsHolder(to_float=to_float) - metric_holder.update({ - "x": 1, - "y": torch.tensor(2), - "z": acc(preds, targets), - }) - metric_holder.convert(device) - metrics = metric_holder.metrics - assert expected_function(metrics["x"]) - assert expected_function(metrics["y"]) - assert expected_function(metrics["z"]) - - -def test_metric_holder_raises(tmpdir): - """Check that an error is raised when trying to convert non-scalar tensors""" - - class TestModel(BoringModel): - - def validation_step(self, batch, *args, **kwargs): - output = self(batch) - self.log('test', output) - - def test_step(self, *args, **kwargs): - return self.validation_step(*args, **kwargs) - - model = TestModel() - model.validation_epoch_end = None - model.test_epoch_end = None - - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - - match = "The metric `.*` does not contain a single element" - with pytest.raises(MisconfigurationException, match=match): - trainer.validate(model) - with pytest.raises(MisconfigurationException, match=match): - trainer.test(model) - - def test_can_return_tensor_with_more_than_one_element(tmpdir): """Ensure {validation,test}_step return values are not included as callback metrics. #6623""" @@ -646,26 +354,28 @@ def _assert_epoch_end(self, stage): acc = self._modules[f"acc_{stage}"] ap = self._modules[f"ap_{stage}"] - acc.reset.asset_not_called() - ap.reset.assert_not_called() + acc.reset.assert_called_once() + ap.reset.assert_called_once() - def on_train_epoch_end(self): - self._assert_epoch_end('train') + def teardown(self, stage): + if stage == TrainerFn.FITTING: + self._assert_epoch_end('train') + self._assert_epoch_end('val') - def on_validation_epoch_end(self): - self._assert_epoch_end('val') + elif stage == TrainerFn.VALIDATING: + self._assert_epoch_end('val') - def on_test_epoch_end(self): - self._assert_epoch_end('test') + elif stage == TrainerFn.TESTING: + self._assert_epoch_end('test') def _assert_called(model, stage): acc = model._modules[f"acc_{stage}"] ap = model._modules[f"ap_{stage}"] - acc.reset.assert_called_once() + assert acc.reset.call_count == 1 acc.reset.reset_mock() - ap.reset.assert_called_once() + assert ap.reset.call_count == 1 ap.reset.reset_mock() model = TestModel() @@ -676,6 +386,7 @@ def _assert_called(model, stage): limit_test_batches=2, max_epochs=1, progress_bar_refresh_rate=0, + num_sanity_val_steps=2, ) trainer.fit(model) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index bdfdadff89d5f..bff558e81b29e 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -192,7 +192,13 @@ def training_epoch_end(self, outputs): assert set(trainer.callback_metrics) == (logged_metrics | pbar_metrics | {'a', 'b'}) - {'epoch'} -@pytest.mark.parametrize(['batches', 'fx', 'result'], [(3, min, 0), (3, max, 2), (11, max, 10)]) +@pytest.mark.parametrize(['batches', 'fx', 'result'], [ + (3, min, 0), + (3, torch.max, 2), + (11, max, 10), + (5, 'avg', 2), + (5, 'SUM', 10), +]) def test__training_step__log_max_reduce_fx(tmpdir, batches, fx, result): """ Tests that log works correctly with different tensor types @@ -329,6 +335,7 @@ def val_dataloader(self): limit_val_batches=2, max_epochs=1, weights_summary=None, + fast_dev_run=True, ) trainer.fit(model) @@ -342,185 +349,107 @@ def test_log_works_in_train_callback(tmpdir): class TestCallback(callbacks.Callback): - # helpers - count = 1 + count = 0 choices = [False, True] + # used to compute expected values - callback_funcs_called = collections.defaultdict(list) - funcs_called_count = collections.defaultdict(int) - funcs_attr = {} - - def make_logging( - self, pl_module: pl.LightningModule, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[] - ): - self.funcs_called_count[func_name] += 1 - iterate = list(itertools.product(*[on_steps, on_epochs, prob_bars])) - for idx, (on_step, on_epoch, prog_bar) in enumerate(iterate): - # run logging - custom_func_name = f"{func_idx}_{idx}_{func_name}" - pl_module.log( - custom_func_name, self.count * func_idx, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar - ) - - # catch information for verification - - # on on_train_start is outside the main loop. Won't be called - if func_name == "on_train_start": - self.callback_funcs_called[func_name].append([self.count * func_idx]) - - # Saved only values from second epoch, so we can compute its mean or latest. - if pl_module.trainer.current_epoch == 1: - self.callback_funcs_called[func_name].append([self.count * func_idx]) - - forked = on_step and on_epoch - - self.funcs_attr[custom_func_name] = { - "on_step": on_step, - "on_epoch": on_epoch, - "prog_bar": prog_bar, - "forked": forked, - "func_name": func_name - } - - if on_step and on_epoch: - self.funcs_attr[f"{custom_func_name}_step"] = { - "on_step": True, - "on_epoch": False, - "prog_bar": prog_bar, - "forked": False, - "func_name": func_name - } - - self.funcs_attr[f"{custom_func_name}_epoch"] = { - "on_step": False, - "on_epoch": True, - "prog_bar": prog_bar, - "forked": False, - "func_name": func_name - } + logged_values = collections.defaultdict(list) + call_counter = collections.Counter() + logged_arguments = {} - def on_train_start(self, trainer, pl_module): - self.make_logging( - pl_module, 'on_train_start', 1, on_steps=[False], on_epochs=[True], prob_bars=self.choices - ) + def make_logging(self, pl_module, func_name, on_steps, on_epochs, prob_bars): + self.call_counter.update([func_name]) + + for idx, (on_step, on_epoch, prog_bar) in enumerate(itertools.product(on_steps, on_epochs, prob_bars)): + fx = f"{func_name}_{idx}" + pl_module.log(fx, self.count, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) + self.logged_values[fx].append(self.count) + self.logged_arguments[fx] = {"on_step": on_step, "on_epoch": on_epoch, "prog_bar": prog_bar} + self.count += 1 + + def on_train_start(self, _, pl_module): + self.make_logging(pl_module, 'on_train_start', on_steps=[False], on_epochs=[True], prob_bars=self.choices) - def on_epoch_start(self, trainer, pl_module): + def on_epoch_start(self, _, pl_module): self.make_logging( - pl_module, 'on_epoch_start', 2, on_steps=[False], on_epochs=[True], prob_bars=self.choices + pl_module, 'on_epoch_start', on_steps=self.choices, on_epochs=[True], prob_bars=self.choices ) - def on_train_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, _, pl_module): self.make_logging( - pl_module, 'on_train_epoch_start', 3, on_steps=[False], on_epochs=[True], prob_bars=self.choices + pl_module, 'on_train_epoch_start', on_steps=self.choices, on_epochs=[True], prob_bars=self.choices ) - def on_batch_end(self, trainer, pl_module): + def on_batch_end(self, _, pl_module): self.make_logging( - pl_module, 'on_batch_end', 6, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + pl_module, 'on_batch_end', on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices ) - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, _, pl_module, *__): self.make_logging( - pl_module, - 'on_train_batch_end', - 7, - on_steps=self.choices, - on_epochs=self.choices, - prob_bars=self.choices + pl_module, 'on_train_batch_end', on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices ) - # used to make sure aggregation works fine. - # we should obtain func[value * c for c in range(1, max_epochs * limit_train_batches)]) - # with func = np.mean if on_epoch else func = np.max - self.count += 1 - def on_train_epoch_end(self, trainer, pl_module): + def on_train_epoch_end(self, _, pl_module): self.make_logging( - pl_module, 'on_train_epoch_end', 8, on_steps=[False], on_epochs=[True], prob_bars=self.choices + pl_module, 'on_train_epoch_end', on_steps=[False], on_epochs=[True], prob_bars=self.choices ) - def on_epoch_end(self, trainer, pl_module): - self.make_logging(pl_module, 'on_epoch_end', 9, on_steps=[False], on_epochs=[True], prob_bars=self.choices) + def on_epoch_end(self, _, pl_module): + self.make_logging(pl_module, 'on_epoch_end', on_steps=[False], on_epochs=[True], prob_bars=self.choices) class TestModel(BoringModel): - - manual_loss = [] + seen_losses = [] def training_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - self.manual_loss.append(loss) - self.log('train_loss', loss) + loss = super().training_step(batch, batch_idx)['loss'] + self.seen_losses.append(loss) + self.log('train_loss', loss, prog_bar=True) return {"loss": loss} - max_epochs = 2 - limit_train_batches = 2 model = TestModel() - test_callback = TestCallback() - + cb = TestCallback() trainer = Trainer( default_root_dir=tmpdir, - limit_train_batches=limit_train_batches, + limit_train_batches=2, limit_val_batches=0, - limit_test_batches=0, - val_check_interval=0., num_sanity_val_steps=0, - max_epochs=max_epochs, - callbacks=[test_callback] + max_epochs=1, + callbacks=[cb] ) trainer.fit(model) - assert test_callback.funcs_called_count["on_train_start"] == 1 - assert test_callback.funcs_called_count["on_epoch_start"] == 2 - assert test_callback.funcs_called_count["on_train_epoch_start"] == 2 - assert test_callback.funcs_called_count["on_batch_end"] == 4 - assert test_callback.funcs_called_count["on_epoch_end"] == 2 - assert test_callback.funcs_called_count["on_train_batch_end"] == 4 - assert test_callback.funcs_called_count["on_epoch_end"] == 2 - assert test_callback.funcs_called_count["on_train_epoch_end"] == 2 - - # Make sure the func_name exists within callback_metrics. If not, we missed some - callback_metrics_keys = [*trainer.callback_metrics.keys()] - for func_name in test_callback.callback_funcs_called.keys(): - is_in = False - for callback_metrics_key in callback_metrics_keys: - if func_name in callback_metrics_key: - is_in = True - assert is_in, (func_name, callback_metrics_keys) - - # function used to describe expected return logic - def get_expected_output(func_attr, original_values): - if func_attr["on_epoch"] and not func_attr["on_step"]: - # Apply mean on values - expected_output = np.mean(original_values) - else: - # Keep the latest value - expected_output = np.max(original_values) - return expected_output - # Make sure the func_name output equals the average from all logged values when on_epoch true - # pop extra keys - assert trainer.logged_metrics["train_loss"] == model.manual_loss[-1] - assert trainer.callback_metrics["train_loss"] == model.manual_loss[-1] - trainer.callback_metrics.pop("train_loss") + assert trainer.progress_bar_dict["train_loss"] == model.seen_losses[-1] + assert trainer.callback_metrics["train_loss"] == model.seen_losses[-1] - for func_name, output_value in trainer.callback_metrics.items(): - if torch.is_tensor(output_value): - output_value = output_value.item() - # get creation attr - func_attr = test_callback.funcs_attr[func_name] + assert cb.call_counter == { + 'on_train_start': 1, + 'on_epoch_start': 1, + 'on_train_epoch_start': 1, + 'on_train_batch_end': 2, + 'on_batch_end': 2, + 'on_train_epoch_end': 1, + 'on_epoch_end': 1 + } - # retrived orginal logged values - original_values = test_callback.callback_funcs_called[func_attr["func_name"]] + def get_expected(on_epoch, values): + reduction = np.mean if on_epoch else np.max + return reduction(values) - # compute expected output and compare to actual one - expected_output = get_expected_output(func_attr, original_values) - assert float(output_value) == float(expected_output) + for fx, value in trainer.callback_metrics.items(): + actual = value.item() + if fx not in cb.logged_arguments: + continue + on_epoch = cb.logged_arguments[fx]['on_epoch'] + values = cb.logged_values[fx] + expected = get_expected(on_epoch, values) + assert actual == expected - for func_name, func_attr in test_callback.funcs_attr.items(): - if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: - assert func_name in trainer.logger_connector.progress_bar_metrics - else: - assert func_name not in trainer.logger_connector.progress_bar_metrics + for fx, attrs in cb.logged_arguments.items(): + should_include = attrs["prog_bar"] and attrs["on_step"] ^ attrs["on_epoch"] + is_included = fx in trainer.logger_connector.progress_bar_metrics + assert is_included if should_include else not is_included @pytest.mark.parametrize('gpus', [None, pytest.param(1, marks=RunIf(min_gpus=1))]) @@ -533,12 +462,12 @@ def test_logging_sync_dist_true(tmpdir, gpus): class TestModel(BoringModel): def training_step(self, batch, batch_idx): - self.log('foo', fake_result, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum') - self.log('foo_2', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum') + self.log('foo', fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx='sum') + self.log('foo_2', 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx='sum') return super().training_step(batch, batch_idx) def validation_step(self, batch, batch_idx): - self.log('bar', fake_result, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum') + self.log('bar', fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx='sum') return super().validation_step(batch, batch_idx) model = TestModel() @@ -567,14 +496,14 @@ class TestLoggingSyncDistModel(BoringModel): def training_step(self, batch, batch_idx): acc = self.step(batch[0]) - self.log('foo', 1, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='SUM') + self.log('foo', 1, on_step=False, on_epoch=True, sync_dist=True, reduce_fx='SUM') self.log('cho', acc, on_step=False, on_epoch=True) return acc def validation_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) - self.log('bar', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='AVG') + self.log('bar', 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx='AVG') return {"x": loss} model = TestLoggingSyncDistModel() @@ -609,14 +538,14 @@ def on_train_epoch_end(self, *_): prog_bar=True, on_epoch=True, sync_dist=True, - sync_dist_op='sum' + reduce_fx='sum' ) self.on_train_epoch_end_called = True def on_epoch_end(self): assert self.trainer.progress_bar_dict["foo"] == self.current_epoch assert self.trainer.progress_bar_dict["foo_2"] == self.current_epoch - self.epoch_end_called = True + self.on_epoch_end_called = True trainer = Trainer( default_root_dir=tmpdir, @@ -631,7 +560,7 @@ def on_epoch_end(self): model = TestModel() trainer.fit(model) assert model.on_train_epoch_end_called - assert model.epoch_end_called + assert model.on_epoch_end_called def test_logging_in_callbacks_with_log_function(tmpdir): @@ -760,3 +689,37 @@ def training_step(self, batch, batch_idx): model = TestModel() with pytest.raises(MisconfigurationException, match='`self.log` with the key `foo/dataloader_idx_0`'): trainer.fit(model) + + class TestModel(BoringModel): + + def training_step(self, *args): + loss = super().training_step(*args)['loss'] + return {"loss": loss, 'foo': loss} + + trainer = Trainer(default_root_dir=tmpdir) + model = TestModel() + with pytest.raises(MisconfigurationException, match='You returned a tensor with `grad_fn`'): + trainer.fit(model) + + class TestModel(BoringModel): + + def training_step(self, *args): + self.log('foo', -1, prog_bar=False) + self.log('foo', -1, prog_bar=True) + return super().training_step(*args) + + trainer = Trainer(default_root_dir=tmpdir) + model = TestModel() + with pytest.raises(MisconfigurationException, match=r'self.log\(foo, ...\)` twice in `training_step`'): + trainer.fit(model) + + class TestModel(BoringModel): + + def training_step(self, *args): + self.log('foo', -1, reduce_fx=torch.argmax) + return super().training_step(*args) + + trainer = Trainer(default_root_dir=tmpdir) + model = TestModel() + with pytest.raises(MisconfigurationException, match=r'reduce_fx={min,max,mean,sum}\)` are currently supported'): + trainer.fit(model) diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index 6e601b577d648..40ed000449f87 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -47,12 +47,13 @@ def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): @mock.patch( - "pytorch_lightning.trainer.connectors.logger_connector.logger_connector.LoggerConnector.get_evaluate_epoch_results" + "pytorch_lightning.trainer.connectors.logger_connector." + "logger_connector_new.LoggerConnectorNew.update_eval_epoch_metrics" ) -def test_log_epoch_metrics_before_on_evaluation_end(get_evaluate_epoch_results_mock, tmpdir): +def test_log_epoch_metrics_before_on_evaluation_end(update_eval_epoch_metrics_mock, tmpdir): """Test that the epoch metrics are logged before the `on_evalutaion_end` hook is fired""" order = [] - get_evaluate_epoch_results_mock.side_effect = lambda: order.append("log_epoch_metrics") + update_eval_epoch_metrics_mock.side_effect = lambda: order.append("log_epoch_metrics") class LessBoringModel(BoringModel): diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index a187dd37254f6..ce93461026363 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -19,6 +19,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.states import RunningStage from tests.helpers.deterministic_model import DeterministicModel @@ -65,16 +66,17 @@ def backward(self, loss, optimizer, optimizer_idx): assert not model.validation_step_end_called assert not model.validation_epoch_end_called - # make sure training outputs what is expected + # simulate training manually + trainer.state.stage = RunningStage.TRAINING batch_idx, batch = 0, next(iter(model.train_dataloader())) out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - train_step_out = out.training_step_output_for_epoch_end + train_step_out = out.training_step_output assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out['minimize'], torch.Tensor) - assert train_step_out['minimize'].item() == 171 + assert isinstance(train_step_out.minimize, torch.Tensor) + assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things opt_closure_result = trainer.train_loop.training_step_and_backward( @@ -135,16 +137,17 @@ def backward(self, loss, optimizer, optimizer_idx): assert model.validation_step_end_called assert not model.validation_epoch_end_called + trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - train_step_out = out.training_step_output_for_epoch_end + train_step_out = out.training_step_output assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out['minimize'], torch.Tensor) - assert train_step_out['minimize'].item() == 171 + assert isinstance(train_step_out.minimize, torch.Tensor) + assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things opt_closure_result = trainer.train_loop.training_step_and_backward( diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index ae1e17aeabfe8..b3c2997436bbf 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -11,10 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Tests to ensure that the training loop works with a dict (1.0) -""" - import pytest import torch from torch.utils.data import DataLoader @@ -22,6 +18,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.states import RunningStage from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.deterministic_model import DeterministicModel from tests.helpers.utils import no_warning_call @@ -149,16 +146,17 @@ def backward(self, loss, optimizer, optimizer_idx): assert len(trainer.logger_connector.callback_metrics) == 0 assert len(trainer.logger_connector.progress_bar_metrics) == 0 + trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - train_step_out = out.training_step_output_for_epoch_end + train_step_out = out.training_step_output assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out['minimize'], torch.Tensor) - assert train_step_out['minimize'].item() == 171 + assert isinstance(train_step_out.minimize, torch.Tensor) + assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things opt_closure_result = trainer.train_loop.training_step_and_backward( @@ -226,16 +224,17 @@ def backward(self, loss, optimizer, optimizer_idx): assert len(trainer.logger_connector.callback_metrics) == 0 assert len(trainer.logger_connector.progress_bar_metrics) == 0 + trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - train_step_out = out.training_step_output_for_epoch_end + train_step_out = out.training_step_output assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out['minimize'], torch.Tensor) - assert train_step_out['minimize'].item() == 171 + assert isinstance(train_step_out.minimize, torch.Tensor) + assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things opt_closure_result = trainer.train_loop.training_step_and_backward( @@ -310,11 +309,13 @@ def training_step(self, batch, batch_idx): with pytest.warns(UserWarning, match=r'.*training_step returned None.*'): trainer.fit(model) + trainer.state.stage = RunningStage.TRAINING + # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) if not batch_idx % 2: - assert out.training_step_output_for_epoch_end == [[]] + assert out.training_step_output == [[]] assert out.signal == 0 @@ -353,9 +354,11 @@ def train_dataloader(self): with pytest.warns(UserWarning, match=r'.*train_dataloader yielded None.*'): trainer.fit(model) + trainer.state.stage = RunningStage.TRAINING + # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) if not batch_idx % 2: - assert out.training_step_output_for_epoch_end == [[]] + assert out.training_step_output == [[]] assert out.signal == 0 diff --git a/tests/trainer/optimization/test_multiple_optimizers.py b/tests/trainer/optimization/test_multiple_optimizers.py index aba3b53248a57..495f51ab8d394 100644 --- a/tests/trainer/optimization/test_multiple_optimizers.py +++ b/tests/trainer/optimization/test_multiple_optimizers.py @@ -30,10 +30,7 @@ def configure_optimizers(self): def test_unbalanced_logging_with_multiple_optimizers(tmpdir): - """ - This tests ensures reduction works in unbalanced logging settings, - even when a Callback also logs. - """ + """This tests ensures reduction works in unbalanced logging settings""" class TestModel(MultiOptModel): @@ -49,22 +46,12 @@ def training_step(self, batch, batch_idx, optimizer_idx): model = TestModel() model.training_epoch_end = None - class TestCallback(pl.Callback): - - def on_train_batch_end(self, trainer, pl_module, output, batch, batch_idx, dl_idx): - # when this is called, the EpochResultStore state has not been reset yet because we are still - # "INSIDE_BATCH_TRAIN_LOOP" and the LoggerConnector runs its `on_train_batch_end` after the - # Callback (see `TrainLoop.on_train_batch_end`). For this reason, opt_idx here is the index - # of the last optimizer updated (the second, index 1). This produced a KeyError as reported in #5459 - pl_module.log("test_train_batch_end", trainer.logger_connector.cached_results._opt_idx) - # Initialize a trainer trainer = pl.Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=5, limit_val_batches=5, - callbacks=[TestCallback()], weights_summary=None, ) trainer.fit(model) @@ -74,8 +61,6 @@ def on_train_batch_end(self, trainer, pl_module, output, batch, batch_idx, dl_id # test loss is properly reduced torch.testing.assert_allclose(trainer.callback_metrics[f"loss_{k}_epoch"], torch.tensor(v).mean()) - assert trainer.callback_metrics["test_train_batch_end"] == len(model.optimizers()) - 1 - def test_multiple_optimizers(tmpdir):