Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove trainer hidden state | sanity refactor [1 / n] #7437

Merged
merged 19 commits into from
May 11, 2021
Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))

- Refactored Loops
* Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))

- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/))

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def on_init_end(self, trainer):
self._trainer = trainer

def on_train_start(self, trainer, pl_module):
self._train_batch_idx = trainer.batch_idx
self._train_batch_idx = trainer.train_loop.batch_idx

def on_train_epoch_start(self, trainer, pl_module):
self._train_batch_idx = 0
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
self._max_epochs = trainer.max_epochs
if self._model_contains_batch_norm:
# virtually increase max_epochs to perform batch norm update on latest epoch.
trainer.max_epochs += 1
trainer.train_loop.max_epochs += 1

def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
if trainer.current_epoch == self.swa_start:
Expand Down Expand Up @@ -232,7 +232,7 @@ def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
# BatchNorm epoch update. Reset state
trainer.accumulate_grad_batches = self._accumulate_grad_batches
trainer.num_training_batches -= 1
trainer.max_epochs -= 1
trainer.train_loop.max_epochs -= 1
self.reset_momenta()
elif trainer.current_epoch == self.swa_end:
# Last SWA epoch. Transfer weights from average model to pl_module
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def restore_training_state(self, checkpoint, load_optimizer_states: bool = True)
# restore callback states
self.trainer.on_load_checkpoint(checkpoint)

self.trainer.global_step = checkpoint['global_step']
self.trainer.current_epoch = checkpoint['epoch']
self.trainer.train_loop.global_step = checkpoint['global_step']
self.trainer.train_loop.current_epoch = checkpoint['epoch']

# crash if max_epochs is lower then the current epoch from the checkpoint
if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/connectors/debugging_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def on_init_start(
limit_val_batches = fast_dev_run
limit_test_batches = fast_dev_run
limit_predict_batches = fast_dev_run
self.trainer.max_steps = fast_dev_run
self.trainer.train_loop.max_steps = fast_dev_run
self.trainer.num_sanity_val_steps = 0
self.trainer.max_epochs = 1
self.trainer.train_loop.max_epochs = 1
val_check_interval = 1.0
self.trainer.check_val_every_n_epoch = 1
self.trainer.logger = DummyLogger()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def info(self):
"""
model_ref = self.trainer.lightning_module
return {
"batch_idx": self.trainer.batch_idx,
"batch_idx": self.trainer.train_loop.batch_idx,
"fx_name": model_ref._current_hook_fx_name or model_ref._current_fx_name,
"dataloader_idx": model_ref._current_dataloader_idx or 0,
"opt_idx": self._opt_idx or 0,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/connectors/optimizer_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def update_learning_rates(
if isinstance(lr_scheduler['opt_idx'], int) and lr_scheduler['opt_idx'] not in opt_indices:
continue

current_idx = self.trainer.batch_idx if interval == 'step' else self.trainer.current_epoch
current_idx = self.trainer.train_loop.batch_idx if interval == 'step' else self.trainer.current_epoch
current_idx += 1 # account for both batch and epoch starts from 0
# Take step if call to update_learning_rates matches the interval key and
# the current step modulo the schedulers frequency is zero
Expand Down Expand Up @@ -86,7 +86,7 @@ def update_learning_rates(

if self.trainer.dev_debugger.enabled:
self.trainer.dev_debugger.track_lr_schedulers_update(
self.trainer.batch_idx,
self.trainer.train_loop.batch_idx,
interval,
scheduler_idx,
old_lr,
Expand Down
28 changes: 27 additions & 1 deletion pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.states import RunningStage, TrainerState, TrainerStatus
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn
from pytorch_lightning.utilities.argparse import (
add_argparse_args,
Expand All @@ -49,7 +50,6 @@ class TrainerProperties(ABC):
_default_root_dir: str
_lightning_optimizers = None
_progress_bar_callback: ProgressBarBase
state: TrainerState
_weights_save_path: str

accelerator_connector: AcceleratorConnector
Expand All @@ -58,6 +58,8 @@ class TrainerProperties(ABC):
limit_val_batches: int
logger: LightningLoggerBase
logger_connector: LoggerConnector
state: TrainerState
train_loop: TrainLoop

@property
def accelerator(self) -> Accelerator:
Expand Down Expand Up @@ -485,6 +487,30 @@ def sanity_checking(self, val: bool) -> None:
elif self.sanity_checking:
self.state.stage = None

@property
def global_step(self) -> int:
return self.train_loop.global_step

@property
def current_epoch(self) -> int:
return self.train_loop.current_epoch

@property
def max_epochs(self) -> Optional[int]:
return self.train_loop.max_epochs

@property
def min_epochs(self) -> Optional[int]:
return self.train_loop.min_epochs

@property
def max_steps(self) -> Optional[int]:
return self.train_loop.max_steps

@property
def min_steps(self) -> Optional[int]:
return self.train_loop.min_steps
Comment on lines +491 to +512
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this be deprecated in future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. The user is allowed to access it and it's useful, but it should be read-only :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is batch_idx intentionally left out of these read-only properties?

Copy link
Contributor Author

@awaelchli awaelchli May 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it's not meant to be on the trainer because 1) it is accessible to the user through the hooks already 2) it anyway wouldn't be clear what the meaning of this variable is since it is specific to the loop.



# Used to represent the concrete type TrainerProperties class methods are called on.
_T = TypeVar('_T', bound=TrainerProperties)
21 changes: 8 additions & 13 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.predict_loop import PredictLoop
from pytorch_lightning.trainer.properties import TrainerProperties
from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus
from pytorch_lightning.trainer.states import TrainerFn, TrainerState, TrainerStatus
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.lr_finder import _LRFinder
Expand Down Expand Up @@ -308,6 +308,7 @@ def __init__(
"""
super().__init__()
Trainer._log_api_event("init")
self.state = TrainerState()
distributed_backend = distributed_backend or accelerator

# init connectors
Expand All @@ -329,7 +330,9 @@ def __init__(
self.checkpoint_connector = CheckpointConnector(self)
self.slurm_connector = SLURMConnector(self)
self.tuner = Tuner(self)
self.train_loop = TrainLoop(self, multiple_trainloader_mode)
self.train_loop = TrainLoop(
self, multiple_trainloader_mode, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps
)
self.evaluation_loop = EvaluationLoop(self)
self.predict_loop = PredictLoop(self)

Expand Down Expand Up @@ -375,13 +378,6 @@ def __init__(
truncated_bptt_steps,
terminate_on_nan,
)
self.train_loop.on_trainer_init(
max_epochs,
min_epochs,
max_steps,
min_steps,
num_sanity_val_steps,
)
self.evaluation_loop.on_trainer_init()

# configure tuner
Expand Down Expand Up @@ -995,10 +991,9 @@ def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
self.optimizer_connector.update_learning_rates(
interval='epoch',
opt_indices=[
opt_idx
for opt_idx, _ in self.train_loop.get_optimizers_iterable(batch_idx=(
self.total_batch_idx - 1
)) # Select the optimizers which were used in the last batch of the epoch
opt_idx for opt_idx, _ in self.train_loop.get_optimizers_iterable(
batch_idx=(self.train_loop.total_batch_idx - 1)
) # Select the optimizers which were used in the last batch of the epoch
],
)

Expand Down
67 changes: 32 additions & 35 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType
from pytorch_lightning.utilities.distributed import rank_zero_info
Expand All @@ -37,7 +36,16 @@

class TrainLoop:

def __init__(self, trainer, multiple_trainloader_mode: str):
def __init__(
self,
trainer,
multiple_trainloader_mode: str,
max_epochs: Optional[int],
min_epochs: Optional[int],
max_steps: Optional[int],
min_steps: Optional[int],
num_sanity_val_steps: int,
):
self.trainer = trainer
self.accumulated_loss = None
self.warning_cache = WarningCache()
Expand All @@ -50,30 +58,21 @@ def __init__(self, trainer, multiple_trainloader_mode: str):
self.trainer._multiple_trainloader_mode = multiple_trainloader_mode
self._optimizer_freq_cumsum = None

def on_trainer_init(
self,
max_epochs: Optional[int],
min_epochs: Optional[int],
max_steps: Optional[int],
min_steps: Optional[int],
num_sanity_val_steps: int,
) -> None:
self.trainer.global_step = 0
self.trainer.current_epoch = 0
self.global_step = 0
self.current_epoch = 0
self.trainer.should_stop = False
self.trainer.state = TrainerState()

self.trainer.total_batch_idx = 0
self.trainer.batch_idx = 0
self.total_batch_idx = 0
self.batch_idx = 0
self.trainer.num_training_batches = 0
self.trainer.train_dataloader = None

# If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000
self.trainer.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs
self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
# If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1
self.trainer.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs
self.trainer.max_steps = max_steps
self.trainer.min_steps = min_steps
self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs
self.max_steps = max_steps
self.min_steps = min_steps

if num_sanity_val_steps == -1:
self.trainer.num_sanity_val_steps = float("inf")
Expand All @@ -91,9 +90,9 @@ def optimizer_freq_cumsum(self):
self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
return self._optimizer_freq_cumsum

def should_skip_training(self):
should_by_max_steps = self.trainer.max_steps is not None and self.trainer.global_step >= self.trainer.max_steps
should_by_epoch = self.trainer.max_epochs is not None and self.trainer.current_epoch >= self.trainer.max_epochs
def should_skip_training(self) -> bool:
should_by_max_steps = self.max_steps is not None and self.global_step >= self.max_steps
should_by_epoch = self.max_epochs is not None and self.current_epoch >= self.max_epochs
return should_by_max_steps or should_by_epoch or self.trainer.num_training_batches == 0

def on_train_start(self):
Expand All @@ -107,9 +106,9 @@ def on_train_end(self):

# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
# when a checkpoint was saved at the last step
self.trainer.global_step -= 1
self.global_step -= 1
self.check_checkpoint_callback(should_update=True, is_last=True)
self.trainer.global_step += 1
self.global_step += 1

# hook
self.trainer.call_hook("on_train_end")
Expand Down Expand Up @@ -145,7 +144,7 @@ def check_checkpoint_callback(self, should_update, is_last=False):
def on_train_epoch_start(self, epoch):

# update training progress in trainer
self.trainer.current_epoch = epoch
self.current_epoch = epoch

model = self.trainer.lightning_module

Expand Down Expand Up @@ -242,7 +241,7 @@ def get_optimizers_iterable(self, batch_idx=None):
return list(enumerate(self.trainer.optimizers))

if batch_idx is None:
batch_idx = self.trainer.total_batch_idx
batch_idx = self.total_batch_idx

optimizers_loop_length = self.optimizer_freq_cumsum[-1]
current_place_in_loop = batch_idx % optimizers_loop_length
Expand Down Expand Up @@ -450,7 +449,7 @@ def track_and_norm_grad(self, optimizer):

def _track_gradient_norm(self):
grad_norm_dict = {}
if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0:
if (self.global_step + 1) % self.trainer.log_every_n_steps == 0:
if float(self.trainer.track_grad_norm) > 0:
model = self.trainer.lightning_module
grad_norm_dict = grad_norm(model, self.trainer.track_grad_norm)
Expand Down Expand Up @@ -480,7 +479,7 @@ def run_training_epoch(self):
is_last_batch = None

for batch_idx, (batch, is_last_batch) in train_dataloader:
self.trainer.batch_idx = batch_idx
self.batch_idx = batch_idx
self.trainer.is_last_batch = is_last_batch

# ------------------------------------
Expand Down Expand Up @@ -530,7 +529,7 @@ def run_training_epoch(self):

# max steps reached, end training
if (
self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1
self.max_steps is not None and self.max_steps <= self.global_step + 1
and self._accumulated_batches_reached()
):
break
Expand All @@ -541,7 +540,7 @@ def run_training_epoch(self):
if self.trainer.should_stop:
break

self.trainer.total_batch_idx += 1
self.total_batch_idx += 1

# stop epoch if we limited the number of training batches
if self._num_training_batches_reached(is_last_batch):
Expand Down Expand Up @@ -887,15 +886,13 @@ def increment_accumulated_grad_global_step(self):

# progress global step according to grads progress
if num_accumulated_batches_reached or num_training_batches_reached:
self.trainer.global_step = self.trainer.accelerator.update_global_step(
self.trainer.total_batch_idx, self.trainer.global_step
)
self.global_step = self.trainer.accelerator.update_global_step(self.total_batch_idx, self.global_step)

def _accumulated_batches_reached(self):
return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0
return (self.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0

def _num_training_batches_reached(self, is_last_batch=False):
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch
return (self.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch

def should_accumulate(self):
# checks if backward or backward + optimizer step (via closure)
Expand Down
Loading