Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/bettertransformer/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ The list of supported model below:
- [BERT](https://arxiv.org/abs/1810.04805)
- [BERT-generation](https://arxiv.org/abs/1907.12461)
- [BLIP-2](https://arxiv.org/abs/2301.12597)
- [BLOOM](https://arxiv.org/abs/2211.05100)
- [CamemBERT](https://arxiv.org/abs/1911.03894)
- [CLIP](https://arxiv.org/abs/2103.00020)
- [CodeGen](https://arxiv.org/abs/2203.13474)
Expand Down
4 changes: 4 additions & 0 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
BarkAttentionLayerBetterTransformer,
BartAttentionLayerBetterTransformer,
BlenderbotAttentionLayerBetterTransformer,
BloomAttentionLayerBetterTransformer,
CodegenAttentionLayerBetterTransformer,
GPT2AttentionLayerBetterTransformer,
GPTBigCodeAttentionLayerBetterTransformer,
Expand Down Expand Up @@ -58,6 +59,7 @@ class BetterTransformerManager:
"bert": {"BertLayer": BertLayerBetterTransformer},
"bert-generation": {"BertGenerationLayer": BertLayerBetterTransformer},
"blenderbot": {"BlenderbotAttention": BlenderbotAttentionLayerBetterTransformer},
"bloom": {"BloomAttention": BloomAttentionLayerBetterTransformer},
"camembert": {"CamembertLayer": BertLayerBetterTransformer},
"blip-2": {"T5Attention": T5AttentionLayerBetterTransformer},
"clip": {"CLIPEncoderLayer": CLIPLayerBetterTransformer},
Expand Down Expand Up @@ -130,6 +132,7 @@ class BetterTransformerManager:
NOT_REQUIRES_NESTED_TENSOR = {
"bark",
"blenderbot",
"bloom",
"codegen",
"gpt2",
"gpt_bigcode",
Expand All @@ -145,6 +148,7 @@ class BetterTransformerManager:
NOT_REQUIRES_STRICT_VALIDATION = {
"blenderbot",
"blip-2",
"bloom",
"codegen",
"gpt2",
"gpt_bigcode",
Expand Down
86 changes: 86 additions & 0 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,3 +804,89 @@ def gpt_bigcode_forward(
outputs = (attn_output, present)

return outputs


# Adapted from transformers.models.bloom.modeling_bloom.BloomAttention.forward
def bloom_forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
raise_on_head_mask(head_mask)

if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]

# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)

batch_size, q_length, _, _ = query_layer.shape

# Permute to [batch_size, num_heads, seq_length, head_dim]
query_layer = query_layer.transpose(1, 2)

if layer_past is not None:
past_key, past_value = layer_past
past_key = past_key.transpose(1, 2)

key_layer = key_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)

# concatenate along seq_length dimension
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)

# untangle batch_size from self.num_heads
key_layer = key_layer.reshape(batch_size, self.num_heads, *key_layer.shape[1:])
value_layer = value_layer.reshape(batch_size, self.num_heads, *value_layer.shape[1:])
else:
key_layer = key_layer.transpose(1, 2)
value_layer = value_layer.transpose(1, 2)

alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
alibi = torch.masked_fill(alibi, attention_mask, torch.finfo(alibi.dtype).min)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=alibi,
dropout_p=self.dropout_prob_attn if self.training else 0.0,
)

# Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim]
context_layer = context_layer.transpose(1, 2)
context_layer = context_layer.reshape(*context_layer.shape[:2], -1)

# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + torch.nn.functional.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)

output_tensor = torch.nn.functional.dropout(output_tensor, p=self.hidden_dropout, training=self.training)
output_tensor = residual + output_tensor

if use_cache is True:
present = (
key_layer.reshape(-1, *key_layer.shape[2:]).transpose(1, 2),
value_layer.reshape(-1, *value_layer.shape[2:]),
)
else:
present = None

return (output_tensor, present)
22 changes: 22 additions & 0 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from transformers.models.bark.modeling_bark import BarkSelfAttention
from transformers.models.bart.modeling_bart import BartAttention
from transformers.models.blenderbot.modeling_blenderbot import BlenderbotAttention
from transformers.models.bloom.modeling_bloom import BloomAttention
from transformers.models.codegen.modeling_codegen import CodeGenAttention
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention
Expand All @@ -34,6 +35,7 @@
from .attention import (
bark_wrapped_scaled_dot_product,
bart_forward,
bloom_forward,
codegen_wrapped_scaled_dot_product,
gpt2_wrapped_scaled_dot_product,
gpt_bigcode_forward,
Expand Down Expand Up @@ -202,6 +204,26 @@ def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)


class BloomAttentionLayerBetterTransformer(BetterTransformerBaseLayer, BloomAttention, nn.Module):
def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
super().__init__(config)

with torch.device("meta"):
super(BetterTransformerBaseLayer, self).__init__(config)

self.dropout_prob_attn = config.attention_dropout

self.module_mapping = None
submodules = ["query_key_value", "dense", "attention_dropout"]
for attr in submodules:
setattr(self, attr, getattr(layer, attr))

self.original_layers_mapping = {submodule: submodule for submodule in submodules}

def forward(self, *args, **kwargs):
return bloom_forward(self, *args, **kwargs)


class CodegenAttentionLayerBetterTransformer(BetterTransformerBaseLayer, CodeGenAttention, nn.Module):
_attn = codegen_wrapped_scaled_dot_product

Expand Down
22 changes: 1 addition & 21 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from ...utils import (
DEFAULT_DUMMY_SHAPES,
BloomDummyPastKeyValuesGenerator,
DummyAudioInputGenerator,
DummyDecoderTextInputGenerator,
DummyPastKeyValuesGenerator,
Expand Down Expand Up @@ -217,27 +218,6 @@ class LlamaOnnxConfig(TextDecoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt"):
past_key_shape = (
self.batch_size * self.num_attention_heads,
self.hidden_size // self.num_attention_heads,
self.sequence_length,
)
past_value_shape = (
self.batch_size * self.num_attention_heads,
self.sequence_length,
self.hidden_size // self.num_attention_heads,
)
return [
(
self.random_float_tensor(past_key_shape, framework=framework),
self.random_float_tensor(past_value_shape, framework=framework),
)
for _ in range(self.num_layers)
]


class BloomOnnxConfig(TextDecoderOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
BloomDummyPastKeyValuesGenerator,
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from .input_generators import (
DEFAULT_DUMMY_SHAPES,
BloomDummyPastKeyValuesGenerator,
DummyAudioInputGenerator,
DummyBboxInputGenerator,
DummyDecoderTextInputGenerator,
Expand Down
21 changes: 21 additions & 0 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,3 +739,24 @@ def generate(self, input_name: str, framework: str = "pt"):
self.hidden_size // self.num_attention_heads * 2,
)
return [self.random_float_tensor(past_key_value_shape, framework=framework) for _ in range(self.num_layers)]


class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt"):
past_key_shape = (
self.batch_size * self.num_attention_heads,
self.hidden_size // self.num_attention_heads,
self.sequence_length,
)
past_value_shape = (
self.batch_size * self.num_attention_heads,
self.sequence_length,
self.hidden_size // self.num_attention_heads,
)
return [
(
self.random_float_tensor(past_key_shape, framework=framework),
self.random_float_tensor(past_value_shape, framework=framework),
)
for _ in range(self.num_layers)
]
11 changes: 9 additions & 2 deletions tests/bettertransformer/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,17 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from optimum.bettertransformer import BetterTransformer
from optimum.utils import DummyPastKeyValuesGenerator, GPTBigCodeDummyPastKeyValuesGenerator, NormalizedConfigManager
from optimum.utils import (
BloomDummyPastKeyValuesGenerator,
DummyPastKeyValuesGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
NormalizedConfigManager,
)
from optimum.utils.testing_utils import grid_parameters, require_accelerate, require_torch_gpu


class BetterTransformersDecoderTest(BetterTransformersTestMixin, unittest.TestCase):
SUPPORTED_ARCH = ["codegen", "gpt2", "gpt_bigcode", "gptj", "gpt_neo", "gpt_neox", "llama", "opt"]
SUPPORTED_ARCH = ["bloom", "codegen", "gpt2", "gpt_bigcode", "gptj", "gpt_neo", "gpt_neox", "llama", "opt"]

FULL_GRID = {
"model_type": SUPPORTED_ARCH,
Expand Down Expand Up @@ -126,6 +131,8 @@ def test_logits_with_cache(self, test_name: str, model_type: str, batch_size: in

if model_type == "gpt_bigcode":
pkv_generator_class = GPTBigCodeDummyPastKeyValuesGenerator
elif model_type == "bloom":
pkv_generator_class = BloomDummyPastKeyValuesGenerator
else:
pkv_generator_class = DummyPastKeyValuesGenerator

Expand Down
4 changes: 3 additions & 1 deletion tests/bettertransformer/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"bert-generation": "ybelkada/random-tiny-BertGenerationModel",
"blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel",
"blip-2": "hf-internal-testing/tiny-random-Blip2Model",
"bloom": "hf-internal-testing/tiny-random-BloomModel",
"camembert": "hf-internal-testing/tiny-random-camembert",
"clip_text_model": "hf-internal-testing/tiny-random-clip-zero-shot-image-classification", # with quick_gelu
"clip": "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", # with gelu
Expand Down Expand Up @@ -82,6 +83,7 @@
"hidden_dropout_prob",
"classifier_dropout_prob",
"attention_dropout",
"hidden_dropout",
"dropout",
"qa_dropout",
"seq_classif_dropout",
Expand Down Expand Up @@ -222,7 +224,7 @@ def _test_logits_backward(self, model_id: str, model_type: str, **preprocessor_k
def _test_logits(self, model_id: str, model_type: str, **preprocessor_kwargs):
r"""
This tests if the converted model produces the same logits
than the original model.
as the original model.
"""
# The first row of the attention mask needs to be all ones -> check: https://github.com/pytorch/pytorch/blob/19171a21ee8a9cc1a811ac46d3abd975f0b6fc3b/test/test_nn.py#L5283
inputs = self.prepare_inputs_for_class(model_id=model_id, model_type=model_type, **preprocessor_kwargs)
Expand Down