Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch scaler docs #3894

Merged
merged 2 commits into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions docs/source/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,34 +54,38 @@ longer training time. Inspired by https://github.com/BlackHC/toma.
# DEFAULT (ie: don't scale batch size automatically)
trainer = Trainer(auto_scale_batch_size=None)
# Autoscale batch size
# Autoscale batch size
trainer = Trainer(auto_scale_batch_size=None|'power'|'binsearch')
# find the batch size
trainer.tune(model)
Currently, this feature supports two modes `'power'` scaling and `'binsearch'`
scaling. In `'power'` scaling, starting from a batch size of 1 keeps doubling
the batch size until an out-of-memory (OOM) error is encountered. Setting the
argument to `'binsearch'` continues to finetune the batch size by performing
a binary search.
scaling. In `'power'` scaling, starting from a batch size of 1 keeps doubling
the batch size until an out-of-memory (OOM) error is encountered. Setting the
argument to `'binsearch'` will initially also try doubling the batch size until
it encounters an OOM, after which it will do a binary search that will finetune the
batch size. Additionally, it should be noted that the batch size scaler cannot
search for batch sizes larger than the size of the training dataset.

.. note::

This feature expects that a `batch_size` field in the `hparams` of your model, i.e.,
`model.hparams.batch_size` should exist and will be overridden by the results of this
algorithm. Additionally, your `train_dataloader()` method should depend on this field
.. note::

This feature expects that a `batch_size` field is either located as a model attribute
i.e. `model.batch_size` or as a field in your `hparams` i.e. `model.hparams.batch_size`.
The field should exist and will be overridden by the results of this algorithm.
Additionally, your `train_dataloader()` method should depend on this field
for this feature to work i.e.

.. code-block:: python
def train_dataloader(self):
return DataLoader(train_dataset, batch_size=self.batch_size)
return DataLoader(train_dataset, batch_size=self.batch_size|self.hparams.batch_size)
.. warning::

Due to these constraints, this features does *NOT* work when passing dataloaders directly
to `.fit()`.
to `.fit()`.

The scaling algorithm has a number of parameters that the user can control by
invoking the trainer method `.scale_batch_size` themself (see description below).
Expand All @@ -93,29 +97,29 @@ invoking the trainer method `.scale_batch_size` themself (see description below)
tuner = Tuner(trainer)
# Invoke method
new_batch_size = tuner.scale_batch_size(model, ...)
new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here)
# Override old batch size
model.hparams.batch_size = new_batch_size
# Fit as normal
trainer.fit(model)
The algorithm in short works by:
1. Dumping the current state of the model and trainer
2. Iteratively until convergence or maximum number of tries `max_trials` (default 25) has been reached:
- Call `fit()` method of trainer. This evaluates `steps_per_trial` (default 3) number of
training steps. Each training step can trigger an OOM error if the tensors
(training batch, weights, gradients ect.) allocated during the steps have a
- Call `fit()` method of trainer. This evaluates `steps_per_trial` (default 3) number of
training steps. Each training step can trigger an OOM error if the tensors
(training batch, weights, gradients ect.) allocated during the steps have a
too large memory footprint.
- If an OOM error is encountered, decrease batch size else increase it.
How much the batch size is increased/decreased is determined by the choosen
stratrgy.
3. The found batch size is saved to `model.hparams.batch_size`
3. The found batch size is saved to either `model.batch_size` or `model.hparams.batch_size`
4. Restore the initial state of model and trainer

.. autoclass:: pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin
:members: scale_batch_size
.. autoclass:: pytorch_lightning.tuner.tuning.Tuner
:noindex:
:members: scale_batch_size

.. warning:: Batch size finder is not supported for DDP yet, it is coming soon.
36 changes: 36 additions & 0 deletions pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,42 @@ def scale_batch_size(self,
max_trials: int = 25,
batch_arg_name: str = 'batch_size',
**fit_kwargs):
r"""
Will iteratively try to find the largest batch size for a given model
that does not give an out of memory (OOM) error.
Args:
model: Model to fit.
mode: string setting the search mode. Either `power` or `binsearch`.
If mode is `power` we keep multiplying the batch size by 2, until
we get an OOM error. If mode is 'binsearch', we will initially
also keep multiplying by 2 and after encountering an OOM error
do a binary search between the last successful batch size and the
batch size that failed.
steps_per_trial: number of steps to run with a given batch size.
Idealy 1 should be enough to test if a OOM error occurs,
however in practise a few are needed
init_val: initial batch size to start the search with
max_trials: max number of increase in batch size done before
algorithm is terminated
batch_arg_name: name of the attribute that stores the batch size.
It is expected that the user has provided a model or datamodule that has a hyperparameter
with that name. We will look for this attribute name in the following places
- `model`
- `model.hparams`
- `model.datamodule`
- `trainer.datamodule` (the datamodule passed to the tune method)
**fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
or datamodule.
"""
return scale_batch_size(
self.trainer, model, mode, steps_per_trial, init_val, max_trials, batch_arg_name, **fit_kwargs
)
Expand Down