diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 08b1a7481d91..210e06481b2f 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -387,3 +387,17 @@ class FlashAttentionKwargs(TypedDict, total=False): cu_seq_lens_k: Optional[torch.LongTensor] max_length_q: Optional[int] max_length_k: Optional[int] + + +def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: + if cu_seq_lens.ndim != 1: + raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") + pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) + seq_lens = cu_seq_lens.diff(dim=-1) + max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) + start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) + for s in seq_lens: + pos_ids[start : start + s] = max_arange[:s] + start += s + + return pos_ids[None] diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index c55d1feb6d9f..bf659a524a0e 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -25,7 +25,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -939,7 +939,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 7337ae6acf49..de9e86cbb5ad 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -36,7 +36,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -589,7 +589,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index c262340aacf9..e9512b4dad36 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -31,7 +31,11 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_flash_attention_utils import ( + FlashAttentionKwargs, + _flash_attention_forward, + get_position_ids_from_cu_seq_lens, +) from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -828,7 +832,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index b31e14910a9b..fa81824abcfb 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -32,7 +32,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -1407,7 +1407,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index a3461ffd71cb..dc98bac5d32b 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -28,7 +28,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -570,7 +570,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 71518c4a9aa8..efc1ad9e55d0 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -29,7 +29,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -557,7 +557,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 361ae15c3127..ec05404495df 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -27,7 +27,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -559,7 +559,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index cc62d378ebae..da170cd57e2c 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -13,7 +13,7 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -531,7 +531,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index c2e1ae15b4b5..a5960d1f8c64 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -14,7 +14,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -535,7 +535,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 163956d61a22..916c98709000 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -13,7 +13,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -536,7 +536,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index e86e028b4027..d7c640517e79 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -29,7 +29,7 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -601,7 +601,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 96cd6a6aa32e..2d0c3957f740 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -13,7 +13,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -544,7 +544,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index c2abf19b2241..f598df1aba9a 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -18,14 +18,18 @@ from packaging import version from parameterized import parameterized +from pytest import mark from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed from transformers.generation.configuration_utils import GenerationConfig +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from transformers.testing_utils import ( cleanup, + require_flash_attn, require_read_token, require_torch, require_torch_accelerator, + require_torch_gpu, slow, torch_device, ) @@ -539,6 +543,87 @@ def _reinitialize_config(base_config, new_kwargs): with self.assertRaises(KeyError): config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor" + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + def test_attn_mask_position_ids_flash_attn_equality(self): + r""" + Verify that the logits agree when using an attention mask, position_ids, or + FlashAttentionKwargs. + """ + torch.manual_seed(42) + decoder_only_classes = [] + for model_class in self.all_generative_model_classes: + config, *_ = self.model_tester.prepare_config_and_inputs() + if config.is_encoder_decoder: + continue + else: + decoder_only_classes.append(model_class) + if len(decoder_only_classes) == 0: + self.skipTest(reason="No decoder-only architecture available for this model.") + + # - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't + # added support for it yet. We skip these models for now. + has_encoder_attributes = any( + attr_name + for attr_name in config.to_dict().keys() + if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size" + ) + if has_encoder_attributes: + self.skipTest( + reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding." + ) + + for model_class in decoder_only_classes: + config, input_ids, _, input_mask, *_ = self.model_tester.prepare_config_and_inputs() + # Padding-free requires training = True and attn_implementation="flash_attention_2" + model = ( + model_class._from_config(config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16) + .to(torch_device) + .train() + ) + + non_padding_free_inputs = {"input_ids": input_ids, "attention_mask": input_mask} + attn_mask_logits = model(**non_padding_free_inputs).logits + + # Build up padding-free tensors + padding_free_input_ids = torch.cat( + [batch[mask.bool()] for batch, mask in zip(input_ids, input_mask)], dim=-1 + )[None] + position_ids_list = [ + torch.arange(mask.sum(), device=mask.device, dtype=torch.int32) for mask in input_mask + ] + position_ids = torch.cat(position_ids_list, dim=-1)[None] + seq_lens = torch.cat( + [torch.tensor([t.numel()], device=input_mask.device, dtype=torch.int32) for t in position_ids_list], + dim=-1, + ) + cu_seq_lens = torch.cat( + [ + torch.tensor([0], device=input_mask.device, dtype=torch.int32), + seq_lens.cumsum(dim=-1, dtype=torch.int32), + ], + dim=-1, + ) + + position_ids_inputs = {"input_ids": padding_free_input_ids, "position_ids": position_ids} + position_ids_logits = model(**position_ids_inputs).logits + + flash_attn_kwargs = FlashAttentionKwargs( + cu_seq_lens_q=cu_seq_lens, + cu_seq_lens_k=cu_seq_lens, + max_length_q=seq_lens.max(), + max_length_k=seq_lens.max(), + ) + flash_attn_kwargs_logits = model(input_ids=padding_free_input_ids, **flash_attn_kwargs).logits + + attn_mask_logits_reshaped = torch.cat( + [batch[mask.bool()] for batch, mask in zip(attn_mask_logits, input_mask)], dim=0 + )[None] + + torch.testing.assert_close(position_ids_logits, attn_mask_logits_reshaped) + torch.testing.assert_close(position_ids_logits, flash_attn_kwargs_logits) + @require_torch_accelerator class LlamaIntegrationTest(unittest.TestCase): @@ -1052,3 +1137,24 @@ def test_partial_stacked_causal_mask_static_cache(self): ] ] self.assertEqual(decoded, decoded_1b) + + +def test_pos_ids_from_cu_seq_lens() -> None: + n_chunks = 5 + max_chunk_len = 64 + + seq_lens = torch.randint(1, max_chunk_len, size=(n_chunks,)) + cu_seq_lens = torch.cat([torch.tensor([0]), seq_lens.cumsum(dim=-1)], dim=-1) + pos_ids = torch.cat( + [ + torch.arange( + s, + dtype=torch.int32, + device=cu_seq_lens.device, + ) + for s in cu_seq_lens.diff(dim=-1) + ], + dim=-1, + )[None] + pos_ids_pred = get_position_ids_from_cu_seq_lens(cu_seq_lens) + assert torch.allclose(pos_ids_pred, pos_ids)