From f37cc21fced5645e3489a3800389417e27a646e8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Jul 2021 16:45:21 +0200 Subject: [PATCH 1/5] Fix references for `ResultCollection.extra` --- .../loops/batch/training_batch_loop.py | 1 - .../loops/epoch/training_epoch_loop.py | 10 +++---- tests/trainer/loops/test_training_loop.py | 27 +++++++++++++++++++ 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 152d34cbb028d..3cb59bef20ace 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -338,7 +338,6 @@ def _process_training_step_output(self, training_step_output: STEP_OUTPUT) -> Op loss = None hiddens = None - results.extra = {} # handle dict return if isinstance(training_step_output, dict): diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index db68691ebced4..52b31b67edfa9 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Any, Dict, Iterator, List, Optional, Union import torch @@ -276,11 +275,7 @@ def _track_epoch_end_reduce_metrics( # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(batch_end_outputs): # with 1 step (no tbptt) don't use a sequence at epoch end - if ( - isinstance(opt_outputs, list) - and len(opt_outputs) == 1 - and not isinstance(opt_outputs[0], ResultCollection) - ): + if isinstance(opt_outputs, list) and len(opt_outputs) == 1: opt_outputs = opt_outputs[0] epoch_output[opt_idx].append(opt_outputs) @@ -320,9 +315,10 @@ def _prepare_outputs( batch_outputs = [batch_outputs] for tbptt_output in batch_outputs: - out = tbptt_output.extra + out = {} if tbptt_output.minimize is not None: out["loss"] = tbptt_output.minimize.detach() + out.update(tbptt_output.extra) processed_tbptt_outputs.append(out) # if there was only one tbptt step then we can collapse that dimension diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index a9df732d9c58d..22258b8e52eea 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -163,3 +163,30 @@ def training_step_end(self, outputs): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) trainer.fit(model) + + +def test_prepare_outputs(tmpdir): + """ + Test that the `extra` field of the saved `ResultCollection` objects for + `training_epoch_end` doesn't get accidentally modified by reference. + """ + + class TestModel(BoringModel): + on_train_batch_end_called = 0 + + def on_train_batch_end(self, outputs, *args, **kwargs): + epoch_outputs = self.trainer.fit_loop.epoch_loop._epoch_output + epoch_outputs = epoch_outputs[0] # 1 optimizer + assert len(epoch_outputs) == self.on_train_batch_end_called + # `extra` should be empty for all `ResultCollection` objects + assert all(not out.extra for out in epoch_outputs) + self.on_train_batch_end_called += 1 + + def training_epoch_end(self, outputs) -> None: + # override so epoch outputs get stored + pass + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2) + trainer.fit(model) + assert model.on_train_batch_end_called == 2 From 34c4f236338f50d15a05a26b9ec57419b4f1d039 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Jul 2021 16:51:29 +0200 Subject: [PATCH 2/5] Update CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a2806602291e7..e6398c3db3811 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,7 +77,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - -- +- Fixed references for ResultCollection.extra ([#8621](https://github.com/PyTorchLightning/pytorch-lightning/pull/8621)) - Fixed `trainer.fit_loop.split_idx` always returning `None` ([#8601](https://github.com/PyTorchLightning/pytorch-lightning/pull/8601)) From 62f4eb52c835e6d99a7d8fad6f408cff615b86f6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Jul 2021 17:01:14 +0200 Subject: [PATCH 3/5] Update CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e6398c3db3811..3e014b2b0274f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,7 +77,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - -- Fixed references for ResultCollection.extra ([#8621](https://github.com/PyTorchLightning/pytorch-lightning/pull/8621)) +- Fix reference issues during epoch end result collection ([#8621](https://github.com/PyTorchLightning/pytorch-lightning/pull/8621)) - Fixed `trainer.fit_loop.split_idx` always returning `None` ([#8601](https://github.com/PyTorchLightning/pytorch-lightning/pull/8601)) From d5469191e07a78e84d2d3a10ac9b7bb9def06653 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Jul 2021 18:06:33 +0200 Subject: [PATCH 4/5] Reference fix --- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/loops/batch/training_batch_loop.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index cbe4b4e86ff26..6aba17ec7b713 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -750,7 +750,7 @@ def training_step(self, batch, batch_idx): out = self(x) - # softmax uses only a portion of the batch in the denomintaor + # softmax uses only a portion of the batch in the denominator loss = self.softmax(out) loss = nce_loss(loss) return loss diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 3cb59bef20ace..a413db4548c26 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -341,11 +341,13 @@ def _process_training_step_output(self, training_step_output: STEP_OUTPUT) -> Op # handle dict return if isinstance(training_step_output, dict): - loss = training_step_output.pop("loss", None) - hiddens = training_step_output.pop("hiddens", None) + # this should not modify the `training_step_output`, as the user could be using it after `training_step_end` + loss = training_step_output.get("loss") + hiddens = training_step_output.get("hiddens") # detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time` hiddens = apply_to_collection(hiddens, Tensor, lambda t: t.detach()) - results.extra = training_step_output + # use the setter instead of `dict.update` because it calls `detach` on the tensor items + results.extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")} # handle scalar return elif isinstance(training_step_output, Tensor): From df11b493c7b2a604fc727aaeb2521df6279514b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 30 Jul 2021 02:39:58 +0200 Subject: [PATCH 5/5] Update CHANGELOG.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e014b2b0274f..d740e9228db97 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,7 +77,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - -- Fix reference issues during epoch end result collection ([#8621](https://github.com/PyTorchLightning/pytorch-lightning/pull/8621)) +- Fixed reference issues during epoch end result collection ([#8621](https://github.com/PyTorchLightning/pytorch-lightning/pull/8621)) - Fixed `trainer.fit_loop.split_idx` always returning `None` ([#8601](https://github.com/PyTorchLightning/pytorch-lightning/pull/8601))