Skip to content
Merged
2 changes: 1 addition & 1 deletion docs/source/dpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ model = AutoModelForCausalLM.from_pretrained(
"mistralai/mixtral-8x7b-v0.1",
load_in_4bit=True,
quantization_config=bnb_config,
attn_implementation="flash_attention_2",
attn_implementation="kernels-community/flash-attn2",
dtype=torch.bfloat16,
device_map="auto",
)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/gkd_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a
The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method.

> [!WARNING]
> Make sure that `attn_implementation="flash_attention_2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture.
> Make sure that `attn_implementation="kernels-community/flash-attn2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture.

The basic API is as follows:

Expand Down
4 changes: 2 additions & 2 deletions docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ Padding-free batching is an alternative approach for reducing memory usage. In t
```python
from trl import DPOConfig

training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"})
training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"})
```

</hfoption>
Expand All @@ -197,7 +197,7 @@ training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_imple
```python
from trl import SFTConfig

training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"})
training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"})
```

</hfoption>
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ judges = [
"openai>=1.23.2",
"llm-blender>=0.0.2"
]
kernels = [
"kernels"
]
liger = [
"liger-kernel>=0.6.2"
]
Expand Down Expand Up @@ -98,6 +101,8 @@ dev = [
# judges
"openai>=1.23.2",
"llm-blender>=0.0.2",
# kernels
"kernels",
# liger
"liger-kernel>=0.6.2",
# peft
Expand Down
8 changes: 5 additions & 3 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@

from .testing_utils import (
TrlTestCase,
require_ampere_or_newer,
require_bitsandbytes,
require_flash_attn,
require_kernels,
require_liger_kernel,
require_peft,
require_torch_accelerator,
Expand Down Expand Up @@ -1987,7 +1988,8 @@ def test_training_with_transformers_paged(self, model_name):
"HuggingFaceTB/SmolVLM-Instruct", # Only test the smaller model to avoid OOM
],
)
@require_flash_attn
@require_kernels
@require_ampere_or_newer # Flash attention 2 requires Ampere or newer GPUs
@require_bitsandbytes
@require_peft
def test_vlm_training(self, model_name):
Expand Down Expand Up @@ -2040,7 +2042,7 @@ def data_gen(num_samples):
)
model = AutoModelForImageTextToText.from_pretrained(
model_name,
attn_implementation="flash_attention_2",
attn_implementation="kernels-community/flash-attn2",
dtype="bfloat16",
device_map=get_kbit_device_map(),
quantization_config=quantization_config,
Expand Down
8 changes: 5 additions & 3 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@
from .testing_utils import (
TrlTestCase,
ignore_warnings,
require_ampere_or_newer,
require_bitsandbytes,
require_flash_attn,
require_kernels,
require_liger_kernel,
require_peft,
require_torch_accelerator,
Expand Down Expand Up @@ -870,7 +871,8 @@ def test_train_with_iterable_dataset(self):
new_param = trainer.model.get_parameter(n)
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"

@require_flash_attn
@require_kernels
@require_ampere_or_newer # Flash attention 2 requires Ampere or newer GPUs
def test_train_padding_free(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
Expand All @@ -879,7 +881,7 @@ def test_train_padding_free(self):
training_args = SFTConfig(
output_dir=self.tmp_dir,
padding_free=True,
model_init_kwargs={"attn_implementation": "flash_attention_2"},
model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"},
bf16=True, # flash_attention_2 only supports bf16 and fp16
report_to="none",
)
Expand Down
21 changes: 8 additions & 13 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available
from transformers.testing_utils import backend_device_count, torch_device
from transformers.utils import (
is_flash_attn_2_available,
is_kernels_available,
is_peft_available,
is_rich_available,
Expand All @@ -45,6 +44,7 @@

require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")
require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml")
require_kernels = pytest.mark.skipif(not is_kernels_available(), reason="test requires kernels")
require_liger_kernel = pytest.mark.skipif(not is_liger_kernel_available(), reason="test requires liger-kernel")
require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender")
require_math_latex = pytest.mark.skipif(not is_math_verify_available(), reason="test requires math_verify")
Expand Down Expand Up @@ -85,21 +85,16 @@ def is_bitsandbytes_multi_backend_available() -> bool:
)


def is_flash_attn_available():
flash_attn_available = is_flash_attn_2_available()
kernels_available = is_kernels_available()
try:
from kernels import get_kernel

get_kernel("kernels-community/flash-attn")
except Exception:
kernels_available = False
def is_ampere_or_newer(device_index=0):
if not torch.cuda.is_available():
return False

return kernels_available or flash_attn_available
major, minor = torch.cuda.get_device_capability(device_index)
# Ampere starts at compute capability 8.0 (e.g., A100 = 8.0, RTX 30xx = 8.6)
return (major, minor) >= (8, 0)


# Function ported from transformers.testing_utils
require_flash_attn = pytest.mark.skipif(not is_flash_attn_available(), reason="test requires Flash Attention")
require_ampere_or_newer = pytest.mark.skipif(not is_ampere_or_newer(), reason="test requires Ampere or newer GPU")


class RandomBinaryJudge(BaseBinaryJudge):
Expand Down
Loading