Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/kernels_hub.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,20 @@ from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
"your-model-name",
attn_implementation="kernels-community/flash-attn" # other options: kernels-community/vllm-flash-attn3, kernels-community/paged-attention
attn_implementation="kernels-community/flash-attn2" # other options: kernels-community/vllm-flash-attn3, kernels-community/paged-attention
)
```

Or when running a TRL training script:

```bash
python sft.py ... --attn_implementation kernels-community/flash-attn
python sft.py ... --attn_implementation kernels-community/flash-attn2
```

Or using the TRL CLI:

```bash
trl sft ... --attn_implementation kernels-community/flash-attn
trl sft ... --attn_implementation kernels-community/flash-attn2
```

> [!TIP]
Expand Down Expand Up @@ -84,7 +84,7 @@ from trl import SFTConfig

model = AutoModelForCausalLM.from_pretrained(
"your-model-name",
attn_implementation="kernels-community/flash-attn" # choose the desired FlashAttention variant
attn_implementation="kernels-community/flash-attn2" # choose the desired FlashAttention variant
)

training_args = SFTConfig(
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class ModelConfig:
be set to `True` for repositories you trust and in which you have read the code, as it will execute code
present on the Hub on your local machine.
attn_implementation (`str`, *optional*):
Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in which case
you must install this manually by running `pip install flash-attn --no-build-isolation`.
Which attention implementation to use. More information in the [Kernels Hub Integrations
Guide](kernels_hub).
use_peft (`bool`, *optional*, defaults to `False`):
Whether to use PEFT for training.
lora_r (`int`, *optional*, defaults to `16`):
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@
"flash_attention_2",
"flash_attention_3",
"kernels-community/flash-attn",
Comment thread
qgallouedec marked this conversation as resolved.
Outdated
"kernels-community/vllm-flash-attn3",
"kernels-community/flash-attn2",
"kernels-community/flash-attn3",
"kernels-community/vllm-flash-attn3",
}


Expand Down
Loading