Skip to content

Commit

Permalink
Fix replace_sampler missing the batch size under specific conditions (
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored and lexierule committed Sep 15, 2021
1 parent 00c6640 commit 564a3f5
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 14 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed signature of `Timer.on_train_epoch_end` and `StochasticWeightAveraging.on_train_epoch_end` to prevent unwanted deprecation warnings ([#9347](https://github.com/PyTorchLightning/pytorch-lightning/pull/9347))


- Fixed `replace_sampler` missing the batch size under specific conditions ([#9367](https://github.com/PyTorchLightning/pytorch-lightning/pull/9367))


## [1.4.5] - 2021-08-31

- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[Runnin

# get the dataloader instance `__init__` parameters
params = dict(inspect.signature(dataloader.__init__).parameters)
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
if has_variadic_kwargs:
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
params.update(inspect.signature(DataLoader.__init__).parameters)
del params["self"]

# keep only the params whose default is different to the current attr value
non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}
Expand Down Expand Up @@ -207,7 +212,6 @@ def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[Runnin
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
)

has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
if not has_variadic_kwargs:
# the dataloader signature does not allow keyword arguments that need to be passed
missing_kwargs = dl_kwargs.keys() - params.keys()
Expand Down
32 changes: 19 additions & 13 deletions tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
@pytest.mark.skipif(
sys.platform == "win32" and not _TORCH_GREATER_EQUAL_1_7, reason="Bad `torch.distributed` support on Windows"
)
@pytest.mark.parametrize("mode", (1, 2))
@pytest.mark.parametrize("mode", (1, 2, 3))
def test_replace_distributed_sampler(tmpdir, mode):
class IndexedRandomDataset(RandomDataset):
def __getitem__(self, index):
Expand Down Expand Up @@ -57,25 +57,31 @@ def test_step(self, batch, batch_idx, dataloader_idx=None):
def on_test_start(self) -> None:
dataloader = self.trainer.test_dataloaders[0]
assert isinstance(dataloader, CustomDataLoader)
assert dataloader.batch_size is None

batch_sampler = dataloader.batch_sampler
assert isinstance(batch_sampler, CustomBatchSampler)
assert batch_sampler.batch_size == 1
if self._mode == 2:
assert isinstance(batch_sampler, CustomBatchSampler)
# the batch_size is set on the batch sampler
assert dataloader.batch_size is None
elif self._mode == 3:
assert type(batch_sampler) is BatchSampler
assert dataloader.batch_size == self._mode
assert batch_sampler.batch_size == self._mode
assert batch_sampler.drop_last
# the sampler has been replaced
assert isinstance(batch_sampler.sampler, DistributedSampler)

def create_dataset(self):
dataset = IndexedRandomDataset(32, 64)
batch_sampler = None
batch_size = 2
if self._mode == 1:
# this case will raise an error
return FailureCustomDataLoader(32, dataset)
if self._mode == 2:
batch_size = 1
batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=batch_size, drop_last=True)
dataloader_cls = CustomDataLoader
else:
dataloader_cls = FailureCustomDataLoader
return dataloader_cls(32, dataset, batch_size=batch_size, batch_sampler=batch_sampler)
# with a custom batch sampler
batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=2, drop_last=True)
return CustomDataLoader(32, dataset, batch_sampler=batch_sampler)
elif self._mode == 3:
# with no batch sampler provided
return CustomDataLoader(32, dataset, batch_size=3, drop_last=True)

def test_dataloader(self):
return [self.create_dataset()] * self._numbers_test_dataloaders
Expand Down

0 comments on commit 564a3f5

Please sign in to comment.