Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically check DataModule.has_{setup,teardown,prepare_data} [2/2] #7238

Merged
merged 16 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Changed


- `DataModules` now avoid duplicate {setup,teardown,prepare_data} calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238))


### Deprecated


Expand Down
28 changes: 16 additions & 12 deletions docs/source/extensions/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,6 @@ Here's a more realistic, complex DataModule that shows how much more reusable th
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)


.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``.


---------------

LightningDataModule API
Expand Down Expand Up @@ -228,7 +224,7 @@ There are also data operations you might want to perform on every GPU. Use setup
def setup(self, stage: Optional[str] = None):

# Assign Train/val split(s) for use in Dataloaders
if stage == 'fit' or stage is None:
if stage in (None, 'fit'):
mnist_full = MNIST(
self.data_dir,
train=True,
Expand All @@ -239,7 +235,7 @@ There are also data operations you might want to perform on every GPU. Use setup
self.dims = self.mnist_train[0][0].shape

# Assign Test split(s) for use in Dataloaders
if stage == 'test' or stage is None:
if stage in (None, 'test'):
self.mnist_test = MNIST(
self.data_dir,
train=False,
Expand All @@ -249,10 +245,17 @@ There are also data operations you might want to perform on every GPU. Use setup
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)


.. warning:: ``setup`` is called from every process. Setting state here is okay.

:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` expects an ``stage: Optional[str]`` argument.
It is used to separate setup logic for ``trainer.{fit,validate,test}``. If ``setup`` is called with ``stage = None``,
we assume all stages have been set-up.

.. note:: ``setup`` is called from every process. Setting state here is okay.
.. note:: ``teardown`` can be used to clean up the state. It is also called from every process
.. note::
``{setup,teardown,prepare_data}`` call will be only called once for a specific stage.
If the stage was ``None`` then we assume ``{fit,validate,test}`` have been called. For example, this means that
any duplicate ``dm.setup('fit')`` calls will be a no-op. To avoid this, you can overwrite
``dm._has_setup_fit = False``


train_dataloader
Expand Down Expand Up @@ -396,11 +399,12 @@ The recommended way to use a DataModule is simply:
dm = MNISTDataModule()
model = Model()
trainer.fit(model, dm)

trainer.test(datamodule=dm)

If you need information from the dataset to build your model, then run `prepare_data` and `setup` manually (Lightning
still ensures the method runs on the correct devices)
If you need information from the dataset to build your model, then run
:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.prepare_data` and
:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` manually (Lightning ensures
the method runs on the correct devices).

.. code-block:: python

Expand All @@ -416,7 +420,7 @@ still ensures the method runs on the correct devices)

----------------

Datamodules without Lightning
DataModules without Lightning
-----------------------------
You can of course use DataModules in plain PyTorch code as well.

Expand Down
2 changes: 0 additions & 2 deletions docs/source/starter/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,6 @@ When your models need to know about the data, it's best to process the data befo
1. use ``prepare_data()`` to download and process the dataset.
2. use ``setup()`` to do splits, and build your model internals

|

An alternative to using a DataModule is to defer initialization of the models modules to the ``setup`` method of your LightningModule as follows:

.. testcode::
Expand Down
4 changes: 2 additions & 2 deletions docs/source/starter/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -658,10 +658,10 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning.
transforms.Normalize((0.1307,), (0.3081,))
])
# split dataset
if stage == 'fit':
if stage in (None, 'fit'):
mnist_train = MNIST(os.getcwd(), train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
if stage == 'test':
if stage == (None, 'test'):
self.mnist_test = MNIST(os.getcwd(), train=False, transform=transform)

# return the dataloader for each split
Expand Down
16 changes: 13 additions & 3 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,15 +366,25 @@ def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any:
stage = args[0] if len(args) else kwargs.get("stage", None)

if stage is None:
has_run = True
for s in ("fit", "validate", "test"):
setattr(obj, f"_has_{name}_{s}", True)
attr = f"_has_{name}_{s}"
has_run &= getattr(obj, attr)
setattr(obj, attr, True)
else:
setattr(obj, f"_has_{name}_{stage}", True)
attr = f"_has_{name}_{stage}"
has_run = getattr(obj, attr)
setattr(obj, attr, True)

elif name == "prepare_data":
has_run = obj._has_prepared_data
obj._has_prepared_data = True

return fn(*args, **kwargs)
else:
raise ValueError(name)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

if not has_run:
return fn(*args, **kwargs)

return wrapped_fn

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def prepare_data(self):

def setup(self, stage: Optional[str] = None) -> None:
"""
Called at the beginning of fit (train + validate), validate, test, predict, or tune.
Called at the beginning of fit (train + validate), validate, test, and predict.
This is a good hook when you need to build models dynamically or adjust something about them.
This hook is called on every process when using DDP.

Expand Down
10 changes: 2 additions & 8 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,10 +1156,7 @@ def call_setup_hook(self, model: LightningModule) -> None:
self.accelerator.barrier("pre_setup")

if self.datamodule is not None:
called = getattr(self.datamodule, f'has_setup_{fn}')
if not called:
self.datamodule.setup(stage=fn)

self.datamodule.setup(stage=fn)
self.setup(model, stage=fn)
model.setup(stage=fn)

Expand All @@ -1182,10 +1179,7 @@ def call_teardown_hook(self, model: LightningModule) -> None:
fn = self.state.fn._setup_fn

if self.datamodule is not None:
called = getattr(self.datamodule, f'has_teardown_{fn}')
if not called:
self.datamodule.teardown(stage=fn)

self.datamodule.teardown(stage=fn)
self.profiler.teardown(stage=fn)
self.teardown(stage=fn)
model.teardown(stage=fn)
Expand Down
43 changes: 43 additions & 0 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,46 @@ def test_dm_init_from_datasets_dataloaders(iterable):
call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True)
])


def test_datamodule_hooks_calls(tmpdir):
"""Test that repeated calls to DataHooks' hooks have no effect"""

class TestDataModule(BoringDataModule):
setup_calls = []
teardown_calls = []
prepare_data_calls = 0

def setup(self, stage=None):
super().setup(stage=stage)
self.setup_calls.append(stage)

def teardown(self, stage=None):
super().teardown(stage=stage)
self.teardown_calls.append(stage)

def prepare_data(self):
super().prepare_data()
self.prepare_data_calls += 1

dm = TestDataModule()
dm.prepare_data()
dm.prepare_data()
dm.setup('fit')
dm.setup('fit')
dm.setup()
dm.setup()
dm.teardown('validate')
dm.teardown('validate')

assert dm.prepare_data_calls == 1
assert dm.setup_calls == ['fit', None]
assert dm.teardown_calls == ['validate']

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
trainer.test(BoringModel(), datamodule=dm)

# same number of calls
assert dm.prepare_data_calls == 1
assert dm.setup_calls == ['fit', None]
assert dm.teardown_calls == ['validate', 'test']