Skip to content

Commit

Permalink
Always run validation inside the training loop epoch (#7357)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
carmocca and awaelchli authored May 26, 2021
1 parent 037a71b commit 311d9fe
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 86 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))


- Validation is now always run inside the training epoch scope ([#7357](https://github.com/PyTorchLightning/pytorch-lightning/pull/7357))


- Refactored Loops
* Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
* Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506))
Expand Down
13 changes: 1 addition & 12 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ def _run_train(self) -> None:
self.state.stage = None
raise

def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
def _run_evaluation(self) -> _EVALUATE_OUTPUT:
if not (self.evaluating or self.sanity_checking):
rank_zero_warn(
f"`trainer._run_evaluation()` was called but the running stage is set to {self.state.stage}."
Expand Down Expand Up @@ -1010,17 +1010,6 @@ def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
# hook
self.evaluation_loop.on_evaluation_epoch_end()

# update epoch-level lr_schedulers
if on_epoch:
self.optimizer_connector.update_learning_rates(
interval='epoch',
opt_indices=[
opt_idx for opt_idx, _ in self.train_loop.get_active_optimizers(
batch_idx=(self.train_loop.total_batch_idx - 1)
) # Select the optimizers which were used in the last batch of the epoch
],
)

# log epoch metrics
eval_loop_results = self.logger_connector.get_evaluate_epoch_results()

Expand Down
58 changes: 22 additions & 36 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,6 @@ def run_training_epoch(self):
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
dataloader_idx = 0
batch_idx = None
is_last_batch = None

for batch_idx, (batch, is_last_batch) in train_dataloader:
self.batch_idx = batch_idx
Expand Down Expand Up @@ -529,44 +528,38 @@ def run_training_epoch(self):

self.total_batch_idx += 1

max_steps_reached = (
self.max_steps is not None and self.max_steps <= self.global_step + 1
and self._accumulated_batches_reached()
)
if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch):
break

# progress global step according to grads progress
self.increment_accumulated_grad_global_step()

max_steps_reached = (self.max_steps is not None and self.max_steps <= self.global_step)
if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch):
break

if batch_idx is None:
# dataloader/iterator did not produce a batch
return

# handle epoch_output on epoch end
self.on_train_epoch_end(epoch_output)

# the global step is manually decreased here due to backwards compatibility with existing loggers
# as they expect that the same step is used when logging epoch end metrics even when the batch loop has
# finished. this means the attribute does not exactly track the number of optimizer steps applied.
# TODO(@carmocca): deprecate and rename so users don't get confused
self.global_step -= 1
# log epoch metrics
self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output)
self.global_step += 1

should_check_val = self._should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
should_train_only = self.trainer.disable_validation or should_skip_eval

# update epoch level lr_schedulers if no val loop outside train loop is triggered
if not should_check_val or should_train_only:
self.update_lr_schedulers('epoch')
self.update_lr_schedulers('epoch')

if should_train_only:
did_train_only = self.trainer.disable_validation or self.trainer.evaluation_loop.should_skip_evaluation(
self.trainer.num_val_batches
)
if did_train_only:
self.global_step -= 1
self.check_checkpoint_callback(True)

if should_check_val:
self.trainer.validating = True
self.trainer._run_evaluation(on_epoch=True)
self.trainer.training = True

if batch_output.signal != -1:
self.increment_accumulated_grad_global_step()
self.global_step += 1

def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None:
# inform logger the batch loop has finished
Expand Down Expand Up @@ -882,7 +875,7 @@ def should_accumulate(self):
is_final_batch = self._num_training_batches_reached()
return not (accumulation_done or is_final_batch)

def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool:
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
""" Decide if we should run validation. """
if not self.trainer.enable_validation:
return False
Expand All @@ -893,26 +886,19 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo

# val_check_batch is inf for iterable datasets with no length defined
is_infinite_dataset = self.trainer.val_check_batch == float('inf')
if on_epoch and is_last_batch and is_infinite_dataset:
if is_last_batch and is_infinite_dataset:
return True

if self.trainer.should_stop:
return True

# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
is_val_check_batch = False
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'):
is_val_check_batch = is_last_batch
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
elif self.trainer.val_check_batch != float('inf'):
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0

# Note: num_training_batches is also inf for iterable datasets with no length defined
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0

if on_epoch:
return is_val_check_batch and epoch_end_val_check
else:
return is_val_check_batch and not epoch_end_val_check
return is_val_check_batch

def _build_kwargs(self, batch, batch_idx, opt_idx, hiddens):
# enable not needing to add opt_idx to training_step
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
call.on_batch_end(trainer, model),
call.on_train_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
Expand All @@ -94,6 +92,8 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
call.on_train_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_train_end(trainer, model),
call.on_fit_end(trainer, model),
call.teardown(trainer, model, 'fit'),
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def training_epoch_end(self, outputs):
model.validation_step = None

early_stop_callback = EarlyStopping(
monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=validation_step_none
monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=True
)
trainer = Trainer(
default_root_dir=tmpdir,
Expand Down
61 changes: 30 additions & 31 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@ def validation_epoch_end(self, outputs):
"validation_step_none,val_dataloaders_none,monitor",
[
(False, False, 'val_log'),
(False, False, 'train_log_epoch'),
(True, False, 'train_log_epoch'),
(False, True, 'train_log_epoch'),
(False, True, 'val_log'),
],
)
@pytest.mark.parametrize('reduce_lr_on_plateau', [False, True])
Expand All @@ -76,7 +75,7 @@ def test_model_checkpoint_score_and_ckpt(
max_epochs = 3
limit_train_batches = 5
limit_val_batches = 7
lr = 1e-1
lr, gamma = 1e-1, 2

class CustomBoringModel(BoringModel):

Expand Down Expand Up @@ -106,7 +105,7 @@ def configure_optimizers(self):
'strict': True,
}
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)

return [optimizer], [lr_scheduler]

Expand Down Expand Up @@ -153,9 +152,12 @@ def configure_optimizers(self):
assert mc_specific_data['current_score'] == score

if not reduce_lr_on_plateau:
lr_scheduler_specific_data = chk['lr_schedulers'][0]
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + 1))
actual_step_count = chk['lr_schedulers'][0]['_step_count']
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
# if validation_step_none, the checkpoint gets saved after the learning rate update
# so we need to increase the count by one
assert actual_step_count == epoch + 1 + validation_step_none
assert actual_lr == lr * gamma**(epoch + validation_step_none)

assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
Expand All @@ -180,23 +182,21 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
max_epochs = 3
limit_train_batches = 12
limit_val_batches = 7
lr = 1e-1
lr, gamma = 1e-1, 2
monitor = 'val_log'
per_epoch_steps = int(limit_train_batches * val_check_interval)
per_epoch_call_count = limit_train_batches // per_epoch_steps
left_over_steps = limit_train_batches % per_epoch_steps
per_val_train_batches = int(limit_train_batches * val_check_interval)
per_epoch_val_checks, leftover_train_batches = divmod(limit_train_batches, per_val_train_batches)

class CustomBoringModel(BoringModel):

def __init__(self):
super().__init__()
self.val_logs = torch.randn(per_epoch_call_count * max_epochs, limit_val_batches)
self.val_logs = torch.randn(per_epoch_val_checks * max_epochs, limit_val_batches)
self.val_loop_count = 0

def validation_step(self, batch, batch_idx):
log_value = self.val_logs[self.val_loop_count, batch_idx]
self.log('val_log', log_value)
self.log('epoch', self.current_epoch, on_epoch=True)
return super().validation_step(batch, batch_idx)

def validation_epoch_end(self, outputs):
Expand All @@ -213,7 +213,7 @@ def configure_optimizers(self):
'strict': True,
}
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)

return [optimizer], [lr_scheduler]

Expand Down Expand Up @@ -241,26 +241,27 @@ def configure_optimizers(self):

# on_train_end ckpt callback is called which creates an additional ckpt in case no ckpt is created at the
# end of epoch, thus if val_check_interval doesn't align with the training steps we create an additional ckpt
additional_ckpt, additional_ckpt_path = 0, None
additional_ckpt, additional_ckpt_path = False, None
if not epoch_aligned:
additional_ckpt_path = [f for f in ckpt_files if 'v1' in f.stem][0]
additional_ckpt = 1
additional_ckpt = True

additional_ckpt = 1 if not epoch_aligned else 0
assert len(ckpt_files) == len(scores) + additional_ckpt == per_epoch_call_count * max_epochs + additional_ckpt
assert len(ckpt_files) == len(scores) + additional_ckpt == per_epoch_val_checks * max_epochs + additional_ckpt
assert len(lr_scheduler_debug) == max_epochs

def _make_assertions(epoch, ix, add=''):
global_ix = ix + per_epoch_call_count * epoch
def _make_assertions(epoch, ix, version=''):
global_ix = ix + per_epoch_val_checks * epoch
duplicated = bool(version)

score = scores[global_ix]
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{add}.ckpt'
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{version}.ckpt'
assert math.isclose(score, expected_score, rel_tol=1e-4)

chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
assert chk['epoch'] == epoch + 1
epoch_num = epoch + (1 if add else 0)
expected_global_step = per_epoch_steps * (global_ix + 1) + (left_over_steps * epoch_num)
epoch_num = epoch + duplicated
expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num)
assert chk['global_step'] == expected_global_step

mc_specific_data = chk['callbacks'][type(checkpoint)]
Expand All @@ -269,25 +270,23 @@ def _make_assertions(epoch, ix, add=''):
assert mc_specific_data['current_score'] == score

if not reduce_lr_on_plateau:
lr_scheduler_specific_data = chk['lr_schedulers'][0]
did_update = 1 if (ix + 1 == per_epoch_call_count) and (epoch_aligned or add) else 0
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update))
actual_step_count = chk['lr_schedulers'][0]['_step_count']
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
assert actual_step_count == epoch + 1 + duplicated
assert actual_lr == lr * gamma**(epoch + duplicated)

return score

for epoch in range(max_epochs):
for i in range(per_epoch_call_count):
for i in range(per_epoch_val_checks):
score = _make_assertions(epoch, i)

assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)

# check the ckpt file saved on_train_end
if additional_ckpt_path:
epoch = max_epochs - 1
i = per_epoch_call_count - 1
_make_assertions(epoch, i, add='-v1')
_make_assertions(max_epochs - 1, per_epoch_val_checks - 1, version='-v1')


@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,6 @@ def test_trainer_model_hook_system_fit(tmpdir):
'on_epoch_start',
'on_train_epoch_start',
*(model.train_batch * train_batches),
'training_epoch_end',
'on_train_epoch_end',
'on_epoch_end',
'on_validation_model_eval',
'on_validation_start',
'on_epoch_start',
Expand All @@ -483,6 +480,9 @@ def test_trainer_model_hook_system_fit(tmpdir):
'on_save_checkpoint',
'on_validation_end',
'on_validation_model_train',
'training_epoch_end',
'on_train_epoch_end',
'on_epoch_end',
'on_train_end',
'on_fit_end',
'teardown_fit',
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,4 @@ def validation_step(self, *args):

assert trainer.current_epoch == 0
assert trainer.global_step == 5
assert model.validation_called_at == (0, 4) # TODO(@carmocca): should be 5 - will be fixed in next PR
assert model.validation_called_at == (0, 4)

0 comments on commit 311d9fe

Please sign in to comment.