Skip to content

Commit

Permalink
Automatically check DataModule.has_{setup,teardown,prepare_data} [2…
Browse files Browse the repository at this point in the history
…/2] (#7238)

* Automatically check `DataModule.has_{setup,teardown,prepare_data}`

* Use variable

* Spacing

* Docs

* Update CHANGELOG

* Remove `_DataModuleWrapper`

* Add test

* Update docs/source/extensions/datamodules.rst

* Bad merge

* add test for invalid name

* Remove ValueError

Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
carmocca and awaelchli authored May 11, 2021
1 parent 8660d8c commit b65ae79
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 28 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))


- `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238))


- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/))


Expand Down
28 changes: 16 additions & 12 deletions docs/source/extensions/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,6 @@ Here's a more realistic, complex DataModule that shows how much more reusable th
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``.


---------------

LightningDataModule API
Expand Down Expand Up @@ -228,7 +224,7 @@ There are also data operations you might want to perform on every GPU. Use setup
def setup(self, stage: Optional[str] = None):
# Assign Train/val split(s) for use in Dataloaders
if stage == 'fit' or stage is None:
if stage in (None, 'fit'):
mnist_full = MNIST(
self.data_dir,
train=True,
Expand All @@ -239,7 +235,7 @@ There are also data operations you might want to perform on every GPU. Use setup
self.dims = self.mnist_train[0][0].shape
# Assign Test split(s) for use in Dataloaders
if stage == 'test' or stage is None:
if stage in (None, 'test'):
self.mnist_test = MNIST(
self.data_dir,
train=False,
Expand All @@ -249,10 +245,17 @@ There are also data operations you might want to perform on every GPU. Use setup
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)
.. warning:: ``setup`` is called from every process. Setting state here is okay.

:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` expects an ``stage: Optional[str]`` argument.
It is used to separate setup logic for ``trainer.{fit,validate,test}``. If ``setup`` is called with ``stage = None``,
we assume all stages have been set-up.

.. note:: ``setup`` is called from every process. Setting state here is okay.
.. note:: ``teardown`` can be used to clean up the state. It is also called from every process
.. note::
``{setup,teardown,prepare_data}`` call will be only called once for a specific stage.
If the stage was ``None`` then we assume ``{fit,validate,test}`` have been called. For example, this means that
any duplicate ``dm.setup('fit')`` calls will be a no-op. To avoid this, you can overwrite
``dm._has_setup_fit = False``


train_dataloader
Expand Down Expand Up @@ -396,11 +399,12 @@ The recommended way to use a DataModule is simply:
dm = MNISTDataModule()
model = Model()
trainer.fit(model, dm)
trainer.test(datamodule=dm)
If you need information from the dataset to build your model, then run `prepare_data` and `setup` manually (Lightning
still ensures the method runs on the correct devices)
If you need information from the dataset to build your model, then run
:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.prepare_data` and
:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` manually (Lightning ensures
the method runs on the correct devices).

.. code-block:: python
Expand All @@ -416,7 +420,7 @@ still ensures the method runs on the correct devices)
----------------

Datamodules without Lightning
DataModules without Lightning
-----------------------------
You can of course use DataModules in plain PyTorch code as well.

Expand Down
2 changes: 0 additions & 2 deletions docs/source/starter/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,6 @@ When your models need to know about the data, it's best to process the data befo
1. use ``prepare_data()`` to download and process the dataset.
2. use ``setup()`` to do splits, and build your model internals

|
An alternative to using a DataModule is to defer initialization of the models modules to the ``setup`` method of your LightningModule as follows:

.. testcode::
Expand Down
4 changes: 2 additions & 2 deletions docs/source/starter/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -658,10 +658,10 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning.
transforms.Normalize((0.1307,), (0.3081,))
])
# split dataset
if stage == 'fit':
if stage in (None, 'fit'):
mnist_train = MNIST(os.getcwd(), train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
if stage == 'test':
if stage == (None, 'test'):
self.mnist_test = MNIST(os.getcwd(), train=False, transform=transform)

# return the dataloader for each split
Expand Down
14 changes: 11 additions & 3 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def _track_data_hook_calls(obj: 'LightningDataModule', fn: callable) -> callable
@functools.wraps(fn)
def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any:
name = fn.__name__
has_run = False

# If calling setup, we check the stage and assign stage-specific bool args
if name in ("setup", "teardown"):
Expand All @@ -366,15 +367,22 @@ def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any:
stage = args[0] if len(args) else kwargs.get("stage", None)

if stage is None:
has_run = True
for s in ("fit", "validate", "test"):
setattr(obj, f"_has_{name}_{s}", True)
attr = f"_has_{name}_{s}"
has_run &= getattr(obj, attr)
setattr(obj, attr, True)
else:
setattr(obj, f"_has_{name}_{stage}", True)
attr = f"_has_{name}_{stage}"
has_run = getattr(obj, attr)
setattr(obj, attr, True)

elif name == "prepare_data":
has_run = obj._has_prepared_data
obj._has_prepared_data = True

return fn(*args, **kwargs)
if not has_run:
return fn(*args, **kwargs)

return wrapped_fn

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def prepare_data(self):

def setup(self, stage: Optional[str] = None) -> None:
"""
Called at the beginning of fit (train + validate), validate, test, predict, or tune.
Called at the beginning of fit (train + validate), validate, test, and predict.
This is a good hook when you need to build models dynamically or adjust something about them.
This hook is called on every process when using DDP.
Expand Down
10 changes: 2 additions & 8 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,10 +1156,7 @@ def call_setup_hook(self, model: LightningModule) -> None:
self.accelerator.barrier("pre_setup")

if self.datamodule is not None:
called = getattr(self.datamodule, f'has_setup_{fn}')
if not called:
self.datamodule.setup(stage=fn)

self.datamodule.setup(stage=fn)
self.setup(model, stage=fn)
model.setup(stage=fn)

Expand All @@ -1182,10 +1179,7 @@ def call_teardown_hook(self, model: LightningModule) -> None:
fn = self.state.fn._setup_fn

if self.datamodule is not None:
called = getattr(self.datamodule, f'has_teardown_{fn}')
if not called:
self.datamodule.teardown(stage=fn)

self.datamodule.teardown(stage=fn)
self.profiler.teardown(stage=fn)
self.teardown(stage=fn)
model.teardown(stage=fn)
Expand Down
43 changes: 43 additions & 0 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,46 @@ def test_dm_init_from_datasets_dataloaders(iterable):
call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True)
])


def test_datamodule_hooks_calls(tmpdir):
"""Test that repeated calls to DataHooks' hooks have no effect"""

class TestDataModule(BoringDataModule):
setup_calls = []
teardown_calls = []
prepare_data_calls = 0

def setup(self, stage=None):
super().setup(stage=stage)
self.setup_calls.append(stage)

def teardown(self, stage=None):
super().teardown(stage=stage)
self.teardown_calls.append(stage)

def prepare_data(self):
super().prepare_data()
self.prepare_data_calls += 1

dm = TestDataModule()
dm.prepare_data()
dm.prepare_data()
dm.setup('fit')
dm.setup('fit')
dm.setup()
dm.setup()
dm.teardown('validate')
dm.teardown('validate')

assert dm.prepare_data_calls == 1
assert dm.setup_calls == ['fit', None]
assert dm.teardown_calls == ['validate']

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
trainer.test(BoringModel(), datamodule=dm)

# same number of calls
assert dm.prepare_data_calls == 1
assert dm.setup_calls == ['fit', None]
assert dm.teardown_calls == ['validate', 'test']

0 comments on commit b65ae79

Please sign in to comment.