Skip to content

Commit

Permalink
ref: run_pretrain_routine -> setup_training (#3294)
Browse files Browse the repository at this point in the history
* ref: .tune()

* ref: run_pretrain_routine -> setup_training
  • Loading branch information
williamFalcon authored Aug 31, 2020
1 parent 805ff37 commit bcd13f7
Show file tree
Hide file tree
Showing 13 changed files with 39 additions and 40 deletions.
2 changes: 1 addition & 1 deletion docs/source/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Trainer
:members: fit, test
:noindex:
:exclude-members:
run_pretrain_routine,
setup_training,
_abc_impl,
set_random_port,
_Trainer__set_root_gpu,
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/cpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def setup(self, model):

def train(self):
model = self.trainer.model
results = self.trainer.run_pretrain_routine(model)
results = self.trainer.setup_training(model)
return results

def training_step(self, args):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
model = model.configure_ddp(model, device_ids)

# continue training routine
results = self.trainer.run_pretrain_routine(model)
results = self.trainer.setup_training(model)

# get original model
model = self.trainer.get_model()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
model = model.configure_ddp(model, device_ids)

# continue training routine
results = self.trainer.run_pretrain_routine(model)
results = self.trainer.setup_training(model)

# get original model
model = self.trainer.get_model()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def ddp_train(self, process_idx, mp_queue, model):
model = model.configure_ddp(model, device_ids)

# continue training routine
results = self.trainer.run_pretrain_routine(model)
results = self.trainer.setup_training(model)

# get original model
model = self.trainer.get_model()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/dp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init_nvidia_apex(self, model):

def train(self):
model = self.trainer.model
results = self.trainer.run_pretrain_routine(model)
results = self.trainer.setup_training(model)
return results

def teardown(self):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/gpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def setup(self, model):

def train(self):
model = self.trainer.model
results = self.trainer.run_pretrain_routine(model)
results = self.trainer.setup_training(model)
return results

def training_step(self, args):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/horovod_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def train(self):
# Synchronization will be performed explicitly following backward()
stack.enter_context(optimizer.skip_synchronize())

result = self.trainer.run_pretrain_routine(self.trainer.model)
result = self.trainer.setup_training(self.trainer.model)

# Make sure all workers have finished training before returning to the user
hvd.join()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine
self.__setup_tpu_training(model, trainer)

# Run the pretrain routine
results = trainer.run_pretrain_routine(model)
results = trainer.setup_training(model)

# save weights at the end of training
self.__save_end_of_training_weights(model, trainer)
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,6 @@ def num_gpus(self) -> int:
def copy_trainer_model_properties(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def run_pretrain_routine(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def init_optimizers(self, *args) -> Tuple[List, List, List]:
"""Warning: this is just empty shell for code implemented in other class."""
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,6 @@ class TrainerDPMixin(ABC):
def call_setup_hook(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def run_pretrain_routine(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def init_optimizers(self, *args) -> Tuple[List, List, List]:
"""Warning: this is just empty shell for code implemented in other class."""
Expand Down
37 changes: 22 additions & 15 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,12 +1118,15 @@ def can_prepare_data(self):
else:
return self.node_rank == 0 and self.local_rank == 0 and should_call_dm_prepare_data

def run_pretrain_routine(self, model: LightningModule):
def setup_training(self, model: LightningModule):
"""Sanity check a few things before starting actual training.
Args:
model: The model to run sanity test on.
"""
# --------------------------
# Setup??
# --------------------------
ref_model = model
if self.data_parallel:
ref_model = model.module
Expand Down Expand Up @@ -1151,7 +1154,7 @@ def run_pretrain_routine(self, model: LightningModule):
# wait for all models to restore weights
if self.on_tpu and XLA_AVAILABLE:
# wait for all processes to catch up
torch_xla.core.xla_model.rendezvous("pl.Trainer.run_pretrain_routine")
torch_xla.core.xla_model.rendezvous("pl.Trainer.setup_training")

elif self.use_horovod:
# wait for all processes to catch up
Expand All @@ -1160,6 +1163,9 @@ def run_pretrain_routine(self, model: LightningModule):
# register auto-resubmit when on SLURM
self.register_slurm_signal_handlers()

# --------------------------
# Pre-train
# --------------------------
# on pretrain routine start
self.on_pretrain_routine_start(ref_model)
if self.is_function_implemented('on_pretrain_routine_start'):
Expand All @@ -1179,6 +1185,14 @@ def run_pretrain_routine(self, model: LightningModule):
# restore training and model before hpc is called
self.restore_weights(model)

# on pretrain routine end
self.on_pretrain_routine_end(ref_model)
if self.is_function_implemented('on_pretrain_routine_end'):
ref_model.on_pretrain_routine_end()

# --------------------------
# if test
# --------------------------
# when testing requested only run test and return
if self.testing:
# only load test dataloader for testing
Expand All @@ -1197,22 +1211,15 @@ def run_pretrain_routine(self, model: LightningModule):

return eval_loop_results

# --------------------------
# sanity
# --------------------------
# run a few val batches before training starts
self._run_sanity_check(ref_model, model)

# clear cache before training
if self.on_gpu and self.root_gpu is not None:
# use context because of:
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
with torch.cuda.device(f'cuda:{self.root_gpu}'):
torch.cuda.empty_cache()

# on pretrain routine end
self.on_pretrain_routine_end(ref_model)
if self.is_function_implemented('on_pretrain_routine_end'):
ref_model.on_pretrain_routine_end()

# CORE TRAINING LOOP
# --------------------------
# TRAIN
# --------------------------
self.train()

def _run_sanity_check(self, ref_model, model):
Expand Down
16 changes: 8 additions & 8 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ class TrainerTrainLoopMixin(ABC):
max_epochs: int
min_epochs: int
on_gpu: bool
root_gpu: ...
use_ddp: bool
use_dp: bool
use_ddp2: bool
Expand Down Expand Up @@ -330,14 +331,13 @@ def has_arg(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def train(self):
# add signal handlers for process kills
# def _signal_kill_handler(*args):
# return TrainerTrainLoopMixin.run_training_teardown(self)
#
# orig_signal_handlers = {}
# for sig_name in SIGNAL_TERMINATE:
# orig_signal_handlers[sig_name] = signal.signal(getattr(signal, sig_name),
# _signal_kill_handler)
# TODO: shrink
# clear cache before training
if self.on_gpu and self.root_gpu is not None:
# use context because of:
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
with torch.cuda.device(f'cuda:{self.root_gpu}'):
torch.cuda.empty_cache()

# get model
model = self.get_model()
Expand Down

0 comments on commit bcd13f7

Please sign in to comment.