Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix docs for auto_lr_find #3883

Merged
merged 2 commits into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 18 additions & 26 deletions docs/source/lr_finder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,13 @@ initial lr.
Using Lightning's built-in LR finder
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In the most basic use case, this feature can be enabled during trainer construction
with ``Trainer(auto_lr_find=True)``. When ``.fit(model)`` is called, the LR finder
will automatically run before any training is done. The ``lr`` that is found
and used will be written to the console and logged together with all other
hyperparameters of the model.

.. testcode::

# default: no automatic learning rate finder
trainer = Trainer(auto_lr_find=False)

This flag sets your learning rate which can be accessed via ``self.lr`` or ``self.learning_rate``.
To enable the learning rate finder, your :class:`~pytorch_lightning.core.LightningModule` needs to have a ``learning_rate`` or ``lr`` property.
Then, set ``Trainer(auto_lr_find=True)`` during trainer construction,
and then call ``trainer.tune(model)`` to run the LR finder. The suggested ``learning_rate``
will be written to the console and will be automatically set to your :class:`~pytorch_lightning.core.LightningModule`,
which can be accessed via ``self.learning_rate`` or ``self.lr``.

.. testcode::
.. code-block:: python

class LitModel(LightningModule):

Expand All @@ -51,31 +44,30 @@ This flag sets your learning rate which can be accessed via ``self.lr`` or ``sel

def configure_optimizers(self):
return Adam(self.parameters(), lr=(self.lr or self.learning_rate))

model = LitModel()

# finds learning rate automatically
# sets hparams.lr or hparams.learning_rate to that learning rate
trainer = Trainer(auto_lr_find=True)

To use an arbitrary value set it as auto_lr_find
trainer.tune(model)

If your model is using an arbitrary value instead of ``self.lr`` or ``self.learning_rate``, set that value as auto_lr_find

.. testcode::
.. code-block:: python

model = LitModel()

# to set to your own hparams.my_value
trainer = Trainer(auto_lr_find='my_value')

Under the hood, when you call fit it runs the learning rate finder before actually calling fit.
trainer.tune(model)

.. code-block:: python

# when you call .fit() this happens
# 1. find learning rate
# 2. actually run fit
trainer.fit(model)

If you want to inspect the results of the learning rate finder before doing any
actual training or just play around with the parameters of the algorithm, this
can be done by invoking the ``lr_find`` method of the trainer. A typical example
of this would look like
If you want to inspect the results of the learning rate finder or just play around
with the parameters of the algorithm, this can be done by invoking the ``lr_find``
method of the trainer. A typical example of this would look like

.. code-block:: python

Expand Down
11 changes: 8 additions & 3 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,22 @@ def forward(self, x):
# default used by the Trainer (no learning rate finder)
trainer = Trainer(auto_lr_find=False)

# call tune to find the lr
trainer.tune(model)

Example::

# run learning rate finder, results override hparams.learning_rate
trainer = Trainer(auto_lr_find=True)

# call tune to find the lr
trainer.tune(model)

Example::

# run learning rate finder, results override hparams.my_lr_arg
trainer = Trainer(auto_lr_find='my_lr_arg')

# call tune to find the lr
trainer.tune(model)

.. note::
See the :ref:`learning rate finder guide <lr_finder>`.

Expand Down
23 changes: 18 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ def __init__(

amp_level: The optimization level to use (O1, O2, etc...).

auto_lr_find: If set to True, will `initially` run a learning rate finder,
trying to optimize initial learning for faster convergence. Sets learning
rate in self.lr or self.learning_rate in the LightningModule.
To use a different key, set a string instead of True with the key name.
auto_lr_find: If set to True, will make trainer.tune() run a learning rate finder,
trying to optimize initial learning for faster convergence. trainer.tune() method will
set the suggested learning rate in self.lr or self.learning_rate in the LightningModule.
To use a different key set a string instead of True with the key name.

auto_scale_batch_size: If set to True, will `initially` run a batch size
finder trying to find the largest batch size that fits into memory.
Expand Down Expand Up @@ -377,8 +377,21 @@ def tune(
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
):
# TODO: temporary, need to decide if tune or separate object
r"""
Runs routines to tune hyperparameters before training.

Args:
datamodule: A instance of :class:`LightningDataModule`.

model: Model to tune.

train_dataloader: A Pytorch DataLoader with training samples. If the model has
a predefined train_dataloader method this will be skipped.

val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped

"""
# setup data, etc...
self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)

Expand Down