Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Prevent StopIteration to be raised in the training_step #9409

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Executing the `optimizer_closure` is now required when overriding the `optimizer_step` hook ([#9360](https://github.com/PyTorchLightning/pytorch-lightning/pull/9360))


- Prevent `training_step` to raise `StopIteration` when `dataloader_iter` is provided to the `training_step` ([#9409](https://github.com/PyTorchLightning/pytorch-lightning/pull/9409))


### Deprecated

- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
Expand Down
39 changes: 33 additions & 6 deletions pytorch_lightning/utilities/fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,20 +375,37 @@ class StepFuncDataLoaderIter:
"""This class is a wrapper to keep track of dataloader iterator fetching event while left entirely to user
control."""

def __init__(self, iterator: Iterator, data_fetcher: "AbstractDataFetcher"):
def __init__(
self,
iterator: Iterator,
parent_data_fetcher: "AbstractDataFetcher",
stage: Optional[str] = None,
batch_to_device: Optional[Callable] = None,
profiler: "Optional[pl.profiler.base.BaseProfiler]" = None,
):
self.iterator = iterator
self.data_fetcher = data_fetcher
self.parent_data_fetcher = parent_data_fetcher
self.stage = stage
self.batch_to_device = batch_to_device
self.profiler = profiler
self.pl_module_step_data_fetcher: Optional[DataFetcher] = None

def __iter__(self) -> "StepFuncDataLoaderIter":
pl_module_step_data_fetcher = DataFetcher()
pl_module_step_data_fetcher.setup(
self.iterator,
)
self.pl_module_step_data_fetcher = iter(pl_module_step_data_fetcher)
return self

def __next__(self) -> Any:
try:
data = next(self.iterator)
self.data_fetcher.fetched += 1
self.parent_data_fetcher.fetched += 1
data, _ = next(self.pl_module_step_data_fetcher)
return data
except StopIteration:
self.data_fetcher.done = True
self.parent_data_fetcher.done = True
self.pl_module_step_data_fetcher = None
raise StopIteration


Expand Down Expand Up @@ -417,9 +434,19 @@ def __init__(self):
self.store_on_device = True

def prefetching(self, prefetch_batches: int) -> None:
self.iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self))
self.iterator = iter(
StepFuncDataLoaderIter(
parent_data_fetcher=self,
iterator=self.dataloader_iter,
stage=self.stage,
batch_to_device=self.batch_to_device,
profiler=self.profiler,
)
)

def fetching_function(self):
while not self.done:
return self.fetched, (self.iterator, self.done)
# this is used to stop the loop
self.iterator = None
raise StopIteration
23 changes: 10 additions & 13 deletions tests/utilities/test_fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,35 +309,32 @@ def test_training_step_with_dataloader_access(tmpdir) -> None:
assert m.num_batches_processed == DATASET_LEN, f"Expect all {DATASET_LEN} batches to be processed."


@pytest.mark.parametrize("trigger_stop_iteration", [False, True])
def test_stop_iteration(trigger_stop_iteration, tmpdir):
def test_stop_iteration(tmpdir):
"""Verify that StopIteration properly terminates the training when this is trigged from the current
`dataloader_iter`"""
EXPECT_NUM_BATCHES_PROCESSED = 2

class TestModel(AsyncBoringModel):
def __init__(self, trigger_stop_iteration) -> None:
def __init__(self) -> None:
super().__init__()
self.trigger_stop_iteration = trigger_stop_iteration
self.has_raised = False

def training_step(self, dataloader_iter: Iterator, *args) -> STEP_OUTPUT:
output = super().training_step(dataloader_iter)
if self.trigger_stop_iteration and args[0] == EXPECT_NUM_BATCHES_PROCESSED:
try:
output = super().training_step(dataloader_iter)
except StopIteration:
self.has_raised = True
raise StopIteration
return output

def train_dataloader(self):
if self.trigger_stop_iteration:
return DataLoader(RandomDataset(BATCH_SIZE, 2 * EXPECT_NUM_BATCHES_PROCESSED))
return DataLoader(RandomDataset(BATCH_SIZE, EXPECT_NUM_BATCHES_PROCESSED))

trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
m = TestModel(trigger_stop_iteration)
m = TestModel()
trainer.fit(m)
expected = EXPECT_NUM_BATCHES_PROCESSED
if trigger_stop_iteration:
expected *= 2
assert m.num_batches_processed == expected
assert not m.has_raised
assert m.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED


def test_on_train_batch_start_overridden(tmpdir) -> None:
Expand Down