Skip to content

Add Falcon support + benchmark results #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,7 @@ notes.txt
demo.py

# Ignore arrow cache files
cache-*.arrow
cache-*.arrow

# Ignore my personal benchmarking shell scripts
benchmark_*.sh
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@

# Attention Sinks in Transformers for Infinite-length LLMs

![llama_2_7b_ppl_vram](https://github.com/tomaarsen/attention_sinks/assets/37621491/1b99f29e-8d8d-4677-bef6-6a6e041776f6)
| Llama 2 7B | Falcon-7B |
| ------------- | ------------- |
| ![llama_2_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/8d2e5b88-7158-41ac-8b3a-5a7abe38020d) | ![falcon_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/1be07370-6de7-4a7e-b5ab-3092a5ecb412) |

## Overview

Expand All @@ -12,7 +14,7 @@

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")
```
Currently, only Llama-based models are supported. Support for other models will come soon.
* Support for Llama and Falcon models.
* New parameters to `AutoModel....from_pretrained`:
* `attention_sink_size`, int, defaults to 4: The number of initial tokens to use as the attention sink. These tokens are always included in the Attention Sink KV Cache.
* `attention_sink_window_size`, int, defaults to 1020: The size of the sliding window, i.e. the number of "recent tokens" to include in the Attention Sink KV Cache.
Expand Down
11 changes: 10 additions & 1 deletion attention_sinks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
__version__ = "0.0.1"
__version__ = "0.1.0.dev"

from transformers import AutoTokenizer

from .attention_sink_kv_cache import AttentionSinkKVCache
from .models import (
AutoModel,
AutoModelForCausalLM,
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
FalconForCausalLM,
FalconForQuestionAnswering,
FalconForSequenceClassification,
FalconForTokenClassification,
FalconModel,
FalconPreTrainedModel,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
Expand Down
16 changes: 15 additions & 1 deletion attention_sinks/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,16 @@
from .auto import AutoModel, AutoModelForCausalLM
from .auto import (
AutoModel,
AutoModelForCausalLM,
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
)
from .falcon import (
FalconForCausalLM,
FalconForQuestionAnswering,
FalconForSequenceClassification,
FalconForTokenClassification,
FalconModel,
FalconPreTrainedModel,
)
from .llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
8 changes: 7 additions & 1 deletion attention_sinks/models/auto/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
from .modeling_auto import AutoModel, AutoModelForCausalLM
from .modeling_auto import (
AutoModel,
AutoModelForCausalLM,
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
)
36 changes: 32 additions & 4 deletions attention_sinks/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,41 @@
from transformers import (
AutoModelForCausalLM as TAutoModelForCausalLM,
)
from transformers import (
AutoModelForQuestionAnswering as TAutoModelForQuestionAnswering,
)
from transformers import (
AutoModelForSequenceClassification as TAutoModelForSequenceClassification,
)
from transformers import LlamaConfig
from transformers import (
AutoModelForTokenClassification as TAutoModelForTokenClassification,
)
from transformers import FalconConfig, LlamaConfig
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
)

from ..falcon import (
FalconForCausalLM,
FalconForQuestionAnswering,
FalconForSequenceClassification,
FalconForTokenClassification,
FalconModel,
)
from ..llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel

MODEL_MAPPING._extra_content = {LlamaConfig: LlamaModel}
MODEL_FOR_CAUSAL_LM_MAPPING._extra_content = {LlamaConfig: LlamaForCausalLM}
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING._extra_content = {LlamaConfig: LlamaForSequenceClassification}
MODEL_MAPPING._extra_content = {LlamaConfig: LlamaModel, FalconConfig: FalconModel}
MODEL_FOR_CAUSAL_LM_MAPPING._extra_content = {LlamaConfig: LlamaForCausalLM, FalconConfig: FalconForCausalLM}
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING._extra_content = {
LlamaConfig: LlamaForSequenceClassification,
FalconConfig: FalconForSequenceClassification,
}
MODEL_FOR_QUESTION_ANSWERING_MAPPING._extra_content = {FalconConfig: FalconForQuestionAnswering}
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING._extra_content = {FalconConfig: FalconForTokenClassification}


class AutoModel(TAutoModel):
Expand All @@ -31,3 +51,11 @@ class AutoModelForCausalLM(TAutoModelForCausalLM):

class AutoModelForSequenceClassification(TAutoModelForSequenceClassification):
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING


class AutoModelForQuestionAnswering(TAutoModelForQuestionAnswering):
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING


class AutoModelForTokenClassification(TAutoModelForTokenClassification):
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
8 changes: 8 additions & 0 deletions attention_sinks/models/falcon/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .modeling_falcon import (
FalconForCausalLM,
FalconForQuestionAnswering,
FalconForSequenceClassification,
FalconForTokenClassification,
FalconModel,
FalconPreTrainedModel,
)
133 changes: 133 additions & 0 deletions attention_sinks/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import os
import types
from typing import List, Optional, Tuple, Union

from transformers import (
FalconForCausalLM as TFalconForCausalLM,
)
from transformers import (
FalconForQuestionAnswering as TFalconForQuestionAnswering,
)
from transformers import (
FalconForSequenceClassification as TFalconForSequenceClassification,
)
from transformers import (
FalconForTokenClassification as TFalconForTokenClassification,
)
from transformers import (
FalconModel as TFalconModel,
)
from transformers import (
FalconPreTrainedModel as TFalconPreTrainedModel,
)
from transformers import (
PretrainedConfig,
)
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.falcon.modeling_falcon import (
_CHECKPOINT_FOR_DOC,
_CONFIG_FOR_DOC,
FALCON_INPUTS_DOCSTRING,
)
from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings_to_model_forward,
)

from attention_sinks.attention_sink_kv_cache import AttentionSinkKVCache
from attention_sinks.models.falcon.pos_shift import enable_falcon_pos_shift_attention


class FalconPreTrainedModel(TFalconPreTrainedModel):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
) -> None:
# Separate Attention Sink kwargs from regular kwargs
attention_sink_kwargs = {key: value for key, value in kwargs.items() if key.startswith("attention_sink")}
for key in attention_sink_kwargs:
kwargs.pop(key)

model = super(FalconPreTrainedModel, cls).from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
**kwargs,
)
# Enable position shifting attention for Falcon
enable_falcon_pos_shift_attention(model)

# Hackishly attach the Attention Sink KV Cache to the model
cls._attach_attention_sink_kv_cache(model, **attention_sink_kwargs)

return model

@classmethod
def _attach_attention_sink_kv_cache(cls, module, **attention_sink_kwargs):
if isinstance(module, TFalconModel):
# Create the new cache
module.attention_sink_kv_cache = AttentionSinkKVCache(
**attention_sink_kwargs,
k_seq_dim=2,
v_seq_dim=2,
)

# Keep track of the old forward method, we need it in the wrapped one
old_forward = module.forward

# Wrap the forward by overriding the past_key_values using the cache
@add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def wrapped_forward(self, *args, **kwargs) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
outputs = old_forward(*args, **kwargs)
outputs.past_key_values = self.attention_sink_kv_cache(outputs.past_key_values)
return outputs

module.forward = types.MethodType(wrapped_forward, module)

# Recursively call this to find all FalconModels
for module in reversed(module._modules.values()):
if len(list(module.children())) > 0:
cls._attach_attention_sink_kv_cache(module, **attention_sink_kwargs)


class FalconModel(FalconPreTrainedModel, TFalconModel):
pass


class FalconForCausalLM(FalconPreTrainedModel, TFalconForCausalLM):
pass


class FalconForSequenceClassification(FalconPreTrainedModel, TFalconForSequenceClassification):
pass


class FalconForTokenClassification(FalconPreTrainedModel, TFalconForTokenClassification):
pass


class FalconForQuestionAnswering(FalconPreTrainedModel, TFalconForQuestionAnswering):
pass
Loading