Skip to content

Commit 01b152f

Browse files
authored
Update docs for multiple optimizers in 2.0 (#16588)
1 parent fda354a commit 01b152f

File tree

6 files changed

+51
-171
lines changed

6 files changed

+51
-171
lines changed

docs/source-pytorch/common/lightning_module.rst

+1-2
Original file line numberDiff line numberDiff line change
@@ -1155,9 +1155,8 @@ See :ref:`manual optimization <common/optimization:Manual optimization>` for det
11551155
self.manual_backward(loss)
11561156
opt.step()
11571157
1158-
This is recommended only if using 2+ optimizers AND if you know how to perform the optimization procedure properly. Note
1159-
that automatic optimization can still be used with multiple optimizers by relying on the ``optimizer_idx`` parameter.
11601158
Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research.
1159+
It is required when you are using 2+ optimizers because with automatic optimization, you can only use one optimizer.
11611160

11621161
.. code-block:: python
11631162

docs/source-pytorch/common/optimization.rst

+14-158
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Lightning offers two modes for managing the optimization process:
1414
For the majority of research cases, **automatic optimization** will do the right thing for you and it is what most
1515
users should use.
1616

17-
For advanced/expert users who want to do esoteric optimization schedules or techniques, use **manual optimization**.
17+
For more advanced use cases like multiple optimizers, esoteric optimization schedules or techniques, use **manual optimization**.
1818

1919
.. _manual_optimization:
2020

@@ -39,7 +39,7 @@ Under the hood, Lightning does the following:
3939
for batch in data:
4040
4141
def closure():
42-
loss = model.training_step(batch, batch_idx, ...)
42+
loss = model.training_step(batch, batch_idx)
4343
optimizer.zero_grad()
4444
loss.backward()
4545
return loss
@@ -48,33 +48,13 @@ Under the hood, Lightning does the following:
4848
4949
lr_scheduler.step()
5050
51-
In the case of multiple optimizers, Lightning does the following:
52-
53-
.. code-block:: python
54-
55-
for epoch in epochs:
56-
for batch in data:
57-
for opt in optimizers:
58-
59-
def closure():
60-
loss = model.training_step(batch, batch_idx, optimizer_idx)
61-
opt.zero_grad()
62-
loss.backward()
63-
return loss
64-
65-
opt.step(closure)
66-
67-
for lr_scheduler in lr_schedulers:
68-
lr_scheduler.step()
69-
7051
As can be seen in the code snippet above, Lightning defines a closure with ``training_step()``, ``optimizer.zero_grad()``
7152
and ``loss.backward()`` for the optimization. This mechanism is in place to support optimizers which operate on the
7253
output of the closure (e.g. the loss) or need to call the closure several times (e.g. :class:`~torch.optim.LBFGS`).
7354

74-
.. warning::
75-
76-
Before v1.2.2, Lightning internally calls ``backward``, ``step`` and ``zero_grad`` in the order.
77-
From v1.2.2, the order is changed to ``zero_grad``, ``backward`` and ``step``.
55+
Should you still require the flexibility of calling ``.zero_grad()``, ``.backward()``, or ``.step()`` yourself, you can
56+
always switch to :ref:`manual optimization <manual_optimization>`.
57+
Manual optimization is required if you wish to work with multiple optimizers.
7858

7959

8060
Gradient Accumulation
@@ -83,113 +63,6 @@ Gradient Accumulation
8363
.. include:: ../common/gradient_accumulation.rst
8464

8565

86-
Use Multiple Optimizers (like GANs)
87-
===================================
88-
89-
To use multiple optimizers (optionally with learning rate schedulers), return two or more optimizers from
90-
:meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers`.
91-
92-
.. testcode:: python
93-
94-
# two optimizers, no schedulers
95-
def configure_optimizers(self):
96-
return Adam(...), SGD(...)
97-
98-
99-
# two optimizers, one scheduler for adam only
100-
def configure_optimizers(self):
101-
opt1 = Adam(...)
102-
opt2 = SGD(...)
103-
optimizers = [opt1, opt2]
104-
lr_schedulers = {"scheduler": ReduceLROnPlateau(opt1, ...), "monitor": "metric_to_track"}
105-
return optimizers, lr_schedulers
106-
107-
108-
# two optimizers, two schedulers
109-
def configure_optimizers(self):
110-
opt1 = Adam(...)
111-
opt2 = SGD(...)
112-
return [opt1, opt2], [StepLR(opt1, ...), OneCycleLR(opt2, ...)]
113-
114-
Under the hood, Lightning will call each optimizer sequentially:
115-
116-
.. code-block:: python
117-
118-
for epoch in epochs:
119-
for batch in data:
120-
for opt in optimizers:
121-
loss = train_step(batch, batch_idx, optimizer_idx)
122-
opt.zero_grad()
123-
loss.backward()
124-
opt.step()
125-
126-
for lr_scheduler in lr_schedulers:
127-
lr_scheduler.step()
128-
129-
130-
Step Optimizers at Arbitrary Intervals
131-
=======================================
132-
133-
To do more interesting things with your optimizers such as learning rate warm-up or odd scheduling,
134-
override the :meth:`~pytorch_lightning.core.module.LightningModule.optimizer_step` function.
135-
136-
.. warning::
137-
If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter to
138-
``optimizer.step()`` function as shown in the examples because ``training_step()``, ``optimizer.zero_grad()``,
139-
``loss.backward()`` are called in the closure function.
140-
141-
For example, here step optimizer A every batch and optimizer B every 2 batches.
142-
143-
.. testcode:: python
144-
145-
# Alternating schedule for optimizer steps (e.g. GANs)
146-
def optimizer_step(
147-
self,
148-
epoch,
149-
batch_idx,
150-
optimizer,
151-
optimizer_idx,
152-
optimizer_closure,
153-
):
154-
# update generator every step
155-
if optimizer_idx == 0:
156-
optimizer.step(closure=optimizer_closure)
157-
158-
# update discriminator every 2 steps
159-
if optimizer_idx == 1:
160-
if (batch_idx + 1) % 2 == 0:
161-
# the closure (which includes the `training_step`) will be executed by `optimizer.step`
162-
optimizer.step(closure=optimizer_closure)
163-
else:
164-
# call the closure by itself to run `training_step` + `backward` without an optimizer step
165-
optimizer_closure()
166-
167-
# ...
168-
# add as many optimizers as you want
169-
170-
Here we add a manual learning rate warm-up without an lr scheduler.
171-
172-
.. testcode:: python
173-
174-
# learning rate warm-up
175-
def optimizer_step(
176-
self,
177-
epoch,
178-
batch_idx,
179-
optimizer,
180-
optimizer_idx,
181-
optimizer_closure,
182-
):
183-
# update params
184-
optimizer.step(closure=optimizer_closure)
185-
186-
# skip the first 500 steps
187-
if self.trainer.global_step < 500:
188-
lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0)
189-
for pg in optimizer.param_groups:
190-
pg["lr"] = lr_scale * self.hparams.learning_rate
191-
192-
19366
Access your Own Optimizer
19467
=========================
19568

@@ -206,7 +79,6 @@ to perform a step, Lightning won't be able to support accelerators, precision an
20679
epoch,
20780
batch_idx,
20881
optimizer,
209-
optimizer_idx,
21082
optimizer_closure,
21183
):
21284
optimizer.step(closure=optimizer_closure)
@@ -220,7 +92,6 @@ to perform a step, Lightning won't be able to support accelerators, precision an
22092
epoch,
22193
batch_idx,
22294
optimizer,
223-
optimizer_idx,
22495
optimizer_closure,
22596
):
22697
optimizer = optimizer.optimizer
@@ -248,7 +119,7 @@ If you are using native PyTorch schedulers, there is no need to override this ho
248119
return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}]
249120
250121
251-
def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
122+
def lr_scheduler_step(self, scheduler, metric):
252123
scheduler.step(epoch=self.current_epoch) # timm's scheduler need the epoch value
253124
254125
@@ -259,7 +130,7 @@ Configure Gradient Clipping
259130

260131
To configure custom gradient clipping, consider overriding
261132
the :meth:`~pytorch_lightning.core.module.LightningModule.configure_gradient_clipping` method.
262-
Attributes ``gradient_clip_val`` and ``gradient_clip_algorithm`` from Trainer will be passed in the
133+
The attributes ``gradient_clip_val`` and ``gradient_clip_algorithm`` from Trainer will be passed in the
263134
respective arguments here and Lightning will handle gradient clipping for you. In case you want to set
264135
different values for your arguments of your choice and let Lightning handle the gradient clipping, you can
265136
use the inbuilt :meth:`~pytorch_lightning.core.module.LightningModule.clip_gradients` method and pass
@@ -270,31 +141,16 @@ the arguments along with your optimizer.
270141
method. If you want to customize gradient clipping, consider using
271142
:meth:`~pytorch_lightning.core.module.LightningModule.configure_gradient_clipping` method.
272143

273-
For example, here we will apply gradient clipping only to the gradients associated with optimizer A.
144+
For example, here we will apply a stronger gradient clipping after a certain number of epochs:
274145

275146
.. testcode:: python
276147

277-
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
278-
if optimizer_idx == 0:
279-
# Lightning will handle the gradient clipping
280-
self.clip_gradients(
281-
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
282-
)
283-
284-
Here we configure gradient clipping differently for optimizer B.
285-
286-
.. testcode:: python
148+
def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm):
149+
if self.current_epoch > 5:
150+
gradient_clip_val = gradient_clip_val * 2
287151

288-
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
289-
if optimizer_idx == 0:
290-
# Lightning will handle the gradient clipping
291-
self.clip_gradients(
292-
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
293-
)
294-
elif optimizer_idx == 1:
295-
self.clip_gradients(
296-
optimizer, gradient_clip_val=gradient_clip_val * 2, gradient_clip_algorithm=gradient_clip_algorithm
297-
)
152+
# Lightning will handle the gradient clipping
153+
self.clip_gradients(optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm)
298154

299155

300156
Total Stepping Batches
@@ -312,4 +168,4 @@ distributed setting into consideration so you don't have to derive it manually.
312168
scheduler = torch.optim.lr_scheduler.OneCycleLR(
313169
optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
314170
)
315-
return [optimizer], [scheduler]
171+
return optimizer, scheduler

docs/source-pytorch/guides/speed.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ This is enabled by default on ``torch>=2.0.0``.
431431
.. testcode::
432432

433433
class Model(LightningModule):
434-
def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
434+
def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
435435
optimizer.zero_grad(set_to_none=True)
436436

437437

docs/source-pytorch/model/build_model_advanced.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Inject custom code anywhere in the Training loop using any of the 20+ methods (:
1717
.. testcode::
1818

1919
class LitModel(pl.LightningModule):
20-
def backward(self, loss, optimizer, optimizer_idx):
20+
def backward(self, loss):
2121
loss.backward()
2222

2323
----

docs/source-pytorch/model/manual_optimization.rst

+33-8
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@ Manual Optimization
33
*******************
44

55
For advanced research topics like reinforcement learning, sparse coding, or GAN research, it may be desirable to
6-
manually manage the optimization process.
6+
manually manage the optimization process, especially when dealing with multiple optimizers at the same time.
77

8-
This is only recommended for experts who need ultimate flexibility.
9-
Lightning will handle only accelerator, precision and strategy logic.
10-
The users are left with ``optimizer.zero_grad()``, gradient accumulation, model toggling, etc..
8+
In this mode, Lightning will handle only accelerator, precision and strategy logic.
9+
The users are left with ``optimizer.zero_grad()``, gradient accumulation, optimizer toggling, etc..
1110

1211
To manually optimize, do the following:
1312

@@ -18,6 +17,7 @@ To manually optimize, do the following:
1817
* ``optimizer.zero_grad()`` to clear the gradients from the previous training step
1918
* ``self.manual_backward(loss)`` instead of ``loss.backward()``
2019
* ``optimizer.step()`` to update your model parameters
20+
* ``self.toggle_optimizer()`` and ``self.untoggle_optimizer()`` if needed
2121

2222
Here is a minimal example of manual optimization.
2323

@@ -39,10 +39,6 @@ Here is a minimal example of manual optimization.
3939
self.manual_backward(loss)
4040
opt.step()
4141

42-
.. warning::
43-
Before 1.2, ``optimizer.step()`` was calling ``optimizer.zero_grad()`` internally.
44-
From 1.2, it is left to the user's expertise.
45-
4642
.. tip::
4743
Be careful where you call ``optimizer.zero_grad()``, or your model won't converge.
4844
It is good practice to call ``optimizer.zero_grad()`` before ``self.manual_backward(loss)``.
@@ -132,6 +128,7 @@ To perform gradient clipping with one optimizer with manual optimization, you ca
132128
.. warning::
133129
* Note that ``configure_gradient_clipping()`` won't be called in Manual Optimization. Instead consider using ``self. clip_gradients()`` manually like in the example above.
134130

131+
135132
Use Multiple Optimizers (like GANs)
136133
===================================
137134

@@ -285,6 +282,34 @@ If you want to call schedulers that require a metric value after each epoch, con
285282
if isinstance(sch, torch.optim.lr_scheduler.ReduceLROnPlateau):
286283
sch.step(self.trainer.callback_metrics["loss"])
287284

285+
286+
Optimizer Steps at Different Frequencies
287+
========================================
288+
289+
In manual optimization, you are free to ``step()`` one optimizer more often than another one.
290+
For example, here we step the optimizer for the *discriminator* weights twice as often as the optimizer for the *generator*.
291+
292+
.. testcode:: python
293+
294+
# Alternating schedule for optimizer steps (e.g. GANs)
295+
def training_step(self, batch, batch_idx):
296+
g_opt, d_opt = self.optimizers()
297+
...
298+
299+
# update discriminator every other step
300+
d_opt.zero_grad()
301+
self.manual_backward(errD)
302+
if (batch_idx + 1) % 2 == 0:
303+
d_opt.step()
304+
305+
...
306+
307+
# update generator every step
308+
g_opt.zero_grad()
309+
self.manual_backward(errG)
310+
g_opt.step()
311+
312+
288313
Use Closure for LBFGS-like Optimizers
289314
=====================================
290315

docs/source-pytorch/starter/introduction.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ Inject custom code anywhere in the Training loop using any of the 20+ methods (:
282282
.. testcode::
283283

284284
class LitAutoEncoder(pl.LightningModule):
285-
def backward(self, loss, optimizer, optimizer_idx):
285+
def backward(self, loss):
286286
loss.backward()
287287

288288
----

0 commit comments

Comments
 (0)