Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for StableLM 3b 4e1t model #20

Merged
merged 3 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ The following figures plot model perplexities under the various different approa
| ![mpt_7b_ppl_vram_plotted](https://github.com/mit-han-lab/streaming-llm/assets/37621491/c96cff66-92a3-43ab-bc21-40232f2740a0) | ![pythia_6 8b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/b0fee168-fa5a-457d-9e27-8395eb6dfb38) |
| **Mistral-7B-v0.1** | **GPT-J-6B** |
| ![mistral_7b_ppl_vram_plotted](https://github.com/microsoft/torchscale/assets/37621491/3a4c5634-cc1b-42d1-a35a-afb376a4f970) | ![gpt_j_6b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/bdca944f-2fd2-46c4-8a88-2e1a8f16f75f) |
| **Qwen-7B** | |
| ![qwen_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/ecf8beaf-7f8b-4412-bdcc-1d7f78b265bd) | |
| **Qwen-7B** | **StableLM-3B-4E1T** |
| ![qwen_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/ecf8beaf-7f8b-4412-bdcc-1d7f78b265bd) | ![stablelm_3b_4e1t_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/d1170e63-870a-404c-99a1-03eebd62422e) |

The results are clear as day:
1. `transformers`: The VRAM usage is linear as it doesn't do any windowing. The performance heavily falls after the pretraining length.
Expand Down Expand Up @@ -61,7 +61,7 @@ This repository is an open-source implementation of the [Efficient Streaming Lan

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")
```
* Support for Llama, Mistral, Falcon, MPT, GPTNeoX (Pythia), GPT-J and Qwen models.
* Support for Llama, Mistral, Falcon, MPT, GPTNeoX (Pythia), GPT-J, Qwen and StableLM_epoch models.
* New parameters to `AutoModelForCausalLM.from_pretrained`:
* `attention_sink_size`, `int`, defaults to 4: The number of initial tokens to use as the attention sink. These tokens are always included in the Attention Sink KV Cache.
* `attention_sink_window_size`, `int`, defaults to 1020: The size of the sliding window, i.e. the number of "recent tokens" to include in the Attention Sink KV Cache. A larger window size costs more memory.
Expand All @@ -75,7 +75,7 @@ pip install attention_sinks
```

### Usage
Loading any Llama, Mistral, Falcon, MPT, GPTNeoX (Pythia), GPT-J and Qwen is as simple as loading it in `transformers`, the only change is that the model class must be imported from `attention_sinks` rather than `transformers`, e.g.:
Loading any Llama, Mistral, Falcon, MPT, GPTNeoX (Pythia), GPT-J, Qwen or StableLM_epoch model is as simple as loading it in `transformers`, the only change is that the model class must be imported from `attention_sinks` rather than `transformers`, e.g.:
```python
from attention_sinks import AutoModelForCausalLM

Expand Down Expand Up @@ -265,6 +265,12 @@ See [CHANGELOG.md](CHANGELOG.md) for all release information.

Inspired by, and adapted from [StreamingLLM](https://github.com/mit-han-lab/streaming-llm).

### Model Contributions
A big thanks to the following contributors for extending the model support of `attention_sinks`!

* [@Sanster](https://github.com/Sanster) for adding support for QWen models.
* [@kmn1024](https://github.com/kmn1024) for adding support for StableLM_Epoch models.

### Citation

```bibtex
Expand Down
5 changes: 5 additions & 0 deletions attention_sinks/inject_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"gptj": "GPTJModel",
"mistral": "MistralModel",
"qwen": "QWenModel",
"stablelm_epoch": "StableLMEpochModel"
}
ATTENTION_NAME_MAPPING = {
"llama": "LlamaAttention",
Expand All @@ -26,6 +27,7 @@
"gptj": "GPTJAttention",
"mistral": "MistralAttention",
"qwen": "QWenAttention",
"stablelm_epoch": "Attention",
}
KV_DIM_MAPPING = {
"llama": (2, 2),
Expand All @@ -35,6 +37,7 @@
"gptj": (2, 2),
"mistral": (2, 2),
"qwen": (1, 1),
"stablelm_epoch": (2, 2),
}


Expand Down Expand Up @@ -88,6 +91,7 @@ def _inject_pos_shift_attention(cls, model: PreTrainedModel) -> Optional[int]:
llama_pos_shift_attention_forward,
mistral_pos_shift_attention_forward,
qwen_pos_shift_attention_forward,
stablelm_epoch_pos_shift_attention_forward,
)

ATTENTION_FORWARD_MAPPING = {
Expand All @@ -98,6 +102,7 @@ def _inject_pos_shift_attention(cls, model: PreTrainedModel) -> Optional[int]:
"gptj": gptj_pos_shift_attention_forward,
"mistral": mistral_pos_shift_attention_forward,
"qwen": qwen_pos_shift_attention_forward,
"stablelm_epoch": stablelm_epoch_pos_shift_attention_forward,
}

# Not all models require updated attention forwards
Expand Down
1 change: 1 addition & 0 deletions attention_sinks/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@
MptPreTrainedModel,
)
from .qwen import qwen_pos_shift_attention_forward
from .stablelm_epoch import stablelm_epoch_pos_shift_attention_forward
1 change: 1 addition & 0 deletions attention_sinks/models/stablelm_epoch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .pos_shift import stablelm_epoch_pos_shift_attention_forward
125 changes: 125 additions & 0 deletions attention_sinks/models/stablelm_epoch/pos_shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""
Adapted from https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
"""

from typing import Optional, Tuple
import math

import torch
from torch import nn

__all__ = ["stablelm_epoch_pos_shift_attention_forward"]


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def rotate_half(x: torch.Tensor):
"""Rotates half the hidden dims of the input."""
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed


def stablelm_epoch_pos_shift_attention_forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

query_rot = query_states[..., : self.rotary_ndims]
query_pass = query_states[..., self.rotary_ndims :]

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states = apply_rotary_pos_emb_single(query_rot, cos, sin, position_ids)
# [batch_size, num_heads, seq_len, head_dim]
query_states = torch.cat((query_states, query_pass), dim=-1)

if past_key_value is not None:
# Reuse k, v, self_attention
key_states = torch.cat((past_key_value[0], key_states), dim=2)
value_states = torch.cat((past_key_value[1], value_states), dim=2)

past_key_value = (key_states, value_states) if use_cache else None

key_rot = key_states[..., : self.rotary_ndims]
key_pass = key_states[..., self.rotary_ndims :]
key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
key_states = apply_rotary_pos_emb_single(key_rot, cos, sin, key_position_ids)
# [batch_size, num_heads, seq_len, head_dim]
key_states = torch.cat((key_states, key_pass), dim=-1)

# Repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask

# Upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

# Merge heads
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

# Final linear projection
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
Loading