From bd34e458e30bd089b9d46b2f13afac2347bbe6cd Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 21 Jul 2025 11:05:22 +0200 Subject: [PATCH 1/3] update docs --- docs/source/en/attention_interface.md | 28 ++++++++++++++++++++++++ docs/source/en/cache_explanation.md | 31 ++++++++++++++++++++++++++- docs/source/en/llm_optims.md | 4 ++-- docs/source/en/perf_infer_gpu_one.md | 4 ++-- 4 files changed, 62 insertions(+), 5 deletions(-) diff --git a/docs/source/en/attention_interface.md b/docs/source/en/attention_interface.md index d78e21413e0e..44cde32f6d24 100644 --- a/docs/source/en/attention_interface.md +++ b/docs/source/en/attention_interface.md @@ -72,6 +72,34 @@ model(torch.ones(1, 5, dtype=int)) and it will stop printing the statements, as it now uses the `sdpa` attention. This allows to quickly change an attention function, without needing to reload the model! +## Different attention per backbone in multimodal models + +For multimodal models, you may want to load different backbones with different attention functions. For example, some vision backbones perform better with full precision only and are incompatible with Flash Attention. If you want to take advantage of Flash Attention while keeping your vision encoder in fp32, you can configure different attention implementations per backbone. + +```python +from transformers import AutoModelForImageTextToText + +model_id = "facebook/chameleon-7b" + +attention_implementation_per_backbone = {"vision_config": "sdpa", "text_config": "flash_attention_2"} +model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation=attention_implementation_per_backbone) + +# NOTE: keys in the attention implementation have to be the same as the sub-config names +for key in attention_implementation_per_backbone: + assert key in model.config.sub_configs, f"Invalid key in `attention_implementation`" + +# You can omit certain backbones - the default attention function (SDPA) will be used +# This is equivalent to the previous example +model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation={"text_config": "flash_attention_2"}) + + +# Set the same attention implementation for all backbones with single string, same as in non-multimodal models +model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager") + +# Alternatively use a dict with an empty key for global configuration +model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation={"": "eager"}) +``` + ## What about new args needed in my custom attention function? But indeed, what if the new function requires a new arg to be properly used? It's no issue! Models supporting the diff --git a/docs/source/en/cache_explanation.md b/docs/source/en/cache_explanation.md index 6c31035234bb..6ef564abe8af 100644 --- a/docs/source/en/cache_explanation.md +++ b/docs/source/en/cache_explanation.md @@ -132,6 +132,35 @@ for _ in range(max_new_tokens): print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]) "[INST] Hello, what's your name. [/INST] Hello! My name is LLaMA," ``` + +### Cache Position Usage + +The cache position tracks where the new tokens should be inserted in the attention cache. It represents the *absolute* position of each token in the context, independent of padding or batch structure. Suppose that you have already cached `N` tokens and you are now processing `K` new tokens. Then the cache position for the new tokens will range from `N` to `N + K - 1`. In other words, you're processing tokens at positions - `[N, N + 1, N + 2, ..., N + K - 1]` + +Cache position is used internally for two purposes: + +1. Selecting new tokens to process in the input sequence and ensuring only tokens that haven’t been cached yet are passed to the model's `forward`. +2. Storing key/value pairs at the correct positions in the cache. Especially important for fixed-size caches like [`StaticCache`] that preallocate certain length for cache in advance. + +Usually you don't have to worry about cache position as the generation loop takes care of it. Yet, in cases where you are writing your own custom generation method, it is important that cache positions are accurate since they are used to write and read key/value states into fixed slots. + +There is an example usage with simple custom generation loop above, so let's create a more complex example here. + +```py +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache + +model_id = "meta-llama/Llama-2-7b-chat-hf" +model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda:0") +tokenizer = AutoTokenizer.from_pretrained(model_id) + +messages = [{"role": "user", "content": "You are a helpful assistant."}] +inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda:0") +generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=10) + +``` + + ## Legacy cache format Before the [`Cache`] class, the cache used to be stored as a tuple of tuples of tensors. This format is dynamic because it grows as text is generated, similar to [`DynamicCache`]. @@ -157,4 +186,4 @@ generation_outputs = model.generate(**inputs, return_dict_in_generate=True, retu cache = DynamicCache.from_legacy_cache(generation_outputs.past_key_values) legacy_format_cache = cache.to_legacy_cache() -``` \ No newline at end of file +``` diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index e8e20dab5db6..926ffc34dcd1 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -341,7 +341,7 @@ A known issue with transformer models is that the self-attention mechanism grows FlashAttention and [FlashAttention-2](./perf_infer_gpu_one#flashattention-2) break up the attention computation into smaller chunks and reduces the number of intermediate read/write operations to the GPU memory to speed up inference. FlashAttention-2 improves on the original FlashAttention algorithm by also parallelizing over sequence length dimension and better partitioning work on the hardware to reduce synchronization and communication overhead. -To use FlashAttention-2, set [attn_implementation](https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.PreTrainedModel.from_pretrained.attn_implementation) to `"flash_attention_2"` in [`~PreTrainedModel.from_pretrained`]. +To use FlashAttention-2, set [attn_implementation](https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.PreTrainedModel.from_pretrained.attn_implementation) to `"flash_attention_2"` in [`~PreTrainedModel.from_pretrained`] or set with `model.set_attention_implementation("flash_attention_2")` to dynamically update the [attention interface](./attention_interface) after the model is loaded. ```py from transformers import AutoModelForCausalLM, BitsAndBytesConfig @@ -360,7 +360,7 @@ model = AutoModelForCausalLM.from_pretrained( Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation. > [!TIP] -> SDPA automaticallysupports FlashAttention-2 as long as you have the latest PyTorch version installed. +> SDPA automatically supports FlashAttention-2 as long as you have the latest PyTorch version installed. Use the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to explicitly enable or disable any of the four attention algorithms. For example, use `SDPBackend.FLASH_ATTENTION` to enable FlashAttention. diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index c3a7ddc8d8af..7c60cb189dee 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -175,7 +175,7 @@ There are three supported implementations available. - [xFormers](https://github.com/facebookresearch/xformers) or Memory-Efficient Attention is able to support models with the fp32 torch type. - C++ implementation of scaled dot product attention -SDPA is used by default for PyTorch v2.1.1. and greater when an implementation is available. You could explicitly enable SDPA by setting `attn_implementation="sdpa"` in [`~PreTrainedModel.from_pretrained`] though. Certain attention parameters, such as `head_mask` and `output_attentions=True`, are unsupported and returns a warning that Transformers will fall back to the (slower) eager implementation. +SDPA is used by default for PyTorch v2.1.1. and greater when an implementation is available. You could explicitly enable SDPA by setting `attn_implementation="sdpa"` in [`~PreTrainedModel.from_pretrained`] though. Certain attention parameters, such as `head_mask` and `output_attentions=True`, are unsupported and returns a warning that Transformers will fall back to the (slower) eager implementation. You can also change the attention implementation after loading the model for most architectures. For details, check out the [attention interfaces documentation](./attention_interface). ```py from transformers import AutoModelForCausalLM @@ -234,7 +234,7 @@ FlashAttention2 support is currently limited to Instinct MI210, Instinct MI250 a -Enable FlashAttention2 by setting `attn_implementation="flash_attention_2"` in [`~PreTrainedModel.from_pretrained`]. FlashAttention2 is only supported for models with the fp16 or bf16 torch type. Make sure to cast your model to the appropriate data type first. +Enable FlashAttention2 by setting `attn_implementation="flash_attention_2"` in [`~PreTrainedModel.from_pretrained`] or by setting `model.set_attention_implementation("flash_attention_2")` to dynamically update the [attention interface](./attention_interface). FlashAttention2 is only supported for models with the fp16 or bf16 torch type. Make sure to cast your model to the appropriate data type first. ```py from transformers import AutoModelForCausalLM From 5c99f69e431341d3bd69ef9ad0c9db0dc214c717 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 22 Jul 2025 09:36:43 +0200 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/cache_explanation.md | 9 ++++----- docs/source/en/perf_infer_gpu_one.md | 4 +++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/source/en/cache_explanation.md b/docs/source/en/cache_explanation.md index 6ef564abe8af..c0c3a86c55ac 100644 --- a/docs/source/en/cache_explanation.md +++ b/docs/source/en/cache_explanation.md @@ -133,18 +133,17 @@ print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]) "[INST] Hello, what's your name. [/INST] Hello! My name is LLaMA," ``` -### Cache Position Usage +## Cache position -The cache position tracks where the new tokens should be inserted in the attention cache. It represents the *absolute* position of each token in the context, independent of padding or batch structure. Suppose that you have already cached `N` tokens and you are now processing `K` new tokens. Then the cache position for the new tokens will range from `N` to `N + K - 1`. In other words, you're processing tokens at positions - `[N, N + 1, N + 2, ..., N + K - 1]` +The cache position tracks where to insert new tokens in the attention cache. It represents the *absolute* position of each token in the context, independent of padding or batch structure. Suppose you already cached `N` tokens and are now processing `K` new tokens. The cache position for the new tokens will range from `N` to `N + K - 1`. In other words, you're processing tokens at positions - `[N, N + 1, N + 2, ..., N + K - 1]`. Cache position is used internally for two purposes: 1. Selecting new tokens to process in the input sequence and ensuring only tokens that haven’t been cached yet are passed to the model's `forward`. -2. Storing key/value pairs at the correct positions in the cache. Especially important for fixed-size caches like [`StaticCache`] that preallocate certain length for cache in advance. +2. Storing key/value pairs at the correct positions in the cache. This is especially important for fixed-size caches, like [`StaticCache`], that pre-allocates a specific cache length. -Usually you don't have to worry about cache position as the generation loop takes care of it. Yet, in cases where you are writing your own custom generation method, it is important that cache positions are accurate since they are used to write and read key/value states into fixed slots. +The generation loop usually takes care of the cache position, but if you're writing a custom generation method, it is important that cache positions are accurate since they are used to write and read key/value states into fixed slots. -There is an example usage with simple custom generation loop above, so let's create a more complex example here. ```py import torch diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 7c60cb189dee..66b5bb6dcf24 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -175,7 +175,9 @@ There are three supported implementations available. - [xFormers](https://github.com/facebookresearch/xformers) or Memory-Efficient Attention is able to support models with the fp32 torch type. - C++ implementation of scaled dot product attention -SDPA is used by default for PyTorch v2.1.1. and greater when an implementation is available. You could explicitly enable SDPA by setting `attn_implementation="sdpa"` in [`~PreTrainedModel.from_pretrained`] though. Certain attention parameters, such as `head_mask` and `output_attentions=True`, are unsupported and returns a warning that Transformers will fall back to the (slower) eager implementation. You can also change the attention implementation after loading the model for most architectures. For details, check out the [attention interfaces documentation](./attention_interface). +SDPA is used by default for PyTorch v2.1.1. and greater when an implementation is available. You could explicitly enable SDPA by setting `attn_implementation="sdpa"` in [`~PreTrainedModel.from_pretrained`] though. Certain attention parameters, such as `head_mask` and `output_attentions=True`, are unsupported and returns a warning that Transformers will fall back to the (slower) eager implementation. + +Refer to the [AttentionInterface](./attention_interface) guide to learn how to change the attention implementation after loading a model. ```py from transformers import AutoModelForCausalLM From c14bd48a0b053db991e047ac603c8437de9c1b4f Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 22 Jul 2025 09:39:46 +0200 Subject: [PATCH 3/3] applu suggestions --- docs/source/en/attention_interface.md | 2 +- docs/source/en/llm_optims.md | 8 ++++++++ docs/source/en/perf_infer_gpu_one.md | 4 ++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/docs/source/en/attention_interface.md b/docs/source/en/attention_interface.md index 44cde32f6d24..42c43973e79a 100644 --- a/docs/source/en/attention_interface.md +++ b/docs/source/en/attention_interface.md @@ -74,7 +74,7 @@ This allows to quickly change an attention function, without needing to reload t ## Different attention per backbone in multimodal models -For multimodal models, you may want to load different backbones with different attention functions. For example, some vision backbones perform better with full precision only and are incompatible with Flash Attention. If you want to take advantage of Flash Attention while keeping your vision encoder in fp32, you can configure different attention implementations per backbone. +For multimodal models different attention functions may work better for each backbone module. For example, some vision backbones perform better in fp32, but are incompatible with FlashAttention. To continue using FlashAttention while keeping the vision encoder in fp32, create a dict and map each config to an attention implementation as shown below. ```python from transformers import AutoModelForImageTextToText diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index 926ffc34dcd1..0295a5bf1b34 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -353,6 +353,14 @@ model = AutoModelForCausalLM.from_pretrained( torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", ) + +# Change the model's attention dynamically after loading +model = AutoModelForCausalLM.from_pretrained( + "google/gemma-2b", + quantization_config=quant_config, + torch_dtype=torch.bfloat16 +) +model.set_attention_implementation("flash_attention_2") ``` ### PyTorch scaled dot product attention diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 66b5bb6dcf24..fa726e1f98b4 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -183,6 +183,10 @@ Refer to the [AttentionInterface](./attention_interface) guide to learn how to c from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", device_map="auto", attn_implementation="sdpa") + +# Change the model's attention dynamically after loading it +model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", device_map="auto") +model.set_attention_implementation("sdpa") ``` SDPA selects the most performant implementation available, but you can also explicitly select an implementation with [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager. The example below shows how to enable the FlashAttention2 implementation with `enable_flash=True`.