diff --git a/CHANGELOG.md b/CHANGELOG.md index b39e4dab4a6f9..a624f44ecd5f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -315,6 +315,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed incorrect main progress bar indicator when resuming training mid-epoch ([#9310](https://github.com/PyTorchLightning/pytorch-lightning/pull/9310)) +- Fixed `replace_sampler` missing the batch size under specific conditions ([#9367](https://github.com/PyTorchLightning/pytorch-lightning/pull/9367)) + + ## [1.4.5] - 2021-08-31 - Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142)) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 1dd5c1c6fec36..969404e68c498 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -201,6 +201,11 @@ def _get_dataloader_init_kwargs( # get the dataloader instance `__init__` parameters params = dict(inspect.signature(dataloader.__init__).parameters) + has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) + if has_variadic_kwargs: + # if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)` + params.update(inspect.signature(DataLoader.__init__).parameters) + del params["self"] # keep only the params whose default is different to the current attr value non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]} @@ -231,7 +236,6 @@ def _get_dataloader_init_kwargs( f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." ) - has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) if not has_variadic_kwargs: # the dataloader signature does not allow keyword arguments that need to be passed missing_kwargs = dl_kwargs.keys() - params.keys() diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index 437acd86a9024..c98c5f098e056 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -27,7 +27,7 @@ @pytest.mark.skipif( sys.platform == "win32" and not _TORCH_GREATER_EQUAL_1_7, reason="Bad `torch.distributed` support on Windows" ) -@pytest.mark.parametrize("mode", (1, 2)) +@pytest.mark.parametrize("mode", (1, 2, 3)) def test_replace_distributed_sampler(tmpdir, mode): class IndexedRandomDataset(RandomDataset): def __getitem__(self, index): @@ -57,25 +57,31 @@ def test_step(self, batch, batch_idx, dataloader_idx=None): def on_test_start(self) -> None: dataloader = self.trainer.test_dataloaders[0] assert isinstance(dataloader, CustomDataLoader) - assert dataloader.batch_size is None - batch_sampler = dataloader.batch_sampler - assert isinstance(batch_sampler, CustomBatchSampler) - assert batch_sampler.batch_size == 1 + if self._mode == 2: + assert isinstance(batch_sampler, CustomBatchSampler) + # the batch_size is set on the batch sampler + assert dataloader.batch_size is None + elif self._mode == 3: + assert type(batch_sampler) is BatchSampler + assert dataloader.batch_size == self._mode + assert batch_sampler.batch_size == self._mode assert batch_sampler.drop_last + # the sampler has been replaced assert isinstance(batch_sampler.sampler, DistributedSampler) def create_dataset(self): dataset = IndexedRandomDataset(32, 64) - batch_sampler = None - batch_size = 2 + if self._mode == 1: + # this case will raise an error + return FailureCustomDataLoader(32, dataset) if self._mode == 2: - batch_size = 1 - batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=batch_size, drop_last=True) - dataloader_cls = CustomDataLoader - else: - dataloader_cls = FailureCustomDataLoader - return dataloader_cls(32, dataset, batch_size=batch_size, batch_sampler=batch_sampler) + # with a custom batch sampler + batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=2, drop_last=True) + return CustomDataLoader(32, dataset, batch_sampler=batch_sampler) + elif self._mode == 3: + # with no batch sampler provided + return CustomDataLoader(32, dataset, batch_size=3, drop_last=True) def test_dataloader(self): return [self.create_dataset()] * self._numbers_test_dataloaders