Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learning Rate finder #1347

Merged
merged 37 commits into from
Apr 10, 2020
Merged
Changes from 24 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
02791fc
initial structure
Mar 31, 2020
49e7645
new_trainer_mixing
Apr 2, 2020
ba1d7ba
rebase
Apr 2, 2020
096e979
incorporate suggestions
Apr 3, 2020
a4e25d6
update CHANGELOG.md
Apr 3, 2020
d2bf01b
initial docs
Apr 3, 2020
7aa2465
Merge remote-tracking branch 'upstream/master' into lr_finder
Apr 3, 2020
992f137
fixes based on reviews
Apr 5, 2020
f084e91
added trainer arg
Apr 5, 2020
1d7f85e
update docs
Apr 5, 2020
c54f48e
Merge remote-tracking branch 'upstream/master' into lr_finder
Apr 5, 2020
dd20940
added saving/restore of model state
Apr 6, 2020
09c9a76
initial tests
Apr 6, 2020
a1dc3f4
fix styling
Apr 6, 2020
832478d
added more tests
Apr 6, 2020
1522b88
fix docs, backward compatility and progressbar
Apr 6, 2020
7cb0980
fix styling
Apr 6, 2020
58d01cb
docs update
Apr 6, 2020
397ae90
Merge remote-tracking branch 'upstream/master' into lr_finder
Apr 7, 2020
c089bbb
updates based on review
Apr 7, 2020
a17f997
changed saving to standard functions
Apr 7, 2020
30aa383
consistent naming
Apr 7, 2020
9835bc8
fix formatting
Apr 7, 2020
b46cb38
improve docs, added support for nested fields, improve codecov
Apr 8, 2020
07236c6
Merge remote-tracking branch 'upstream/master' into lr_finder
Apr 10, 2020
c30f714
update CHANGELOG.md
Apr 10, 2020
bf8c7c5
Update lr_finder.rst
williamFalcon Apr 10, 2020
744bee8
Update pytorch_lightning/trainer/trainer.py
williamFalcon Apr 10, 2020
df2092f
Update trainer.py
williamFalcon Apr 10, 2020
23767c4
Merge branch 'master' into lr_finder
williamFalcon Apr 10, 2020
c58fd00
Update CHANGELOG.md
Borda Apr 10, 2020
be0d4f3
Update path
Borda Apr 10, 2020
43e4843
restoring
Borda Apr 10, 2020
d085e1a
test
Borda Apr 10, 2020
200a35f
attribs
Borda Apr 10, 2020
692b099
docs
Borda Apr 10, 2020
e68fd97
doc typo
Borda Apr 10, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added option to run without an optimizer by returning `None` from `configure_optimizers`. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279))
- Added a warning when the number of data loader workers is small. ([#1378](https://github.com/PyTorchLightning/pytorch-lightning/pull/1378))

- Added learining rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347))

### Changed

- Changed `progress_bar_refresh_rate` trainer flag to disable progress bar when set to 0. ([#1108](https://github.com/PyTorchLightning/pytorch-lightning/pull/1108))
Binary file added docs/source/_images/trainer/lr_finder.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -66,6 +66,7 @@ PyTorch Lightning Documentation
fast_training
hooks
hyperparameters
lr_finder
multi_gpu
weights_loading
optimizers
75 changes: 75 additions & 0 deletions docs/source/lr_finder.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
Learning Rate Finder
--------------------

For training deep neural networks, selecting a good learning rate is essential
for both better performance and faster convergence. Even optimizers such as
`Adam` that are self-adjusting the learning rate can benefit from more optimal
choices.

To reduce the amount of guesswork concerning choosing a good initial learning
rate, a `learning rate finder` can be used. As described in this `paper <https://arxiv.org/abs/1506.01186>`_
a learning rate finder does a small run where the learning rate is increased
after each processed batch and the corresponding loss is logged. The result of
this is a `lr` vs. `loss` plot that can be used as guidence for choosing a optimal
initial lr.

.. warning:: For the moment, this feature only works with models having a single
optimizer.

Using Lightnings build-in LR finder
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In the most basic use case, this feature can be enabled during trainer construction
with ``Trainer(auto_lr_find=True)``. When ``.fit(model)`` is called, the lr finder
will automatically be run before any training is done. The ``lr`` that is found
and used will be written to the console and logged together with all other
hyperparameters of the model.

.. note:: If ``auto_lr_find=True``, it is expected that the ``hparams`` of the
model either has a ``lr`` or ``learning_rate`` field that can be overridden.
Additionally ``auto_lr_find`` can be set to a string ``s``, which will then
try to override ``model.hparams.s``. In both cases, if the respective fields
are not found, an error will be thrown.

If you want to inspect the results of the learning rate finder before doing any
actual training or just play around with the parameters of the algorithm, this
can be done by invoking the ``lr_find`` method of the trainer. A typical example
of this would look like

.. code-block:: python

model = MyModelClass(hparams)
trainer = pl.Trainer()

# Run learning rate finder
lr_finder = trainer.lr_find(model)

# Results can be found in
lr_finder.results

# Plot with
fig = lr_finder.plot(suggest=True)
fig.show()

# Pick point based on plot, or get suggestion
new_lr = lr_finder.suggestion()

# update hparams of the model
model.hparams.lr = new_lr

# Fit model
trainer.fit(model)

The figure produced by ``lr_finder.plot()`` should look something like the figure
below. It is recommended to not pick the learning rate that achives the lowest
loss, but instead something in the middle of the sharpest downward slope (red point).
This is the point returned py ``lr_finder.suggestion()``.

.. figure:: /_images/trainer/lr_finder.png

The parameters of the algorithm can be seen below.

.. autoclass:: pytorch_lightning.trainer.lr_finder.TrainerLRFinderMixin
:members: lr_find
:noindex:
:exclude-members: _run_lr_finder_internally, save_checkpoint, restore
21 changes: 21 additions & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -135,6 +135,27 @@ def forward(self, x):
# default used by the Trainer
trainer = Trainer(amp_level='O1')

auto_lr_find
^^^^^^^^^^^^
Runs a learning rate finder algorithm (see this `paper <https://arxiv.org/abs/1506.01186>`_)
before any training, to find optimal initial learning rate.

.. code-block:: python

# default used by the Trainer (no learning rate finder)
trainer = Trainer(auto_lr_find=False)

Example::

# run learning rate finder, results override hparams.learning_rate
trainer = Trainer(auto_lr_find=True)

# run learning rate finder, results override hparams.my_lr_arg
trainer = Trainer(auto_lr_find='my_lr_arg')

.. note::
See the `learning rate finder guide <lr_finder.rst>`_

benchmark
^^^^^^^^^

Loading