Skip to content

Commit

Permalink
Rename optimization loops (#16598)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
3 people authored Feb 2, 2023
1 parent 6f7276b commit 6b70d17
Show file tree
Hide file tree
Showing 19 changed files with 80 additions and 62 deletions.
2 changes: 1 addition & 1 deletion src/lightning/pytorch/loops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
from lightning.pytorch.loops.dataloader import _DataLoaderLoop, _EvaluationLoop, _PredictionLoop # noqa: F401
from lightning.pytorch.loops.epoch import _EvaluationEpochLoop, _PredictionEpochLoop, _TrainingEpochLoop # noqa: F401
from lightning.pytorch.loops.fit_loop import _FitLoop # noqa: F401
from lightning.pytorch.loops.optimization import _ManualOptimization, _OptimizerLoop # noqa: F401
from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization # noqa: F401
22 changes: 11 additions & 11 deletions src/lightning/pytorch/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import torch

from lightning.pytorch import loops # import as loops to avoid circular imports
from lightning.pytorch.loops.optimization import _ManualOptimization, _OptimizerLoop
from lightning.pytorch.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
from lightning.pytorch.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization
from lightning.pytorch.loops.optimization.automatic import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
from lightning.pytorch.loops.optimization.manual import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
from lightning.pytorch.loops.progress import BatchProgress, SchedulerProgress
from lightning.pytorch.loops.utilities import _is_max_limit_reached
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
Expand Down Expand Up @@ -68,8 +68,8 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None
self.batch_progress = BatchProgress()
self.scheduler_progress = SchedulerProgress()

self.optimizer_loop = _OptimizerLoop()
self.manual_loop = _ManualOptimization()
self.automatic_optimization = _AutomaticOptimization()
self.manual_optimization = _ManualOptimization()

self.val_loop = loops._EvaluationLoop(verbose=False)

Expand All @@ -96,8 +96,8 @@ def batch_idx(self) -> int:
def global_step(self) -> int:
lightning_module = self.trainer.lightning_module
if lightning_module is None or lightning_module.automatic_optimization:
return self.optimizer_loop.optim_progress.optimizer_steps
return self.manual_loop.optim_step_progress.total.completed
return self.automatic_optimization.optim_progress.optimizer_steps
return self.manual_optimization.optim_step_progress.total.completed

@property
def _is_training_done(self) -> bool:
Expand Down Expand Up @@ -146,7 +146,7 @@ def reset(self) -> None:
if self.restarting:
self.batch_progress.reset_on_restart()
self.scheduler_progress.reset_on_restart()
self.optimizer_loop.optim_progress.reset_on_restart()
self.automatic_optimization.optim_progress.reset_on_restart()

trainer = self.trainer
if trainer.num_training_batches != float("inf"):
Expand All @@ -159,7 +159,7 @@ def reset(self) -> None:
else:
self.batch_progress.reset_on_run()
self.scheduler_progress.reset_on_run()
self.optimizer_loop.optim_progress.reset_on_run()
self.automatic_optimization.optim_progress.reset_on_run()
# when the epoch starts, the total val batch progress should be reset as it's supposed to count the batches
# seen per epoch, this is useful for tracking when validation is run multiple times per epoch
self.val_loop.epoch_loop.batch_progress.total.reset()
Expand Down Expand Up @@ -222,9 +222,9 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None:
with self.trainer.profiler.profile("run_training_batch"):
if self.trainer.lightning_module.automatic_optimization:
# in automatic optimization, there can only be one optimizer
batch_output = self.optimizer_loop.run(self.trainer.optimizers[0], kwargs)
batch_output = self.automatic_optimization.run(self.trainer.optimizers[0], kwargs)
else:
batch_output = self.manual_loop.run(kwargs)
batch_output = self.manual_optimization.run(kwargs)

self.batch_progress.increment_processed()

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ def prefetch_batches(self) -> int:
@property
def _skip_backward(self) -> bool:
"""Determines whether the loop will skip backward during automatic optimization."""
return self.epoch_loop.optimizer_loop._skip_backward
return self.epoch_loop.automatic_optimization._skip_backward

@_skip_backward.setter
def _skip_backward(self, value: bool) -> None:
"""Determines whether the loop will skip backward during automatic optimization."""
self.epoch_loop.optimizer_loop._skip_backward = value
self.epoch_loop.automatic_optimization._skip_backward = value

@property
def _results(self) -> _ResultCollection:
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loops/optimization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from lightning.pytorch.loops.optimization.manual_loop import _ManualOptimization # noqa: F401
from lightning.pytorch.loops.optimization.optimizer_loop import _OptimizerLoop # noqa: F401
from lightning.pytorch.loops.optimization.automatic import _AutomaticOptimization # noqa: F401
from lightning.pytorch.loops.optimization.manual import _ManualOptimization # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
_OUTPUTS_TYPE = Dict[str, Any]


class _OptimizerLoop(_Loop):
class _AutomaticOptimization(_Loop):
"""Performs automatic optimization (forward, zero grad, backward, optimizer step)"""

output_result_cls = ClosureResult
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

@dataclass
class ManualResult(OutputResult):
"""A container to hold the result returned by the ``ManualLoop``.
"""A container to hold the result returned by ``_ManualOptimization``.
It is created from the output of :meth:`~lightning.pytorch.core.module.LightningModule.training_step`.
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:

def _reset_progress(trainer: "pl.Trainer") -> None:
if trainer.lightning_module.automatic_optimization:
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.reset()
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.reset()
else:
trainer.fit_loop.epoch_loop.manual_loop.optim_step_progress.reset()
trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.reset()

trainer.fit_loop.epoch_progress.reset()
21 changes: 17 additions & 4 deletions src/lightning/pytorch/utilities/migration/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,26 @@ def _migrate_loop_structure_after_optimizer_loop_removal(checkpoint: _CHECKPOINT
resuming the loop.
Version: 2.0.0
Commit: TBD
PR: TBD
Commit: 6a56586
PR: #16539, #16598
"""
if "loops" not in checkpoint:
return checkpoint

# TODO: Complete this migration function when optimizer loop gets flattened out and keys need to be remapped
fit_loop = checkpoint["loops"]["fit_loop"]
fit_loop["epoch_loop.optimizer_loop.optim_progress"].pop("optimizer_position", None)
# optimizer_position is no longer used
if "epoch_loop.optimizer_loop.optim_progress" in fit_loop:
fit_loop["epoch_loop.optimizer_loop.optim_progress"].pop("optimizer_position", None)

# the subloop attribute names have changed
if "epoch_loop.optimizer_loop.state_dict" in fit_loop:
fit_loop["epoch_loop.automatic_optimization.state_dict"] = fit_loop.pop("epoch_loop.optimizer_loop.state_dict")
fit_loop["epoch_loop.automatic_optimization.optim_progress"] = fit_loop.pop(
"epoch_loop.optimizer_loop.optim_progress"
)
if "epoch_loop.manual_loop.state_dict" in fit_loop:
fit_loop["epoch_loop.manual_optimization.state_dict"] = fit_loop.pop("epoch_loop.manual_loop.state_dict")
fit_loop["epoch_loop.manual_optimization.optim_step_progress"] = fit_loop.pop(
"epoch_loop.manual_loop.optim_step_progress"
)
return checkpoint
2 changes: 1 addition & 1 deletion tests/tests_pytorch/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from lightning.pytorch import Trainer
from lightning.pytorch.core.optimizer import LightningOptimizer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loops.optimization.optimizer_loop import Closure
from lightning.pytorch.loops.optimization.automatic import Closure
from lightning.pytorch.tuner.tuning import Tuner


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_should_stop_early_stopping_conditions_not_met(
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0)
trainer.num_training_batches = 10
trainer.should_stop = True
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = global_step
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = global_step
trainer.fit_loop.epoch_loop.batch_progress.current.ready = global_step
trainer.fit_loop.epoch_progress.current.completed = current_epoch - 1

Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/loops/optimization/test_manual_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loops.optimization.manual_loop import ManualResult
from lightning.pytorch.loops.optimization.manual import ManualResult
from lightning.pytorch.utilities.exceptions import MisconfigurationException


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loops.optimization.optimizer_loop import ClosureResult
from lightning.pytorch.loops.optimization.automatic import ClosureResult
from lightning.pytorch.utilities.exceptions import MisconfigurationException


Expand Down
8 changes: 4 additions & 4 deletions tests/tests_pytorch/loops/test_evaluation_loop_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ def backward(self, loss):
# simulate training manually
trainer.state.stage = RunningStage.TRAINING
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)

assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

Expand Down Expand Up @@ -124,13 +124,13 @@ def backward(self, loss):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)

assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

Expand Down
8 changes: 4 additions & 4 deletions tests/tests_pytorch/loops/test_loop_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ def test_loops_state_dict_structure():
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.manual_loop.state_dict": {},
"epoch_loop.manual_loop.optim_step_progress": {
"epoch_loop.manual_optimization.state_dict": {},
"epoch_loop.manual_optimization.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.optimizer_loop.state_dict": {},
"epoch_loop.optimizer_loop.optim_progress": {
"epoch_loop.automatic_optimization.state_dict": {},
"epoch_loop.automatic_optimization.optim_progress": {
"optimizer": {
"step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
"zero_grad": {
Expand Down
22 changes: 11 additions & 11 deletions tests/tests_pytorch/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def training_step(self, batch, batch_idx):
assert os.path.exists(ckpt_path)
checkpoint = torch.load(ckpt_path)

optim_progress = trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress
optim_progress = trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress
sch_progress = trainer.fit_loop.epoch_loop.scheduler_progress

# `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch
Expand Down Expand Up @@ -365,13 +365,13 @@ def training_step(self, batch, batch_idx):
"total": {"ready": nbe_sch_steps + be_sch_steps, "completed": nbe_sch_steps + be_sch_steps},
"current": {"ready": be_sch_steps, "completed": be_sch_steps},
},
"epoch_loop.manual_loop.state_dict": ANY,
"epoch_loop.manual_loop.optim_step_progress": {
"epoch_loop.manual_optimization.state_dict": ANY,
"epoch_loop.manual_optimization.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.optimizer_loop.state_dict": {},
"epoch_loop.optimizer_loop.optim_progress": {
"epoch_loop.automatic_optimization.state_dict": {},
"epoch_loop.automatic_optimization.optim_progress": {
"optimizer": {
"step": {
"total": {
Expand Down Expand Up @@ -423,7 +423,7 @@ def training_step(self, batch, batch_idx):
assert batch_progress.current.ready == be_batches_completed
assert batch_progress.current.completed == be_batches_completed

optim_progress = trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress
optim_progress = trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress
assert optim_progress.optimizer.step.current.ready == be_total_opt_steps
assert optim_progress.optimizer.step.current.completed == be_total_opt_steps
assert optim_progress.optimizer.zero_grad.current.ready == be_total_zero_grad
Expand Down Expand Up @@ -503,13 +503,13 @@ def train_dataloader(self):
"total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total},
"current": {"ready": n_sch_steps_current, "completed": n_sch_steps_current},
},
"epoch_loop.manual_loop.state_dict": ANY,
"epoch_loop.manual_loop.optim_step_progress": {
"epoch_loop.manual_optimization.state_dict": ANY,
"epoch_loop.manual_optimization.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.optimizer_loop.state_dict": {},
"epoch_loop.optimizer_loop.optim_progress": {
"epoch_loop.automatic_optimization.state_dict": {},
"epoch_loop.automatic_optimization.optim_progress": {
"optimizer": {
"step": {
"total": {
Expand Down Expand Up @@ -568,7 +568,7 @@ def test_fit_loop_reset(tmpdir):
mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=2.ckpt"))
fit_loop = trainer.fit_loop
epoch_loop = fit_loop.epoch_loop
optimizer_loop = epoch_loop.optimizer_loop
optimizer_loop = epoch_loop.automatic_optimization
assert not fit_loop.restarting
assert not epoch_loop.restarting
assert not optimizer_loop.restarting
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_should_stop_early_stopping_conditions_met(
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0, max_epochs=100)
trainer.num_training_batches = 10
trainer.should_stop = True
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = (
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = (
current_epoch * trainer.num_training_batches
)
trainer.fit_loop.epoch_loop.batch_progress.current.ready = 10
Expand Down
12 changes: 6 additions & 6 deletions tests/tests_pytorch/loops/test_training_loop_flow_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from lightning.pytorch import Trainer
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.loops.optimization.optimizer_loop import Closure
from lightning.pytorch.loops.optimization.automatic import Closure
from lightning.pytorch.trainer.states import RunningStage
from tests_pytorch.helpers.deterministic_model import DeterministicModel

Expand Down Expand Up @@ -147,13 +147,13 @@ def backward(self, loss):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)

assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

Expand Down Expand Up @@ -215,13 +215,13 @@ def backward(self, loss):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)

assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

Expand Down Expand Up @@ -298,7 +298,7 @@ def training_step(self, batch, batch_idx):
# manually check a few batches
for batch_idx, batch in enumerate(model.train_dataloader()):
kwargs = {"batch": batch, "batch_idx": batch_idx}
out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)
if not batch_idx % 2:
assert out == {}

Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def mock_save_function(filepath, *args):
# emulate callback's calls during the training
for i, loss in enumerate(losses, 1):
# sets `trainer.global_step`
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = i
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = i
trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)})
checkpoint_callback.on_validation_end(trainer, trainer.lightning_module)
trainer.fit_loop.epoch_progress.current.completed = i # sets `trainer.current_epoch`
Expand Down
Loading

0 comments on commit 6b70d17

Please sign in to comment.