Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete TensorBoardLogger experiment before spawning the processes. #10777

Merged
merged 16 commits into from
Nov 26, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `_compare_version` for python packages ([#10762](https://github.com/PyTorchLightning/pytorch-lightning/pull/10762))


- Fixed TensorBoardLogger `SummaryWriter` not close before spawning the processes ([#10777](https://github.com/PyTorchLightning/pytorch-lightning/pull/10777))


## [1.5.2] - 2021-11-16

### Fixed
Expand Down
12 changes: 12 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch.nn.parallel.distributed import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
Expand Down Expand Up @@ -147,14 +148,17 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st
return {"nprocs": self.num_processes}

def start_training(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
trainer.optimizers = []

def start_evaluating(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)

def start_predicting(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)

def spawn(self, function: Callable, *args: Any, return_result: bool = True, **kwargs: Any) -> Optional[Any]:
Expand Down Expand Up @@ -415,3 +419,11 @@ def teardown(self) -> None:
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()

@staticmethod
def _clean_logger(trainer: "pl.Trainer") -> None:
loggers = trainer._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
for logger in loggers:
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
rank_zero_warn("When using `ddp_spawn`, the Tensorboard experiment should be `None`.")
tchaton marked this conversation as resolved.
Show resolved Hide resolved
logger._experiment = None
tchaton marked this conversation as resolved.
Show resolved Hide resolved