Skip to content

Commit

Permalink
drop opt_idx from to_lightning_optimizer
Browse files Browse the repository at this point in the history
update docs


update
  • Loading branch information
awaelchli committed Feb 1, 2023
1 parent 6a56586 commit c6ea75f
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 24 deletions.
2 changes: 1 addition & 1 deletion src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS:
if self._fabric:
opts: MODULE_OPTIMIZERS = self._fabric_optimizers
elif use_pl_optimizer:
opts = list(self.trainer.strategy._lightning_optimizers.values())
opts = self.trainer.strategy._lightning_optimizers
else:
opts = self.trainer.optimizers

Expand Down
17 changes: 5 additions & 12 deletions src/pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(self, optimizer: Optimizer):

self._optimizer = optimizer
self._strategy: Optional[pl.strategies.Strategy] = None
self._optimizer_idx = 0
# to inject logic around the optimizer step, particularly useful with manual optimization
self._on_before_step = do_nothing_closure
self._on_after_step = do_nothing_closure
Expand All @@ -55,7 +54,7 @@ def optimizer(self) -> Optimizer:

@classmethod
def _to_lightning_optimizer(
cls, optimizer: Union[Optimizer, "LightningOptimizer"], strategy: "pl.strategies.Strategy", opt_idx: int
cls, optimizer: Union[Optimizer, "LightningOptimizer"], strategy: "pl.strategies.Strategy"
) -> "LightningOptimizer":
if isinstance(optimizer, LightningOptimizer):
# the user could return a `LightningOptimizer` from `configure_optimizers`, see test:
Expand All @@ -64,7 +63,6 @@ def _to_lightning_optimizer(
else:
lightning_optimizer = cls(optimizer)
lightning_optimizer._strategy = proxy(strategy)
lightning_optimizer._optimizer_idx = opt_idx
return lightning_optimizer

@contextmanager
Expand Down Expand Up @@ -102,7 +100,7 @@ def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> An
Example::
# Scenario for a GAN using manual optimization
def training_step(...):
def training_step(self, batch, batch_idx):
opt_gen, opt_dis = self.optimizers()
...
Expand All @@ -124,7 +122,7 @@ def training_step(...):
# A more advanced example
def training_step(self, batch, batch_idx, ...):
def training_step(self, batch, batch_idx):
opt_gen, opt_dis = self.optimizers()
...
Expand Down Expand Up @@ -219,15 +217,10 @@ def _configure_optimizers(
_validate_optim_conf(opt_dict)
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
scheduler_dict = (
lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx)
if isinstance(scheduler, dict)
else {"scheduler": scheduler, "opt_idx": opt_idx}
lambda scheduler: dict(scheduler) if isinstance(scheduler, dict) else {"scheduler": scheduler}
)

lr_schedulers = [
scheduler_dict(opt_dict["lr_scheduler"], opt_idx)
for opt_idx, opt_dict in enumerate(optim_conf)
if "lr_scheduler" in opt_dict
scheduler_dict(opt_dict["lr_scheduler"]) for opt_dict in optim_conf if "lr_scheduler" in opt_dict
]
# single list or tuple, multiple optimizer
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizable) for opt in optim_conf):
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/loops/optimization/manual_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def run(self, kwargs: OrderedDict) -> _OUTPUTS_TYPE:

def on_run_start(self) -> None:
# inject logic around the optimizer step
for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items():
for lightning_optimizer in self.trainer.strategy._lightning_optimizers:
lightning_optimizer._on_before_step = self._on_before_step
lightning_optimizer._on_after_step = self._on_after_step

Expand All @@ -119,7 +119,7 @@ def on_run_end(self) -> _OUTPUTS_TYPE:
"""Returns the result of this loop, i.e., the post-processed outputs from the training step."""
output, self._output = self._output, {} # free memory
# reset logic around the optimizer step
for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items():
for lightning_optimizer in self.trainer.strategy._lightning_optimizers:
lightning_optimizer._on_before_step = do_nothing_closure
lightning_optimizer._on_after_step = do_nothing_closure
return output
Expand Down
8 changes: 3 additions & 5 deletions src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
self._model: Optional[Module] = None
self._launcher: Optional[_Launcher] = None
self._optimizers: List[Optimizer] = []
self._lightning_optimizers: Dict[int, LightningOptimizer] = {}
self._lightning_optimizers: List[LightningOptimizer] = []
self.lr_scheduler_configs: List[LRSchedulerConfig] = []

@property
Expand Down Expand Up @@ -108,9 +108,7 @@ def optimizers(self) -> List[Optimizer]:
@optimizers.setter
def optimizers(self, optimizers: List[Optimizer]) -> None:
self._optimizers = optimizers
self._lightning_optimizers = {
idx: LightningOptimizer._to_lightning_optimizer(opt, self, idx) for idx, opt in enumerate(self.optimizers)
}
self._lightning_optimizers = [LightningOptimizer._to_lightning_optimizer(opt, self) for opt in optimizers]

def connect(self, model: "pl.LightningModule") -> None:
"""Called by the accelerator to connect the accelerator and the model with this plugin."""
Expand Down Expand Up @@ -537,7 +535,7 @@ def dispatch(self, trainer: "pl.Trainer") -> None:
def __getstate__(self) -> Dict:
# `LightningOptimizer` overrides `self.__class__` so they cannot be pickled
state = dict(vars(self)) # copy
state["_lightning_optimizers"] = {}
state["_lightning_optimizers"] = []
return state

def __setstate__(self, state: Dict) -> None:
Expand Down
5 changes: 1 addition & 4 deletions tests/tests_pytorch/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_state():
k: v
for k, v in lightning_optimizer.__dict__.items()
if k
not in {"_optimizer", "_optimizer_idx", "_strategy", "_lightning_module", "_on_before_step", "_on_after_step"}
not in {"_optimizer", "_strategy", "_lightning_module", "_on_before_step", "_on_after_step"}
}

assert lightning_dict == optimizer.__dict__
Expand Down Expand Up @@ -192,9 +192,6 @@ def test_lightning_optimizer_automatic_optimization_optimizer_step(tmpdir):
"""Test overriding step works in automatic_optimization."""

class TestModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx=None):
return super().training_step(batch, batch_idx)

def training_epoch_end(self, outputs):
...

Expand Down

0 comments on commit c6ea75f

Please sign in to comment.