From 5039d0d18b4fbc1f123d9ae2c7a388ccc215d7a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 12 Nov 2025 17:27:07 +0000 Subject: [PATCH 1/4] Rename `flash-attn` to `flash-attn2` --- docs/source/kernels_hub.md | 8 ++++---- tests/testing_utils.py | 2 +- trl/trainer/sft_trainer.py | 1 + 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/source/kernels_hub.md b/docs/source/kernels_hub.md index a4c4a651557..f3d7ee124ba 100644 --- a/docs/source/kernels_hub.md +++ b/docs/source/kernels_hub.md @@ -27,20 +27,20 @@ from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained( "your-model-name", - attn_implementation="kernels-community/flash-attn" # other options: kernels-community/vllm-flash-attn3, kernels-community/paged-attention + attn_implementation="kernels-community/flash-attn2" # other options: kernels-community/vllm-flash-attn3, kernels-community/paged-attention ) ``` Or when running a TRL training script: ```bash -python sft.py ... --attn_implementation kernels-community/flash-attn +python sft.py ... --attn_implementation kernels-community/flash-attn2 ``` Or using the TRL CLI: ```bash -trl sft ... --attn_implementation kernels-community/flash-attn +trl sft ... --attn_implementation kernels-community/flash-attn2 ``` > [!TIP] @@ -84,7 +84,7 @@ from trl import SFTConfig model = AutoModelForCausalLM.from_pretrained( "your-model-name", - attn_implementation="kernels-community/flash-attn" # choose the desired FlashAttention variant + attn_implementation="kernels-community/flash-attn2" # choose the desired FlashAttention variant ) training_args = SFTConfig( diff --git a/tests/testing_utils.py b/tests/testing_utils.py index bf86b7ab703..2ec18b2f9cb 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -91,7 +91,7 @@ def is_flash_attn_available(): try: from kernels import get_kernel - get_kernel("kernels-community/flash-attn") + get_kernel("kernels-community/flash-attn2") except Exception: kernels_available = False diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index f69a5f3eddb..a3ddc49f2d0 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -73,6 +73,7 @@ "flash_attention_2", "flash_attention_3", "kernels-community/flash-attn", + "kernels-community/vllm-flash-attn2", "kernels-community/vllm-flash-attn3", "kernels-community/flash-attn3", } From fd46a6258a2f5afc4ed87d88bc4b6b3dab947d87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 12 Nov 2025 17:27:55 +0000 Subject: [PATCH 2/4] typo --- trl/trainer/sft_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index a3ddc49f2d0..8b27a42b732 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -73,9 +73,9 @@ "flash_attention_2", "flash_attention_3", "kernels-community/flash-attn", - "kernels-community/vllm-flash-attn2", - "kernels-community/vllm-flash-attn3", + "kernels-community/flash-attn2", "kernels-community/flash-attn3", + "kernels-community/vllm-flash-attn3", } From e9ed9ccfb2537f10959d2e76b298e296d1ea6f20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 13 Nov 2025 04:39:17 +0000 Subject: [PATCH 3/4] update model config --- trl/trainer/model_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/model_config.py b/trl/trainer/model_config.py index 9e3d5fe8021..a7556b47646 100644 --- a/trl/trainer/model_config.py +++ b/trl/trainer/model_config.py @@ -43,8 +43,8 @@ class ModelConfig: be set to `True` for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. attn_implementation (`str`, *optional*): - Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in which case - you must install this manually by running `pip install flash-attn --no-build-isolation`. + Which attention implementation to use. More information in the [Kernels Hub Integrations + Guide](kernels_hub). use_peft (`bool`, *optional*, defaults to `False`): Whether to use PEFT for training. lora_r (`int`, *optional*, defaults to `16`): From fe5611ad86183d6d45e7d432b95177e58a51d12a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 13 Nov 2025 12:07:32 -0700 Subject: [PATCH 4/4] Update trl/trainer/sft_trainer.py Co-authored-by: Sergio Paniego Blanco --- trl/trainer/sft_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 8b27a42b732..f446efdc63c 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -72,7 +72,6 @@ FLASH_ATTENTION_VARIANTS = { "flash_attention_2", "flash_attention_3", - "kernels-community/flash-attn", "kernels-community/flash-attn2", "kernels-community/flash-attn3", "kernels-community/vllm-flash-attn3",