From 9304c0df8f89048e5bc8eabad0798595d37d91f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 27 May 2021 14:27:52 +0200 Subject: [PATCH] Rename and move Result (#7736) Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> --- CHANGELOG.md | 5 +++++ pytorch_lightning/core/lightning.py | 8 +++++--- pytorch_lightning/plugins/training_type/ddp2.py | 2 +- pytorch_lightning/plugins/training_type/dp.py | 2 +- .../connectors/logger_connector/epoch_result_store.py | 2 +- .../connectors/logger_connector/logger_connector.py | 2 +- .../connectors/logger_connector/result.py} | 0 pytorch_lightning/trainer/evaluation_loop.py | 2 +- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/trainer/training_loop.py | 8 ++++---- tests/core/test_metric_result_integration.py | 2 +- tests/core/test_results.py | 2 +- tests/models/test_tpu.py | 2 +- tests/trainer/logging_/test_logger_connector.py | 2 +- 14 files changed, 24 insertions(+), 17 deletions(-) rename pytorch_lightning/{core/step_result.py => trainer/connectors/logger_connector/result.py} (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 58cce920af23e..4259684748e85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,6 +77,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Simplified "should run validation" logic ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682)) * Refactored "should run validation" logic when the trainer is signaled to stop ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701)) + +- Refactored logging + * Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736)) + + - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 74c1ef442f993..cf6e25c54f336 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -26,7 +26,7 @@ from argparse import Namespace from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union import torch from torch import ScriptModule, Tensor @@ -38,7 +38,6 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES -from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -50,6 +49,9 @@ from pytorch_lightning.utilities.types import _METRIC, EPOCH_OUTPUT, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache +if TYPE_CHECKING: + from pytorch_lightning.trainer.connectors.logger_connector.result import Result + warning_cache = WarningCache() log = logging.getLogger(__name__) @@ -106,7 +108,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # optionally can be set by user self._example_input_array = None self._datamodule = None - self._results: Optional[Result] = None + self._results: Optional['Result'] = None self._current_fx_name: Optional[str] = None self._running_manual_backward: bool = False self._current_dataloader_idx: Optional[int] = None diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index b6d21904d1933..ecf6997cba321 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -13,8 +13,8 @@ # limitations under the License. import torch -from pytorch_lightning.core.step_result import Result from pytorch_lightning.plugins.training_type.ddp import DDPPlugin +from pytorch_lightning.trainer.connectors.logger_connector.result import Result class DDP2Plugin(DDPPlugin): diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 08caa7398ab8c..bb6f25a0eed36 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -16,9 +16,9 @@ import torch from torch.nn import DataParallel -from pytorch_lightning.core.step_result import Result from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.trainer.connectors.logger_connector.result import Result from pytorch_lightning.utilities.apply_func import apply_to_collection diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 3d6370e3eb658..2ec7dd9460fa4 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -18,7 +18,7 @@ import torch import pytorch_lightning as pl -from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.connectors.logger_connector.result import Result from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import DistributedType, LightningEnum diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index a16f5119abff2..7bd834d5925b4 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -19,11 +19,11 @@ import torch from pytorch_lightning.core import memory -from pytorch_lightning.core.step_result import Result from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder +from pytorch_lightning.trainer.connectors.logger_connector.result import Result from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import DeviceType from pytorch_lightning.utilities.metrics import metrics_to_scalars diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py similarity index 100% rename from pytorch_lightning/core/step_result.py rename to pytorch_lightning/trainer/connectors/logger_connector/result.py diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index f048297892533..05bd1e48e5f74 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -17,7 +17,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.connectors.logger_connector.result import Result from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.model_helpers import is_overridden diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b01f4fa36bd33..ded6e3395e30c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -28,7 +28,6 @@ from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary -from pytorch_lightning.core.step_result import Result from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment @@ -48,6 +47,7 @@ from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector +from pytorch_lightning.trainer.connectors.logger_connector.result import Result from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 09a32c3c96aad..7a12d7e766dae 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -23,8 +23,8 @@ from torch.optim import Optimizer from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.core.step_result import Result from pytorch_lightning.plugins import ParallelPlugin +from pytorch_lightning.trainer.connectors.logger_connector.result import Result from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType from pytorch_lightning.utilities.distributed import rank_zero_info @@ -742,9 +742,9 @@ def training_step_and_backward_closure( return_result: AttributeDict, ) -> Optional[torch.Tensor]: - step_result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) - if step_result is not None: - return_result.update(step_result) + result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) + if result is not None: + return_result.update(result) return return_result.loss def make_closure(self, *closure_args, **closure_kwargs: Any) -> Callable: diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 734b9e7f56152..fd08890604807 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -18,7 +18,7 @@ from torchmetrics import Metric import tests.helpers.utils as tutils -from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.connectors.logger_connector.result import Result from tests.helpers.runif import RunIf diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 02d30d9f79ee3..ef8f1403057ad 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -22,7 +22,7 @@ import tests.helpers.utils as tutils from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.connectors.logger_connector.result import Result from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index f7d0aea829ced..d5c9ae65e6afc 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -24,8 +24,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.core.step_result import Result from pytorch_lightning.plugins import TPUSpawnPlugin +from pytorch_lightning.trainer.connectors.logger_connector.result import Result from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index e0e1c3cdf42ec..75b81392ff916 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -26,10 +26,10 @@ from pytorch_lightning import LightningModule from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder +from pytorch_lightning.trainer.connectors.logger_connector.result import Result from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf