From 212aa3a8d685ce603dc349b57867b59db31cb167 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Tue, 25 Jul 2023 13:36:52 +0200 Subject: [PATCH 1/8] support training --- optimum/bettertransformer/models/__init__.py | 28 - optimum/bettertransformer/models/base.py | 11 - .../models/decoder_models.py | 24 - .../models/encoder_models.py | 1110 ++++++++++------- optimum/bettertransformer/transformation.py | 12 +- tests/bettertransformer/test_encoder.py | 63 +- tests/bettertransformer/testing_utils.py | 75 +- 7 files changed, 770 insertions(+), 553 deletions(-) diff --git a/optimum/bettertransformer/models/__init__.py b/optimum/bettertransformer/models/__init__.py index 6029b0e31d..3083c2e517 100644 --- a/optimum/bettertransformer/models/__init__.py +++ b/optimum/bettertransformer/models/__init__.py @@ -148,23 +148,6 @@ class BetterTransformerManager: "t5", } - REQUIRES_TORCH_20 = { - "blenderbot", - "bart", - "codegen", - "gpt2", - "gptj", - "gpt_neo", - "gpt_neox", - "llama", - "m2m_100", - "marian", - "mbart", - "opt", - "pegasus", - "t5", - } - @staticmethod def cannot_support(model_type: str) -> bool: """ @@ -209,17 +192,6 @@ def requires_strict_validation(model_type: str) -> bool: """ return model_type not in BetterTransformerManager.NOT_REQUIRES_STRICT_VALIDATION - @staticmethod - def requires_torch_20(model_type: str) -> bool: - """ - Returns True if the architecture requires PyTorch 2.0 to be used with BetterTransformer. - - Args: - model_type (`str`): - The model type to check. - """ - return model_type in BetterTransformerManager.REQUIRES_TORCH_20 - class warn_uncompatible_save(object): def __init__(self, callback): diff --git a/optimum/bettertransformer/models/base.py b/optimum/bettertransformer/models/base.py index d2e5bb4bba..fdd7fc5eb9 100644 --- a/optimum/bettertransformer/models/base.py +++ b/optimum/bettertransformer/models/base.py @@ -55,7 +55,6 @@ def __init__( self.num_layers = None self.original_layers_mapping = {} self.module_mapping = None - self.supports_training = False # Some models does not have some attributes thus needs to be ignored # e.g. whisper does not have self_attn.k_proj.bias but has self_attn.v_proj.bias & self_attn.q_proj.bias self.keys_to_ignore = [] @@ -127,16 +126,6 @@ def validate_bettertransformer(self): f" Number of heads must be even." ) - def forward_checker(self, *args, **kwargs): - if torch.is_autocast_enabled() or torch.is_autocast_cpu_enabled(): - raise ValueError("Autocast is not supported for `BetterTransformer` integration.") - - if self.training and not self.supports_training: - raise ValueError( - "Training is not supported for `BetterTransformer` integration.", - " Please use `model.eval()` before running the model.", - ) - def _revert(self, module: torch.nn.Module) -> torch.nn.Module: if self.module_mapping is not None: if "" in self.module_mapping.values(): diff --git a/optimum/bettertransformer/models/decoder_models.py b/optimum/bettertransformer/models/decoder_models.py index bb8f890227..bbc5a13135 100644 --- a/optimum/bettertransformer/models/decoder_models.py +++ b/optimum/bettertransformer/models/decoder_models.py @@ -65,12 +65,10 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): setattr(self, "q_attn", getattr(layer, "q_attn")) self.original_layers_mapping["q_attn"] = "q_attn" - self.supports_training = True self.downcast_qk = False self.dropout_prob_attn = config.attn_pdrop def forward(self, *args, **kwargs): - super().forward_checker() return super().forward(*args, **kwargs) @@ -105,11 +103,9 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): self.original_layers_mapping = {submodule: submodule for submodule in submodules} self.downcast_qk = True - self.supports_training = True self.dropout_prob_attn = config.attn_pdrop def forward(self, *args, **kwargs): - super().forward_checker() return super().forward(*args, **kwargs) @@ -129,11 +125,9 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): self.original_layers_mapping = {submodule: submodule for submodule in submodules} self.downcast_qk = True - self.supports_training = True self.dropout_prob_attn = 0.0 # no dropout for gpt-neox def forward(self, *args, **kwargs): - super().forward_checker() return super().forward(*args, **kwargs) @@ -159,11 +153,9 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): self.original_layers_mapping = {submodule: submodule for submodule in submodules} self.scale = torch.sqrt(torch.tensor(layer.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) - self.supports_training = True self.dropout_prob_attn = float(config.attention_dropout) def forward(self, *args, **kwargs): - super().forward_checker() return super().forward(*args, **kwargs) @@ -188,11 +180,9 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): self.original_layers_mapping = {submodule: submodule for submodule in submodules} - self.supports_training = True self.dropout_prob_attn = config.attn_pdrop def forward(self, *args, **kwargs): - super().forward_checker() return super().forward(*args, **kwargs) @@ -218,10 +208,7 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): self.original_layers_mapping = {submodule: submodule for submodule in submodules} - self.supports_training = True - def forward(self, *args, **kwargs): - super().forward_checker() return opt_forward(self, *args, **kwargs) @@ -249,11 +236,9 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): self.module_mapping = None - self.supports_training = True self.is_decoder = layer.is_decoder def forward(self, *args, **kwargs): - super().forward_checker() return t5_forward(self, *args, **kwargs) @@ -274,7 +259,6 @@ def bart_bettertransformer_init(self, layer: "nn.Module", config: "PretrainedCon self.original_layers_mapping = {submodule: submodule for submodule in submodules} - self.supports_training = True self.is_decoder = layer.is_decoder @@ -284,7 +268,6 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): bart_bettertransformer_init(self, layer, config) def forward(self, *args, **kwargs): - super().forward_checker() return bart_forward(self, *args, **kwargs) @@ -294,7 +277,6 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): bart_bettertransformer_init(self, layer, config) def forward(self, *args, **kwargs): - super().forward_checker() return bart_forward(self, *args, **kwargs) @@ -304,7 +286,6 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): bart_bettertransformer_init(self, layer, config) def forward(self, *args, **kwargs): - super().forward_checker() return bart_forward(self, *args, **kwargs) @@ -314,7 +295,6 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): bart_bettertransformer_init(self, layer, config) def forward(self, *args, **kwargs): - super().forward_checker() return bart_forward(self, *args, **kwargs) @@ -323,7 +303,6 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): bart_bettertransformer_init(self, layer, config) def forward(self, *args, **kwargs): - super().forward_checker() return bart_forward(self, *args, **kwargs) @@ -339,8 +318,5 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): self.original_layers_mapping = {submodule: submodule for submodule in submodules} - self.supports_training = True - def forward(self, *args, **kwargs): - super().forward_checker() return llama_forward(self, *args, **kwargs) diff --git a/optimum/bettertransformer/models/encoder_models.py b/optimum/bettertransformer/models/encoder_models.py index 3913830640..9cda2ee925 100644 --- a/optimum/bettertransformer/models/encoder_models.py +++ b/optimum/bettertransformer/models/encoder_models.py @@ -15,6 +15,8 @@ import torch import torch.nn as nn +import torch.nn.functional as F +from transformers.activations import ACT2FN from .base import BetterTransformerBaseLayer @@ -99,50 +101,109 @@ def __init__(self, albert_layer, config): "norm2_weight": "full_layer_layer_norm.weight", "norm2_bias": "full_layer_layer_norm.bias", } + self.attention_head_size = config.hidden_size // config.num_attention_heads + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + self.hidden_dropout_prob = config.hidden_dropout_prob self.validate_bettertransformer() def forward(self, hidden_states, attention_mask, *_): - r""" - This is just a wrapper around the forward function proposed in: - https://github.com/huggingface/transformers/pull/19553 - """ - super().forward_checker() + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): + if hidden_states.is_nested: + attention_mask = None + + if attention_mask is not None: + # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask + # 0->false->keep this token -inf->true->mask this token + attention_mask = attention_mask.bool() + attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1])) + hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask) + attention_mask = None + + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attention_mask, + ) + if hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0) + else: + # TODO: check dropout + qkv = F.linear(hidden_states, weight=self.in_proj_weight, bias=self.in_proj_bias) - if hidden_states.is_nested: - attention_mask = None + qkv = qkv.view(qkv.size()[:-1] + (3, self.num_heads, self.attention_head_size)).permute(2, 0, 3, 1, 4) + query, key, value = qkv[0], qkv[1], qkv[2] - if attention_mask is not None: - # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask - # 0->false->keep this token -inf->true->mask this token - attention_mask = attention_mask.bool() - attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1])) - hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask) - attention_mask = None + # TODO: pass scale argument in PyTorch 2.1 release + query = ( + query + * torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32)).to(torch.get_default_dtype()) + / 8 + ) + + # NOTE: In PyTorch 2.0, passing an attention_mask will automatically dispatch + # to the "math" path and will NOT use flash attention / memory-efficient attention. + # We should support xformers / Hazy-flash / rocm-flash directly and stop relying on PyTorch to do the work. + attention_out = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + is_causal=False, + dropout_p=self.attention_probs_dropout_prob if self.training else 0.0, + ) + + attention_out = attention_out.permute(0, 2, 1, 3).contiguous() + new_attention_out_shape = attention_out.size()[:-2] + (self.num_heads * self.attention_head_size,) + attention_out = attention_out.view(new_attention_out_shape) + + # BertSelfOutput + attention_out = F.layer_norm( + F.dropout( + F.linear(attention_out, self.out_proj_weight, self.out_proj_bias), + p=self.hidden_dropout_prob, + training=self.training, + ) + + hidden_states, + normalized_shape=self.norm1_weight.shape, # TODO: stateful + weight=self.norm1_weight, + bias=self.norm1_bias, + ) + + # BertIntermediate + # TODO: stateful + act_fn = ACT2FN[self.act_fn] + x = act_fn(F.linear(attention_out, self.linear1_weight, self.linear1_bias)) + + # BertOutput + hidden_states = F.layer_norm( + attention_out + + F.dropout( + F.linear(x, self.linear2_weight, self.linear2_bias), + p=self.hidden_dropout_prob, + training=self.training, + ), + normalized_shape=self.norm2_weight.shape, # TODO: stateful + weight=self.norm2_weight, + bias=self.norm2_bias, + ) - hidden_states = torch._transformer_encoder_layer_fwd( - hidden_states, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj_weight, - self.out_proj_bias, - self.use_gelu, - self.norm_first, - self.norm1_eps, - self.norm1_weight, - self.norm1_bias, - self.norm2_weight, - self.norm2_bias, - self.linear1_weight, - self.linear1_bias, - self.linear2_weight, - self.linear2_bias, - attention_mask, - ) - if hidden_states.is_nested and self.is_last_layer: - hidden_states = hidden_states.to_padded_tensor(0.0) return (hidden_states,) @@ -227,49 +288,109 @@ def __init__(self, bert_layer, config): "norm2_bias": "output.LayerNorm.bias", } + # TODO: cleaner solution + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.hidden_dropout_prob = config.hidden_dropout_prob + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + self.validate_bettertransformer() def forward(self, hidden_states, attention_mask, *_): - r""" - This is just a wrapper around the forward function proposed in: - https://github.com/huggingface/transformers/pull/19553 - """ - super().forward_checker() + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): + if hidden_states.is_nested: + attention_mask = None + + if attention_mask is not None: + # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask + # 0->false->keep this token -inf->true->mask this token + attention_mask = attention_mask.bool() + attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1])) + hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask) + attention_mask = None + + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attention_mask, + ) + if hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0) + else: + qkv = F.linear(hidden_states, weight=self.in_proj_weight, bias=self.in_proj_bias) - if hidden_states.is_nested: - attention_mask = None + qkv = qkv.view(qkv.size()[:-1] + (3, self.num_heads, self.attention_head_size)).permute(2, 0, 3, 1, 4) + query, key, value = qkv[0], qkv[1], qkv[2] - if attention_mask is not None: - # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask - # 0->false->keep this token -inf->true->mask this token - attention_mask = attention_mask.bool() - attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1])) - hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask) - attention_mask = None + # TODO: pass scale argument in PyTorch 2.1 release + query = ( + query + * torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32)).to(torch.get_default_dtype()) + / 8 + ) + + # NOTE: In PyTorch 2.0, passing an attention_mask will automatically dispatch + # to the "math" path and will NOT use flash attention / memory-efficient attention. + # We should support xformers / Hazy-flash / rocm-flash directly and stop relying on PyTorch to do the work. + attention_out = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + is_causal=False, + dropout_p=self.attention_probs_dropout_prob if self.training else 0.0, + ) + + attention_out = attention_out.permute(0, 2, 1, 3).contiguous() + new_attention_out_shape = attention_out.size()[:-2] + (self.num_heads * self.attention_head_size,) + attention_out = attention_out.view(new_attention_out_shape) + + # BertSelfOutput + attention_out = F.layer_norm( + F.dropout( + F.linear(attention_out, self.out_proj_weight, self.out_proj_bias), + p=self.hidden_dropout_prob, + training=self.training, + ) + + hidden_states, + normalized_shape=self.norm1_weight.shape, # TODO: stateful + weight=self.norm1_weight, + bias=self.norm1_bias, + ) + + # BertIntermediate + # TODO: stateful + act_fn = ACT2FN[self.act_fn] + x = act_fn(F.linear(attention_out, self.linear1_weight, self.linear1_bias)) + + # BertOutput + hidden_states = F.layer_norm( + attention_out + + F.dropout( + F.linear(x, self.linear2_weight, self.linear2_bias), + p=self.hidden_dropout_prob, + training=self.training, + ), + normalized_shape=self.norm2_weight.shape, # TODO: stateful + weight=self.norm2_weight, + bias=self.norm2_bias, + ) - hidden_states = torch._transformer_encoder_layer_fwd( - hidden_states, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj_weight, - self.out_proj_bias, - self.use_gelu, - self.norm_first, - self.norm1_eps, - self.norm1_weight, - self.norm1_bias, - self.norm2_weight, - self.norm2_bias, - self.linear1_weight, - self.linear1_bias, - self.linear2_weight, - self.linear2_bias, - attention_mask, - ) - if hidden_states.is_nested and self.is_last_layer: - hidden_states = hidden_states.to_padded_tensor(0.0) return (hidden_states,) @@ -354,56 +475,53 @@ def __init__(self, bart_layer, config): self.validate_bettertransformer() def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): - r""" - This is just a wrapper around the forward function proposed in: - https://github.com/huggingface/transformers/pull/19553 - """ - super().forward_checker() + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): + if not hasattr(hidden_states, "original_shape"): + original_shape = hidden_states.shape + else: + original_shape = hidden_states.original_shape + + if hidden_states.is_nested: + attention_mask = None + + if attention_mask is not None: + # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask + # 0->false->keep this token -inf->true->mask this token + if len(attention_mask.shape) == 4: + attention_mask = attention_mask.squeeze(1)[:, 0] + attention_mask = attention_mask.bool() + attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1])) + hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask) + attention_mask = None + + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attention_mask, + ) - if not hasattr(hidden_states, "original_shape"): - original_shape = hidden_states.shape + if not self.is_last_layer: + hidden_states.original_shape = original_shape + elif hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0, original_shape) else: - original_shape = hidden_states.original_shape - - if hidden_states.is_nested: - attention_mask = None - - if attention_mask is not None: - # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask - # 0->false->keep this token -inf->true->mask this token - if len(attention_mask.shape) == 4: - attention_mask = attention_mask.squeeze(1)[:, 0] - attention_mask = attention_mask.bool() - attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1])) - hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask) - attention_mask = None - - hidden_states = torch._transformer_encoder_layer_fwd( - hidden_states, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj_weight, - self.out_proj_bias, - self.use_gelu, - self.norm_first, - self.norm1_eps, - self.norm1_weight, - self.norm1_bias, - self.norm2_weight, - self.norm2_bias, - self.linear1_weight, - self.linear1_bias, - self.linear2_weight, - self.linear2_bias, - attention_mask, - ) - - if not self.is_last_layer: - hidden_states.original_shape = original_shape - elif hidden_states.is_nested and self.is_last_layer: - hidden_states = hidden_states.to_padded_tensor(0.0, original_shape) + raise NotImplementedError("TODO") return (hidden_states,) @@ -492,56 +610,53 @@ def __init__(self, mbart_layer, config): self.validate_bettertransformer() def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): - r""" - This is just a wrapper around the forward function proposed in: - https://github.com/huggingface/transformers/pull/19553 - """ - super().forward_checker() + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): + if not hasattr(hidden_states, "original_shape"): + original_shape = hidden_states.shape + else: + original_shape = hidden_states.original_shape + + if hidden_states.is_nested: + attention_mask = None + + if attention_mask is not None: + # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask + # 0->false->keep this token -inf->true->mask this token + if len(attention_mask.shape) == 4: + attention_mask = attention_mask.squeeze(1)[:, 0] + attention_mask = attention_mask.bool() + attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1])) + hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask) + attention_mask = None + + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attention_mask, + ) - if not hasattr(hidden_states, "original_shape"): - original_shape = hidden_states.shape + if not self.is_last_layer: + hidden_states.original_shape = original_shape + elif hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0, original_shape) else: - original_shape = hidden_states.original_shape - - if hidden_states.is_nested: - attention_mask = None - - if attention_mask is not None: - # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask - # 0->false->keep this token -inf->true->mask this token - if len(attention_mask.shape) == 4: - attention_mask = attention_mask.squeeze(1)[:, 0] - attention_mask = attention_mask.bool() - attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1])) - hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask) - attention_mask = None - - hidden_states = torch._transformer_encoder_layer_fwd( - hidden_states, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj_weight, - self.out_proj_bias, - self.use_gelu, - self.norm_first, - self.norm1_eps, - self.norm1_weight, - self.norm1_bias, - self.norm2_weight, - self.norm2_bias, - self.linear1_weight, - self.linear1_bias, - self.linear2_weight, - self.linear2_bias, - attention_mask, - ) - - if not self.is_last_layer: - hidden_states.original_shape = original_shape - elif hidden_states.is_nested and self.is_last_layer: - hidden_states = hidden_states.to_padded_tensor(0.0, original_shape) + raise NotImplementedError("TODO") return (hidden_states,) @@ -619,54 +734,114 @@ def __init__(self, bert_layer, config): "norm2_weight": "output_layer_norm.weight", "norm2_bias": "output_layer_norm.bias", } + self.attention_dropout = config.attention_dropout + self.dropout = config.dropout + self.attention_head_size = config.dim // config.n_heads self.validate_bettertransformer() - def forward(self, x, attn_mask, head_mask=None, output_attentions=None, *_): - r""" - This is just a wrapper around the forward function proposed in: - https://github.com/huggingface/transformers/pull/19553 - """ - super().forward_checker() - - if x.is_nested: - attn_mask = None - - if attn_mask is not None: - # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask - # 0->false->keep this token -inf->true->mask this token - attn_mask = attn_mask.bool() - attn_mask = torch.reshape(attn_mask, (attn_mask.shape[0], attn_mask.shape[-1])) - seqlen = attn_mask.shape[1] - lengths = torch.sum(~attn_mask, 1) - if not all(l == seqlen for l in lengths): - x = torch._nested_tensor_from_mask(x, attn_mask) - attn_mask = None - - x = torch._transformer_encoder_layer_fwd( - x, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj_weight, - self.out_proj_bias, - self.use_gelu, - self.norm_first, - self.norm1_eps, - self.norm1_weight, - self.norm1_bias, - self.norm2_weight, - self.norm2_bias, - self.linear1_weight, - self.linear1_bias, - self.linear2_weight, - self.linear2_bias, - attn_mask, - ) - if x.is_nested and self.is_last_layer: - x = x.to_padded_tensor(0.0) - return (x,) + def forward(self, hidden_states, attn_mask, head_mask=None, output_attentions=None, *_): + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): + if hidden_states.is_nested: + attn_mask = None + + if attn_mask is not None: + # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask + # 0->false->keep this token -inf->true->mask this token + attn_mask = attn_mask.bool() + attn_mask = torch.reshape(attn_mask, (attn_mask.shape[0], attn_mask.shape[-1])) + seqlen = attn_mask.shape[1] + lengths = torch.sum(~attn_mask, 1) + if not all(l == seqlen for l in lengths): + hidden_states = torch._nested_tensor_from_mask(hidden_states, attn_mask) + attn_mask = None + + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attn_mask, + ) + if hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0) + else: + # TODO: check dropout + qkv = F.linear(hidden_states, weight=self.in_proj_weight, bias=self.in_proj_bias) + + qkv = qkv.view(qkv.size()[:-1] + (3, self.num_heads, self.attention_head_size)).permute(2, 0, 3, 1, 4) + query, key, value = qkv[0], qkv[1], qkv[2] + + # TODO: pass scale argument in PyTorch 2.1 release + query = ( + query + * torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32)).to(torch.get_default_dtype()) + / 8 + ) + + # TODO: Kind of stupid to do that at each layer, should be fixed in transformers + attn_mask = attn_mask.unsqueeze(1).unsqueeze(2).to(dtype=query.dtype) + attn_mask = (1.0 - attn_mask) * torch.finfo(query.dtype).min + + # NOTE: In PyTorch 2.0, passing an attention_mask will automatically dispatch + # to the "math" path and will NOT use flash attention / memory-efficient attention. + # We should support xformers / Hazy-flash / rocm-flash directly and stop relying on PyTorch to do the work. + attention_out = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + is_causal=False, + dropout_p=self.attention_dropout if self.training else 0.0, + ) + + attention_out = attention_out.permute(0, 2, 1, 3).contiguous() + new_attention_out_shape = attention_out.size()[:-2] + (self.num_heads * self.attention_head_size,) + attention_out = attention_out.view(new_attention_out_shape) + + # BertSelfOutput + attention_out = F.layer_norm( + F.dropout( + F.linear(attention_out, self.out_proj_weight, self.out_proj_bias), + p=self.dropout, + training=self.training, + ) + + hidden_states, + normalized_shape=self.norm1_weight.shape, # TODO: stateful + weight=self.norm1_weight, + bias=self.norm1_bias, + ) + + # BertIntermediate + # TODO: stateful + act_fn = ACT2FN[self.act_fn] + x = act_fn(F.linear(attention_out, self.linear1_weight, self.linear1_bias)) + + # BertOutput + hidden_states = F.layer_norm( + attention_out + + F.dropout( + F.linear(x, self.linear2_weight, self.linear2_bias), p=self.dropout, training=self.training + ), + normalized_shape=self.norm2_weight.shape, # TODO: stateful + weight=self.norm2_weight, + bias=self.norm2_bias, + ) + return (hidden_states,) class WhisperEncoderLayerBetterTransformer(BetterTransformerBaseLayer, nn.Module): @@ -749,36 +924,34 @@ def __init__(self, whisper_layer, config): self.validate_bettertransformer() def forward(self, hidden_states, attention_mask, *_, **__): - r""" - This is just a wrapper around the forward function proposed in: - https://github.com/huggingface/transformers/pull/19553 - """ - super().forward_checker() - attention_mask = None # attention mask seems to be always None: https://github.com/huggingface/transformers/blob/94b3f544a1f5e04b78d87a2ae32a7ac252e22e31/src/transformers/models/whisper/modeling_whisper.py#L690 - - hidden_states = torch._transformer_encoder_layer_fwd( - hidden_states, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj_weight, - self.out_proj_bias, - self.use_gelu, - self.norm_first, - self.norm1_eps, - self.norm1_weight, - self.norm1_bias, - self.norm2_weight, - self.norm2_bias, - self.linear1_weight, - self.linear1_bias, - self.linear2_weight, - self.linear2_bias, - attention_mask, - ) - if hidden_states.is_nested and self.is_last_layer: - hidden_states = hidden_states.to_padded_tensor(0.0) + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): + attention_mask = None # attention mask seems to be always None: https://github.com/huggingface/transformers/blob/94b3f544a1f5e04b78d87a2ae32a7ac252e22e31/src/transformers/models/whisper/modeling_whisper.py#L690 + + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attention_mask, + ) + if hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0) + else: + raise NotImplementedError("TODO") return (hidden_states,) @@ -869,36 +1042,34 @@ def __init__(self, vit_layer, config): self.validate_bettertransformer() def forward(self, hidden_states, *_, **__): - r""" - This is just a wrapper around the forward function proposed in: - https://github.com/huggingface/transformers/pull/19553 - """ - super().forward_checker() - attention_mask = None - - hidden_states = torch._transformer_encoder_layer_fwd( - hidden_states, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj_weight, - self.out_proj_bias, - self.use_gelu, - self.norm_first, - self.norm1_eps, - self.norm1_weight, - self.norm1_bias, - self.norm2_weight, - self.norm2_bias, - self.linear1_weight, - self.linear1_bias, - self.linear2_weight, - self.linear2_bias, - attention_mask, - ) - if hidden_states.is_nested and self.is_last_layer: - hidden_states = hidden_states.to_padded_tensor(0.0) + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): + attention_mask = None + + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attention_mask, + ) + if hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0) + else: + raise NotImplementedError("TODO") return (hidden_states,) @@ -989,36 +1160,34 @@ def __init__(self, vilt_layer, config): self.validate_bettertransformer() def forward(self, hidden_states, *_, **__): - r""" - This is just a wrapper around the forward function proposed in: - https://github.com/huggingface/transformers/pull/19553 - """ - super().forward_checker() - attention_mask = None - - hidden_states = torch._transformer_encoder_layer_fwd( - hidden_states, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj_weight, - self.out_proj_bias, - self.use_gelu, - self.norm_first, - self.norm1_eps, - self.norm1_weight, - self.norm1_bias, - self.norm2_weight, - self.norm2_bias, - self.linear1_weight, - self.linear1_bias, - self.linear2_weight, - self.linear2_bias, - attention_mask, - ) - if hidden_states.is_nested and self.is_last_layer: - hidden_states = hidden_states.to_padded_tensor(0.0) + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): + attention_mask = None + + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attention_mask, + ) + if hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0) + else: + raise NotImplementedError("TODO") return (hidden_states,) @@ -1105,47 +1274,45 @@ def __init__(self, wav2vec2_layer, config): self.validate_bettertransformer() def forward(self, hidden_states, attention_mask, **__): - r""" - This is just a wrapper around the forward function proposed in: - https://github.com/huggingface/transformers/pull/19553 - """ - super().forward_checker() - if hidden_states.is_nested: - attention_mask = None - - if attention_mask is not None: - # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask - # 0->false->keep this token -inf->true->mask this token - attention_mask = attention_mask.bool() - if len(attention_mask.shape) == 4: - attention_mask = attention_mask.squeeze(1)[:, 0] - attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1])) - hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask) - attention_mask = None - - hidden_states = torch._transformer_encoder_layer_fwd( - hidden_states, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj_weight, - self.out_proj_bias, - self.use_gelu, - self.norm_first, - self.norm1_eps, - self.norm1_weight, - self.norm1_bias, - self.norm2_weight, - self.norm2_bias, - self.linear1_weight, - self.linear1_bias, - self.linear2_weight, - self.linear2_bias, - attention_mask, - ) - if hidden_states.is_nested and self.is_last_layer: - hidden_states = hidden_states.to_padded_tensor(0.0) + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): + if hidden_states.is_nested: + attention_mask = None + + if attention_mask is not None: + # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask + # 0->false->keep this token -inf->true->mask this token + attention_mask = attention_mask.bool() + if len(attention_mask.shape) == 4: + attention_mask = attention_mask.squeeze(1)[:, 0] + attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1])) + hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask) + attention_mask = None + + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attention_mask, + ) + if hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0) + else: + raise NotImplementedError("TODO") return (hidden_states,) @@ -1227,61 +1394,59 @@ def __init__(self, fsmt_layer, config): self.validate_bettertransformer() def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): - r""" - This is just a wrapper around the forward function proposed in: - https://github.com/huggingface/transformers/pull/19553 - """ - super().forward_checker() - - if not hasattr(hidden_states, "original_shape"): - original_shape = hidden_states.shape - else: - original_shape = hidden_states.original_shape - - if hidden_states.is_nested: - attention_mask = None - - if attention_mask is not None: - # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask - # 0->false->keep this token -inf->true->mask this token - attention_mask = attention_mask.bool() - attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1])) - - # FSMT swaps the first two axis before calling the encoder stack - # Reference: https://github.com/huggingface/transformers/blob/699e90437f984d69ad3c9b891dd2e9d0fc2cffe4/src/transformers/models/fsmt/modeling_fsmt.py#L508 - if hidden_states.shape[0] != attention_mask.shape[0]: - hidden_states = hidden_states.transpose(1, 0) + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): + if not hasattr(hidden_states, "original_shape"): original_shape = hidden_states.shape + else: + original_shape = hidden_states.original_shape + + if hidden_states.is_nested: + attention_mask = None + + if attention_mask is not None: + # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask + # 0->false->keep this token -inf->true->mask this token + attention_mask = attention_mask.bool() + attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1])) + + # FSMT swaps the first two axis before calling the encoder stack + # Reference: https://github.com/huggingface/transformers/blob/699e90437f984d69ad3c9b891dd2e9d0fc2cffe4/src/transformers/models/fsmt/modeling_fsmt.py#L508 + if hidden_states.shape[0] != attention_mask.shape[0]: + hidden_states = hidden_states.transpose(1, 0) + original_shape = hidden_states.shape + + hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask) + attention_mask = None + + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attention_mask, + ) - hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask) - attention_mask = None - - hidden_states = torch._transformer_encoder_layer_fwd( - hidden_states, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj_weight, - self.out_proj_bias, - self.use_gelu, - self.norm_first, - self.norm1_eps, - self.norm1_weight, - self.norm1_bias, - self.norm2_weight, - self.norm2_bias, - self.linear1_weight, - self.linear1_bias, - self.linear2_weight, - self.linear2_bias, - attention_mask, - ) + if not self.is_last_layer: + hidden_states.original_shape = original_shape + elif hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0, original_shape) + else: + raise ValueError("Training and Autocast are not supported for BetterTransformer + FSMT.") - if not self.is_last_layer: - hidden_states.original_shape = original_shape - elif hidden_states.is_nested and self.is_last_layer: - hidden_states = hidden_states.to_padded_tensor(0.0, original_shape) return (hidden_states, attention_mask) @@ -1368,54 +1533,52 @@ def __init__(self, prophetnet_layer, config): self.validate_bettertransformer() def forward(self, hidden_states, attention_mask, *_, **__): - r""" - This is just a wrapper around the forward function proposed in: - https://github.com/huggingface/transformers/pull/19553 - """ - super().forward_checker() - - if not hasattr(hidden_states, "original_shape"): - original_shape = hidden_states.shape - else: - original_shape = hidden_states.original_shape + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): + if not hasattr(hidden_states, "original_shape"): + original_shape = hidden_states.shape + else: + original_shape = hidden_states.original_shape - if hidden_states.is_nested: - attention_mask = None + if hidden_states.is_nested: + attention_mask = None - if attention_mask is not None: - # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask - # 0->false->keep this token -inf->true->mask this token - attention_mask = attention_mask.squeeze(1)[:, 0] - attention_mask = attention_mask.bool() - attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1])) - hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask) - attention_mask = None + if attention_mask is not None: + # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask + # 0->false->keep this token -inf->true->mask this token + attention_mask = attention_mask.squeeze(1)[:, 0] + attention_mask = attention_mask.bool() + attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1])) + hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask) + attention_mask = None + + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attention_mask, + ) + if not self.is_last_layer: + hidden_states.original_shape = original_shape + elif hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0, original_shape) + else: + raise ValueError("Training and Autocast are not supported for BetterTransformer + ProphetNet.") - hidden_states = torch._transformer_encoder_layer_fwd( - hidden_states, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj_weight, - self.out_proj_bias, - self.use_gelu, - self.norm_first, - self.norm1_eps, - self.norm1_weight, - self.norm1_bias, - self.norm2_weight, - self.norm2_bias, - self.linear1_weight, - self.linear1_bias, - self.linear2_weight, - self.linear2_bias, - attention_mask, - ) - if not self.is_last_layer: - hidden_states.original_shape = original_shape - elif hidden_states.is_nested and self.is_last_layer: - hidden_states = hidden_states.to_padded_tensor(0.0, original_shape) return (hidden_states,) @@ -1502,39 +1665,36 @@ def __init__(self, layer, config): self.validate_bettertransformer() def forward(self, hidden_states, attention_mask, *_, **__): - r""" - This is just a wrapper around the forward function proposed in: - https://github.com/huggingface/transformers/pull/19553 - """ - super().forward_checker() - - # we expect attention_mask to be None in the vision model - if attention_mask is not None: - raise ValueError( - "Please do not use attention masks when using `BetterTransformer` converted vision models" + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): + # we expect attention_mask to be None in the vision model + if attention_mask is not None: + raise ValueError( + "Please do not use attention masks when using `BetterTransformer` converted vision models" + ) + + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attention_mask, ) - - hidden_states = torch._transformer_encoder_layer_fwd( - hidden_states, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj_weight, - self.out_proj_bias, - self.use_gelu, - self.norm_first, - self.norm1_eps, - self.norm1_weight, - self.norm1_bias, - self.norm2_weight, - self.norm2_bias, - self.linear1_weight, - self.linear1_bias, - self.linear2_weight, - self.linear2_bias, - attention_mask, - ) + else: + raise NotImplementedError("TODO") return (hidden_states,) diff --git a/optimum/bettertransformer/transformation.py b/optimum/bettertransformer/transformation.py index c0d6c734aa..fcf66558d6 100644 --- a/optimum/bettertransformer/transformation.py +++ b/optimum/bettertransformer/transformation.py @@ -231,12 +231,9 @@ def transform( f" Currently supported models are: {BetterTransformerManager.MODEL_MAPPING.keys()}." ) - # check on 1.14 in case there is any more patch release on 1.13 - if BetterTransformerManager.requires_torch_20(model.config.model_type) and parse(torch.__version__) <= parse( - "1.14" - ): + if parse(torch.__version__) <= parse("1.14"): raise ValueError( - f"BetterTransformer for {model.config.model_type} requires torch>=2.0 but {torch.__version__} is installed. Please upgrade PyTorch." + f"BetterTransformer requires torch>=2.0 but {torch.__version__} is installed. Please upgrade PyTorch." ) hf_config = model.config @@ -290,9 +287,10 @@ def transform( model = dispatch_model(model, hf_device_map, offload_dir=offload_dir) # See: https://github.com/pytorch/pytorch/issues/96099 - if BetterTransformerManager.requires_torch_20(model_fast.config.model_type): + # TODO: show the warning only for decoders (which do not need an attention mask for training) + if BetterTransformerManager.is_decoder(model_fast.config.model_type): logging.warning( - f"For training, the BetterTransformer implementation for {model_fast.config.model_type} " + f"For decoder training, the BetterTransformer implementation for {model_fast.config.model_type} " " architecture currently does not support padding as fused kernels do not support custom" " attention masks. Beware that passing padded batched training data may result in unexpected outputs." ) diff --git a/tests/bettertransformer/test_encoder.py b/tests/bettertransformer/test_encoder.py index b6859450cd..df51145849 100644 --- a/tests/bettertransformer/test_encoder.py +++ b/tests/bettertransformer/test_encoder.py @@ -20,7 +20,7 @@ import transformers from parameterized import parameterized from testing_utils import MODELS_DICT, BetterTransformersTestMixin -from transformers import AutoModel +from transformers import AutoModel, AutoProcessor, AutoTokenizer from optimum.bettertransformer import BetterTransformer from optimum.utils.testing_utils import grid_parameters, require_accelerate, require_torch_20, require_torch_gpu @@ -64,12 +64,27 @@ class BetterTransformersEncoderTest(BetterTransformersTestMixin): def tearDown(self): gc.collect() - def prepare_inputs_for_class(self, model_id, model_type): - input_dict = { - "input_ids": torch.LongTensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]), - "attention_mask": torch.LongTensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]), - } - return input_dict + def prepare_inputs_for_class(self, model_id: str, model_type: str, batch_size: int = 2, **preprocessor_kwargs): + # TODO: remove the need for tokenizer + if model_type == "markuplm": + preprocessor = AutoProcessor.from_pretrained(model_id) + else: + preprocessor = AutoTokenizer.from_pretrained(model_id) + if batch_size == 1: + texts = ["a dummy input yeah yeah!"] + else: + texts = ["a dummy input yeah yeah!"] + ["and two"] * (batch_size - 1) + + padding = preprocessor_kwargs.pop("padding", True) + if padding == "max_length": + max_length = 25 + else: + max_length = None + + inputs = preprocessor( + texts, return_tensors="pt", padding=padding, max_length=max_length, **preprocessor_kwargs + ) + return inputs def test_raise_pos_emb(self): r""" @@ -250,6 +265,40 @@ def test_accelerate_compatibility_single_gpu_without_keeping(self): max_memory = {0: "2GB"} self.check_accelerate_compatibility_cpu_gpu(keep_original_model=False, max_memory=max_memory) + @parameterized.expand( + grid_parameters( + { + "model_type": SUPPORTED_ARCH, + "batch_size": [1, 3], + } + ) + ) + def test_logits(self, test_name: str, model_type: str, batch_size: int): + if model_type in ["rocbert", "splinter", "markuplm", "bert-generation"]: + self.skipTest(f"tiny tokenizers are broken on the Hub {model_type}") + if model_type in ["tapas"]: + self.skipTest(f"{model_type} requires dataframe") + + model_id = MODELS_DICT[model_type] + self._test_logits(model_id=model_id, model_type=model_type, batch_size=batch_size) + + @parameterized.expand( + grid_parameters( + { + "model_type": SUPPORTED_ARCH, + "batch_size": [1, 3], + } + ) + ) + def test_logits_backward(self, test_name: str, model_type: str, batch_size: int): + if model_type in ["rocbert", "splinter", "markuplm", "bert-generation"]: + self.skipTest(f"tiny tokenizer is broken on the Hub for {model_type}") + if model_type in ["tapas"]: + self.skipTest(f"{model_type} requires dataframe") + + model_id = MODELS_DICT[model_type] + self._test_logits_backward(model_id=model_id, model_type=model_type, batch_size=batch_size) + @parameterized.expand(grid_parameters(FULL_GRID)) @require_torch_20 def test_invert_modules(self, test_name: str, model_type: str, keep_original_model=False): diff --git a/tests/bettertransformer/testing_utils.py b/tests/bettertransformer/testing_utils.py index 8b908fa905..2c531ebe64 100644 --- a/tests/bettertransformer/testing_utils.py +++ b/tests/bettertransformer/testing_utils.py @@ -57,7 +57,7 @@ "opt": "hf-internal-testing/tiny-random-OPTModel", "pegasus": "hf-internal-testing/tiny-random-PegasusModel", "prophetnet": "hirotasoshu/tiny-random-prophetnet", # the other tiny ones have a too small max_position_embeddings - "rembert": "hf-internal-testing/tiny-random-rembert", + "rembert": "hf-internal-testing/tiny-random-RemBertModel", "roberta": "hf-internal-testing/tiny-random-RobertaModel", "rocbert": "hf-internal-testing/tiny-random-RoCBertModel", "roformer": "hf-internal-testing/tiny-random-RoFormerModel", @@ -74,6 +74,26 @@ "yolos": "hf-internal-testing/tiny-random-YolosModel", } +known_dropout_keys = [ + "attention_probs_dropout_prob", + "hidden_dropout_prob", + "classifier_dropout_prob", + "attention_dropout", + "dropout", + "qa_dropout", + "seq_classif_dropout", + "summary_last_dropout", + "classifier_dropout", +] + + +def set_dropout_to_zero(config): + for attr_name in known_dropout_keys: + if hasattr(config, attr_name): + setattr(config, attr_name, 0.0) + + return config + class BetterTransformersTestMixin(unittest.TestCase): r""" @@ -136,6 +156,59 @@ def _test_fp16_inference( f"Maxdiff: {(output_hf - output_bt).abs().max()}", ) + def _test_logits_backward(self, model_id: str, model_type: str, **preprocessor_kwargs): + inputs = self.prepare_inputs_for_class(model_id=model_id, model_type=model_type, **preprocessor_kwargs) + + hf_random_model = AutoModel.from_pretrained(model_id).eval() + random_config = hf_random_model.config + + # I could not obtain reproducible results with `torch.manual_seed` nor with + # `torch.random.set_rng_state`. An alternative could be to make dropout stateful, + # and to replace them with a static pattern for this test. Currently, we use + # functional dropout though. + random_config = set_dropout_to_zero(random_config) + + hf_random_model = hf_random_model.__class__(random_config) + converted_model = copy.deepcopy(hf_random_model) + converted_model = BetterTransformer.transform(converted_model) + + hf_random_model = hf_random_model.train() + converted_model = converted_model.train() + + optimizer_hf = torch.optim.SGD(hf_random_model.parameters(), lr=0.2) + optimizer_bt = torch.optim.SGD(converted_model.parameters(), lr=0.2) + + tol = 2e-3 + + hf_hidden_states = hf_random_model(**inputs)[0] + bt_hidden_states = converted_model(**inputs)[0] + + self.assert_equal( + hf_hidden_states, + bt_hidden_states, + atol=tol, + model_name=hf_random_model.__class__.__name__, + ) + + loss_hf = hf_hidden_states.abs().mean() + loss_bt = bt_hidden_states.abs().mean() + + loss_hf.backward() + loss_bt.backward() + + optimizer_hf.step() + optimizer_bt.step() + + hf_hidden_states = hf_random_model(**inputs)[0] + bt_hidden_states = converted_model(**inputs)[0] + + self.assert_equal( + hf_hidden_states, + bt_hidden_states, + atol=tol, + model_name=hf_random_model.__class__.__name__, + ) + def _test_logits(self, model_id: str, model_type: str, **preprocessor_kwargs): r""" This tests if the converted model produces the same logits From abd892012f63f400ad673b69a677f521da4995af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Wed, 26 Jul 2023 15:19:57 +0200 Subject: [PATCH 2/8] encoders and encoder+decoder all work --- .../models/encoder_models.py | 180 +++++++++++++----- optimum/bettertransformer/transformation.py | 2 +- optimum/utils/testing_utils.py | 6 - tests/bettertransformer/test_audio.py | 10 +- tests/bettertransformer/test_common.py | 13 +- tests/bettertransformer/test_decoder.py | 35 ++-- tests/bettertransformer/test_encoder.py | 10 +- .../bettertransformer/test_encoder_decoder.py | 26 ++- tests/bettertransformer/test_vision.py | 5 +- tests/bettertransformer/testing_utils.py | 20 +- 10 files changed, 191 insertions(+), 116 deletions(-) diff --git a/optimum/bettertransformer/models/encoder_models.py b/optimum/bettertransformer/models/encoder_models.py index 9cda2ee925..3eb4b45be6 100644 --- a/optimum/bettertransformer/models/encoder_models.py +++ b/optimum/bettertransformer/models/encoder_models.py @@ -104,6 +104,7 @@ def __init__(self, albert_layer, config): self.attention_head_size = config.hidden_size // config.num_attention_heads self.attention_probs_dropout_prob = config.attention_probs_dropout_prob self.hidden_dropout_prob = config.hidden_dropout_prob + self.act_fn_callable = ACT2FN[self.act_fn] self.validate_bettertransformer() @@ -150,13 +151,6 @@ def forward(self, hidden_states, attention_mask, *_): qkv = qkv.view(qkv.size()[:-1] + (3, self.num_heads, self.attention_head_size)).permute(2, 0, 3, 1, 4) query, key, value = qkv[0], qkv[1], qkv[2] - # TODO: pass scale argument in PyTorch 2.1 release - query = ( - query - * torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32)).to(torch.get_default_dtype()) - / 8 - ) - # NOTE: In PyTorch 2.0, passing an attention_mask will automatically dispatch # to the "math" path and will NOT use flash attention / memory-efficient attention. # We should support xformers / Hazy-flash / rocm-flash directly and stop relying on PyTorch to do the work. @@ -181,25 +175,23 @@ def forward(self, hidden_states, attention_mask, *_): training=self.training, ) + hidden_states, - normalized_shape=self.norm1_weight.shape, # TODO: stateful + normalized_shape=self.norm1_weight.shape, weight=self.norm1_weight, bias=self.norm1_bias, ) # BertIntermediate - # TODO: stateful - act_fn = ACT2FN[self.act_fn] - x = act_fn(F.linear(attention_out, self.linear1_weight, self.linear1_bias)) + hidden_states = self.act_fn_callable(F.linear(attention_out, self.linear1_weight, self.linear1_bias)) # BertOutput hidden_states = F.layer_norm( attention_out + F.dropout( - F.linear(x, self.linear2_weight, self.linear2_bias), + F.linear(hidden_states, self.linear2_weight, self.linear2_bias), p=self.hidden_dropout_prob, training=self.training, ), - normalized_shape=self.norm2_weight.shape, # TODO: stateful + normalized_shape=self.norm2_weight.shape, weight=self.norm2_weight, bias=self.norm2_bias, ) @@ -292,6 +284,7 @@ def __init__(self, bert_layer, config): self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.hidden_dropout_prob = config.hidden_dropout_prob self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + self.act_fn_callable = ACT2FN[self.act_fn] self.validate_bettertransformer() @@ -337,13 +330,6 @@ def forward(self, hidden_states, attention_mask, *_): qkv = qkv.view(qkv.size()[:-1] + (3, self.num_heads, self.attention_head_size)).permute(2, 0, 3, 1, 4) query, key, value = qkv[0], qkv[1], qkv[2] - # TODO: pass scale argument in PyTorch 2.1 release - query = ( - query - * torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32)).to(torch.get_default_dtype()) - / 8 - ) - # NOTE: In PyTorch 2.0, passing an attention_mask will automatically dispatch # to the "math" path and will NOT use flash attention / memory-efficient attention. # We should support xformers / Hazy-flash / rocm-flash directly and stop relying on PyTorch to do the work. @@ -368,25 +354,23 @@ def forward(self, hidden_states, attention_mask, *_): training=self.training, ) + hidden_states, - normalized_shape=self.norm1_weight.shape, # TODO: stateful + normalized_shape=self.norm1_weight.shape, weight=self.norm1_weight, bias=self.norm1_bias, ) # BertIntermediate - # TODO: stateful - act_fn = ACT2FN[self.act_fn] - x = act_fn(F.linear(attention_out, self.linear1_weight, self.linear1_bias)) + hidden_states = self.act_fn_callable(F.linear(attention_out, self.linear1_weight, self.linear1_bias)) # BertOutput hidden_states = F.layer_norm( attention_out + F.dropout( - F.linear(x, self.linear2_weight, self.linear2_bias), + F.linear(hidden_states, self.linear2_weight, self.linear2_bias), p=self.hidden_dropout_prob, training=self.training, ), - normalized_shape=self.norm2_weight.shape, # TODO: stateful + normalized_shape=self.norm2_weight.shape, weight=self.norm2_weight, bias=self.norm2_bias, ) @@ -471,6 +455,10 @@ def __init__(self, bart_layer, config): "norm2_weight": "final_layer_norm.weight", "norm2_bias": "final_layer_norm.bias", } + self.dropout = config.attention_dropout + self.activation_dropout = config.activation_dropout + self.attention_head_size = config.d_model // config.encoder_attention_heads + self.act_fn_callable = ACT2FN[self.act_fn] self.validate_bettertransformer() @@ -521,7 +509,58 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): elif hidden_states.is_nested and self.is_last_layer: hidden_states = hidden_states.to_padded_tensor(0.0, original_shape) else: - raise NotImplementedError("TODO") + qkv = F.linear(hidden_states, weight=self.in_proj_weight, bias=self.in_proj_bias) + + qkv = qkv.view(qkv.size()[:-1] + (3, self.num_heads, self.attention_head_size)).permute(2, 0, 3, 1, 4) + query, key, value = qkv[0], qkv[1], qkv[2] + + # NOTE: In PyTorch 2.0, passing an attention_mask will automatically dispatch + # to the "math" path and will NOT use flash attention / memory-efficient attention. + # We should support xformers / Hazy-flash / rocm-flash directly and stop relying on PyTorch to do the work. + attention_out = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + is_causal=False, + dropout_p=self.dropout if self.training else 0.0, + ) + + attention_out = attention_out.permute(0, 2, 1, 3).contiguous() + new_attention_out_shape = attention_out.size()[:-2] + (self.num_heads * self.attention_head_size,) + attention_out = attention_out.view(new_attention_out_shape) + + # BertSelfOutput + attention_out = F.layer_norm( + F.dropout( + F.linear(attention_out, self.out_proj_weight, self.out_proj_bias), + p=self.dropout, + training=self.training, + ) + + hidden_states, + normalized_shape=self.norm1_weight.shape, + weight=self.norm1_weight, + bias=self.norm1_bias, + ) + + # One additional dropout compared to bert + hidden_states = F.dropout( + self.act_fn_callable(F.linear(attention_out, self.linear1_weight, self.linear1_bias)), + p=self.activation_dropout, + training=self.training, + ) + + hidden_states = F.layer_norm( + attention_out + + F.dropout( + F.linear(hidden_states, self.linear2_weight, self.linear2_bias), + p=self.dropout, + training=self.training, + ), + normalized_shape=self.norm2_weight.shape, + weight=self.norm2_weight, + bias=self.norm2_bias, + ) return (hidden_states,) @@ -606,6 +645,10 @@ def __init__(self, mbart_layer, config): "norm2_bias": "final_layer_norm.bias", "norm2_eps": "final_layer_norm.eps", } + self.dropout = config.attention_dropout + self.activation_dropout = config.activation_dropout + self.attention_head_size = config.d_model // config.encoder_attention_heads + self.act_fn_callable = ACT2FN[self.act_fn] self.validate_bettertransformer() @@ -656,7 +699,60 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): elif hidden_states.is_nested and self.is_last_layer: hidden_states = hidden_states.to_padded_tensor(0.0, original_shape) else: - raise NotImplementedError("TODO") + residual = hidden_states + hidden_states = F.layer_norm( + hidden_states, + normalized_shape=self.norm1_weight.shape, + weight=self.norm1_weight, + bias=self.norm1_bias, + ) + + qkv = F.linear(hidden_states, weight=self.in_proj_weight, bias=self.in_proj_bias) + qkv = qkv.view(qkv.size()[:-1] + (3, self.num_heads, self.attention_head_size)).permute(2, 0, 3, 1, 4) + query, key, value = qkv[0], qkv[1], qkv[2] + + # NOTE: In PyTorch 2.0, passing an attention_mask will automatically dispatch + # to the "math" path and will NOT use flash attention / memory-efficient attention. + # We should support xformers / Hazy-flash / rocm-flash directly and stop relying on PyTorch to do the work. + attention_out = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + is_causal=False, + dropout_p=self.dropout if self.training else 0.0, + ) + + attention_out = attention_out.permute(0, 2, 1, 3).contiguous() + new_attention_out_shape = attention_out.size()[:-2] + (self.num_heads * self.attention_head_size,) + attention_out = attention_out.view(new_attention_out_shape) + + hidden_states = residual + F.dropout( + F.linear(attention_out, self.out_proj_weight, self.out_proj_bias), + p=self.dropout, + training=self.training, + ) + residual = hidden_states + hidden_states = F.layer_norm( + hidden_states, + normalized_shape=self.norm2_weight.shape, + weight=self.norm2_weight, + bias=self.norm2_bias, + ) + + # One additional dropout compared to bert + hidden_states = F.dropout( + self.act_fn_callable(F.linear(hidden_states, self.linear1_weight, self.linear1_bias)), + p=self.activation_dropout, + training=self.training, + ) + + hidden_states = residual + F.dropout( + F.linear(hidden_states, self.linear2_weight, self.linear2_bias), + p=self.dropout, + training=self.training, + ) + return (hidden_states,) @@ -737,6 +833,7 @@ def __init__(self, bert_layer, config): self.attention_dropout = config.attention_dropout self.dropout = config.dropout self.attention_head_size = config.dim // config.n_heads + self.act_fn_callable = ACT2FN[self.act_fn] self.validate_bettertransformer() @@ -786,13 +883,6 @@ def forward(self, hidden_states, attn_mask, head_mask=None, output_attentions=No qkv = qkv.view(qkv.size()[:-1] + (3, self.num_heads, self.attention_head_size)).permute(2, 0, 3, 1, 4) query, key, value = qkv[0], qkv[1], qkv[2] - # TODO: pass scale argument in PyTorch 2.1 release - query = ( - query - * torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32)).to(torch.get_default_dtype()) - / 8 - ) - # TODO: Kind of stupid to do that at each layer, should be fixed in transformers attn_mask = attn_mask.unsqueeze(1).unsqueeze(2).to(dtype=query.dtype) attn_mask = (1.0 - attn_mask) * torch.finfo(query.dtype).min @@ -821,23 +911,23 @@ def forward(self, hidden_states, attn_mask, head_mask=None, output_attentions=No training=self.training, ) + hidden_states, - normalized_shape=self.norm1_weight.shape, # TODO: stateful + normalized_shape=self.norm1_weight.shape, weight=self.norm1_weight, bias=self.norm1_bias, ) # BertIntermediate - # TODO: stateful - act_fn = ACT2FN[self.act_fn] - x = act_fn(F.linear(attention_out, self.linear1_weight, self.linear1_bias)) + hidden_states = self.act_fn_callable(F.linear(attention_out, self.linear1_weight, self.linear1_bias)) # BertOutput hidden_states = F.layer_norm( attention_out + F.dropout( - F.linear(x, self.linear2_weight, self.linear2_bias), p=self.dropout, training=self.training + F.linear(hidden_states, self.linear2_weight, self.linear2_bias), + p=self.dropout, + training=self.training, ), - normalized_shape=self.norm2_weight.shape, # TODO: stateful + normalized_shape=self.norm2_weight.shape, weight=self.norm2_weight, bias=self.norm2_bias, ) @@ -1445,7 +1535,9 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): elif hidden_states.is_nested and self.is_last_layer: hidden_states = hidden_states.to_padded_tensor(0.0, original_shape) else: - raise ValueError("Training and Autocast are not supported for BetterTransformer + FSMT.") + raise NotImplementedError( + "Training and Autocast are not implemented for BetterTransformer + FSMT. Please open an issue." + ) return (hidden_states, attention_mask) @@ -1577,7 +1669,9 @@ def forward(self, hidden_states, attention_mask, *_, **__): elif hidden_states.is_nested and self.is_last_layer: hidden_states = hidden_states.to_padded_tensor(0.0, original_shape) else: - raise ValueError("Training and Autocast are not supported for BetterTransformer + ProphetNet.") + raise ValueError( + "Training and Autocast are not implemented for BetterTransformer + ProphetNet. Please open an issue." + ) return (hidden_states,) diff --git a/optimum/bettertransformer/transformation.py b/optimum/bettertransformer/transformation.py index fcf66558d6..8520465bc7 100644 --- a/optimum/bettertransformer/transformation.py +++ b/optimum/bettertransformer/transformation.py @@ -288,7 +288,7 @@ def transform( # See: https://github.com/pytorch/pytorch/issues/96099 # TODO: show the warning only for decoders (which do not need an attention mask for training) - if BetterTransformerManager.is_decoder(model_fast.config.model_type): + if False: # BetterTransformerManager.is_decoder(model_fast.config.model_type): logging.warning( f"For decoder training, the BetterTransformer implementation for {model_fast.config.model_type} " " architecture currently does not support padding as fused kernels do not support custom" diff --git a/optimum/utils/testing_utils.py b/optimum/utils/testing_utils.py index 1d1177ae72..e48a128051 100644 --- a/optimum/utils/testing_utils.py +++ b/optimum/utils/testing_utils.py @@ -23,7 +23,6 @@ from typing import Any, Callable, Dict, Iterable, Optional, Tuple import torch -from packaging.version import parse from . import is_accelerate_available, is_diffusers_available @@ -63,11 +62,6 @@ def require_torch_gpu(test_case): return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) -def require_torch_20(test_case): - """Decorator marking a test that requires torch>=2.0.""" - return unittest.skipUnless(parse(torch.__version__) > parse("1.14"), "test requires torch>=2.0")(test_case) - - def require_hf_token(test_case): """ Decorator marking a test that requires huggingface hub token. diff --git a/tests/bettertransformer/test_audio.py b/tests/bettertransformer/test_audio.py index 95f74de491..5d995ce439 100644 --- a/tests/bettertransformer/test_audio.py +++ b/tests/bettertransformer/test_audio.py @@ -21,7 +21,7 @@ from transformers import AutoFeatureExtractor, AutoModel, AutoProcessor from optimum.bettertransformer import BetterTransformer -from optimum.utils.testing_utils import grid_parameters, require_torch_20 +from optimum.utils.testing_utils import grid_parameters ALL_AUDIO_MODELS_TO_TEST = [ @@ -64,21 +64,16 @@ def prepare_inputs_for_class(self, model_id, model_type): return input_dict @parameterized.expand(grid_parameters(FULL_GRID)) - @require_torch_20 def test_invert_modules(self, test_name: str, model_type: str, keep_original_model=False): - self._skip_on_torch_version(model_type) model_id = MODELS_DICT[model_type] self._test_invert_modules(model_id=model_id, keep_original_model=keep_original_model) @parameterized.expand(grid_parameters(FULL_GRID)) - @require_torch_20 def test_save_load_invertible(self, test_name: str, model_type: str, keep_original_model=False): - self._skip_on_torch_version(model_type) model_id = MODELS_DICT[model_type] self._test_save_load_invertible(model_id=model_id, keep_original_model=keep_original_model) @parameterized.expand(grid_parameters(FULL_GRID)) - @require_torch_20 def test_invert_model_logits(self, test_name: str, model_type: str, keep_original_model=False): model_id = MODELS_DICT[model_type] self._test_invert_model_logits( @@ -182,7 +177,6 @@ def test_raise_train(self, model_type: str): self._test_raise_train(model_id, model_type=model_type) @parameterized.expand(grid_parameters(FULL_GRID)) - @require_torch_20 def test_invert_modules(self, test_name: str, model_type: str, keep_original_model=False): if model_type in ["hubert", "wav2vec2"] and keep_original_model is True: self.skipTest(f"{model_type} does not support keep_original_model=True") @@ -194,7 +188,6 @@ def test_invert_modules(self, test_name: str, model_type: str, keep_original_mod self._test_invert_modules(model_id=model_id, keep_original_model=keep_original_model) @parameterized.expand(grid_parameters(FULL_GRID)) - @require_torch_20 def test_save_load_invertible(self, test_name: str, model_type: str, keep_original_model=False): if model_type in ["hubert", "wav2vec2"] and keep_original_model is True: self.skipTest(f"{model_type} does not support keep_original_model=True") @@ -206,7 +199,6 @@ def test_save_load_invertible(self, test_name: str, model_type: str, keep_origin self._test_save_load_invertible(model_id=model_id, keep_original_model=keep_original_model) @parameterized.expand(grid_parameters(FULL_GRID)) - @require_torch_20 def test_invert_model_logits(self, test_name: str, model_type: str, keep_original_model=False): if model_type == "hubert" and keep_original_model is True: self.skipTest("hubert does not support keep_original_model=True") diff --git a/tests/bettertransformer/test_common.py b/tests/bettertransformer/test_common.py index ffcd413f4a..e34923730c 100644 --- a/tests/bettertransformer/test_common.py +++ b/tests/bettertransformer/test_common.py @@ -16,23 +16,17 @@ import unittest from unittest.mock import patch -import torch import transformers -from packaging.version import parse from parameterized import parameterized from testing_utils import MODELS_DICT from transformers import AutoModel from optimum.bettertransformer import BetterTransformer, BetterTransformerManager from optimum.pipelines import pipeline -from optimum.utils.testing_utils import grid_parameters, require_torch_20 +from optimum.utils.testing_utils import grid_parameters class BetterTransformerIntegrationTests(unittest.TestCase): - def _skip_on_torch_version(self, model_type: str): - if BetterTransformerManager.requires_torch_20(model_type) and parse(torch.__version__) < parse("1.14"): - self.skipTest(f"The model type {model_type} require PyTorch 2.0 for BetterTransformer") - def test_raise_error_on_double_transform_call(self): model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-BertModel") @@ -60,7 +54,6 @@ def test_raise_on_save(self, model_type: str): r""" Test if the conversion properly raises an error if someone tries to save the model using `save_pretrained`. """ - self._skip_on_torch_version(model_type) model_ids = ( MODELS_DICT[model_type] if isinstance(MODELS_DICT[model_type], tuple) else (MODELS_DICT[model_type],) ) @@ -76,7 +69,6 @@ def test_conversion(self, model_type: str): This tests if the conversion of a slow model to its BetterTransformer version using fastpath has been successful. """ - self._skip_on_torch_version(model_type) model_ids = ( MODELS_DICT[model_type] if isinstance(MODELS_DICT[model_type], tuple) else (MODELS_DICT[model_type],) ) @@ -93,13 +85,11 @@ def test_conversion(self, model_type: str): self.assertTrue(hasattr(converted_model, "generate")) @parameterized.expand(grid_parameters({"model_type": MODELS_DICT.keys(), "keep_original_model": [True, False]})) - @require_torch_20 def test_raise_save_pretrained_error(self, test_name: str, model_type: str, keep_original_model: bool): r""" Test if the converted model raises an error when calling `save_pretrained` but not when the model is reverted """ - self._skip_on_torch_version(model_type) if model_type in ["wav2vec2", "hubert"] and keep_original_model is True: self.skipTest("These architectures do not support deepcopy") @@ -125,7 +115,6 @@ def test_raise_activation_fun(self, model_type: str): A tests that checks if the conversion raises an error if the model contains an activation function that is not supported by `BetterTransformer`. Here we need to loop over the config files """ - self._skip_on_torch_version(model_type) if BetterTransformerManager.requires_strict_validation(model_type) is False: self.skipTest("The architecture does not require a specific activation function") diff --git a/tests/bettertransformer/test_decoder.py b/tests/bettertransformer/test_decoder.py index d446a08385..a19e64fcf2 100644 --- a/tests/bettertransformer/test_decoder.py +++ b/tests/bettertransformer/test_decoder.py @@ -23,7 +23,7 @@ from optimum.bettertransformer import BetterTransformer from optimum.utils import DummyPastKeyValuesGenerator, NormalizedConfigManager -from optimum.utils.testing_utils import grid_parameters, require_accelerate, require_torch_20, require_torch_gpu +from optimum.utils.testing_utils import grid_parameters, require_accelerate, require_torch_gpu class BetterTransformersDecoderTest(BetterTransformersTestMixin, unittest.TestCase): @@ -34,7 +34,9 @@ class BetterTransformersDecoderTest(BetterTransformersTestMixin, unittest.TestCa "keep_original_model": [True, False], } - def prepare_inputs_for_class(self, model_id: str, model_type: str, batch_size: int = 2, **preprocessor_kwargs): + def prepare_inputs_for_class( + self, model_id: str, model_type: str, batch_size: int = 2, no_padding: bool = False, **preprocessor_kwargs + ): tokenizer = AutoTokenizer.from_pretrained(model_id) if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: if tokenizer.eos_token != "": @@ -45,13 +47,12 @@ def prepare_inputs_for_class(self, model_id: str, model_type: str, batch_size: i padding = preprocessor_kwargs.pop("padding", True) if batch_size == 1: texts = ["a dummy input yeah!"] + elif no_padding: + texts = ["a dummy input yeah!"] * batch_size else: texts = ["a dummy input yeah!"] + ["and two"] * (batch_size - 1) inputs = tokenizer(texts, return_tensors="pt", padding=padding, max_length=20, **preprocessor_kwargs) - if model_type == "llama": - del inputs["token_type_ids"] - return inputs @parameterized.expand( @@ -64,13 +65,24 @@ def prepare_inputs_for_class(self, model_id: str, model_type: str, batch_size: i ) ) def test_logits_without_cache(self, test_name: str, model_type: str, padding, batch_size: int): - self._skip_on_torch_version(model_type) if batch_size == 1 and padding == "max_length": self.skipTest("batch_size=1 + padding='max_length' is unsupported") model_id = MODELS_DICT[model_type] self._test_logits(model_id, model_type=model_type, padding=padding, batch_size=batch_size) + @parameterized.expand( + grid_parameters( + { + "model_type": SUPPORTED_ARCH, + "batch_size": [1, 3], + } + ) + ) + def test_logits_backward(self, test_name: str, model_type: str, batch_size: int): + model_id = MODELS_DICT[model_type] + self._test_logits_backward(model_id, model_type=model_type, no_padding=True, batch_size=batch_size) + @parameterized.expand( grid_parameters( { @@ -84,8 +96,6 @@ def test_logits_without_cache(self, test_name: str, model_type: str, padding, ba @require_torch_gpu @pytest.mark.gpu_test def test_fp16_inference(self, test_name: str, model_type: str, use_to_operator: bool, batch_size: int): - self._skip_on_torch_version(model_type) - model_id = MODELS_DICT[model_type] self._test_fp16_inference( model_id, @@ -104,7 +114,6 @@ def test_fp16_inference(self, test_name: str, model_type: str, use_to_operator: ) ) def test_logits_with_cache(self, test_name: str, model_type: str, batch_size: int): - self._skip_on_torch_version(model_type) input_ids = torch.randint(low=1, high=10, size=(batch_size, 1)) seq_length = 12 attention_mask = torch.ones(batch_size, seq_length + 1, dtype=torch.int32) @@ -138,7 +147,6 @@ def test_logits_with_cache(self, test_name: str, model_type: str, batch_size: in grid_parameters({"model_type": SUPPORTED_ARCH, "batch_size": [1, 3], "padding": [True, "max_length"]}) ) def test_generation(self, test_name: str, model_type: str, batch_size: int, padding: str): - self._skip_on_torch_version(model_type) if batch_size == 1 and padding == "max_length": self.skipTest("batch_size=1 + padding='max_length' is unsupported") @@ -175,33 +183,26 @@ def test_generation(self, test_name: str, model_type: str, batch_size: int, padd @parameterized.expand(SUPPORTED_ARCH) def test_raise_autocast(self, model_type: str): - self._skip_on_torch_version(model_type) model_id = MODELS_DICT[model_type] self._test_raise_autocast(model_id, model_type=model_type) @parameterized.expand(SUPPORTED_ARCH) @pytest.mark.training def test_train(self, model_type: str): - self._skip_on_torch_version(model_type) model_id = MODELS_DICT[model_type] self._test_train_decoder(model_id, model_type=model_type) @parameterized.expand(grid_parameters(FULL_GRID)) - @require_torch_20 def test_invert_modules(self, test_name: str, model_type: str, keep_original_model=False): - self._skip_on_torch_version(model_type) model_id = MODELS_DICT[model_type] self._test_invert_modules(model_id=model_id, keep_original_model=keep_original_model) @parameterized.expand(grid_parameters(FULL_GRID)) - @require_torch_20 def test_save_load_invertible(self, test_name: str, model_type: str, keep_original_model=False): - self._skip_on_torch_version(model_type) model_id = MODELS_DICT[model_type] self._test_save_load_invertible(model_id=model_id, keep_original_model=keep_original_model) @parameterized.expand(grid_parameters(FULL_GRID)) - @require_torch_20 def test_invert_model_logits(self, test_name: str, model_type: str, keep_original_model=False): model_id = MODELS_DICT[model_type] self._test_invert_model_logits( diff --git a/tests/bettertransformer/test_encoder.py b/tests/bettertransformer/test_encoder.py index df51145849..2c44177fe9 100644 --- a/tests/bettertransformer/test_encoder.py +++ b/tests/bettertransformer/test_encoder.py @@ -23,7 +23,8 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer from optimum.bettertransformer import BetterTransformer -from optimum.utils.testing_utils import grid_parameters, require_accelerate, require_torch_20, require_torch_gpu +from optimum.pipelines import pipeline +from optimum.utils.testing_utils import grid_parameters, require_accelerate, require_torch_gpu class BetterTransformersEncoderTest(BetterTransformersTestMixin): @@ -144,8 +145,6 @@ def test_pipeline_on_cpu(self): r""" This test runs pipeline together with Better Transformers converted models using optimum `pipeline`. """ - from optimum.pipelines import pipeline - model_name = "distilbert-base-uncased" unmasker = pipeline("fill-mask", model_name, accelerator="bettertransformer") @@ -160,8 +159,6 @@ def test_pipeline_on_gpu(self): r""" This test runs pipeline together with Better Transformers converted models using optimum `pipeline`. """ - from optimum.pipelines import pipeline - model_name = "distilbert-base-uncased" unmasker = pipeline("fill-mask", model_name, accelerator="bettertransformer", device="cuda:0") @@ -300,19 +297,16 @@ def test_logits_backward(self, test_name: str, model_type: str, batch_size: int) self._test_logits_backward(model_id=model_id, model_type=model_type, batch_size=batch_size) @parameterized.expand(grid_parameters(FULL_GRID)) - @require_torch_20 def test_invert_modules(self, test_name: str, model_type: str, keep_original_model=False): model_id = MODELS_DICT[model_type] self._test_invert_modules(model_id=model_id, keep_original_model=keep_original_model) @parameterized.expand(grid_parameters(FULL_GRID)) - @require_torch_20 def test_save_load_invertible(self, test_name: str, model_type: str, keep_original_model=False): model_id = MODELS_DICT[model_type] self._test_save_load_invertible(model_id=model_id, keep_original_model=keep_original_model) @parameterized.expand(grid_parameters(FULL_GRID)) - @require_torch_20 def test_invert_model_logits(self, test_name: str, model_type: str, keep_original_model=False): model_id = MODELS_DICT[model_type] self._test_invert_model_logits( diff --git a/tests/bettertransformer/test_encoder_decoder.py b/tests/bettertransformer/test_encoder_decoder.py index 44173e9267..f400d16967 100644 --- a/tests/bettertransformer/test_encoder_decoder.py +++ b/tests/bettertransformer/test_encoder_decoder.py @@ -22,7 +22,7 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from optimum.bettertransformer import BetterTransformer -from optimum.utils.testing_utils import grid_parameters, require_torch_20, require_torch_gpu +from optimum.utils.testing_utils import grid_parameters, require_torch_gpu class BetterTransformersEncoderDecoderTest(BetterTransformersTestMixin, unittest.TestCase): @@ -71,19 +71,31 @@ def prepare_inputs_for_class(self, model_id, model_type, **preprocessor_kwargs): ) ) def test_logits_without_cache(self, test_name: str, model_type: str, padding, max_length=20): - self._skip_on_torch_version(model_type) model_id = MODELS_DICT[model_type] self._test_logits(model_id, model_type=model_type, padding=padding, max_length=max_length) + @parameterized.expand( + grid_parameters( + { + "model_type": SUPPORTED_ARCH, + "padding": ["max_length", True], + } + ) + ) + def test_logits_backward(self, test_name: str, model_type: str, padding, max_length=20): + if model_type in ["fsmt", "prophetnet"]: + self.skipTest(f"Training support not implemented for {model_type}") + + model_id = MODELS_DICT[model_type] + self._test_logits_backward(model_id, model_type=model_type, padding=padding, max_length=max_length) + @parameterized.expand(SUPPORTED_ARCH) def test_raise_autocast(self, model_type: str): - self._skip_on_torch_version(model_type) model_id = MODELS_DICT[model_type] self._test_raise_autocast(model_id, model_type=model_type) @parameterized.expand(SUPPORTED_ARCH) def test_raise_train(self, model_type: str): - self._skip_on_torch_version(model_type) model_id = MODELS_DICT[model_type] if model_type not in ["blenderbot", "pegasus", "t5"]: self._test_raise_train(model_id, model_type=model_type) @@ -91,19 +103,16 @@ def test_raise_train(self, model_type: str): self._test_train_decoder(model_id, model_type=model_type) @parameterized.expand(grid_parameters(FULL_GRID)) - @require_torch_20 def test_invert_modules(self, test_name: str, model_type: str, keep_original_model=False): model_id = MODELS_DICT[model_type] self._test_invert_modules(model_id=model_id, keep_original_model=keep_original_model) @parameterized.expand(grid_parameters(FULL_GRID)) - @require_torch_20 def test_save_load_invertible(self, test_name: str, model_type: str, keep_original_model=False): model_id = MODELS_DICT[model_type] self._test_save_load_invertible(model_id=model_id, keep_original_model=keep_original_model) @parameterized.expand(grid_parameters(FULL_GRID)) - @require_torch_20 def test_invert_model_logits(self, test_name: str, model_type: str, keep_original_model=False): model_id = MODELS_DICT[model_type] self._test_invert_model_logits( @@ -122,8 +131,6 @@ def test_invert_model_logits(self, test_name: str, model_type: str, keep_origina @require_torch_gpu @pytest.mark.gpu_test def test_fp16_inference(self, test_name: str, model_type: str, use_to_operator: bool): - self._skip_on_torch_version(model_type) - # TODO: fix in transformers if model_type == "fsmt": self.skipTest("fsmt is broken is transformers when loaded through torch_dtype=torch.float16") @@ -137,7 +144,6 @@ def test_fp16_inference(self, test_name: str, model_type: str, use_to_operator: grid_parameters({"model_type": SUPPORTED_ARCH, "batch_size": [1, 3], "padding": [True, "max_length"]}) ) def test_generation(self, test_name: str, model_type: str, batch_size: int, padding: str): - self._skip_on_torch_version(model_type) if batch_size == 1 and padding == "max_length": self.skipTest("batch_size=1 + padding='max_length' is unsupported") diff --git a/tests/bettertransformer/test_vision.py b/tests/bettertransformer/test_vision.py index 025c539330..48410dff7b 100644 --- a/tests/bettertransformer/test_vision.py +++ b/tests/bettertransformer/test_vision.py @@ -20,7 +20,7 @@ from testing_utils import MODELS_DICT, BetterTransformersTestMixin from transformers import AutoFeatureExtractor, AutoProcessor -from optimum.utils.testing_utils import grid_parameters, require_torch_20 +from optimum.utils.testing_utils import grid_parameters class BetterTransformersVisionTest(BetterTransformersTestMixin, unittest.TestCase): @@ -93,7 +93,6 @@ def test_raise_train(self, model_type: str): } ) ) - @require_torch_20 def test_invert_modules(self, test_name: str, model_type: str, keep_original_model=False): model_id = MODELS_DICT[model_type] self._test_invert_modules(model_id=model_id, keep_original_model=keep_original_model) @@ -106,7 +105,6 @@ def test_invert_modules(self, test_name: str, model_type: str, keep_original_mod } ) ) - @require_torch_20 def test_save_load_invertible(self, test_name: str, model_type: str, keep_original_model=False): model_id = MODELS_DICT[model_type] self._test_save_load_invertible(model_id=model_id, keep_original_model=keep_original_model) @@ -119,7 +117,6 @@ def test_save_load_invertible(self, test_name: str, model_type: str, keep_origin } ) ) - @require_torch_20 def test_invert_model_logits(self, test_name: str, model_type: str, keep_original_model=False): model_id = MODELS_DICT[model_type] self._test_invert_model_logits( diff --git a/tests/bettertransformer/testing_utils.py b/tests/bettertransformer/testing_utils.py index 2c531ebe64..0e8a7c7e06 100644 --- a/tests/bettertransformer/testing_utils.py +++ b/tests/bettertransformer/testing_utils.py @@ -19,10 +19,9 @@ import unittest import torch -from packaging.version import parse from transformers import AutoModel -from optimum.bettertransformer import BetterTransformer, BetterTransformerManager +from optimum.bettertransformer import BetterTransformer from optimum.utils.testing_utils import flatten_dict, require_torch_gpu @@ -84,6 +83,13 @@ "seq_classif_dropout", "summary_last_dropout", "classifier_dropout", + "activation_dropout", + "classif_dropout", + "dropout_rate", + "attn_pdrop", + "embd_pdrop", + "resid_pdrop", + "summary_first_dropout", ] @@ -110,10 +116,6 @@ class BetterTransformersTestMixin(unittest.TestCase): def prepare_inputs_for_class(self, model_id=None, model_type=None): raise NotImplementedError - def _skip_on_torch_version(self, model_type: str): - if BetterTransformerManager.requires_torch_20(model_type) and parse(torch.__version__) < parse("1.14"): - self.skipTest(f"The model type {model_type} require PyTorch 2.0 for BetterTransformer") - @require_torch_gpu def _test_fp16_inference( self, model_id: str, model_type: str, automodel_class, use_to_operator=False, **preprocessor_kwargs @@ -168,7 +170,13 @@ def _test_logits_backward(self, model_id: str, model_type: str, **preprocessor_k # functional dropout though. random_config = set_dropout_to_zero(random_config) + # m2m_100 randomly drops layers, which makes testing flaky (see `skip_the_layer` in transformers, some other models use it as well) + if model_type == "m2m_100": + random_config.encoder_layerdrop = 0 + random_config.decoder_layerdrop = 0 + hf_random_model = hf_random_model.__class__(random_config) + converted_model = copy.deepcopy(hf_random_model) converted_model = BetterTransformer.transform(converted_model) From db0c561898bd9e513adf2e0d6130f4e2528247f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Wed, 26 Jul 2023 15:30:23 +0200 Subject: [PATCH 3/8] warning about training decoders with padding --- optimum/bettertransformer/models/__init__.py | 13 +++++++++++++ optimum/bettertransformer/models/decoder_models.py | 1 - optimum/bettertransformer/transformation.py | 11 +++++------ 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/optimum/bettertransformer/models/__init__.py b/optimum/bettertransformer/models/__init__.py index 3083c2e517..56f907125b 100644 --- a/optimum/bettertransformer/models/__init__.py +++ b/optimum/bettertransformer/models/__init__.py @@ -148,6 +148,19 @@ class BetterTransformerManager: "t5", } + DO_NOT_SUPPORT_PADDED_TRAINING = { + "blenderbot", + "codegen", + "gpt2", + "gptj", + "gpt_neo", + "gpt_neox", + "llama", + "opt", + "pegasus", + "t5", + } + @staticmethod def cannot_support(model_type: str) -> bool: """ diff --git a/optimum/bettertransformer/models/decoder_models.py b/optimum/bettertransformer/models/decoder_models.py index bbc5a13135..bfb45ff317 100644 --- a/optimum/bettertransformer/models/decoder_models.py +++ b/optimum/bettertransformer/models/decoder_models.py @@ -72,7 +72,6 @@ def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) -# TODO: validate class GPTJAttentionLayerBetterTransformer(BetterTransformerBaseLayer, GPTJAttention, nn.Module): _attn = gpt2_wrapped_scaled_dot_product diff --git a/optimum/bettertransformer/transformation.py b/optimum/bettertransformer/transformation.py index 8520465bc7..0c724a8f45 100644 --- a/optimum/bettertransformer/transformation.py +++ b/optimum/bettertransformer/transformation.py @@ -287,12 +287,11 @@ def transform( model = dispatch_model(model, hf_device_map, offload_dir=offload_dir) # See: https://github.com/pytorch/pytorch/issues/96099 - # TODO: show the warning only for decoders (which do not need an attention mask for training) - if False: # BetterTransformerManager.is_decoder(model_fast.config.model_type): - logging.warning( - f"For decoder training, the BetterTransformer implementation for {model_fast.config.model_type} " - " architecture currently does not support padding as fused kernels do not support custom" - " attention masks. Beware that passing padded batched training data may result in unexpected outputs." + if model_fast.config.model_type in BetterTransformerManager.DO_NOT_SUPPORT_PADDED_TRAINING: + logger.warning( + f"For decoder models (here {model_fast.config.model_type}), the BetterTransformer implementation" + " does not support padding during training, as the fused kernels do not support" + " attention masks. Beware that passing padded batched data during training may result in unexpected outputs." ) # Overwrite the `save_pretrained` method From bac435b8691ab0ba691bfc6601d68fb1be7e423b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Wed, 26 Jul 2023 15:35:44 +0200 Subject: [PATCH 4/8] leave to an other PR the backward for some archs --- .../models/encoder_models.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/optimum/bettertransformer/models/encoder_models.py b/optimum/bettertransformer/models/encoder_models.py index 3eb4b45be6..8d9c88a3ea 100644 --- a/optimum/bettertransformer/models/encoder_models.py +++ b/optimum/bettertransformer/models/encoder_models.py @@ -145,7 +145,6 @@ def forward(self, hidden_states, attention_mask, *_): if hidden_states.is_nested and self.is_last_layer: hidden_states = hidden_states.to_padded_tensor(0.0) else: - # TODO: check dropout qkv = F.linear(hidden_states, weight=self.in_proj_weight, bias=self.in_proj_bias) qkv = qkv.view(qkv.size()[:-1] + (3, self.num_heads, self.attention_head_size)).permute(2, 0, 3, 1, 4) @@ -279,8 +278,6 @@ def __init__(self, bert_layer, config): "norm2_weight": "output.LayerNorm.weight", "norm2_bias": "output.LayerNorm.bias", } - - # TODO: cleaner solution self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.hidden_dropout_prob = config.hidden_dropout_prob self.attention_probs_dropout_prob = config.attention_probs_dropout_prob @@ -877,7 +874,6 @@ def forward(self, hidden_states, attn_mask, head_mask=None, output_attentions=No if hidden_states.is_nested and self.is_last_layer: hidden_states = hidden_states.to_padded_tensor(0.0) else: - # TODO: check dropout qkv = F.linear(hidden_states, weight=self.in_proj_weight, bias=self.in_proj_bias) qkv = qkv.view(qkv.size()[:-1] + (3, self.num_heads, self.attention_head_size)).permute(2, 0, 3, 1, 4) @@ -1041,7 +1037,9 @@ def forward(self, hidden_states, attention_mask, *_, **__): if hidden_states.is_nested and self.is_last_layer: hidden_states = hidden_states.to_padded_tensor(0.0) else: - raise NotImplementedError("TODO") + raise NotImplementedError( + "Training and Autocast are not implemented for BetterTransformer + Whisper. Please open an issue." + ) return (hidden_states,) @@ -1159,7 +1157,9 @@ def forward(self, hidden_states, *_, **__): if hidden_states.is_nested and self.is_last_layer: hidden_states = hidden_states.to_padded_tensor(0.0) else: - raise NotImplementedError("TODO") + raise NotImplementedError( + "Training and Autocast are not implemented for BetterTransformer + ViT. Please open an issue." + ) return (hidden_states,) @@ -1277,7 +1277,9 @@ def forward(self, hidden_states, *_, **__): if hidden_states.is_nested and self.is_last_layer: hidden_states = hidden_states.to_padded_tensor(0.0) else: - raise NotImplementedError("TODO") + raise NotImplementedError( + "Training and Autocast are not implemented for BetterTransformer + Vilt. Please open an issue." + ) return (hidden_states,) @@ -1402,7 +1404,9 @@ def forward(self, hidden_states, attention_mask, **__): if hidden_states.is_nested and self.is_last_layer: hidden_states = hidden_states.to_padded_tensor(0.0) else: - raise NotImplementedError("TODO") + raise NotImplementedError( + "Training and Autocast are not implemented for BetterTransformer + Wav2Vec2. Please open an issue." + ) return (hidden_states,) @@ -1788,7 +1792,9 @@ def forward(self, hidden_states, attention_mask, *_, **__): attention_mask, ) else: - raise NotImplementedError("TODO") + NotImplementedError( + "Training and Autocast are not implemented for BetterTransformer + CLIP. Please open an issue." + ) return (hidden_states,) From d1f160a11af035b9b1361da91dbbe0e34158e5ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Wed, 26 Jul 2023 15:39:18 +0200 Subject: [PATCH 5/8] nit --- optimum/bettertransformer/transformation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/bettertransformer/transformation.py b/optimum/bettertransformer/transformation.py index 0c724a8f45..018c30de4a 100644 --- a/optimum/bettertransformer/transformation.py +++ b/optimum/bettertransformer/transformation.py @@ -255,9 +255,9 @@ def transform( " `keep_original_model=False` and create a new copy of the original" " model somewhere else." ) - model_fast = replace_to_bettertransformer(model_fast, hf_config).eval() + model_fast = replace_to_bettertransformer(model_fast, hf_config) else: - model_fast = replace_to_bettertransformer(model, hf_config).eval() + model_fast = replace_to_bettertransformer(model, hf_config) model = None if BetterTransformerManager.requires_nested_tensor(model_fast.config.model_type): From c70a3dbba823702d3b9d0fb53427590cf990c046 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Wed, 26 Jul 2023 17:25:39 +0200 Subject: [PATCH 6/8] fix tests --- tests/bettertransformer/test_audio.py | 16 --------- tests/bettertransformer/test_decoder.py | 5 --- tests/bettertransformer/test_encoder.py | 15 -------- .../bettertransformer/test_encoder_decoder.py | 13 ------- tests/bettertransformer/test_vision.py | 12 ------- tests/bettertransformer/testing_utils.py | 35 +++---------------- 6 files changed, 5 insertions(+), 91 deletions(-) diff --git a/tests/bettertransformer/test_audio.py b/tests/bettertransformer/test_audio.py index 5d995ce439..595bf6c5a4 100644 --- a/tests/bettertransformer/test_audio.py +++ b/tests/bettertransformer/test_audio.py @@ -160,22 +160,6 @@ def test_logits(self, model_type: str): ), ) - @parameterized.expand(SUPPORTED_ARCH) - def test_raise_autocast(self, model_type: str): - model_ids = ( - MODELS_DICT[model_type] if isinstance(MODELS_DICT[model_type], tuple) else (MODELS_DICT[model_type],) - ) - for model_id in model_ids: - self._test_raise_autocast(model_id, model_type=model_type) - - @parameterized.expand(SUPPORTED_ARCH) - def test_raise_train(self, model_type: str): - model_ids = ( - MODELS_DICT[model_type] if isinstance(MODELS_DICT[model_type], tuple) else (MODELS_DICT[model_type],) - ) - for model_id in model_ids: - self._test_raise_train(model_id, model_type=model_type) - @parameterized.expand(grid_parameters(FULL_GRID)) def test_invert_modules(self, test_name: str, model_type: str, keep_original_model=False): if model_type in ["hubert", "wav2vec2"] and keep_original_model is True: diff --git a/tests/bettertransformer/test_decoder.py b/tests/bettertransformer/test_decoder.py index a19e64fcf2..a417216517 100644 --- a/tests/bettertransformer/test_decoder.py +++ b/tests/bettertransformer/test_decoder.py @@ -181,11 +181,6 @@ def test_generation(self, test_name: str, model_type: str, batch_size: int, padd f" Maxdiff: {(result_vanilla - result_bettertransformer).abs().max()}", ) - @parameterized.expand(SUPPORTED_ARCH) - def test_raise_autocast(self, model_type: str): - model_id = MODELS_DICT[model_type] - self._test_raise_autocast(model_id, model_type=model_type) - @parameterized.expand(SUPPORTED_ARCH) @pytest.mark.training def test_train(self, model_type: str): diff --git a/tests/bettertransformer/test_encoder.py b/tests/bettertransformer/test_encoder.py index 2c44177fe9..6a2e520276 100644 --- a/tests/bettertransformer/test_encoder.py +++ b/tests/bettertransformer/test_encoder.py @@ -209,21 +209,6 @@ def check_accelerate_compatibility_cpu_gpu(self, keep_original_model=True, max_m self.assertTrue(torch.allclose(output_bt[0][1, 3:], torch.zeros_like(output_bt[0][1, 3:]))) gc.collect() - @parameterized.expand(SUPPORTED_ARCH) - def test_raise_autocast(self, model_type: str): - if model_type == "rocbert": - self.skipTest( - "unrelated issue with torch.amp.autocast with rocbert (expected scalar type BFloat16 but found Float)" - ) - - model_id = MODELS_DICT[model_type] - self._test_raise_autocast(model_id, model_type) - - @parameterized.expand(SUPPORTED_ARCH) - def test_raise_train(self, model_type: str): - model_id = MODELS_DICT[model_type] - self._test_raise_train(model_id, model_type) - @pytest.mark.gpu_test @pytest.mark.accelerate_test def test_accelerate_compatibility_cpu_gpu(self): diff --git a/tests/bettertransformer/test_encoder_decoder.py b/tests/bettertransformer/test_encoder_decoder.py index f400d16967..df74ed03d2 100644 --- a/tests/bettertransformer/test_encoder_decoder.py +++ b/tests/bettertransformer/test_encoder_decoder.py @@ -89,19 +89,6 @@ def test_logits_backward(self, test_name: str, model_type: str, padding, max_len model_id = MODELS_DICT[model_type] self._test_logits_backward(model_id, model_type=model_type, padding=padding, max_length=max_length) - @parameterized.expand(SUPPORTED_ARCH) - def test_raise_autocast(self, model_type: str): - model_id = MODELS_DICT[model_type] - self._test_raise_autocast(model_id, model_type=model_type) - - @parameterized.expand(SUPPORTED_ARCH) - def test_raise_train(self, model_type: str): - model_id = MODELS_DICT[model_type] - if model_type not in ["blenderbot", "pegasus", "t5"]: - self._test_raise_train(model_id, model_type=model_type) - else: - self._test_train_decoder(model_id, model_type=model_type) - @parameterized.expand(grid_parameters(FULL_GRID)) def test_invert_modules(self, test_name: str, model_type: str, keep_original_model=False): model_id = MODELS_DICT[model_type] diff --git a/tests/bettertransformer/test_vision.py b/tests/bettertransformer/test_vision.py index 48410dff7b..ea04936fab 100644 --- a/tests/bettertransformer/test_vision.py +++ b/tests/bettertransformer/test_vision.py @@ -73,18 +73,6 @@ def test_logits(self, model_type: str): model_id = MODELS_DICT[model_type] self._test_logits(model_id, model_type=model_type) - @parameterized.expand(SUPPORTED_ARCH) - def test_raise_autocast(self, model_type: str): - model_id = MODELS_DICT[model_type] - self._test_raise_autocast(model_id, model_type=model_type) - - @parameterized.expand(SUPPORTED_ARCH) - def test_raise_train(self, model_type: str): - if model_type in ["blip-2"]: - self.skipTest("can be trained") - model_id = MODELS_DICT[model_type] - self._test_raise_train(model_id, model_type=model_type) - @parameterized.expand( grid_parameters( { diff --git a/tests/bettertransformer/testing_utils.py b/tests/bettertransformer/testing_utils.py index 0e8a7c7e06..c63d5d241e 100644 --- a/tests/bettertransformer/testing_utils.py +++ b/tests/bettertransformer/testing_utils.py @@ -108,9 +108,6 @@ class BetterTransformersTestMixin(unittest.TestCase): - `test_logits`: This tests if the converted model produces the same logits than the original model. - `test_raise_on_save`: Test if the converion properly raises an error if someone tries to save the model using `save_pretrained`. - - `test_raise_autocast`: A tests that checks if the conversion raises an error if the model is run under - `torch.cuda.amp.autocast`. - - `test_raise_train`: A tests that checks if the conversion raises an error if the model is run in training mode. """ def prepare_inputs_for_class(self, model_id=None, model_type=None): @@ -168,6 +165,7 @@ def _test_logits_backward(self, model_id: str, model_type: str, **preprocessor_k # `torch.random.set_rng_state`. An alternative could be to make dropout stateful, # and to replace them with a static pattern for this test. Currently, we use # functional dropout though. + # We need to be in train mode to take the right path. random_config = set_dropout_to_zero(random_config) # m2m_100 randomly drops layers, which makes testing flaky (see `skip_the_layer` in transformers, some other models use it as well) @@ -229,9 +227,13 @@ def _test_logits(self, model_id: str, model_type: str, **preprocessor_kwargs): hf_random_model = AutoModel.from_pretrained(model_id).eval() random_config = hf_random_model.config + hf_random_model = hf_random_model.eval() + torch.manual_seed(0) converted_model = BetterTransformer.transform(hf_random_model, keep_original_model=True) + self.assertFalse(hf_random_model.training) + self.assertFalse(converted_model.training) self.assertFalse( hasattr(hf_random_model, "use_bettertransformer"), f"The model {hf_random_model.__class__.__name__} has been converted to a `fast` model by mistake.", @@ -290,33 +292,6 @@ def assert_equal(self, tensor1, tensor2, atol: float, model_name: str): f" Maxdiff: {torch.abs(tensor1 - tensor2).max()}", ) - def _test_raise_autocast(self, model_id: str, model_type: str, **kwargs): - r""" - A tests that checks if the conversion raises an error if the model is run under - `torch.cuda.amp.autocast`. - """ - inputs = self.prepare_inputs_for_class(model_id=model_id, model_type=model_type, **kwargs) - hf_random_model = AutoModel.from_pretrained(model_id).eval() - - # Check for the autocast on CPU - with self.assertRaises(ValueError), torch.amp.autocast("cpu"): - bt_model = BetterTransformer.transform(hf_random_model, keep_original_model=True) - _ = bt_model(**inputs) - - def _test_raise_train(self, model_id: str, model_type: str, **kwargs): - r""" - A tests that checks if the conversion raises an error if the model is run under - `model.train()`. - """ - inputs = self.prepare_inputs_for_class(model_id=model_id, model_type=model_type, **kwargs) - - hf_random_model = AutoModel.from_pretrained(model_id).eval() - # Check for training mode - with self.assertRaises(ValueError): - bt_model = BetterTransformer.transform(hf_random_model, keep_original_model=True) - bt_model.train() - _ = bt_model(**inputs) - def _test_train_decoder(self, model_id: str, model_type: str, **kwargs): r""" A tests that checks if the training works as expected for decoder models. From dd67595588d16d3b69ddd098dd0a1bca1f603ef9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Wed, 26 Jul 2023 18:05:12 +0200 Subject: [PATCH 7/8] hopefully tests pass --- optimum/bettertransformer/transformation.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/optimum/bettertransformer/transformation.py b/optimum/bettertransformer/transformation.py index 018c30de4a..feba07a172 100644 --- a/optimum/bettertransformer/transformation.py +++ b/optimum/bettertransformer/transformation.py @@ -242,6 +242,8 @@ def transform( # Remove the hooks from the original model to avoid weights being on `meta` device. remove_hook_from_module(model, recurse=True) + training_mode = model.training + if keep_original_model: try: if not check_if_pytorch_greater(2.0, "Please upgrade PyTorch to >=2.0 to use training mode"): @@ -303,6 +305,11 @@ def transform( model_fast.save_pretrained = raise_save_or_push_incompatible model_fast.push_to_hub = raise_save_or_push_incompatible + if training_mode: + model_fast = model_fast.train() + else: + model_fast = model_fast.eval() + return model_fast def reverse(bt_model: "PreTrainedModel") -> "PreTrainedModel": From 0fcdff811bbc07a292a928092853c8e99b518ff1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Wed, 26 Jul 2023 18:23:33 +0200 Subject: [PATCH 8/8] fix --- tests/bettertransformer/test_encoder.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/bettertransformer/test_encoder.py b/tests/bettertransformer/test_encoder.py index 6a2e520276..1a152c7f6e 100644 --- a/tests/bettertransformer/test_encoder.py +++ b/tests/bettertransformer/test_encoder.py @@ -256,6 +256,7 @@ def test_accelerate_compatibility_single_gpu_without_keeping(self): ) ) def test_logits(self, test_name: str, model_type: str, batch_size: int): + # TODO: enable those tests if model_type in ["rocbert", "splinter", "markuplm", "bert-generation"]: self.skipTest(f"tiny tokenizers are broken on the Hub {model_type}") if model_type in ["tapas"]: @@ -273,6 +274,7 @@ def test_logits(self, test_name: str, model_type: str, batch_size: int): ) ) def test_logits_backward(self, test_name: str, model_type: str, batch_size: int): + # TODO: enable those tests if model_type in ["rocbert", "splinter", "markuplm", "bert-generation"]: self.skipTest(f"tiny tokenizer is broken on the Hub for {model_type}") if model_type in ["tapas"]: @@ -293,6 +295,12 @@ def test_save_load_invertible(self, test_name: str, model_type: str, keep_origin @parameterized.expand(grid_parameters(FULL_GRID)) def test_invert_model_logits(self, test_name: str, model_type: str, keep_original_model=False): + # TODO: reenable those tests + if model_type in ["rocbert", "splinter", "markuplm", "bert-generation"]: + self.skipTest(f"tiny tokenizers are broken on the Hub {model_type}") + if model_type in ["tapas"]: + self.skipTest(f"{model_type} requires dataframe") + model_id = MODELS_DICT[model_type] self._test_invert_model_logits( model_id=model_id, model_type=model_type, keep_original_model=keep_original_model