Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPT-NeoX/Pythia support + benchmark results #4

Merged
merged 2 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@

# Attention Sinks in Transformers for Infinite-length LLMs

| Llama 2 7B | Falcon 7B | MPT 7B |
| ------------- | ------------- | ---- |
| ![llama_2_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/8d2e5b88-7158-41ac-8b3a-5a7abe38020d) | ![falcon_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/1be07370-6de7-4a7e-b5ab-3092a5ecb412) | ![mpt_7b_ppl_vram_plotted](https://github.com/mit-han-lab/streaming-llm/assets/37621491/c96cff66-92a3-43ab-bc21-40232f2740a0) |
| Llama 2 7B | Falcon 7B |
|:-------------:|:-------------:|
| ![llama_2_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/8d2e5b88-7158-41ac-8b3a-5a7abe38020d) | ![falcon_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/1be07370-6de7-4a7e-b5ab-3092a5ecb412) |
| **MPT 7B** | **Pythia 6.9B** |
| ![mpt_7b_ppl_vram_plotted](https://github.com/mit-han-lab/streaming-llm/assets/37621491/c96cff66-92a3-43ab-bc21-40232f2740a0) | ![pythia_6 8b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/b0fee168-fa5a-457d-9e27-8395eb6dfb38) |


## Overview

* Extend existing LLMs (e.g. Llama 2) to infinite length without sacrificing efficiency and performance, without any retraining.
* Model perplexities were stable even after 4 million tokens!
* Unlike with regular `transformers`, there is no linear memory increase and no extremely slow inference due to memory issues at higher sequence lengths.
* The `attention_sinks` API allows for a drop-in replacement of the `transformers` API:
```python
from attention_sinks import AutoModel

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")
```
* Support for Llama and Falcon models.
* Support for Llama, Falcon, MPT and GPTNeoX (Pythia) models.
* New parameters to `AutoModel....from_pretrained`:
* `attention_sink_size`, int, defaults to 4: The number of initial tokens to use as the attention sink. These tokens are always included in the Attention Sink KV Cache.
* `attention_sink_window_size`, int, defaults to 1020: The size of the sliding window, i.e. the number of "recent tokens" to include in the Attention Sink KV Cache.
Expand All @@ -26,7 +31,14 @@ pip install attention_sinks
```

## Benchmarks
You can run a few benchmarks to compute the perplexity of various models over time using the provided [perplexity.py](benchmark/perplexity.py) benchmarking script. For example:

### Pre-prepared benchmarks
See [benchmark/scripts](benchmark/scripts) for a collection of ready-to-go scripts for various model architectures like Llama 2, Falcon, MPT and GPT-NeoX (Pythia). Each of these scripts runs the benchmarking and plotting tools described below for pure [`transformers`](https://github.com/huggingface/transformers), [`attention_sinks`](https://github.com/tomaarsen/attention_sinks) and a third alternative: `windowed`, which involves simple windowed attention at a window size of 1024 tokens. Upon completion, the script will plot the figures that you see at the top of this README.

### Benchmarking tool
You can run a few benchmarks to compute the perplexity of various models over time using the provided [perplexity.py](benchmark/perplexity.py) benchmarking script. This is done by computing the negative log likelihood losses of the chosen model when it is provided a full book with 60k+ tokens. By default, the scripts stop after 8192 tokens, but this can be modified. An ideal solution continuously has a low log perplexity and a constant CUDA VRAM usage.

To use the script, you can run:
```
python benchmark/perplexity.py --experiment attention_sinks
```
Expand Down Expand Up @@ -57,7 +69,13 @@ options:

This script will create a `csv` file in the output directory (`"benchmarks/outputs"` by default) for that experiment, with information about perplexities, CUDA VRAM usage and latencies.

This information can be plotted using the [plot_perplexity.py](benchmark\plot_perplexity.py) script. For example:
### Plotting tool
The information from the benchmarking tool can be plotted using the [plot_perplexity.py](benchmark\plot_perplexity.py) script. In particular, you can plot any combination of the following features:
* `perplexity`,
* `vram`, i.e. CUDA VRAM usage,
* `latency`.

For example:
```
python benchmark/plot_perplexity.py --features perplexity latency --title "Log perplexity & latency of Llama 2 7B as a function of input lengths"
```
Expand Down Expand Up @@ -90,7 +108,7 @@ Clear as day:
2. `windowed`: The VRAM is constant usage due to the windowing at 1024 tokens. However, it fails as soon as the first tokens leave the window.
3. `attention_sinks`: Constant VRAM usage due to windowing with 4 attention sink tokens + the 1020 most recent tokens. This approach never fails despite the constant VRAM usage.

I've uploaded [benchmark/outputs_llama_2_7b](benchmark/outputs_llama_2_7b) so you can reproduce this graph using the former command.
I've uploaded outputs of various benchmarks in [benchmark](benchmark) so you can reproduce this graph using the former command.

## Changelog

Expand Down
6 changes: 6 additions & 0 deletions attention_sinks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
FalconForTokenClassification,
FalconModel,
FalconPreTrainedModel,
GPTNeoXForCausalLM,
GPTNeoXForQuestionAnswering,
GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
GPTNeoXModel,
GPTNeoXPreTrainedModel,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
Expand Down
8 changes: 8 additions & 0 deletions attention_sinks/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
FalconModel,
FalconPreTrainedModel,
)
from .gpt_neox import (
GPTNeoXForCausalLM,
GPTNeoXForQuestionAnswering,
GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
GPTNeoXModel,
GPTNeoXPreTrainedModel,
)
from .llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
from .mpt import (
MptForCausalLM,
Expand Down
20 changes: 18 additions & 2 deletions attention_sinks/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers import (
AutoModelForTokenClassification as TAutoModelForTokenClassification,
)
from transformers import FalconConfig, LlamaConfig, MptConfig
from transformers import FalconConfig, GPTNeoXConfig, LlamaConfig, MptConfig
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
Expand All @@ -29,6 +29,13 @@
FalconForTokenClassification,
FalconModel,
)
from ..gpt_neox import (
GPTNeoXForCausalLM,
GPTNeoXForQuestionAnswering,
GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
GPTNeoXModel,
)
from ..llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
from ..mpt import (
MptForCausalLM,
Expand All @@ -38,24 +45,33 @@
MptModel,
)

MODEL_MAPPING._extra_content = {LlamaConfig: LlamaModel, FalconConfig: FalconModel, MptConfig: MptModel}
MODEL_MAPPING._extra_content = {
LlamaConfig: LlamaModel,
FalconConfig: FalconModel,
MptConfig: MptModel,
GPTNeoXConfig: GPTNeoXModel,
}
MODEL_FOR_CAUSAL_LM_MAPPING._extra_content = {
LlamaConfig: LlamaForCausalLM,
FalconConfig: FalconForCausalLM,
MptConfig: MptForCausalLM,
GPTNeoXConfig: GPTNeoXForCausalLM,
}
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING._extra_content = {
LlamaConfig: LlamaForSequenceClassification,
FalconConfig: FalconForSequenceClassification,
MptConfig: MptForSequenceClassification,
GPTNeoXConfig: GPTNeoXForSequenceClassification,
}
MODEL_FOR_QUESTION_ANSWERING_MAPPING._extra_content = {
FalconConfig: FalconForQuestionAnswering,
MptConfig: MptForQuestionAnswering,
GPTNeoXConfig: GPTNeoXForQuestionAnswering,
}
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING._extra_content = {
FalconConfig: FalconForTokenClassification,
MptConfig: MptForTokenClassification,
GPTNeoXConfig: GPTNeoXForTokenClassification,
}


Expand Down
1 change: 0 additions & 1 deletion attention_sinks/models/falcon/pos_shift.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
import types
from typing import Optional, Tuple

Expand Down
8 changes: 8 additions & 0 deletions attention_sinks/models/gpt_neox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .modeling_gpt_neox import (
GPTNeoXForCausalLM,
GPTNeoXForQuestionAnswering,
GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
GPTNeoXModel,
GPTNeoXPreTrainedModel,
)
137 changes: 137 additions & 0 deletions attention_sinks/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
import types
from typing import List, Optional, Tuple, Union

from transformers import (
GPTNeoXForCausalLM as TGPTNeoXForCausalLM,
)
from transformers import (
GPTNeoXForQuestionAnswering as TGPTNeoXForQuestionAnswering,
)
from transformers import (
GPTNeoXForSequenceClassification as TGPTNeoXForSequenceClassification,
)
from transformers import (
GPTNeoXForTokenClassification as TGPTNeoXForTokenClassification,
)
from transformers import (
GPTNeoXModel as TGPTNeoXModel,
)
from transformers import (
GPTNeoXPreTrainedModel as TGPTNeoXPreTrainedModel,
)
from transformers import (
PretrainedConfig,
)
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.gpt_neox.modeling_gpt_neox import (
_CHECKPOINT_FOR_DOC,
_CONFIG_FOR_DOC,
_REAL_CHECKPOINT_FOR_DOC,
GPT_NEOX_INPUTS_DOCSTRING,
)
from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings_to_model_forward,
)

from attention_sinks.attention_sink_kv_cache import AttentionSinkKVCache
from attention_sinks.models.gpt_neox.pos_shift import (
enable_gpt_neox_pos_shift_attention,
)


class GPTNeoXPreTrainedModel(TGPTNeoXPreTrainedModel):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
) -> None:
# Separate Attention Sink kwargs from regular kwargs
attention_sink_kwargs = {key: value for key, value in kwargs.items() if key.startswith("attention_sink")}
for key in attention_sink_kwargs:
kwargs.pop(key)

model = super(GPTNeoXPreTrainedModel, cls).from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
**kwargs,
)
# Enable position shifting attention for GPTNeoX
enable_gpt_neox_pos_shift_attention(model)

# Hackishly attach the Attention Sink KV Cache to the model
cls._attach_attention_sink_kv_cache(model, **attention_sink_kwargs)

return model

@classmethod
def _attach_attention_sink_kv_cache(cls, module, **attention_sink_kwargs):
if isinstance(module, TGPTNeoXModel):
# Create the new cache
module.attention_sink_kv_cache = AttentionSinkKVCache(
**attention_sink_kwargs,
k_seq_dim=2,
v_seq_dim=2,
)

# Keep track of the old forward method, we need it in the wrapped one
old_forward = module.forward

# Wrap the forward by overriding the past_key_values using the cache
@add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def wrapped_forward(self, *args, **kwargs) -> Union[Tuple, BaseModelOutputWithPast]:
outputs = old_forward(*args, **kwargs)
outputs.past_key_values = self.attention_sink_kv_cache(outputs.past_key_values)
return outputs

module.forward = types.MethodType(wrapped_forward, module)

# Recursively call this to find all GPTNeoXModels
for module in reversed(module._modules.values()):
if len(list(module.children())) > 0:
cls._attach_attention_sink_kv_cache(module, **attention_sink_kwargs)


class GPTNeoXModel(GPTNeoXPreTrainedModel, TGPTNeoXModel):
pass


class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, TGPTNeoXForCausalLM):
pass


class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel, TGPTNeoXForSequenceClassification):
pass


class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel, TGPTNeoXForTokenClassification):
pass


class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel, TGPTNeoXForQuestionAnswering):
pass
Loading