-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Refactor Cohere Model #30027
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Cohere Model #30027
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -76,10 +76,10 @@ def _get_unpad_data(attention_mask): | |
|
|
||
|
|
||
| class CohereLayerNorm(nn.Module): | ||
| def __init__(self, hidden_size, eps=1e-5, bias=False): | ||
| def __init__(self, hidden_size=None, eps=1e-5, bias=False): | ||
| """The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim""" | ||
| super().__init__() | ||
| self.weight = nn.Parameter(torch.ones(hidden_size)) | ||
| self.bias = nn.Parameter(torch.zeros(hidden_size)) if bias else None | ||
| self.variance_epsilon = eps | ||
|
|
||
| def forward(self, hidden_states): | ||
|
|
@@ -89,8 +89,6 @@ def forward(self, hidden_states): | |
| variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) | ||
| hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon) | ||
| hidden_states = self.weight.to(torch.float32) * hidden_states | ||
| if self.bias is not None: | ||
| hidden_states = hidden_states + self.bias.to(torch.float32) | ||
| return hidden_states.to(input_dtype) | ||
|
|
||
|
|
||
|
|
@@ -122,7 +120,7 @@ def forward(self, x, position_ids): | |
| emb = torch.repeat_interleave(freqs, 2, dim=-1) | ||
| cos = emb.cos() | ||
| sin = emb.sin() | ||
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) | ||
| return cos, sin | ||
|
|
||
|
|
||
| def rotate_half(x): | ||
|
|
@@ -133,7 +131,6 @@ def rotate_half(x): | |
| return rot_x | ||
|
|
||
|
|
||
| # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb | ||
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | ||
| """Applies Rotary Position Embedding to the query and key tensors. | ||
|
|
||
|
|
@@ -154,11 +151,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | |
| Returns: | ||
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. | ||
| """ | ||
| dtype = q.dtype | ||
| q = q.float() | ||
| k = k.float() | ||
| cos = cos.unsqueeze(unsqueeze_dim) | ||
| sin = sin.unsqueeze(unsqueeze_dim) | ||
| q_embed = (q * cos) + (rotate_half(q) * sin) | ||
| k_embed = (k * cos) + (rotate_half(k) * sin) | ||
| return q_embed, k_embed | ||
| return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) | ||
|
Comment on lines
+154
to
+161
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if this is done outside |
||
|
|
||
|
|
||
| # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere | ||
|
|
@@ -192,7 +192,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | ||
|
|
||
|
|
||
| # Copied from transformers.models.llama.modeling_llama.LlamaAttention Llama->Cohere | ||
| class CohereAttention(nn.Module): | ||
| """Multi-headed attention from 'Attention Is All You Need' paper""" | ||
|
|
||
|
|
@@ -216,13 +215,21 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None): | |
| self.max_position_embeddings = config.max_position_embeddings | ||
| self.rope_theta = config.rope_theta | ||
| self.is_causal = True | ||
| self.use_qk_norm = config.use_qk_norm | ||
|
|
||
| if (self.head_dim * self.num_heads) != self.hidden_size: | ||
| raise ValueError( | ||
| f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" | ||
| f" and `num_heads`: {self.num_heads})." | ||
| ) | ||
|
|
||
| if self.use_qk_norm: | ||
| # When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if this comment is relevant as the model is not sharded by default
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh this is to warn others who port this to other frameworks from HF. |
||
| self.q_norm = CohereLayerNorm(hidden_size=(self.num_heads, self.head_dim), eps=config.layer_norm_eps) | ||
| self.k_norm = CohereLayerNorm( | ||
| hidden_size=(self.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps | ||
| ) | ||
|
|
||
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) | ||
| self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) | ||
| self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) | ||
|
|
@@ -255,8 +262,14 @@ def forward( | |
| key_states = self.k_proj(hidden_states) | ||
| value_states = self.v_proj(hidden_states) | ||
|
|
||
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) | ||
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) | ||
| if self.use_qk_norm: | ||
| query_states = self.q_norm(query_states) | ||
| key_states = self.k_norm(key_states) | ||
|
|
||
| query_states = query_states.transpose(1, 2) | ||
| key_states = key_states.transpose(1, 2) | ||
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||
|
|
||
| past_key_value = getattr(self, "past_key_value", past_key_value) | ||
|
|
@@ -335,11 +348,14 @@ def forward( | |
| key_states = self.k_proj(hidden_states) | ||
| value_states = self.v_proj(hidden_states) | ||
|
|
||
| # Flash attention requires the input to have the shape | ||
| # batch_size x seq_length x head_dim x hidden_dim | ||
| # therefore we just need to keep the original shape | ||
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) | ||
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) | ||
| if self.use_qk_norm: | ||
| query_states = self.q_norm(query_states) | ||
| key_states = self.k_norm(key_states) | ||
|
|
||
| query_states = query_states.transpose(1, 2) | ||
| key_states = key_states.transpose(1, 2) | ||
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||
|
|
||
| cos, sin = self.rotary_emb(value_states, position_ids) | ||
|
|
@@ -505,7 +521,7 @@ class CohereSdpaAttention(CohereAttention): | |
| SDPA API. | ||
| """ | ||
|
|
||
| # Adapted from CohereAttention.forward | ||
| # Ignore copy | ||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
|
|
@@ -538,8 +554,14 @@ def forward( | |
| key_states = self.k_proj(hidden_states) | ||
| value_states = self.v_proj(hidden_states) | ||
|
|
||
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) | ||
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) | ||
| if self.use_qk_norm: | ||
| query_states = self.q_norm(query_states) | ||
| key_states = self.k_norm(key_states) | ||
|
|
||
| query_states = query_states.transpose(1, 2) | ||
| key_states = key_states.transpose(1, 2) | ||
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||
|
|
||
| cos, sin = self.rotary_emb(value_states, position_ids) | ||
|
|
@@ -599,7 +621,7 @@ def __init__(self, config: CohereConfig, layer_idx: int): | |
| self.self_attn = COHERE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) | ||
|
|
||
| self.mlp = CohereMLP(config) | ||
| self.input_layernorm = CohereLayerNorm(config.hidden_size, eps=config.layer_norm_eps) | ||
| self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) | ||
|
|
||
| def forward( | ||
| self, | ||
|
|
@@ -822,7 +844,7 @@ def __init__(self, config: CohereConfig): | |
| self.layers = nn.ModuleList( | ||
| [CohereDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] | ||
| ) | ||
| self.norm = CohereLayerNorm(config.hidden_size, eps=config.layer_norm_eps) | ||
| self.norm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) | ||
| self.gradient_checkpointing = False | ||
|
|
||
| # Initialize weights and apply final processing | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.