From 85d549a78db98ba8f83a3018f50fc756a8aa5111 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Jun 2024 16:41:16 +0200 Subject: [PATCH 01/52] softcapping --- .../models/gemma2/configuration_gemma2.py | 13 ++++++++++++- src/transformers/models/gemma2/modeling_gemma2.py | 5 +++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 74976bdd340f..0febfb7a591b 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -121,6 +121,8 @@ def __init__( rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, + final_logit_softcapping=30.0, + attn_logit_softcapping=50.00, query_pre_attn_scalar=224, sliding_window=4096, final_logit_softcapping=30.0, @@ -149,7 +151,16 @@ def __init__( self.rope_theta = rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout - self.hidden_activation = hidden_activation + self.attn_logit_softcapping = attn_logit_softcapping + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.final_logit_softcapping = final_logit_softcapping self.query_pre_attn_scalar = query_pre_attn_scalar self.sliding_window = sliding_window self.final_logit_softcapping = final_logit_softcapping diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 28f5f5da7ba0..c8fc930e96c7 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -247,6 +247,11 @@ def forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask + if self.config.attn_logit_softcapping is not None: + attn_weights = attn_weights / self.config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * self.config.attn_logit_softcapping + # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) From eba51917476ca0458c8601097db29ec557585ed2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Jun 2024 16:43:59 +0200 Subject: [PATCH 02/52] soft cap before the mask --- src/transformers/models/gemma2/modeling_gemma2.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index c8fc930e96c7..8e48f4af5595 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -243,15 +243,11 @@ def forward( attn_weights = attn_weights / self.config.attn_logit_softcapping attn_weights = torch.tanh(attn_weights) attn_weights = attn_weights * self.config.attn_logit_softcapping + if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - if self.config.attn_logit_softcapping is not None: - attn_weights = attn_weights / self.config.attn_logit_softcapping - attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * self.config.attn_logit_softcapping - # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) From b9e4a54c35e8fe55b03aeb903f3fca5b48ff5759 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Jun 2024 17:02:14 +0200 Subject: [PATCH 03/52] style --- src/transformers/models/gemma2/configuration_gemma2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 0febfb7a591b..23bc85dd7d4f 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -80,6 +80,8 @@ class Gemma2Config(PretrainedConfig): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. + final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the attention scores. query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window. From 514a839450f5a52c1edc7233d033297509038c27 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Jun 2024 17:08:08 +0200 Subject: [PATCH 04/52] ... --- src/transformers/models/gemma2/configuration_gemma2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 23bc85dd7d4f..064a4cc9861d 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -81,7 +81,7 @@ class Gemma2Config(PretrainedConfig): attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. - attn_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the attention scores. + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window. From 7544febb4700cf4b7efb2655084846d00e73c040 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Jun 2024 17:09:50 +0200 Subject: [PATCH 05/52] super nit --- src/transformers/models/gemma2/configuration_gemma2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 064a4cc9861d..a6b74bdd0ac8 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -124,7 +124,7 @@ def __init__( attention_bias=False, attention_dropout=0.0, final_logit_softcapping=30.0, - attn_logit_softcapping=50.00, + attn_logit_softcapping=50.0, query_pre_attn_scalar=224, sliding_window=4096, final_logit_softcapping=30.0, From be1b8c38b39a387c5993cb038b9e3a80e8675642 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 21 Oct 2024 15:02:15 +0200 Subject: [PATCH 06/52] update --- .../models/gemma2/modeling_gemma2.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 8e48f4af5595..c20f828b448e 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -24,6 +24,7 @@ import torch import torch.nn as nn import torch.utils.checkpoint +from torch.nn.attention.flex_attention import flex_attention from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache @@ -460,6 +461,27 @@ def forward( # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False + + def tanh_softcap(score, b, h, q_idx, kv_idx): + soft_cap = self.config.attn_logit_softcapping + return soft_cap * torch.tanh(score / soft_cap) + + # def causal_mask(b, h, q_idx, kv_idx): + # return q_idx >= kv_idx + + attn_output = flex_attention( + query_states, + key_states, + value_states, + # attn_mask=causal_mask, + block_mask=causal_mask, + score_mod=tanh_softcap, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + scale=self.scaling, + ) + + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, From 0e0511f8e2ef41e67dd163419afb64f9b89a22f4 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 21 Oct 2024 15:50:46 +0200 Subject: [PATCH 07/52] fixes --- .../models/gemma2/configuration_gemma2.py | 15 +-------- .../models/gemma2/modeling_gemma2.py | 33 +------------------ 2 files changed, 2 insertions(+), 46 deletions(-) diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index a6b74bdd0ac8..74976bdd340f 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -80,8 +80,6 @@ class Gemma2Config(PretrainedConfig): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. - attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window. @@ -123,8 +121,6 @@ def __init__( rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, - final_logit_softcapping=30.0, - attn_logit_softcapping=50.0, query_pre_attn_scalar=224, sliding_window=4096, final_logit_softcapping=30.0, @@ -153,16 +149,7 @@ def __init__( self.rope_theta = rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout - self.attn_logit_softcapping = attn_logit_softcapping - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - self.final_logit_softcapping = final_logit_softcapping + self.hidden_activation = hidden_activation self.query_pre_attn_scalar = query_pre_attn_scalar self.sliding_window = sliding_window self.final_logit_softcapping = final_logit_softcapping diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index c20f828b448e..2ba31b07eee2 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -443,52 +443,21 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - causal_mask = attention_mask if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = self.config.attn_logit_softcapping return soft_cap * torch.tanh(score / soft_cap) - # def causal_mask(b, h, q_idx, kv_idx): - # return q_idx >= kv_idx - attn_output = flex_attention( query_states, key_states, value_states, - # attn_mask=causal_mask, block_mask=causal_mask, score_mod=tanh_softcap, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - scale=self.scaling, - ) - - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + enable_gqa=True, scale=self.scaling, ) From 03ccc224d2e99d606faf6605e1cd46f4128b13a5 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 21 Oct 2024 15:57:38 +0200 Subject: [PATCH 08/52] update --- .../models/gemma2/modular_gemma2.py | 49 ++++++------------- src/transformers/utils/import_utils.py | 8 +++ 2 files changed, 23 insertions(+), 34 deletions(-) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 9d7f047e1a84..e05152ace4d7 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -30,6 +30,7 @@ is_flash_attn_2_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, + is_torch_greater_or_equal, logging, ) from ..gemma.modeling_gemma import ( @@ -49,6 +50,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_greater_or_equal("2.5"): + from torch.nn.attention.flex_attention import flex_attention + logger = logging.get_logger(__name__) @@ -414,22 +418,6 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Gemma2Model is using Gemma2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -453,40 +441,33 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - causal_mask = attention_mask if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False + def tanh_softcap(score, b, h, q_idx, kv_idx): + soft_cap = self.config.attn_logit_softcapping + return soft_cap * torch.tanh(score / soft_cap) - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output = flex_attention( query_states, key_states, value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + block_mask=causal_mask, + score_mod=tanh_softcap, + enable_gqa=True, scale=self.scaling, + return_lse=output_attentions, ) + if output_attentions: + attn_output, attention_scores = attn_output attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, attention_scores, past_key_value class Gemma2DecoderLayer(GemmaDecoderLayer): diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 173aee9b1ac7..6306efa2face 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -929,6 +929,14 @@ def is_flash_attn_greater_or_equal(library_version: str): return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version) +@lru_cache() +def is_torch_greater_or_equal(library_version: str): + if not _is_package_available("torch"): + return False + + return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version) + + def is_torchdistx_available(): return _torchdistx_available From bdda7245ce686e78a336561359666ddd1b7b759f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 21 Oct 2024 16:08:55 +0200 Subject: [PATCH 09/52] small issue with modular --- .../models/gemma2/configuration_gemma2.py | 11 ++++ .../models/gemma2/modeling_gemma2.py | 50 +++++++++---------- utils/modular_model_converter.py | 4 +- 3 files changed, 39 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 74976bdd340f..f286f92316ad 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -22,6 +22,17 @@ from ...configuration_utils import PretrainedConfig +from ...utils import ( + is_flash_attn_2_available, + is_torch_greater_or_equal, +) + + +if is_flash_attn_2_available(): + pass + +if is_torch_greater_or_equal("2.5"): + pass class Gemma2Config(PretrainedConfig): diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 2ba31b07eee2..1c16dfb96b8f 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -19,20 +19,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint -from torch.nn.attention.flex_attention import flex_attention from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache -from ...generation import GenerationMixin -from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, +) +from ...utils import ( + is_flash_attn_2_available, + is_flash_attn_greater_or_equal, + is_flash_attn_greater_or_equal_2_10, + is_torch_greater_or_equal, + logging, +) + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + +if is_torch_greater_or_equal("2.5"): + from torch.nn.attention.flex_attention import flex_attention +from typing import List + +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_outputs import ( SequenceClassifierOutputWithPast, TokenClassifierOutput, ) @@ -40,9 +57,6 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal, - is_flash_attn_greater_or_equal_2_10, - logging, replace_return_docstrings, ) from .configuration_gemma2 import Gemma2Config @@ -244,7 +258,6 @@ def forward( attn_weights = attn_weights / self.config.attn_logit_softcapping attn_weights = torch.tanh(attn_weights) attn_weights = attn_weights * self.config.attn_logit_softcapping - if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask @@ -404,22 +417,6 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Gemma2Model is using Gemma2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -459,14 +456,17 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): score_mod=tanh_softcap, enable_gqa=True, scale=self.scaling, + return_lse=output_attentions, ) + if output_attentions: + attn_output, attention_scores = attn_output attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, attention_scores, past_key_value GEMMA2_ATTENTION_CLASSES = { diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index c107a4831862..5091922ca0b3 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -1034,6 +1034,7 @@ def leave_If(self, original_node, node): if re.search(r"[\s\S]*is_.*available", full_statement): self.all_safe_imports.append(node) elif full_statement not in self.all_imports: + self.all_safe_imports.append(node) logger.warning(f"one import is protected with `if`. Hard guess where it's used {full_statement}") return node @@ -1102,6 +1103,7 @@ def _recursively_add_all_new_needed_functions_in_files(self): ) def leave_Module(self, original_node: cst.Module, node): + self.all_imports.extend(self.all_safe_imports) imports = {self.python_module.code_for_node(k): k for k in self.all_imports} dependency_imports = {file_type: imports.copy() for file_type in self.files} for super_file_name, visiter in self.visited_module.items(): @@ -1180,7 +1182,7 @@ def save_modeling_file(modular_file, converted_file): parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["src/transformers/models/roberta/modular_roberta.py"], + default=["src/transformers/models/gemma2/modular_gemma2.py"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", ) From a2b6b12a3e0450c182a49003d278e2c1c095f953 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 21 Oct 2024 17:17:58 +0200 Subject: [PATCH 10/52] fix modular imports --- .../models/gemma2/configuration_gemma2.py | 13 ---- .../models/gemma2/modeling_gemma2.py | 14 ++--- utils/modular_model_converter.py | 59 ++++++++++++------- 3 files changed, 43 insertions(+), 43 deletions(-) diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index f286f92316ad..45006b8ca2f5 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -19,20 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from ...configuration_utils import PretrainedConfig -from ...utils import ( - is_flash_attn_2_available, - is_torch_greater_or_equal, -) - - -if is_flash_attn_2_available(): - pass - -if is_torch_greater_or_equal("2.5"): - pass class Gemma2Config(PretrainedConfig): diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 1c16dfb96b8f..11430bc2b067 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -19,13 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.utils.checkpoint - -from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -45,8 +38,13 @@ if is_torch_greater_or_equal("2.5"): from torch.nn.attention.flex_attention import flex_attention -from typing import List +from typing import List, Optional, Tuple, Union +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 5091922ca0b3..6235c9564a1f 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -663,8 +663,10 @@ def get_new_part(class_name, base_class): def find_all_dependencies(function: str, dependency_mapping: Dict[str, set]): """Return all the dependencies of the given top-level function. Given the following structure in the `modular_xxx.py` file: ``` + from time import now + def foo1(): - pass + return now() def foo2(): pass @@ -682,12 +684,12 @@ def forward(...): ``` and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get: ``` - dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}} + dependency_mapping = {'foo1': {'now'}, 'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}} find_all_dependencies('foobar', dependency_mapping) >>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')] ``` That is, all the functions needed (and their immediate parent) so that the function to be added in MyLayer (`foobar`) can - work correctly. + work correctly. Plus the nodes that import the dependencies. """ all_dependencies = deque(dependency_mapping[function]) all_dependencies_with_parent = [(dep, function) for dep in dependency_mapping[function]] @@ -707,8 +709,8 @@ def forward(...): class PostModularConverterCleaner(CSTTransformer): - """Allow simple cleaning after conversion. Remove top-level functions/classes without any calls (they may arise due - to dependency mapping, even if code parts with those functions/classes were overwritten)""" + """Allow simple cleaning after conversion. Removes top-level functions/classes that are defined, but not called. + (this may happen due to dependency mapping, even if code parts with those functions/classes were overwritten)""" METADATA_DEPENDENCIES = (ParentNodeProvider,) @@ -769,8 +771,7 @@ def __init__(self, python_module, new_name, given_old_name=None, given_new_name= self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"} self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama" self.inserted_deps = [] # nodes inserted via super dependency - self.all_imports = [] # just stores all of the imports - self.all_safe_imports = [] # stores the import under simple statements + self.all_imports = {} # just stores all of the imports self.global_scope_index = 0 # fmt: on self.files = { # mapping for different component bodies @@ -828,8 +829,10 @@ def leave_SimpleStatementLine(self, original_node, updated_node): parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) if m.matches(parent_node, m.Module()): if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])): - if updated_node not in self.all_imports: - self.all_imports.append(updated_node) + for k in updated_node.body[0].names: + if k not in self.all_imports: + import_name = self.python_module.code_for_node(k.name) + self.all_imports[import_name] = updated_node return updated_node elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])): full_statement = self.python_module.code_for_node(updated_node.body[0].module) @@ -837,8 +840,10 @@ def leave_SimpleStatementLine(self, original_node, updated_node): rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", full_statement ): # OR MATCH ..llama.modeling_llama return cst.RemoveFromParent() - if updated_node not in self.all_imports: - self.all_imports.append(updated_node) + for k in updated_node.body[0].names: + if k not in self.all_imports: + import_name = self.python_module.code_for_node(k.name) + self.all_imports[import_name] = updated_node return updated_node elif m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])): if original_node.body[0].targets[0].target.value in ASSIGNMENTS_TO_KEEP.keys(): @@ -853,6 +858,9 @@ def leave_SimpleStatementLine(self, original_node, updated_node): def visit_ClassDef(self, node: cst.ClassDef): """Used to keep track of current class""" self.current_class = node.name.value + for k in node.bases: + if isinstance(k.value, cst.Name): + self.function_call_class_mapping[k.value.value].add(self.current_class) def leave_ClassDef(self, original_node, updated_node): """ @@ -1030,12 +1038,9 @@ def visit_Assign(self, node: cst.Assign) -> None: def leave_If(self, original_node, node): parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) if m.matches(parent_node, m.Module()): - full_statement = self.python_module.code_for_node(original_node.test) - if re.search(r"[\s\S]*is_.*available", full_statement): - self.all_safe_imports.append(node) - elif full_statement not in self.all_imports: - self.all_safe_imports.append(node) - logger.warning(f"one import is protected with `if`. Hard guess where it's used {full_statement}") + for k in node.body.body[0].body[0].names: + import_name = self.python_module.code_for_node(k.name) + self.all_imports[import_name] = node return node def visit_Call(self, node: cst.Call): @@ -1081,9 +1086,12 @@ def _maybe_add_function_to_body( return False def _recursively_add_all_new_needed_functions_in_files(self): - """For all top-level functions which were newly defined in the `modular_xxx.py`, check if they are used in a class in + r"""For all top-level functions which were newly defined in the `modular_xxx.py`, check if they are used in a class in the different files, and add them to the file if it is the case (also recursively adding all other functions that - may be needed in that function body).""" + may be needed in that function body). + + Also takes care of sorting which imports are needed for this file. + """ # At this point, `self.all_definitions` only contains newly defined top-level functions in the `modualr_xxx.py` for top_level_function, function_node in self.all_definitions.items(): calling_entities = self.function_call_class_mapping[top_level_function] @@ -1102,10 +1110,17 @@ def _recursively_add_all_new_needed_functions_in_files(self): dependency, body, self.all_definitions[dependency], parent=parent ) + def _filter_imports_for_file(self, file_name, imports): + _dict = {} + for key, value in imports.items(): + if key in self.function_call_class_mapping: + node = self.function_call_class_mapping[key] + if len(node) == 1 and node.copy().pop() in self.files[file_name]: + _dict[key] = value + return _dict + def leave_Module(self, original_node: cst.Module, node): - self.all_imports.extend(self.all_safe_imports) - imports = {self.python_module.code_for_node(k): k for k in self.all_imports} - dependency_imports = {file_type: imports.copy() for file_type in self.files} + dependency_imports = {file_type: self._filter_imports_for_file(file_type, self.all_imports.copy()) for file_type in self.files} for super_file_name, visiter in self.visited_module.items(): file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0] dependency_imports[file_type].update( From 9365c1b71de0e127ffcfdf6f7990bed5815314c3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 21 Oct 2024 17:33:11 +0200 Subject: [PATCH 11/52] update --- .../models/gemma/configuration_gemma.py | 2 -- src/transformers/models/glm/modeling_glm.py | 2 +- .../configuration_instructblipvideo.py | 1 + .../modeling_instructblipvideo.py | 1 + utils/modular_model_converter.py | 27 ++++++++++--------- 5 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index e170803cccab..75d0096d4811 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -19,8 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from ...configuration_utils import PretrainedConfig diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index a458c02a6fed..c0d3767164c3 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -23,8 +23,8 @@ from typing import List, Optional, Tuple, Union import torch -import torch.nn as nn import torch.utils.checkpoint +from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache diff --git a/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py b/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py index e7c8eeccef98..2a0f8d8a647f 100644 --- a/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import os from typing import Union diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index a300268ed713..9877a079b8d8 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import math from dataclasses import dataclass from typing import Any, Optional, Tuple, Union diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 6235c9564a1f..54855a6e066b 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -328,9 +328,9 @@ def __init__(self, all_bases: Set[str]): def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode: # Handle ClassB.call_to_method if ( - isinstance(original_node.value, cst.Name) + m.matches(original_node.value, m.Name()) and original_node.value.value in self.all_bases - and isinstance(original_node.attr, cst.Name) + and m.matches(original_node.attr, m.Name()) ): # Replace with super().call_to_method return updated_node.with_changes( @@ -338,10 +338,10 @@ def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attrib ) # Handle ClassB().call_to_method elif ( - isinstance(original_node.value, cst.Call) - and isinstance(original_node.value.func, cst.Name) + m.matches(original_node.value, m.Call()) + and m.matches(original_node.value.func, m.Name()) and original_node.value.func.value in self.all_bases - and isinstance(original_node.attr, cst.Name) + and m.matches(original_node.attr, m.Name()) ): # Replace with super().call_to_method return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super")))) @@ -349,16 +349,16 @@ def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attrib def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode: # Check if the function being called is of the form ClassB().func_a or ClassB.func_a - if isinstance(original_node.func, cst.Attribute) and ( + if m.matches(original_node.func, m.Attribute()) and ( # Match ClassB().func_a(...) ( - isinstance(original_node.func.value, cst.Call) - and isinstance(original_node.func.value.func, cst.Name) + m.matches(original_node.func.value, m.Call()) + and m.matches(original_node.func.value.func, m.Name()) and original_node.func.value.func.value in self.all_bases ) or # Match ClassB.func_a(...) - (isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases) + (m.matches(original_node.func.value, m.Name()) and original_node.func.value.value in self.all_bases) ): # Check if the first argument is 'self', and remove it if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")): @@ -1050,11 +1050,14 @@ def visit_Call(self, node: cst.Call): # Only map function calls if we're inside a class (i.e., current_class is set) if self.current_class is not None: # Simple function calls such as foo() - if isinstance(node.func, cst.Name): + if m.matches(node.func, m.Name()): self.function_call_class_mapping[node.func.value].add(self.current_class) + if m.matches(node.func, m.Attribute()|m.Subscript()): + _code = self.python_module.code_for_node(node.func.value) + self.function_call_class_mapping[_code].add(self.current_class) elif self.current_top_level_function is not None: # Simple function calls such as foo() - if isinstance(node.func, cst.Name): + if m.matches(node.func, m.Name()): self.function_call_dependency_mapping[self.current_top_level_function].add(node.func.value) def _maybe_add_function_to_body( @@ -1197,7 +1200,7 @@ def save_modeling_file(modular_file, converted_file): parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["src/transformers/models/gemma2/modular_gemma2.py"], + default=["src/transformers/models/llava_next_video/modular_llava_next_video.py"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", ) From 2108ee3ae4d409441470f472fc05ceba4e04204c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 21 Oct 2024 17:34:05 +0200 Subject: [PATCH 12/52] fixup --- src/transformers/utils/__init__.py | 1 + utils/modular_model_converter.py | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index a781389c2fbd..712f5e487f44 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -208,6 +208,7 @@ is_torch_fp16_available_on_device, is_torch_fx_available, is_torch_fx_proxy, + is_torch_greater_or_equal, is_torch_mlu_available, is_torch_mps_available, is_torch_musa_available, diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 54855a6e066b..c53e20b3f735 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -1052,7 +1052,7 @@ def visit_Call(self, node: cst.Call): # Simple function calls such as foo() if m.matches(node.func, m.Name()): self.function_call_class_mapping[node.func.value].add(self.current_class) - if m.matches(node.func, m.Attribute()|m.Subscript()): + if m.matches(node.func, m.Attribute() | m.Subscript()): _code = self.python_module.code_for_node(node.func.value) self.function_call_class_mapping[_code].add(self.current_class) elif self.current_top_level_function is not None: @@ -1118,12 +1118,14 @@ def _filter_imports_for_file(self, file_name, imports): for key, value in imports.items(): if key in self.function_call_class_mapping: node = self.function_call_class_mapping[key] - if len(node) == 1 and node.copy().pop() in self.files[file_name]: + if len(node) == 1 and node.copy().pop() in self.files[file_name]: _dict[key] = value return _dict def leave_Module(self, original_node: cst.Module, node): - dependency_imports = {file_type: self._filter_imports_for_file(file_type, self.all_imports.copy()) for file_type in self.files} + dependency_imports = { + file_type: self._filter_imports_for_file(file_type, self.all_imports.copy()) for file_type in self.files + } for super_file_name, visiter in self.visited_module.items(): file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0] dependency_imports[file_type].update( From 520120a12d18cb2a5e828d357bacaf6c14750eaa Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 21 Oct 2024 17:54:56 +0200 Subject: [PATCH 13/52] simplify a hell lot --- utils/modular_model_converter.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index c53e20b3f735..cbfe027b84e0 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -788,9 +788,9 @@ def __init__(self, python_module, new_name, given_old_name=None, given_new_name= self.current_class = None # keep track of current top-level class during visit self.current_top_level_function = None # keep track of current top-level function during visit # Mapping from top-level functions to classes using them - self.function_call_class_mapping = defaultdict(lambda: set()) + self.function_call_class_mapping = defaultdict(set) # Mapping from top-level functions to other top-level functions dependencies - self.function_call_dependency_mapping = defaultdict(lambda: set()) + self.function_call_dependency_mapping = defaultdict(set) self.added_dependencies = set() def visit_ImportFrom(self, node: cst.ImportFrom) -> None: @@ -1113,19 +1113,18 @@ def _recursively_add_all_new_needed_functions_in_files(self): dependency, body, self.all_definitions[dependency], parent=parent ) - def _filter_imports_for_file(self, file_name, imports): - _dict = {} + def _filter_imports_for_file(self, imports): + _dict = defaultdict(dict) for key, value in imports.items(): if key in self.function_call_class_mapping: node = self.function_call_class_mapping[key] - if len(node) == 1 and node.copy().pop() in self.files[file_name]: - _dict[key] = value + if len(node) == 1: + file_name = self.class_to_file_type[node.copy().pop()] + _dict[file_name][key] = value return _dict def leave_Module(self, original_node: cst.Module, node): - dependency_imports = { - file_type: self._filter_imports_for_file(file_type, self.all_imports.copy()) for file_type in self.files - } + dependency_imports = self._filter_imports_for_file(self.all_imports.copy()) for super_file_name, visiter in self.visited_module.items(): file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0] dependency_imports[file_type].update( From 314ed1f4b6e9ccb47891762a3aef29acd35d43b3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 22 Oct 2024 09:26:42 +0200 Subject: [PATCH 14/52] simplify cleaning imports --- .../models/gemma2/modeling_gemma2.py | 29 +++--- utils/modular_model_converter.py | 89 +++++++++++++------ 2 files changed, 71 insertions(+), 47 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 11430bc2b067..afa5301a5968 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -19,11 +19,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...utils import ( is_flash_attn_2_available, is_flash_attn_greater_or_equal, @@ -38,25 +41,13 @@ if is_torch_greater_or_equal("2.5"): from torch.nn.attention.flex_attention import flex_attention -from typing import List, Optional, Tuple, Union +from typing import List -import torch -import torch.utils.checkpoint -from torch import nn - -from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_outputs import ( - SequenceClassifierOutputWithPast, - TokenClassifierOutput, -) +from ...modeling_outputs import SequenceClassifierOutputWithPast, TokenClassifierOutput from ...modeling_utils import PreTrainedModel -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from .configuration_gemma2 import Gemma2Config diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index cbfe027b84e0..95a21affbd7b 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -18,7 +18,7 @@ import os import re from collections import defaultdict, deque -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Union import libcst as cst from check_copies import run_ruff @@ -708,17 +708,18 @@ def forward(...): return all_dependencies_with_parent -class PostModularConverterCleaner(CSTTransformer): +class PostModularConverterCleaner(m.MatcherDecoratableTransformer): """Allow simple cleaning after conversion. Removes top-level functions/classes that are defined, but not called. (this may happen due to dependency mapping, even if code parts with those functions/classes were overwritten)""" METADATA_DEPENDENCIES = (ParentNodeProvider,) - def __init__(self, added_dependencies: set): + def __init__(self, added_dependencies: set, unused_imports:Dict[Union[cst.Import, cst.ImportFrom], Set[str]]): super().__init__() self.top_level_functions_or_classes = {} self.all_used_functions_or_classes = set() self.added_dependencies = added_dependencies + self.unused_imports = unused_imports def visit_FunctionDef(self, node): parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) @@ -750,11 +751,52 @@ def leave_Module(self, original_node: cst.Module, node): nodes_to_remove = [ self.top_level_functions_or_classes[name] for name in unused if name in self.top_level_functions_or_classes ] - new_body = [node_ for node_ in original_node.body if node_ not in nodes_to_remove] + new_body = [node_ for node_ in node.body if node_ not in nodes_to_remove] # Return a new module with the updated body return node.with_changes(body=new_body) + def leave_If(self,original_node: cst.If,updated_node: cst.If): + for stmt in original_node.body.body: + if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): + if len(updated_node.body.body) == 0: + return cst.RemoveFromParent() + return updated_node + + @m.leave(m.Import() | m.ImportFrom()) + def leave_import_alike(self, original_node, updated_node): + names_to_keep = [] + for name in updated_node.names: + name_value = name.evaluated_name + if name_value not in self.unused_imports: + names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT)) + if len(names_to_keep) == 0: + return cst.RemoveFromParent() + else: + return updated_node.with_changes(names=names_to_keep) + + + +def get_unused_imports(source): + wrapper = cst.metadata.MetadataWrapper(source) + scopes = set(wrapper.resolve(cst.metadata.ScopeProvider).values()) + unused_imports: Dict[Union[cst.Import, cst.ImportFrom], Set[str]] = defaultdict(set) + ranges = wrapper.resolve(cst.metadata.PositionProvider) + for scope in scopes: + for assignment in scope.assignments: + node = assignment.node + if isinstance(assignment, cst.metadata.Assignment) and isinstance( + node, (cst.Import, cst.ImportFrom) + ): + if len(assignment.references) == 0: + unused_imports[assignment.name].add(node) + location = ranges[node].start + print( + f"Warning on line {location.line:2d}, column {location.column:2d}: Imported name `{assignment.name}` is unused." + ) + return unused_imports + + class ModularConverterTransformer(CSTTransformer): METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider) @@ -788,7 +830,7 @@ def __init__(self, python_module, new_name, given_old_name=None, given_new_name= self.current_class = None # keep track of current top-level class during visit self.current_top_level_function = None # keep track of current top-level function during visit # Mapping from top-level functions to classes using them - self.function_call_class_mapping = defaultdict(set) + self.callable_dependency_mapping = defaultdict(set) # Mapping from top-level functions to other top-level functions dependencies self.function_call_dependency_mapping = defaultdict(set) self.added_dependencies = set() @@ -830,9 +872,8 @@ def leave_SimpleStatementLine(self, original_node, updated_node): if m.matches(parent_node, m.Module()): if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])): for k in updated_node.body[0].names: - if k not in self.all_imports: - import_name = self.python_module.code_for_node(k.name) - self.all_imports[import_name] = updated_node + import_name = self.python_module.code_for_node(k.name) + self.all_imports[import_name] = updated_node return updated_node elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])): full_statement = self.python_module.code_for_node(updated_node.body[0].module) @@ -841,9 +882,8 @@ def leave_SimpleStatementLine(self, original_node, updated_node): ): # OR MATCH ..llama.modeling_llama return cst.RemoveFromParent() for k in updated_node.body[0].names: - if k not in self.all_imports: - import_name = self.python_module.code_for_node(k.name) - self.all_imports[import_name] = updated_node + import_name = self.python_module.code_for_node(k.name) + self.all_imports[import_name] = updated_node return updated_node elif m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])): if original_node.body[0].targets[0].target.value in ASSIGNMENTS_TO_KEEP.keys(): @@ -860,7 +900,7 @@ def visit_ClassDef(self, node: cst.ClassDef): self.current_class = node.name.value for k in node.bases: if isinstance(k.value, cst.Name): - self.function_call_class_mapping[k.value.value].add(self.current_class) + self.callable_dependency_mapping[k.value.value].add(self.current_class) def leave_ClassDef(self, original_node, updated_node): """ @@ -1051,10 +1091,10 @@ def visit_Call(self, node: cst.Call): if self.current_class is not None: # Simple function calls such as foo() if m.matches(node.func, m.Name()): - self.function_call_class_mapping[node.func.value].add(self.current_class) + self.callable_dependency_mapping[node.func.value].add(self.current_class) if m.matches(node.func, m.Attribute() | m.Subscript()): _code = self.python_module.code_for_node(node.func.value) - self.function_call_class_mapping[_code].add(self.current_class) + self.callable_dependency_mapping[_code].add(self.current_class) elif self.current_top_level_function is not None: # Simple function calls such as foo() if m.matches(node.func, m.Name()): @@ -1097,7 +1137,7 @@ def _recursively_add_all_new_needed_functions_in_files(self): """ # At this point, `self.all_definitions` only contains newly defined top-level functions in the `modualr_xxx.py` for top_level_function, function_node in self.all_definitions.items(): - calling_entities = self.function_call_class_mapping[top_level_function] + calling_entities = self.callable_dependency_mapping[top_level_function] # The function may be needed in different files, we need to iterate on them for file, body in self.files.items(): file_elements = set(body.keys()) @@ -1113,18 +1153,9 @@ def _recursively_add_all_new_needed_functions_in_files(self): dependency, body, self.all_definitions[dependency], parent=parent ) - def _filter_imports_for_file(self, imports): - _dict = defaultdict(dict) - for key, value in imports.items(): - if key in self.function_call_class_mapping: - node = self.function_call_class_mapping[key] - if len(node) == 1: - file_name = self.class_to_file_type[node.copy().pop()] - _dict[file_name][key] = value - return _dict - def leave_Module(self, original_node: cst.Module, node): - dependency_imports = self._filter_imports_for_file(self.all_imports.copy()) + imports = {self.python_module.code_for_node(k): k for k in self.all_imports.values()} + dependency_imports = {file_type: imports.copy() for file_type in self.files} for super_file_name, visiter in self.visited_module.items(): file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0] dependency_imports[file_type].update( @@ -1142,7 +1173,9 @@ def leave_Module(self, original_node: cst.Module, node): new_body = list(dependency_imports[file].values()) + new_body new_module = cst.Module(body=[*new_body], header=node.header) # Final cleanup - new_module = MetadataWrapper(new_module).visit(PostModularConverterCleaner(self.added_dependencies)) + unused_imports = get_unused_imports(new_module) + cleaner = PostModularConverterCleaner(self.added_dependencies, unused_imports) + new_module = MetadataWrapper(new_module).visit(cleaner) self.files[file] = new_module return node @@ -1201,7 +1234,7 @@ def save_modeling_file(modular_file, converted_file): parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["src/transformers/models/llava_next_video/modular_llava_next_video.py"], + default=["src/transformers/models/gemma2/modular_gemma2.py"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", ) From 88304738944154c62cdbb58a5c3c8761f2234781 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 22 Oct 2024 13:45:19 +0200 Subject: [PATCH 15/52] finish fixing --- .../models/gemma/configuration_gemma.py | 1 + .../models/gemma/modeling_gemma.py | 1 - .../models/gemma2/modeling_gemma2.py | 24 +++++++++++-------- src/transformers/models/glm/modeling_glm.py | 3 +-- .../configuration_instructblipvideo.py | 1 - .../modeling_instructblipvideo.py | 2 -- .../modeling_llava_next_video.py | 8 +------ 7 files changed, 17 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index 75d0096d4811..346f386ba698 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -19,6 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from ...configuration_utils import PretrainedConfig diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 43882e7f8c05..d4e6872ece41 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -23,7 +23,6 @@ from typing import List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index afa5301a5968..9dfb2619587a 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -19,21 +19,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, is_torch_greater_or_equal, logging, + replace_return_docstrings, ) +from .configuration_gemma2 import Gemma2Config if is_flash_attn_2_available(): @@ -41,14 +53,6 @@ if is_torch_greater_or_equal("2.5"): from torch.nn.attention.flex_attention import flex_attention -from typing import List - -from ...generation import GenerationMixin -from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_outputs import SequenceClassifierOutputWithPast, TokenClassifierOutput -from ...modeling_utils import PreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings -from .configuration_gemma2 import Gemma2Config class Gemma2RMSNorm(nn.Module): diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index c0d3767164c3..484b16d314d5 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -23,8 +23,7 @@ from typing import List, Optional, Tuple, Union import torch -import torch.utils.checkpoint -from torch import nn +import torch.nn as nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache diff --git a/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py b/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py index 2a0f8d8a647f..e7c8eeccef98 100644 --- a/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py @@ -19,7 +19,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import os from typing import Union diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 9877a079b8d8..19e96c54230e 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -19,13 +19,11 @@ # See the License for the specific language governing permissions and # limitations under the License. - import math from dataclasses import dataclass from typing import Any, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 3fd6bb47fc76..fbfd37291c88 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -25,7 +25,6 @@ import numpy as np import torch -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN @@ -33,12 +32,7 @@ from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ..auto import AutoModel, AutoModelForCausalLM from .configuration_llava_next_video import LlavaNextVideoConfig From e4c19d7a6c6e41ed1beb9290d610b9c92222058a Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 22 Oct 2024 14:29:17 +0200 Subject: [PATCH 16/52] update our design --- .../models/gemma2/modular_gemma2.py | 386 ++++++++---------- utils/modular_model_converter.py | 61 +-- 2 files changed, 199 insertions(+), 248 deletions(-) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index e05152ace4d7..0192e39bea4e 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -29,7 +29,6 @@ from ...utils import ( is_flash_attn_2_available, is_flash_attn_greater_or_equal, - is_flash_attn_greater_or_equal_2_10, is_torch_greater_or_equal, logging, ) @@ -209,118 +208,183 @@ def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) -class Gemma2Attention(GemmaAttention): +def eager_attention_forward(config, query, key, value, mask): + key_states = repeat_kv(key, config.num_key_value_groups) + value_states = repeat_kv(value, config.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + + if config.attn_logit_softcapping is not None: + attn_weights = attn_weights / config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * config.attn_logit_softcapping + if mask is not None: # no matter the length, we just slice it + causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) + attn_output = torch.matmul(attn_weights, value_states) + return attn_output + + + +def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16): + if mask is not None: + seq_len = mask.shape[1] + query = query[:, :, :seq_len] + value = value[:, :, :seq_len] + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding + query_states = query.transpose(1, 2) + key_states = key.transpose(1, 2) + value_states = value.transpose(1, 2) + + dropout_rate = config.attention_dropout if config.training else 0.0 + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + mask, + seq_len, + dropout=dropout_rate, + softmax_scale=config.scaling, + is_causal=config.is_causal, + sliding_window=config.sliding_window, + use_top_left_mask=config._flash_attn_uses_top_left_mask, + softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, + ) + + return attn_output + + +def flex_attention_forward(config, query, key, value, mask, output_attentions=False, target_dtype=torch.float16): + causal_mask = mask + if mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + def tanh_softcap(score, b, h, q_idx, kv_idx): + soft_cap = config.attn_logit_softcapping + return soft_cap * torch.tanh(score / soft_cap) + + attn_output = flex_attention( + query, + key, + value, + block_mask=causal_mask, + score_mod=tanh_softcap, + enable_gqa=True, + scale=config.scaling, + return_lse=output_attentions, + ) + return attn_output + + +def sdpa_attention_forward(config, query, key, value, mask, output_attentions=False, target_dtype=torch.float16): + key = repeat_kv(key, config.num_key_value_groups) + value = repeat_kv(value, config.num_key_value_groups) + + causal_mask = mask + if mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query.device.type == "cuda" and causal_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and query.shape[1] > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=config.attention_dropout if config.training else 0.0, + is_causal=is_causal, + scale=config.scaling, + ) + return attn_output + + +GEMMA_ATTENTION_FUNCTION = { + "flash_attention": flash_attention_forward, + "flex_attention": flex_attention_forward, + "eager": eager_attention_forward, + "sdpa": sdpa_attention_forward, +} + + +class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) - self.scaling = config.query_pre_attn_scalar**-0.5 - self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + super().__init__() + self.config = config + self.layer_idx = layer_idx - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.scaling = 1 / math.sqrt(config.head_dim) - if self.config.attn_logit_softcapping is not None: - attn_weights = attn_weights / self.config.attn_logit_softcapping - attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * self.config.attn_logit_softcapping - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + self.scaling = config.query_pre_attn_scalar**-0.5 + self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None + self.attention_type = config.attn_implementation + self.attention_function = GEMMA_ATTENTION_FUNCTION[config.attn_implementation] - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + if self.hidden_size % self.num_heads != 0: raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." ) - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Gemma2FlashAttention2(Gemma2Attention): - """ - Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self.rotary_emb = Gemma2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -338,57 +402,8 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if attention_mask is not None: - seq_len = attention_mask.shape[1] - key_states = key_states[:, :, :seq_len] - value_states = value_states[:, :, :seq_len] - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (Gemma2RMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - softmax_scale=self.scaling, - is_causal=self.is_causal, - sliding_window=self.sliding_window, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, + attn_output = self.attention_function( + self, query_states, key_states, value_states, attention_mask, self.config ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -399,83 +414,18 @@ def forward( return attn_output, attn_weights, past_key_value - -class Gemma2SdpaAttention(Gemma2Attention): - """ - Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Gemma2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - def tanh_softcap(score, b, h, q_idx, kv_idx): - soft_cap = self.config.attn_logit_softcapping - return soft_cap * torch.tanh(score / soft_cap) - - attn_output = flex_attention( - query_states, - key_states, - value_states, - block_mask=causal_mask, - score_mod=tanh_softcap, - enable_gqa=True, - scale=self.scaling, - return_lse=output_attentions, - ) - if output_attentions: - attn_output, attention_scores = attn_output - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, attention_scores, past_key_value - - -class Gemma2DecoderLayer(GemmaDecoderLayer): +class Gemma2DecoderLayer(nn.Module): def __init__(self, config: Gemma2Config, layer_idx: int): - super().__init__(config, layer_idx) + super().__init__() + self.hidden_size = config.hidden_size self.config = config self.is_sliding = not bool(layer_idx % 2) + self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma2MLP(config) + self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.sliding_window = config.sliding_window @@ -541,20 +491,6 @@ def forward( class Gemma2PreTrainedModel(GemmaPreTrainedModel): _supports_quantized_cache = False - @classmethod - def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): - """ - Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models. - SDPA reduces the model performance on Gemma2 because of the logits softcapping. - """ - config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only) - - # if using the default path -> swap sdpa by eager - if not hard_check_only and config._attn_implementation == "sdpa": - config._attn_implementation = "eager" - - return config - class Gemma2Model(GemmaModel, Gemma2PreTrainedModel): def __init__(self, config: Gemma2Config): diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 95a21affbd7b..e226041f3891 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -714,7 +714,7 @@ class PostModularConverterCleaner(m.MatcherDecoratableTransformer): METADATA_DEPENDENCIES = (ParentNodeProvider,) - def __init__(self, added_dependencies: set, unused_imports:Dict[Union[cst.Import, cst.ImportFrom], Set[str]]): + def __init__(self, added_dependencies: set, unused_imports: Dict[Union[cst.Import, cst.ImportFrom], Set[str]]): super().__init__() self.top_level_functions_or_classes = {} self.all_used_functions_or_classes = set() @@ -755,8 +755,7 @@ def leave_Module(self, original_node: cst.Module, node): # Return a new module with the updated body return node.with_changes(body=new_body) - - def leave_If(self,original_node: cst.If,updated_node: cst.If): + def leave_If(self, original_node: cst.If, updated_node: cst.If): for stmt in original_node.body.body: if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): if len(updated_node.body.body) == 0: @@ -776,8 +775,10 @@ def leave_import_alike(self, original_node, updated_node): return updated_node.with_changes(names=names_to_keep) - def get_unused_imports(source): + r""" + You have to use `isinstance` on assignements, m.matches apparently does not work here yet! + """ wrapper = cst.metadata.MetadataWrapper(source) scopes = set(wrapper.resolve(cst.metadata.ScopeProvider).values()) unused_imports: Dict[Union[cst.Import, cst.ImportFrom], Set[str]] = defaultdict(set) @@ -785,15 +786,18 @@ def get_unused_imports(source): for scope in scopes: for assignment in scope.assignments: node = assignment.node - if isinstance(assignment, cst.metadata.Assignment) and isinstance( - node, (cst.Import, cst.ImportFrom) - ): + if isinstance(assignment, cst.metadata.Assignment) and isinstance(node, (cst.Import, cst.ImportFrom)): if len(assignment.references) == 0: unused_imports[assignment.name].add(node) location = ranges[node].start print( f"Warning on line {location.line:2d}, column {location.column:2d}: Imported name `{assignment.name}` is unused." ) + if isinstance(scope, cst.metadata.GlobalScope): + for assignment in scope.assignments: + node = assignment.node + if assignment.references == 0: + print(f"Warning, {assignment.name} is never referenced") return unused_imports @@ -814,6 +818,7 @@ def __init__(self, python_module, new_name, given_old_name=None, given_new_name= self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama" self.inserted_deps = [] # nodes inserted via super dependency self.all_imports = {} # just stores all of the imports + self.all_safe_imports = {} # stores the safe imports to place them at the end self.global_scope_index = 0 # fmt: on self.files = { # mapping for different component bodies @@ -834,6 +839,7 @@ def __init__(self, python_module, new_name, given_old_name=None, given_new_name= # Mapping from top-level functions to other top-level functions dependencies self.function_call_dependency_mapping = defaultdict(set) self.added_dependencies = set() + self.original_nodes: Dict[str, cst.ClassDef] = {} # Stores the original class def nodes def visit_ImportFrom(self, node: cst.ImportFrom) -> None: """When visiting imports from `transformers.models.xxx` we need to: @@ -867,6 +873,14 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: f"You are importing from {import_statement} directly using global imports. Import from the correct local path" ) + def leave_If(self, original_node, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) + if m.matches(parent_node, m.Module()): + for k in node.body.body[0].body[0].names: + import_name = self.python_module.code_for_node(k.name) + self.all_safe_imports[import_name] = node + return node + def leave_SimpleStatementLine(self, original_node, updated_node): parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) if m.matches(parent_node, m.Module()): @@ -914,6 +928,7 @@ def leave_ClassDef(self, original_node, updated_node): 3. Replace the calls to `super().xxxx` merging parent code """ class_name = original_node.name.value + self.original_nodes[class_name] = original_node bases = [k.value.value for k in original_node.bases if k.value.value in self.imported_mapping] all_bases = [k.value.value for k in original_node.bases] self.global_scope_index += 100 @@ -997,7 +1012,6 @@ def leave_ClassDef(self, original_node, updated_node): list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True) start_insert_idx = self.global_scope_index file_to_update = self.files[file_type] - is_empty_node = self.python_module.code_for_node(original_node.body) == "pass\n" for dependency, _ in list_dependencies: # we can write to the correct body, using the source of the parent class node = class_finder.global_nodes.get(dependency, None) @@ -1008,6 +1022,8 @@ def leave_ClassDef(self, original_node, updated_node): file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node} self.added_dependencies.add(dependency) elif dependency not in self.inserted_deps: + # if the dependency is defined in the modular file, but is just `pass` + is_empty_node = self.python_module.code_for_node(self.original_nodes[dependency].body).strip(" \n") == "pass" # make sure the node is written after its dependencies start_insert_idx = file_to_update[dependency]["insert_idx"] - 1 if ( @@ -1075,14 +1091,6 @@ def visit_Assign(self, node: cst.Assign) -> None: "node": updated_node, } - def leave_If(self, original_node, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) - if m.matches(parent_node, m.Module()): - for k in node.body.body[0].body[0].names: - import_name = self.python_module.code_for_node(k.name) - self.all_imports[import_name] = node - return node - def visit_Call(self, node: cst.Call): """This is used to create a mapping from functions to class calling them, and from top-level functions to functions called inside them. Important note: we only rely on direct Call to the functions here, not indirect mentions (such as assigning a variable with the function, @@ -1154,23 +1162,30 @@ def _recursively_add_all_new_needed_functions_in_files(self): ) def leave_Module(self, original_node: cst.Module, node): - imports = {self.python_module.code_for_node(k): k for k in self.all_imports.values()} - dependency_imports = {file_type: imports.copy() for file_type in self.files} + all_imports = list(self.all_imports.values()) + all_imports_keys = {self.python_module.code_for_node(k) for k in self.all_imports.values()} + dependency_imports = {file_type: all_imports.copy() for file_type in self.files} for super_file_name, visiter in self.visited_module.items(): file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0] - dependency_imports[file_type].update( - {self.python_module.code_for_node(k): k for k in visiter.imports.values()} - ) + dependency_imports[file_type] += [ + k for k in visiter.imports.values() if self.python_module.code_for_node(k) not in all_imports_keys + ] + all_imports_keys.update({self.python_module.code_for_node(k) for k in dependency_imports[file_type]}) + dependency_imports[file_type] += [ + k + for k in self.all_safe_imports.values() + if self.python_module.code_for_node(k) not in all_imports_keys + ] # Check if any new top-level function from the `modular_xxx.py` should be added to the different files # (if it is called in a class in the file, then it will be copy pasted from `modular.py` to that file). - self._recursively_add_all_new_needed_functions_in_files() + # self._recursively_add_all_new_needed_functions_in_files() for file, body in self.files.items(): new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] if len(new_body) > 0: if file in dependency_imports.keys(): - new_body = list(dependency_imports[file].values()) + new_body + new_body = dependency_imports[file] + new_body new_module = cst.Module(body=[*new_body], header=node.header) # Final cleanup unused_imports = get_unused_imports(new_module) From 7922210a9353bc980c4ba2f6e16a3fbe55100e02 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 22 Oct 2024 17:53:01 +0200 Subject: [PATCH 17/52] nits --- src/transformers/models/gemma2/modular_gemma2.py | 10 ++++++---- utils/modular_model_converter.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 0192e39bea4e..0b88df56230b 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn import torch.utils.checkpoint +import math from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache @@ -33,8 +34,7 @@ logging, ) from ..gemma.modeling_gemma import ( - GemmaAttention, - GemmaDecoderLayer, + GemmaRotaryEmbedding, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification, @@ -320,13 +320,15 @@ def sdpa_attention_forward(config, query, key, value, mask, output_attentions=Fa return attn_output -GEMMA_ATTENTION_FUNCTION = { +GEMMA2_ATTENTION_FUNCTION = { "flash_attention": flash_attention_forward, "flex_attention": flex_attention_forward, "eager": eager_attention_forward, "sdpa": sdpa_attention_forward, } +class Gemma2RotaryEmbedding(GemmaRotaryEmbedding): + pass class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -350,7 +352,7 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.scaling = config.query_pre_attn_scalar**-0.5 self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None self.attention_type = config.attn_implementation - self.attention_function = GEMMA_ATTENTION_FUNCTION[config.attn_implementation] + self.attention_function = GEMMA2_ATTENTION_FUNCTION[config.attn_implementation] if self.hidden_size % self.num_heads != 0: diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index e226041f3891..0710232e59ea 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -781,7 +781,7 @@ def get_unused_imports(source): """ wrapper = cst.metadata.MetadataWrapper(source) scopes = set(wrapper.resolve(cst.metadata.ScopeProvider).values()) - unused_imports: Dict[Union[cst.Import, cst.ImportFrom], Set[str]] = defaultdict(set) + unused_imports: Dict[ Set[str], Union[cst.Import, cst.ImportFrom]] = defaultdict(set) ranges = wrapper.resolve(cst.metadata.PositionProvider) for scope in scopes: for assignment in scope.assignments: From 43c68f662c4ddbd880596bd98b07aa5979db6045 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Nov 2024 10:35:38 +0100 Subject: [PATCH 18/52] use a deprecation cycle --- .../models/gemma/configuration_gemma.py | 1 + .../models/gemma2/modular_gemma2.py | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index 346f386ba698..e170803cccab 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -20,6 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from ...configuration_utils import PretrainedConfig diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 3298356a9484..5954b79942e2 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -418,6 +418,27 @@ def forward( return attn_output, attn_weights, past_key_value + +class Gemma2FlashAttention2(Gemma2Attention): + def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.attention_function = GEMMA2_ATTENTION_FUNCTION["flash_attention"] + logger.warning_once( + "The `Gemma2FlashAttention2` class is deprecated in favor of simply modify the `attention_function`" + "attribute of the `GemmaAttention` class! It will be removed in v4.48" + ) + + +class Gemma2SdpaAttention(Gemma2Attention): + def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.attention_function = GEMMA2_ATTENTION_FUNCTION["sdpa"] + logger.warning_once( + "The `Gemma2FlashAttention2` class is deprecated in favor of simply modify the `attention_function`" + "attribute of the `GemmaAttention` class! It will be removed in v4.48" + ) + + class Gemma2DecoderLayer(nn.Module): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() From 1aec9445e5353d95e13a73a45b614568aa493d62 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Nov 2024 11:22:00 +0100 Subject: [PATCH 19/52] updates --- .../models/gemma2/modeling_gemma2.py | 427 ++++++------------ utils/modular_model_converter.py | 3 +- 2 files changed, 149 insertions(+), 281 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index ffce9fb2db17..31a165bf9d64 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -19,6 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import List, Optional, Tuple, Union import torch @@ -41,7 +42,6 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal, - is_flash_attn_greater_or_equal_2_10, is_torch_greater_or_equal, logging, replace_return_docstrings, @@ -52,6 +52,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_greater_or_equal("2.5"): + from torch.nn.attention.flex_attention import flex_attention logger = logging.get_logger(__name__) @@ -60,20 +62,6 @@ _CONFIG_FOR_DOC = "Gemma2Config" -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - -logger = logging.get_logger(__name__) - - -_CHECKPOINT_FOR_DOC = "google/gemma2-7b" -_CONFIG_FOR_DOC = "Gemma2Config" - -if is_torch_greater_or_equal("2.5"): - from torch.nn.attention.flex_attention import flex_attention - - class Gemma2RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() @@ -137,6 +125,125 @@ def forward(self, x, position_ids, seq_len=None): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +def eager_attention_forward(config, query, key, value, mask): + key_states = repeat_kv(key, config.num_key_value_groups) + value_states = repeat_kv(value, config.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + + if config.attn_logit_softcapping is not None: + attn_weights = attn_weights / config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * config.attn_logit_softcapping + if mask is not None: # no matter the length, we just slice it + causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) + attn_output = torch.matmul(attn_weights, value_states) + return attn_output + + +def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16): + if mask is not None: + seq_len = mask.shape[1] + query = query[:, :, :seq_len] + value = value[:, :, :seq_len] + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding + query_states = query.transpose(1, 2) + key_states = key.transpose(1, 2) + value_states = value.transpose(1, 2) + + dropout_rate = config.attention_dropout if config.training else 0.0 + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + mask, + seq_len, + dropout=dropout_rate, + softmax_scale=config.scaling, + is_causal=config.is_causal, + sliding_window=config.sliding_window, + use_top_left_mask=config._flash_attn_uses_top_left_mask, + softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, + ) + + return attn_output + + +def flex_attention_forward(config, query, key, value, mask, output_attentions=False, target_dtype=torch.float16): + causal_mask = mask + if mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + def tanh_softcap(score, b, h, q_idx, kv_idx): + soft_cap = config.attn_logit_softcapping + return soft_cap * torch.tanh(score / soft_cap) + + attn_output = flex_attention( + query, + key, + value, + block_mask=causal_mask, + score_mod=tanh_softcap, + enable_gqa=True, + scale=config.scaling, + return_lse=output_attentions, + ) + return attn_output + + +def sdpa_attention_forward(config, query, key, value, mask, output_attentions=False, target_dtype=torch.float16): + key = repeat_kv(key, config.num_key_value_groups) + value = repeat_kv(value, config.num_key_value_groups) + + causal_mask = mask + if mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query.device.type == "cuda" and causal_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and query.shape[1] > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=config.attention_dropout if config.training else 0.0, + is_causal=is_causal, + scale=config.scaling, + ) + return attn_output + + +GEMMA2_ATTENTION_FUNCTION = { + "flash_attention": flash_attention_forward, + "flex_attention": flex_attention_forward, + "eager": eager_attention_forward, + "sdpa": sdpa_attention_forward, +} + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -171,18 +278,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -190,12 +285,6 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size @@ -206,7 +295,12 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True + self.scaling = 1 / math.sqrt(config.head_dim) + self.scaling = config.query_pre_attn_scalar**-0.5 + self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None + self.attention_type = config.attn_implementation + self.attention_function = GEMMA2_ATTENTION_FUNCTION[config.attn_implementation] if self.hidden_size % self.num_heads != 0: raise ValueError( @@ -223,7 +317,6 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) - self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None def forward( self, @@ -258,33 +351,11 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - - if self.config.attn_logit_softcapping is not None: - attn_weights = attn_weights / self.config.attn_logit_softcapping - attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * self.config.attn_logit_softcapping - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = self.attention_function( + self, query_states, key_states, value_states, attention_mask, self.config + ) - attn_output = attn_output.view(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: @@ -294,207 +365,36 @@ def forward( class Gemma2FlashAttention2(Gemma2Attention): - """ - Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - if attention_mask is not None: - seq_len = attention_mask.shape[1] - key_states = key_states[:, :, :seq_len] - value_states = value_states[:, :, :seq_len] - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (Gemma2RMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - softmax_scale=self.scaling, - is_causal=self.is_causal, - sliding_window=self.sliding_window, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, + def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.attention_function = GEMMA2_ATTENTION_FUNCTION["flash_attention"] + logger.warning_once( + "The `Gemma2FlashAttention2` class is deprecated in favor of simply modify the `attention_function`" + "attribute of the `GemmaAttention` class! It will be removed in v4.48" ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - class Gemma2SdpaAttention(Gemma2Attention): - """ - Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Gemma2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - def tanh_softcap(score, b, h, q_idx, kv_idx): - soft_cap = self.config.attn_logit_softcapping - return soft_cap * torch.tanh(score / soft_cap) - - attn_output = flex_attention( - query_states, - key_states, - value_states, - block_mask=causal_mask, - score_mod=tanh_softcap, - enable_gqa=True, - scale=self.scaling, - return_lse=output_attentions, + def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.attention_function = GEMMA2_ATTENTION_FUNCTION["sdpa"] + logger.warning_once( + "The `Gemma2FlashAttention2` class is deprecated in favor of simply modify the `attention_function`" + "attribute of the `GemmaAttention` class! It will be removed in v4.48" ) - if output_attentions: - attn_output, attention_scores = attn_output - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, attention_scores, past_key_value - - -GEMMA2_ATTENTION_CLASSES = { - "eager": Gemma2Attention, - "flash_attention_2": Gemma2FlashAttention2, - "sdpa": Gemma2SdpaAttention, -} class Gemma2DecoderLayer(nn.Module): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.config = config + self.is_sliding = not bool(layer_idx % 2) + self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma2MLP(config) self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.config = config - self.is_sliding = not bool(layer_idx % 2) + self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.sliding_window = config.sliding_window @@ -509,25 +409,6 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding # Flash-attn is a 2D tensor if self.config._attn_implementation == "flash_attention_2": @@ -620,20 +501,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - @classmethod - def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): - """ - Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models. - SDPA reduces the model performance on Gemma2 because of the logits softcapping. - """ - config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only) - - # if using the default path -> swap sdpa by eager - if not hard_check_only and config._attn_implementation == "sdpa": - config._attn_implementation = "eager" - - return config - GEMMA2_INPUTS_DOCSTRING = r""" Args: diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 93fda8b03428..c26d938b18f2 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -1142,6 +1142,7 @@ def visit_SimpleStatementLine(self, node): if assigned_variable == "__all__": self.all_all_to_add = split_all_assignment(node) else: + self.current_assignment = assigned_variable self.assignments[assigned_variable] = node def leave_Module(self, node): @@ -1485,7 +1486,7 @@ def save_modeling_file(modular_file, converted_file): parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["src/transformers/models/gemma/modular_gemma.py"], + default=["src/transformers/models/gemma2/modular_gemma2.py"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", ) From 93b53efdaff3ee70d3d99b8c545411a7b841f8f7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 1 Nov 2024 11:53:30 +0100 Subject: [PATCH 20/52] Fix modular (recursive deps need to always be computed after merges!) --- utils/modular_model_converter.py | 37 +++++++++++++++++++------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 93fda8b03428..b9cfd32a0eea 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -632,8 +632,10 @@ def leave_Module(self, node): for id, node in self.global_nodes.items(): self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line - # Since we added every Name as part of `self.object_dependency_mapping`, we now remove those that - # are not part of the recorded objects (i.e. built-in variables, imports, etc) + def _restrict_dependencies_to_known_entities(self): + """Since we added every Name as part of `self.object_dependency_mapping`, we need to remove those that + are not part of the recorded objects in `self.global_nodes` (i.e. built-in variables, imports, etc). + This should be called only after all merging operations have been finalized!!""" global_objects = set(self.global_nodes.keys()) for object_name, dependencies in self.object_dependency_mapping.items(): self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects} @@ -814,6 +816,8 @@ def merge_modular_dependencies(self, classes, functions, assignments, object_map # Correctly re-set the global nodes at this point self.global_nodes.update(self.functions) self.global_nodes.update(self.assignments) + # Restrict the dependency mappings to the know entities to avoid Python's built-ins + self._restrict_dependencies_to_known_entities() # Create the global mapping of recursive dependencies for functions and assignments self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() @@ -1142,22 +1146,20 @@ def visit_SimpleStatementLine(self, node): if assigned_variable == "__all__": self.all_all_to_add = split_all_assignment(node) else: + self.current_assignment = assigned_variable self.assignments[assigned_variable] = node def leave_Module(self, node): """When we leave the modular file, we do the following in order: - 1. compute the nested (recursive) function and assignment dependencies - 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update + 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update its dependency graph with the new function and assignment definitions found in the modular - 3. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files) + 2. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files) + 3. compute the nested (recursive) function and assignment dependencies """ # Takes care of finalizing our visit super().leave_Module(node) - # 1. compute the nested (recursive) function and assignment dependencies - self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() - - # 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies + # 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies self.visited_modules = {} self.renamers = {} for file, module in self.model_specific_modules.items(): @@ -1177,10 +1179,13 @@ def leave_Module(self, node): # We record it so that we can rename classes later the exact same way self.renamers[file] = renamer - # 3. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the + # 2. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the # definitions found in the visited files self.merge_model_specific_imports(self.visited_modules) + # 3. compute the nested (recursive) function and assignment dependencies + self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() + # We need to keep track of which objects were imported directly into which modeling file to not add them wrongly later # Note that we may visit several of the same file types, thus we save them per file type, not file self.imported_objects_per_file = defaultdict(set) @@ -1200,9 +1205,9 @@ def merge_model_specific_imports(self, visited_modules): if object_name in visited_module.functions and object_name not in self.functions: self.functions[object_name] = visited_module.functions[object_name] self.added_objects_file_mapping[object_name] = file - dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None) + dependencies = visited_module.object_dependency_mapping.get(object_name, None) if dependencies is not None: - self.object_recursive_dependency_mapping[object_name] = dependencies + self.object_dependency_mapping[object_name] = dependencies for dep in dependencies: if dep not in self.global_nodes: self.added_objects_file_mapping[dep] = file @@ -1212,9 +1217,9 @@ def merge_model_specific_imports(self, visited_modules): elif object_name in visited_module.assignments and object_name not in self.assignments: self.assignments[object_name] = visited_module.assignments[object_name] self.added_objects_file_mapping[object_name] = file - dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None) + dependencies = visited_module.object_dependency_mapping.get(object_name, None) if dependencies is not None: - self.object_recursive_dependency_mapping[object_name] = dependencies + self.object_dependency_mapping[object_name] = dependencies for dep in dependencies: if dep not in self.global_nodes: self.added_objects_file_mapping[dep] = file @@ -1222,6 +1227,8 @@ def merge_model_specific_imports(self, visited_modules): # Do not forget to re-assign all nodes after the merge self.global_nodes = {**self.assignments, **self.classes, **self.functions} + # And restric dependencies to those nodes only + self._restrict_dependencies_to_known_entities() def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: """Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that @@ -1485,7 +1492,7 @@ def save_modeling_file(modular_file, converted_file): parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["src/transformers/models/gemma/modular_gemma.py"], + default=["src/transformers/models/gemma2/modular_gemma2.py"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", ) From a79c4a949770a48f9c037f24b871fb2280671339 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Nov 2024 11:59:20 +0100 Subject: [PATCH 21/52] push --- src/transformers/models/gemma2/modeling_gemma2.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 31a165bf9d64..fae9db359632 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -278,6 +278,18 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" From 4c6d2990f9a831a453a2b2ac516734014804f57f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Nov 2024 12:24:02 +0100 Subject: [PATCH 22/52] fix --- src/transformers/models/gemma2/modular_gemma2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 5954b79942e2..3393514b0e12 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -353,8 +353,8 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.scaling = config.query_pre_attn_scalar**-0.5 self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None - self.attention_type = config.attn_implementation - self.attention_function = GEMMA2_ATTENTION_FUNCTION[config.attn_implementation] + self.attention_type = config._attn_implementation + self.attention_function = GEMMA2_ATTENTION_FUNCTION[config._attn_implementation] if self.hidden_size % self.num_heads != 0: From 607c45df9b63f40815a8f0e04de81a70d54f89ef Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Nov 2024 12:24:11 +0100 Subject: [PATCH 23/52] update --- src/transformers/models/gemma2/modular_gemma2.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 3393514b0e12..c91863060b08 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -13,12 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint -import math from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache @@ -34,13 +34,13 @@ logging, ) from ..gemma.modeling_gemma import ( - GemmaRotaryEmbedding, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification, GemmaModel, GemmaPreTrainedModel, GemmaRMSNorm, + GemmaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv, ) @@ -231,7 +231,6 @@ def eager_attention_forward(config, query, key, value, mask): return attn_output - def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16): if mask is not None: seq_len = mask.shape[1] @@ -329,9 +328,11 @@ def sdpa_attention_forward(config, query, key, value, mask, output_attentions=Fa "sdpa": sdpa_attention_forward, } + class Gemma2RotaryEmbedding(GemmaRotaryEmbedding): pass + class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -356,7 +357,6 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.attention_type = config._attn_implementation self.attention_function = GEMMA2_ATTENTION_FUNCTION[config._attn_implementation] - if self.hidden_size % self.num_heads != 0: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" @@ -450,7 +450,6 @@ def __init__(self, config: Gemma2Config, layer_idx: int): self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.sliding_window = config.sliding_window From 4598bba4212b1dd4b66fab1f20be5a47b307cbf8 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 1 Nov 2024 13:07:16 +0100 Subject: [PATCH 24/52] fix modular order --- utils/modular_model_converter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index b9cfd32a0eea..e5f6e34ece0e 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -1246,10 +1246,11 @@ def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: else: original_dependencies.append(dep) # Sort all lists according to the order in their respective file - all_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) + all_dependencies = [] for file, dependencies in other_files_dependencies.items(): sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x]) all_dependencies += sorted_dependencies + all_dependencies += sorted(original_dependencies, key=lambda x: self.start_lines[x]) # Add all original node first, then merged ones (one file at a time) for dep in all_dependencies: From 5727270324b89e06d91fc6aca9cd3b503fd9bd50 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Nov 2024 18:13:37 +0100 Subject: [PATCH 25/52] make fix-copies --- .../models/gemma2/modeling_gemma2.py | 96 +++++++++---------- 1 file changed, 48 insertions(+), 48 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index fae9db359632..cd8e63734546 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -125,6 +125,52 @@ def forward(self, x, position_ids, seq_len=None): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + def eager_attention_forward(config, query, key, value, mask): key_states = repeat_kv(key, config.num_key_value_groups) value_states = repeat_kv(value, config.num_key_value_groups) @@ -244,52 +290,6 @@ def sdpa_attention_forward(config, query, key, value, mask, output_attentions=Fa } -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -311,8 +311,8 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.scaling = config.query_pre_attn_scalar**-0.5 self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None - self.attention_type = config.attn_implementation - self.attention_function = GEMMA2_ATTENTION_FUNCTION[config.attn_implementation] + self.attention_type = config._attn_implementation + self.attention_function = GEMMA2_ATTENTION_FUNCTION[config._attn_implementation] if self.hidden_size % self.num_heads != 0: raise ValueError( From 198b4c4805755ea3a9755a86cc4c975261b1c762 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Nov 2024 18:30:20 +0100 Subject: [PATCH 26/52] updates --- src/transformers/models/gemma2/modeling_gemma2.py | 12 +++++++++--- src/transformers/models/gemma2/modular_gemma2.py | 12 +++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index cd8e63734546..c106000498bf 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -226,7 +226,7 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch. softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, ) - return attn_output + return attn_output, None def flex_attention_forward(config, query, key, value, mask, output_attentions=False, target_dtype=torch.float16): @@ -279,7 +279,7 @@ def sdpa_attention_forward(config, query, key, value, mask, output_attentions=Fa is_causal=is_causal, scale=config.scaling, ) - return attn_output + return attn_output, None GEMMA2_ATTENTION_FUNCTION = { @@ -363,7 +363,13 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attn_output = self.attention_function( + if output_attentions and self.attention_type in ["sdpa", "flash_attention_2"]: + logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") + attention_type = "flex_attention" + else: + attention_type = self.attention_type + + attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( self, query_states, key_states, value_states, attention_mask, self.config ) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index c91863060b08..e0b29978558e 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -265,7 +265,7 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch. softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, ) - return attn_output + return attn_output, None def flex_attention_forward(config, query, key, value, mask, output_attentions=False, target_dtype=torch.float16): @@ -318,7 +318,7 @@ def sdpa_attention_forward(config, query, key, value, mask, output_attentions=Fa is_causal=is_causal, scale=config.scaling, ) - return attn_output + return attn_output, None GEMMA2_ATTENTION_FUNCTION = { @@ -406,7 +406,13 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attn_output = self.attention_function( + if output_attentions and self.attention_type in ["sdpa", "flash_attention_2"]: + logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") + attention_type = "flex_attention" + else: + attention_type = self.attention_type + + attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( self, query_states, key_states, value_states, attention_mask, self.config ) From 3d35151f806daa4ce2e8e534b65515e668a727fe Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Nov 2024 18:47:08 +0100 Subject: [PATCH 27/52] update --- src/transformers/models/gemma2/modeling_gemma2.py | 11 ++++++----- src/transformers/models/gemma2/modular_gemma2.py | 11 ++++++----- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index c106000498bf..1f3ca3b6a364 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -53,7 +53,7 @@ from ...modeling_flash_attention_utils import _flash_attention_forward if is_torch_greater_or_equal("2.5"): - from torch.nn.attention.flex_attention import flex_attention + from torch.nn.attention.flex_attention import create_block_mask, flex_attention logger = logging.get_logger(__name__) @@ -230,9 +230,10 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch. def flex_attention_forward(config, query, key, value, mask, output_attentions=False, target_dtype=torch.float16): - causal_mask = mask - if mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] + def mask_mod(b, h, q_idx, kv_idx): + if mask is None: + return None + return mask[:, :, :, :kv_idx] def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping @@ -242,7 +243,7 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): query, key, value, - block_mask=causal_mask, + block_mask=create_block_mask(mask_mod), score_mod=tanh_softcap, enable_gqa=True, scale=config.scaling, diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index e0b29978558e..11c1e75e5e8f 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -50,7 +50,7 @@ from ...modeling_flash_attention_utils import _flash_attention_forward if is_torch_greater_or_equal("2.5"): - from torch.nn.attention.flex_attention import flex_attention + from torch.nn.attention.flex_attention import flex_attention, create_block_mask _CHECKPOINT_FOR_DOC = "google/gemma2-7b" @@ -269,9 +269,10 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch. def flex_attention_forward(config, query, key, value, mask, output_attentions=False, target_dtype=torch.float16): - causal_mask = mask - if mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] + def mask_mod(b, h, q_idx, kv_idx): + if mask is None: + return None + return mask[:, :, :, : kv_idx] def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping @@ -281,7 +282,7 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): query, key, value, - block_mask=causal_mask, + block_mask=create_block_mask(mask_mod), score_mod=tanh_softcap, enable_gqa=True, scale=config.scaling, From da050cdd6f72641d7792082f89adc5c8c8b53054 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Nov 2024 18:51:53 +0100 Subject: [PATCH 28/52] ? --- src/transformers/models/gemma2/modeling_gemma2.py | 4 +++- src/transformers/models/gemma2/modular_gemma2.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 1f3ca3b6a364..f81c0fe28766 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -243,7 +243,9 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): query, key, value, - block_mask=create_block_mask(mask_mod), + block_mask=create_block_mask( + mask_mod, query.shape[0], query.shape[2], query.shape[1], key.shape[1], _compile=True + ), score_mod=tanh_softcap, enable_gqa=True, scale=config.scaling, diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 11c1e75e5e8f..0fadad38e4bb 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -282,7 +282,7 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): query, key, value, - block_mask=create_block_mask(mask_mod), + block_mask=create_block_mask(mask_mod, query.shape[0], query.shape[2], query.shape[1], key.shape[1], _compile=True), score_mod=tanh_softcap, enable_gqa=True, scale=config.scaling, From e02078cabe2f8fe14c96125eef78ca4e821b82a1 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Nov 2024 19:15:32 +0100 Subject: [PATCH 29/52] don't compile for now --- src/transformers/models/gemma2/modular_gemma2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 0fadad38e4bb..01ad361ec461 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -282,7 +282,7 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): query, key, value, - block_mask=create_block_mask(mask_mod, query.shape[0], query.shape[2], query.shape[1], key.shape[1], _compile=True), + block_mask=create_block_mask(mask_mod, query.shape[0], query.shape[2], query.shape[1], key.shape[1]), score_mod=tanh_softcap, enable_gqa=True, scale=config.scaling, From 5861bbf561752c71217a50c4cb1edcc33dec5aed Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Nov 2024 08:04:49 +0100 Subject: [PATCH 30/52] ? --- src/transformers/models/gemma2/modular_gemma2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 01ad361ec461..8d3a656716e3 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -269,10 +269,11 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch. def flex_attention_forward(config, query, key, value, mask, output_attentions=False, target_dtype=torch.float16): + if mask is None: + mask = mask[:, :, :, : key.shape[-2]] + def mask_mod(b, h, q_idx, kv_idx): - if mask is None: - return None - return mask[:, :, :, : kv_idx] + return mask def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping From 8c47da2e6b8a66453c6bee8e9f5b82e5dabc1f9a Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Nov 2024 08:48:52 +0100 Subject: [PATCH 31/52] fix some stuff --- .../models/gemma2/modeling_gemma2.py | 33 ++++++++-------- .../models/gemma2/modular_gemma2.py | 38 +++++++++---------- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index f81c0fe28766..93e7b18e467a 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -53,7 +53,7 @@ from ...modeling_flash_attention_utils import _flash_attention_forward if is_torch_greater_or_equal("2.5"): - from torch.nn.attention.flex_attention import create_block_mask, flex_attention + from torch.nn.attention.flex_attention import flex_attention logger = logging.get_logger(__name__) @@ -171,7 +171,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(config, query, key, value, mask): +def eager_attention_forward(config, query, key, value, mask, **_kwargs): key_states = repeat_kv(key, config.num_key_value_groups) value_states = repeat_kv(value, config.num_key_value_groups) @@ -189,10 +189,10 @@ def eager_attention_forward(config, query, key, value, mask): attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) attn_output = torch.matmul(attn_weights, value_states) - return attn_output + return attn_output, attn_weights -def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16): +def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs): if mask is not None: seq_len = mask.shape[1] query = query[:, :, :seq_len] @@ -229,32 +229,33 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch. return attn_output, None -def flex_attention_forward(config, query, key, value, mask, output_attentions=False, target_dtype=torch.float16): - def mask_mod(b, h, q_idx, kv_idx): - if mask is None: - return None - return mask[:, :, :, :kv_idx] +# torch._dynamo.config.capture_scalar_outputs = True +def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): + if mask is not None: + mask = mask[0][0] def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping - return soft_cap * torch.tanh(score / soft_cap) + score = soft_cap * torch.tanh(score / soft_cap) + if mask is not None: + score = score + mask[q_idx, kv_idx] + return score attn_output = flex_attention( query, key, value, - block_mask=create_block_mask( - mask_mod, query.shape[0], query.shape[2], query.shape[1], key.shape[1], _compile=True - ), score_mod=tanh_softcap, enable_gqa=True, scale=config.scaling, return_lse=output_attentions, ) + if not output_attentions: + return attn_output, None return attn_output -def sdpa_attention_forward(config, query, key, value, mask, output_attentions=False, target_dtype=torch.float16): +def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): key = repeat_kv(key, config.num_key_value_groups) value = repeat_kv(value, config.num_key_value_groups) @@ -316,7 +317,7 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None self.attention_type = config._attn_implementation self.attention_function = GEMMA2_ATTENTION_FUNCTION[config._attn_implementation] - + self.attn_logit_softcapping = config.attn_logit_softcapping if self.hidden_size % self.num_heads != 0: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" @@ -373,7 +374,7 @@ def forward( attention_type = self.attention_type attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( - self, query_states, key_states, value_states, attention_mask, self.config + self, query_states, key_states, value_states, attention_mask ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 8d3a656716e3..629392c07199 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -210,7 +210,10 @@ def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) -def eager_attention_forward(config, query, key, value, mask): +class Gemma2RotaryEmbedding(GemmaRotaryEmbedding): + pass + +def eager_attention_forward(config, query, key, value, mask, **_kwargs): key_states = repeat_kv(key, config.num_key_value_groups) value_states = repeat_kv(value, config.num_key_value_groups) @@ -228,10 +231,10 @@ def eager_attention_forward(config, query, key, value, mask): attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) attn_output = torch.matmul(attn_weights, value_states) - return attn_output + return attn_output, attn_weights -def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16): +def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs): if mask is not None: seq_len = mask.shape[1] query = query[:, :, :seq_len] @@ -267,32 +270,33 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch. return attn_output, None - -def flex_attention_forward(config, query, key, value, mask, output_attentions=False, target_dtype=torch.float16): - if mask is None: - mask = mask[:, :, :, : key.shape[-2]] - - def mask_mod(b, h, q_idx, kv_idx): - return mask +# torch._dynamo.config.capture_scalar_outputs = True +def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): + if mask is not None: + mask = mask[0][0] def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping - return soft_cap * torch.tanh(score / soft_cap) + score = soft_cap * torch.tanh(score / soft_cap) + if mask is not None: + score = score + mask[q_idx, kv_idx] + return score attn_output = flex_attention( query, key, value, - block_mask=create_block_mask(mask_mod, query.shape[0], query.shape[2], query.shape[1], key.shape[1]), score_mod=tanh_softcap, enable_gqa=True, scale=config.scaling, return_lse=output_attentions, ) + if not output_attentions: + return attn_output, None return attn_output -def sdpa_attention_forward(config, query, key, value, mask, output_attentions=False, target_dtype=torch.float16): +def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): key = repeat_kv(key, config.num_key_value_groups) value = repeat_kv(value, config.num_key_value_groups) @@ -331,10 +335,6 @@ def sdpa_attention_forward(config, query, key, value, mask, output_attentions=Fa } -class Gemma2RotaryEmbedding(GemmaRotaryEmbedding): - pass - - class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -358,7 +358,7 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None self.attention_type = config._attn_implementation self.attention_function = GEMMA2_ATTENTION_FUNCTION[config._attn_implementation] - + self.attn_logit_softcapping = config.attn_logit_softcapping if self.hidden_size % self.num_heads != 0: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" @@ -415,7 +415,7 @@ def forward( attention_type = self.attention_type attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( - self, query_states, key_states, value_states, attention_mask, self.config + self, query_states, key_states, value_states, attention_mask ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() From 09a88d938389b82d5617ac9c48b630644d712041 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Nov 2024 09:01:08 +0100 Subject: [PATCH 32/52] donc! --- src/transformers/models/gemma2/modeling_gemma2.py | 5 +---- src/transformers/models/gemma2/modular_gemma2.py | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 93e7b18e467a..b92770896a6a 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -231,14 +231,11 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch. # torch._dynamo.config.capture_scalar_outputs = True def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): - if mask is not None: - mask = mask[0][0] - def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping score = soft_cap * torch.tanh(score / soft_cap) if mask is not None: - score = score + mask[q_idx, kv_idx] + return score + mask[b][h] return score attn_output = flex_attention( diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 629392c07199..e988625c911b 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -272,14 +272,11 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch. # torch._dynamo.config.capture_scalar_outputs = True def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): - if mask is not None: - mask = mask[0][0] - def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping score = soft_cap * torch.tanh(score / soft_cap) if mask is not None: - score = score + mask[q_idx, kv_idx] + return score + mask[b][h] return score attn_output = flex_attention( From c06b5306541bdc79e32c0d49785d9e1ba4bb4b8c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Nov 2024 09:01:50 +0100 Subject: [PATCH 33/52] fix copies --- src/transformers/models/gemma2/modeling_gemma2.py | 1 - src/transformers/models/gemma2/modular_gemma2.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index b92770896a6a..fe03ad368e5a 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -229,7 +229,6 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch. return attn_output, None -# torch._dynamo.config.capture_scalar_outputs = True def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index e988625c911b..64fd540d9250 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -50,7 +50,7 @@ from ...modeling_flash_attention_utils import _flash_attention_forward if is_torch_greater_or_equal("2.5"): - from torch.nn.attention.flex_attention import flex_attention, create_block_mask + from torch.nn.attention.flex_attention import flex_attention _CHECKPOINT_FOR_DOC = "google/gemma2-7b" @@ -213,6 +213,7 @@ def forward(self, x): class Gemma2RotaryEmbedding(GemmaRotaryEmbedding): pass + def eager_attention_forward(config, query, key, value, mask, **_kwargs): key_states = repeat_kv(key, config.num_key_value_groups) value_states = repeat_kv(value, config.num_key_value_groups) @@ -270,7 +271,7 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch. return attn_output, None -# torch._dynamo.config.capture_scalar_outputs = True + def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping From 89e6f8593dd0d84cec0ebf6f5da730e925f20327 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Nov 2024 09:03:44 +0100 Subject: [PATCH 34/52] update --- src/transformers/models/gemma2/modeling_gemma2.py | 6 ++---- src/transformers/models/gemma2/modular_gemma2.py | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index fe03ad368e5a..4b49314c6e1b 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -311,8 +311,6 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.scaling = config.query_pre_attn_scalar**-0.5 self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None - self.attention_type = config._attn_implementation - self.attention_function = GEMMA2_ATTENTION_FUNCTION[config._attn_implementation] self.attn_logit_softcapping = config.attn_logit_softcapping if self.hidden_size % self.num_heads != 0: raise ValueError( @@ -363,11 +361,11 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if output_attentions and self.attention_type in ["sdpa", "flash_attention_2"]: + if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") attention_type = "flex_attention" else: - attention_type = self.attention_type + attention_type = self.config._attn_implementation attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( self, query_states, key_states, value_states, attention_mask diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 64fd540d9250..a28b71564ff3 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -354,8 +354,6 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.scaling = config.query_pre_attn_scalar**-0.5 self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None - self.attention_type = config._attn_implementation - self.attention_function = GEMMA2_ATTENTION_FUNCTION[config._attn_implementation] self.attn_logit_softcapping = config.attn_logit_softcapping if self.hidden_size % self.num_heads != 0: raise ValueError( @@ -406,11 +404,11 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if output_attentions and self.attention_type in ["sdpa", "flash_attention_2"]: + if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") attention_type = "flex_attention" else: - attention_type = self.attention_type + attention_type = self.config._attn_implementation attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( self, query_states, key_states, value_states, attention_mask From 152e0b77133f088e07c72aa59d42a2f620b6537c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Nov 2024 09:07:53 +0100 Subject: [PATCH 35/52] fixup --- src/transformers/models/gemma2/modeling_gemma2.py | 4 ++-- src/transformers/models/gemma2/modular_gemma2.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 4b49314c6e1b..164de60f35fa 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -383,7 +383,7 @@ def forward( class Gemma2FlashAttention2(Gemma2Attention): def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.attention_function = GEMMA2_ATTENTION_FUNCTION["flash_attention"] + self.config._attn_implementation = GEMMA2_ATTENTION_FUNCTION["flash_attention"] logger.warning_once( "The `Gemma2FlashAttention2` class is deprecated in favor of simply modify the `attention_function`" "attribute of the `GemmaAttention` class! It will be removed in v4.48" @@ -393,7 +393,7 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): class Gemma2SdpaAttention(Gemma2Attention): def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.attention_function = GEMMA2_ATTENTION_FUNCTION["sdpa"] + self.config._attn_implementation = GEMMA2_ATTENTION_FUNCTION["sdpa"] logger.warning_once( "The `Gemma2FlashAttention2` class is deprecated in favor of simply modify the `attention_function`" "attribute of the `GemmaAttention` class! It will be removed in v4.48" diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index a28b71564ff3..828f61a1be49 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -426,7 +426,7 @@ def forward( class Gemma2FlashAttention2(Gemma2Attention): def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.attention_function = GEMMA2_ATTENTION_FUNCTION["flash_attention"] + self.config._attn_implementation = GEMMA2_ATTENTION_FUNCTION["flash_attention"] logger.warning_once( "The `Gemma2FlashAttention2` class is deprecated in favor of simply modify the `attention_function`" "attribute of the `GemmaAttention` class! It will be removed in v4.48" @@ -436,7 +436,7 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): class Gemma2SdpaAttention(Gemma2Attention): def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.attention_function = GEMMA2_ATTENTION_FUNCTION["sdpa"] + self.config._attn_implementation = GEMMA2_ATTENTION_FUNCTION["sdpa"] logger.warning_once( "The `Gemma2FlashAttention2` class is deprecated in favor of simply modify the `attention_function`" "attribute of the `GemmaAttention` class! It will be removed in v4.48" From 006e86931247d7bacf94ad00ada1fb6baf11e94b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Nov 2024 12:24:31 +0100 Subject: [PATCH 36/52] ? --- src/transformers/models/gemma2/modeling_gemma2.py | 2 +- src/transformers/models/gemma2/modular_gemma2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 164de60f35fa..7fd350308f68 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -234,7 +234,7 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping score = soft_cap * torch.tanh(score / soft_cap) if mask is not None: - return score + mask[b][h] + return score + mask[b][h][q_idx][kv_idx] return score attn_output = flex_attention( diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 828f61a1be49..edd6ad99b60f 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -277,7 +277,7 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping score = soft_cap * torch.tanh(score / soft_cap) if mask is not None: - return score + mask[b][h] + return score + mask[b][h][q_idx][kv_idx] return score attn_output = flex_attention( From 159c65a2134a9378214e525331cc8d2c118d36aa Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Nov 2024 15:53:52 +0100 Subject: [PATCH 37/52] fix two tests --- src/transformers/models/gemma2/modeling_gemma2.py | 2 +- tests/models/gemma2/test_modeling_gemma2.py | 13 ------------- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 7fd350308f68..04fc48205023 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -368,7 +368,7 @@ def forward( attention_type = self.config._attn_implementation attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( - self, query_states, key_states, value_states, attention_mask + self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 7bca83f96d73..9728452a54f5 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -199,19 +199,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l def test_sdpa_equivalence(self): pass - def test_eager_attention_loaded_by_default(self): - """Gemma 2 + SDPA = inferior results, because of the logit softcapping. Eager is the default.""" - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - - # Usually we enable SDPA by default, but not for Gemma2 - model = Gemma2Model(config) - self.assertTrue(model.config._attn_implementation == "eager") - - # We can still force SDPA - config._attn_implementation = "sdpa" - model = Gemma2Model(config) - self.assertTrue(model.config._attn_implementation == "sdpa") - @slow @require_torch_gpu From 56ea5b9045ea1775734f92bf16d8586740251de7 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Nov 2024 16:21:20 +0100 Subject: [PATCH 38/52] fix? --- src/transformers/models/gemma2/modeling_gemma2.py | 3 ++- src/transformers/models/gemma2/modular_gemma2.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 04fc48205023..c9a0697b0a30 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -248,7 +248,8 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): ) if not output_attentions: return attn_output, None - return attn_output + else: + return attn_output[0], attn_output[1] def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index edd6ad99b60f..f1a7fbe4e100 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -291,7 +291,8 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): ) if not output_attentions: return attn_output, None - return attn_output + else: + return attn_output[0], attn_output[1] def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): @@ -411,7 +412,7 @@ def forward( attention_type = self.config._attn_implementation attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( - self, query_states, key_states, value_states, attention_mask + self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() From 4c3deb9d551fd2947c20666af9b6fff1717c4e4e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Nov 2024 19:04:49 +0100 Subject: [PATCH 39/52] for now, don't use head info --- src/transformers/models/gemma2/modeling_gemma2.py | 2 +- src/transformers/models/gemma2/modular_gemma2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index c9a0697b0a30..270aec271644 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -234,7 +234,7 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping score = soft_cap * torch.tanh(score / soft_cap) if mask is not None: - return score + mask[b][h][q_idx][kv_idx] + return score + mask[b][0][q_idx][kv_idx] return score attn_output = flex_attention( diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index f1a7fbe4e100..b5fa1fec6279 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -277,7 +277,7 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping score = soft_cap * torch.tanh(score / soft_cap) if mask is not None: - return score + mask[b][h][q_idx][kv_idx] + return score + mask[b][0][q_idx][kv_idx] return score attn_output = flex_attention( From 9e3609d03c98c84fcbf7e45867b97644d9232fc0 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Nov 2024 19:11:37 +0100 Subject: [PATCH 40/52] eager when output attentoin and sdpa or flash as it's the simplest behaviour (for our tests as well :)) --- src/transformers/models/gemma2/modular_gemma2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index b5fa1fec6279..3f22ee0e782c 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -407,7 +407,7 @@ def forward( if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") - attention_type = "flex_attention" + attention_type = "eager" else: attention_type = self.config._attn_implementation From 21edaedc77478fba1abe30312543d78fc4d048d2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Nov 2024 19:11:41 +0100 Subject: [PATCH 41/52] fix-copies --- src/transformers/models/gemma2/modeling_gemma2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 270aec271644..e0cf7c5edfbe 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -364,7 +364,7 @@ def forward( if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") - attention_type = "flex_attention" + attention_type = "eager" else: attention_type = self.config._attn_implementation From b5d98194bd7380168e9e819d3b0edf73c7911d31 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Nov 2024 19:16:39 +0100 Subject: [PATCH 42/52] revert sdpa check --- src/transformers/models/gemma2/modeling_gemma2.py | 14 ++++++++++++++ src/transformers/models/gemma2/modular_gemma2.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index e0cf7c5edfbe..952f44b5c6d5 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -518,6 +518,20 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + @classmethod + def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): + """ + Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models. + SDPA reduces the model performance on Gemma2 because of the logits softcapping. + """ + config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only) + + # if using the default path -> swap sdpa by eager + if not hard_check_only and config._attn_implementation == "sdpa": + config._attn_implementation = "eager" + + return config + GEMMA2_INPUTS_DOCSTRING = r""" Args: diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 3f22ee0e782c..c3942603cb25 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -520,6 +520,20 @@ def forward( class Gemma2PreTrainedModel(GemmaPreTrainedModel): _supports_quantized_cache = False + @classmethod + def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): + """ + Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models. + SDPA reduces the model performance on Gemma2 because of the logits softcapping. + """ + config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only) + + # if using the default path -> swap sdpa by eager + if not hard_check_only and config._attn_implementation == "sdpa": + config._attn_implementation = "eager" + + return config + class Gemma2Model(GemmaModel, Gemma2PreTrainedModel): def __init__(self, config: Gemma2Config): From 5a3dadee686c1a98dd799d2f7d82a2e07e52a9be Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Wed, 6 Nov 2024 08:23:48 +0100 Subject: [PATCH 43/52] Apply suggestions from code review Co-authored-by: Cyril Vallez --- src/transformers/models/gemma2/modular_gemma2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index c3942603cb25..d7a21c34d9ae 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -327,7 +327,7 @@ def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): GEMMA2_ATTENTION_FUNCTION = { - "flash_attention": flash_attention_forward, + "flash_attention_2": flash_attention_forward, "flex_attention": flex_attention_forward, "eager": eager_attention_forward, "sdpa": sdpa_attention_forward, @@ -427,9 +427,9 @@ def forward( class Gemma2FlashAttention2(Gemma2Attention): def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.config._attn_implementation = GEMMA2_ATTENTION_FUNCTION["flash_attention"] + self.config._attn_implementation = "flash_attention_2" logger.warning_once( - "The `Gemma2FlashAttention2` class is deprecated in favor of simply modify the `attention_function`" + "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" "attribute of the `GemmaAttention` class! It will be removed in v4.48" ) @@ -437,9 +437,9 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): class Gemma2SdpaAttention(Gemma2Attention): def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.config._attn_implementation = GEMMA2_ATTENTION_FUNCTION["sdpa"] + self.config._attn_implementation = "sdpa" logger.warning_once( - "The `Gemma2FlashAttention2` class is deprecated in favor of simply modify the `attention_function`" + "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" "attribute of the `GemmaAttention` class! It will be removed in v4.48" ) From 1da75e111bb63810d229bbda567e0a22f83ecf1d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 6 Nov 2024 08:37:38 +0100 Subject: [PATCH 44/52] rebase, fix-copies and push --- src/transformers/models/gemma2/modeling_gemma2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 952f44b5c6d5..771151ca9c98 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -284,7 +284,7 @@ def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): GEMMA2_ATTENTION_FUNCTION = { - "flash_attention": flash_attention_forward, + "flash_attention_2": flash_attention_forward, "flex_attention": flex_attention_forward, "eager": eager_attention_forward, "sdpa": sdpa_attention_forward, @@ -384,9 +384,9 @@ def forward( class Gemma2FlashAttention2(Gemma2Attention): def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.config._attn_implementation = GEMMA2_ATTENTION_FUNCTION["flash_attention"] + self.config._attn_implementation = "flash_attention_2" logger.warning_once( - "The `Gemma2FlashAttention2` class is deprecated in favor of simply modify the `attention_function`" + "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" "attribute of the `GemmaAttention` class! It will be removed in v4.48" ) @@ -394,9 +394,9 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): class Gemma2SdpaAttention(Gemma2Attention): def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.config._attn_implementation = GEMMA2_ATTENTION_FUNCTION["sdpa"] + self.config._attn_implementation = "sdpa" logger.warning_once( - "The `Gemma2FlashAttention2` class is deprecated in favor of simply modify the `attention_function`" + "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" "attribute of the `GemmaAttention` class! It will be removed in v4.48" ) From aca9120a6e71b77edd86ff63a6cb0e3a998cb4af Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 6 Nov 2024 08:57:28 +0100 Subject: [PATCH 45/52] add a slow integration test --- tests/models/gemma2/test_modeling_gemma2.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 9728452a54f5..9a4ee1951594 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -352,3 +352,23 @@ def test_export_static_cache(self): ) ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) + + @require_read_token + def test_model_9b_bf16_flex_attention(self): + model_id = "google/gemma-2-9b" + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", + "Hi today I'm going to be talking about the history of the United States. The United States of America", + ] + + model = AutoModelForCausalLM.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention" + ).to(torch_device) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=False) + + self.assertEqual(output_text, EXPECTED_TEXTS) From 8f1fc5eae4c4b709743b4f00e2402ea1a4547cd3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 19 Nov 2024 13:01:10 +0100 Subject: [PATCH 46/52] update the test --- tests/generation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index cbe851e97e9a..6c5a192739f3 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1496,7 +1496,7 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature): next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] # They should result in very similar logits - self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5)) + torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, atol=1e-4, rtol=1e-4) @pytest.mark.generate def test_past_key_values_format(self): From 5be3babde90b439eff85d45e7b707b9cf0a61789 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 19 Nov 2024 13:10:38 +0100 Subject: [PATCH 47/52] fix left padding issue --- src/transformers/models/gemma2/modeling_gemma2.py | 1 + src/transformers/models/gemma2/modular_gemma2.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 771151ca9c98..4b50e87a0c62 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -189,6 +189,7 @@ def eager_attention_forward(config, query, key, value, mask, **_kwargs): attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index d7a21c34d9ae..59c81cc925b3 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -232,6 +232,7 @@ def eager_attention_forward(config, query, key, value, mask, **_kwargs): attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights From 3e5b87a9ed5897e0fc5aa6e82274ba358c3b48af Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 19 Nov 2024 13:10:57 +0100 Subject: [PATCH 48/52] fix test --- tests/generation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6c5a192739f3..faaf2fb38c75 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1496,7 +1496,7 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature): next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] # They should result in very similar logits - torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, atol=1e-5, rtol=1e-5) @pytest.mark.generate def test_past_key_values_format(self): From 0513aff070637a2e07f74ca25e25abbd1a8f538b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 19 Nov 2024 13:12:14 +0100 Subject: [PATCH 49/52] remove duplicate scaling --- src/transformers/models/gemma2/modeling_gemma2.py | 3 --- src/transformers/models/gemma2/modular_gemma2.py | 2 -- 2 files changed, 5 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 4b50e87a0c62..77a1670d6254 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -19,7 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import List, Optional, Tuple, Union import torch @@ -309,8 +308,6 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True - self.scaling = 1 / math.sqrt(config.head_dim) - self.scaling = config.query_pre_attn_scalar**-0.5 self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None self.attn_logit_softcapping = config.attn_logit_softcapping diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 59c81cc925b3..f90de7193f6b 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -352,8 +352,6 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True - self.scaling = 1 / math.sqrt(config.head_dim) - self.scaling = config.query_pre_attn_scalar**-0.5 self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None self.attn_logit_softcapping = config.attn_logit_softcapping From 480aff811612dd1db19070c21ea0f533483400af Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 19 Nov 2024 13:15:57 +0100 Subject: [PATCH 50/52] quality --- src/transformers/models/gemma2/modular_gemma2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index f90de7193f6b..52d1ffe7e977 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Optional, Tuple, Union import torch From 2a765d6e7288fb861041502640813d06513461ce Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 19 Nov 2024 13:36:46 +0100 Subject: [PATCH 51/52] add a small test and make sure it works --- src/transformers/modeling_utils.py | 1 + tests/models/gemma2/test_modeling_gemma2.py | 27 ++++++++++++++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0df59d1db8e0..0bbef5265f0d 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1591,6 +1591,7 @@ def _autoset_attn_implementation( "eager", "sdpa", "flash_attention_2", + "flex_attention", ]: message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' if cls._supports_flash_attn_2: diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 9a4ee1951594..4f42a2e97cfc 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -264,9 +264,30 @@ def test_model_9b_pipeline_bf16(self): "Hi today I'm going to be talking about the history of the United States. The United States of America", ] - model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to( - torch_device - ) + model = AutoModelForCausalLM.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention" + ).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(model_id) + pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) + + output = pipe(self.input_text, max_new_tokens=20, do_sample=False, padding=True) + + self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0]) + self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1]) + + @require_read_token + def test_model_2b_pipeline_bf16_flex_attention(self): + # See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR + model_id = "google/gemma-2-9b" + # EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1960s and I am trying to find out what the average", + "Hi today I'm going to be talking about the 10 best anime of all time.\n\n1", + ] + + model = AutoModelForCausalLM.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention" + ).to(torch_device) tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) From 6aba68c0b9508b8c8fce025bd5492151258cd57a Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 19 Nov 2024 13:37:05 +0100 Subject: [PATCH 52/52] 2b --- tests/models/gemma2/test_modeling_gemma2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 4f42a2e97cfc..06116c4dbafb 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -278,7 +278,7 @@ def test_model_9b_pipeline_bf16(self): @require_read_token def test_model_2b_pipeline_bf16_flex_attention(self): # See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR - model_id = "google/gemma-2-9b" + model_id = "google/gemma-2-2b" # EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens EXPECTED_TEXTS = [ "Hello I am doing a project on the 1960s and I am trying to find out what the average",