Skip to content

Commit

Permalink
Use callable object for patching dataloaders (#971)
Browse files Browse the repository at this point in the history
* Use callable object for patching dataloaders

* Add test for ddp with dataloaders passed to fit()

* Update pytorch_lightning/trainer/trainer.py

Co-Authored-By: Jirka Borovec <[email protected]>

* Update pytorch_lightning/trainer/trainer.py

Co-Authored-By: Jirka Borovec <[email protected]>

Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
shoarora and Borda authored Mar 2, 2020
1 parent 19c1c77 commit a1fb3a4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
30 changes: 18 additions & 12 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,30 +1002,21 @@ def __set_fit_dataloaders(self, model, train_dataloader, val_dataloaders, test_d
m = 'You called .fit() with a train_dataloader but did not define training_step()'
raise MisconfigurationException(m)

def patch_train_dataloader():
return train_dataloader

model.train_dataloader = patch_train_dataloader
model.train_dataloader = _PatchDataLoader(train_dataloader)

if val_dataloaders is not None:
if not self.is_overriden('validation_step', model):
m = 'You called .fit() with a val_dataloaders but did not define validation_step()'
raise MisconfigurationException(m)

def patch_val_dataloader():
return val_dataloaders

model.val_dataloader = patch_val_dataloader
model.val_dataloader = _PatchDataLoader(val_dataloaders)

if test_dataloaders is not None:
if not self.is_overriden('test_step', model):
m = 'You called .fit() with a test_dataloaders but did not define test_step()'
raise MisconfigurationException(m)

def patch_test_dataloader():
return test_dataloaders

model.test_dataloader = patch_test_dataloader
model.test_dataloader = _PatchDataLoader(test_dataloaders)

def init_optimizers(
self,
Expand Down Expand Up @@ -1189,6 +1180,21 @@ def test(self, model: Optional[LightningModule] = None):
self.run_evaluation(test_mode=True)


class _PatchDataLoader(object):
r'''
Callable object for patching dataloaders passed into trainer.fit().
Use this class to override model.*_dataloader() and be pickle-compatible.
Args:
dataloader: Dataloader object to return when called.
'''
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader


def _set_dataloader(model, dataloader, attribute):
r'''
Check dataloaders passed to .fit() method if they are pytorch DataLoader
Expand Down
25 changes: 25 additions & 0 deletions tests/test_gpu_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,31 @@ def test_multi_gpu_model_ddp(tmpdir):
tutils.run_model_test(trainer_options, model)


def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
"""Make sure DDP works with dataloaders passed to fit()"""
if not tutils.can_run_gpu_test():
return

tutils.reset_seed()
tutils.set_random_master_port()

model, hparams = tutils.get_model()
trainer_options = dict(default_save_path=tmpdir,
show_progress_bar=False,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
gpus=[0, 1],
distributed_backend='ddp')

fit_options = dict(train_dataloader=model.train_dataloader(),
val_dataloaders=model.val_dataloader())

trainer = Trainer(**trainer_options)
result = trainer.fit(model, **fit_options)
assert result == 1, "DDP doesn't work with dataloaders passed to fit()."


def test_optimizer_return_options():
tutils.reset_seed()

Expand Down

0 comments on commit a1fb3a4

Please sign in to comment.