Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip replacing dataloader sampler if it's already a distributed sampler #4273

Merged
merged 14 commits into from
Oct 23, 2020
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587))

- Allow changing the logged step value in `validation_step` ([#4130](https://github.com/PyTorchLightning/pytorch-lightning/pull/4130))
- Allow setting `replace_sampler_ddp=True` with a distributed sampler already added ([#4273](https://github.com/PyTorchLightning/pytorch-lightning/pull/4273))

### Deprecated

Expand Down Expand Up @@ -122,7 +123,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed `current_epoch` property update to reflect true epoch number inside `LightningDataModule`, when `reload_dataloaders_every_epoch=True`. ([#3974](https://github.com/PyTorchLightning/pytorch-lightning/pull/3974))
- Fixed to print scaler value in progress bar ([#4053](https://github.com/PyTorchLightning/pytorch-lightning/pull/4053))
- Fixed to print scaler value in progress bar ([#4053](https://github.com/PyTorchLightning/pytorch-lightning/pull/4053))
- Fixed mismatch between docstring and code regarding when `on_load_checkpoint` hook is called ([#3996](https://github.com/PyTorchLightning/pytorch-lightning/pull/3996))


Expand Down Expand Up @@ -467,7 +468,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed adding val step argument to metrics ([#2986](https://github.com/PyTorchLightning/pytorch-lightning/pull/2986))
- Fixed an issue that caused `Trainer.test()` to stall in ddp mode ([#2997](https://github.com/PyTorchLightning/pytorch-lightning/pull/2997))
- Fixed gathering of results with tensors of varying shape ([#3020](https://github.com/PyTorchLightning/pytorch-lightning/pull/3020))
- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043))
- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043))
- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045))
- Fixed setting device to root gpu ([#3042](https://github.com/PyTorchLightning/pytorch-lightning/pull/3042))

Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,6 +1278,8 @@ def world_size(self):
Enables auto adding of distributed sampler. By default it will add ``shuffle=True``
for train sampler and ``shuffle=False`` for val/test sampler. If you want to customize
it, you can set ``replace_sampler_ddp=False`` and add your own distributed sampler.
If ``replace_sampler_ddp=True`` and a distributed sampler was already added,
Lightning will not replace the existing one.

.. testcode::

Expand Down
16 changes: 9 additions & 7 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
f' (try {num_cpus} which is the number of cpus on this machine)'
' in the `DataLoader` init to improve performance.')

def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:

# don't do anything if it's not a dataloader
is_dataloader = isinstance(dataloader, DataLoader)
Expand All @@ -112,8 +112,9 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:

if not is_dataloader or is_iterable_ds:
return dataloader
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)

is_in_dist = self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu
need_dist_sampler = is_in_dist and not isinstance(dataloader.sampler, DistributedSampler)
if self.replace_sampler_ddp and need_dist_sampler:
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
raise MisconfigurationException(
Expand All @@ -123,7 +124,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
' `replace_sampler_ddp`=False if you want to use your custom sampler.')

# replace with distributed sampler
sampler = self._get_distributed_sampler(dataloader, train)
sampler = self._get_distributed_sampler(dataloader, shuffle)
dataloader = self.replace_sampler(dataloader, sampler)

return dataloader
Expand All @@ -136,10 +137,11 @@ def replace_sampler(self, dataloader, sampler):
}

dl_args['sampler'] = sampler
dl_args['shuffle'] = False
dataloader = type(dataloader)(**dl_args)
return dataloader

def _get_distributed_sampler(self, dataloader, train):
def _get_distributed_sampler(self, dataloader, shuffle):
if self.use_tpu:
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
elif self.use_horovod:
Expand All @@ -154,7 +156,7 @@ def _get_distributed_sampler(self, dataloader, train):
assert self.distributed_backend is not None
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)

kwargs['shuffle'] = train and not self.overfit_batches
kwargs['shuffle'] = shuffle and not self.overfit_batches
sampler = DistributedSampler(dataloader.dataset, **kwargs)
return sampler

Expand All @@ -179,7 +181,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
self.num_training_batches = 0

# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, shuffle=True)

self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')
self._worker_check(self.train_dataloader, 'train dataloader')
Expand Down Expand Up @@ -267,7 +269,7 @@ def _reset_eval_dataloader(
rank_zero_warn("One of given dataloaders is None and it will be skipped.")

# add samplers
dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl is not None]
dataloaders = [self.auto_add_sampler(dl, shuffle=False) for dl in dataloaders if dl is not None]

loader_num_batches = []

Expand Down
41 changes: 37 additions & 4 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,17 +686,17 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
class CustomDummyObj:
sampler = None

result = trainer.auto_add_sampler(CustomDummyObj(), train=True)
result = trainer.auto_add_sampler(CustomDummyObj(), shuffle=True)
assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader"

dataset = list(range(1000))
result = trainer.auto_add_sampler(CustomDataLoader(dataset), train=True)
result = trainer.auto_add_sampler(CustomDataLoader(dataset), shuffle=True)
assert isinstance(result, torch.utils.data.DataLoader)
assert isinstance(result, CustomDataLoader)
assert hasattr(result, 'dummy_kwarg')

# Shuffled DataLoader should also work
result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), train=True)
result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), shuffle=True)
assert isinstance(result, torch.utils.data.DataLoader)
assert isinstance(result, CustomDataLoader)
assert hasattr(result, 'dummy_kwarg')
Expand All @@ -707,7 +707,7 @@ class CustomSampler(torch.utils.data.Sampler):
# Should raise an error if existing sampler is being replaced
with pytest.raises(MisconfigurationException, match='DistributedSampler'):
trainer.auto_add_sampler(
CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), train=True)
CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), shuffle=True)


class DistribSamplerCallback(Callback):
Expand Down Expand Up @@ -746,6 +746,39 @@ def test_dataloader_distributed_sampler(tmpdir):
trainer.test(ckpt_path=None)


class ModelWithDataLoaderDistributedSampler(EvalModelTemplate):

def train_dataloader(self):
dataloader = super().train_dataloader()
dist_sampler = DistributedSampler(dataloader.dataset, shuffle=True)
return DataLoader(
dataloader.dataset,
batch_size=self.batch_size,
drop_last=False,
sampler=dist_sampler,
shuffle=False
)


@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.')
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
def test_dataloader_distributed_sampler_already_attached(tmpdir):
""" Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on dataloader """

model = ModelWithDataLoaderDistributedSampler()
trainer = Trainer(
gpus=[0, 1],
num_nodes=1,
distributed_backend='ddp_spawn',
default_root_dir=tmpdir,
max_steps=100,
callbacks=[DistribSamplerCallback()],
replace_sampler_ddp=True,
)
result = trainer.fit(model)
assert result == 1, "DDP Training failed"


@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
def test_batch_size_smaller_than_num_gpus(tmpdir):
# we need at least 3 gpus for this test
Expand Down