Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename optimization loops #16598

Merged
merged 16 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/pytorch/loops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
from lightning.pytorch.loops.dataloader import _DataLoaderLoop, _EvaluationLoop, _PredictionLoop # noqa: F401
from lightning.pytorch.loops.epoch import _EvaluationEpochLoop, _PredictionEpochLoop, _TrainingEpochLoop # noqa: F401
from lightning.pytorch.loops.fit_loop import _FitLoop # noqa: F401
from lightning.pytorch.loops.optimization import _ManualOptimization, _OptimizerLoop # noqa: F401
from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization # noqa: F401
8 changes: 4 additions & 4 deletions src/lightning/pytorch/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import torch

from lightning.pytorch import loops # import as loops to avoid circular imports
from lightning.pytorch.loops.optimization import _ManualOptimization, _OptimizerLoop
from lightning.pytorch.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
from lightning.pytorch.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization
from lightning.pytorch.loops.optimization.automatic import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
from lightning.pytorch.loops.optimization.manual import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
from lightning.pytorch.loops.progress import BatchProgress, SchedulerProgress
from lightning.pytorch.loops.utilities import _is_max_limit_reached
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None
self.batch_progress = BatchProgress()
self.scheduler_progress = SchedulerProgress()

self.optimizer_loop = _OptimizerLoop()
self.optimizer_loop = _AutomaticOptimization()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self.manual_loop = _ManualOptimization()

self.val_loop = loops._EvaluationLoop(verbose=False)
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ def prefetch_batches(self) -> int:
@property
def _skip_backward(self) -> bool:
"""Determines whether the loop will skip backward during automatic optimization."""
return self.epoch_loop.optimizer_loop._skip_backward
return self.epoch_loop.automatic_optimization._skip_backward

@_skip_backward.setter
def _skip_backward(self, value: bool) -> None:
"""Determines whether the loop will skip backward during automatic optimization."""
self.epoch_loop.optimizer_loop._skip_backward = value
self.epoch_loop.automatic_optimization._skip_backward = value

@property
def _results(self) -> _ResultCollection:
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loops/optimization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from lightning.pytorch.loops.optimization.manual_loop import _ManualOptimization # noqa: F401
from lightning.pytorch.loops.optimization.optimizer_loop import _OptimizerLoop # noqa: F401
from lightning.pytorch.loops.optimization.automatic import _AutomaticOptimization # noqa: F401
from lightning.pytorch.loops.optimization.manual import _ManualOptimization # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
_OUTPUTS_TYPE = Dict[str, Any]


class _OptimizerLoop(_Loop):
class _AutomaticOptimization(_Loop):
"""Performs automatic optimization (forward, zero grad, backward, optimizer step)"""

output_result_cls = ClosureResult
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

@dataclass
class ManualResult(OutputResult):
"""A container to hold the result returned by the ``ManualLoop``.
"""A container to hold the result returned by ``_ManualOptimization``.

It is created from the output of :meth:`~lightning.pytorch.core.module.LightningModule.training_step`.

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:

def _reset_progress(trainer: "pl.Trainer") -> None:
if trainer.lightning_module.automatic_optimization:
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.reset()
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.reset()
else:
trainer.fit_loop.epoch_loop.manual_loop.optim_step_progress.reset()
trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.reset()

trainer.fit_loop.epoch_progress.reset()
15 changes: 12 additions & 3 deletions src/lightning/pytorch/utilities/migration/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,22 @@ def _migrate_loop_structure_after_optimizer_loop_removal(checkpoint: _CHECKPOINT
resuming the loop.

Version: 2.0.0
Commit: TBD
PR: TBD
Commit: 6a56586
PR: #16539, #16598
"""
if "loops" not in checkpoint:
return checkpoint

# TODO: Complete this migration function when optimizer loop gets flattened out and keys need to be remapped
fit_loop = checkpoint["loops"]["fit_loop"]
# optimizer_position is no longer used
fit_loop["epoch_loop.optimizer_loop.optim_progress"].pop("optimizer_position", None)
# the subloop attribute names have changed
fit_loop["epoch_loop.automatic_optimization.state_dict"] = fit_loop.pop("epoch_loop.optimizer_loop.state_dict")
fit_loop["epoch_loop.automatic_optimization.optim_progress"] = fit_loop.pop(
"epoch_loop.optimizer_loop.optim_progress"
)
fit_loop["epoch_loop.manual_optimization.state_dict"] = fit_loop.pop("epoch_loop.manual_loop.state_dict")
fit_loop["epoch_loop.manual_optimization.optim_step_progress"] = fit_loop.pop(
"epoch_loop.manual_loop.optim_step_progress"
)
return checkpoint
2 changes: 1 addition & 1 deletion tests/tests_pytorch/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from lightning.pytorch import Trainer
from lightning.pytorch.core.optimizer import LightningOptimizer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loops.optimization.optimizer_loop import Closure
from lightning.pytorch.tuner.tuning import Tuner
from pytorch_lightning.loops.optimization.automatic import Closure


@pytest.mark.parametrize("auto", (True, False))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_should_stop_early_stopping_conditions_not_met(
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0)
trainer.num_training_batches = 10
trainer.should_stop = True
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = global_step
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = global_step
trainer.fit_loop.epoch_loop.batch_progress.current.ready = global_step
trainer.fit_loop.epoch_progress.current.completed = current_epoch - 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loops.optimization.manual_loop import ManualResult
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from pytorch_lightning.loops.optimization.manual import ManualResult


def test_manual_result():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loops.optimization.optimizer_loop import ClosureResult
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from pytorch_lightning.loops.optimization.automatic import ClosureResult


def test_closure_result_deepcopy():
Expand Down
8 changes: 4 additions & 4 deletions tests/tests_pytorch/loops/test_evaluation_loop_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ def backward(self, loss):
# simulate training manually
trainer.state.stage = RunningStage.TRAINING
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)

assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

Expand Down Expand Up @@ -124,13 +124,13 @@ def backward(self, loss):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)

assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

Expand Down
8 changes: 4 additions & 4 deletions tests/tests_pytorch/loops/test_loop_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ def test_loops_state_dict_structure():
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.manual_loop.state_dict": {},
"epoch_loop.manual_loop.optim_step_progress": {
"epoch_loop.manual_optimization.state_dict": {},
"epoch_loop.manual_optimization.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.optimizer_loop.state_dict": {},
"epoch_loop.optimizer_loop.optim_progress": {
"epoch_loop.automatic_optimization.state_dict": {},
"epoch_loop.automatic_optimization.optim_progress": {
"optimizer": {
"step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
"zero_grad": {
Expand Down
22 changes: 11 additions & 11 deletions tests/tests_pytorch/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def training_step(self, batch, batch_idx):
assert os.path.exists(ckpt_path)
checkpoint = torch.load(ckpt_path)

optim_progress = trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress
optim_progress = trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress
sch_progress = trainer.fit_loop.epoch_loop.scheduler_progress

# `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch
Expand Down Expand Up @@ -365,13 +365,13 @@ def training_step(self, batch, batch_idx):
"total": {"ready": nbe_sch_steps + be_sch_steps, "completed": nbe_sch_steps + be_sch_steps},
"current": {"ready": be_sch_steps, "completed": be_sch_steps},
},
"epoch_loop.manual_loop.state_dict": ANY,
"epoch_loop.manual_loop.optim_step_progress": {
"epoch_loop.manual_optimization.state_dict": ANY,
"epoch_loop.manual_optimization.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.optimizer_loop.state_dict": {},
"epoch_loop.optimizer_loop.optim_progress": {
"epoch_loop.automatic_optimization.state_dict": {},
"epoch_loop.automatic_optimization.optim_progress": {
"optimizer": {
"step": {
"total": {
Expand Down Expand Up @@ -423,7 +423,7 @@ def training_step(self, batch, batch_idx):
assert batch_progress.current.ready == be_batches_completed
assert batch_progress.current.completed == be_batches_completed

optim_progress = trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress
optim_progress = trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress
assert optim_progress.optimizer.step.current.ready == be_total_opt_steps
assert optim_progress.optimizer.step.current.completed == be_total_opt_steps
assert optim_progress.optimizer.zero_grad.current.ready == be_total_zero_grad
Expand Down Expand Up @@ -503,13 +503,13 @@ def train_dataloader(self):
"total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total},
"current": {"ready": n_sch_steps_current, "completed": n_sch_steps_current},
},
"epoch_loop.manual_loop.state_dict": ANY,
"epoch_loop.manual_loop.optim_step_progress": {
"epoch_loop.manual_optimization.state_dict": ANY,
"epoch_loop.manual_optimization.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.optimizer_loop.state_dict": {},
"epoch_loop.optimizer_loop.optim_progress": {
"epoch_loop.automatic_optimization.state_dict": {},
"epoch_loop.automatic_optimization.optim_progress": {
"optimizer": {
"step": {
"total": {
Expand Down Expand Up @@ -568,7 +568,7 @@ def test_fit_loop_reset(tmpdir):
mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=2.ckpt"))
fit_loop = trainer.fit_loop
epoch_loop = fit_loop.epoch_loop
optimizer_loop = epoch_loop.optimizer_loop
optimizer_loop = epoch_loop.automatic_optimization
assert not fit_loop.restarting
assert not epoch_loop.restarting
assert not optimizer_loop.restarting
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_should_stop_early_stopping_conditions_met(
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0, max_epochs=100)
trainer.num_training_batches = 10
trainer.should_stop = True
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = (
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = (
current_epoch * trainer.num_training_batches
)
trainer.fit_loop.epoch_loop.batch_progress.current.ready = 10
Expand Down
12 changes: 6 additions & 6 deletions tests/tests_pytorch/loops/test_training_loop_flow_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from lightning.pytorch import Trainer
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.loops.optimization.optimizer_loop import Closure
from lightning.pytorch.trainer.states import RunningStage
from pytorch_lightning.loops.optimization.automatic import Closure
from tests_pytorch.helpers.deterministic_model import DeterministicModel


Expand Down Expand Up @@ -147,13 +147,13 @@ def backward(self, loss):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)

assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

Expand Down Expand Up @@ -215,13 +215,13 @@ def backward(self, loss):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)

assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

Expand Down Expand Up @@ -298,7 +298,7 @@ def training_step(self, batch, batch_idx):
# manually check a few batches
for batch_idx, batch in enumerate(model.train_dataloader()):
kwargs = {"batch": batch, "batch_idx": batch_idx}
out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)
if not batch_idx % 2:
assert out == {}

Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def mock_save_function(filepath, *args):
# emulate callback's calls during the training
for i, loss in enumerate(losses, 1):
# sets `trainer.global_step`
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = i
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = i
trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)})
checkpoint_callback.on_validation_end(trainer, trainer.lightning_module)
trainer.fit_loop.epoch_progress.current.completed = i # sets `trainer.current_epoch`
Expand Down
19 changes: 12 additions & 7 deletions tests/tests_pytorch/utilities/migration/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,10 @@ def test_migrate_loop_structure_after_tbptt_removal():
assert updated_checkpoint["loops"] == {
"fit_loop": {
"epoch_loop.state_dict": {"any": "state", "old_batch_loop_state_dict": old_batch_loop_state},
"epoch_loop.optimizer_loop.state_dict": state_automatic,
"epoch_loop.optimizer_loop.optim_progress": optim_progress_automatic,
"epoch_loop.manual_loop.state_dict": state_manual,
"epoch_loop.manual_loop.optim_step_progress": optim_progress_manual,
"epoch_loop.automatic_optimization.state_dict": state_automatic,
"epoch_loop.automatic_optimization.optim_progress": optim_progress_automatic,
"epoch_loop.manual_optimization.state_dict": state_manual,
"epoch_loop.manual_optimization.optim_step_progress": optim_progress_manual,
}
}

Expand All @@ -198,17 +198,21 @@ def test_migrate_loop_structure_after_optimizer_loop_removal():
"""Test the loop state migration after multiple optimizer support in automatic optimization was removed in
2.0.0."""
state_automatic = MagicMock()
state_manual = MagicMock()
optim_progress_automatic = {
"optimizer": MagicMock(),
"optimizer_position": 33,
}
optim_progress_manual = MagicMock()
old_checkpoint = {
"loops": {
"fit_loop": {
"epoch_loop.state_dict": {"any": "state"},
"epoch_loop.batch_loop.state_dict": MagicMock(),
"epoch_loop.batch_loop.optimizer_loop.state_dict": state_automatic,
"epoch_loop.batch_loop.optimizer_loop.optim_progress": optim_progress_automatic,
"epoch_loop.batch_loop.manual_loop.state_dict": state_manual,
"epoch_loop.batch_loop.manual_loop.optim_step_progress": optim_progress_manual,
}
}
}
Expand All @@ -217,8 +221,9 @@ def test_migrate_loop_structure_after_optimizer_loop_removal():
assert updated_checkpoint["loops"] == {
"fit_loop": {
"epoch_loop.state_dict": ANY,
"epoch_loop.optimizer_loop.state_dict": state_automatic,
# optimizer_position gets dropped:
"epoch_loop.optimizer_loop.optim_progress": {"optimizer": ANY},
"epoch_loop.automatic_optimization.state_dict": state_automatic,
"epoch_loop.automatic_optimization.optim_progress": {"optimizer": ANY}, # optimizer_position gets dropped
"epoch_loop.manual_optimization.state_dict": state_manual,
"epoch_loop.manual_optimization.optim_step_progress": optim_progress_manual,
}
}