From 9cbeb5ff3b70be649030770e57ce5bca280610f1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 9 Sep 2021 18:32:54 +0100 Subject: [PATCH 1/4] update --- pytorch_lightning/utilities/fetching.py | 41 ++++++++++++++++++++----- tests/utilities/test_fetching.py | 23 ++++++-------- 2 files changed, 44 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index f0f09401ab47e..0c088a8f8b92b 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -375,20 +375,37 @@ class StepFuncDataLoaderIter: """This class is a wrapper to keep track of dataloader iterator fetching event while left entirely to user control.""" - def __init__(self, iterator: Iterator, data_fetcher: "AbstractDataFetcher"): + def __init__( + self, + iterator: Iterator, + parent_data_fetcher: "AbstractDataFetcher", + stage: Optional[str] = None, + batch_to_device: Optional[Callable] = None, + profiler: "Optional[pl.profiler.base.BaseProfiler]" = None, + ): self.iterator = iterator - self.data_fetcher = data_fetcher + self.parent_data_fetcher = parent_data_fetcher + self.stage = stage + self.batch_to_device = batch_to_device + self.profiler = profiler + self.training_step_data_fetcher: Optional[DataFetcher] = None def __iter__(self) -> "StepFuncDataLoaderIter": + training_step_data_fetcher = DataFetcher() + training_step_data_fetcher.setup( + self.iterator, + ) + self.training_step_data_fetcher = iter(training_step_data_fetcher) return self def __next__(self) -> Any: try: - data = next(self.iterator) - self.data_fetcher.fetched += 1 - return data + self.parent_data_fetcher.fetched += 1 + data = next(self.training_step_data_fetcher) + return data[0] except StopIteration: - self.data_fetcher.done = True + self.parent_data_fetcher.done = True + self.training_step_data_fetcher = None raise StopIteration @@ -417,9 +434,19 @@ def __init__(self): self.store_on_device = True def prefetching(self, prefetch_batches: int) -> None: - self.iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self)) + self.iterator = iter( + StepFuncDataLoaderIter( + parent_data_fetcher=self, + iterator=self.dataloader_iter, + stage=self.stage, + batch_to_device=self.batch_to_device, + profiler=self.profiler, + ) + ) def fetching_function(self): while not self.done: return self.fetched, (self.iterator, self.done) + # this is used to stop the loop + self.iterator = None raise StopIteration diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 86d04af9d2eb3..4d7cff5f6306a 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -309,35 +309,32 @@ def test_training_step_with_dataloader_access(tmpdir) -> None: assert m.num_batches_processed == DATASET_LEN, f"Expect all {DATASET_LEN} batches to be processed." -@pytest.mark.parametrize("trigger_stop_iteration", [False, True]) -def test_stop_iteration(trigger_stop_iteration, tmpdir): +def test_stop_iteration(tmpdir): """Verify that StopIteration properly terminates the training when this is trigged from the current `dataloader_iter`""" EXPECT_NUM_BATCHES_PROCESSED = 2 class TestModel(AsyncBoringModel): - def __init__(self, trigger_stop_iteration) -> None: + def __init__(self) -> None: super().__init__() - self.trigger_stop_iteration = trigger_stop_iteration + self.has_raised = False def training_step(self, dataloader_iter: Iterator, *args) -> STEP_OUTPUT: - output = super().training_step(dataloader_iter) - if self.trigger_stop_iteration and args[0] == EXPECT_NUM_BATCHES_PROCESSED: + try: + output = super().training_step(dataloader_iter) + except StopIteration: + self.has_raised = True raise StopIteration return output def train_dataloader(self): - if self.trigger_stop_iteration: - return DataLoader(RandomDataset(BATCH_SIZE, 2 * EXPECT_NUM_BATCHES_PROCESSED)) return DataLoader(RandomDataset(BATCH_SIZE, EXPECT_NUM_BATCHES_PROCESSED)) trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = TestModel(trigger_stop_iteration) + m = TestModel() trainer.fit(m) - expected = EXPECT_NUM_BATCHES_PROCESSED - if trigger_stop_iteration: - expected *= 2 - assert m.num_batches_processed == expected + assert not m.has_raised + assert m.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED def test_on_train_batch_start_overridden(tmpdir) -> None: From d9322a4f28af3d340acd86ce3b69678f0052839d Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 9 Sep 2021 18:36:42 +0100 Subject: [PATCH 2/4] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bb9a9f1db8be3..d25cb9fa5299f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -166,6 +166,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Executing the `optimizer_closure` is now required when overriding the `optimizer_step` hook ([#9360](https://github.com/PyTorchLightning/pytorch-lightning/pull/9360)) +- Prevent `training_step` to raise `StopIteration` when `dataloader_iter` is provided to the `training_step` ([#9409](https://github.com/PyTorchLightning/pytorch-lightning/pull/9409)) + + ### Deprecated - Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()` From 46c83398bcbbb77e60880b4cc1c9a0dffb7421af Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 9 Sep 2021 18:37:58 +0100 Subject: [PATCH 3/4] update --- pytorch_lightning/utilities/fetching.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 0c088a8f8b92b..7c2c506aa4d62 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -388,24 +388,24 @@ def __init__( self.stage = stage self.batch_to_device = batch_to_device self.profiler = profiler - self.training_step_data_fetcher: Optional[DataFetcher] = None + self.pl_module_step_data_fetcher: Optional[DataFetcher] = None def __iter__(self) -> "StepFuncDataLoaderIter": - training_step_data_fetcher = DataFetcher() - training_step_data_fetcher.setup( + pl_module_step_data_fetcher = DataFetcher() + pl_module_step_data_fetcher.setup( self.iterator, ) - self.training_step_data_fetcher = iter(training_step_data_fetcher) + self.pl_module_step_data_fetcher = iter(pl_module_step_data_fetcher) return self def __next__(self) -> Any: try: self.parent_data_fetcher.fetched += 1 - data = next(self.training_step_data_fetcher) + data = next(self.pl_module_step_data_fetcher) return data[0] except StopIteration: self.parent_data_fetcher.done = True - self.training_step_data_fetcher = None + self.pl_module_step_data_fetcher = None raise StopIteration From 6eb232d999ce264079a6b004916b80bf1de94890 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 9 Sep 2021 18:40:32 +0100 Subject: [PATCH 4/4] update --- pytorch_lightning/utilities/fetching.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 7c2c506aa4d62..4dd78244e9711 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -401,8 +401,8 @@ def __iter__(self) -> "StepFuncDataLoaderIter": def __next__(self) -> Any: try: self.parent_data_fetcher.fetched += 1 - data = next(self.pl_module_step_data_fetcher) - return data[0] + data, _ = next(self.pl_module_step_data_fetcher) + return data except StopIteration: self.parent_data_fetcher.done = True self.pl_module_step_data_fetcher = None