From 298d6b7630193f17b2e80deff0068fb3e92e8f15 Mon Sep 17 00:00:00 2001 From: Mayerowitz Alexandre Date: Thu, 4 Nov 2021 13:54:15 +0100 Subject: [PATCH 1/8] Remove deprecated datamodule lifecycle properties --- pytorch_lightning/core/datamodule.py | 133 --------------------- tests/core/test_datamodules.py | 152 ++++++++++++------------ tests/deprecated_api/test_remove_1-6.py | 22 ---- 3 files changed, 76 insertions(+), 231 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index f3a5c855fe07a..3f1c99ae7e9c9 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -188,139 +188,6 @@ def size(self, dim=None) -> Union[Tuple, List[Tuple]]: return self.dims - @property - def has_prepared_data(self) -> bool: - """Return bool letting you know if ``datamodule.prepare_data()`` has been called or not. - - Returns: - bool: True if ``datamodule.prepare_data()`` has been called. False by default. - - .. deprecated:: v1.4 - Will be removed in v1.6.0. - """ - rank_zero_deprecation( - "DataModule property `has_prepared_data` was deprecated in v1.4 and will be removed in v1.6." - ) - return self._has_prepared_data - - @property - def has_setup_fit(self) -> bool: - """Return bool letting you know if ``datamodule.setup(stage='fit')`` has been called or not. - - Returns: - bool: True ``if datamodule.setup(stage='fit')`` has been called. False by default. - - .. deprecated:: v1.4 - Will be removed in v1.6.0. - """ - rank_zero_deprecation("DataModule property `has_setup_fit` was deprecated in v1.4 and will be removed in v1.6.") - return self._has_setup_fit - - @property - def has_setup_validate(self) -> bool: - """Return bool letting you know if ``datamodule.setup(stage='validate')`` has been called or not. - - Returns: - bool: True if ``datamodule.setup(stage='validate')`` has been called. False by default. - - .. deprecated:: v1.4 - Will be removed in v1.6.0. - """ - rank_zero_deprecation( - "DataModule property `has_setup_validate` was deprecated in v1.4 and will be removed in v1.6." - ) - return self._has_setup_validate - - @property - def has_setup_test(self) -> bool: - """Return bool letting you know if ``datamodule.setup(stage='test')`` has been called or not. - - Returns: - bool: True if ``datamodule.setup(stage='test')`` has been called. False by default. - - .. deprecated:: v1.4 - Will be removed in v1.6.0. - """ - rank_zero_deprecation( - "DataModule property `has_setup_test` was deprecated in v1.4 and will be removed in v1.6." - ) - return self._has_setup_test - - @property - def has_setup_predict(self) -> bool: - """Return bool letting you know if ``datamodule.setup(stage='predict')`` has been called or not. - - Returns: - bool: True if ``datamodule.setup(stage='predict')`` has been called. False by default. - - .. deprecated:: v1.4 - Will be removed in v1.6.0. - """ - rank_zero_deprecation( - "DataModule property `has_setup_predict` was deprecated in v1.4 and will be removed in v1.6." - ) - return self._has_setup_predict - - @property - def has_teardown_fit(self) -> bool: - """Return bool letting you know if ``datamodule.teardown(stage='fit')`` has been called or not. - - Returns: - bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default. - - .. deprecated:: v1.4 - Will be removed in v1.6.0. - """ - rank_zero_deprecation( - "DataModule property `has_teardown_fit` was deprecated in v1.4 and will be removed in v1.6." - ) - return self._has_teardown_fit - - @property - def has_teardown_validate(self) -> bool: - """Return bool letting you know if ``datamodule.teardown(stage='validate')`` has been called or not. - - Returns: - bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default. - - .. deprecated:: v1.4 - Will be removed in v1.6.0. - """ - rank_zero_deprecation( - "DataModule property `has_teardown_validate` was deprecated in v1.4 and will be removed in v1.6." - ) - return self._has_teardown_validate - - @property - def has_teardown_test(self) -> bool: - """Return bool letting you know if ``datamodule.teardown(stage='test')`` has been called or not. - - Returns: - bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default. - - .. deprecated:: v1.4 - Will be removed in v1.6.0. - """ - rank_zero_deprecation( - "DataModule property `has_teardown_test` was deprecated in v1.4 and will be removed in v1.6." - ) - return self._has_teardown_test - - @property - def has_teardown_predict(self) -> bool: - """Return bool letting you know if ``datamodule.teardown(stage='predict')`` has been called or not. - - Returns: - bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default. - - .. deprecated:: v1.4 - Will be removed in v1.6.0. - """ - rank_zero_deprecation( - "DataModule property `has_teardown_predict` was deprecated in v1.4 and will be removed in v1.6." - ) - return self._has_teardown_predict - @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: """Extends existing argparse by default `LightningDataModule` attributes.""" diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 51b51bfbd011a..8fc400f95fc2f 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -140,110 +140,110 @@ def test_helper_boringdatamodule_with_verbose_setup(): def test_data_hooks_called(): dm = BoringDataModule() - assert not dm.has_prepared_data - assert not dm.has_setup_fit - assert not dm.has_setup_test - assert not dm.has_setup_validate - assert not dm.has_setup_predict - assert not dm.has_teardown_fit - assert not dm.has_teardown_test - assert not dm.has_teardown_validate - assert not dm.has_teardown_predict + assert not dm._has_prepared_data + assert not dm._has_setup_fit + assert not dm._has_setup_test + assert not dm._has_setup_validate + assert not dm._has_setup_predict + assert not dm._has_teardown_fit + assert not dm._has_teardown_test + assert not dm._has_teardown_validate + assert not dm._has_teardown_predict dm.prepare_data() - assert dm.has_prepared_data - assert not dm.has_setup_fit - assert not dm.has_setup_test - assert not dm.has_setup_validate - assert not dm.has_setup_predict - assert not dm.has_teardown_fit - assert not dm.has_teardown_test - assert not dm.has_teardown_validate - assert not dm.has_teardown_predict + assert dm._has_prepared_data + assert not dm._has_setup_fit + assert not dm._has_setup_test + assert not dm._has_setup_validate + assert not dm._has_setup_predict + assert not dm._has_teardown_fit + assert not dm._has_teardown_test + assert not dm._has_teardown_validate + assert not dm._has_teardown_predict dm.setup() - assert dm.has_prepared_data - assert dm.has_setup_fit - assert dm.has_setup_test - assert dm.has_setup_validate - assert not dm.has_setup_predict - assert not dm.has_teardown_fit - assert not dm.has_teardown_test - assert not dm.has_teardown_validate - assert not dm.has_teardown_predict + assert dm._has_prepared_data + assert dm._has_setup_fit + assert dm._has_setup_test + assert dm._has_setup_validate + assert not dm._has_setup_predict + assert not dm._has_teardown_fit + assert not dm._has_teardown_test + assert not dm._has_teardown_validate + assert not dm._has_teardown_predict dm.teardown() - assert dm.has_prepared_data - assert dm.has_setup_fit - assert dm.has_setup_test - assert dm.has_setup_validate - assert not dm.has_setup_predict - assert dm.has_teardown_fit - assert dm.has_teardown_test - assert dm.has_teardown_validate - assert not dm.has_teardown_predict + assert dm._has_prepared_data + assert dm._has_setup_fit + assert dm._has_setup_test + assert dm._has_setup_validate + assert not dm._has_setup_predict + assert dm._has_teardown_fit + assert dm._has_teardown_test + assert dm._has_teardown_validate + assert not dm._has_teardown_predict @pytest.mark.parametrize("use_kwarg", (False, True)) def test_data_hooks_called_verbose(use_kwarg): dm = BoringDataModule() dm.prepare_data() - assert not dm.has_setup_fit - assert not dm.has_setup_test - assert not dm.has_setup_validate - assert not dm.has_setup_predict - assert not dm.has_teardown_fit - assert not dm.has_teardown_test - assert not dm.has_teardown_validate - assert not dm.has_teardown_predict + assert not dm._has_setup_fit + assert not dm._has_setup_test + assert not dm._has_setup_validate + assert not dm._has_setup_predict + assert not dm._has_teardown_fit + assert not dm._has_teardown_test + assert not dm._has_teardown_validate + assert not dm._has_teardown_predict dm.setup(stage="fit") if use_kwarg else dm.setup("fit") - assert dm.has_setup_fit - assert not dm.has_setup_validate - assert not dm.has_setup_test - assert not dm.has_setup_predict + assert dm._has_setup_fit + assert not dm._has_setup_validate + assert not dm._has_setup_test + assert not dm._has_setup_predict dm.setup(stage="validate") if use_kwarg else dm.setup("validate") - assert dm.has_setup_fit - assert dm.has_setup_validate - assert not dm.has_setup_test - assert not dm.has_setup_predict + assert dm._has_setup_fit + assert dm._has_setup_validate + assert not dm._has_setup_test + assert not dm._has_setup_predict dm.setup(stage="test") if use_kwarg else dm.setup("test") - assert dm.has_setup_fit - assert dm.has_setup_validate - assert dm.has_setup_test - assert not dm.has_setup_predict + assert dm._has_setup_fit + assert dm._has_setup_validate + assert dm._has_setup_test + assert not dm._has_setup_predict dm.setup(stage="predict") if use_kwarg else dm.setup("predict") - assert dm.has_setup_fit - assert dm.has_setup_validate - assert dm.has_setup_test - assert dm.has_setup_predict + assert dm._has_setup_fit + assert dm._has_setup_validate + assert dm._has_setup_test + assert dm._has_setup_predict dm.teardown(stage="fit") if use_kwarg else dm.teardown("fit") - assert dm.has_teardown_fit - assert not dm.has_teardown_validate - assert not dm.has_teardown_test - assert not dm.has_teardown_predict + assert dm._has_teardown_fit + assert not dm._has_teardown_validate + assert not dm._has_teardown_test + assert not dm._has_teardown_predict dm.teardown(stage="validate") if use_kwarg else dm.teardown("validate") - assert dm.has_teardown_fit - assert dm.has_teardown_validate - assert not dm.has_teardown_test - assert not dm.has_teardown_predict + assert dm._has_teardown_fit + assert dm._has_teardown_validate + assert not dm._has_teardown_test + assert not dm._has_teardown_predict dm.teardown(stage="test") if use_kwarg else dm.teardown("test") - assert dm.has_teardown_fit - assert dm.has_teardown_validate - assert dm.has_teardown_test - assert not dm.has_teardown_predict + assert dm._has_teardown_fit + assert dm._has_teardown_validate + assert dm._has_teardown_test + assert not dm._has_teardown_predict dm.teardown(stage="predict") if use_kwarg else dm.teardown("predict") - assert dm.has_teardown_fit - assert dm.has_teardown_validate - assert dm.has_teardown_test - assert dm.has_teardown_predict + assert dm._has_teardown_fit + assert dm._has_teardown_validate + assert dm._has_teardown_test + assert dm._has_teardown_predict def test_dm_add_argparse_args(tmpdir): diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 2b55a08e52036..ae8637ad1ef16 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -128,28 +128,6 @@ def training_step(self, *args): trainer.fit(TestModel()) -def test_v1_6_0_datamodule_lifecycle_properties(tmpdir): - dm = BoringDataModule() - with pytest.deprecated_call(match=r"DataModule property `has_prepared_data` was deprecated in v1.4"): - dm.has_prepared_data - with pytest.deprecated_call(match=r"DataModule property `has_setup_fit` was deprecated in v1.4"): - dm.has_setup_fit - with pytest.deprecated_call(match=r"DataModule property `has_setup_validate` was deprecated in v1.4"): - dm.has_setup_validate - with pytest.deprecated_call(match=r"DataModule property `has_setup_test` was deprecated in v1.4"): - dm.has_setup_test - with pytest.deprecated_call(match=r"DataModule property `has_setup_predict` was deprecated in v1.4"): - dm.has_setup_predict - with pytest.deprecated_call(match=r"DataModule property `has_teardown_fit` was deprecated in v1.4"): - dm.has_teardown_fit - with pytest.deprecated_call(match=r"DataModule property `has_teardown_validate` was deprecated in v1.4"): - dm.has_teardown_validate - with pytest.deprecated_call(match=r"DataModule property `has_teardown_test` was deprecated in v1.4"): - dm.has_teardown_test - with pytest.deprecated_call(match=r"DataModule property `has_teardown_predict` was deprecated in v1.4"): - dm.has_teardown_predict - - def test_v1_6_0_datamodule_hooks_calls(tmpdir): """Test that repeated calls to DataHooks' hooks show a warning about the coming API change.""" From 0a89c95239b4825bfcd23101edd396fab533c9e8 Mon Sep 17 00:00:00 2001 From: Mayerowitz Alexandre Date: Thu, 4 Nov 2021 14:03:27 +0100 Subject: [PATCH 2/8] Update CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 82cd8e115213b..af0b4aa855727 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -67,7 +67,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -- +- Removed deprecated `has_prepared_data`, `has_setup_fit`, `has_setup_validate`, `has_setup_test`, `has_setup_predict`, `has_teardown_fit`, `has_teardown_validate`, `has_teardown_test` and `has_teardown_predict` datamodule lifecycle properties ([#10350](https://github.com/PyTorchLightning/pytorch-lightning/pull/10350)) - From c198544c827108e985ffb47ed80d48d15327f239 Mon Sep 17 00:00:00 2001 From: Alexandre Mayerowitz Date: Thu, 4 Nov 2021 22:37:19 +0100 Subject: [PATCH 3/8] Remove datamodule hook tests --- tests/core/test_datamodules.py | 108 --------------------------------- 1 file changed, 108 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 8fc400f95fc2f..e4925b604ac30 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -138,114 +138,6 @@ def test_helper_boringdatamodule_with_verbose_setup(): dm.setup("test") -def test_data_hooks_called(): - dm = BoringDataModule() - assert not dm._has_prepared_data - assert not dm._has_setup_fit - assert not dm._has_setup_test - assert not dm._has_setup_validate - assert not dm._has_setup_predict - assert not dm._has_teardown_fit - assert not dm._has_teardown_test - assert not dm._has_teardown_validate - assert not dm._has_teardown_predict - - dm.prepare_data() - assert dm._has_prepared_data - assert not dm._has_setup_fit - assert not dm._has_setup_test - assert not dm._has_setup_validate - assert not dm._has_setup_predict - assert not dm._has_teardown_fit - assert not dm._has_teardown_test - assert not dm._has_teardown_validate - assert not dm._has_teardown_predict - - dm.setup() - assert dm._has_prepared_data - assert dm._has_setup_fit - assert dm._has_setup_test - assert dm._has_setup_validate - assert not dm._has_setup_predict - assert not dm._has_teardown_fit - assert not dm._has_teardown_test - assert not dm._has_teardown_validate - assert not dm._has_teardown_predict - - dm.teardown() - assert dm._has_prepared_data - assert dm._has_setup_fit - assert dm._has_setup_test - assert dm._has_setup_validate - assert not dm._has_setup_predict - assert dm._has_teardown_fit - assert dm._has_teardown_test - assert dm._has_teardown_validate - assert not dm._has_teardown_predict - - -@pytest.mark.parametrize("use_kwarg", (False, True)) -def test_data_hooks_called_verbose(use_kwarg): - dm = BoringDataModule() - dm.prepare_data() - assert not dm._has_setup_fit - assert not dm._has_setup_test - assert not dm._has_setup_validate - assert not dm._has_setup_predict - assert not dm._has_teardown_fit - assert not dm._has_teardown_test - assert not dm._has_teardown_validate - assert not dm._has_teardown_predict - - dm.setup(stage="fit") if use_kwarg else dm.setup("fit") - assert dm._has_setup_fit - assert not dm._has_setup_validate - assert not dm._has_setup_test - assert not dm._has_setup_predict - - dm.setup(stage="validate") if use_kwarg else dm.setup("validate") - assert dm._has_setup_fit - assert dm._has_setup_validate - assert not dm._has_setup_test - assert not dm._has_setup_predict - - dm.setup(stage="test") if use_kwarg else dm.setup("test") - assert dm._has_setup_fit - assert dm._has_setup_validate - assert dm._has_setup_test - assert not dm._has_setup_predict - - dm.setup(stage="predict") if use_kwarg else dm.setup("predict") - assert dm._has_setup_fit - assert dm._has_setup_validate - assert dm._has_setup_test - assert dm._has_setup_predict - - dm.teardown(stage="fit") if use_kwarg else dm.teardown("fit") - assert dm._has_teardown_fit - assert not dm._has_teardown_validate - assert not dm._has_teardown_test - assert not dm._has_teardown_predict - - dm.teardown(stage="validate") if use_kwarg else dm.teardown("validate") - assert dm._has_teardown_fit - assert dm._has_teardown_validate - assert not dm._has_teardown_test - assert not dm._has_teardown_predict - - dm.teardown(stage="test") if use_kwarg else dm.teardown("test") - assert dm._has_teardown_fit - assert dm._has_teardown_validate - assert dm._has_teardown_test - assert not dm._has_teardown_predict - - dm.teardown(stage="predict") if use_kwarg else dm.teardown("predict") - assert dm._has_teardown_fit - assert dm._has_teardown_validate - assert dm._has_teardown_test - assert dm._has_teardown_predict - - def test_dm_add_argparse_args(tmpdir): parser = ArgumentParser() parser = BoringDataModule.add_argparse_args(parser) From 89ffd5789a3b6266a9965be3f8652ba0efb85a9d Mon Sep 17 00:00:00 2001 From: Alexandre Mayerowitz Date: Thu, 4 Nov 2021 22:48:58 +0100 Subject: [PATCH 4/8] Remove private datamodule lifecycle properties --- docs/source/extensions/datamodules.rst | 3 +- pytorch_lightning/core/datamodule.py | 89 ------------------- .../trainer/connectors/data_connector.py | 2 +- tests/core/test_datamodules.py | 36 -------- 4 files changed, 2 insertions(+), 128 deletions(-) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index b9860be79790a..79d8ede164a96 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -237,8 +237,7 @@ we assume all stages have been set-up. .. note:: ``{setup,teardown,prepare_data}`` call will be only called once for a specific stage. If the stage was ``None`` then we assume ``{fit,validate,test}`` have been called. For example, this means that - any duplicate ``dm.setup('fit')`` calls will be a no-op. To avoid this, you can overwrite - ``dm._has_setup_fit = False`` + any duplicate ``dm.setup('fit')`` calls will be a no-op. train_dataloader diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 3f1c99ae7e9c9..ee8658a30c0b1 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -91,19 +91,6 @@ def __init__(self, train_transforms=None, val_transforms=None, test_transforms=N # Pointer to the trainer object self.trainer = None - # Private attrs to keep track of whether or not data hooks have been called yet - self._has_prepared_data = False - - self._has_setup_fit = False - self._has_setup_validate = False - self._has_setup_test = False - self._has_setup_predict = False - - self._has_teardown_fit = False - self._has_teardown_validate = False - self._has_teardown_test = False - self._has_teardown_predict = False - @property def train_transforms(self): """Optional transforms (or collection of transforms) you can apply to train dataset. @@ -272,79 +259,3 @@ def test_dataloader(): if test_dataset is not None: datamodule.test_dataloader = test_dataloader return datamodule - - def __new__(cls, *args: Any, **kwargs: Any) -> "LightningDataModule": - obj = super().__new__(cls) - # track `DataHooks` calls - obj.prepare_data = cls._track_data_hook_calls(obj, obj.prepare_data) - obj.setup = cls._track_data_hook_calls(obj, obj.setup) - obj.teardown = cls._track_data_hook_calls(obj, obj.teardown) - - # calling this to ensure the `LightningDataModule` is initialized for all cases of inheritance, - # even if `super().__init__` hasn't been explicitly called in the class - LightningDataModule.__init__(obj) - return obj - - @staticmethod - def _track_data_hook_calls(obj: "LightningDataModule", fn: callable) -> callable: - """A decorator that checks if prepare_data/setup/teardown has been called. - - - When ``dm.prepare_data()`` is called, ``dm._has_prepared_data`` gets set to True - - When ``dm.setup()``, ``dm._has_setup_{fit,validate,test}`` get set to True - - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. - Its corresponding `dm_has_setup_{stage}` attribute gets set to True - - ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup`` - - Args: - obj: Object whose function will be tracked - fn: Function that will be tracked to see if it has been called. - - Returns: - Decorated function that tracks its call status and saves it to private attrs in its obj instance. - """ - - @functools.wraps(fn) - def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any: - name = fn.__name__ - has_run = False - - # If calling setup, we check the stage and assign stage-specific bool args - if name in ("setup", "teardown"): - - # Get stage either by grabbing from args or checking kwargs. - # If not provided, set call status of 'fit', 'validate', and 'test' to True. - # We do this so __attach_datamodule in trainer.py doesn't mistakenly call - # setup('test') on trainer.test() - stage = args[0] if len(args) else kwargs.get("stage", None) - - if stage is None: - has_run = True - for s in ("fit", "validate", "test"): - attr = f"_has_{name}_{s}" - has_run &= getattr(obj, attr) - setattr(obj, attr, True) - else: - attr = f"_has_{name}_{stage}" - has_run = getattr(obj, attr) - setattr(obj, attr, True) - - elif name == "prepare_data": - has_run = obj._has_prepared_data - obj._has_prepared_data = True - - if has_run: - rank_zero_deprecation( - f"DataModule.{name} has already been called, so it will not be called again. " - f"In v1.6 this behavior will change to always call DataModule.{name}." - ) - else: - fn(*args, **kwargs) - - return wrapped_fn - - def __getstate__(self) -> dict: - # avoids _pickle.PicklingError: Can't pickle <...>: it's not the same object as <...> - d = self.__dict__.copy() - for fn in ("prepare_data", "setup", "teardown"): - del d[fn] - return d diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 8f286964940d2..90c398087578d 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -140,7 +140,7 @@ def prepare_data(self) -> None: lightning_module = self.trainer.lightning_module # handle datamodule prepare data: # check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data - if datamodule is not None and not datamodule._has_prepared_data: + if datamodule is not None: dm_prepare_data_per_node = datamodule.prepare_data_per_node dm_eq_prepare_data = datamodule.prepare_data_per_node == self.trainer.prepare_data_per_node if self.trainer.prepare_data_per_node is not None and not dm_eq_prepare_data: diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index e4925b604ac30..76e3834627860 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -46,7 +46,6 @@ def test_can_prepare_data(local_rank, node_rank): # prepare_data_per_node = True # local rank = 0 (True) dm.random_full = None - dm._has_prepared_data = False local_rank.return_value = 0 assert trainer.local_rank == 0 @@ -55,7 +54,6 @@ def test_can_prepare_data(local_rank, node_rank): # local rank = 1 (False) dm.random_full = None - dm._has_prepared_data = False local_rank.return_value = 1 assert trainer.local_rank == 1 @@ -65,7 +63,6 @@ def test_can_prepare_data(local_rank, node_rank): # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) dm.random_full = None - dm._has_prepared_data = False dm.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 @@ -75,7 +72,6 @@ def test_can_prepare_data(local_rank, node_rank): # global rank = 1 (False) dm.random_full = None - dm._has_prepared_data = False node_rank.return_value = 1 local_rank.return_value = 0 @@ -98,13 +94,11 @@ def test_can_prepare_data(local_rank, node_rank): # is_overridden prepare data = True # has been called # False - dm._has_prepared_data = True trainer._data_connector.prepare_data() dm_mock.assert_not_called() # has not been called # True - dm._has_prepared_data = False trainer._data_connector.prepare_data() dm_mock.assert_called_once() @@ -485,36 +479,6 @@ class BoringDataModule2(LightningDataModule): assert BoringDataModule2(batch_size=32).prepare_data() is None assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32) - # checking for all the different multilevel inhertiance scenarios, for init call on LightningDataModule - @dataclass - class BoringModuleBase1(LightningDataModule): - num_features: int - - class BoringModuleBase2(LightningDataModule): - def __init__(self, num_features: int): - self.num_features = num_features - - @dataclass - class BoringModuleDerived1(BoringModuleBase1): - ... - - class BoringModuleDerived2(BoringModuleBase1): - def __init__(self): - ... - - @dataclass - class BoringModuleDerived3(BoringModuleBase2): - ... - - class BoringModuleDerived4(BoringModuleBase2): - def __init__(self): - ... - - assert hasattr(BoringModuleDerived1(num_features=2), "_has_prepared_data") - assert hasattr(BoringModuleDerived2(), "_has_prepared_data") - assert hasattr(BoringModuleDerived3(), "_has_prepared_data") - assert hasattr(BoringModuleDerived4(), "_has_prepared_data") - def test_inconsistent_prepare_data_per_node(tmpdir): with pytest.raises(MisconfigurationException, match="Inconsistent settings found for `prepare_data_per_node`."): From 4585a640dc1c2287161884f614e218848b855581 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Nov 2021 21:50:21 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/extensions/datamodules.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 79d8ede164a96..b0e76d10f93c9 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -237,7 +237,7 @@ we assume all stages have been set-up. .. note:: ``{setup,teardown,prepare_data}`` call will be only called once for a specific stage. If the stage was ``None`` then we assume ``{fit,validate,test}`` have been called. For example, this means that - any duplicate ``dm.setup('fit')`` calls will be a no-op. + any duplicate ``dm.setup('fit')`` calls will be a no-op. train_dataloader From 2d2272ef36cf09cef7a091006e83a76aa8f6a460 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 5 Nov 2021 04:29:15 +0100 Subject: [PATCH 6/8] Missing docs and tests removal --- docs/source/extensions/datamodules.rst | 4 ---- tests/core/test_datamodules.py | 7 ------- 2 files changed, 11 deletions(-) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index b0e76d10f93c9..7b8e2cd3754e6 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -234,10 +234,6 @@ we assume all stages have been set-up. .. note:: ``setup`` is called from every process. Setting state here is okay. .. note:: ``teardown`` can be used to clean up the state. It is also called from every process -.. note:: - ``{setup,teardown,prepare_data}`` call will be only called once for a specific stage. - If the stage was ``None`` then we assume ``{fit,validate,test}`` have been called. For example, this means that - any duplicate ``dm.setup('fit')`` calls will be a no-op. train_dataloader diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 76e3834627860..7fe3032058e2d 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -92,13 +92,6 @@ def test_can_prepare_data(local_rank, node_rank): with mock.patch.object(trainer.datamodule, "prepare_data") as dm_mock: # is_overridden prepare data = True - # has been called - # False - trainer._data_connector.prepare_data() - dm_mock.assert_not_called() - - # has not been called - # True trainer._data_connector.prepare_data() dm_mock.assert_called_once() From 1505295d129d52e749857970030f32646d907469 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 5 Nov 2021 04:31:11 +0100 Subject: [PATCH 7/8] Remove deprecation test --- tests/deprecated_api/test_remove_1-6.py | 53 +------------------------ 1 file changed, 1 insertion(+), 52 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index ae8637ad1ef16..8905fe4209c54 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -33,7 +33,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.model_summary import ModelSummary from tests.deprecated_api import _soft_unimport_module -from tests.helpers import BoringDataModule, BoringModel +from tests.helpers import BoringModel def test_old_transfer_batch_to_device_hook(tmpdir): @@ -128,57 +128,6 @@ def training_step(self, *args): trainer.fit(TestModel()) -def test_v1_6_0_datamodule_hooks_calls(tmpdir): - """Test that repeated calls to DataHooks' hooks show a warning about the coming API change.""" - - class TestDataModule(BoringDataModule): - setup_calls = [] - teardown_calls = [] - prepare_data_calls = 0 - - def setup(self, stage=None): - super().setup(stage=stage) - self.setup_calls.append(stage) - - def teardown(self, stage=None): - super().teardown(stage=stage) - self.teardown_calls.append(stage) - - def prepare_data(self): - super().prepare_data() - self.prepare_data_calls += 1 - - dm = TestDataModule() - dm.prepare_data() - dm.prepare_data() - dm.setup("fit") - with pytest.deprecated_call( - match=r"DataModule.setup has already been called, so it will not be called again. " - "In v1.6 this behavior will change to always call DataModule.setup" - ): - dm.setup("fit") - dm.setup() - dm.setup() - dm.teardown("validate") - with pytest.deprecated_call( - match=r"DataModule.teardown has already been called, so it will not be called again. " - "In v1.6 this behavior will change to always call DataModule.teardown" - ): - dm.teardown("validate") - - assert dm.prepare_data_calls == 1 - assert dm.setup_calls == ["fit", None] - assert dm.teardown_calls == ["validate"] - - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) - trainer.test(BoringModel(), datamodule=dm) - - # same number of calls - assert dm.prepare_data_calls == 1 - assert dm.setup_calls == ["fit", None] - assert dm.teardown_calls == ["validate", "test"] - - def test_v1_6_0_is_overridden_model(): model = BoringModel() with pytest.deprecated_call(match="and will be removed in v1.6"): From ee3cc14045ac05912f838af604bcd48b4f671794 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 5 Nov 2021 05:28:40 +0100 Subject: [PATCH 8/8] Fix pre-commit --- pytorch_lightning/core/datamodule.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index ee8658a30c0b1..98e14b128b2ff 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """LightningDataModule for loading DataLoaders with ease.""" - -import functools from argparse import ArgumentParser, Namespace from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union