diff --git a/pytorch_lightning/accelerators/cpu_backend.py b/pytorch_lightning/accelerators/cpu_backend.py index e1ed7275ceda6..4b830408bea44 100644 --- a/pytorch_lightning/accelerators/cpu_backend.py +++ b/pytorch_lightning/accelerators/cpu_backend.py @@ -40,7 +40,12 @@ def setup(self, model): def train(self): model = self.trainer.model - results = self.trainer.setup_training(model) + + # set up training routine + self.trainer.setup_training(model) + + # train or test + results = self.trainer.train_or_test() return results def training_step(self, args): diff --git a/pytorch_lightning/accelerators/ddp2_backend.py b/pytorch_lightning/accelerators/ddp2_backend.py index db63773df18da..1c40bcbda0e67 100644 --- a/pytorch_lightning/accelerators/ddp2_backend.py +++ b/pytorch_lightning/accelerators/ddp2_backend.py @@ -150,8 +150,11 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # allow user to configure ddp model = model.configure_ddp(model, device_ids) - # continue training routine - results = self.trainer.setup_training(model) + # set up training routine + self.trainer.setup_training(model) + + # train or test + results = self.trainer.train_or_test() # get original model model = self.trainer.get_model() diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py index a878ac9769c5a..c21449cc750be 100644 --- a/pytorch_lightning/accelerators/ddp_backend.py +++ b/pytorch_lightning/accelerators/ddp_backend.py @@ -234,8 +234,11 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # allow user to configure ddp model = model.configure_ddp(model, device_ids) - # continue training routine - results = self.trainer.setup_training(model) + # set up training routine + self.trainer.setup_training(model) + + # train or test + results = self.trainer.train_or_test() # get original model model = self.trainer.get_model() diff --git a/pytorch_lightning/accelerators/ddp_spawn_backend.py b/pytorch_lightning/accelerators/ddp_spawn_backend.py index d38b333e7acbe..46d88e11bd5ae 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_spawn_backend.py @@ -161,8 +161,11 @@ def ddp_train(self, process_idx, mp_queue, model): # allow user to configure ddp model = model.configure_ddp(model, device_ids) - # continue training routine - results = self.trainer.setup_training(model) + # set up training routine + self.trainer.setup_training(model) + + # train or test + results = self.trainer.train_or_test() # get original model model = self.trainer.get_model() diff --git a/pytorch_lightning/accelerators/dp_backend.py b/pytorch_lightning/accelerators/dp_backend.py index 593b51674a940..7652cd9bcf53c 100644 --- a/pytorch_lightning/accelerators/dp_backend.py +++ b/pytorch_lightning/accelerators/dp_backend.py @@ -96,7 +96,12 @@ def __init_nvidia_apex(self, model): def train(self): model = self.trainer.model - results = self.trainer.setup_training(model) + # set up training routine + self.trainer.setup_training(model) + + # train or test + results = self.trainer.train_or_test() + return results def teardown(self): diff --git a/pytorch_lightning/accelerators/gpu_backend.py b/pytorch_lightning/accelerators/gpu_backend.py index 46c29c9d3c503..a84d35b4cd77c 100644 --- a/pytorch_lightning/accelerators/gpu_backend.py +++ b/pytorch_lightning/accelerators/gpu_backend.py @@ -51,7 +51,13 @@ def setup(self, model): def train(self): model = self.trainer.model - results = self.trainer.setup_training(model) + + # set up training routine + self.trainer.setup_training(model) + + # train or test + results = self.trainer.train_or_test() + return results def training_step(self, args): diff --git a/pytorch_lightning/accelerators/horovod_backend.py b/pytorch_lightning/accelerators/horovod_backend.py index 066bf4871d4cc..8c7a9ecf912fb 100644 --- a/pytorch_lightning/accelerators/horovod_backend.py +++ b/pytorch_lightning/accelerators/horovod_backend.py @@ -105,11 +105,15 @@ def train(self): # Synchronization will be performed explicitly following backward() stack.enter_context(optimizer.skip_synchronize()) - result = self.trainer.setup_training(self.trainer.model) + # set up training routine + self.trainer.setup_training(self.trainer.model) + + # train or test + results = self.trainer.train_or_test() # Make sure all workers have finished training before returning to the user hvd.join() - return result + return results def teardown(self): pass diff --git a/pytorch_lightning/accelerators/tpu_backend.py b/pytorch_lightning/accelerators/tpu_backend.py index b7a4cef62e764..e67055ae3f1f6 100644 --- a/pytorch_lightning/accelerators/tpu_backend.py +++ b/pytorch_lightning/accelerators/tpu_backend.py @@ -114,8 +114,11 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine # setup TPU training self.__setup_tpu_training(model, trainer) - # Run the pretrain routine - results = trainer.setup_training(model) + # set up training routine + self.trainer.setup_training(model) + + # train or test + results = self.trainer.train_or_test() # save weights at the end of training self.__save_end_of_training_weights(model, trainer) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cc7ac431fa445..b4bbc43d96345 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1190,40 +1190,32 @@ def setup_training(self, model: LightningModule): if self.is_function_implemented('on_pretrain_routine_end'): ref_model.on_pretrain_routine_end() - # -------------------------- - # if test - # -------------------------- - # when testing requested only run test and return - if self.testing: - # only load test dataloader for testing - # self.reset_test_dataloader(ref_model) - eval_loop_results, _ = self.run_evaluation(test_mode=True) + def run_test(self): + # only load test dataloader for testing + # self.reset_test_dataloader(ref_model) + eval_loop_results, _ = self.run_evaluation(test_mode=True) - if len(eval_loop_results) == 0: - return 1 + if len(eval_loop_results) == 0: + return 1 - # remove the tensors from the eval results - for i, result in enumerate(eval_loop_results): - if isinstance(result, dict): - for k, v in result.items(): - if isinstance(v, torch.Tensor): - result[k] = v.cpu().item() + # remove the tensors from the eval results + for i, result in enumerate(eval_loop_results): + if isinstance(result, dict): + for k, v in result.items(): + if isinstance(v, torch.Tensor): + result[k] = v.cpu().item() - return eval_loop_results - - # -------------------------- - # sanity - # -------------------------- - # run a few val batches before training starts - self._run_sanity_check(ref_model, model) + return eval_loop_results - # -------------------------- - # TRAIN - # -------------------------- - self.train() + def train_or_test(self): + if self.testing: + results = self.run_test() + else: + results = self.train() + return results - def _run_sanity_check(self, ref_model, model): - using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', self.get_model()) + def run_sanity_check(self, ref_model): + using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model) should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 # run tiny validation (if validation defined) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 01455cc226535..76f91039708cf 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -330,7 +330,13 @@ def call_hook(self, hook_name, *args, **kwargs): def has_arg(self, *args): """Warning: this is just empty shell for code implemented in other class.""" + @abstractmethod + def run_sanity_check(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" + def train(self): + self.run_sanity_check(self.get_model()) + # TODO: shrink # clear cache before training if self.on_gpu and self.root_gpu is not None: