Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete TensorBoardLogger experiment before spawning the processes. #10777

Merged
merged 16 commits into from
Nov 26, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `_compare_version` for python packages ([#10762](https://github.com/PyTorchLightning/pytorch-lightning/pull/10762))


- Fixed TensorBoardLogger `SummaryWriter` not close before spawning the processes ([#10777](https://github.com/PyTorchLightning/pytorch-lightning/pull/10777))


## [1.5.2] - 2021-11-16

### Fixed
Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch.nn.parallel.distributed import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
Expand Down Expand Up @@ -147,14 +148,17 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st
return {"nprocs": self.num_processes}

def start_training(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
trainer.optimizers = []

def start_evaluating(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)

def start_predicting(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)

def spawn(self, function: Callable, *args: Any, return_result: bool = True, **kwargs: Any) -> Optional[Any]:
Expand Down Expand Up @@ -415,3 +419,14 @@ def teardown(self) -> None:
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()

@staticmethod
def _clean_logger(trainer: "pl.Trainer") -> None:
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
for logger in loggers:
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
rank_zero_warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton Is it necessary to have this warning? This info won't be necessary for majority of users.

And this warning pops up every time a user trains with ddp_spawn, ddp_spawn_sharded and tpu_spawn as TensorBoard logger is the default.

"When using `ddp_spawn`, the `TensorBoardLogger` experiment should be `None`. Setting it to `None`."
)
del logger._experiment
logger._experiment = None
tchaton marked this conversation as resolved.
Show resolved Hide resolved
14 changes: 14 additions & 0 deletions tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.utilities.imports import _compare_version
from tests.helpers import BoringModel

Expand Down Expand Up @@ -332,3 +333,16 @@ def test_tensorboard_missing_folder_warning(tmpdir, caplog):
assert logger.version == 0

assert "Missing logger folder:" in caplog.text


@pytest.mark.parametrize("use_list", [False, True])
def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir):
tensorboard_logger = TensorBoardLogger(save_dir=tmpdir)
tensorboard_logger.experiment
assert tensorboard_logger._experiment
logger = [tensorboard_logger] if use_list else tensorboard_logger
trainer = Trainer(strategy="ddp_spawn", devices=2, accelerator="auto", logger=logger)
trainer.training_type_plugin._clean_logger(trainer)
if use_list:
assert isinstance(trainer.logger, LoggerCollection)
assert tensorboard_logger._experiment is None