Skip to content

Commit

Permalink
Remove deprecated datamodule lifecycle properties (#10350)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
mayeroa and carmocca authored Nov 5, 2021
1 parent a501cd3 commit b3c0f12
Show file tree
Hide file tree
Showing 6 changed files with 3 additions and 456 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).



-
- Removed deprecated `has_prepared_data`, `has_setup_fit`, `has_setup_validate`, `has_setup_test`, `has_setup_predict`, `has_teardown_fit`, `has_teardown_validate`, `has_teardown_test` and `has_teardown_predict` datamodule lifecycle properties ([#10350](https://github.com/PyTorchLightning/pytorch-lightning/pull/10350))


-
Expand Down
5 changes: 0 additions & 5 deletions docs/source/extensions/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,6 @@ we assume all stages have been set-up.

.. note:: ``setup`` is called from every process. Setting state here is okay.
.. note:: ``teardown`` can be used to clean up the state. It is also called from every process
.. note::
``{setup,teardown,prepare_data}`` call will be only called once for a specific stage.
If the stage was ``None`` then we assume ``{fit,validate,test}`` have been called. For example, this means that
any duplicate ``dm.setup('fit')`` calls will be a no-op. To avoid this, you can overwrite
``dm._has_setup_fit = False``


train_dataloader
Expand Down
224 changes: 0 additions & 224 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""LightningDataModule for loading DataLoaders with ease."""

import functools
from argparse import ArgumentParser, Namespace
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -91,19 +89,6 @@ def __init__(self, train_transforms=None, val_transforms=None, test_transforms=N
# Pointer to the trainer object
self.trainer = None

# Private attrs to keep track of whether or not data hooks have been called yet
self._has_prepared_data = False

self._has_setup_fit = False
self._has_setup_validate = False
self._has_setup_test = False
self._has_setup_predict = False

self._has_teardown_fit = False
self._has_teardown_validate = False
self._has_teardown_test = False
self._has_teardown_predict = False

@property
def train_transforms(self):
"""Optional transforms (or collection of transforms) you can apply to train dataset.
Expand Down Expand Up @@ -188,139 +173,6 @@ def size(self, dim=None) -> Union[Tuple, List[Tuple]]:

return self.dims

@property
def has_prepared_data(self) -> bool:
"""Return bool letting you know if ``datamodule.prepare_data()`` has been called or not.
Returns:
bool: True if ``datamodule.prepare_data()`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
"DataModule property `has_prepared_data` was deprecated in v1.4 and will be removed in v1.6."
)
return self._has_prepared_data

@property
def has_setup_fit(self) -> bool:
"""Return bool letting you know if ``datamodule.setup(stage='fit')`` has been called or not.
Returns:
bool: True ``if datamodule.setup(stage='fit')`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation("DataModule property `has_setup_fit` was deprecated in v1.4 and will be removed in v1.6.")
return self._has_setup_fit

@property
def has_setup_validate(self) -> bool:
"""Return bool letting you know if ``datamodule.setup(stage='validate')`` has been called or not.
Returns:
bool: True if ``datamodule.setup(stage='validate')`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
"DataModule property `has_setup_validate` was deprecated in v1.4 and will be removed in v1.6."
)
return self._has_setup_validate

@property
def has_setup_test(self) -> bool:
"""Return bool letting you know if ``datamodule.setup(stage='test')`` has been called or not.
Returns:
bool: True if ``datamodule.setup(stage='test')`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
"DataModule property `has_setup_test` was deprecated in v1.4 and will be removed in v1.6."
)
return self._has_setup_test

@property
def has_setup_predict(self) -> bool:
"""Return bool letting you know if ``datamodule.setup(stage='predict')`` has been called or not.
Returns:
bool: True if ``datamodule.setup(stage='predict')`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
"DataModule property `has_setup_predict` was deprecated in v1.4 and will be removed in v1.6."
)
return self._has_setup_predict

@property
def has_teardown_fit(self) -> bool:
"""Return bool letting you know if ``datamodule.teardown(stage='fit')`` has been called or not.
Returns:
bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
"DataModule property `has_teardown_fit` was deprecated in v1.4 and will be removed in v1.6."
)
return self._has_teardown_fit

@property
def has_teardown_validate(self) -> bool:
"""Return bool letting you know if ``datamodule.teardown(stage='validate')`` has been called or not.
Returns:
bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
"DataModule property `has_teardown_validate` was deprecated in v1.4 and will be removed in v1.6."
)
return self._has_teardown_validate

@property
def has_teardown_test(self) -> bool:
"""Return bool letting you know if ``datamodule.teardown(stage='test')`` has been called or not.
Returns:
bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
"DataModule property `has_teardown_test` was deprecated in v1.4 and will be removed in v1.6."
)
return self._has_teardown_test

@property
def has_teardown_predict(self) -> bool:
"""Return bool letting you know if ``datamodule.teardown(stage='predict')`` has been called or not.
Returns:
bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
"DataModule property `has_teardown_predict` was deprecated in v1.4 and will be removed in v1.6."
)
return self._has_teardown_predict

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:
"""Extends existing argparse by default `LightningDataModule` attributes."""
Expand Down Expand Up @@ -405,79 +257,3 @@ def test_dataloader():
if test_dataset is not None:
datamodule.test_dataloader = test_dataloader
return datamodule

def __new__(cls, *args: Any, **kwargs: Any) -> "LightningDataModule":
obj = super().__new__(cls)
# track `DataHooks` calls
obj.prepare_data = cls._track_data_hook_calls(obj, obj.prepare_data)
obj.setup = cls._track_data_hook_calls(obj, obj.setup)
obj.teardown = cls._track_data_hook_calls(obj, obj.teardown)

# calling this to ensure the `LightningDataModule` is initialized for all cases of inheritance,
# even if `super().__init__` hasn't been explicitly called in the class
LightningDataModule.__init__(obj)
return obj

@staticmethod
def _track_data_hook_calls(obj: "LightningDataModule", fn: callable) -> callable:
"""A decorator that checks if prepare_data/setup/teardown has been called.
- When ``dm.prepare_data()`` is called, ``dm._has_prepared_data`` gets set to True
- When ``dm.setup()``, ``dm._has_setup_{fit,validate,test}`` get set to True
- When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``.
Its corresponding `dm_has_setup_{stage}` attribute gets set to True
- ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup``
Args:
obj: Object whose function will be tracked
fn: Function that will be tracked to see if it has been called.
Returns:
Decorated function that tracks its call status and saves it to private attrs in its obj instance.
"""

@functools.wraps(fn)
def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any:
name = fn.__name__
has_run = False

# If calling setup, we check the stage and assign stage-specific bool args
if name in ("setup", "teardown"):

# Get stage either by grabbing from args or checking kwargs.
# If not provided, set call status of 'fit', 'validate', and 'test' to True.
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call
# setup('test') on trainer.test()
stage = args[0] if len(args) else kwargs.get("stage", None)

if stage is None:
has_run = True
for s in ("fit", "validate", "test"):
attr = f"_has_{name}_{s}"
has_run &= getattr(obj, attr)
setattr(obj, attr, True)
else:
attr = f"_has_{name}_{stage}"
has_run = getattr(obj, attr)
setattr(obj, attr, True)

elif name == "prepare_data":
has_run = obj._has_prepared_data
obj._has_prepared_data = True

if has_run:
rank_zero_deprecation(
f"DataModule.{name} has already been called, so it will not be called again. "
f"In v1.6 this behavior will change to always call DataModule.{name}."
)
else:
fn(*args, **kwargs)

return wrapped_fn

def __getstate__(self) -> dict:
# avoids _pickle.PicklingError: Can't pickle <...>: it's not the same object as <...>
d = self.__dict__.copy()
for fn in ("prepare_data", "setup", "teardown"):
del d[fn]
return d
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def prepare_data(self) -> None:
lightning_module = self.trainer.lightning_module
# handle datamodule prepare data:
# check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data
if datamodule is not None and not datamodule._has_prepared_data:
if datamodule is not None:
dm_prepare_data_per_node = datamodule.prepare_data_per_node
dm_eq_prepare_data = datamodule.prepare_data_per_node == self.trainer.prepare_data_per_node
if self.trainer.prepare_data_per_node is not None and not dm_eq_prepare_data:
Expand Down
Loading

0 comments on commit b3c0f12

Please sign in to comment.