diff --git a/pytorch_lightning/accelerators/cpu_backend.py b/pytorch_lightning/accelerators/cpu_backend.py index 88c7329e6c500..abdc962a01fce 100644 --- a/pytorch_lightning/accelerators/cpu_backend.py +++ b/pytorch_lightning/accelerators/cpu_backend.py @@ -36,8 +36,10 @@ def setup(self, model): self.trainer.optimizers = optimizers self.trainer.lr_schedulers = lr_schedulers self.trainer.optimizer_frequencies = optimizer_frequencies + self.trainer.model = model - def train(self, model): + def train(self): + model = self.trainer.model results = self.trainer.run_pretrain_routine(model) return results diff --git a/pytorch_lightning/accelerators/ddp2_backend.py b/pytorch_lightning/accelerators/ddp2_backend.py index 45af3b32087ac..1e8c8bac614d3 100644 --- a/pytorch_lightning/accelerators/ddp2_backend.py +++ b/pytorch_lightning/accelerators/ddp2_backend.py @@ -46,6 +46,8 @@ def __init__(self, trainer): def setup(self, model): self._resolve_task_idx() + self.trainer.model = model + def _resolve_task_idx(self): if self.trainer.is_slurm_managing_tasks: self.task_idx = int(os.environ['SLURM_LOCALID']) @@ -57,7 +59,8 @@ def _resolve_task_idx(self): m = 'ddp2 only works in SLURM or via torchelastic with the WORLD_SIZE, LOCAL_RANK, GROUP_RANK flags' raise MisconfigurationException(m) - def train(self, model): + def train(self): + model = self.trainer.model self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model) def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py index 4f0de214066d5..70bf888fa500c 100644 --- a/pytorch_lightning/accelerators/ddp_backend.py +++ b/pytorch_lightning/accelerators/ddp_backend.py @@ -57,13 +57,16 @@ def setup(self, model): elif self.mode == 'torchelastic_ddp': self.__torchelastic_setup() + self.trainer.model = model + def __slurm_setup(self): self.task_idx = int(os.environ['SLURM_LOCALID']) def __torchelastic_setup(self): self.task_idx = int(os.environ['LOCAL_RANK']) - def train(self, model): + def train(self): + model = self.trainer.model self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model) def spawn_ddp_children(self, model): diff --git a/pytorch_lightning/accelerators/ddp_spawn_backend.py b/pytorch_lightning/accelerators/ddp_spawn_backend.py index 7e20b668fcf8f..684a9ebe0b89e 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_spawn_backend.py @@ -41,7 +41,11 @@ def setup(self, model): smp = mp.get_context('spawn') self.mp_queue = smp.SimpleQueue() - def train(self, model): + self.trainer.model = model + + def train(self): + model = self.trainer.model + # train in children process mp.spawn(self.ddp_train, nprocs=self.nprocs, args=(self.mp_queue, model,)) diff --git a/pytorch_lightning/accelerators/gpu_backend.py b/pytorch_lightning/accelerators/gpu_backend.py index 13a100f9d2140..68dbe4384d0ff 100644 --- a/pytorch_lightning/accelerators/gpu_backend.py +++ b/pytorch_lightning/accelerators/gpu_backend.py @@ -46,9 +46,11 @@ def setup(self, model): if self.trainer.amp_backend == AMPType.APEX: model = self._setup_nvidia_apex(model) - return model - def train(self, model): + self.trainer.model = model + + def train(self): + model = self.trainer.model results = self.trainer.run_pretrain_routine(model) return results diff --git a/pytorch_lightning/accelerators/tpu_backend.py b/pytorch_lightning/accelerators/tpu_backend.py index 672c547d07f1c..d260136a8cc26 100644 --- a/pytorch_lightning/accelerators/tpu_backend.py +++ b/pytorch_lightning/accelerators/tpu_backend.py @@ -54,6 +54,8 @@ def setup(self, model): smp = mp.get_context(self.start_method) self.mp_queue = smp.SimpleQueue() + self.trainer.model = model + def teardown(self, model): # restore main state with best weights best_path = self.mp_queue.get() @@ -75,8 +77,8 @@ def teardown(self, model): self.__load_weights_on_main_process() return results - def train(self, model: LightningModule): - self.trainer.model = model + def train(self): + model = self.trainer.model # train if self.trainer.tpu_id is not None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f305a712aa234..60ae60f4c5240 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1040,26 +1040,26 @@ def fit( if self.use_ddp2: self.accelerator_backend = DDP2Backend(self) self.accelerator_backend.setup(model) - results = self.accelerator_backend.train(model) + results = self.accelerator_backend.train() self.accelerator_backend.teardown() elif use_slurm_ddp: self.accelerator_backend = DDPBackend(self, mode='slurm_ddp') self.accelerator_backend.setup(model) - results = self.accelerator_backend.train(model) + results = self.accelerator_backend.train() self.accelerator_backend.teardown() elif use_torchelastic_ddp: self.accelerator_backend = DDPBackend(self, mode='torchelastic_ddp') self.accelerator_backend.setup(model) - results = self.accelerator_backend.train(model) + results = self.accelerator_backend.train() self.accelerator_backend.teardown() # regular ddp using .spawn elif use_ddp_spawn: self.accelerator_backend = DDPSpawnBackend(self, nprocs=self.num_processes) self.accelerator_backend.setup(model) - results = self.accelerator_backend.train(model) + results = self.accelerator_backend.train() self.accelerator_backend.teardown() # ddp @@ -1082,20 +1082,20 @@ def fit( elif self.use_single_gpu: self.accelerator_backend = GPUBackend(self) - model = self.accelerator_backend.setup(model) - results = self.accelerator_backend.train(model) + self.accelerator_backend.setup(model) + results = self.accelerator_backend.train() self.accelerator_backend.teardown() elif self.use_tpu: self.accelerator_backend = TPUBackend(self) self.accelerator_backend.setup(model) - self.accelerator_backend.train(model) + self.accelerator_backend.train() self.accelerator_backend.teardown(model) else: self.accelerator_backend = CPUBackend(self) self.accelerator_backend.setup(model) - results = self.accelerator_backend.train(model) + results = self.accelerator_backend.train() self.accelerator_backend.teardown() # hook