Skip to content

Commit

Permalink
ref: inner train loop (intermediate step) 12/n (#3372)
Browse files Browse the repository at this point in the history
* ref: inner train loop (intermediate step) 12/n

* ref: inner train loop (intermediate step) 12/n

* ref: inner train loop (intermediate step) 12/n

* ref: inner train loop (intermediate step) 12/n

* ref: inner train loop (intermediate step) 12/n

* ref: inner train loop (intermediate step) 12/n
  • Loading branch information
williamFalcon authored Sep 6, 2020
1 parent d091faf commit 9939f53
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 171 deletions.
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,6 @@ def _clip_gradients(self, optimizer):
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
for p in parameters:
p.grad.data.mul_(clip_coef.to(p.grad.data.device))

def on_train_epoch_end(self):
pass
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/horovod_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,6 @@ def test_step(self, args):
def backward(self, closure_loss, optimizer, opt_idx):
super().backward(closure_loss, optimizer, opt_idx)
optimizer.synchronize()

def on_train_epoch_end(self):
hvd.join(hvd.local_rank() if self.trainer.on_gpu else -1)
66 changes: 66 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from pytorch_lightning.trainer.data_connector import DataConnector
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.training_loop_temp import TrainLoop
from pytorch_lightning import _logger as log

from pytorch_lightning.utilities.model_utils import is_overridden

Expand Down Expand Up @@ -1135,6 +1136,71 @@ def setup_training(self, model: LightningModule):
if self.is_function_implemented('on_pretrain_routine_end'):
ref_model.on_pretrain_routine_end()

def train(self):
self.run_sanity_check(self.get_model())

# enable train mode
model = self.get_model()
model.train()
torch.set_grad_enabled(True)

# reload data when needed
self.train_loop.reset_train_val_dataloaders(model)

# hook
self.train_loop.on_train_start()

try:
# run all epochs
for epoch in range(self.current_epoch, self.max_epochs):

# reset train dataloader
if self.reload_dataloaders_every_epoch:
self.reset_train_dataloader(model)

# hook
self.train_loop.on_train_epoch_start(epoch)

# run train epoch
self.run_training_epoch()

if self.max_steps and self.max_steps <= self.global_step:

# hook
self.train_loop.on_train_end()
return

# update LR schedulers
self.update_learning_rates(interval='epoch')

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

if self.should_stop:
if (met_min_epochs and met_min_steps):
self.train_loop.on_train_end()
return
else:
log.info('Trainer was signaled to stop but required minimum epochs'
f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
' not been met. Training will continue...')

# hook
self.train_loop.on_train_end()

except KeyboardInterrupt:
rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')

# user could press ctrl+c many times... only shutdown once
if not self.interrupted:
self.interrupted = True
self._state = TrainerState.INTERRUPTED
self.on_keyboard_interrupt()

# hook
self.train_loop.on_train_end()

def run_test(self):
# only load test dataloader for testing
# self.reset_test_dataloader(ref_model)
Expand Down
Loading

0 comments on commit 9939f53

Please sign in to comment.