Skip to content

Commit

Permalink
Add property to skip restoring optimizers and schedulers via plugin (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Naren authored Jul 31, 2021
1 parent 1f01db8 commit 7a1e972
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,14 @@ def restore_checkpoint_after_pre_dispatch(self) -> bool:
"""
return False

@property
def lightning_restore_optimizer_and_schedulers(self) -> bool:
"""
Override to disable Lightning restoring optimizers/schedulers.
This is useful for plugins which manage restoring optimizers/schedulers.
"""
return True

def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
"""
Provide a hook to count optimizer step calls.
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,10 @@ def restore_loops(self) -> None:

def restore_optimizers_and_schedulers(self) -> None:
"""Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint."""
if not self._loaded_checkpoint:
if (
not self._loaded_checkpoint
or not self.trainer.training_type_plugin.lightning_restore_optimizer_and_schedulers
):
return

# validation
Expand Down
36 changes: 35 additions & 1 deletion tests/plugins/test_custom_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Mapping

import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.plugins import DDPPlugin, SingleDevicePlugin
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf

Expand All @@ -33,3 +39,31 @@ def test_sync_batchnorm_set(tmpdir):
trainer = Trainer(max_epochs=1, plugins=[plugin], default_root_dir=tmpdir, sync_batchnorm=True)
trainer.fit(model)
assert plugin.sync_batchnorm is True


@pytest.mark.parametrize("restore_optimizer_and_schedulers", [True, False])
def test_plugin_lightning_restore_optimizer_and_schedulers(tmpdir, restore_optimizer_and_schedulers):
class TestPlugin(SingleDevicePlugin):
load_optimizer_state_dict_called = False

@property
def lightning_restore_optimizer_and_schedulers(self) -> bool:
return restore_optimizer_and_schedulers

def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
self.load_optimizer_state_dict_called = True

# create ckpt to resume from
checkpoint_path = os.path.join(tmpdir, "model.ckpt")
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)
trainer.save_checkpoint(checkpoint_path)

model = BoringModel()
plugin = TestPlugin(torch.device("cpu"))
trainer = Trainer(
default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin, resume_from_checkpoint=checkpoint_path
)
trainer.fit(model)
assert plugin.load_optimizer_state_dict_called == restore_optimizer_and_schedulers

0 comments on commit 7a1e972

Please sign in to comment.