From 6b70d177b6d3d05004a5a75f61922c9500694f4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 2 Feb 2023 22:29:33 +0100 Subject: [PATCH] Rename optimization loops (#16598) Co-authored-by: Jirka Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/lightning/pytorch/loops/__init__.py | 2 +- .../loops/epoch/training_epoch_loop.py | 22 +++++++++---------- src/lightning/pytorch/loops/fit_loop.py | 4 ++-- .../pytorch/loops/optimization/__init__.py | 4 ++-- .../{optimizer_loop.py => automatic.py} | 2 +- .../{manual_loop.py => manual.py} | 2 +- .../pytorch/tuner/batch_size_scaling.py | 4 ++-- .../pytorch/utilities/migration/migration.py | 21 ++++++++++++++---- .../core/test_lightning_optimizer.py | 2 +- .../loops/epoch/test_training_epoch_loop.py | 2 +- .../loops/optimization/test_manual_loop.py | 2 +- .../loops/optimization/test_optimizer_loop.py | 2 +- .../loops/test_evaluation_loop_flow.py | 8 +++---- .../loops/test_loop_state_dict.py | 8 +++---- tests/tests_pytorch/loops/test_loops.py | 22 +++++++++---------- .../tests_pytorch/loops/test_training_loop.py | 2 +- .../loops/test_training_loop_flow_scalar.py | 12 +++++----- tests/tests_pytorch/trainer/test_trainer.py | 2 +- .../utilities/migration/test_migration.py | 19 ++++++++++------ 19 files changed, 80 insertions(+), 62 deletions(-) rename src/lightning/pytorch/loops/optimization/{optimizer_loop.py => automatic.py} (99%) rename src/lightning/pytorch/loops/optimization/{manual_loop.py => manual.py} (98%) diff --git a/src/lightning/pytorch/loops/__init__.py b/src/lightning/pytorch/loops/__init__.py index 5c4443dbac68b..dd64810d789c2 100644 --- a/src/lightning/pytorch/loops/__init__.py +++ b/src/lightning/pytorch/loops/__init__.py @@ -15,4 +15,4 @@ from lightning.pytorch.loops.dataloader import _DataLoaderLoop, _EvaluationLoop, _PredictionLoop # noqa: F401 from lightning.pytorch.loops.epoch import _EvaluationEpochLoop, _PredictionEpochLoop, _TrainingEpochLoop # noqa: F401 from lightning.pytorch.loops.fit_loop import _FitLoop # noqa: F401 -from lightning.pytorch.loops.optimization import _ManualOptimization, _OptimizerLoop # noqa: F401 +from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization # noqa: F401 diff --git a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py index 80af832451019..d36ad1a22f2f2 100644 --- a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py @@ -18,9 +18,9 @@ import torch from lightning.pytorch import loops # import as loops to avoid circular imports -from lightning.pytorch.loops.optimization import _ManualOptimization, _OptimizerLoop -from lightning.pytorch.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE -from lightning.pytorch.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE +from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization +from lightning.pytorch.loops.optimization.automatic import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE +from lightning.pytorch.loops.optimization.manual import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE from lightning.pytorch.loops.progress import BatchProgress, SchedulerProgress from lightning.pytorch.loops.utilities import _is_max_limit_reached from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection @@ -68,8 +68,8 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None self.batch_progress = BatchProgress() self.scheduler_progress = SchedulerProgress() - self.optimizer_loop = _OptimizerLoop() - self.manual_loop = _ManualOptimization() + self.automatic_optimization = _AutomaticOptimization() + self.manual_optimization = _ManualOptimization() self.val_loop = loops._EvaluationLoop(verbose=False) @@ -96,8 +96,8 @@ def batch_idx(self) -> int: def global_step(self) -> int: lightning_module = self.trainer.lightning_module if lightning_module is None or lightning_module.automatic_optimization: - return self.optimizer_loop.optim_progress.optimizer_steps - return self.manual_loop.optim_step_progress.total.completed + return self.automatic_optimization.optim_progress.optimizer_steps + return self.manual_optimization.optim_step_progress.total.completed @property def _is_training_done(self) -> bool: @@ -146,7 +146,7 @@ def reset(self) -> None: if self.restarting: self.batch_progress.reset_on_restart() self.scheduler_progress.reset_on_restart() - self.optimizer_loop.optim_progress.reset_on_restart() + self.automatic_optimization.optim_progress.reset_on_restart() trainer = self.trainer if trainer.num_training_batches != float("inf"): @@ -159,7 +159,7 @@ def reset(self) -> None: else: self.batch_progress.reset_on_run() self.scheduler_progress.reset_on_run() - self.optimizer_loop.optim_progress.reset_on_run() + self.automatic_optimization.optim_progress.reset_on_run() # when the epoch starts, the total val batch progress should be reset as it's supposed to count the batches # seen per epoch, this is useful for tracking when validation is run multiple times per epoch self.val_loop.epoch_loop.batch_progress.total.reset() @@ -222,9 +222,9 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: with self.trainer.profiler.profile("run_training_batch"): if self.trainer.lightning_module.automatic_optimization: # in automatic optimization, there can only be one optimizer - batch_output = self.optimizer_loop.run(self.trainer.optimizers[0], kwargs) + batch_output = self.automatic_optimization.run(self.trainer.optimizers[0], kwargs) else: - batch_output = self.manual_loop.run(kwargs) + batch_output = self.manual_optimization.run(kwargs) self.batch_progress.increment_processed() diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 06140dfaa0bd2..d68bb5d5b664a 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -131,12 +131,12 @@ def prefetch_batches(self) -> int: @property def _skip_backward(self) -> bool: """Determines whether the loop will skip backward during automatic optimization.""" - return self.epoch_loop.optimizer_loop._skip_backward + return self.epoch_loop.automatic_optimization._skip_backward @_skip_backward.setter def _skip_backward(self, value: bool) -> None: """Determines whether the loop will skip backward during automatic optimization.""" - self.epoch_loop.optimizer_loop._skip_backward = value + self.epoch_loop.automatic_optimization._skip_backward = value @property def _results(self) -> _ResultCollection: diff --git a/src/lightning/pytorch/loops/optimization/__init__.py b/src/lightning/pytorch/loops/optimization/__init__.py index 5fe6bf52609de..4ea5fdfe7e75c 100644 --- a/src/lightning/pytorch/loops/optimization/__init__.py +++ b/src/lightning/pytorch/loops/optimization/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from lightning.pytorch.loops.optimization.manual_loop import _ManualOptimization # noqa: F401 -from lightning.pytorch.loops.optimization.optimizer_loop import _OptimizerLoop # noqa: F401 +from lightning.pytorch.loops.optimization.automatic import _AutomaticOptimization # noqa: F401 +from lightning.pytorch.loops.optimization.manual import _ManualOptimization # noqa: F401 diff --git a/src/lightning/pytorch/loops/optimization/optimizer_loop.py b/src/lightning/pytorch/loops/optimization/automatic.py similarity index 99% rename from src/lightning/pytorch/loops/optimization/optimizer_loop.py rename to src/lightning/pytorch/loops/optimization/automatic.py index 4cc25f8649892..7610990234b6f 100644 --- a/src/lightning/pytorch/loops/optimization/optimizer_loop.py +++ b/src/lightning/pytorch/loops/optimization/automatic.py @@ -144,7 +144,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]: _OUTPUTS_TYPE = Dict[str, Any] -class _OptimizerLoop(_Loop): +class _AutomaticOptimization(_Loop): """Performs automatic optimization (forward, zero grad, backward, optimizer step)""" output_result_cls = ClosureResult diff --git a/src/lightning/pytorch/loops/optimization/manual_loop.py b/src/lightning/pytorch/loops/optimization/manual.py similarity index 98% rename from src/lightning/pytorch/loops/optimization/manual_loop.py rename to src/lightning/pytorch/loops/optimization/manual.py index baa81ae227e14..7b64dcee6b349 100644 --- a/src/lightning/pytorch/loops/optimization/manual_loop.py +++ b/src/lightning/pytorch/loops/optimization/manual.py @@ -28,7 +28,7 @@ @dataclass class ManualResult(OutputResult): - """A container to hold the result returned by the ``ManualLoop``. + """A container to hold the result returned by ``_ManualOptimization``. It is created from the output of :meth:`~lightning.pytorch.core.module.LightningModule.training_step`. diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py index 6f206209cd4d3..5386167862202 100644 --- a/src/lightning/pytorch/tuner/batch_size_scaling.py +++ b/src/lightning/pytorch/tuner/batch_size_scaling.py @@ -359,8 +359,8 @@ def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: def _reset_progress(trainer: "pl.Trainer") -> None: if trainer.lightning_module.automatic_optimization: - trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.reset() + trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.reset() else: - trainer.fit_loop.epoch_loop.manual_loop.optim_step_progress.reset() + trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.reset() trainer.fit_loop.epoch_progress.reset() diff --git a/src/lightning/pytorch/utilities/migration/migration.py b/src/lightning/pytorch/utilities/migration/migration.py index 0b43b2d21fc95..e7433f992c76f 100644 --- a/src/lightning/pytorch/utilities/migration/migration.py +++ b/src/lightning/pytorch/utilities/migration/migration.py @@ -268,13 +268,26 @@ def _migrate_loop_structure_after_optimizer_loop_removal(checkpoint: _CHECKPOINT resuming the loop. Version: 2.0.0 - Commit: TBD - PR: TBD + Commit: 6a56586 + PR: #16539, #16598 """ if "loops" not in checkpoint: return checkpoint - # TODO: Complete this migration function when optimizer loop gets flattened out and keys need to be remapped fit_loop = checkpoint["loops"]["fit_loop"] - fit_loop["epoch_loop.optimizer_loop.optim_progress"].pop("optimizer_position", None) + # optimizer_position is no longer used + if "epoch_loop.optimizer_loop.optim_progress" in fit_loop: + fit_loop["epoch_loop.optimizer_loop.optim_progress"].pop("optimizer_position", None) + + # the subloop attribute names have changed + if "epoch_loop.optimizer_loop.state_dict" in fit_loop: + fit_loop["epoch_loop.automatic_optimization.state_dict"] = fit_loop.pop("epoch_loop.optimizer_loop.state_dict") + fit_loop["epoch_loop.automatic_optimization.optim_progress"] = fit_loop.pop( + "epoch_loop.optimizer_loop.optim_progress" + ) + if "epoch_loop.manual_loop.state_dict" in fit_loop: + fit_loop["epoch_loop.manual_optimization.state_dict"] = fit_loop.pop("epoch_loop.manual_loop.state_dict") + fit_loop["epoch_loop.manual_optimization.optim_step_progress"] = fit_loop.pop( + "epoch_loop.manual_loop.optim_step_progress" + ) return checkpoint diff --git a/tests/tests_pytorch/core/test_lightning_optimizer.py b/tests/tests_pytorch/core/test_lightning_optimizer.py index a95d2855c032d..901d6191b62e9 100644 --- a/tests/tests_pytorch/core/test_lightning_optimizer.py +++ b/tests/tests_pytorch/core/test_lightning_optimizer.py @@ -20,7 +20,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.core.optimizer import LightningOptimizer from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.loops.optimization.optimizer_loop import Closure +from lightning.pytorch.loops.optimization.automatic import Closure from lightning.pytorch.tuner.tuning import Tuner diff --git a/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py b/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py index d9ed81ef8e2ae..91ce9f3921c4e 100644 --- a/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py +++ b/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py @@ -66,7 +66,7 @@ def test_should_stop_early_stopping_conditions_not_met( trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0) trainer.num_training_batches = 10 trainer.should_stop = True - trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = global_step + trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = global_step trainer.fit_loop.epoch_loop.batch_progress.current.ready = global_step trainer.fit_loop.epoch_progress.current.completed = current_epoch - 1 diff --git a/tests/tests_pytorch/loops/optimization/test_manual_loop.py b/tests/tests_pytorch/loops/optimization/test_manual_loop.py index 8ef6826c263be..b994df631bdfe 100644 --- a/tests/tests_pytorch/loops/optimization/test_manual_loop.py +++ b/tests/tests_pytorch/loops/optimization/test_manual_loop.py @@ -16,7 +16,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.loops.optimization.manual_loop import ManualResult +from lightning.pytorch.loops.optimization.manual import ManualResult from lightning.pytorch.utilities.exceptions import MisconfigurationException diff --git a/tests/tests_pytorch/loops/optimization/test_optimizer_loop.py b/tests/tests_pytorch/loops/optimization/test_optimizer_loop.py index 4f7cbdbd0ca5a..de2a34e48e3e9 100644 --- a/tests/tests_pytorch/loops/optimization/test_optimizer_loop.py +++ b/tests/tests_pytorch/loops/optimization/test_optimizer_loop.py @@ -17,7 +17,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.loops.optimization.optimizer_loop import ClosureResult +from lightning.pytorch.loops.optimization.automatic import ClosureResult from lightning.pytorch.utilities.exceptions import MisconfigurationException diff --git a/tests/tests_pytorch/loops/test_evaluation_loop_flow.py b/tests/tests_pytorch/loops/test_evaluation_loop_flow.py index 757f7f3306310..1aebb6c30e51e 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop_flow.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop_flow.py @@ -65,13 +65,13 @@ def backward(self, loss): # simulate training manually trainer.state.stage = RunningStage.TRAINING kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0} - train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs) + train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs) assert isinstance(train_step_out["loss"], Tensor) assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0]) + opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0]) opt_closure_result = opt_closure() assert opt_closure_result.item() == 171 @@ -124,13 +124,13 @@ def backward(self, loss): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0} - train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs) + train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs) assert isinstance(train_step_out["loss"], Tensor) assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0]) + opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0]) opt_closure_result = opt_closure() assert opt_closure_result.item() == 171 diff --git a/tests/tests_pytorch/loops/test_loop_state_dict.py b/tests/tests_pytorch/loops/test_loop_state_dict.py index ffebc6cd37cc9..ad2b539931be4 100644 --- a/tests/tests_pytorch/loops/test_loop_state_dict.py +++ b/tests/tests_pytorch/loops/test_loop_state_dict.py @@ -49,13 +49,13 @@ def test_loops_state_dict_structure(): "total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}, }, - "epoch_loop.manual_loop.state_dict": {}, - "epoch_loop.manual_loop.optim_step_progress": { + "epoch_loop.manual_optimization.state_dict": {}, + "epoch_loop.manual_optimization.optim_step_progress": { "total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}, }, - "epoch_loop.optimizer_loop.state_dict": {}, - "epoch_loop.optimizer_loop.optim_progress": { + "epoch_loop.automatic_optimization.state_dict": {}, + "epoch_loop.automatic_optimization.optim_progress": { "optimizer": { "step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}}, "zero_grad": { diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index ee3ea9f4da6ea..a885d2de9d16b 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -300,7 +300,7 @@ def training_step(self, batch, batch_idx): assert os.path.exists(ckpt_path) checkpoint = torch.load(ckpt_path) - optim_progress = trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress + optim_progress = trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress sch_progress = trainer.fit_loop.epoch_loop.scheduler_progress # `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch @@ -365,13 +365,13 @@ def training_step(self, batch, batch_idx): "total": {"ready": nbe_sch_steps + be_sch_steps, "completed": nbe_sch_steps + be_sch_steps}, "current": {"ready": be_sch_steps, "completed": be_sch_steps}, }, - "epoch_loop.manual_loop.state_dict": ANY, - "epoch_loop.manual_loop.optim_step_progress": { + "epoch_loop.manual_optimization.state_dict": ANY, + "epoch_loop.manual_optimization.optim_step_progress": { "total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}, }, - "epoch_loop.optimizer_loop.state_dict": {}, - "epoch_loop.optimizer_loop.optim_progress": { + "epoch_loop.automatic_optimization.state_dict": {}, + "epoch_loop.automatic_optimization.optim_progress": { "optimizer": { "step": { "total": { @@ -423,7 +423,7 @@ def training_step(self, batch, batch_idx): assert batch_progress.current.ready == be_batches_completed assert batch_progress.current.completed == be_batches_completed - optim_progress = trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress + optim_progress = trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress assert optim_progress.optimizer.step.current.ready == be_total_opt_steps assert optim_progress.optimizer.step.current.completed == be_total_opt_steps assert optim_progress.optimizer.zero_grad.current.ready == be_total_zero_grad @@ -503,13 +503,13 @@ def train_dataloader(self): "total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total}, "current": {"ready": n_sch_steps_current, "completed": n_sch_steps_current}, }, - "epoch_loop.manual_loop.state_dict": ANY, - "epoch_loop.manual_loop.optim_step_progress": { + "epoch_loop.manual_optimization.state_dict": ANY, + "epoch_loop.manual_optimization.optim_step_progress": { "total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}, }, - "epoch_loop.optimizer_loop.state_dict": {}, - "epoch_loop.optimizer_loop.optim_progress": { + "epoch_loop.automatic_optimization.state_dict": {}, + "epoch_loop.automatic_optimization.optim_progress": { "optimizer": { "step": { "total": { @@ -568,7 +568,7 @@ def test_fit_loop_reset(tmpdir): mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=2.ckpt")) fit_loop = trainer.fit_loop epoch_loop = fit_loop.epoch_loop - optimizer_loop = epoch_loop.optimizer_loop + optimizer_loop = epoch_loop.automatic_optimization assert not fit_loop.restarting assert not epoch_loop.restarting assert not optimizer_loop.restarting diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index 98f86fb3dfc61..5b6b6c56229b8 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -217,7 +217,7 @@ def test_should_stop_early_stopping_conditions_met( trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0, max_epochs=100) trainer.num_training_batches = 10 trainer.should_stop = True - trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = ( + trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = ( current_epoch * trainer.num_training_batches ) trainer.fit_loop.epoch_loop.batch_progress.current.ready = 10 diff --git a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py index 2f27580ae2015..8381f047d0530 100644 --- a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py +++ b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py @@ -20,7 +20,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.core.module import LightningModule from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.loops.optimization.optimizer_loop import Closure +from lightning.pytorch.loops.optimization.automatic import Closure from lightning.pytorch.trainer.states import RunningStage from tests_pytorch.helpers.deterministic_model import DeterministicModel @@ -147,13 +147,13 @@ def backward(self, loss): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0} - train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs) + train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs) assert isinstance(train_step_out["loss"], Tensor) assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0]) + opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0]) opt_closure_result = opt_closure() assert opt_closure_result.item() == 171 @@ -215,13 +215,13 @@ def backward(self, loss): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0} - train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs) + train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs) assert isinstance(train_step_out["loss"], Tensor) assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0]) + opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0]) opt_closure_result = opt_closure() assert opt_closure_result.item() == 171 @@ -298,7 +298,7 @@ def training_step(self, batch, batch_idx): # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): kwargs = {"batch": batch, "batch_idx": batch_idx} - out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs) + out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs) if not batch_idx % 2: assert out == {} diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 7fb522ef6f990..490edf044f355 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -339,7 +339,7 @@ def mock_save_function(filepath, *args): # emulate callback's calls during the training for i, loss in enumerate(losses, 1): # sets `trainer.global_step` - trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = i + trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = i trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)}) checkpoint_callback.on_validation_end(trainer, trainer.lightning_module) trainer.fit_loop.epoch_progress.current.completed = i # sets `trainer.current_epoch` diff --git a/tests/tests_pytorch/utilities/migration/test_migration.py b/tests/tests_pytorch/utilities/migration/test_migration.py index f901f4f02842f..68834d328ad8c 100644 --- a/tests/tests_pytorch/utilities/migration/test_migration.py +++ b/tests/tests_pytorch/utilities/migration/test_migration.py @@ -186,10 +186,10 @@ def test_migrate_loop_structure_after_tbptt_removal(): assert updated_checkpoint["loops"] == { "fit_loop": { "epoch_loop.state_dict": {"any": "state", "old_batch_loop_state_dict": old_batch_loop_state}, - "epoch_loop.optimizer_loop.state_dict": state_automatic, - "epoch_loop.optimizer_loop.optim_progress": optim_progress_automatic, - "epoch_loop.manual_loop.state_dict": state_manual, - "epoch_loop.manual_loop.optim_step_progress": optim_progress_manual, + "epoch_loop.automatic_optimization.state_dict": state_automatic, + "epoch_loop.automatic_optimization.optim_progress": optim_progress_automatic, + "epoch_loop.manual_optimization.state_dict": state_manual, + "epoch_loop.manual_optimization.optim_step_progress": optim_progress_manual, } } @@ -198,10 +198,12 @@ def test_migrate_loop_structure_after_optimizer_loop_removal(): """Test the loop state migration after multiple optimizer support in automatic optimization was removed in 2.0.0.""" state_automatic = MagicMock() + state_manual = MagicMock() optim_progress_automatic = { "optimizer": MagicMock(), "optimizer_position": 33, } + optim_progress_manual = MagicMock() old_checkpoint = { "loops": { "fit_loop": { @@ -209,6 +211,8 @@ def test_migrate_loop_structure_after_optimizer_loop_removal(): "epoch_loop.batch_loop.state_dict": MagicMock(), "epoch_loop.batch_loop.optimizer_loop.state_dict": state_automatic, "epoch_loop.batch_loop.optimizer_loop.optim_progress": optim_progress_automatic, + "epoch_loop.batch_loop.manual_loop.state_dict": state_manual, + "epoch_loop.batch_loop.manual_loop.optim_step_progress": optim_progress_manual, } } } @@ -217,8 +221,9 @@ def test_migrate_loop_structure_after_optimizer_loop_removal(): assert updated_checkpoint["loops"] == { "fit_loop": { "epoch_loop.state_dict": ANY, - "epoch_loop.optimizer_loop.state_dict": state_automatic, - # optimizer_position gets dropped: - "epoch_loop.optimizer_loop.optim_progress": {"optimizer": ANY}, + "epoch_loop.automatic_optimization.state_dict": state_automatic, + "epoch_loop.automatic_optimization.optim_progress": {"optimizer": ANY}, # optimizer_position gets dropped + "epoch_loop.manual_optimization.state_dict": state_manual, + "epoch_loop.manual_optimization.optim_step_progress": optim_progress_manual, } }