Skip to content

Commit

Permalink
doc update (#3894)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Oct 6, 2020
1 parent 3ab43dd commit f745c4a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 21 deletions.
46 changes: 25 additions & 21 deletions docs/source/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,34 +54,38 @@ longer training time. Inspired by https://github.com/BlackHC/toma.
# DEFAULT (ie: don't scale batch size automatically)
trainer = Trainer(auto_scale_batch_size=None)
# Autoscale batch size
# Autoscale batch size
trainer = Trainer(auto_scale_batch_size=None|'power'|'binsearch')
# find the batch size
trainer.tune(model)
Currently, this feature supports two modes `'power'` scaling and `'binsearch'`
scaling. In `'power'` scaling, starting from a batch size of 1 keeps doubling
the batch size until an out-of-memory (OOM) error is encountered. Setting the
argument to `'binsearch'` continues to finetune the batch size by performing
a binary search.
scaling. In `'power'` scaling, starting from a batch size of 1 keeps doubling
the batch size until an out-of-memory (OOM) error is encountered. Setting the
argument to `'binsearch'` will initially also try doubling the batch size until
it encounters an OOM, after which it will do a binary search that will finetune the
batch size. Additionally, it should be noted that the batch size scaler cannot
search for batch sizes larger than the size of the training dataset.

.. note::

This feature expects that a `batch_size` field in the `hparams` of your model, i.e.,
`model.hparams.batch_size` should exist and will be overridden by the results of this
algorithm. Additionally, your `train_dataloader()` method should depend on this field
.. note::

This feature expects that a `batch_size` field is either located as a model attribute
i.e. `model.batch_size` or as a field in your `hparams` i.e. `model.hparams.batch_size`.
The field should exist and will be overridden by the results of this algorithm.
Additionally, your `train_dataloader()` method should depend on this field
for this feature to work i.e.

.. code-block:: python
def train_dataloader(self):
return DataLoader(train_dataset, batch_size=self.batch_size)
return DataLoader(train_dataset, batch_size=self.batch_size|self.hparams.batch_size)
.. warning::

Due to these constraints, this features does *NOT* work when passing dataloaders directly
to `.fit()`.
to `.fit()`.

The scaling algorithm has a number of parameters that the user can control by
invoking the trainer method `.scale_batch_size` themself (see description below).
Expand All @@ -93,29 +97,29 @@ invoking the trainer method `.scale_batch_size` themself (see description below)
tuner = Tuner(trainer)
# Invoke method
new_batch_size = tuner.scale_batch_size(model, ...)
new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here)
# Override old batch size
model.hparams.batch_size = new_batch_size
# Fit as normal
trainer.fit(model)
The algorithm in short works by:
1. Dumping the current state of the model and trainer
2. Iteratively until convergence or maximum number of tries `max_trials` (default 25) has been reached:
- Call `fit()` method of trainer. This evaluates `steps_per_trial` (default 3) number of
training steps. Each training step can trigger an OOM error if the tensors
(training batch, weights, gradients ect.) allocated during the steps have a
- Call `fit()` method of trainer. This evaluates `steps_per_trial` (default 3) number of
training steps. Each training step can trigger an OOM error if the tensors
(training batch, weights, gradients ect.) allocated during the steps have a
too large memory footprint.
- If an OOM error is encountered, decrease batch size else increase it.
How much the batch size is increased/decreased is determined by the choosen
stratrgy.
3. The found batch size is saved to `model.hparams.batch_size`
3. The found batch size is saved to either `model.batch_size` or `model.hparams.batch_size`
4. Restore the initial state of model and trainer

.. autoclass:: pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin
:members: scale_batch_size
.. autoclass:: pytorch_lightning.tuner.tuning.Tuner
:noindex:
:members: scale_batch_size

.. warning:: Batch size finder is not supported for DDP yet, it is coming soon.
36 changes: 36 additions & 0 deletions pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,42 @@ def scale_batch_size(self,
max_trials: int = 25,
batch_arg_name: str = 'batch_size',
**fit_kwargs):
r"""
Will iteratively try to find the largest batch size for a given model
that does not give an out of memory (OOM) error.
Args:
model: Model to fit.
mode: string setting the search mode. Either `power` or `binsearch`.
If mode is `power` we keep multiplying the batch size by 2, until
we get an OOM error. If mode is 'binsearch', we will initially
also keep multiplying by 2 and after encountering an OOM error
do a binary search between the last successful batch size and the
batch size that failed.
steps_per_trial: number of steps to run with a given batch size.
Idealy 1 should be enough to test if a OOM error occurs,
however in practise a few are needed
init_val: initial batch size to start the search with
max_trials: max number of increase in batch size done before
algorithm is terminated
batch_arg_name: name of the attribute that stores the batch size.
It is expected that the user has provided a model or datamodule that has a hyperparameter
with that name. We will look for this attribute name in the following places
- `model`
- `model.hparams`
- `model.datamodule`
- `trainer.datamodule` (the datamodule passed to the tune method)
**fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
or datamodule.
"""
return scale_batch_size(
self.trainer, model, mode, steps_per_trial, init_val, max_trials, batch_arg_name, **fit_kwargs
)
Expand Down

0 comments on commit f745c4a

Please sign in to comment.