diff --git a/CHANGELOG.md b/CHANGELOG.md index a2806602291e7..d740e9228db97 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,7 +77,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - -- +- Fixed reference issues during epoch end result collection ([#8621](https://github.com/PyTorchLightning/pytorch-lightning/pull/8621)) - Fixed `trainer.fit_loop.split_idx` always returning `None` ([#8601](https://github.com/PyTorchLightning/pytorch-lightning/pull/8601)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index cbe4b4e86ff26..6aba17ec7b713 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -750,7 +750,7 @@ def training_step(self, batch, batch_idx): out = self(x) - # softmax uses only a portion of the batch in the denomintaor + # softmax uses only a portion of the batch in the denominator loss = self.softmax(out) loss = nce_loss(loss) return loss diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 152d34cbb028d..a413db4548c26 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -338,15 +338,16 @@ def _process_training_step_output(self, training_step_output: STEP_OUTPUT) -> Op loss = None hiddens = None - results.extra = {} # handle dict return if isinstance(training_step_output, dict): - loss = training_step_output.pop("loss", None) - hiddens = training_step_output.pop("hiddens", None) + # this should not modify the `training_step_output`, as the user could be using it after `training_step_end` + loss = training_step_output.get("loss") + hiddens = training_step_output.get("hiddens") # detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time` hiddens = apply_to_collection(hiddens, Tensor, lambda t: t.detach()) - results.extra = training_step_output + # use the setter instead of `dict.update` because it calls `detach` on the tensor items + results.extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")} # handle scalar return elif isinstance(training_step_output, Tensor): diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index db68691ebced4..52b31b67edfa9 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Any, Dict, Iterator, List, Optional, Union import torch @@ -276,11 +275,7 @@ def _track_epoch_end_reduce_metrics( # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(batch_end_outputs): # with 1 step (no tbptt) don't use a sequence at epoch end - if ( - isinstance(opt_outputs, list) - and len(opt_outputs) == 1 - and not isinstance(opt_outputs[0], ResultCollection) - ): + if isinstance(opt_outputs, list) and len(opt_outputs) == 1: opt_outputs = opt_outputs[0] epoch_output[opt_idx].append(opt_outputs) @@ -320,9 +315,10 @@ def _prepare_outputs( batch_outputs = [batch_outputs] for tbptt_output in batch_outputs: - out = tbptt_output.extra + out = {} if tbptt_output.minimize is not None: out["loss"] = tbptt_output.minimize.detach() + out.update(tbptt_output.extra) processed_tbptt_outputs.append(out) # if there was only one tbptt step then we can collapse that dimension diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index a9df732d9c58d..22258b8e52eea 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -163,3 +163,30 @@ def training_step_end(self, outputs): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) trainer.fit(model) + + +def test_prepare_outputs(tmpdir): + """ + Test that the `extra` field of the saved `ResultCollection` objects for + `training_epoch_end` doesn't get accidentally modified by reference. + """ + + class TestModel(BoringModel): + on_train_batch_end_called = 0 + + def on_train_batch_end(self, outputs, *args, **kwargs): + epoch_outputs = self.trainer.fit_loop.epoch_loop._epoch_output + epoch_outputs = epoch_outputs[0] # 1 optimizer + assert len(epoch_outputs) == self.on_train_batch_end_called + # `extra` should be empty for all `ResultCollection` objects + assert all(not out.extra for out in epoch_outputs) + self.on_train_batch_end_called += 1 + + def training_epoch_end(self, outputs) -> None: + # override so epoch outputs get stored + pass + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2) + trainer.fit(model) + assert model.on_train_batch_end_called == 2