Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,14 @@ class BetterTransformerManager:
"t5",
}

REQUIRES_TORCH_20 = {
DO_NOT_SUPPORT_PADDED_TRAINING = {
"blenderbot",
"bart",
"codegen",
"gpt2",
"gptj",
"gpt_neo",
"gpt_neox",
"llama",
"m2m_100",
"marian",
"mbart",
"opt",
"pegasus",
"t5",
Expand Down Expand Up @@ -209,17 +205,6 @@ def requires_strict_validation(model_type: str) -> bool:
"""
return model_type not in BetterTransformerManager.NOT_REQUIRES_STRICT_VALIDATION

@staticmethod
def requires_torch_20(model_type: str) -> bool:
"""
Returns True if the architecture requires PyTorch 2.0 to be used with BetterTransformer.

Args:
model_type (`str`):
The model type to check.
"""
return model_type in BetterTransformerManager.REQUIRES_TORCH_20


class warn_uncompatible_save(object):
def __init__(self, callback):
Expand Down
11 changes: 0 additions & 11 deletions optimum/bettertransformer/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(
self.num_layers = None
self.original_layers_mapping = {}
self.module_mapping = None
self.supports_training = False
# Some models does not have some attributes thus needs to be ignored
# e.g. whisper does not have self_attn.k_proj.bias but has self_attn.v_proj.bias & self_attn.q_proj.bias
self.keys_to_ignore = []
Expand Down Expand Up @@ -127,16 +126,6 @@ def validate_bettertransformer(self):
f" Number of heads must be even."
)

def forward_checker(self, *args, **kwargs):
if torch.is_autocast_enabled() or torch.is_autocast_cpu_enabled():
raise ValueError("Autocast is not supported for `BetterTransformer` integration.")

if self.training and not self.supports_training:
raise ValueError(
"Training is not supported for `BetterTransformer` integration.",
" Please use `model.eval()` before running the model.",
)

def _revert(self, module: torch.nn.Module) -> torch.nn.Module:
if self.module_mapping is not None:
if "" in self.module_mapping.values():
Expand Down
25 changes: 0 additions & 25 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,13 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
setattr(self, "q_attn", getattr(layer, "q_attn"))
self.original_layers_mapping["q_attn"] = "q_attn"

self.supports_training = True
self.downcast_qk = False
self.dropout_prob_attn = config.attn_pdrop

def forward(self, *args, **kwargs):
super().forward_checker()
return super().forward(*args, **kwargs)


# TODO: validate
class GPTJAttentionLayerBetterTransformer(BetterTransformerBaseLayer, GPTJAttention, nn.Module):
_attn = gpt2_wrapped_scaled_dot_product

Expand Down Expand Up @@ -105,11 +102,9 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
self.original_layers_mapping = {submodule: submodule for submodule in submodules}

self.downcast_qk = True
self.supports_training = True
self.dropout_prob_attn = config.attn_pdrop

def forward(self, *args, **kwargs):
super().forward_checker()
return super().forward(*args, **kwargs)


Expand All @@ -129,11 +124,9 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
self.original_layers_mapping = {submodule: submodule for submodule in submodules}

self.downcast_qk = True
self.supports_training = True
self.dropout_prob_attn = 0.0 # no dropout for gpt-neox

def forward(self, *args, **kwargs):
super().forward_checker()
return super().forward(*args, **kwargs)


Expand All @@ -159,11 +152,9 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
self.original_layers_mapping = {submodule: submodule for submodule in submodules}

self.scale = torch.sqrt(torch.tensor(layer.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
self.supports_training = True
self.dropout_prob_attn = float(config.attention_dropout)

def forward(self, *args, **kwargs):
super().forward_checker()
return super().forward(*args, **kwargs)


Expand All @@ -188,11 +179,9 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

self.original_layers_mapping = {submodule: submodule for submodule in submodules}

self.supports_training = True
self.dropout_prob_attn = config.attn_pdrop

def forward(self, *args, **kwargs):
super().forward_checker()
return super().forward(*args, **kwargs)


Expand All @@ -218,10 +207,7 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

self.original_layers_mapping = {submodule: submodule for submodule in submodules}

self.supports_training = True

def forward(self, *args, **kwargs):
super().forward_checker()
return opt_forward(self, *args, **kwargs)


Expand Down Expand Up @@ -249,11 +235,9 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

self.module_mapping = None

self.supports_training = True
self.is_decoder = layer.is_decoder

def forward(self, *args, **kwargs):
super().forward_checker()
return t5_forward(self, *args, **kwargs)


Expand All @@ -274,7 +258,6 @@ def bart_bettertransformer_init(self, layer: "nn.Module", config: "PretrainedCon

self.original_layers_mapping = {submodule: submodule for submodule in submodules}

self.supports_training = True
self.is_decoder = layer.is_decoder


Expand All @@ -284,7 +267,6 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
bart_bettertransformer_init(self, layer, config)

def forward(self, *args, **kwargs):
super().forward_checker()
return bart_forward(self, *args, **kwargs)


Expand All @@ -294,7 +276,6 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
bart_bettertransformer_init(self, layer, config)

def forward(self, *args, **kwargs):
super().forward_checker()
return bart_forward(self, *args, **kwargs)


Expand All @@ -304,7 +285,6 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
bart_bettertransformer_init(self, layer, config)

def forward(self, *args, **kwargs):
super().forward_checker()
return bart_forward(self, *args, **kwargs)


Expand All @@ -314,7 +294,6 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
bart_bettertransformer_init(self, layer, config)

def forward(self, *args, **kwargs):
super().forward_checker()
return bart_forward(self, *args, **kwargs)


Expand All @@ -323,7 +302,6 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
bart_bettertransformer_init(self, layer, config)

def forward(self, *args, **kwargs):
super().forward_checker()
return bart_forward(self, *args, **kwargs)


Expand All @@ -339,8 +317,5 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

self.original_layers_mapping = {submodule: submodule for submodule in submodules}

self.supports_training = True

def forward(self, *args, **kwargs):
super().forward_checker()
return llama_forward(self, *args, **kwargs)
Loading