Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use callable object for patching dataloaders #971

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,30 +1002,21 @@ def __set_fit_dataloaders(self, model, train_dataloader, val_dataloaders, test_d
m = 'You called .fit() with a train_dataloader but did not define training_step()'
raise MisconfigurationException(m)

def patch_train_dataloader():
return train_dataloader

model.train_dataloader = patch_train_dataloader
model.train_dataloader = _PatchDataLoader(train_dataloader)

if val_dataloaders is not None:
if not self.is_overriden('validation_step', model):
m = 'You called .fit() with a val_dataloaders but did not define validation_step()'
raise MisconfigurationException(m)

def patch_val_dataloader():
return val_dataloaders

model.val_dataloader = patch_val_dataloader
model.val_dataloader = _PatchDataLoader(val_dataloaders)

if test_dataloaders is not None:
if not self.is_overriden('test_step', model):
m = 'You called .fit() with a test_dataloaders but did not define test_step()'
raise MisconfigurationException(m)

def patch_test_dataloader():
return test_dataloaders

model.test_dataloader = patch_test_dataloader
model.test_dataloader = _PatchDataLoader(test_dataloaders)

def init_optimizers(
self,
Expand Down Expand Up @@ -1189,6 +1180,21 @@ def test(self, model: Optional[LightningModule] = None):
self.run_evaluation(test_mode=True)


class _PatchDataLoader(object):
r'''
Callable object for patching dataloaders passed into trainer.fit().
Use this class to override model.*_dataloader() and be pickle-compatible.

Args:
dataloader: Dataloader object to return when called.
'''
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader


def _set_dataloader(model, dataloader, attribute):
r'''
Check dataloaders passed to .fit() method if they are pytorch DataLoader
Expand Down
25 changes: 25 additions & 0 deletions tests/test_gpu_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,31 @@ def test_multi_gpu_model_ddp(tmpdir):
tutils.run_model_test(trainer_options, model)


def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
"""Make sure DDP works with dataloaders passed to fit()"""
if not tutils.can_run_gpu_test():
return

tutils.reset_seed()
tutils.set_random_master_port()

model, hparams = tutils.get_model()
trainer_options = dict(default_save_path=tmpdir,
show_progress_bar=False,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
gpus=[0, 1],
distributed_backend='ddp')

fit_options = dict(train_dataloader=model.train_dataloader(),
val_dataloaders=model.val_dataloader())

trainer = Trainer(**trainer_options)
result = trainer.fit(model, **fit_options)
assert result == 1, "DDP doesn't work with dataloaders passed to fit()."


def test_optimizer_return_options():
tutils.reset_seed()

Expand Down