Skip to content

Commit

Permalink
[bugfix] Prevent on_before_batch_transfer to be called twice (#9715)
Browse files Browse the repository at this point in the history
Co-authored-by: Rohit Gupta <[email protected]>
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
3 people authored Sep 27, 2021
1 parent 64bbebc commit 131176b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
2 changes: 0 additions & 2 deletions pytorch_lightning/utilities/fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,6 @@ def _consume_prefetched_batches(self) -> Generator:
def _get_queued_batch(self) -> Tuple[Any, bool]:
self.wait()
batch = self.batches.pop(0)
if not self.store_on_device:
batch = self.move_data_to_device(batch)
is_last = len(self.batches) == 0
return batch, is_last

Expand Down
51 changes: 50 additions & 1 deletion tests/utilities/test_fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch import tensor
from torch.utils.data import DataLoader, Dataset, IterableDataset

from pytorch_lightning import Callback, Trainer
from pytorch_lightning import Callback, LightningDataModule, Trainer
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher
Expand Down Expand Up @@ -392,3 +392,52 @@ def __init__(self) -> None:
m = InvalidModel()
with pytest.raises(MisconfigurationException, match="is incompatible with `truncated_bptt_steps > 0`."):
trainer.fit(m)


def test_transfer_hooks_with_unpacking(tmpdir):

"""This test asserts the `transfer_batch` hooks are called only once per batch."""

class RandomDictDataset(RandomDataset):
def __getitem__(self, index):
return {"x": self.data[index], "y_true": torch.ones((2,)), "other": torch.ones((1,))}

class BoringDataModule(LightningDataModule):

count_called_on_before_batch_transfer = 0
count_called_transfer_batch_to_device = 0
count_called_on_after_batch_transfer = 0

def train_dataloader(self):
return DataLoader(RandomDictDataset(32, 2))

def val_dataloader(self):
return DataLoader(RandomDictDataset(32, 2))

def on_before_batch_transfer(self, batch, dataloader_idx: int):
self.count_called_on_before_batch_transfer += 1
return batch["x"], batch["y_true"]

def transfer_batch_to_device(self, *args, **kwargs):
self.count_called_transfer_batch_to_device += 1
return super().transfer_batch_to_device(*args, **kwargs)

def on_after_batch_transfer(self, batch, dataloader_idx: int):
self.count_called_on_after_batch_transfer += 1
return super().on_after_batch_transfer(batch, dataloader_idx)

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
x, _ = batch
return super().training_step(x, batch_idx)

def validation_step(self, batch, batch_idx):
x, _ = batch
return super().validation_step(x, batch_idx)

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, num_sanity_val_steps=0)
dm = BoringDataModule()
trainer.fit(TestModel(), datamodule=dm)
assert dm.count_called_on_before_batch_transfer == 4
assert dm.count_called_transfer_batch_to_device == 4
assert dm.count_called_on_after_batch_transfer == 4

0 comments on commit 131176b

Please sign in to comment.