-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #20 from kmn1024/stablelm
Add support for StableLM 3b 4e1t model
- Loading branch information
Showing
9 changed files
with
24,723 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .pos_shift import stablelm_epoch_pos_shift_attention_forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
""" | ||
Adapted from https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py | ||
""" | ||
|
||
from typing import Optional, Tuple | ||
import math | ||
|
||
import torch | ||
from torch import nn | ||
|
||
__all__ = ["stablelm_epoch_pos_shift_attention_forward"] | ||
|
||
|
||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | ||
""" | ||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, | ||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | ||
""" | ||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape | ||
if n_rep == 1: | ||
return hidden_states | ||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) | ||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | ||
|
||
|
||
def rotate_half(x: torch.Tensor): | ||
"""Rotates half the hidden dims of the input.""" | ||
x1, x2 = torch.chunk(x, 2, dim=-1) | ||
return torch.cat((-x2, x1), dim=-1) | ||
|
||
|
||
def apply_rotary_pos_emb_single(x, cos, sin, position_ids): | ||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. | ||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] | ||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] | ||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] | ||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] | ||
x_embed = (x * cos) + (rotate_half(x) * sin) | ||
return x_embed | ||
|
||
|
||
def stablelm_epoch_pos_shift_attention_forward( | ||
self, | ||
hidden_states: torch.FloatTensor, | ||
attention_mask: torch.FloatTensor, | ||
position_ids: torch.LongTensor, | ||
past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
output_attentions: Optional[bool] = False, | ||
use_cache: Optional[bool] = False, | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
bsz, q_len, _ = hidden_states.size() | ||
|
||
query_states = self.q_proj(hidden_states) | ||
key_states = self.k_proj(hidden_states) | ||
value_states = self.v_proj(hidden_states) | ||
|
||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||
|
||
query_rot = query_states[..., : self.rotary_ndims] | ||
query_pass = query_states[..., self.rotary_ndims :] | ||
|
||
kv_seq_len = key_states.shape[-2] | ||
if past_key_value is not None: | ||
kv_seq_len += past_key_value[0].shape[-2] | ||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||
query_states = apply_rotary_pos_emb_single(query_rot, cos, sin, position_ids) | ||
# [batch_size, num_heads, seq_len, head_dim] | ||
query_states = torch.cat((query_states, query_pass), dim=-1) | ||
|
||
if past_key_value is not None: | ||
# Reuse k, v, self_attention | ||
key_states = torch.cat((past_key_value[0], key_states), dim=2) | ||
value_states = torch.cat((past_key_value[1], value_states), dim=2) | ||
|
||
past_key_value = (key_states, value_states) if use_cache else None | ||
|
||
key_rot = key_states[..., : self.rotary_ndims] | ||
key_pass = key_states[..., self.rotary_ndims :] | ||
key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0) | ||
key_states = apply_rotary_pos_emb_single(key_rot, cos, sin, key_position_ids) | ||
# [batch_size, num_heads, seq_len, head_dim] | ||
key_states = torch.cat((key_states, key_pass), dim=-1) | ||
|
||
# Repeat k/v heads if n_kv_heads < n_heads | ||
key_states = repeat_kv(key_states, self.num_key_value_groups) | ||
value_states = repeat_kv(value_states, self.num_key_value_groups) | ||
|
||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) | ||
|
||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): | ||
raise ValueError( | ||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" | ||
f" {attn_weights.size()}" | ||
) | ||
|
||
if attention_mask is not None: | ||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): | ||
raise ValueError( | ||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" | ||
) | ||
attn_weights = attn_weights + attention_mask | ||
|
||
# Upcast attention to fp32 | ||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) | ||
attn_output = torch.matmul(attn_weights, value_states) | ||
|
||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): | ||
raise ValueError( | ||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" | ||
f" {attn_output.size()}" | ||
) | ||
|
||
# Merge heads | ||
attn_output = attn_output.transpose(1, 2).contiguous() | ||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | ||
|
||
# Final linear projection | ||
attn_output = self.o_proj(attn_output) | ||
|
||
if not output_attentions: | ||
attn_weights = None | ||
|
||
return attn_output, attn_weights, past_key_value |
Oops, something went wrong.