diff --git a/docs/source/bettertransformer/overview.mdx b/docs/source/bettertransformer/overview.mdx index 415c569547..6de6544dc9 100644 --- a/docs/source/bettertransformer/overview.mdx +++ b/docs/source/bettertransformer/overview.mdx @@ -41,6 +41,7 @@ The list of supported model below: - [BERT](https://arxiv.org/abs/1810.04805) - [BERT-generation](https://arxiv.org/abs/1907.12461) - [BLIP-2](https://arxiv.org/abs/2301.12597) +- [BLOOM](https://arxiv.org/abs/2211.05100) - [CamemBERT](https://arxiv.org/abs/1911.03894) - [CLIP](https://arxiv.org/abs/2103.00020) - [CodeGen](https://arxiv.org/abs/2203.13474) diff --git a/optimum/bettertransformer/models/__init__.py b/optimum/bettertransformer/models/__init__.py index b21ff5da45..35d106b9f5 100644 --- a/optimum/bettertransformer/models/__init__.py +++ b/optimum/bettertransformer/models/__init__.py @@ -18,6 +18,7 @@ BarkAttentionLayerBetterTransformer, BartAttentionLayerBetterTransformer, BlenderbotAttentionLayerBetterTransformer, + BloomAttentionLayerBetterTransformer, CodegenAttentionLayerBetterTransformer, GPT2AttentionLayerBetterTransformer, GPTBigCodeAttentionLayerBetterTransformer, @@ -58,6 +59,7 @@ class BetterTransformerManager: "bert": {"BertLayer": BertLayerBetterTransformer}, "bert-generation": {"BertGenerationLayer": BertLayerBetterTransformer}, "blenderbot": {"BlenderbotAttention": BlenderbotAttentionLayerBetterTransformer}, + "bloom": {"BloomAttention": BloomAttentionLayerBetterTransformer}, "camembert": {"CamembertLayer": BertLayerBetterTransformer}, "blip-2": {"T5Attention": T5AttentionLayerBetterTransformer}, "clip": {"CLIPEncoderLayer": CLIPLayerBetterTransformer}, @@ -130,6 +132,7 @@ class BetterTransformerManager: NOT_REQUIRES_NESTED_TENSOR = { "bark", "blenderbot", + "bloom", "codegen", "gpt2", "gpt_bigcode", @@ -145,6 +148,7 @@ class BetterTransformerManager: NOT_REQUIRES_STRICT_VALIDATION = { "blenderbot", "blip-2", + "bloom", "codegen", "gpt2", "gpt_bigcode", diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index b0251d3178..462f1b46a8 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -804,3 +804,89 @@ def gpt_bigcode_forward( outputs = (attn_output, present) return outputs + + +# Adapted from transformers.models.bloom.modeling_bloom.BloomAttention.forward +def bloom_forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, +): + raise_on_head_mask(head_mask) + + if output_attentions is True: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + batch_size, q_length, _, _ = query_layer.shape + + # Permute to [batch_size, num_heads, seq_length, head_dim] + query_layer = query_layer.transpose(1, 2) + + if layer_past is not None: + past_key, past_value = layer_past + past_key = past_key.transpose(1, 2) + + key_layer = key_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + + # concatenate along seq_length dimension + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + # untangle batch_size from self.num_heads + key_layer = key_layer.reshape(batch_size, self.num_heads, *key_layer.shape[1:]) + value_layer = value_layer.reshape(batch_size, self.num_heads, *value_layer.shape[1:]) + else: + key_layer = key_layer.transpose(1, 2) + value_layer = value_layer.transpose(1, 2) + + alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) + alibi = torch.masked_fill(alibi, attention_mask, torch.finfo(alibi.dtype).min) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=alibi, + dropout_p=self.dropout_prob_attn if self.training else 0.0, + ) + + # Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim] + context_layer = context_layer.transpose(1, 2) + context_layer = context_layer.reshape(*context_layer.shape[:2], -1) + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + torch.nn.functional.linear( + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + output_tensor = torch.nn.functional.dropout(output_tensor, p=self.hidden_dropout, training=self.training) + output_tensor = residual + output_tensor + + if use_cache is True: + present = ( + key_layer.reshape(-1, *key_layer.shape[2:]).transpose(1, 2), + value_layer.reshape(-1, *value_layer.shape[2:]), + ) + else: + present = None + + return (output_tensor, present) diff --git a/optimum/bettertransformer/models/decoder_models.py b/optimum/bettertransformer/models/decoder_models.py index b52bb4d7bd..ab09d6af8c 100644 --- a/optimum/bettertransformer/models/decoder_models.py +++ b/optimum/bettertransformer/models/decoder_models.py @@ -18,6 +18,7 @@ from transformers.models.bark.modeling_bark import BarkSelfAttention from transformers.models.bart.modeling_bart import BartAttention from transformers.models.blenderbot.modeling_blenderbot import BlenderbotAttention +from transformers.models.bloom.modeling_bloom import BloomAttention from transformers.models.codegen.modeling_codegen import CodeGenAttention from transformers.models.gpt2.modeling_gpt2 import GPT2Attention from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention @@ -34,6 +35,7 @@ from .attention import ( bark_wrapped_scaled_dot_product, bart_forward, + bloom_forward, codegen_wrapped_scaled_dot_product, gpt2_wrapped_scaled_dot_product, gpt_bigcode_forward, @@ -202,6 +204,26 @@ def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) +class BloomAttentionLayerBetterTransformer(BetterTransformerBaseLayer, BloomAttention, nn.Module): + def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): + super().__init__(config) + + with torch.device("meta"): + super(BetterTransformerBaseLayer, self).__init__(config) + + self.dropout_prob_attn = config.attention_dropout + + self.module_mapping = None + submodules = ["query_key_value", "dense", "attention_dropout"] + for attr in submodules: + setattr(self, attr, getattr(layer, attr)) + + self.original_layers_mapping = {submodule: submodule for submodule in submodules} + + def forward(self, *args, **kwargs): + return bloom_forward(self, *args, **kwargs) + + class CodegenAttentionLayerBetterTransformer(BetterTransformerBaseLayer, CodeGenAttention, nn.Module): _attn = codegen_wrapped_scaled_dot_product diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 0d4202fd74..78a52def53 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -21,6 +21,7 @@ from ...utils import ( DEFAULT_DUMMY_SHAPES, + BloomDummyPastKeyValuesGenerator, DummyAudioInputGenerator, DummyDecoderTextInputGenerator, DummyPastKeyValuesGenerator, @@ -217,27 +218,6 @@ class LlamaOnnxConfig(TextDecoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig -class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): - def generate(self, input_name: str, framework: str = "pt"): - past_key_shape = ( - self.batch_size * self.num_attention_heads, - self.hidden_size // self.num_attention_heads, - self.sequence_length, - ) - past_value_shape = ( - self.batch_size * self.num_attention_heads, - self.sequence_length, - self.hidden_size // self.num_attention_heads, - ) - return [ - ( - self.random_float_tensor(past_key_shape, framework=framework), - self.random_float_tensor(past_value_shape, framework=framework), - ) - for _ in range(self.num_layers) - ] - - class BloomOnnxConfig(TextDecoderOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = ( BloomDummyPastKeyValuesGenerator, diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 153499497d..6117fe1895 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -42,6 +42,7 @@ ) from .input_generators import ( DEFAULT_DUMMY_SHAPES, + BloomDummyPastKeyValuesGenerator, DummyAudioInputGenerator, DummyBboxInputGenerator, DummyDecoderTextInputGenerator, diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 0ece5543d1..1ff2e4bf2f 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -739,3 +739,24 @@ def generate(self, input_name: str, framework: str = "pt"): self.hidden_size // self.num_attention_heads * 2, ) return [self.random_float_tensor(past_key_value_shape, framework=framework) for _ in range(self.num_layers)] + + +class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def generate(self, input_name: str, framework: str = "pt"): + past_key_shape = ( + self.batch_size * self.num_attention_heads, + self.hidden_size // self.num_attention_heads, + self.sequence_length, + ) + past_value_shape = ( + self.batch_size * self.num_attention_heads, + self.sequence_length, + self.hidden_size // self.num_attention_heads, + ) + return [ + ( + self.random_float_tensor(past_key_shape, framework=framework), + self.random_float_tensor(past_value_shape, framework=framework), + ) + for _ in range(self.num_layers) + ] diff --git a/tests/bettertransformer/test_decoder.py b/tests/bettertransformer/test_decoder.py index aad05b54b9..b06ebf5879 100644 --- a/tests/bettertransformer/test_decoder.py +++ b/tests/bettertransformer/test_decoder.py @@ -22,12 +22,17 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from optimum.bettertransformer import BetterTransformer -from optimum.utils import DummyPastKeyValuesGenerator, GPTBigCodeDummyPastKeyValuesGenerator, NormalizedConfigManager +from optimum.utils import ( + BloomDummyPastKeyValuesGenerator, + DummyPastKeyValuesGenerator, + GPTBigCodeDummyPastKeyValuesGenerator, + NormalizedConfigManager, +) from optimum.utils.testing_utils import grid_parameters, require_accelerate, require_torch_gpu class BetterTransformersDecoderTest(BetterTransformersTestMixin, unittest.TestCase): - SUPPORTED_ARCH = ["codegen", "gpt2", "gpt_bigcode", "gptj", "gpt_neo", "gpt_neox", "llama", "opt"] + SUPPORTED_ARCH = ["bloom", "codegen", "gpt2", "gpt_bigcode", "gptj", "gpt_neo", "gpt_neox", "llama", "opt"] FULL_GRID = { "model_type": SUPPORTED_ARCH, @@ -126,6 +131,8 @@ def test_logits_with_cache(self, test_name: str, model_type: str, batch_size: in if model_type == "gpt_bigcode": pkv_generator_class = GPTBigCodeDummyPastKeyValuesGenerator + elif model_type == "bloom": + pkv_generator_class = BloomDummyPastKeyValuesGenerator else: pkv_generator_class = DummyPastKeyValuesGenerator diff --git a/tests/bettertransformer/testing_utils.py b/tests/bettertransformer/testing_utils.py index e926ab4afc..e5699d0e5b 100644 --- a/tests/bettertransformer/testing_utils.py +++ b/tests/bettertransformer/testing_utils.py @@ -33,6 +33,7 @@ "bert-generation": "ybelkada/random-tiny-BertGenerationModel", "blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel", "blip-2": "hf-internal-testing/tiny-random-Blip2Model", + "bloom": "hf-internal-testing/tiny-random-BloomModel", "camembert": "hf-internal-testing/tiny-random-camembert", "clip_text_model": "hf-internal-testing/tiny-random-clip-zero-shot-image-classification", # with quick_gelu "clip": "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", # with gelu @@ -82,6 +83,7 @@ "hidden_dropout_prob", "classifier_dropout_prob", "attention_dropout", + "hidden_dropout", "dropout", "qa_dropout", "seq_classif_dropout", @@ -222,7 +224,7 @@ def _test_logits_backward(self, model_id: str, model_type: str, **preprocessor_k def _test_logits(self, model_id: str, model_type: str, **preprocessor_kwargs): r""" This tests if the converted model produces the same logits - than the original model. + as the original model. """ # The first row of the attention mask needs to be all ones -> check: https://github.com/pytorch/pytorch/blob/19171a21ee8a9cc1a811ac46d3abd975f0b6fc3b/test/test_nn.py#L5283 inputs = self.prepare_inputs_for_class(model_id=model_id, model_type=model_type, **preprocessor_kwargs)