diff --git a/src/transformers/models/mixtral/configuration_mixtral.py b/src/transformers/models/mixtral/configuration_mixtral.py index 164988b4dc52..686c214ef25c 100644 --- a/src/transformers/models/mixtral/configuration_mixtral.py +++ b/src/transformers/models/mixtral/configuration_mixtral.py @@ -53,6 +53,8 @@ class MixtralConfig(PretrainedConfig): converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to `4096*32`): @@ -116,6 +118,7 @@ def __init__( num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, + head_dim=None, hidden_act="silu", max_position_embeddings=4096 * 32, initializer_range=0.02, @@ -154,6 +157,7 @@ def __init__( self.use_cache = use_cache self.rope_theta = rope_theta self.attention_dropout = attention_dropout + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads self.num_experts_per_tok = num_experts_per_tok self.num_local_experts = num_local_experts diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 192b7801af05..d189e94d875b 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -283,7 +283,7 @@ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads + self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings @@ -291,11 +291,6 @@ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): self.is_causal = True self.attention_dropout = config.attention_dropout - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -374,7 +369,7 @@ def forward( ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.o_proj(attn_output) @@ -481,7 +476,7 @@ def forward( is_causal=self.is_causal, ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: @@ -575,7 +570,7 @@ def forward( ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output)