Skip to content

Commit

Permalink
ref: inner train loop (intermediate step) 11/n (#3370)
Browse files Browse the repository at this point in the history
* ref: inner train loop (intermediate step) 11/n

* ref: inner train loop (intermediate step) 11/n
  • Loading branch information
williamFalcon authored Sep 6, 2020
1 parent 8542146 commit 8eef97c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 31 deletions.
31 changes: 5 additions & 26 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,8 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
from pytorch_lightning.utilities import rank_zero_warn, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer.training_loop_temp import TrainLoop
Expand Down Expand Up @@ -279,7 +277,6 @@ class TrainerTrainLoopMixin(ABC):
on_epoch_end: Callable
on_validation_end: Callable
on_keyboard_interrupt: Callable
on_train_epoch_start: Callable
on_train_epoch_end: Callable

@abstractmethod
Expand Down Expand Up @@ -351,30 +348,15 @@ def train(self):
try:
# run all epochs
for epoch in range(self.current_epoch, self.max_epochs):

# reset train dataloader
if self.reload_dataloaders_every_epoch:
self.reset_train_dataloader(model)
# set seed for distributed sampler (enables shuffling for each epoch)
if (self.use_ddp or self.use_horovod or self.on_tpu) \
and hasattr(self.train_dataloader, 'sampler') \
and hasattr(self.train_dataloader.sampler, 'set_epoch'):
self.train_dataloader.sampler.set_epoch(epoch)

# update training progress in trainer and model
model.current_epoch = epoch
self.current_epoch = epoch

# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_start(self, self.get_model())

# stores accumulated grad fractions per batch
self.batch_loss_value = TensorRunningAccum(
window_length=self.accumulate_grad_batches
)

# -----------------
# RUN TNG EPOCH
# -----------------
# hook
self.train_loop.on_train_epoch_start(epoch)

# run train epoch
self.run_training_epoch()

if self.max_steps and self.max_steps <= self.global_step:
Expand Down Expand Up @@ -419,9 +401,6 @@ def run_training_epoch(self):
# get model
model = self.get_model()

# hook
self.train_loop.on_train_epoch_start()

# modify dataloader if needed (ddp, etc...)
train_dataloader = self.accelerator_backend.process_dataloader(self.train_dataloader)

Expand Down
31 changes: 26 additions & 5 deletions pytorch_lightning/trainer/training_loop_temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.distributed as torch_distrib
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer.supporters import Accumulator
from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.memory import recursive_detach
Expand Down Expand Up @@ -83,10 +83,27 @@ def check_checkpoint_callback(self, should_check_val):
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
[c.on_validation_end(self.trainer, model) for c in checkpoint_callbacks]

def on_train_epoch_start(self):
# hook
self.trainer.call_hook('on_epoch_start')
self.trainer.call_hook('on_train_epoch_start')
def on_train_epoch_start(self, epoch):
model = self.trainer.get_model()

# set seed for distributed sampler (enables shuffling for each epoch)
# TODO: move to accelerators
if (self.trainer.use_ddp or self.trainer.use_horovod or self.trainer.on_tpu) \
and hasattr(self.trainer.train_dataloader, 'sampler') \
and hasattr(self.trainer.train_dataloader.sampler, 'set_epoch'):
self.trainer.train_dataloader.sampler.set_epoch(epoch)

# update training progress in trainer and model
model.current_epoch = epoch
self.trainer.current_epoch = epoch

# changing gradient according accumulation_scheduler
self.trainer.accumulation_scheduler.on_epoch_start(self.trainer, self.trainer.get_model())

# stores accumulated grad fractions per batch
self.trainer.batch_loss_value = TensorRunningAccum(
window_length=self.trainer.accumulate_grad_batches
)

# bookkeeping
self.should_check_val = False
Expand All @@ -95,6 +112,10 @@ def on_train_epoch_start(self):
self.early_stopping_accumulator = Accumulator()
self.checkpoint_accumulator = Accumulator()

# hook
self.trainer.call_hook('on_epoch_start')
self.trainer.call_hook('on_train_epoch_start')

def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx):
# figure out what to track for epoch end
self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs)
Expand Down

0 comments on commit 8eef97c

Please sign in to comment.