From 9d42c45e5df733ae5788bd75d597cd6d3df8a207 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 16 Aug 2024 20:55:04 +0200 Subject: [PATCH 1/7] support head dim --- src/transformers/models/llama/configuration_llama.py | 5 ++++- src/transformers/models/llama/modeling_llama.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 710809093f38..87d11b0cc7e0 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -123,6 +123,8 @@ class LlamaConfig(PretrainedConfig): The dropout ratio for the attention probabilities. mlp_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + head_dim (`int`, *optional*, defaults to None): + The attention head dimension. If None, it will default to hidden_size // num_heads ```python >>> from transformers import LlamaModel, LlamaConfig @@ -163,6 +165,7 @@ def __init__( attention_bias=False, attention_dropout=0.0, mlp_bias=False, + head_dim=None, **kwargs, ): self.vocab_size = vocab_size @@ -187,7 +190,7 @@ def __init__( self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.mlp_bias = mlp_bias - + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, move it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4a0887629cc6..51fabba909a5 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -340,7 +340,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings From e7a094b103cfccf65140c0d1980d3894ee397047 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 16 Aug 2024 20:55:40 +0200 Subject: [PATCH 2/7] fix the doc --- src/transformers/models/llama/configuration_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 87d11b0cc7e0..435f0091e06e 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -123,7 +123,7 @@ class LlamaConfig(PretrainedConfig): The dropout ratio for the attention probabilities. mlp_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. - head_dim (`int`, *optional*, defaults to None): + head_dim (`int`, *optional*): The attention head dimension. If None, it will default to hidden_size // num_heads ```python From f623b89c375f2a9d5d5a7166495003cf3d83ccdd Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 16 Aug 2024 20:55:50 +0200 Subject: [PATCH 3/7] fixup --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 51fabba909a5..9dd6053c87c5 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -340,7 +340,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings From 7476abdd4ccf8cfff72135d91efb41caf9f5917b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Aug 2024 14:54:40 +0200 Subject: [PATCH 4/7] add oproj Co-authored-by: Suhara > --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 9dd6053c87c5..3672b3b0d1f2 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -356,7 +356,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) self.rotary_emb = LlamaRotaryEmbedding(config=self.config) From 10f646a52c16027066790c463f49a6abfed9b1d2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Aug 2024 14:58:48 +0200 Subject: [PATCH 5/7] update Co-authored-by: bzantium --- src/transformers/models/llama/modeling_llama.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3672b3b0d1f2..3e1ca6b8badc 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -347,12 +347,6 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.rope_theta = config.rope_theta self.is_causal = True - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) From 92d14202ec86095c2f1a90f784a9218bcd201852 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Aug 2024 15:00:17 +0200 Subject: [PATCH 6/7] Co-authored-by: suhara --- src/transformers/models/llama/modeling_flax_llama.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/transformers/models/llama/modeling_flax_llama.py b/src/transformers/models/llama/modeling_flax_llama.py index 1c9f1c4adc3e..26a2c2bb09a3 100644 --- a/src/transformers/models/llama/modeling_flax_llama.py +++ b/src/transformers/models/llama/modeling_flax_llama.py @@ -214,12 +214,6 @@ def setup(self): self.k_proj = dense(self.num_key_value_heads * self.head_dim) self.v_proj = dense(self.num_key_value_heads * self.head_dim) self.o_proj = dense(self.embed_dim) - if (self.head_dim * self.num_heads) != self.embed_dim: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.embed_dim}" - f" and `num_heads`: {self.num_heads})." - ) - self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool") self.rotary_emb = FlaxLlamaRotaryEmbedding(config, dtype=self.dtype) From 687aee48455d3c0111d4121420b002f61372a884 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Aug 2024 15:01:51 +0200 Subject: [PATCH 7/7] Update Co-authored-by: Yoshi Suhara --- src/transformers/models/llama/modeling_llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3e1ca6b8badc..8716d27f5481 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -414,7 +414,6 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it