Skip to content

Commit

Permalink
Merge pull request #20 from kmn1024/stablelm
Browse files Browse the repository at this point in the history
Add support for StableLM 3b 4e1t model
  • Loading branch information
tomaarsen authored Oct 19, 2023
2 parents fc33531 + e2658e4 commit 52d4917
Show file tree
Hide file tree
Showing 9 changed files with 24,723 additions and 4 deletions.
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ The following figures plot model perplexities under the various different approa
| ![mpt_7b_ppl_vram_plotted](https://github.com/mit-han-lab/streaming-llm/assets/37621491/c96cff66-92a3-43ab-bc21-40232f2740a0) | ![pythia_6 8b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/b0fee168-fa5a-457d-9e27-8395eb6dfb38) |
| **Mistral-7B-v0.1** | **GPT-J-6B** |
| ![mistral_7b_ppl_vram_plotted](https://github.com/microsoft/torchscale/assets/37621491/3a4c5634-cc1b-42d1-a35a-afb376a4f970) | ![gpt_j_6b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/bdca944f-2fd2-46c4-8a88-2e1a8f16f75f) |
| **Qwen-7B** | |
| ![qwen_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/ecf8beaf-7f8b-4412-bdcc-1d7f78b265bd) | |
| **Qwen-7B** | **StableLM-3B-4E1T** |
| ![qwen_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/ecf8beaf-7f8b-4412-bdcc-1d7f78b265bd) | ![stablelm_3b_4e1t_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/d1170e63-870a-404c-99a1-03eebd62422e) |

The results are clear as day:
1. `transformers`: The VRAM usage is linear as it doesn't do any windowing. The performance heavily falls after the pretraining length.
Expand Down Expand Up @@ -61,7 +61,7 @@ This repository is an open-source implementation of the [Efficient Streaming Lan

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")
```
* Support for Llama, Mistral, Falcon, MPT, GPTNeoX (Pythia), GPT-J and Qwen models.
* Support for Llama, Mistral, Falcon, MPT, GPTNeoX (Pythia), GPT-J, Qwen and StableLM_epoch models.
* New parameters to `AutoModelForCausalLM.from_pretrained`:
* `attention_sink_size`, `int`, defaults to 4: The number of initial tokens to use as the attention sink. These tokens are always included in the Attention Sink KV Cache.
* `attention_sink_window_size`, `int`, defaults to 1020: The size of the sliding window, i.e. the number of "recent tokens" to include in the Attention Sink KV Cache. A larger window size costs more memory.
Expand All @@ -75,7 +75,7 @@ pip install attention_sinks
```

### Usage
Loading any Llama, Mistral, Falcon, MPT, GPTNeoX (Pythia), GPT-J and Qwen is as simple as loading it in `transformers`, the only change is that the model class must be imported from `attention_sinks` rather than `transformers`, e.g.:
Loading any Llama, Mistral, Falcon, MPT, GPTNeoX (Pythia), GPT-J, Qwen or StableLM_epoch model is as simple as loading it in `transformers`, the only change is that the model class must be imported from `attention_sinks` rather than `transformers`, e.g.:
```python
from attention_sinks import AutoModelForCausalLM

Expand Down Expand Up @@ -265,6 +265,12 @@ See [CHANGELOG.md](CHANGELOG.md) for all release information.

Inspired by, and adapted from [StreamingLLM](https://github.com/mit-han-lab/streaming-llm).

### Model Contributions
A big thanks to the following contributors for extending the model support of `attention_sinks`!

* [@Sanster](https://github.com/Sanster) for adding support for QWen models.
* [@kmn1024](https://github.com/kmn1024) for adding support for StableLM_Epoch models.

### Citation

```bibtex
Expand Down
5 changes: 5 additions & 0 deletions attention_sinks/inject_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"gptj": "GPTJModel",
"mistral": "MistralModel",
"qwen": "QWenModel",
"stablelm_epoch": "StableLMEpochModel"
}
ATTENTION_NAME_MAPPING = {
"llama": "LlamaAttention",
Expand All @@ -26,6 +27,7 @@
"gptj": "GPTJAttention",
"mistral": "MistralAttention",
"qwen": "QWenAttention",
"stablelm_epoch": "Attention",
}
KV_DIM_MAPPING = {
"llama": (2, 2),
Expand All @@ -35,6 +37,7 @@
"gptj": (2, 2),
"mistral": (2, 2),
"qwen": (1, 1),
"stablelm_epoch": (2, 2),
}


Expand Down Expand Up @@ -88,6 +91,7 @@ def _inject_pos_shift_attention(cls, model: PreTrainedModel) -> Optional[int]:
llama_pos_shift_attention_forward,
mistral_pos_shift_attention_forward,
qwen_pos_shift_attention_forward,
stablelm_epoch_pos_shift_attention_forward,
)

ATTENTION_FORWARD_MAPPING = {
Expand All @@ -98,6 +102,7 @@ def _inject_pos_shift_attention(cls, model: PreTrainedModel) -> Optional[int]:
"gptj": gptj_pos_shift_attention_forward,
"mistral": mistral_pos_shift_attention_forward,
"qwen": qwen_pos_shift_attention_forward,
"stablelm_epoch": stablelm_epoch_pos_shift_attention_forward,
}

# Not all models require updated attention forwards
Expand Down
1 change: 1 addition & 0 deletions attention_sinks/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@
MptPreTrainedModel,
)
from .qwen import qwen_pos_shift_attention_forward
from .stablelm_epoch import stablelm_epoch_pos_shift_attention_forward
1 change: 1 addition & 0 deletions attention_sinks/models/stablelm_epoch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .pos_shift import stablelm_epoch_pos_shift_attention_forward
125 changes: 125 additions & 0 deletions attention_sinks/models/stablelm_epoch/pos_shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""
Adapted from https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
"""

from typing import Optional, Tuple
import math

import torch
from torch import nn

__all__ = ["stablelm_epoch_pos_shift_attention_forward"]


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def rotate_half(x: torch.Tensor):
"""Rotates half the hidden dims of the input."""
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed


def stablelm_epoch_pos_shift_attention_forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

query_rot = query_states[..., : self.rotary_ndims]
query_pass = query_states[..., self.rotary_ndims :]

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states = apply_rotary_pos_emb_single(query_rot, cos, sin, position_ids)
# [batch_size, num_heads, seq_len, head_dim]
query_states = torch.cat((query_states, query_pass), dim=-1)

if past_key_value is not None:
# Reuse k, v, self_attention
key_states = torch.cat((past_key_value[0], key_states), dim=2)
value_states = torch.cat((past_key_value[1], value_states), dim=2)

past_key_value = (key_states, value_states) if use_cache else None

key_rot = key_states[..., : self.rotary_ndims]
key_pass = key_states[..., self.rotary_ndims :]
key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
key_states = apply_rotary_pos_emb_single(key_rot, cos, sin, key_position_ids)
# [batch_size, num_heads, seq_len, head_dim]
key_states = torch.cat((key_states, key_pass), dim=-1)

# Repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask

# Upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

# Merge heads
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

# Final linear projection
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
Loading

0 comments on commit 52d4917

Please sign in to comment.