Skip to content

Commit

Permalink
ref: move train outside of setup training (#3297)
Browse files Browse the repository at this point in the history
* ref: move train outside of setup training

* ref: move train outside of setup training

* ref: move train outside of setup training

* ref: move train outside of setup training
  • Loading branch information
williamFalcon authored Sep 1, 2020
1 parent bcd13f7 commit b0298ce
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 42 deletions.
7 changes: 6 additions & 1 deletion pytorch_lightning/accelerators/cpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ def setup(self, model):

def train(self):
model = self.trainer.model
results = self.trainer.setup_training(model)

# set up training routine
self.trainer.setup_training(model)

# train or test
results = self.trainer.train_or_test()
return results

def training_step(self, args):
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,11 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
# allow user to configure ddp
model = model.configure_ddp(model, device_ids)

# continue training routine
results = self.trainer.setup_training(model)
# set up training routine
self.trainer.setup_training(model)

# train or test
results = self.trainer.train_or_test()

# get original model
model = self.trainer.get_model()
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,11 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
# allow user to configure ddp
model = model.configure_ddp(model, device_ids)

# continue training routine
results = self.trainer.setup_training(model)
# set up training routine
self.trainer.setup_training(model)

# train or test
results = self.trainer.train_or_test()

# get original model
model = self.trainer.get_model()
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,11 @@ def ddp_train(self, process_idx, mp_queue, model):
# allow user to configure ddp
model = model.configure_ddp(model, device_ids)

# continue training routine
results = self.trainer.setup_training(model)
# set up training routine
self.trainer.setup_training(model)

# train or test
results = self.trainer.train_or_test()

# get original model
model = self.trainer.get_model()
Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/accelerators/dp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ def __init_nvidia_apex(self, model):

def train(self):
model = self.trainer.model
results = self.trainer.setup_training(model)
# set up training routine
self.trainer.setup_training(model)

# train or test
results = self.trainer.train_or_test()

return results

def teardown(self):
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/accelerators/gpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ def setup(self, model):

def train(self):
model = self.trainer.model
results = self.trainer.setup_training(model)

# set up training routine
self.trainer.setup_training(model)

# train or test
results = self.trainer.train_or_test()

return results

def training_step(self, args):
Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/accelerators/horovod_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,15 @@ def train(self):
# Synchronization will be performed explicitly following backward()
stack.enter_context(optimizer.skip_synchronize())

result = self.trainer.setup_training(self.trainer.model)
# set up training routine
self.trainer.setup_training(self.trainer.model)

# train or test
results = self.trainer.train_or_test()

# Make sure all workers have finished training before returning to the user
hvd.join()
return result
return results

def teardown(self):
pass
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,11 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine
# setup TPU training
self.__setup_tpu_training(model, trainer)

# Run the pretrain routine
results = trainer.setup_training(model)
# set up training routine
self.trainer.setup_training(model)

# train or test
results = self.trainer.train_or_test()

# save weights at the end of training
self.__save_end_of_training_weights(model, trainer)
Expand Down
50 changes: 21 additions & 29 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,40 +1190,32 @@ def setup_training(self, model: LightningModule):
if self.is_function_implemented('on_pretrain_routine_end'):
ref_model.on_pretrain_routine_end()

# --------------------------
# if test
# --------------------------
# when testing requested only run test and return
if self.testing:
# only load test dataloader for testing
# self.reset_test_dataloader(ref_model)
eval_loop_results, _ = self.run_evaluation(test_mode=True)
def run_test(self):
# only load test dataloader for testing
# self.reset_test_dataloader(ref_model)
eval_loop_results, _ = self.run_evaluation(test_mode=True)

if len(eval_loop_results) == 0:
return 1
if len(eval_loop_results) == 0:
return 1

# remove the tensors from the eval results
for i, result in enumerate(eval_loop_results):
if isinstance(result, dict):
for k, v in result.items():
if isinstance(v, torch.Tensor):
result[k] = v.cpu().item()
# remove the tensors from the eval results
for i, result in enumerate(eval_loop_results):
if isinstance(result, dict):
for k, v in result.items():
if isinstance(v, torch.Tensor):
result[k] = v.cpu().item()

return eval_loop_results

# --------------------------
# sanity
# --------------------------
# run a few val batches before training starts
self._run_sanity_check(ref_model, model)
return eval_loop_results

# --------------------------
# TRAIN
# --------------------------
self.train()
def train_or_test(self):
if self.testing:
results = self.run_test()
else:
results = self.train()
return results

def _run_sanity_check(self, ref_model, model):
using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', self.get_model())
def run_sanity_check(self, ref_model):
using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model)
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0

# run tiny validation (if validation defined)
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,13 @@ def call_hook(self, hook_name, *args, **kwargs):
def has_arg(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def run_sanity_check(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def train(self):
self.run_sanity_check(self.get_model())

# TODO: shrink
# clear cache before training
if self.on_gpu and self.root_gpu is not None:
Expand Down

0 comments on commit b0298ce

Please sign in to comment.