Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename optimization loops #16598

Merged
merged 16 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/pytorch/loops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
from lightning.pytorch.loops.dataloader import _DataLoaderLoop, _EvaluationLoop, _PredictionLoop # noqa: F401
from lightning.pytorch.loops.epoch import _EvaluationEpochLoop, _PredictionEpochLoop, _TrainingEpochLoop # noqa: F401
from lightning.pytorch.loops.fit_loop import _FitLoop # noqa: F401
from lightning.pytorch.loops.optimization import _ManualOptimization, _OptimizerLoop # noqa: F401
from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization # noqa: F401
22 changes: 11 additions & 11 deletions src/lightning/pytorch/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import torch

from lightning.pytorch import loops # import as loops to avoid circular imports
from lightning.pytorch.loops.optimization import _ManualOptimization, _OptimizerLoop
from lightning.pytorch.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
from lightning.pytorch.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization
from lightning.pytorch.loops.optimization.automatic import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
from lightning.pytorch.loops.optimization.manual import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
from lightning.pytorch.loops.progress import BatchProgress, SchedulerProgress
from lightning.pytorch.loops.utilities import _is_max_limit_reached
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
Expand Down Expand Up @@ -68,8 +68,8 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None
self.batch_progress = BatchProgress()
self.scheduler_progress = SchedulerProgress()

self.optimizer_loop = _OptimizerLoop()
self.manual_loop = _ManualOptimization()
self.automatic_optimization = _AutomaticOptimization()
self.manual_optimization = _ManualOptimization()

self.val_loop = loops._EvaluationLoop(verbose=False)

Expand All @@ -96,8 +96,8 @@ def batch_idx(self) -> int:
def global_step(self) -> int:
lightning_module = self.trainer.lightning_module
if lightning_module is None or lightning_module.automatic_optimization:
return self.optimizer_loop.optim_progress.optimizer_steps
return self.manual_loop.optim_step_progress.total.completed
return self.automatic_optimization.optim_progress.optimizer_steps
return self.manual_optimization.optim_step_progress.total.completed

@property
def _is_training_done(self) -> bool:
Expand Down Expand Up @@ -146,7 +146,7 @@ def reset(self) -> None:
if self.restarting:
self.batch_progress.reset_on_restart()
self.scheduler_progress.reset_on_restart()
self.optimizer_loop.optim_progress.reset_on_restart()
self.automatic_optimization.optim_progress.reset_on_restart()

trainer = self.trainer
if trainer.num_training_batches != float("inf"):
Expand All @@ -159,7 +159,7 @@ def reset(self) -> None:
else:
self.batch_progress.reset_on_run()
self.scheduler_progress.reset_on_run()
self.optimizer_loop.optim_progress.reset_on_run()
self.automatic_optimization.optim_progress.reset_on_run()
# when the epoch starts, the total val batch progress should be reset as it's supposed to count the batches
# seen per epoch, this is useful for tracking when validation is run multiple times per epoch
self.val_loop.epoch_loop.batch_progress.total.reset()
Expand Down Expand Up @@ -222,9 +222,9 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None:
with self.trainer.profiler.profile("run_training_batch"):
if self.trainer.lightning_module.automatic_optimization:
# in automatic optimization, there can only be one optimizer
batch_output = self.optimizer_loop.run(self.trainer.optimizers[0], kwargs)
batch_output = self.automatic_optimization.run(self.trainer.optimizers[0], kwargs)
else:
batch_output = self.manual_loop.run(kwargs)
batch_output = self.manual_optimization.run(kwargs)

self.batch_progress.increment_processed()

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ def prefetch_batches(self) -> int:
@property
def _skip_backward(self) -> bool:
"""Determines whether the loop will skip backward during automatic optimization."""
return self.epoch_loop.optimizer_loop._skip_backward
return self.epoch_loop.automatic_optimization._skip_backward

@_skip_backward.setter
def _skip_backward(self, value: bool) -> None:
"""Determines whether the loop will skip backward during automatic optimization."""
self.epoch_loop.optimizer_loop._skip_backward = value
self.epoch_loop.automatic_optimization._skip_backward = value

@property
def _results(self) -> _ResultCollection:
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loops/optimization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from lightning.pytorch.loops.optimization.manual_loop import _ManualOptimization # noqa: F401
from lightning.pytorch.loops.optimization.optimizer_loop import _OptimizerLoop # noqa: F401
from lightning.pytorch.loops.optimization.automatic import _AutomaticOptimization # noqa: F401
from lightning.pytorch.loops.optimization.manual import _ManualOptimization # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
_OUTPUTS_TYPE = Dict[str, Any]


class _OptimizerLoop(_Loop):
class _AutomaticOptimization(_Loop):
"""Performs automatic optimization (forward, zero grad, backward, optimizer step)"""

output_result_cls = ClosureResult
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

@dataclass
class ManualResult(OutputResult):
"""A container to hold the result returned by the ``ManualLoop``.
"""A container to hold the result returned by ``_ManualOptimization``.

It is created from the output of :meth:`~lightning.pytorch.core.module.LightningModule.training_step`.

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:

def _reset_progress(trainer: "pl.Trainer") -> None:
if trainer.lightning_module.automatic_optimization:
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.reset()
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.reset()
else:
trainer.fit_loop.epoch_loop.manual_loop.optim_step_progress.reset()
trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.reset()

trainer.fit_loop.epoch_progress.reset()
21 changes: 17 additions & 4 deletions src/lightning/pytorch/utilities/migration/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,26 @@ def _migrate_loop_structure_after_optimizer_loop_removal(checkpoint: _CHECKPOINT
resuming the loop.

Version: 2.0.0
Commit: TBD
PR: TBD
Commit: 6a56586
PR: #16539, #16598
"""
if "loops" not in checkpoint:
return checkpoint

# TODO: Complete this migration function when optimizer loop gets flattened out and keys need to be remapped
fit_loop = checkpoint["loops"]["fit_loop"]
fit_loop["epoch_loop.optimizer_loop.optim_progress"].pop("optimizer_position", None)
# optimizer_position is no longer used
if "epoch_loop.optimizer_loop.optim_progress" in fit_loop:
fit_loop["epoch_loop.optimizer_loop.optim_progress"].pop("optimizer_position", None)

# the subloop attribute names have changed
if "epoch_loop.optimizer_loop.state_dict" in fit_loop:
fit_loop["epoch_loop.automatic_optimization.state_dict"] = fit_loop.pop("epoch_loop.optimizer_loop.state_dict")
fit_loop["epoch_loop.automatic_optimization.optim_progress"] = fit_loop.pop(
"epoch_loop.optimizer_loop.optim_progress"
)
if "epoch_loop.manual_loop.state_dict" in fit_loop:
fit_loop["epoch_loop.manual_optimization.state_dict"] = fit_loop.pop("epoch_loop.manual_loop.state_dict")
fit_loop["epoch_loop.manual_optimization.optim_step_progress"] = fit_loop.pop(
"epoch_loop.manual_loop.optim_step_progress"
)
return checkpoint
2 changes: 1 addition & 1 deletion tests/tests_pytorch/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from lightning.pytorch import Trainer
from lightning.pytorch.core.optimizer import LightningOptimizer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loops.optimization.optimizer_loop import Closure
from lightning.pytorch.loops.optimization.automatic import Closure
from lightning.pytorch.tuner.tuning import Tuner


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_should_stop_early_stopping_conditions_not_met(
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0)
trainer.num_training_batches = 10
trainer.should_stop = True
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = global_step
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = global_step
trainer.fit_loop.epoch_loop.batch_progress.current.ready = global_step
trainer.fit_loop.epoch_progress.current.completed = current_epoch - 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loops.optimization.manual_loop import ManualResult
from lightning.pytorch.loops.optimization.manual import ManualResult
from lightning.pytorch.utilities.exceptions import MisconfigurationException


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loops.optimization.optimizer_loop import ClosureResult
from lightning.pytorch.loops.optimization.automatic import ClosureResult
from lightning.pytorch.utilities.exceptions import MisconfigurationException


Expand Down
8 changes: 4 additions & 4 deletions tests/tests_pytorch/loops/test_evaluation_loop_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ def backward(self, loss):
# simulate training manually
trainer.state.stage = RunningStage.TRAINING
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)

assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

Expand Down Expand Up @@ -124,13 +124,13 @@ def backward(self, loss):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)

assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

Expand Down
8 changes: 4 additions & 4 deletions tests/tests_pytorch/loops/test_loop_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ def test_loops_state_dict_structure():
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.manual_loop.state_dict": {},
"epoch_loop.manual_loop.optim_step_progress": {
"epoch_loop.manual_optimization.state_dict": {},
"epoch_loop.manual_optimization.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.optimizer_loop.state_dict": {},
"epoch_loop.optimizer_loop.optim_progress": {
"epoch_loop.automatic_optimization.state_dict": {},
"epoch_loop.automatic_optimization.optim_progress": {
"optimizer": {
"step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
"zero_grad": {
Expand Down
22 changes: 11 additions & 11 deletions tests/tests_pytorch/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def training_step(self, batch, batch_idx):
assert os.path.exists(ckpt_path)
checkpoint = torch.load(ckpt_path)

optim_progress = trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress
optim_progress = trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress
sch_progress = trainer.fit_loop.epoch_loop.scheduler_progress

# `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch
Expand Down Expand Up @@ -365,13 +365,13 @@ def training_step(self, batch, batch_idx):
"total": {"ready": nbe_sch_steps + be_sch_steps, "completed": nbe_sch_steps + be_sch_steps},
"current": {"ready": be_sch_steps, "completed": be_sch_steps},
},
"epoch_loop.manual_loop.state_dict": ANY,
"epoch_loop.manual_loop.optim_step_progress": {
"epoch_loop.manual_optimization.state_dict": ANY,
"epoch_loop.manual_optimization.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.optimizer_loop.state_dict": {},
"epoch_loop.optimizer_loop.optim_progress": {
"epoch_loop.automatic_optimization.state_dict": {},
"epoch_loop.automatic_optimization.optim_progress": {
"optimizer": {
"step": {
"total": {
Expand Down Expand Up @@ -423,7 +423,7 @@ def training_step(self, batch, batch_idx):
assert batch_progress.current.ready == be_batches_completed
assert batch_progress.current.completed == be_batches_completed

optim_progress = trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress
optim_progress = trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress
assert optim_progress.optimizer.step.current.ready == be_total_opt_steps
assert optim_progress.optimizer.step.current.completed == be_total_opt_steps
assert optim_progress.optimizer.zero_grad.current.ready == be_total_zero_grad
Expand Down Expand Up @@ -503,13 +503,13 @@ def train_dataloader(self):
"total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total},
"current": {"ready": n_sch_steps_current, "completed": n_sch_steps_current},
},
"epoch_loop.manual_loop.state_dict": ANY,
"epoch_loop.manual_loop.optim_step_progress": {
"epoch_loop.manual_optimization.state_dict": ANY,
"epoch_loop.manual_optimization.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.optimizer_loop.state_dict": {},
"epoch_loop.optimizer_loop.optim_progress": {
"epoch_loop.automatic_optimization.state_dict": {},
"epoch_loop.automatic_optimization.optim_progress": {
"optimizer": {
"step": {
"total": {
Expand Down Expand Up @@ -568,7 +568,7 @@ def test_fit_loop_reset(tmpdir):
mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=2.ckpt"))
fit_loop = trainer.fit_loop
epoch_loop = fit_loop.epoch_loop
optimizer_loop = epoch_loop.optimizer_loop
optimizer_loop = epoch_loop.automatic_optimization
assert not fit_loop.restarting
assert not epoch_loop.restarting
assert not optimizer_loop.restarting
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_should_stop_early_stopping_conditions_met(
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0, max_epochs=100)
trainer.num_training_batches = 10
trainer.should_stop = True
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = (
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = (
current_epoch * trainer.num_training_batches
)
trainer.fit_loop.epoch_loop.batch_progress.current.ready = 10
Expand Down
12 changes: 6 additions & 6 deletions tests/tests_pytorch/loops/test_training_loop_flow_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from lightning.pytorch import Trainer
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.loops.optimization.optimizer_loop import Closure
from lightning.pytorch.loops.optimization.automatic import Closure
from lightning.pytorch.trainer.states import RunningStage
from tests_pytorch.helpers.deterministic_model import DeterministicModel

Expand Down Expand Up @@ -147,13 +147,13 @@ def backward(self, loss):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)

assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

Expand Down Expand Up @@ -215,13 +215,13 @@ def backward(self, loss):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)

assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

Expand Down Expand Up @@ -298,7 +298,7 @@ def training_step(self, batch, batch_idx):
# manually check a few batches
for batch_idx, batch in enumerate(model.train_dataloader()):
kwargs = {"batch": batch, "batch_idx": batch_idx}
out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)
if not batch_idx % 2:
assert out == {}

Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def mock_save_function(filepath, *args):
# emulate callback's calls during the training
for i, loss in enumerate(losses, 1):
# sets `trainer.global_step`
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = i
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = i
trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)})
checkpoint_callback.on_validation_end(trainer, trainer.lightning_module)
trainer.fit_loop.epoch_progress.current.completed = i # sets `trainer.current_epoch`
Expand Down
Loading