Skip to content

Commit

Permalink
Merge pull request #3 from tomaarsen/model/mpt
Browse files Browse the repository at this point in the history
Add MPT support + benchmark results
  • Loading branch information
tomaarsen authored Oct 3, 2023
2 parents 746f511 + 16557ae commit 1e8a1b0
Show file tree
Hide file tree
Showing 14 changed files with 20,621 additions and 11 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -194,5 +194,6 @@ demo.py
# Ignore arrow cache files
cache-*.arrow

# Ignore my personal benchmarking shell scripts
benchmark_*.sh
# Ignore my personal benchmarking shell scripts, but allow the ones in benchmark/scripts
benchmark_*.sh
!benchmark/scripts/*.sh
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@

# Attention Sinks in Transformers for Infinite-length LLMs

| Llama 2 7B | Falcon-7B |
| ------------- | ------------- |
| ![llama_2_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/8d2e5b88-7158-41ac-8b3a-5a7abe38020d) | ![falcon_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/1be07370-6de7-4a7e-b5ab-3092a5ecb412) |
| Llama 2 7B | Falcon 7B | MPT 7B |
| ------------- | ------------- | ---- |
| ![llama_2_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/8d2e5b88-7158-41ac-8b3a-5a7abe38020d) | ![falcon_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/1be07370-6de7-4a7e-b5ab-3092a5ecb412) | ![mpt_7b_ppl_vram_plotted](https://github.com/mit-han-lab/streaming-llm/assets/37621491/c96cff66-92a3-43ab-bc21-40232f2740a0) |

## Overview

Expand Down
6 changes: 6 additions & 0 deletions attention_sinks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,10 @@
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
MptForCausalLM,
MptForQuestionAnswering,
MptForSequenceClassification,
MptForTokenClassification,
MptModel,
MptPreTrainedModel,
)
8 changes: 8 additions & 0 deletions attention_sinks/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,11 @@
FalconPreTrainedModel,
)
from .llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
from .mpt import (
MptForCausalLM,
MptForQuestionAnswering,
MptForSequenceClassification,
MptForTokenClassification,
MptModel,
MptPreTrainedModel,
)
28 changes: 23 additions & 5 deletions attention_sinks/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers import (
AutoModelForTokenClassification as TAutoModelForTokenClassification,
)
from transformers import FalconConfig, LlamaConfig
from transformers import FalconConfig, LlamaConfig, MptConfig
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
Expand All @@ -30,15 +30,33 @@
FalconModel,
)
from ..llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
from ..mpt import (
MptForCausalLM,
MptForQuestionAnswering,
MptForSequenceClassification,
MptForTokenClassification,
MptModel,
)

MODEL_MAPPING._extra_content = {LlamaConfig: LlamaModel, FalconConfig: FalconModel}
MODEL_FOR_CAUSAL_LM_MAPPING._extra_content = {LlamaConfig: LlamaForCausalLM, FalconConfig: FalconForCausalLM}
MODEL_MAPPING._extra_content = {LlamaConfig: LlamaModel, FalconConfig: FalconModel, MptConfig: MptModel}
MODEL_FOR_CAUSAL_LM_MAPPING._extra_content = {
LlamaConfig: LlamaForCausalLM,
FalconConfig: FalconForCausalLM,
MptConfig: MptForCausalLM,
}
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING._extra_content = {
LlamaConfig: LlamaForSequenceClassification,
FalconConfig: FalconForSequenceClassification,
MptConfig: MptForSequenceClassification,
}
MODEL_FOR_QUESTION_ANSWERING_MAPPING._extra_content = {
FalconConfig: FalconForQuestionAnswering,
MptConfig: MptForQuestionAnswering,
}
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING._extra_content = {
FalconConfig: FalconForTokenClassification,
MptConfig: MptForTokenClassification,
}
MODEL_FOR_QUESTION_ANSWERING_MAPPING._extra_content = {FalconConfig: FalconForQuestionAnswering}
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING._extra_content = {FalconConfig: FalconForTokenClassification}


class AutoModel(TAutoModel):
Expand Down
8 changes: 8 additions & 0 deletions attention_sinks/models/mpt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .modeling_mpt import (
MptForCausalLM,
MptForQuestionAnswering,
MptForSequenceClassification,
MptForTokenClassification,
MptModel,
MptPreTrainedModel,
)
132 changes: 132 additions & 0 deletions attention_sinks/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import os
import types
from typing import List, Optional, Tuple, Union

import torch
from transformers import (
MptForCausalLM as TMptForCausalLM,
)
from transformers import (
MptForQuestionAnswering as TMptForQuestionAnswering,
)
from transformers import (
MptForSequenceClassification as TMptForSequenceClassification,
)
from transformers import (
MptForTokenClassification as TMptForTokenClassification,
)
from transformers import (
MptModel as TMptModel,
)
from transformers import (
MptPreTrainedModel as TMptPreTrainedModel,
)
from transformers import (
PretrainedConfig,
)
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.mpt.modeling_mpt import (
_CHECKPOINT_FOR_DOC,
_CONFIG_FOR_DOC,
MPT_INPUTS_DOCSTRING,
)
from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings_to_model_forward,
)

from attention_sinks.attention_sink_kv_cache import AttentionSinkKVCache


class MptPreTrainedModel(TMptPreTrainedModel):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
) -> None:
# Separate Attention Sink kwargs from regular kwargs
attention_sink_kwargs = {key: value for key, value in kwargs.items() if key.startswith("attention_sink")}
for key in attention_sink_kwargs:
kwargs.pop(key)

model = super(MptPreTrainedModel, cls).from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
**kwargs,
)
# MPT already has windowed attention

# Hackishly attach the Attention Sink KV Cache to the model
cls._attach_attention_sink_kv_cache(model, **attention_sink_kwargs)

return model

@classmethod
def _attach_attention_sink_kv_cache(cls, module, **attention_sink_kwargs):
if isinstance(module, TMptModel):
# Create the new cache
module.attention_sink_kv_cache = AttentionSinkKVCache(
**attention_sink_kwargs,
k_seq_dim=2,
v_seq_dim=2,
)

# Keep track of the old forward method, we need it in the wrapped one
old_forward = module.forward

# Wrap the forward by overriding the past_key_values using the cache
@add_start_docstrings_to_model_forward(MPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def wrapped_forward(self, *args, **kwargs) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
outputs = old_forward(*args, **kwargs)
outputs.past_key_values = self.attention_sink_kv_cache(outputs.past_key_values)
return outputs

module.forward = types.MethodType(wrapped_forward, module)

# Recursively call this to find all MptModels
for module in reversed(module._modules.values()):
if len(list(module.children())) > 0:
cls._attach_attention_sink_kv_cache(module, **attention_sink_kwargs)


class MptModel(MptPreTrainedModel, TMptModel):
pass


class MptForCausalLM(MptPreTrainedModel, TMptForCausalLM):
pass


class MptForSequenceClassification(MptPreTrainedModel, TMptForSequenceClassification):
pass


class MptForTokenClassification(MptPreTrainedModel, TMptForTokenClassification):
pass


class MptForQuestionAnswering(MptPreTrainedModel, TMptForQuestionAnswering):
pass
Loading

0 comments on commit 1e8a1b0

Please sign in to comment.