Skip to content

Commit

Permalink
Merge pull request #15 from Sanster/main
Browse files Browse the repository at this point in the history
Add QWen model + benchmark results
  • Loading branch information
tomaarsen authored Oct 17, 2023
2 parents 48bb293 + 6120a47 commit fc33531
Show file tree
Hide file tree
Showing 12 changed files with 41,100 additions and 8 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ Types of changes
* "Security" in case of vulnerabilities.
-->

## [Unreleased]

### Added

- Added support for Qwen models. ([#15](https://github.com/tomaarsen/attention_sinks/pull/15))

### Changed

- Changed how Attention Sinks are injected into models, allows `attention_sinks` to be integrated with architectures that aren't in `transformers` ([#16](https://github.com/tomaarsen/attention_sinks/pull/16))

## [0.2.3] - 2023-10-10

### Added
Expand Down
17 changes: 9 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ The following figures plot model perplexities under the various different approa
| ![mpt_7b_ppl_vram_plotted](https://github.com/mit-han-lab/streaming-llm/assets/37621491/c96cff66-92a3-43ab-bc21-40232f2740a0) | ![pythia_6 8b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/b0fee168-fa5a-457d-9e27-8395eb6dfb38) |
| **Mistral-7B-v0.1** | **GPT-J-6B** |
| ![mistral_7b_ppl_vram_plotted](https://github.com/microsoft/torchscale/assets/37621491/3a4c5634-cc1b-42d1-a35a-afb376a4f970) | ![gpt_j_6b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/bdca944f-2fd2-46c4-8a88-2e1a8f16f75f) |
| **Qwen-7B** | |
| ![qwen_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/ecf8beaf-7f8b-4412-bdcc-1d7f78b265bd) | |

The results are clear as day:
1. `transformers`: The VRAM usage is linear as it doesn't do any windowing. The performance heavily falls after ~4096 tokens.
2. `windowed`: The VRAM is constant usage due to the windowing at 1024 tokens. However, it fails as soon as the first tokens leave the window.
1. `transformers`: The VRAM usage is linear as it doesn't do any windowing. The performance heavily falls after the pretraining length.
2. `windowed`: The VRAM is constant usage due to the windowing at 1024 tokens. However, the performance falls as soon as the first tokens leave the window.
3. `attention_sinks`: Constant VRAM usage due to windowing with 4 attention sink tokens + the 1020 most recent tokens. This approach never fails despite the constant VRAM usage.

### Fluency during endless generation
Expand Down Expand Up @@ -59,9 +61,8 @@ This repository is an open-source implementation of the [Efficient Streaming Lan

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")
```
* Support for Llama, Mistral, Falcon, MPT, GPTNeoX (Pythia) and GPT-J models.
* Note: All of these models must be loaded **without** `trust_remote_code=True`.
* New parameters to `AutoModel....from_pretrained`:
* Support for Llama, Mistral, Falcon, MPT, GPTNeoX (Pythia), GPT-J and Qwen models.
* New parameters to `AutoModelForCausalLM.from_pretrained`:
* `attention_sink_size`, `int`, defaults to 4: The number of initial tokens to use as the attention sink. These tokens are always included in the Attention Sink KV Cache.
* `attention_sink_window_size`, `int`, defaults to 1020: The size of the sliding window, i.e. the number of "recent tokens" to include in the Attention Sink KV Cache. A larger window size costs more memory.

Expand All @@ -74,11 +75,11 @@ pip install attention_sinks
```

### Usage
Loading any Llama, Mistral, Falcon, MPT, GPTNeoX (Pythia) and GPT-J model is as simple as loading it in `transformers`, the only change is that the model class must be imported from `attention_sinks` rather than `transformers`, e.g.:
Loading any Llama, Mistral, Falcon, MPT, GPTNeoX (Pythia), GPT-J and Qwen is as simple as loading it in `transformers`, the only change is that the model class must be imported from `attention_sinks` rather than `transformers`, e.g.:
```python
from attention_sinks import AutoModel
from attention_sinks import AutoModelForCausalLM

model = AutoModel.from_pretrained("mosaicml/mpt-7b", device_map="auto")
model = AutoModelForCausalLM.from_pretrained("mosaicml/mpt-7b", device_map="auto")
```

Generation can be done like you would expect from `transformers`, e.g. like so:
Expand Down
5 changes: 5 additions & 0 deletions attention_sinks/inject_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"gpt_neox": "GPTNeoXModel",
"gptj": "GPTJModel",
"mistral": "MistralModel",
"qwen": "QWenModel",
}
ATTENTION_NAME_MAPPING = {
"llama": "LlamaAttention",
Expand All @@ -24,6 +25,7 @@
"gpt_neox": "GPTNeoXAttention",
"gptj": "GPTJAttention",
"mistral": "MistralAttention",
"qwen": "QWenAttention",
}
KV_DIM_MAPPING = {
"llama": (2, 2),
Expand All @@ -32,6 +34,7 @@
"gpt_neox": (2, 2),
"gptj": (2, 2),
"mistral": (2, 2),
"qwen": (1, 1),
}


Expand Down Expand Up @@ -84,6 +87,7 @@ def _inject_pos_shift_attention(cls, model: PreTrainedModel) -> Optional[int]:
gptj_pos_shift_attention_forward,
llama_pos_shift_attention_forward,
mistral_pos_shift_attention_forward,
qwen_pos_shift_attention_forward,
)

ATTENTION_FORWARD_MAPPING = {
Expand All @@ -93,6 +97,7 @@ def _inject_pos_shift_attention(cls, model: PreTrainedModel) -> Optional[int]:
"gpt_neox": gpt_neox_pos_shift_attention_forward,
"gptj": gptj_pos_shift_attention_forward,
"mistral": mistral_pos_shift_attention_forward,
"qwen": qwen_pos_shift_attention_forward,
}

# Not all models require updated attention forwards
Expand Down
1 change: 1 addition & 0 deletions attention_sinks/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@
MptModel,
MptPreTrainedModel,
)
from .qwen import qwen_pos_shift_attention_forward
1 change: 1 addition & 0 deletions attention_sinks/models/qwen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .pos_shift import qwen_pos_shift_attention_forward
108 changes: 108 additions & 0 deletions attention_sinks/models/qwen/pos_shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import List, Optional, Tuple

import torch

__all__ = ["qwen_pos_shift_attention_forward"]


def _rotate_half(x):
from einops import rearrange

x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(t, freqs):
cos, sin = freqs
rot_dim = freqs[0].shape[-1]
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
t_ = t_.float()
t_pass_ = t_pass_.float()
t_ = (t_ * cos) + (_rotate_half(t_) * sin)
return torch.cat((t_, t_pass_), dim=-1).type_as(t)


def qwen_pos_shift_attention_forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
registered_causal_mask: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
):
mixed_x_layer = self.c_attn(hidden_states)

query, key, value = mixed_x_layer.split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

if rotary_pos_emb_list is not None:
cur_len = query.shape[1]
if len(rotary_pos_emb_list) == 1:
rotary_pos_emb = rotary_pos_emb_list[0]
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query = apply_rotary_pos_emb(query, q_pos_emb)
# key = apply_rotary_pos_emb(key, k_pos_emb)
else:
# TODO: modify batch infer
query_list = []
key_list = []
for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query_list += [apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)]
key_list += [apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)]
query = torch.cat(query_list, dim=0)
key = torch.cat(key_list, dim=0)

if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)

if use_cache:
present = (key, value)
else:
present = None

### Shift pos ###
kv_seq_len = key.size(1)
key_shifted_position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=key.device)
key_rotary_pos_emb = [it[:, key_shifted_position_ids, :, :] for it in rotary_pos_emb_list[0]]
key = apply_rotary_pos_emb(key, key_rotary_pos_emb)
#######

if self.use_logn_attn and not self.training:
seq_start = key.size(1) - query.size(1)
seq_end = key.size(1)
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
query = query * logn_tensor.expand_as(query)

query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
attn_output, attn_weight = self._attn(
query, key, value, registered_causal_mask, attention_mask=None, head_mask=head_mask
)
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)

attn_output = self.c_proj(context_layer)

outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weight,)

return outputs
Loading

0 comments on commit fc33531

Please sign in to comment.