Skip to content

Commit

Permalink
Merge pull request #4 from tomaarsen/model/pythia
Browse files Browse the repository at this point in the history
Add GPT-NeoX/Pythia support + benchmark results
  • Loading branch information
tomaarsen authored Oct 3, 2023
2 parents 1e8a1b0 + c081faf commit 4be3831
Show file tree
Hide file tree
Showing 12 changed files with 20,703 additions and 10 deletions.
32 changes: 25 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@

# Attention Sinks in Transformers for Infinite-length LLMs

| Llama 2 7B | Falcon 7B | MPT 7B |
| ------------- | ------------- | ---- |
| ![llama_2_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/8d2e5b88-7158-41ac-8b3a-5a7abe38020d) | ![falcon_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/1be07370-6de7-4a7e-b5ab-3092a5ecb412) | ![mpt_7b_ppl_vram_plotted](https://github.com/mit-han-lab/streaming-llm/assets/37621491/c96cff66-92a3-43ab-bc21-40232f2740a0) |
| Llama 2 7B | Falcon 7B |
|:-------------:|:-------------:|
| ![llama_2_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/8d2e5b88-7158-41ac-8b3a-5a7abe38020d) | ![falcon_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/1be07370-6de7-4a7e-b5ab-3092a5ecb412) |
| **MPT 7B** | **Pythia 6.9B** |
| ![mpt_7b_ppl_vram_plotted](https://github.com/mit-han-lab/streaming-llm/assets/37621491/c96cff66-92a3-43ab-bc21-40232f2740a0) | ![pythia_6 8b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/b0fee168-fa5a-457d-9e27-8395eb6dfb38) |


## Overview

* Extend existing LLMs (e.g. Llama 2) to infinite length without sacrificing efficiency and performance, without any retraining.
* Model perplexities were stable even after 4 million tokens!
* Unlike with regular `transformers`, there is no linear memory increase and no extremely slow inference due to memory issues at higher sequence lengths.
* The `attention_sinks` API allows for a drop-in replacement of the `transformers` API:
```python
from attention_sinks import AutoModel

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")
```
* Support for Llama and Falcon models.
* Support for Llama, Falcon, MPT and GPTNeoX (Pythia) models.
* New parameters to `AutoModel....from_pretrained`:
* `attention_sink_size`, int, defaults to 4: The number of initial tokens to use as the attention sink. These tokens are always included in the Attention Sink KV Cache.
* `attention_sink_window_size`, int, defaults to 1020: The size of the sliding window, i.e. the number of "recent tokens" to include in the Attention Sink KV Cache.
Expand All @@ -26,7 +31,14 @@ pip install attention_sinks
```

## Benchmarks
You can run a few benchmarks to compute the perplexity of various models over time using the provided [perplexity.py](benchmark/perplexity.py) benchmarking script. For example:

### Pre-prepared benchmarks
See [benchmark/scripts](benchmark/scripts) for a collection of ready-to-go scripts for various model architectures like Llama 2, Falcon, MPT and GPT-NeoX (Pythia). Each of these scripts runs the benchmarking and plotting tools described below for pure [`transformers`](https://github.com/huggingface/transformers), [`attention_sinks`](https://github.com/tomaarsen/attention_sinks) and a third alternative: `windowed`, which involves simple windowed attention at a window size of 1024 tokens. Upon completion, the script will plot the figures that you see at the top of this README.

### Benchmarking tool
You can run a few benchmarks to compute the perplexity of various models over time using the provided [perplexity.py](benchmark/perplexity.py) benchmarking script. This is done by computing the negative log likelihood losses of the chosen model when it is provided a full book with 60k+ tokens. By default, the scripts stop after 8192 tokens, but this can be modified. An ideal solution continuously has a low log perplexity and a constant CUDA VRAM usage.

To use the script, you can run:
```
python benchmark/perplexity.py --experiment attention_sinks
```
Expand Down Expand Up @@ -57,7 +69,13 @@ options:

This script will create a `csv` file in the output directory (`"benchmarks/outputs"` by default) for that experiment, with information about perplexities, CUDA VRAM usage and latencies.

This information can be plotted using the [plot_perplexity.py](benchmark\plot_perplexity.py) script. For example:
### Plotting tool
The information from the benchmarking tool can be plotted using the [plot_perplexity.py](benchmark\plot_perplexity.py) script. In particular, you can plot any combination of the following features:
* `perplexity`,
* `vram`, i.e. CUDA VRAM usage,
* `latency`.

For example:
```
python benchmark/plot_perplexity.py --features perplexity latency --title "Log perplexity & latency of Llama 2 7B as a function of input lengths"
```
Expand Down Expand Up @@ -90,7 +108,7 @@ Clear as day:
2. `windowed`: The VRAM is constant usage due to the windowing at 1024 tokens. However, it fails as soon as the first tokens leave the window.
3. `attention_sinks`: Constant VRAM usage due to windowing with 4 attention sink tokens + the 1020 most recent tokens. This approach never fails despite the constant VRAM usage.

I've uploaded [benchmark/outputs_llama_2_7b](benchmark/outputs_llama_2_7b) so you can reproduce this graph using the former command.
I've uploaded outputs of various benchmarks in [benchmark](benchmark) so you can reproduce this graph using the former command.

## Changelog

Expand Down
6 changes: 6 additions & 0 deletions attention_sinks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
FalconForTokenClassification,
FalconModel,
FalconPreTrainedModel,
GPTNeoXForCausalLM,
GPTNeoXForQuestionAnswering,
GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
GPTNeoXModel,
GPTNeoXPreTrainedModel,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
Expand Down
8 changes: 8 additions & 0 deletions attention_sinks/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
FalconModel,
FalconPreTrainedModel,
)
from .gpt_neox import (
GPTNeoXForCausalLM,
GPTNeoXForQuestionAnswering,
GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
GPTNeoXModel,
GPTNeoXPreTrainedModel,
)
from .llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
from .mpt import (
MptForCausalLM,
Expand Down
20 changes: 18 additions & 2 deletions attention_sinks/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers import (
AutoModelForTokenClassification as TAutoModelForTokenClassification,
)
from transformers import FalconConfig, LlamaConfig, MptConfig
from transformers import FalconConfig, GPTNeoXConfig, LlamaConfig, MptConfig
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
Expand All @@ -29,6 +29,13 @@
FalconForTokenClassification,
FalconModel,
)
from ..gpt_neox import (
GPTNeoXForCausalLM,
GPTNeoXForQuestionAnswering,
GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
GPTNeoXModel,
)
from ..llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
from ..mpt import (
MptForCausalLM,
Expand All @@ -38,24 +45,33 @@
MptModel,
)

MODEL_MAPPING._extra_content = {LlamaConfig: LlamaModel, FalconConfig: FalconModel, MptConfig: MptModel}
MODEL_MAPPING._extra_content = {
LlamaConfig: LlamaModel,
FalconConfig: FalconModel,
MptConfig: MptModel,
GPTNeoXConfig: GPTNeoXModel,
}
MODEL_FOR_CAUSAL_LM_MAPPING._extra_content = {
LlamaConfig: LlamaForCausalLM,
FalconConfig: FalconForCausalLM,
MptConfig: MptForCausalLM,
GPTNeoXConfig: GPTNeoXForCausalLM,
}
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING._extra_content = {
LlamaConfig: LlamaForSequenceClassification,
FalconConfig: FalconForSequenceClassification,
MptConfig: MptForSequenceClassification,
GPTNeoXConfig: GPTNeoXForSequenceClassification,
}
MODEL_FOR_QUESTION_ANSWERING_MAPPING._extra_content = {
FalconConfig: FalconForQuestionAnswering,
MptConfig: MptForQuestionAnswering,
GPTNeoXConfig: GPTNeoXForQuestionAnswering,
}
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING._extra_content = {
FalconConfig: FalconForTokenClassification,
MptConfig: MptForTokenClassification,
GPTNeoXConfig: GPTNeoXForTokenClassification,
}


Expand Down
1 change: 0 additions & 1 deletion attention_sinks/models/falcon/pos_shift.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
import types
from typing import Optional, Tuple

Expand Down
8 changes: 8 additions & 0 deletions attention_sinks/models/gpt_neox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .modeling_gpt_neox import (
GPTNeoXForCausalLM,
GPTNeoXForQuestionAnswering,
GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
GPTNeoXModel,
GPTNeoXPreTrainedModel,
)
137 changes: 137 additions & 0 deletions attention_sinks/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
import types
from typing import List, Optional, Tuple, Union

from transformers import (
GPTNeoXForCausalLM as TGPTNeoXForCausalLM,
)
from transformers import (
GPTNeoXForQuestionAnswering as TGPTNeoXForQuestionAnswering,
)
from transformers import (
GPTNeoXForSequenceClassification as TGPTNeoXForSequenceClassification,
)
from transformers import (
GPTNeoXForTokenClassification as TGPTNeoXForTokenClassification,
)
from transformers import (
GPTNeoXModel as TGPTNeoXModel,
)
from transformers import (
GPTNeoXPreTrainedModel as TGPTNeoXPreTrainedModel,
)
from transformers import (
PretrainedConfig,
)
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.gpt_neox.modeling_gpt_neox import (
_CHECKPOINT_FOR_DOC,
_CONFIG_FOR_DOC,
_REAL_CHECKPOINT_FOR_DOC,
GPT_NEOX_INPUTS_DOCSTRING,
)
from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings_to_model_forward,
)

from attention_sinks.attention_sink_kv_cache import AttentionSinkKVCache
from attention_sinks.models.gpt_neox.pos_shift import (
enable_gpt_neox_pos_shift_attention,
)


class GPTNeoXPreTrainedModel(TGPTNeoXPreTrainedModel):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
) -> None:
# Separate Attention Sink kwargs from regular kwargs
attention_sink_kwargs = {key: value for key, value in kwargs.items() if key.startswith("attention_sink")}
for key in attention_sink_kwargs:
kwargs.pop(key)

model = super(GPTNeoXPreTrainedModel, cls).from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
**kwargs,
)
# Enable position shifting attention for GPTNeoX
enable_gpt_neox_pos_shift_attention(model)

# Hackishly attach the Attention Sink KV Cache to the model
cls._attach_attention_sink_kv_cache(model, **attention_sink_kwargs)

return model

@classmethod
def _attach_attention_sink_kv_cache(cls, module, **attention_sink_kwargs):
if isinstance(module, TGPTNeoXModel):
# Create the new cache
module.attention_sink_kv_cache = AttentionSinkKVCache(
**attention_sink_kwargs,
k_seq_dim=2,
v_seq_dim=2,
)

# Keep track of the old forward method, we need it in the wrapped one
old_forward = module.forward

# Wrap the forward by overriding the past_key_values using the cache
@add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def wrapped_forward(self, *args, **kwargs) -> Union[Tuple, BaseModelOutputWithPast]:
outputs = old_forward(*args, **kwargs)
outputs.past_key_values = self.attention_sink_kv_cache(outputs.past_key_values)
return outputs

module.forward = types.MethodType(wrapped_forward, module)

# Recursively call this to find all GPTNeoXModels
for module in reversed(module._modules.values()):
if len(list(module.children())) > 0:
cls._attach_attention_sink_kv_cache(module, **attention_sink_kwargs)


class GPTNeoXModel(GPTNeoXPreTrainedModel, TGPTNeoXModel):
pass


class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, TGPTNeoXForCausalLM):
pass


class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel, TGPTNeoXForSequenceClassification):
pass


class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel, TGPTNeoXForTokenClassification):
pass


class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel, TGPTNeoXForQuestionAnswering):
pass
Loading

0 comments on commit 4be3831

Please sign in to comment.