Skip to content

Commit

Permalink
Fix failure when DataLoader(batch_size=None) is passed (#10345)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
2 people authored and lexierule committed Nov 9, 2021
1 parent 7d87986 commit 2685aa6
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed `apply_to_collection(defaultdict)` ([#10316](https://github.com/PyTorchLightning/pytorch-lightning/issues/10316))
- Fixed failure when `DataLoader(batch_size=None)` is passed ([#10345](https://github.com/PyTorchLightning/pytorch-lightning/issues/10345))


## [1.5.0] - 2021-11-02
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _dataloader_init_kwargs_resolve_sampler(
batch_sampler = getattr(dataloader, "batch_sampler")
is_predicting = mode == RunningStage.PREDICTING
# checking the batch sampler type is different than PyTorch default.
if (batch_sampler is not None and type(batch_sampler) is not BatchSampler) or is_predicting:
if batch_sampler is not None and (type(batch_sampler) is not BatchSampler or is_predicting):
batch_sampler = type(batch_sampler)(
sampler,
batch_size=batch_sampler.batch_size,
Expand Down
38 changes: 23 additions & 15 deletions tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,25 +283,26 @@ class CustomSampler(Sampler):
trainer.prepare_dataloader(dataloader, shuffle=True)


def test_loader_detaching():
"""Checks that the loader has been resetted after the entrypoint."""
class LoaderTestModel(BoringModel):
def training_step(self, batch, batch_idx):
assert len(self.trainer.train_dataloader.loaders) == 10
return super().training_step(batch, batch_idx)

class LoaderTestModel(BoringModel):
def training_step(self, batch, batch_idx):
assert len(self.trainer.train_dataloader.loaders) == 10
return super().training_step(batch, batch_idx)
def validation_step(self, batch, batch_idx):
assert len(self.trainer.val_dataloaders[0]) == 10
return super().validation_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
assert len(self.trainer.val_dataloaders[0]) == 10
return super().validation_step(batch, batch_idx)
def test_step(self, batch, batch_idx):
assert len(self.trainer.test_dataloaders[0]) == 10
return super().test_step(batch, batch_idx)

def test_step(self, batch, batch_idx):
assert len(self.trainer.test_dataloaders[0]) == 10
return super().test_step(batch, batch_idx)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
assert len(self.trainer.predict_dataloaders[0]) == 10
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)

def predict_step(self, batch, batch_idx, dataloader_idx=None):
assert len(self.trainer.predict_dataloaders[0]) == 10
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)

def test_loader_detaching():
"""Checks that the loader has been resetted after the entrypoint."""

loader = DataLoader(RandomDataset(32, 10), batch_size=1)

Expand Down Expand Up @@ -340,3 +341,10 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
assert len(model.val_dataloader()) == 64
assert len(model.predict_dataloader()) == 64
assert len(model.test_dataloader()) == 64


def test_pre_made_batches():
"""Check that loader works with pre-made batches."""
loader = DataLoader(RandomDataset(32, 10), batch_size=None)
trainer = Trainer(fast_dev_run=1)
trainer.predict(LoaderTestModel(), loader)

0 comments on commit 2685aa6

Please sign in to comment.