Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move logger and profiler finalization to trainer's teardown #8685

Merged
merged 8 commits into from
Aug 5, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
[#8627](https://github.com/PyTorchLightning/pytorch-lightning/pull/8627))


- Fix an issue with logger ouputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333))


## [1.4.0] - 2021-07-27

### Added
Expand Down
5 changes: 0 additions & 5 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT

Expand Down Expand Up @@ -207,10 +206,6 @@ def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
else:
self.trainer.call_hook("on_validation_end", *args, **kwargs)

if self.trainer.state.fn != TrainerFn.FITTING:
# summarize profile results
self.trainer.profiler.describe()

# reset any `torchmetrics.Metric` and the logger connector state
self.trainer.logger_connector.reset(metrics=True)

Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
Returns:
the results for all dataloaders
"""
self.trainer.profiler.describe()

results = self.predictions

self.trainer.call_hook("on_predict_epoch_end", results)
Expand Down
9 changes: 0 additions & 9 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,6 @@ def on_run_end(self) -> None:
# hook
self.trainer.call_hook("on_train_end")

# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
# It might be related to xla tensors blocked when moving the cpu
# kill loggers
if self.trainer.logger is not None:
self.trainer.logger.finalize("success")

# summarize profile results
self.trainer.profiler.describe()

# give accelerators a chance to finish
self.trainer.accelerator.on_train_end()

Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ def new_process(self, process_idx, trainer, mp_queue):
# persist info in ddp_spawn
self.transfer_distrib_spawn_state_on_fit_end(results)

# ensure that spawned processes go through teardown before joining
trainer._call_teardown_hook(self.lightning_module)

def post_dispatch(self):
# restore main state with best weights
best_path = self.mp_queue.get()
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
if self.local_rank == 0:
time.sleep(2)

# ensure that spawned processes go through teardown before joining
trainer._call_teardown_hook(self.lightning_module)

@parameter_validation
def model_to_device(self) -> None:
self.model = self.wrapped_model.to(self.root_device)
Expand Down
17 changes: 14 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
)
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.enums import DistributedType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -942,8 +943,10 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
if self.state.fn == TrainerFn.FITTING:
self.call_hook("on_fit_end")

# teardown
self._call_teardown_hook(model)
# teardown if necessary (similar calls for spawn plugins are excluded as they have
# been included at the end of `new_process` functions)
if self._distrib_type not in DistributedType.interactive_compatible_types():
self._call_teardown_hook(model)

if self.state.status != TrainerStatus.INTERRUPTED:
self.state.status = TrainerStatus.FINISHED
Expand Down Expand Up @@ -1208,7 +1211,7 @@ def _call_teardown_hook(self, model: "pl.LightningModule") -> None:

if self.datamodule is not None:
self.datamodule.teardown(stage=fn)
self.profiler.teardown(stage=fn)

self.teardown(stage=fn)
model.teardown(stage=fn)

Expand All @@ -1217,6 +1220,14 @@ def _call_teardown_hook(self, model: "pl.LightningModule") -> None:
# these could have become stale if metrics are defined in `setup`
model._metric_attributes = None

# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
# It might be related to xla tensors blocked when moving the cpu kill loggers.
if self.logger is not None:
self.logger.finalize("success")

# summarize profile results
self.profiler.describe()

def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
if self.lightning_module:
prev_fx_name = self.lightning_module._current_fx_name
Expand Down
68 changes: 68 additions & 0 deletions tests/trainer/logging_/test_distributed_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, Optional, Union
from unittest import mock
from unittest.mock import Mock

import pytorch_lightning as pl
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers.base import LightningLoggerBase
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -101,3 +104,68 @@ def on_train_start(self, trainer, pl_module):
callbacks=[LoggerCallsObserver()],
)
trainer.fit(model)


def test_logger_after_fit_predict_test_calls(tmpdir):
"""
Make sure logger outputs are finalized after fit, prediction, and test calls.
"""

class BufferLogger(LightningLoggerBase):
def __init__(self):
super().__init__()
self.buffer = {}
self.logs = {}

def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
self.buffer.update(metrics)

def finalize(self, status: str) -> None:
self.logs.update(self.buffer)
self.buffer = {}

@property
def experiment(self) -> Any:
return None

@property
def version(self) -> Union[int, str]:
return 1

@property
def name(self) -> str:
return "BufferLogger"

def log_hyperparams(self, *args, **kwargs) -> None:
return None

class LoggerCallsObserver(Callback):
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
trainer.logger.log_metrics({"fit": 1})

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
trainer.logger.log_metrics({"validate": 1})

def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
trainer.logger.log_metrics({"predict": 1})

def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
trainer.logger.log_metrics({"test": 1})

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
max_epochs=1,
logger=BufferLogger(),
callbacks=[LoggerCallsObserver()],
)

assert not trainer.logger.logs
trainer.fit(model)
assert trainer.logger.logs == {"fit": 1, "validate": 1}
trainer.test(model)
assert trainer.logger.logs == {"fit": 1, "validate": 1, "test": 1}
trainer.predict(model)
assert trainer.logger.logs == {"fit": 1, "validate": 1, "test": 1, "predict": 1}