diff --git a/.circleci/config.yml b/.circleci/config.yml index 5cf2fab47635..46b10b59a3d9 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -84,6 +84,7 @@ jobs: - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html - run: pip install tensorflow_probability - run: pip install https://github.com/kpu/kenlm/archive/master.zip + - run: pip install git+https://github.com/huggingface/accelerate - save_cache: key: v0.4-{{ checksum "setup.py" }} paths: @@ -122,6 +123,7 @@ jobs: - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html - run: pip install tensorflow_probability - run: pip install https://github.com/kpu/kenlm/archive/master.zip + - run: pip install git+https://github.com/huggingface/accelerate - save_cache: key: v0.4-{{ checksum "setup.py" }} paths: @@ -154,6 +156,7 @@ jobs: - run: pip install .[sklearn,flax,torch,testing,sentencepiece,torch-speech,vision] - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html - run: pip install https://github.com/kpu/kenlm/archive/master.zip + - run: pip install git+https://github.com/huggingface/accelerate - save_cache: key: v0.4-{{ checksum "setup.py" }} paths: @@ -191,6 +194,7 @@ jobs: - run: pip install .[sklearn,flax,torch,testing,sentencepiece,torch-speech,vision] - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html - run: pip install https://github.com/kpu/kenlm/archive/master.zip + - run: pip install git+https://github.com/huggingface/accelerate - save_cache: key: v0.4-{{ checksum "setup.py" }} paths: @@ -222,6 +226,7 @@ jobs: - run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm] - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html - run: pip install https://github.com/kpu/kenlm/archive/master.zip + - run: pip install git+https://github.com/huggingface/accelerate - save_cache: key: v0.4-torch-{{ checksum "setup.py" }} paths: @@ -258,6 +263,7 @@ jobs: - run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm] - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html - run: pip install https://github.com/kpu/kenlm/archive/master.zip + - run: pip install git+https://github.com/huggingface/accelerate - save_cache: key: v0.4-torch-{{ checksum "setup.py" }} paths: diff --git a/setup.py b/setup.py index bb3598fda20d..386d1c25c51d 100644 --- a/setup.py +++ b/setup.py @@ -96,6 +96,7 @@ # 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py _deps = [ "Pillow", + "accelerate>=0.7.1", "black~=22.0", "codecarbon==1.2.0", "cookiecutter==1.7.3", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 4b3498e1f8ea..c6ae12ff84e8 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -3,6 +3,7 @@ # 2. run `make deps_table_update`` deps = { "Pillow": "Pillow", + "accelerate": "accelerate>=0.7.1", "black": "black~=22.0", "codecarbon": "codecarbon==1.2.0", "cookiecutter": "cookiecutter==1.7.3", diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 86d3673b7477..fe6130363985 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -40,6 +40,7 @@ is_wandb_available, ) from .utils import ( + is_accelerate_available, is_apex_available, is_bitsandbytes_available, is_detectron2_available, @@ -238,6 +239,13 @@ def require_git_lfs(test_case): return unittest.skipUnless(_run_git_lfs_tests, "test of git lfs workflow")(test_case) +def require_accelerate(test_case): + """ + Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed. + """ + return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case) + + def require_rjieba(test_case): """ Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e3e6c0297bef..aa54f2af1bb5 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -115,6 +115,7 @@ default_compute_objective, default_hp_space, denumpify_detensorize, + find_executable_batch_size, get_last_checkpoint, has_length, number_of_arguments, @@ -548,6 +549,9 @@ def __init__( self.label_names = default_label_names if self.args.label_names is None else self.args.label_names self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) + # Internal variables to keep track of the original batch size + self._train_batch_size = args.train_batch_size + # very last self._memory_tracker.stop_and_update_metrics() @@ -718,7 +722,7 @@ def get_train_dataloader(self) -> DataLoader: if self.args.world_size > 1: train_dataset = IterableDatasetShard( train_dataset, - batch_size=self.args.train_batch_size, + batch_size=self._train_batch_size, drop_last=self.args.dataloader_drop_last, num_processes=self.args.world_size, process_index=self.args.process_index, @@ -736,7 +740,7 @@ def get_train_dataloader(self) -> DataLoader: return DataLoader( train_dataset, - batch_size=self.args.train_batch_size, + batch_size=self._train_batch_size, sampler=train_sampler, collate_fn=self.data_collator, drop_last=self.args.dataloader_drop_last, @@ -1267,6 +1271,20 @@ def train( self._move_model_to_device(self.model, args.device) self.model_wrapped = self.model + inner_training_loop = find_executable_batch_size( + self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size + ) + return inner_training_loop( + args=args, + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + ) + + def _inner_training_loop( + self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None + ): + self._train_batch_size = batch_size # Data loader and number of training steps train_dataloader = self.get_train_dataloader() diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 5369b2e78023..62cab858b7e1 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -36,6 +36,7 @@ is_torch_available, is_torch_cuda_available, is_torch_tpu_available, + requires_backends, ) @@ -355,6 +356,7 @@ class TrainerMemoryTracker: stages = { "__init__": "init", "train": "train", + "_inner_training_loop": "train", "evaluate": "eval", "predict": "test", } @@ -584,6 +586,37 @@ class ShardedDDPOption(ExplicitEnum): AUTO_WRAP = "auto_wrap" +def find_executable_batch_size( + function: callable = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False +): + """ + Args: + A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or + CUDNN, the batch size is cut in half and passed to `function` `function` must take in a `batch_size` parameter as + its first argument. + function (`callable`, *optional*) + A function to wrap + starting_batch_size (`int`, *optional*) + The batch size to try and fit into memory + auto_find_batch_size (`bool`, *optional*) + If False, will just execute `function` + """ + if function is None: + return functools.partial( + find_executable_batch_size, + starting_batch_size=starting_batch_size, + auto_find_batch_size=auto_find_batch_size, + ) + + if auto_find_batch_size: + requires_backends(find_executable_batch_size, "accelerate") + import accelerate.memory_utils as mem_utils + + return mem_utils.find_executable_batch_size(function=function, starting_batch_size=starting_batch_size) + + return functools.partial(function, batch_size=starting_batch_size) + + class FSDPOption(ExplicitEnum): FULL_SHARD = "full_shard" SHARD_GRAD_OP = "shard_grad_op" diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b1c3f8b2558b..631fe0b3ec57 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -443,6 +443,9 @@ class TrainingArguments: include_inputs_for_metrics (`bool`, *optional*, defaults to `False`): Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics that need inputs, predictions and references for scoring calculation in Metric class. + auto_find_batch_size (`bool`, *optional*, defaults to `False`) + Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding + CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`) """ output_dir: str = field( @@ -803,6 +806,13 @@ class TrainingArguments: metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"}, ) + auto_find_batch_size: bool = field( + default=False, + metadata={ + "help": "Whether to automatically decrease the batch size in half and rerun the training loop again each time a CUDA Out-of-Memory was reached" + }, + ) + def __post_init__(self): # Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then). # This needs to happen before any call to self.device or self.n_gpu. diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index e07ba68bd03b..2c473b389d4e 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -85,6 +85,7 @@ DummyObject, OptionalDependencyNotAvailable, _LazyModule, + is_accelerate_available, is_apex_available, is_bitsandbytes_available, is_coloredlogs_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 648bb17184d4..3ee89167b255 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -428,6 +428,10 @@ def is_protobuf_available(): return importlib.util.find_spec("google.protobuf") is not None +def is_accelerate_available(): + return importlib.util.find_spec("accelerate") is not None + + def is_tokenizers_available(): return importlib.util.find_spec("tokenizers") is not None @@ -725,6 +729,12 @@ def wrapper(*args, **kwargs): `pip install pyctcdecode` """ +# docstyle-ignore +ACCELERATE_IMPORT_ERROR = """ +{0} requires the accelerate library but it was not found in your environment. You can install it with pip: +`pip install accelerate` +""" + BACKENDS_MAPPING = OrderedDict( [ @@ -750,6 +760,7 @@ def wrapper(*args, **kwargs): ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("vision", (is_vision_available, VISION_IMPORT_ERROR)), ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), + ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)), ] ) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index f9df63c15e34..e5e11fcd213c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -21,6 +21,7 @@ import random import re import subprocess +import sys import tempfile import time import unittest @@ -58,6 +59,7 @@ require_torch_bf16, require_torch_gpu, require_torch_multi_gpu, + require_torch_non_multi_gpu, require_torch_tf32, require_torch_up_to_2_gpus, require_wandb, @@ -1075,6 +1077,41 @@ def test_resume_training_with_randomness(self): self.assertAlmostEqual(a, a1, delta=1e-8) self.assertAlmostEqual(b, b1, delta=1e-8) + @slow + @require_torch_non_multi_gpu + def test_auto_batch_size_finder(self): + + if torch.cuda.is_available(): + torch.backends.cudnn.deterministic = True + + SRC_DIR = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "examples", "pytorch", "text-classification") + ) + sys.path.append(SRC_DIR) + import run_glue + + with tempfile.TemporaryDirectory() as tmpdir: + testargs = f""" + run_glue.py + --model_name_or_path distilbert-base-uncased + --task_name mrpc + --do_train + --do_eval + --max_seq_len 128 + --per_device_train_batch_size 4096 + --learning_rate 2e-5 + --num_train_epochs 1 + --output_dir {tmpdir} + --auto_find_batch_size 0 + """.split() + with self.assertRaises(RuntimeError): + with patch.object(sys, "argv", testargs): + run_glue.main() + + testargs[-1] = "1" + with patch.object(sys, "argv", testargs): + run_glue.main() + # regression for this issue: https://github.com/huggingface/transformers/issues/12970 def test_training_with_resume_from_checkpoint_false(self): train_dataset = RegressionDataset(length=128) diff --git a/tests/trainer/test_trainer_utils.py b/tests/trainer/test_trainer_utils.py index 7710892d8d79..41448fdcb403 100644 --- a/tests/trainer/test_trainer_utils.py +++ b/tests/trainer/test_trainer_utils.py @@ -18,7 +18,8 @@ import numpy as np -from transformers.testing_utils import require_torch +from transformers.testing_utils import require_accelerate, require_torch +from transformers.trainer_utils import find_executable_batch_size from transformers.utils import is_torch_available @@ -420,3 +421,39 @@ def test_shard_sampler(self): self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3) self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3) + + @require_accelerate + def test_executable_batch_size(self): + batch_sizes = [] + + @find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=True) + def mock_training_loop_function(batch_size): + nonlocal batch_sizes + batch_sizes.append(batch_size) + if batch_size > 16: + raise RuntimeError("CUDA out of memory.") + + mock_training_loop_function() + self.assertEqual(batch_sizes, [64, 32, 16]) + + @require_accelerate + def test_executable_batch_size_no_search(self): + batch_sizes = [] + + @find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False) + def mock_training_loop_function(batch_size): + nonlocal batch_sizes + batch_sizes.append(batch_size) + + mock_training_loop_function() + self.assertEqual(batch_sizes, [64]) + + @require_accelerate + def test_executable_batch_size_with_error(self): + @find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False) + def mock_training_loop_function(batch_size): + raise RuntimeError("CUDA out of memory.") + + with self.assertRaises(RuntimeError) as cm: + mock_training_loop_function() + self.assertEqual("CUDA out of memory", cm.args[0])