Skip to content

Commit

Permalink
Move training_output validation to after train_step_end (#7868)
Browse files Browse the repository at this point in the history
* move validation to after aggregation

* changelog

* add test for training_step_end

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
kandluis and pre-commit-ci[bot] authored Jun 8, 2021
1 parent 3427cb7 commit f9fccdf
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868))

- Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851))

- Fixed ambiguous warning when both overfit and train dataloader shuffling are enabled ([#7685](https://github.com/PyTorchLightning/pytorch-lightning/pull/7685))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,10 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):

self.trainer.logger_connector.cache_logged_metrics()

self._check_training_step_output(training_step_output)

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)

self._check_training_step_output(training_step_output)

training_step_output_for_epoch_end, training_step_output = self._process_training_step_output(
training_step_output, split_batch
)
Expand Down
23 changes: 21 additions & 2 deletions tests/trainer/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ def validation_step(self, *args):
@pytest.mark.parametrize(['output'], [(5., ), ({'a': 5}, )])
def test_warning_invalid_trainstep_output(tmpdir, output):

class TestModel(BoringModel):
class InvalidTrainStepModel(BoringModel):

def training_step(self, batch, batch_idx):
return output

model = TestModel()
model = InvalidTrainStepModel()

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
with pytest.raises(
Expand All @@ -166,3 +166,22 @@ def training_step(self, batch, batch_idx):
)
):
trainer.fit(model)


def test_warning_valid_train_step_end(tmpdir):

class ValidTrainStepEndModel(BoringModel):

def training_step(self, batch, batch_idx):
output = self(batch)
return {'output': output, 'batch': batch}

def training_step_end(self, outputs):
loss = self.loss(outputs['batch'], outputs['output'])
return loss

# No error is raised
model = ValidTrainStepEndModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)

trainer.fit(model)

0 comments on commit f9fccdf

Please sign in to comment.