Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix replace_sampler missing the batch size under specific conditions #9367

Merged
merged 2 commits into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed incorrect main progress bar indicator when resuming training mid-epoch ([#9310](https://github.com/PyTorchLightning/pytorch-lightning/pull/9310))


- Fixed `replace_sampler` missing the batch size under specific conditions ([#9367](https://github.com/PyTorchLightning/pytorch-lightning/pull/9367))


## [1.4.5] - 2021-08-31

- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ def _get_dataloader_init_kwargs(

# get the dataloader instance `__init__` parameters
params = dict(inspect.signature(dataloader.__init__).parameters)
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
if has_variadic_kwargs:
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
params.update(inspect.signature(DataLoader.__init__).parameters)
del params["self"]

# keep only the params whose default is different to the current attr value
non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}
carmocca marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -231,7 +236,6 @@ def _get_dataloader_init_kwargs(
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
)

has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
if not has_variadic_kwargs:
# the dataloader signature does not allow keyword arguments that need to be passed
missing_kwargs = dl_kwargs.keys() - params.keys()
Expand Down
32 changes: 19 additions & 13 deletions tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
@pytest.mark.skipif(
sys.platform == "win32" and not _TORCH_GREATER_EQUAL_1_7, reason="Bad `torch.distributed` support on Windows"
)
@pytest.mark.parametrize("mode", (1, 2))
@pytest.mark.parametrize("mode", (1, 2, 3))
def test_replace_distributed_sampler(tmpdir, mode):
class IndexedRandomDataset(RandomDataset):
def __getitem__(self, index):
Expand Down Expand Up @@ -57,25 +57,31 @@ def test_step(self, batch, batch_idx, dataloader_idx=None):
def on_test_start(self) -> None:
dataloader = self.trainer.test_dataloaders[0]
assert isinstance(dataloader, CustomDataLoader)
assert dataloader.batch_size is None

batch_sampler = dataloader.batch_sampler
assert isinstance(batch_sampler, CustomBatchSampler)
assert batch_sampler.batch_size == 1
if self._mode == 2:
assert isinstance(batch_sampler, CustomBatchSampler)
# the batch_size is set on the batch sampler
assert dataloader.batch_size is None
elif self._mode == 3:
assert type(batch_sampler) is BatchSampler
assert dataloader.batch_size == self._mode
assert batch_sampler.batch_size == self._mode
assert batch_sampler.drop_last
# the sampler has been replaced
assert isinstance(batch_sampler.sampler, DistributedSampler)

def create_dataset(self):
dataset = IndexedRandomDataset(32, 64)
batch_sampler = None
batch_size = 2
if self._mode == 1:
# this case will raise an error
return FailureCustomDataLoader(32, dataset)
if self._mode == 2:
batch_size = 1
batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=batch_size, drop_last=True)
dataloader_cls = CustomDataLoader
else:
dataloader_cls = FailureCustomDataLoader
return dataloader_cls(32, dataset, batch_size=batch_size, batch_sampler=batch_sampler)
# with a custom batch sampler
batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=2, drop_last=True)
return CustomDataLoader(32, dataset, batch_sampler=batch_sampler)
elif self._mode == 3:
# with no batch sampler provided
return CustomDataLoader(32, dataset, batch_size=3, drop_last=True)

def test_dataloader(self):
return [self.create_dataset()] * self._numbers_test_dataloaders
Expand Down