diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 710809093f38..ad11aea5138a 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -123,6 +123,9 @@ class LlamaConfig(PretrainedConfig): The dropout ratio for the attention probabilities. mlp_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to hidden_size // num_heads + ```python >>> from transformers import LlamaModel, LlamaConfig @@ -163,6 +166,7 @@ def __init__( attention_bias=False, attention_dropout=0.0, mlp_bias=False, + head_dim=None, **kwargs, ): self.vocab_size = vocab_size @@ -187,6 +191,10 @@ def __init__( self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.mlp_bias = mlp_bias + if head_dim is None: + self.head_dim = self.hidden_size // self.num_attention_heads + else: + self.head_dim = head_dim # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, move it to 'rope_type'. diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4a0887629cc6..3e1ca6b8badc 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -340,23 +340,17 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) self.rotary_emb = LlamaRotaryEmbedding(config=self.config)