diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index e5889172b6a6a..afe9f119b6089 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -246,6 +246,14 @@ def restore_checkpoint_after_pre_dispatch(self) -> bool: """ return False + @property + def lightning_restore_optimizer_and_schedulers(self) -> bool: + """ + Override to disable Lightning restoring optimizers/schedulers. + This is useful for plugins which manage restoring optimizers/schedulers. + """ + return True + def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int: """ Provide a hook to count optimizer step calls. diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 611946fd53dae..3bb2dbd3ea61e 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -221,7 +221,10 @@ def restore_loops(self) -> None: def restore_optimizers_and_schedulers(self) -> None: """Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint.""" - if not self._loaded_checkpoint: + if ( + not self._loaded_checkpoint + or not self.trainer.training_type_plugin.lightning_restore_optimizer_and_schedulers + ): return # validation diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index a8894fe2dd00a..939c05d1b7afe 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -11,8 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +from typing import Any, Mapping + +import pytest +import torch + from pytorch_lightning import Trainer -from pytorch_lightning.plugins import DDPPlugin +from pytorch_lightning.plugins import DDPPlugin, SingleDevicePlugin from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -33,3 +39,31 @@ def test_sync_batchnorm_set(tmpdir): trainer = Trainer(max_epochs=1, plugins=[plugin], default_root_dir=tmpdir, sync_batchnorm=True) trainer.fit(model) assert plugin.sync_batchnorm is True + + +@pytest.mark.parametrize("restore_optimizer_and_schedulers", [True, False]) +def test_plugin_lightning_restore_optimizer_and_schedulers(tmpdir, restore_optimizer_and_schedulers): + class TestPlugin(SingleDevicePlugin): + load_optimizer_state_dict_called = False + + @property + def lightning_restore_optimizer_and_schedulers(self) -> bool: + return restore_optimizer_and_schedulers + + def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + self.load_optimizer_state_dict_called = True + + # create ckpt to resume from + checkpoint_path = os.path.join(tmpdir, "model.ckpt") + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model) + trainer.save_checkpoint(checkpoint_path) + + model = BoringModel() + plugin = TestPlugin(torch.device("cpu")) + trainer = Trainer( + default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin, resume_from_checkpoint=checkpoint_path + ) + trainer.fit(model) + assert plugin.load_optimizer_state_dict_called == restore_optimizer_and_schedulers