Skip to content

Commit

Permalink
Rename and move Result (#7736)
Browse files Browse the repository at this point in the history
Co-authored-by: Kaushik B <[email protected]>
  • Loading branch information
carmocca and kaushikb11 authored May 27, 2021
1 parent 906c067 commit 9304c0d
Show file tree
Hide file tree
Showing 14 changed files with 24 additions and 17 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Simplified "should run validation" logic ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682))
* Refactored "should run validation" logic when the trainer is signaled to stop ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701))


- Refactored logging
* Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736))


- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/))


Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from argparse import Namespace
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import torch
from torch import ScriptModule, Tensor
Expand All @@ -38,7 +38,6 @@
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
Expand All @@ -50,6 +49,9 @@
from pytorch_lightning.utilities.types import _METRIC, EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

if TYPE_CHECKING:
from pytorch_lightning.trainer.connectors.logger_connector.result import Result

warning_cache = WarningCache()
log = logging.getLogger(__name__)

Expand Down Expand Up @@ -106,7 +108,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# optionally can be set by user
self._example_input_array = None
self._datamodule = None
self._results: Optional[Result] = None
self._results: Optional['Result'] = None
self._current_fx_name: Optional[str] = None
self._running_manual_backward: bool = False
self._current_dataloader_idx: Optional[int] = None
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
import torch

from pytorch_lightning.core.step_result import Result
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import Result


class DDP2Plugin(DDPPlugin):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import torch
from torch.nn import DataParallel

from pytorch_lightning.core.step_result import Result
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
from pytorch_lightning.utilities.apply_func import apply_to_collection


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

import pytorch_lightning as pl
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import DistributedType, LightningEnum

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import torch

from pytorch_lightning.core import memory
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities.metrics import metrics_to_scalars
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.plugins import Plugin
from pytorch_lightning.plugins.environments import ClusterEnvironment
Expand All @@ -48,6 +47,7 @@
from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector
from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from torch.optim import Optimizer

from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType
from pytorch_lightning.utilities.distributed import rank_zero_info
Expand Down Expand Up @@ -742,9 +742,9 @@ def training_step_and_backward_closure(
return_result: AttributeDict,
) -> Optional[torch.Tensor]:

step_result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens)
if step_result is not None:
return_result.update(step_result)
result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens)
if result is not None:
return_result.update(result)
return return_result.loss

def make_closure(self, *closure_args, **closure_kwargs: Any) -> Callable:
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torchmetrics import Metric

import tests.helpers.utils as tutils
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
from tests.helpers.runif import RunIf


Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import tests.helpers.utils as tutils
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import TPUAccelerator
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.plugins import TPUSpawnPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@

from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
Expand Down

0 comments on commit 9304c0d

Please sign in to comment.