Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

opt_idx cleanup after optimizer loop changes #16597

Merged
merged 5 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS:
if self._fabric:
opts: MODULE_OPTIMIZERS = self._fabric_optimizers
elif use_pl_optimizer:
opts = list(self.trainer.strategy._lightning_optimizers.values())
opts = self.trainer.strategy._lightning_optimizers
else:
opts = self.trainer.optimizers

Expand Down
19 changes: 5 additions & 14 deletions src/pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(self, optimizer: Optimizer):

self._optimizer = optimizer
self._strategy: Optional[pl.strategies.Strategy] = None
self._optimizer_idx = 0
# to inject logic around the optimizer step, particularly useful with manual optimization
self._on_before_step = do_nothing_closure
self._on_after_step = do_nothing_closure
Expand All @@ -55,7 +54,7 @@ def optimizer(self) -> Optimizer:

@classmethod
def _to_lightning_optimizer(
cls, optimizer: Union[Optimizer, "LightningOptimizer"], strategy: "pl.strategies.Strategy", opt_idx: int
cls, optimizer: Union[Optimizer, "LightningOptimizer"], strategy: "pl.strategies.Strategy"
) -> "LightningOptimizer":
if isinstance(optimizer, LightningOptimizer):
# the user could return a `LightningOptimizer` from `configure_optimizers`, see test:
Expand All @@ -64,7 +63,6 @@ def _to_lightning_optimizer(
else:
lightning_optimizer = cls(optimizer)
lightning_optimizer._strategy = proxy(strategy)
lightning_optimizer._optimizer_idx = opt_idx
return lightning_optimizer

@contextmanager
Expand Down Expand Up @@ -102,7 +100,7 @@ def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> An
Example::

# Scenario for a GAN using manual optimization
def training_step(...):
def training_step(self, batch, batch_idx):
opt_gen, opt_dis = self.optimizers()

...
Expand All @@ -124,7 +122,7 @@ def training_step(...):


# A more advanced example
def training_step(self, batch, batch_idx, ...):
def training_step(self, batch, batch_idx):
opt_gen, opt_dis = self.optimizers()

...
Expand Down Expand Up @@ -218,16 +216,9 @@ def _configure_optimizers(
for opt_dict in optim_conf:
_validate_optim_conf(opt_dict)
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
scheduler_dict = (
lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx)
if isinstance(scheduler, dict)
else {"scheduler": scheduler, "opt_idx": opt_idx}
)

scheduler_dict = lambda scheduler: dict(scheduler) if isinstance(scheduler, dict) else {"scheduler": scheduler}
lr_schedulers = [
scheduler_dict(opt_dict["lr_scheduler"], opt_idx)
for opt_idx, opt_dict in enumerate(optim_conf)
if "lr_scheduler" in opt_dict
scheduler_dict(opt_dict["lr_scheduler"]) for opt_dict in optim_conf if "lr_scheduler" in opt_dict
]
# single list or tuple, multiple optimizer
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizable) for opt in optim_conf):
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/loops/optimization/manual_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def run(self, kwargs: OrderedDict) -> _OUTPUTS_TYPE:

def on_run_start(self) -> None:
# inject logic around the optimizer step
for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items():
for lightning_optimizer in self.trainer.strategy._lightning_optimizers:
lightning_optimizer._on_before_step = self._on_before_step
lightning_optimizer._on_after_step = self._on_after_step

Expand All @@ -119,7 +119,7 @@ def on_run_end(self) -> _OUTPUTS_TYPE:
"""Returns the result of this loop, i.e., the post-processed outputs from the training step."""
output, self._output = self._output, {} # free memory
# reset logic around the optimizer step
for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items():
for lightning_optimizer in self.trainer.strategy._lightning_optimizers:
lightning_optimizer._on_before_step = do_nothing_closure
lightning_optimizer._on_after_step = do_nothing_closure
return output
Expand Down
8 changes: 3 additions & 5 deletions src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
self._model: Optional[Module] = None
self._launcher: Optional[_Launcher] = None
self._optimizers: List[Optimizer] = []
self._lightning_optimizers: Dict[int, LightningOptimizer] = {}
self._lightning_optimizers: List[LightningOptimizer] = []
self.lr_scheduler_configs: List[LRSchedulerConfig] = []

@property
Expand Down Expand Up @@ -108,9 +108,7 @@ def optimizers(self) -> List[Optimizer]:
@optimizers.setter
def optimizers(self, optimizers: List[Optimizer]) -> None:
self._optimizers = optimizers
self._lightning_optimizers = {
idx: LightningOptimizer._to_lightning_optimizer(opt, self, idx) for idx, opt in enumerate(self.optimizers)
}
self._lightning_optimizers = [LightningOptimizer._to_lightning_optimizer(opt, self) for opt in optimizers]

def connect(self, model: "pl.LightningModule") -> None:
"""Called by the accelerator to connect the accelerator and the model with this plugin."""
Expand Down Expand Up @@ -537,7 +535,7 @@ def dispatch(self, trainer: "pl.Trainer") -> None:
def __getstate__(self) -> Dict:
# `LightningOptimizer` overrides `self.__class__` so they cannot be pickled
state = dict(vars(self)) # copy
state["_lightning_optimizers"] = {}
state["_lightning_optimizers"] = []
return state

def __setstate__(self, state: Dict) -> None:
Expand Down
6 changes: 1 addition & 5 deletions tests/tests_pytorch/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,7 @@ def test_state():
lightning_dict = {
k: v
for k, v in lightning_optimizer.__dict__.items()
if k
not in {"_optimizer", "_optimizer_idx", "_strategy", "_lightning_module", "_on_before_step", "_on_after_step"}
if k not in {"_optimizer", "_strategy", "_lightning_module", "_on_before_step", "_on_after_step"}
}

assert lightning_dict == optimizer.__dict__
Expand Down Expand Up @@ -192,9 +191,6 @@ def test_lightning_optimizer_automatic_optimization_optimizer_step(tmpdir):
"""Test overriding step works in automatic_optimization."""

class TestModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx=None):
return super().training_step(batch, batch_idx)

def training_epoch_end(self, outputs):
...

Expand Down