From 08da972cfbae22424256355892468103708d002d Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Tue, 28 Jan 2025 17:34:22 +0000 Subject: [PATCH 1/5] fix FlashAttentionKwargs RoPE --- .../models/llama/modeling_llama.py | 12 +- tests/models/llama/test_modeling_llama.py | 107 ++++++++++++++++++ 2 files changed, 118 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 361ae15c3127..4baa8beb6f5b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -559,7 +559,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions @@ -1151,6 +1154,13 @@ def forward( ) +def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: + pos_ids = torch.cat( + [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 + )[None] + return pos_ids + + __all__ = [ "LlamaForCausalLM", "LlamaModel", diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index c2abf19b2241..b3cd22e3d5ab 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -18,14 +18,19 @@ from packaging import version from parameterized import parameterized +from pytest import mark from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed from transformers.generation.configuration_utils import GenerationConfig +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.models.llama.modeling_llama import get_position_ids_from_cu_seq_lens from transformers.testing_utils import ( cleanup, + require_flash_attn, require_read_token, require_torch, require_torch_accelerator, + require_torch_gpu, slow, torch_device, ) @@ -539,6 +544,87 @@ def _reinitialize_config(base_config, new_kwargs): with self.assertRaises(KeyError): config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor" + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + def test_attn_mask_position_ids_flash_attn_equality(self): + r""" + Verify that the logits agree when using an attention mask, position_ids, or + FlashAttentionKwargs. + """ + torch.manual_seed(42) + decoder_only_classes = [] + for model_class in self.all_generative_model_classes: + config, *_ = self.model_tester.prepare_config_and_inputs() + if config.is_encoder_decoder: + continue + else: + decoder_only_classes.append(model_class) + if len(decoder_only_classes) == 0: + self.skipTest(reason="No decoder-only architecture available for this model.") + + # - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't + # added support for it yet. We skip these models for now. + has_encoder_attributes = any( + attr_name + for attr_name in config.to_dict().keys() + if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size" + ) + if has_encoder_attributes: + self.skipTest( + reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding." + ) + + for model_class in decoder_only_classes: + config, input_ids, _, input_mask, *_ = self.model_tester.prepare_config_and_inputs() + # Padding-free requires training = True and attn_implementation="flash_attention_2" + model = ( + model_class._from_config(config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16) + .to(torch_device) + .train() + ) + + non_padding_free_inputs = {"input_ids": input_ids, "attention_mask": input_mask} + attn_mask_logits = model(**non_padding_free_inputs).logits + + # Build up padding-free tensors + padding_free_input_ids = torch.cat( + [batch[mask.bool()] for batch, mask in zip(input_ids, input_mask)], dim=-1 + )[None] + position_ids_list = [ + torch.arange(mask.sum(), device=mask.device, dtype=torch.int32) for mask in input_mask + ] + position_ids = torch.cat(position_ids_list, dim=-1)[None] + seq_lens = torch.cat( + [torch.tensor([t.numel()], device=input_mask.device, dtype=torch.int32) for t in position_ids_list], + dim=-1, + ) + cu_seq_lens = torch.cat( + [ + torch.tensor([0], device=input_mask.device, dtype=torch.int32), + seq_lens.cumsum(dim=-1, dtype=torch.int32), + ], + dim=-1, + ) + + position_ids_inputs = {"input_ids": padding_free_input_ids, "position_ids": position_ids} + position_ids_logits = model(**position_ids_inputs).logits + + flash_attn_kwargs = FlashAttentionKwargs( + cu_seq_lens_q=cu_seq_lens, + cu_seq_lens_k=cu_seq_lens, + max_length_q=seq_lens.max(), + max_length_k=seq_lens.max(), + ) + flash_attn_kwargs_logits = model(input_ids=padding_free_input_ids, **flash_attn_kwargs).logits + + attn_mask_logits_reshaped = torch.cat( + [batch[mask.bool()] for batch, mask in zip(attn_mask_logits, input_mask)], dim=0 + )[None] + + torch.testing.assert_close(position_ids_logits, attn_mask_logits_reshaped) + torch.testing.assert_close(position_ids_logits, flash_attn_kwargs_logits) + @require_torch_accelerator class LlamaIntegrationTest(unittest.TestCase): @@ -1052,3 +1138,24 @@ def test_partial_stacked_causal_mask_static_cache(self): ] ] self.assertEqual(decoded, decoded_1b) + + +def test_pos_ids_from_cu_seq_lens() -> None: + n_chunks = 5 + max_chunk_len = 64 + + seq_lens = torch.randint(1, max_chunk_len, size=(n_chunks,)) + cu_seq_lens = torch.cat([torch.tensor([0]), seq_lens.cumsum(dim=-1)], dim=-1) + pos_ids = torch.cat( + [ + torch.arange( + s, + dtype=torch.int32, + device=cu_seq_lens.device, + ) + for s in cu_seq_lens.diff(dim=-1) + ], + dim=-1, + )[None] + pos_ids_pred = get_position_ids_from_cu_seq_lens(cu_seq_lens) + assert torch.allclose(pos_ids_pred, pos_ids) From 0804ea594f2ce34908c935d64eed4059a9df9907 Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Tue, 28 Jan 2025 17:54:42 +0000 Subject: [PATCH 2/5] run modular_model_converter.py --- src/transformers/models/aria/modeling_aria.py | 12 +++++++++++- src/transformers/models/cohere/modeling_cohere.py | 12 +++++++++++- .../models/diffllama/modeling_diffllama.py | 12 +++++++++++- src/transformers/models/emu3/modeling_emu3.py | 12 +++++++++++- src/transformers/models/glm/modeling_glm.py | 12 +++++++++++- src/transformers/models/helium/modeling_helium.py | 12 +++++++++++- src/transformers/models/mistral/modeling_mistral.py | 12 +++++++++++- src/transformers/models/olmo/modeling_olmo.py | 12 +++++++++++- src/transformers/models/olmo2/modeling_olmo2.py | 12 +++++++++++- src/transformers/models/phi3/modeling_phi3.py | 12 +++++++++++- src/transformers/models/qwen2/modeling_qwen2.py | 12 +++++++++++- 11 files changed, 121 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index c55d1feb6d9f..94348a474f20 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -861,6 +861,13 @@ def forward(self, x, position_ids): """ +def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: + pos_ids = torch.cat( + [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 + )[None] + return pos_ids + + @add_start_docstrings( "The bare AriaText Model outputting raw hidden-states without any specific head on top.", ARIA_TEXT_START_DOCSTRING, @@ -939,7 +946,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 7337ae6acf49..b050897fc32a 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -511,6 +511,13 @@ def _init_weights(self, module): """ +def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: + pos_ids = torch.cat( + [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 + )[None] + return pos_ids + + @add_start_docstrings( "The bare Cohere Model outputting raw hidden-states without any specific head on top.", COHERE_START_DOCSTRING, @@ -589,7 +596,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index c262340aacf9..d39643ba01d2 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -750,6 +750,13 @@ def forward(self, x, position_ids): """ +def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: + pos_ids = torch.cat( + [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 + )[None] + return pos_ids + + @add_start_docstrings( "The bare DiffLlama Model outputting raw hidden-states without any specific head on top.", DIFFLLAMA_START_DOCSTRING, @@ -828,7 +835,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index b31e14910a9b..f83d373cbb77 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1258,6 +1258,13 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: + pos_ids = torch.cat( + [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 + )[None] + return pos_ids + + EMU3_TEXT_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -1407,7 +1414,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index a3461ffd71cb..9327f1dd0d60 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -492,6 +492,13 @@ def _init_weights(self, module): """ +def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: + pos_ids = torch.cat( + [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 + )[None] + return pos_ids + + @add_start_docstrings( "The bare Glm Model outputting raw hidden-states without any specific head on top.", GLM_START_DOCSTRING, @@ -570,7 +577,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 71518c4a9aa8..6bd96a058b58 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -479,6 +479,13 @@ def _init_weights(self, module): """ +def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: + pos_ids = torch.cat( + [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 + )[None] + return pos_ids + + @add_start_docstrings( "The bare Helium Model outputting raw hidden-states without any specific head on top.", HELIUM_START_DOCSTRING, @@ -557,7 +564,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index cc62d378ebae..a2160f6b26ba 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -453,6 +453,13 @@ def _init_weights(self, module): """ +def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: + pos_ids = torch.cat( + [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 + )[None] + return pos_ids + + @add_start_docstrings( "The bare Mistral Model outputting raw hidden-states without any specific head on top.", MISTRAL_START_DOCSTRING, @@ -531,7 +538,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index c2e1ae15b4b5..637a01b36285 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -457,6 +457,13 @@ def _init_weights(self, module): """ +def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: + pos_ids = torch.cat( + [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 + )[None] + return pos_ids + + @add_start_docstrings( "The bare Olmo Model outputting raw hidden-states without any specific head on top.", OLMO_START_DOCSTRING, @@ -535,7 +542,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 163956d61a22..95bb80b3feb5 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -458,6 +458,13 @@ def _init_weights(self, module): """ +def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: + pos_ids = torch.cat( + [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 + )[None] + return pos_ids + + @add_start_docstrings( "The bare Olmo2 Model outputting raw hidden-states without any specific head on top.", OLMO2_START_DOCSTRING, @@ -536,7 +543,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index e86e028b4027..413b3e55e6bf 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -523,6 +523,13 @@ def _init_weights(self, module): """ +def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: + pos_ids = torch.cat( + [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 + )[None] + return pos_ids + + @add_start_docstrings( "The bare Phi3 Model outputting raw hidden-states without any specific head on top.", PHI3_START_DOCSTRING, @@ -601,7 +608,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 96cd6a6aa32e..bb03e03eaca8 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -466,6 +466,13 @@ def _init_weights(self, module): """ +def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: + pos_ids = torch.cat( + [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 + )[None] + return pos_ids + + @add_start_docstrings( "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", QWEN2_START_DOCSTRING, @@ -544,7 +551,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + if "cu_seq_lens_q" in flash_attn_kwargs: + position_ids = get_position_ids_from_cu_seq_lens(flash_attn_kwargs["cu_seq_lens_q"]) + else: + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions From 06a268e802d9d520c01e785c83b7762448089f6e Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Wed, 5 Feb 2025 16:41:05 +0000 Subject: [PATCH 3/5] minimize tensor allocations --- src/transformers/models/aria/modeling_aria.py | 15 +++++++++++---- src/transformers/models/cohere/modeling_cohere.py | 15 +++++++++++---- .../models/diffllama/modeling_diffllama.py | 15 +++++++++++---- src/transformers/models/emu3/modeling_emu3.py | 15 +++++++++++---- src/transformers/models/glm/modeling_glm.py | 15 +++++++++++---- src/transformers/models/helium/modeling_helium.py | 15 +++++++++++---- src/transformers/models/llama/modeling_llama.py | 15 +++++++++++---- .../models/mistral/modeling_mistral.py | 15 +++++++++++---- src/transformers/models/olmo/modeling_olmo.py | 15 +++++++++++---- src/transformers/models/olmo2/modeling_olmo2.py | 15 +++++++++++---- src/transformers/models/phi3/modeling_phi3.py | 15 +++++++++++---- src/transformers/models/qwen2/modeling_qwen2.py | 15 +++++++++++---- 12 files changed, 132 insertions(+), 48 deletions(-) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 94348a474f20..641781580961 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -862,10 +862,17 @@ def forward(self, x, position_ids): def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - pos_ids = torch.cat( - [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 - )[None] - return pos_ids + if cu_seq_lens.ndim != 1: + raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") + pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) + seq_lens = cu_seq_lens.diff(dim=-1) + max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) + start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) + for s in seq_lens: + pos_ids[start : start + s] = max_arange[:s] + start += s + + return pos_ids[None] @add_start_docstrings( diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index b050897fc32a..bc649f856d5b 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -512,10 +512,17 @@ def _init_weights(self, module): def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - pos_ids = torch.cat( - [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 - )[None] - return pos_ids + if cu_seq_lens.ndim != 1: + raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") + pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) + seq_lens = cu_seq_lens.diff(dim=-1) + max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) + start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) + for s in seq_lens: + pos_ids[start : start + s] = max_arange[:s] + start += s + + return pos_ids[None] @add_start_docstrings( diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index d39643ba01d2..f0e150795a77 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -751,10 +751,17 @@ def forward(self, x, position_ids): def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - pos_ids = torch.cat( - [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 - )[None] - return pos_ids + if cu_seq_lens.ndim != 1: + raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") + pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) + seq_lens = cu_seq_lens.diff(dim=-1) + max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) + start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) + for s in seq_lens: + pos_ids[start : start + s] = max_arange[:s] + start += s + + return pos_ids[None] @add_start_docstrings( diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index f83d373cbb77..5dcc28b31672 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1259,10 +1259,17 @@ def forward(self, x, position_ids): def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - pos_ids = torch.cat( - [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 - )[None] - return pos_ids + if cu_seq_lens.ndim != 1: + raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") + pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) + seq_lens = cu_seq_lens.diff(dim=-1) + max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) + start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) + for s in seq_lens: + pos_ids[start : start + s] = max_arange[:s] + start += s + + return pos_ids[None] EMU3_TEXT_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 9327f1dd0d60..adc59bc04c05 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -493,10 +493,17 @@ def _init_weights(self, module): def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - pos_ids = torch.cat( - [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 - )[None] - return pos_ids + if cu_seq_lens.ndim != 1: + raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") + pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) + seq_lens = cu_seq_lens.diff(dim=-1) + max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) + start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) + for s in seq_lens: + pos_ids[start : start + s] = max_arange[:s] + start += s + + return pos_ids[None] @add_start_docstrings( diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 6bd96a058b58..31e206450fac 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -480,10 +480,17 @@ def _init_weights(self, module): def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - pos_ids = torch.cat( - [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 - )[None] - return pos_ids + if cu_seq_lens.ndim != 1: + raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") + pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) + seq_lens = cu_seq_lens.diff(dim=-1) + max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) + start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) + for s in seq_lens: + pos_ids[start : start + s] = max_arange[:s] + start += s + + return pos_ids[None] @add_start_docstrings( diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4baa8beb6f5b..6d043ea4c0cb 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1155,10 +1155,17 @@ def forward( def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - pos_ids = torch.cat( - [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 - )[None] - return pos_ids + if cu_seq_lens.ndim != 1: + raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") + pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) + seq_lens = cu_seq_lens.diff(dim=-1) + max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) + start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) + for s in seq_lens: + pos_ids[start : start + s] = max_arange[:s] + start += s + + return pos_ids[None] __all__ = [ diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index a2160f6b26ba..c79edebaca98 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -454,10 +454,17 @@ def _init_weights(self, module): def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - pos_ids = torch.cat( - [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 - )[None] - return pos_ids + if cu_seq_lens.ndim != 1: + raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") + pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) + seq_lens = cu_seq_lens.diff(dim=-1) + max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) + start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) + for s in seq_lens: + pos_ids[start : start + s] = max_arange[:s] + start += s + + return pos_ids[None] @add_start_docstrings( diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 637a01b36285..3b668531f2f4 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -458,10 +458,17 @@ def _init_weights(self, module): def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - pos_ids = torch.cat( - [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 - )[None] - return pos_ids + if cu_seq_lens.ndim != 1: + raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") + pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) + seq_lens = cu_seq_lens.diff(dim=-1) + max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) + start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) + for s in seq_lens: + pos_ids[start : start + s] = max_arange[:s] + start += s + + return pos_ids[None] @add_start_docstrings( diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 95bb80b3feb5..cd5584afaa2d 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -459,10 +459,17 @@ def _init_weights(self, module): def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - pos_ids = torch.cat( - [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 - )[None] - return pos_ids + if cu_seq_lens.ndim != 1: + raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") + pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) + seq_lens = cu_seq_lens.diff(dim=-1) + max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) + start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) + for s in seq_lens: + pos_ids[start : start + s] = max_arange[:s] + start += s + + return pos_ids[None] @add_start_docstrings( diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 413b3e55e6bf..42c39d7d63a2 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -524,10 +524,17 @@ def _init_weights(self, module): def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - pos_ids = torch.cat( - [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 - )[None] - return pos_ids + if cu_seq_lens.ndim != 1: + raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") + pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) + seq_lens = cu_seq_lens.diff(dim=-1) + max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) + start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) + for s in seq_lens: + pos_ids[start : start + s] = max_arange[:s] + start += s + + return pos_ids[None] @add_start_docstrings( diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index bb03e03eaca8..63f316d61dcd 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -467,10 +467,17 @@ def _init_weights(self, module): def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - pos_ids = torch.cat( - [torch.arange(s, dtype=torch.int32, device=cu_seq_lens.device) for s in cu_seq_lens.diff(dim=-1)], dim=-1 - )[None] - return pos_ids + if cu_seq_lens.ndim != 1: + raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") + pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) + seq_lens = cu_seq_lens.diff(dim=-1) + max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) + start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) + for s in seq_lens: + pos_ids[start : start + s] = max_arange[:s] + start += s + + return pos_ids[None] @add_start_docstrings( From 973d362078f73df46010481c795aa32a41a99cd8 Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Tue, 11 Feb 2025 16:26:36 +0000 Subject: [PATCH 4/5] mv get_position_ids_from_cu_seq_lens to flash util --- .../modeling_flash_attention_utils.py | 14 +++++++++++++ src/transformers/models/aria/modeling_aria.py | 16 +-------------- .../models/cohere/modeling_cohere.py | 16 +-------------- .../models/diffllama/modeling_diffllama.py | 20 +++++-------------- src/transformers/models/emu3/modeling_emu3.py | 16 +-------------- src/transformers/models/glm/modeling_glm.py | 16 +-------------- .../models/helium/modeling_helium.py | 16 +-------------- .../models/llama/modeling_llama.py | 16 +-------------- .../models/mistral/modeling_mistral.py | 16 +-------------- src/transformers/models/olmo/modeling_olmo.py | 16 +-------------- .../models/olmo2/modeling_olmo2.py | 16 +-------------- src/transformers/models/phi3/modeling_phi3.py | 16 +-------------- .../models/qwen2/modeling_qwen2.py | 16 +-------------- tests/models/llama/test_modeling_llama.py | 2 +- 14 files changed, 31 insertions(+), 181 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 08b1a7481d91..210e06481b2f 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -387,3 +387,17 @@ class FlashAttentionKwargs(TypedDict, total=False): cu_seq_lens_k: Optional[torch.LongTensor] max_length_q: Optional[int] max_length_k: Optional[int] + + +def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: + if cu_seq_lens.ndim != 1: + raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") + pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) + seq_lens = cu_seq_lens.diff(dim=-1) + max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) + start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) + for s in seq_lens: + pos_ids[start : start + s] = max_arange[:s] + start += s + + return pos_ids[None] diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 641781580961..bf659a524a0e 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -25,7 +25,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -861,20 +861,6 @@ def forward(self, x, position_ids): """ -def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - if cu_seq_lens.ndim != 1: - raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") - pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) - seq_lens = cu_seq_lens.diff(dim=-1) - max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) - start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) - for s in seq_lens: - pos_ids[start : start + s] = max_arange[:s] - start += s - - return pos_ids[None] - - @add_start_docstrings( "The bare AriaText Model outputting raw hidden-states without any specific head on top.", ARIA_TEXT_START_DOCSTRING, diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index bc649f856d5b..de9e86cbb5ad 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -36,7 +36,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -511,20 +511,6 @@ def _init_weights(self, module): """ -def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - if cu_seq_lens.ndim != 1: - raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") - pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) - seq_lens = cu_seq_lens.diff(dim=-1) - max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) - start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) - for s in seq_lens: - pos_ids[start : start + s] = max_arange[:s] - start += s - - return pos_ids[None] - - @add_start_docstrings( "The bare Cohere Model outputting raw hidden-states without any specific head on top.", COHERE_START_DOCSTRING, diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index f0e150795a77..e9512b4dad36 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -31,7 +31,11 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_flash_attention_utils import ( + FlashAttentionKwargs, + _flash_attention_forward, + get_position_ids_from_cu_seq_lens, +) from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -750,20 +754,6 @@ def forward(self, x, position_ids): """ -def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - if cu_seq_lens.ndim != 1: - raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") - pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) - seq_lens = cu_seq_lens.diff(dim=-1) - max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) - start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) - for s in seq_lens: - pos_ids[start : start + s] = max_arange[:s] - start += s - - return pos_ids[None] - - @add_start_docstrings( "The bare DiffLlama Model outputting raw hidden-states without any specific head on top.", DIFFLLAMA_START_DOCSTRING, diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 5dcc28b31672..fa81824abcfb 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -32,7 +32,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -1258,20 +1258,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - if cu_seq_lens.ndim != 1: - raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") - pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) - seq_lens = cu_seq_lens.diff(dim=-1) - max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) - start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) - for s in seq_lens: - pos_ids[start : start + s] = max_arange[:s] - start += s - - return pos_ids[None] - - EMU3_TEXT_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index adc59bc04c05..dc98bac5d32b 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -28,7 +28,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -492,20 +492,6 @@ def _init_weights(self, module): """ -def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - if cu_seq_lens.ndim != 1: - raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") - pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) - seq_lens = cu_seq_lens.diff(dim=-1) - max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) - start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) - for s in seq_lens: - pos_ids[start : start + s] = max_arange[:s] - start += s - - return pos_ids[None] - - @add_start_docstrings( "The bare Glm Model outputting raw hidden-states without any specific head on top.", GLM_START_DOCSTRING, diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 31e206450fac..efc1ad9e55d0 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -29,7 +29,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -479,20 +479,6 @@ def _init_weights(self, module): """ -def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - if cu_seq_lens.ndim != 1: - raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") - pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) - seq_lens = cu_seq_lens.diff(dim=-1) - max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) - start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) - for s in seq_lens: - pos_ids[start : start + s] = max_arange[:s] - start += s - - return pos_ids[None] - - @add_start_docstrings( "The bare Helium Model outputting raw hidden-states without any specific head on top.", HELIUM_START_DOCSTRING, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 6d043ea4c0cb..ec05404495df 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -27,7 +27,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -1154,20 +1154,6 @@ def forward( ) -def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - if cu_seq_lens.ndim != 1: - raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") - pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) - seq_lens = cu_seq_lens.diff(dim=-1) - max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) - start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) - for s in seq_lens: - pos_ids[start : start + s] = max_arange[:s] - start += s - - return pos_ids[None] - - __all__ = [ "LlamaForCausalLM", "LlamaModel", diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index c79edebaca98..da170cd57e2c 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -13,7 +13,7 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -453,20 +453,6 @@ def _init_weights(self, module): """ -def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - if cu_seq_lens.ndim != 1: - raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") - pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) - seq_lens = cu_seq_lens.diff(dim=-1) - max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) - start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) - for s in seq_lens: - pos_ids[start : start + s] = max_arange[:s] - start += s - - return pos_ids[None] - - @add_start_docstrings( "The bare Mistral Model outputting raw hidden-states without any specific head on top.", MISTRAL_START_DOCSTRING, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 3b668531f2f4..a5960d1f8c64 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -14,7 +14,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -457,20 +457,6 @@ def _init_weights(self, module): """ -def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - if cu_seq_lens.ndim != 1: - raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") - pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) - seq_lens = cu_seq_lens.diff(dim=-1) - max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) - start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) - for s in seq_lens: - pos_ids[start : start + s] = max_arange[:s] - start += s - - return pos_ids[None] - - @add_start_docstrings( "The bare Olmo Model outputting raw hidden-states without any specific head on top.", OLMO_START_DOCSTRING, diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index cd5584afaa2d..916c98709000 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -13,7 +13,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -458,20 +458,6 @@ def _init_weights(self, module): """ -def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - if cu_seq_lens.ndim != 1: - raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") - pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) - seq_lens = cu_seq_lens.diff(dim=-1) - max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) - start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) - for s in seq_lens: - pos_ids[start : start + s] = max_arange[:s] - start += s - - return pos_ids[None] - - @add_start_docstrings( "The bare Olmo2 Model outputting raw hidden-states without any specific head on top.", OLMO2_START_DOCSTRING, diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 42c39d7d63a2..d7c640517e79 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -29,7 +29,7 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -523,20 +523,6 @@ def _init_weights(self, module): """ -def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - if cu_seq_lens.ndim != 1: - raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") - pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) - seq_lens = cu_seq_lens.diff(dim=-1) - max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) - start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) - for s in seq_lens: - pos_ids[start : start + s] = max_arange[:s] - start += s - - return pos_ids[None] - - @add_start_docstrings( "The bare Phi3 Model outputting raw hidden-states without any specific head on top.", PHI3_START_DOCSTRING, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 63f316d61dcd..2d0c3957f740 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -13,7 +13,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -466,20 +466,6 @@ def _init_weights(self, module): """ -def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - if cu_seq_lens.ndim != 1: - raise ValueError(f"cu_seq_lens must be a 1D tensor, received {cu_seq_lens.ndim=}.") - pos_ids = torch.empty(cu_seq_lens[-1], device=cu_seq_lens.device, dtype=torch.int32) - seq_lens = cu_seq_lens.diff(dim=-1) - max_arange = torch.arange(seq_lens.max(), dtype=torch.int32, device=cu_seq_lens.device) - start = torch.tensor(0, device=cu_seq_lens.device, dtype=torch.int32) - for s in seq_lens: - pos_ids[start : start + s] = max_arange[:s] - start += s - - return pos_ids[None] - - @add_start_docstrings( "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", QWEN2_START_DOCSTRING, diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index b3cd22e3d5ab..b73c80f10032 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -22,7 +22,7 @@ from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed from transformers.generation.configuration_utils import GenerationConfig -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens from transformers.models.llama.modeling_llama import get_position_ids_from_cu_seq_lens from transformers.testing_utils import ( cleanup, From 0fb7312ec4f3747da309052e6e39e759c794db5e Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Tue, 11 Feb 2025 16:28:20 +0000 Subject: [PATCH 5/5] minor import fix --- tests/models/llama/test_modeling_llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index b73c80f10032..f598df1aba9a 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -23,7 +23,6 @@ from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed from transformers.generation.configuration_utils import GenerationConfig from transformers.modeling_flash_attention_utils import FlashAttentionKwargs, get_position_ids_from_cu_seq_lens -from transformers.models.llama.modeling_llama import get_position_ids_from_cu_seq_lens from transformers.testing_utils import ( cleanup, require_flash_attn,