Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
297e7d1
Start implementation
muellerzr May 3, 2022
3d78372
Fix partial
muellerzr May 3, 2022
c78ed02
change where logic happens
muellerzr May 3, 2022
a933e18
Change how decorator is called
muellerzr May 3, 2022
0ad8a77
Should fix issue
muellerzr May 3, 2022
c5d5d7b
Style
muellerzr May 3, 2022
8da6c7e
Proper style
muellerzr May 3, 2022
f759535
Add tests
muellerzr May 3, 2022
f97111a
Add to deps and remove guard
muellerzr May 3, 2022
253f0cb
Try now?
muellerzr May 3, 2022
4f16986
Dep tables
muellerzr May 3, 2022
5624504
Rm import
muellerzr May 3, 2022
6dcb9df
Add test
muellerzr May 3, 2022
48753df
Update setup.py to use latest
muellerzr May 3, 2022
ab12e99
Fixup dep style
muellerzr May 4, 2022
81fe4d6
Move auto_find_batch_size to a TrainingArgument
muellerzr May 4, 2022
3a777fb
Use requires_backends, add decorators to tests
muellerzr May 4, 2022
56b884d
move import
muellerzr May 4, 2022
b12ab6a
Rework test to use glue
muellerzr May 4, 2022
4ce0683
Fix import errors
muellerzr May 4, 2022
7752813
Working trainer test, still need to fix utils
muellerzr May 4, 2022
f8e8927
Add accelerate to depslist
muellerzr May 4, 2022
88d20c1
Update config with install accelerate
muellerzr May 4, 2022
bdc4692
Restructure inner loop
muellerzr May 4, 2022
97b575e
Slow
muellerzr May 4, 2022
ff6caca
Restyle
muellerzr May 4, 2022
62d9a22
Fix logging import
muellerzr May 5, 2022
5c6c4b2
Add stas fix to stages
muellerzr May 5, 2022
ef92961
Style
muellerzr May 5, 2022
475e164
Finally clean
muellerzr May 5, 2022
00d316b
Merge branch 'main' into muellerzr-memory-decorator
muellerzr May 9, 2022
ada7bd1
Clean
muellerzr May 9, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ jobs:
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
- run: pip install tensorflow_probability
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
- run: pip install git+https://github.com/huggingface/accelerate
- save_cache:
key: v0.4-{{ checksum "setup.py" }}
paths:
Expand Down Expand Up @@ -122,6 +123,7 @@ jobs:
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
- run: pip install tensorflow_probability
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
- run: pip install git+https://github.com/huggingface/accelerate
- save_cache:
key: v0.4-{{ checksum "setup.py" }}
paths:
Expand Down Expand Up @@ -154,6 +156,7 @@ jobs:
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,torch-speech,vision]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
- run: pip install git+https://github.com/huggingface/accelerate
- save_cache:
key: v0.4-{{ checksum "setup.py" }}
paths:
Expand Down Expand Up @@ -191,6 +194,7 @@ jobs:
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,torch-speech,vision]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
- run: pip install git+https://github.com/huggingface/accelerate
- save_cache:
key: v0.4-{{ checksum "setup.py" }}
paths:
Expand Down Expand Up @@ -222,6 +226,7 @@ jobs:
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
- run: pip install git+https://github.com/huggingface/accelerate
- save_cache:
key: v0.4-torch-{{ checksum "setup.py" }}
paths:
Expand Down Expand Up @@ -258,6 +263,7 @@ jobs:
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
- run: pip install https://github.com/kpu/kenlm/archive/master.zip
- run: pip install git+https://github.com/huggingface/accelerate
- save_cache:
key: v0.4-torch-{{ checksum "setup.py" }}
paths:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
_deps = [
"Pillow",
"accelerate>=0.7.1",
"black~=22.0",
"codecarbon==1.2.0",
"cookiecutter==1.7.3",
Expand Down
1 change: 1 addition & 0 deletions src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# 2. run `make deps_table_update``
deps = {
"Pillow": "Pillow",
"accelerate": "accelerate>=0.7.1",
"black": "black~=22.0",
"codecarbon": "codecarbon==1.2.0",
"cookiecutter": "cookiecutter==1.7.3",
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
is_wandb_available,
)
from .utils import (
is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
is_detectron2_available,
Expand Down Expand Up @@ -238,6 +239,13 @@ def require_git_lfs(test_case):
return unittest.skipUnless(_run_git_lfs_tests, "test of git lfs workflow")(test_case)


def require_accelerate(test_case):
"""
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
"""
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)


def require_rjieba(test_case):
"""
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
Expand Down
22 changes: 20 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
default_compute_objective,
default_hp_space,
denumpify_detensorize,
find_executable_batch_size,
get_last_checkpoint,
has_length,
number_of_arguments,
Expand Down Expand Up @@ -548,6 +549,9 @@ def __init__(
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

# Internal variables to keep track of the original batch size
self._train_batch_size = args.train_batch_size

# very last
self._memory_tracker.stop_and_update_metrics()

Expand Down Expand Up @@ -718,7 +722,7 @@ def get_train_dataloader(self) -> DataLoader:
if self.args.world_size > 1:
train_dataset = IterableDatasetShard(
train_dataset,
batch_size=self.args.train_batch_size,
batch_size=self._train_batch_size,
drop_last=self.args.dataloader_drop_last,
num_processes=self.args.world_size,
process_index=self.args.process_index,
Expand All @@ -736,7 +740,7 @@ def get_train_dataloader(self) -> DataLoader:

return DataLoader(
train_dataset,
batch_size=self.args.train_batch_size,
batch_size=self._train_batch_size,
sampler=train_sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
Expand Down Expand Up @@ -1267,6 +1271,20 @@ def train(
self._move_model_to_device(self.model, args.device)
self.model_wrapped = self.model

inner_training_loop = find_executable_batch_size(
self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
)
return inner_training_loop(
args=args,
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
)

def _inner_training_loop(
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
):
self._train_batch_size = batch_size
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()

Expand Down
33 changes: 33 additions & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
is_torch_available,
is_torch_cuda_available,
is_torch_tpu_available,
requires_backends,
)


Expand Down Expand Up @@ -355,6 +356,7 @@ class TrainerMemoryTracker:
stages = {
"__init__": "init",
"train": "train",
"_inner_training_loop": "train",
"evaluate": "eval",
"predict": "test",
}
Expand Down Expand Up @@ -584,6 +586,37 @@ class ShardedDDPOption(ExplicitEnum):
AUTO_WRAP = "auto_wrap"


def find_executable_batch_size(
function: callable = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False
):
"""
Args:
A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
CUDNN, the batch size is cut in half and passed to `function` `function` must take in a `batch_size` parameter as
its first argument.
function (`callable`, *optional*)
A function to wrap
starting_batch_size (`int`, *optional*)
The batch size to try and fit into memory
auto_find_batch_size (`bool`, *optional*)
If False, will just execute `function`
"""
if function is None:
return functools.partial(
find_executable_batch_size,
starting_batch_size=starting_batch_size,
auto_find_batch_size=auto_find_batch_size,
)

if auto_find_batch_size:
requires_backends(find_executable_batch_size, "accelerate")
import accelerate.memory_utils as mem_utils

return mem_utils.find_executable_batch_size(function=function, starting_batch_size=starting_batch_size)

return functools.partial(function, batch_size=starting_batch_size)


class FSDPOption(ExplicitEnum):
FULL_SHARD = "full_shard"
SHARD_GRAD_OP = "shard_grad_op"
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,9 @@ class TrainingArguments:
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
that need inputs, predictions and references for scoring calculation in Metric class.
auto_find_batch_size (`bool`, *optional*, defaults to `False`)
Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding
CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
"""

output_dir: str = field(
Expand Down Expand Up @@ -803,6 +806,13 @@ class TrainingArguments:
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"},
)

auto_find_batch_size: bool = field(
default=False,
metadata={
"help": "Whether to automatically decrease the batch size in half and rerun the training loop again each time a CUDA Out-of-Memory was reached"
},
)

def __post_init__(self):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
# This needs to happen before any call to self.device or self.n_gpu.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
DummyObject,
OptionalDependencyNotAvailable,
_LazyModule,
is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
is_coloredlogs_available,
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,10 @@ def is_protobuf_available():
return importlib.util.find_spec("google.protobuf") is not None


def is_accelerate_available():
return importlib.util.find_spec("accelerate") is not None


def is_tokenizers_available():
return importlib.util.find_spec("tokenizers") is not None

Expand Down Expand Up @@ -725,6 +729,12 @@ def wrapper(*args, **kwargs):
`pip install pyctcdecode`
"""

# docstyle-ignore
ACCELERATE_IMPORT_ERROR = """
{0} requires the accelerate library but it was not found in your environment. You can install it with pip:
`pip install accelerate`
"""


BACKENDS_MAPPING = OrderedDict(
[
Expand All @@ -750,6 +760,7 @@ def wrapper(*args, **kwargs):
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
("vision", (is_vision_available, VISION_IMPORT_ERROR)),
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
]
)

Expand Down
37 changes: 37 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import random
import re
import subprocess
import sys
import tempfile
import time
import unittest
Expand Down Expand Up @@ -58,6 +59,7 @@
require_torch_bf16,
require_torch_gpu,
require_torch_multi_gpu,
require_torch_non_multi_gpu,
require_torch_tf32,
require_torch_up_to_2_gpus,
require_wandb,
Expand Down Expand Up @@ -1075,6 +1077,41 @@ def test_resume_training_with_randomness(self):
self.assertAlmostEqual(a, a1, delta=1e-8)
self.assertAlmostEqual(b, b1, delta=1e-8)

@slow
@require_torch_non_multi_gpu
def test_auto_batch_size_finder(self):

if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True

SRC_DIR = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "..", "examples", "pytorch", "text-classification")
)
sys.path.append(SRC_DIR)
import run_glue

with tempfile.TemporaryDirectory() as tmpdir:
testargs = f"""
run_glue.py
--model_name_or_path distilbert-base-uncased
--task_name mrpc
--do_train
--do_eval
--max_seq_len 128
--per_device_train_batch_size 4096
--learning_rate 2e-5
--num_train_epochs 1
--output_dir {tmpdir}
--auto_find_batch_size 0
""".split()
with self.assertRaises(RuntimeError):
with patch.object(sys, "argv", testargs):
run_glue.main()

testargs[-1] = "1"
with patch.object(sys, "argv", testargs):
run_glue.main()

# regression for this issue: https://github.com/huggingface/transformers/issues/12970
def test_training_with_resume_from_checkpoint_false(self):
train_dataset = RegressionDataset(length=128)
Expand Down
39 changes: 38 additions & 1 deletion tests/trainer/test_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

import numpy as np

from transformers.testing_utils import require_torch
from transformers.testing_utils import require_accelerate, require_torch
from transformers.trainer_utils import find_executable_batch_size
from transformers.utils import is_torch_available


Expand Down Expand Up @@ -420,3 +421,39 @@ def test_shard_sampler(self):

self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3)
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3)

@require_accelerate
def test_executable_batch_size(self):
batch_sizes = []

@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=True)
def mock_training_loop_function(batch_size):
nonlocal batch_sizes
batch_sizes.append(batch_size)
if batch_size > 16:
raise RuntimeError("CUDA out of memory.")

mock_training_loop_function()
self.assertEqual(batch_sizes, [64, 32, 16])

@require_accelerate
def test_executable_batch_size_no_search(self):
batch_sizes = []

@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False)
def mock_training_loop_function(batch_size):
nonlocal batch_sizes
batch_sizes.append(batch_size)

mock_training_loop_function()
self.assertEqual(batch_sizes, [64])

@require_accelerate
def test_executable_batch_size_with_error(self):
@find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False)
def mock_training_loop_function(batch_size):
raise RuntimeError("CUDA out of memory.")

with self.assertRaises(RuntimeError) as cm:
mock_training_loop_function()
self.assertEqual("CUDA out of memory", cm.args[0])