Skip to content

Commit

Permalink
Fix docs for auto_lr_find (#3883)
Browse files Browse the repository at this point in the history
* Fix docs for auto_lr_find

* change testcode to codeblock

we are not showing a complete example here
  • Loading branch information
edenlightning authored Oct 6, 2020
1 parent 0823cdd commit 2119184
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 34 deletions.
44 changes: 18 additions & 26 deletions docs/source/lr_finder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,13 @@ initial lr.
Using Lightning's built-in LR finder
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In the most basic use case, this feature can be enabled during trainer construction
with ``Trainer(auto_lr_find=True)``. When ``.fit(model)`` is called, the LR finder
will automatically run before any training is done. The ``lr`` that is found
and used will be written to the console and logged together with all other
hyperparameters of the model.

.. testcode::

# default: no automatic learning rate finder
trainer = Trainer(auto_lr_find=False)

This flag sets your learning rate which can be accessed via ``self.lr`` or ``self.learning_rate``.
To enable the learning rate finder, your :class:`~pytorch_lightning.core.LightningModule` needs to have a ``learning_rate`` or ``lr`` property.
Then, set ``Trainer(auto_lr_find=True)`` during trainer construction,
and then call ``trainer.tune(model)`` to run the LR finder. The suggested ``learning_rate``
will be written to the console and will be automatically set to your :class:`~pytorch_lightning.core.LightningModule`,
which can be accessed via ``self.learning_rate`` or ``self.lr``.

.. testcode::
.. code-block:: python
class LitModel(LightningModule):
Expand All @@ -51,31 +44,30 @@ This flag sets your learning rate which can be accessed via ``self.lr`` or ``sel
def configure_optimizers(self):
return Adam(self.parameters(), lr=(self.lr or self.learning_rate))
model = LitModel()
# finds learning rate automatically
# sets hparams.lr or hparams.learning_rate to that learning rate
trainer = Trainer(auto_lr_find=True)
To use an arbitrary value set it as auto_lr_find
trainer.tune(model)
If your model is using an arbitrary value instead of ``self.lr`` or ``self.learning_rate``, set that value as auto_lr_find

.. testcode::
.. code-block:: python
model = LitModel()
# to set to your own hparams.my_value
trainer = Trainer(auto_lr_find='my_value')
Under the hood, when you call fit it runs the learning rate finder before actually calling fit.
trainer.tune(model)
.. code-block:: python
# when you call .fit() this happens
# 1. find learning rate
# 2. actually run fit
trainer.fit(model)
If you want to inspect the results of the learning rate finder before doing any
actual training or just play around with the parameters of the algorithm, this
can be done by invoking the ``lr_find`` method of the trainer. A typical example
of this would look like
If you want to inspect the results of the learning rate finder or just play around
with the parameters of the algorithm, this can be done by invoking the ``lr_find``
method of the trainer. A typical example of this would look like

.. code-block:: python
Expand Down
11 changes: 8 additions & 3 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,22 @@ def forward(self, x):
# default used by the Trainer (no learning rate finder)
trainer = Trainer(auto_lr_find=False)
# call tune to find the lr
trainer.tune(model)
Example::
# run learning rate finder, results override hparams.learning_rate
trainer = Trainer(auto_lr_find=True)
# call tune to find the lr
trainer.tune(model)
Example::
# run learning rate finder, results override hparams.my_lr_arg
trainer = Trainer(auto_lr_find='my_lr_arg')
# call tune to find the lr
trainer.tune(model)
.. note::
See the :ref:`learning rate finder guide <lr_finder>`.
Expand Down
23 changes: 18 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ def __init__(
amp_level: The optimization level to use (O1, O2, etc...).
auto_lr_find: If set to True, will `initially` run a learning rate finder,
trying to optimize initial learning for faster convergence. Sets learning
rate in self.lr or self.learning_rate in the LightningModule.
To use a different key, set a string instead of True with the key name.
auto_lr_find: If set to True, will make trainer.tune() run a learning rate finder,
trying to optimize initial learning for faster convergence. trainer.tune() method will
set the suggested learning rate in self.lr or self.learning_rate in the LightningModule.
To use a different key set a string instead of True with the key name.
auto_scale_batch_size: If set to True, will `initially` run a batch size
finder trying to find the largest batch size that fits into memory.
Expand Down Expand Up @@ -377,8 +377,21 @@ def tune(
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
):
# TODO: temporary, need to decide if tune or separate object
r"""
Runs routines to tune hyperparameters before training.
Args:
datamodule: A instance of :class:`LightningDataModule`.
model: Model to tune.
train_dataloader: A Pytorch DataLoader with training samples. If the model has
a predefined train_dataloader method this will be skipped.
val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped
"""
# setup data, etc...
self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)

Expand Down

0 comments on commit 2119184

Please sign in to comment.