Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

limit auto scaling batch size to the size of the training dataset #3271

Merged
merged 11 commits into from
Sep 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed Horovod distributed backend compatibility with native AMP ([#3404](https://github.com/PyTorchLightning/pytorch-lightning/pull/3404))

- Fixed batch size auto scaling exceeding the size of the dataset ([#3271](https://github.com/PyTorchLightning/pytorch-lightning/pull/3271))

## [0.9.0] - YYYY-MM-DD

### Added
Expand Down
8 changes: 0 additions & 8 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os
from abc import ABC, abstractmethod
from typing import Optional

import torch
from torch import Tensor

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_getattr, lightning_setattr

try:
from apex import amp
Expand Down
52 changes: 38 additions & 14 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License
import os
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_getattr, lightning_setattr
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning import _logger as log
from typing import Optional
from typing import Optional, Tuple


def scale_batch_size(trainer,
Expand Down Expand Up @@ -55,6 +56,13 @@ def scale_batch_size(trainer,
algorithm is terminated

batch_arg_name: name of the attribute that stores the batch size.
It is expected that the user has provided a model or datamodule that has a hyperparameter
with that name. We will look for this attribute name in the following places

- `model`
- `model.hparams`
- `model.datamodule`
- `trainer.datamodule` (the datamodule passed to the tune method)

**fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
or datamodule.
Expand Down Expand Up @@ -165,16 +173,19 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **f
# Try fit
trainer.fit(model, **fit_kwargs)
# Double in size
new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
except RuntimeError as exception:
# Only these errors should trigger an adjustment
if is_oom_error(exception):
# If we fail in power mode, half the size and return
garbage_collection_cuda()
new_size = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc='failed')
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc='failed')
break
else:
raise # some other error not memory related

if not changed:
break
return new_size


Expand All @@ -199,39 +210,40 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials,
if high - low <= 1:
break
midval = (high + low) // 2
new_size = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='succeeded')
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='succeeded')
else:
new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')

if not changed:
break

except RuntimeError as exception:
# Only these errors should trigger an adjustment
if is_oom_error(exception):
# If we fail in power mode, half the size and return
garbage_collection_cuda()
high = new_size
midval = (high + low) // 2
new_size = _adjust_batch_size(trainer, value=midval, desc='failed')
new_size, _ = _adjust_batch_size(trainer, value=midval, desc='failed')
if high - low <= 1:
break
else:
raise # some other error not memory related

return new_size


def _adjust_batch_size(trainer,
batch_arg_name: str = 'batch_size',
factor: float = 1.0,
value: Optional[int] = None,
desc: str = None):
""" Function for adjusting the batch size. It is expected that the user
has provided a model that has a hparam field called `batch_size` i.e.
`model.hparams.batch_size` should exist. Additionally there can be a
datamodule attached to either Trainer or model, in that case the attribute
also gets updated when present.
desc: str = None) -> Tuple[int, bool]:
""" Helper function for adjusting the batch size.

Args:
trainer: instance of pytorch_lightning.Trainer

batch_arg_name: field where batch_size is stored in `model.hparams`
batch_arg_name: name of the field where batch_size is stored.

factor: value which the old batch size is multiplied by to get the
new batch size
Expand All @@ -241,11 +253,23 @@ def _adjust_batch_size(trainer,

desc: either `succeeded` or `failed`. Used purely for logging

Returns:
The new batch size for the next trial and a bool that signals whether the
new value is different than the previous batch size.
"""
model = trainer.get_model()
batch_size = lightning_getattr(model, batch_arg_name)
new_size = value if value is not None else int(batch_size * factor)
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')

if not _is_valid_batch_size(new_size, trainer.train_dataloader):
new_size = min(new_size, len(trainer.train_dataloader.dataset))

changed = new_size != batch_size
lightning_setattr(model, batch_arg_name, new_size)
return new_size
return new_size, changed


def _is_valid_batch_size(current_size, dataloader):
return not has_len(dataloader) or current_size <= len(dataloader)
3 changes: 2 additions & 1 deletion tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,10 @@ def dataloader(self, *args, **kwargs):

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
trainer.tune(model, datamodule_fit)
assert trainer.datamodule == datamodule_fit
after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
assert trainer.datamodule == datamodule_fit
assert before_batch_size != after_batch_size
assert after_batch_size <= len(trainer.train_dataloader.dataset)
assert datamodule_fit.batch_size == after_batch_size
# should be left unchanged, since it was not passed to .tune()
assert datamodule_model.batch_size == 111
Expand Down