Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename optimization loops #16598

Merged
merged 16 commits into from
Feb 2, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
rename attributes
awaelchli committed Feb 1, 2023
commit eb7160b44780fe6bc5f6581d92ce02f13cdaf779
16 changes: 8 additions & 8 deletions src/pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
@@ -68,8 +68,8 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None
self.batch_progress = BatchProgress()
self.scheduler_progress = SchedulerProgress()

self.optimizer_loop = _AutomaticOptimization()
self.manual_loop = _ManualOptimization()
self.automatic_optimization = _AutomaticOptimization()
self.manual_optimization = _ManualOptimization()

self.val_loop = loops._EvaluationLoop(verbose=False)

@@ -96,8 +96,8 @@ def batch_idx(self) -> int:
def global_step(self) -> int:
lightning_module = self.trainer.lightning_module
if lightning_module is None or lightning_module.automatic_optimization:
return self.optimizer_loop.optim_progress.optimizer_steps
return self.manual_loop.optim_step_progress.total.completed
return self.automatic_optimization.optim_progress.optimizer_steps
return self.manual_optimization.optim_step_progress.total.completed

@property
def _is_training_done(self) -> bool:
@@ -146,7 +146,7 @@ def reset(self) -> None:
if self.restarting:
self.batch_progress.reset_on_restart()
self.scheduler_progress.reset_on_restart()
self.optimizer_loop.optim_progress.reset_on_restart()
self.automatic_optimization.optim_progress.reset_on_restart()

trainer = self.trainer
if trainer.num_training_batches != float("inf"):
@@ -159,7 +159,7 @@ def reset(self) -> None:
else:
self.batch_progress.reset_on_run()
self.scheduler_progress.reset_on_run()
self.optimizer_loop.optim_progress.reset_on_run()
self.automatic_optimization.optim_progress.reset_on_run()
# when the epoch starts, the total val batch progress should be reset as it's supposed to count the batches
# seen per epoch, this is useful for tracking when validation is run multiple times per epoch
self.val_loop.epoch_loop.batch_progress.total.reset()
@@ -222,9 +222,9 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None:
with self.trainer.profiler.profile("run_training_batch"):
if self.trainer.lightning_module.automatic_optimization:
# in automatic optimization, there can only be one optimizer
batch_output = self.optimizer_loop.run(self.trainer.optimizers[0], kwargs)
batch_output = self.automatic_optimization.run(self.trainer.optimizers[0], kwargs)
else:
batch_output = self.manual_loop.run(kwargs)
batch_output = self.manual_optimization.run(kwargs)

self.batch_progress.increment_processed()

4 changes: 2 additions & 2 deletions src/pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
@@ -131,12 +131,12 @@ def prefetch_batches(self) -> int:
@property
def _skip_backward(self) -> bool:
"""Determines whether the loop will skip backward during automatic optimization."""
return self.epoch_loop.optimizer_loop._skip_backward
return self.epoch_loop.automatic_optimization._skip_backward

@_skip_backward.setter
def _skip_backward(self, value: bool) -> None:
"""Determines whether the loop will skip backward during automatic optimization."""
self.epoch_loop.optimizer_loop._skip_backward = value
self.epoch_loop.automatic_optimization._skip_backward = value

@property
def _results(self) -> _ResultCollection:
4 changes: 2 additions & 2 deletions src/pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
@@ -361,8 +361,8 @@ def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:

def _reset_progress(trainer: "pl.Trainer") -> None:
if trainer.lightning_module.automatic_optimization:
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.reset()
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.reset()
else:
trainer.fit_loop.epoch_loop.manual_loop.optim_step_progress.reset()
trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.reset()

trainer.fit_loop.epoch_progress.reset()
6 changes: 3 additions & 3 deletions src/pytorch_lightning/utilities/migration/migration.py
Original file line number Diff line number Diff line change
@@ -91,10 +91,10 @@ def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _
checkpoint.setdefault("loops", {"fit_loop": _get_fit_loop_initial_state_1_6_0()})
checkpoint["loops"].setdefault("fit_loop", _get_fit_loop_initial_state_1_6_0())
# for automatic optimization
optim_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"]
optim_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.automatic_optimization.optim_progress"]
optim_progress["optimizer"]["step"]["total"]["completed"] = global_step
# for manual optimization
optim_step_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"]
optim_step_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_optimization.optim_step_progress"]
optim_step_progress["total"]["completed"] = global_step
return checkpoint

@@ -276,5 +276,5 @@ def _migrate_loop_structure_after_optimizer_loop_removal(checkpoint: _CHECKPOINT

# TODO: Complete this migration function when optimizer loop gets flattened out and keys need to be remapped
fit_loop = checkpoint["loops"]["fit_loop"]
fit_loop["epoch_loop.optimizer_loop.optim_progress"].pop("optimizer_position", None)
fit_loop["epoch_loop.automatic_optimization.optim_progress"].pop("optimizer_position", None)
return checkpoint
Original file line number Diff line number Diff line change
@@ -66,7 +66,7 @@ def test_should_stop_early_stopping_conditions_not_met(
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0)
trainer.num_training_batches = 10
trainer.should_stop = True
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = global_step
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = global_step
trainer.fit_loop.epoch_loop.batch_progress.current.ready = global_step
trainer.fit_loop.epoch_progress.current.completed = current_epoch - 1

8 changes: 4 additions & 4 deletions tests/tests_pytorch/loops/test_evaluation_loop_flow.py
Original file line number Diff line number Diff line change
@@ -65,13 +65,13 @@ def backward(self, loss):
# simulate training manually
trainer.state.stage = RunningStage.TRAINING
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)

assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

@@ -124,13 +124,13 @@ def backward(self, loss):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)

assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

8 changes: 4 additions & 4 deletions tests/tests_pytorch/loops/test_loop_state_dict.py
Original file line number Diff line number Diff line change
@@ -49,13 +49,13 @@ def test_loops_state_dict_structure():
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.manual_loop.state_dict": {},
"epoch_loop.manual_loop.optim_step_progress": {
"epoch_loop.manual_optimization.state_dict": {},
"epoch_loop.manual_optimization.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.optimizer_loop.state_dict": {},
"epoch_loop.optimizer_loop.optim_progress": {
"epoch_loop.automatic_optimization.state_dict": {},
"epoch_loop.automatic_optimization.optim_progress": {
"optimizer": {
"step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
"zero_grad": {
22 changes: 11 additions & 11 deletions tests/tests_pytorch/loops/test_loops.py
Original file line number Diff line number Diff line change
@@ -300,7 +300,7 @@ def training_step(self, batch, batch_idx):
assert os.path.exists(ckpt_path)
checkpoint = torch.load(ckpt_path)

optim_progress = trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress
optim_progress = trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress
sch_progress = trainer.fit_loop.epoch_loop.scheduler_progress

# `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch
@@ -365,13 +365,13 @@ def training_step(self, batch, batch_idx):
"total": {"ready": nbe_sch_steps + be_sch_steps, "completed": nbe_sch_steps + be_sch_steps},
"current": {"ready": be_sch_steps, "completed": be_sch_steps},
},
"epoch_loop.manual_loop.state_dict": ANY,
"epoch_loop.manual_loop.optim_step_progress": {
"epoch_loop.manual_optimization.state_dict": ANY,
"epoch_loop.manual_optimization.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.optimizer_loop.state_dict": {},
"epoch_loop.optimizer_loop.optim_progress": {
"epoch_loop.automatic_optimization.state_dict": {},
"epoch_loop.automatic_optimization.optim_progress": {
"optimizer": {
"step": {
"total": {
@@ -423,7 +423,7 @@ def training_step(self, batch, batch_idx):
assert batch_progress.current.ready == be_batches_completed
assert batch_progress.current.completed == be_batches_completed

optim_progress = trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress
optim_progress = trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress
assert optim_progress.optimizer.step.current.ready == be_total_opt_steps
assert optim_progress.optimizer.step.current.completed == be_total_opt_steps
assert optim_progress.optimizer.zero_grad.current.ready == be_total_zero_grad
@@ -503,13 +503,13 @@ def train_dataloader(self):
"total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total},
"current": {"ready": n_sch_steps_current, "completed": n_sch_steps_current},
},
"epoch_loop.manual_loop.state_dict": ANY,
"epoch_loop.manual_loop.optim_step_progress": {
"epoch_loop.manual_optimization.state_dict": ANY,
"epoch_loop.manual_optimization.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.optimizer_loop.state_dict": {},
"epoch_loop.optimizer_loop.optim_progress": {
"epoch_loop.automatic_optimization.state_dict": {},
"epoch_loop.automatic_optimization.optim_progress": {
"optimizer": {
"step": {
"total": {
@@ -568,7 +568,7 @@ def test_fit_loop_reset(tmpdir):
mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=2.ckpt"))
fit_loop = trainer.fit_loop
epoch_loop = fit_loop.epoch_loop
optimizer_loop = epoch_loop.optimizer_loop
optimizer_loop = epoch_loop.automatic_optimization
assert not fit_loop.restarting
assert not epoch_loop.restarting
assert not optimizer_loop.restarting
2 changes: 1 addition & 1 deletion tests/tests_pytorch/loops/test_training_loop.py
Original file line number Diff line number Diff line change
@@ -217,7 +217,7 @@ def test_should_stop_early_stopping_conditions_met(
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0, max_epochs=100)
trainer.num_training_batches = 10
trainer.should_stop = True
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = (
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = (
current_epoch * trainer.num_training_batches
)
trainer.fit_loop.epoch_loop.batch_progress.current.ready = 10
10 changes: 5 additions & 5 deletions tests/tests_pytorch/loops/test_training_loop_flow_scalar.py
Original file line number Diff line number Diff line change
@@ -147,13 +147,13 @@ def backward(self, loss):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)

assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

@@ -215,13 +215,13 @@ def backward(self, loss):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
train_step_out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)

assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.automatic_optimization._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

@@ -298,7 +298,7 @@ def training_step(self, batch, batch_idx):
# manually check a few batches
for batch_idx, batch in enumerate(model.train_dataloader()):
kwargs = {"batch": batch, "batch_idx": batch_idx}
out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs)
out = trainer.fit_loop.epoch_loop.automatic_optimization.run(trainer.optimizers[0], kwargs)
if not batch_idx % 2:
assert out == {}

2 changes: 1 addition & 1 deletion tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
@@ -339,7 +339,7 @@ def mock_save_function(filepath, *args):
# emulate callback's calls during the training
for i, loss in enumerate(losses, 1):
# sets `trainer.global_step`
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = i
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = i
trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)})
checkpoint_callback.on_validation_end(trainer, trainer.lightning_module)
trainer.fit_loop.epoch_progress.current.completed = i # sets `trainer.current_epoch`
Original file line number Diff line number Diff line change
@@ -72,7 +72,7 @@ def test_migrate_loop_global_step_to_progress_tracking():
)
# for manual optimization
assert (
updated_checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"]["total"][
updated_checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_optimization.optim_step_progress"]["total"][
"completed"
]
== 15