diff --git a/docs/source/dpo_trainer.md b/docs/source/dpo_trainer.md index 8e9f0ac41b5..9e524c6a940 100644 --- a/docs/source/dpo_trainer.md +++ b/docs/source/dpo_trainer.md @@ -253,7 +253,7 @@ model = AutoModelForCausalLM.from_pretrained( "mistralai/mixtral-8x7b-v0.1", load_in_4bit=True, quantization_config=bnb_config, - attn_implementation="flash_attention_2", + attn_implementation="kernels-community/flash-attn2", dtype=torch.bfloat16, device_map="auto", ) diff --git a/docs/source/gkd_trainer.md b/docs/source/gkd_trainer.md index 73be330c637..fd7f3c0c2eb 100644 --- a/docs/source/gkd_trainer.md +++ b/docs/source/gkd_trainer.md @@ -28,7 +28,7 @@ The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method. > [!WARNING] -> Make sure that `attn_implementation="flash_attention_2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture. +> Make sure that `attn_implementation="kernels-community/flash-attn2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture. The basic API is as follows: diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md index de027a15eb0..d9a30c336ea 100644 --- a/docs/source/reducing_memory_usage.md +++ b/docs/source/reducing_memory_usage.md @@ -188,7 +188,7 @@ Padding-free batching is an alternative approach for reducing memory usage. In t ```python from trl import DPOConfig -training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"}) +training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}) ``` @@ -197,7 +197,7 @@ training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_imple ```python from trl import SFTConfig -training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"}) +training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}) ``` diff --git a/pyproject.toml b/pyproject.toml index 84e3f13debf..1b84ff50d7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,9 @@ judges = [ "openai>=1.23.2", "llm-blender>=0.0.2" ] +kernels = [ + "kernels" +] liger = [ "liger-kernel>=0.6.2" ] @@ -98,6 +101,8 @@ dev = [ # judges "openai>=1.23.2", "llm-blender>=0.0.2", + # kernels + "kernels", # liger "liger-kernel>=0.6.2", # peft diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index a3bf80a641b..b3844a399c1 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -41,8 +41,9 @@ from .testing_utils import ( TrlTestCase, + require_ampere_or_newer, require_bitsandbytes, - require_flash_attn, + require_kernels, require_liger_kernel, require_peft, require_torch_accelerator, @@ -1987,7 +1988,8 @@ def test_training_with_transformers_paged(self, model_name): "HuggingFaceTB/SmolVLM-Instruct", # Only test the smaller model to avoid OOM ], ) - @require_flash_attn + @require_kernels + @require_ampere_or_newer # Flash attention 2 requires Ampere or newer GPUs @require_bitsandbytes @require_peft def test_vlm_training(self, model_name): @@ -2040,7 +2042,7 @@ def data_gen(num_samples): ) model = AutoModelForImageTextToText.from_pretrained( model_name, - attn_implementation="flash_attention_2", + attn_implementation="kernels-community/flash-attn2", dtype="bfloat16", device_map=get_kbit_device_map(), quantization_config=quantization_config, diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index abc6faa2d03..874d5304f2f 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -33,8 +33,9 @@ from .testing_utils import ( TrlTestCase, ignore_warnings, + require_ampere_or_newer, require_bitsandbytes, - require_flash_attn, + require_kernels, require_liger_kernel, require_peft, require_torch_accelerator, @@ -870,7 +871,8 @@ def test_train_with_iterable_dataset(self): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" - @require_flash_attn + @require_kernels + @require_ampere_or_newer # Flash attention 2 requires Ampere or newer GPUs def test_train_padding_free(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") @@ -879,7 +881,7 @@ def test_train_padding_free(self): training_args = SFTConfig( output_dir=self.tmp_dir, padding_free=True, - model_init_kwargs={"attn_implementation": "flash_attention_2"}, + model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}, bf16=True, # flash_attention_2 only supports bf16 and fp16 report_to="none", ) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index bf86b7ab703..8f558bd5491 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -24,7 +24,6 @@ from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available from transformers.testing_utils import backend_device_count, torch_device from transformers.utils import ( - is_flash_attn_2_available, is_kernels_available, is_peft_available, is_rich_available, @@ -45,6 +44,7 @@ require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes") require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml") +require_kernels = pytest.mark.skipif(not is_kernels_available(), reason="test requires kernels") require_liger_kernel = pytest.mark.skipif(not is_liger_kernel_available(), reason="test requires liger-kernel") require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender") require_math_latex = pytest.mark.skipif(not is_math_verify_available(), reason="test requires math_verify") @@ -85,21 +85,16 @@ def is_bitsandbytes_multi_backend_available() -> bool: ) -def is_flash_attn_available(): - flash_attn_available = is_flash_attn_2_available() - kernels_available = is_kernels_available() - try: - from kernels import get_kernel - - get_kernel("kernels-community/flash-attn") - except Exception: - kernels_available = False +def is_ampere_or_newer(device_index=0): + if not torch.cuda.is_available(): + return False - return kernels_available or flash_attn_available + major, minor = torch.cuda.get_device_capability(device_index) + # Ampere starts at compute capability 8.0 (e.g., A100 = 8.0, RTX 30xx = 8.6) + return (major, minor) >= (8, 0) -# Function ported from transformers.testing_utils -require_flash_attn = pytest.mark.skipif(not is_flash_attn_available(), reason="test requires Flash Attention") +require_ampere_or_newer = pytest.mark.skipif(not is_ampere_or_newer(), reason="test requires Ampere or newer GPU") class RandomBinaryJudge(BaseBinaryJudge):