diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index a06dd3c0b5dd6..defb751cee97e 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -350,8 +350,9 @@ def save_spawn_weights(self, model): :param model: :return: """ - path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt') - self.save_checkpoint(path) + if self.proc_rank == 0: + path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt') + self.save_checkpoint(path) def load_spawn_weights(self, original_model): """ @@ -370,6 +371,8 @@ def load_spawn_weights(self, original_model): # remove ddp weights os.remove(path) + return loaded_model + def resolve_root_node_address(self, root_node): if '[' in root_node: name = root_node.split('[')[0] diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 286e71efdf050..991b1796fc223 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -337,6 +337,7 @@ from abc import ABC, abstractmethod import logging as log import os +import signal import torch @@ -494,6 +495,8 @@ def tpu_train(self, tpu_core_idx, model): m = f'INIT TPU local core: {self.tpu_local_core_rank}, ' \ f'global rank: {self.tpu_global_core_rank}' log.info(m) + + # continue training routine self.run_pretrain_routine(model) self.save_spawn_weights(model) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index af6d36d140146..5647d24a03404 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -959,7 +959,14 @@ def fit( self.ddp_train(task, model) else: self.__set_random_port() + + # track for predict + self.model = model + + # train mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,)) + + # load weights if not interrupted self.load_spawn_weights(model) self.model = model @@ -976,7 +983,14 @@ def fit( # COLAB_GPU is an env var available by default in Colab environments. start_method = 'fork' if os.getenv('COLAB_GPU') else 'spawn' + + # track for predict + self.model = model + + # train xmp.spawn(self.tpu_train, args=(model,), nprocs=self.num_tpu_cores, start_method=start_method) + + # load weights if not interrupted self.load_spawn_weights(model) self.model = model @@ -1192,12 +1206,19 @@ def test(self, model: Optional[LightningModule] = None): trainer = Trainer() trainer.test(model) """ + self.testing = True if model is not None: self.model = model self.fit(model) - elif self.model is not None and (self.use_ddp or self.use_tpu): - self.fit(self.model) + elif self.use_ddp or self.use_tpu: + # attempt to load weights from a spawn + path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt') + test_model = self.model + if os.path.exists(path): + test_model = self.load_spawn_weights(self.model) + + self.fit(test_model) else: self.run_evaluation(test_mode=True) @@ -1217,21 +1238,6 @@ def __call__(self) -> Union[List[DataLoader], DataLoader]: return self.dataloader -class _PatchDataLoader(object): - r''' - Callable object for patching dataloaders passed into trainer.fit(). - Use this class to override model.*_dataloader() and be pickle-compatible. - - Args: - dataloader: Dataloader object to return when called. - ''' - def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): - self.dataloader = dataloader - - def __call__(self) -> Union[List[DataLoader], DataLoader]: - return self.dataloader - - def _set_dataloader(model, dataloader, attribute): r''' Check dataloaders passed to .fit() method if they are pytorch DataLoader diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index b4592923495e4..f62b840dac96d 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -6,6 +6,7 @@ from abc import ABC from subprocess import call from typing import Union +from copy import deepcopy import torch import torch.distributed as dist @@ -233,7 +234,9 @@ def dump_checkpoint(self): # add the hparams and state_dict from the model model = self.get_model() + checkpoint['state_dict'] = model.state_dict() + if hasattr(model, "hparams"): checkpoint['hparams'] = vars(model.hparams) else: