Skip to content

Commit

Permalink
remove trainer hidden state | sanity refactor [1 / n] (#7437)
Browse files Browse the repository at this point in the history

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
awaelchli and pre-commit-ci[bot] authored May 11, 2021
1 parent 4a1134d commit ad9118f
Show file tree
Hide file tree
Showing 16 changed files with 100 additions and 80 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))

- Refactored Loops
* Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))

- `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238))

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def on_init_end(self, trainer):
self._trainer = trainer

def on_train_start(self, trainer, pl_module):
self._train_batch_idx = trainer.batch_idx
self._train_batch_idx = trainer.train_loop.batch_idx

def on_train_epoch_start(self, trainer, pl_module):
self._train_batch_idx = 0
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
self._max_epochs = trainer.max_epochs
if self._model_contains_batch_norm:
# virtually increase max_epochs to perform batch norm update on latest epoch.
trainer.max_epochs += 1
trainer.train_loop.max_epochs += 1

def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
if trainer.current_epoch == self.swa_start:
Expand Down Expand Up @@ -232,7 +232,7 @@ def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
# BatchNorm epoch update. Reset state
trainer.accumulate_grad_batches = self._accumulate_grad_batches
trainer.num_training_batches -= 1
trainer.max_epochs -= 1
trainer.train_loop.max_epochs -= 1
self.reset_momenta()
elif trainer.current_epoch == self.swa_end:
# Last SWA epoch. Transfer weights from average model to pl_module
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def restore_training_state(self, checkpoint, load_optimizer_states: bool = True)
# restore callback states
self.trainer.on_load_checkpoint(checkpoint)

self.trainer.global_step = checkpoint['global_step']
self.trainer.current_epoch = checkpoint['epoch']
self.trainer.train_loop.global_step = checkpoint['global_step']
self.trainer.train_loop.current_epoch = checkpoint['epoch']

# crash if max_epochs is lower then the current epoch from the checkpoint
if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/connectors/debugging_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def on_init_start(
limit_val_batches = fast_dev_run
limit_test_batches = fast_dev_run
limit_predict_batches = fast_dev_run
self.trainer.max_steps = fast_dev_run
self.trainer.train_loop.max_steps = fast_dev_run
self.trainer.num_sanity_val_steps = 0
self.trainer.max_epochs = 1
self.trainer.train_loop.max_epochs = 1
val_check_interval = 1.0
self.trainer.check_val_every_n_epoch = 1
self.trainer.logger = DummyLogger()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def info(self):
"""
model_ref = self.trainer.lightning_module
return {
"batch_idx": self.trainer.batch_idx,
"batch_idx": self.trainer.train_loop.batch_idx,
"fx_name": model_ref._current_hook_fx_name or model_ref._current_fx_name,
"dataloader_idx": model_ref._current_dataloader_idx or 0,
"opt_idx": self._opt_idx or 0,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/connectors/optimizer_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def update_learning_rates(
if isinstance(lr_scheduler['opt_idx'], int) and lr_scheduler['opt_idx'] not in opt_indices:
continue

current_idx = self.trainer.batch_idx if interval == 'step' else self.trainer.current_epoch
current_idx = self.trainer.train_loop.batch_idx if interval == 'step' else self.trainer.current_epoch
current_idx += 1 # account for both batch and epoch starts from 0
# Take step if call to update_learning_rates matches the interval key and
# the current step modulo the schedulers frequency is zero
Expand Down Expand Up @@ -86,7 +86,7 @@ def update_learning_rates(

if self.trainer.dev_debugger.enabled:
self.trainer.dev_debugger.track_lr_schedulers_update(
self.trainer.batch_idx,
self.trainer.train_loop.batch_idx,
interval,
scheduler_idx,
old_lr,
Expand Down
28 changes: 27 additions & 1 deletion pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.states import RunningStage, TrainerState, TrainerStatus
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn
from pytorch_lightning.utilities.argparse import (
add_argparse_args,
Expand All @@ -49,7 +50,6 @@ class TrainerProperties(ABC):
_default_root_dir: str
_lightning_optimizers = None
_progress_bar_callback: ProgressBarBase
state: TrainerState
_weights_save_path: str

accelerator_connector: AcceleratorConnector
Expand All @@ -58,6 +58,8 @@ class TrainerProperties(ABC):
limit_val_batches: int
logger: LightningLoggerBase
logger_connector: LoggerConnector
state: TrainerState
train_loop: TrainLoop

@property
def accelerator(self) -> Accelerator:
Expand Down Expand Up @@ -485,6 +487,30 @@ def sanity_checking(self, val: bool) -> None:
elif self.sanity_checking:
self.state.stage = None

@property
def global_step(self) -> int:
return self.train_loop.global_step

@property
def current_epoch(self) -> int:
return self.train_loop.current_epoch

@property
def max_epochs(self) -> Optional[int]:
return self.train_loop.max_epochs

@property
def min_epochs(self) -> Optional[int]:
return self.train_loop.min_epochs

@property
def max_steps(self) -> Optional[int]:
return self.train_loop.max_steps

@property
def min_steps(self) -> Optional[int]:
return self.train_loop.min_steps


# Used to represent the concrete type TrainerProperties class methods are called on.
_T = TypeVar('_T', bound=TrainerProperties)
21 changes: 8 additions & 13 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.predict_loop import PredictLoop
from pytorch_lightning.trainer.properties import TrainerProperties
from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus
from pytorch_lightning.trainer.states import TrainerFn, TrainerState, TrainerStatus
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.lr_finder import _LRFinder
Expand Down Expand Up @@ -308,6 +308,7 @@ def __init__(
"""
super().__init__()
Trainer._log_api_event("init")
self.state = TrainerState()
distributed_backend = distributed_backend or accelerator

# init connectors
Expand All @@ -329,7 +330,9 @@ def __init__(
self.checkpoint_connector = CheckpointConnector(self)
self.slurm_connector = SLURMConnector(self)
self.tuner = Tuner(self)
self.train_loop = TrainLoop(self, multiple_trainloader_mode)
self.train_loop = TrainLoop(
self, multiple_trainloader_mode, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps
)
self.evaluation_loop = EvaluationLoop(self)
self.predict_loop = PredictLoop(self)

Expand Down Expand Up @@ -375,13 +378,6 @@ def __init__(
truncated_bptt_steps,
terminate_on_nan,
)
self.train_loop.on_trainer_init(
max_epochs,
min_epochs,
max_steps,
min_steps,
num_sanity_val_steps,
)
self.evaluation_loop.on_trainer_init()

# configure tuner
Expand Down Expand Up @@ -995,10 +991,9 @@ def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
self.optimizer_connector.update_learning_rates(
interval='epoch',
opt_indices=[
opt_idx
for opt_idx, _ in self.train_loop.get_optimizers_iterable(batch_idx=(
self.total_batch_idx - 1
)) # Select the optimizers which were used in the last batch of the epoch
opt_idx for opt_idx, _ in self.train_loop.get_optimizers_iterable(
batch_idx=(self.train_loop.total_batch_idx - 1)
) # Select the optimizers which were used in the last batch of the epoch
],
)

Expand Down
67 changes: 32 additions & 35 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType
from pytorch_lightning.utilities.distributed import rank_zero_info
Expand All @@ -37,7 +36,16 @@

class TrainLoop:

def __init__(self, trainer, multiple_trainloader_mode: str):
def __init__(
self,
trainer,
multiple_trainloader_mode: str,
max_epochs: Optional[int],
min_epochs: Optional[int],
max_steps: Optional[int],
min_steps: Optional[int],
num_sanity_val_steps: int,
):
self.trainer = trainer
self.accumulated_loss = None
self.warning_cache = WarningCache()
Expand All @@ -50,30 +58,21 @@ def __init__(self, trainer, multiple_trainloader_mode: str):
self.trainer._multiple_trainloader_mode = multiple_trainloader_mode
self._optimizer_freq_cumsum = None

def on_trainer_init(
self,
max_epochs: Optional[int],
min_epochs: Optional[int],
max_steps: Optional[int],
min_steps: Optional[int],
num_sanity_val_steps: int,
) -> None:
self.trainer.global_step = 0
self.trainer.current_epoch = 0
self.global_step = 0
self.current_epoch = 0
self.trainer.should_stop = False
self.trainer.state = TrainerState()

self.trainer.total_batch_idx = 0
self.trainer.batch_idx = 0
self.total_batch_idx = 0
self.batch_idx = 0
self.trainer.num_training_batches = 0
self.trainer.train_dataloader = None

# If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000
self.trainer.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs
self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs
# If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1
self.trainer.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs
self.trainer.max_steps = max_steps
self.trainer.min_steps = min_steps
self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs
self.max_steps = max_steps
self.min_steps = min_steps

if num_sanity_val_steps == -1:
self.trainer.num_sanity_val_steps = float("inf")
Expand All @@ -91,9 +90,9 @@ def optimizer_freq_cumsum(self):
self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
return self._optimizer_freq_cumsum

def should_skip_training(self):
should_by_max_steps = self.trainer.max_steps is not None and self.trainer.global_step >= self.trainer.max_steps
should_by_epoch = self.trainer.max_epochs is not None and self.trainer.current_epoch >= self.trainer.max_epochs
def should_skip_training(self) -> bool:
should_by_max_steps = self.max_steps is not None and self.global_step >= self.max_steps
should_by_epoch = self.max_epochs is not None and self.current_epoch >= self.max_epochs
return should_by_max_steps or should_by_epoch or self.trainer.num_training_batches == 0

def on_train_start(self):
Expand All @@ -107,9 +106,9 @@ def on_train_end(self):

# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
# when a checkpoint was saved at the last step
self.trainer.global_step -= 1
self.global_step -= 1
self.check_checkpoint_callback(should_update=True, is_last=True)
self.trainer.global_step += 1
self.global_step += 1

# hook
self.trainer.call_hook("on_train_end")
Expand Down Expand Up @@ -145,7 +144,7 @@ def check_checkpoint_callback(self, should_update, is_last=False):
def on_train_epoch_start(self, epoch):

# update training progress in trainer
self.trainer.current_epoch = epoch
self.current_epoch = epoch

model = self.trainer.lightning_module

Expand Down Expand Up @@ -242,7 +241,7 @@ def get_optimizers_iterable(self, batch_idx=None):
return list(enumerate(self.trainer.optimizers))

if batch_idx is None:
batch_idx = self.trainer.total_batch_idx
batch_idx = self.total_batch_idx

optimizers_loop_length = self.optimizer_freq_cumsum[-1]
current_place_in_loop = batch_idx % optimizers_loop_length
Expand Down Expand Up @@ -450,7 +449,7 @@ def track_and_norm_grad(self, optimizer):

def _track_gradient_norm(self):
grad_norm_dict = {}
if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0:
if (self.global_step + 1) % self.trainer.log_every_n_steps == 0:
if float(self.trainer.track_grad_norm) > 0:
model = self.trainer.lightning_module
grad_norm_dict = grad_norm(model, self.trainer.track_grad_norm)
Expand Down Expand Up @@ -480,7 +479,7 @@ def run_training_epoch(self):
is_last_batch = None

for batch_idx, (batch, is_last_batch) in train_dataloader:
self.trainer.batch_idx = batch_idx
self.batch_idx = batch_idx
self.trainer.is_last_batch = is_last_batch

# ------------------------------------
Expand Down Expand Up @@ -530,7 +529,7 @@ def run_training_epoch(self):

# max steps reached, end training
if (
self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1
self.max_steps is not None and self.max_steps <= self.global_step + 1
and self._accumulated_batches_reached()
):
break
Expand All @@ -541,7 +540,7 @@ def run_training_epoch(self):
if self.trainer.should_stop:
break

self.trainer.total_batch_idx += 1
self.total_batch_idx += 1

# stop epoch if we limited the number of training batches
if self._num_training_batches_reached(is_last_batch):
Expand Down Expand Up @@ -887,15 +886,13 @@ def increment_accumulated_grad_global_step(self):

# progress global step according to grads progress
if num_accumulated_batches_reached or num_training_batches_reached:
self.trainer.global_step = self.trainer.accelerator.update_global_step(
self.trainer.total_batch_idx, self.trainer.global_step
)
self.global_step = self.trainer.accelerator.update_global_step(self.total_batch_idx, self.global_step)

def _accumulated_batches_reached(self):
return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0
return (self.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0

def _num_training_batches_reached(self, is_last_batch=False):
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch
return (self.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch

def should_accumulate(self):
# checks if backward or backward + optimizer step (via closure)
Expand Down
Loading

0 comments on commit ad9118f

Please sign in to comment.