diff --git a/src/transformers/commands/add_new_model_like.py b/src/transformers/commands/add_new_model_like.py index e443a235c42e..3ba5d71099de 100644 --- a/src/transformers/commands/add_new_model_like.py +++ b/src/transformers/commands/add_new_model_like.py @@ -1189,6 +1189,16 @@ def create_new_model_like( if "tokenization" not in str(f) and "processor" not in str(f) and "feature_extraction" not in str(f) ] + def disable_fx_test(filename: Path) -> bool: + with open(filename) as fp: + content = fp.read() + new_content = re.sub(r"fx_compatible\s*=\s*True", "fx_compatible = False", content) + with open(filename, "w") as fp: + fp.write(new_content) + return content != new_content + + disabled_fx_test = False + for test_file in files_to_adapt: new_test_file_name = test_file.name.replace( old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased @@ -1201,6 +1211,13 @@ def create_new_model_like( dest_file=dest_file, add_copied_from=False, ) + disabled_fx_test = disabled_fx_test | disable_fx_test(dest_file) + + if disabled_fx_test: + print( + "The tests for symbolic tracing with torch.fx were disabled, you can add those once symbolic tracing works " + "for your new model." + ) # 4. Add model to auto classes add_model_to_auto_classes(old_model_patterns, new_model_patterns, model_classes) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 36d4a005a8d8..d2a4e260dbe5 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -322,7 +322,7 @@ HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}" # This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs. -TORCH_FX_REQUIRED_VERSION = version.parse("1.9") +TORCH_FX_REQUIRED_VERSION = version.parse("1.10") TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8") _is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9a1a1ccfaf09..d02062a210a4 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -247,6 +247,27 @@ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: return encoder_extended_attention_mask + def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask, device): + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + return extended_attention_mask + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. @@ -271,26 +292,9 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple # - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder: - batch_size, seq_length = input_shape - seq_ids = torch.arange(seq_length, device=device) - causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] - # in case past_key_values are used we need to add a prefix ones mask to the causal mask - # causal and attention masks must have same type with pytorch version < 1.3 - causal_mask = causal_mask.to(attention_mask.dtype) - - if causal_mask.shape[1] < attention_mask.shape[1]: - prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] - causal_mask = torch.cat( - [ - torch.ones( - (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype - ), - causal_mask, - ], - axis=-1, - ) - - extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + extended_attention_mask = self.create_extended_attention_mask_for_decoder( + input_shape, attention_mask, device + ) else: extended_attention_mask = attention_mask[:, None, None, :] else: @@ -1835,7 +1839,7 @@ def __init__(self, nf, nx): def forward(self, x): size_out = x.size()[:-1] + (self.nf,) x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) - x = x.view(*size_out) + x = x.view(size_out) return x diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 6f443fb4f8bc..a54a3874adf2 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -293,7 +293,7 @@ def __init__(self, config): # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def prune_heads(self, heads): diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 23dfbcee63a9..26c629f78f16 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -252,7 +252,7 @@ def __init__(self, config, position_embedding_type=None): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -341,7 +341,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 054eff4be016..a61045f9c0aa 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -245,7 +245,7 @@ def __init__(self, config, position_embedding_type=None): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -334,7 +334,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index b1988d7edf8e..59df99e8ab91 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -193,7 +193,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: - attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) + attn_weights = attn_weights / (value.size(-1) ** 0.5) # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: @@ -281,7 +281,7 @@ def _split_heads(self, tensor, num_heads, attn_head_size): Splits hidden_size dim into attn_head_size and num_heads """ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(*new_shape) + tensor = tensor.view(new_shape) return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) def _merge_heads(self, tensor, num_heads, attn_head_size): @@ -915,7 +915,7 @@ def custom_forward(*inputs): hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(*output_shape) + hidden_states = hidden_states.view(output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -1410,7 +1410,7 @@ def forward( f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) - pooled_logits = logits[range(batch_size), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] loss = None if labels is not None: diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 7176cfa790b2..c516ca57a1e5 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -173,7 +173,7 @@ def _split_heads(self, tensor, num_heads, attn_head_size): Splits hidden_size dim into attn_head_size and num_heads """ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(*new_shape) + tensor = tensor.view(new_shape) return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) def _merge_heads(self, tensor, num_heads, attn_head_size): @@ -637,7 +637,7 @@ def custom_forward(*inputs): hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(*output_shape) + hidden_states = hidden_states.view(output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -891,7 +891,7 @@ def forward( f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) - pooled_logits = logits[torch.arange(batch_size), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] loss = None if labels is not None: diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 869014bee626..66163ad49fd0 100755 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -107,7 +107,7 @@ def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): Splits hidden dim into attn_head_size and num_attention_heads """ new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) - tensor = tensor.view(*new_shape) + tensor = tensor.view(new_shape) if rotary: return tensor if len(tensor.shape) == 5: @@ -665,7 +665,7 @@ def custom_forward(*inputs): hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(*output_shape) + hidden_states = hidden_states.view(output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -945,7 +945,7 @@ def forward( f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) - pooled_logits = logits[range(batch_size), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] loss = None if labels is not None: diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index d595fc8b517a..bbdfeaac83fc 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -160,7 +160,7 @@ def __init__(self, config, position_embedding_type=None): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -249,7 +249,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index dbfb76cb5d52..292b920bf5d5 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -223,7 +223,7 @@ def __init__(self, config, position_embedding_type=None): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -312,7 +312,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 2a90f1d92aff..acf9607a7367 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -237,7 +237,7 @@ def __init__(self, config): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -274,7 +274,7 @@ def forward( context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) return outputs diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index 165e62c0ef6c..118916413863 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -260,7 +260,7 @@ def __init__(self, config, position_embedding_type=None): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -349,7 +349,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 426095e03ed6..88f0aa8d29ec 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -187,7 +187,7 @@ def __init__(self, config, position_embedding_type=None): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -276,7 +276,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index b982a38b62f4..3d15cc682556 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -127,7 +127,7 @@ def __init__(self, config, position_embedding_type=None): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -216,7 +216,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index cdea06ac57b6..cfeb788ec62e 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -181,7 +181,7 @@ def __init__(self, config, position_embedding_type=None): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -270,7 +270,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 23a2eb4c1fab..f9cdc407aeb6 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -1,8 +1,24 @@ -import copy +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import functools import inspect +import math import random -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from types import ModuleType +from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union import torch from packaging import version @@ -26,17 +42,11 @@ GPT2DoubleHeadsModel, PretrainedConfig, PreTrainedModel, + XLNetForQuestionAnswering, logging, ) from ..file_utils import TORCH_FX_REQUIRED_VERSION, importlib_metadata, is_torch_fx_available from ..models.auto import get_values -from .fx_transformations import ( - _cache_attributes, - _patch_arguments_, - _restore_attributes_, - transform_to_dynamic_input_, - transformation, -) logger = logging.get_logger(__name__) @@ -46,6 +56,7 @@ def _generate_supported_model_classes( model_name: Type[PretrainedConfig], supported_tasks: Optional[Union[str, List[str]]] = None, ) -> List[Type[PreTrainedModel]]: + model_config_class = CONFIG_MAPPING[model_name] task_mapping = { "default": MODEL_MAPPING, @@ -86,15 +97,10 @@ def _generate_supported_model_classes( "gptj", "gpt_neo", "t5", -] - -_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES = [ - "albert", - "bert", - "distilbert", - "mobilebert", - "electra", - "megatron-bert", + "roberta", + # TODO: add support for them as it should be quite easy to do so (small blocking issues). + # "layoutlm", + # "xlnet", ] _REGULAR_SUPPORTED_MODELS = [] @@ -106,21 +112,11 @@ def _generate_supported_model_classes( _SPECIAL_SUPPORTED_MODELS = [ GPT2DoubleHeadsModel, + # TODO: add support for them as it should be quite easy to do so (small blocking issues). + # XLNetForQuestionAnswering, ] _SUPPORTED_MODELS = tuple(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS) -_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = [] -for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES: - if isinstance(item, dict): - _REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(**item)) - else: - _REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(item)) - -_SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = [] -_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = tuple( - _REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES + _SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES -) - class HFProxy(Proxy): """ @@ -134,6 +130,7 @@ def __init__(self, node: Node, tracer: Optional[Tracer] = None): if hasattr(self, "tracer") and self.tracer is not None: self.device = self.tracer.root.device self.dtype = next(self.tracer.root.parameters()).dtype + self.cache = None @property def shape(self): @@ -145,42 +142,54 @@ def __setitem__(self, key, value): def __contains__(self, key): return False + def __eq__(self, other): + if self.cache is not None: + return self.cache == other + elif isinstance(other, HFProxy): + return True + else: + return super().__eq__(other) -def _wrap_method_for_model_recording(model, method_name, cache_name): - """Helper function that wraps a torch.Tensor method to record its outputs during forward pass.""" - method = getattr(torch.Tensor, method_name) + def __ne__(self, other): + return not self == other - @functools.wraps(method) - def wrapped(*args, **kwargs): - if not hasattr(model, cache_name): - setattr(model, cache_name, []) - cache = getattr(model, cache_name) - res = method(*args, **kwargs) - cache.append(res) - return res + def __len__(self): + if self.cache is not None: + if isinstance(self.cache, int): + return self.cache + elif isinstance(self.cache, (torch.Size, list, tuple)): + return len(self.cache) + else: + return super().__len__(self) + return super().__len__(self) - return wrapped + def __torch_function__(self, orig_method, types, args=None, kwargs=None): + proxy = super().__torch_function__(orig_method, types, args=args, kwargs=kwargs) + proxy.cache = self.cache + return proxy -def _create_recorded_proxy_method(proxy, method_name, cache_name): - """ - Helper function that sets a recorded torch.Tensor method as a HFProxy method that will use the recorded values - during symbolic tracing. - """ +def _function_to_leaf(func: Callable[..., Any]) -> Callable[..., Any]: + """Wrapper that marks func as a leaf function, meaning that it will not be traced through by HFTracer.""" - def method(self, *args, **kwargs): - cache = getattr(self.tracer.root, cache_name) - res = cache.pop(0) - return res + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper - method.__name__ = method_name - bound_method = method.__get__(proxy, proxy.__class__) - setattr(proxy, method_name, bound_method) +def _function_leaf_getter(func_name: str, mapping: Dict[str, Callable[..., Any]]) -> Callable[..., Any]: + @functools.wraps(mapping[func_name]) + def wrapper(*args, **kwargs): + return mapping[func_name](*args, **kwargs) -def _wrap_method_for_model_tracing(model, method_name, cache_name): + return wrapper + + +def _create_recorded_proxy_method(proxy: HFProxy, method_name: str, cache_name: str, return_proxy: bool): """ - Helper function that sets a recorded torch.Tensor method as a torch.Tensor method that will use the recorded values + Helper function that sets a recorded torch.Tensor method as a HFProxy method that will use the recorded values during symbolic tracing. """ @@ -188,55 +197,69 @@ def _wrap_method_for_model_tracing(model, method_name, cache_name): @functools.wraps(original_method) def method(*args, **kwargs): - cache = getattr(model, cache_name) + cache = getattr(args[0].tracer.root, cache_name) res = cache.pop(0) + if return_proxy: + proxy = args[0].__torch_function__( + original_method, + None, + args=args, + kwargs=kwargs, + ) + proxy.cache = res + return proxy return res - setattr(torch.Tensor, method_name, method) - - if method_name == "size": - setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name))) - - -def _monkey_patch_tensor_methods_for_model_recording(model, method_names): - """ - Helper function that patches torch.Tensor methods (specified by the method_names list) to record model inference - before symbolic tracing. - """ - cache_names = dict() - original_methods = dict() - for method_name in method_names: - cache_name = f"cache_{method_name}" - cache_names[method_name] = cache_name - if not hasattr(torch.Tensor, method_name): - logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.") - continue - original_methods[method_name] = getattr(torch.Tensor, method_name) - setattr(torch.Tensor, method_name, _wrap_method_for_model_recording(model, method_name, cache_name)) - - if method_name == "size": - original_methods["shape"] = torch.Tensor.shape - setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name))) - - return cache_names, original_methods + method.__name__ = method_name + bound_method = method.__get__(proxy, proxy.__class__) + setattr(proxy, method_name, bound_method) -def _reset_tensor_methods(original_methods): +def _reset_tensor_methods(original_methods: Dict[str, Callable[..., Any]]): """Helper function that resets the monkey patched torch.Tensor methods to their original values.""" for name, method in original_methods.items(): setattr(torch.Tensor, name, method) +def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None): + if forbidden_values is None: + forbidden_values = [] + value = random.randint(low, high) + while value in forbidden_values: + value = random.randint(low, high) + return value + + class HFTracer(Tracer): """ Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the regular PyTorch torch.fx.Proxy. """ - default_methods_to_record = {"__bool__", "size", "dim"} + _DEFAULT_METHODS_TO_RECORD = {"__bool__": False, "size": True, "dim": False} + from transformers import modeling_utils + + _FUNCTIONS_TO_AUTOWRAP = { + torch: {"arange", "zeros", "ones", "full_like", "eye"}, + modeling_utils.ModuleUtilsMixin: {"create_extended_attention_mask_for_decoder"}, + } + + def __init__(self, autowrap_modules=(math,), autowrap_functions=(), enable_cpatching=False): + + # Loading the leaf functions register + self._leaf_functions_register = {} + for module, names in self._FUNCTIONS_TO_AUTOWRAP.items(): + for name in names: + self._register_leaf_function(module, name) + + # TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer. + # autowrap_functions = autowrap_functions + tuple( + # patched for (_, _, patched) in self._leaf_functions_register.values() + # ) - def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1): - super().__init__() + super().__init__( + autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions, enable_cpatching=enable_cpatching + ) if not is_torch_fx_available(): torch_version = version.parse(importlib_metadata.version("torch")) @@ -245,40 +268,107 @@ def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1): f"{TORCH_FX_REQUIRED_VERSION} is supported." ) - encoder_sequence_length = sequence_length[0] if isinstance(sequence_length, (list, tuple)) else sequence_length - decoder_sequence_length = ( - sequence_length[1] if isinstance(sequence_length, (list, tuple)) else encoder_sequence_length - ) - self.encoder_shape = [batch_size, encoder_sequence_length] - self.decoder_shape = ( - [batch_size, decoder_sequence_length] if decoder_sequence_length > 0 else list(self.encoder_shape) - ) - self.num_choices = num_choices - if self.num_choices > 0: - self.encoder_shape = [batch_size, self.num_choices, encoder_sequence_length] - self.decoder_shape = [batch_size, self.num_choices, decoder_sequence_length] - self.prev_module = None self.recorded_methods = None - def proxy(self, node: Node): - p = HFProxy(node, self) - if self.recorded_methods: - for method_name, cache_name in self.recorded_methods.items(): - _create_recorded_proxy_method(p, method_name, cache_name) - return p + def _register_leaf_function(self, module: ModuleType, name: str): + """Registers the function called name in module as a leaf function.""" + orig_func = getattr(module, name) + patched_func = _function_to_leaf(orig_func) + patched_func.__module__ = __name__ + self._leaf_functions_register[name] = (module, orig_func, patched_func) + + def _patch_leaf_functions_for_root(self, root: PreTrainedModel, restore: bool = False): + """Patches leaf functions specifically for root.""" + for name in self._leaf_functions_register: + module, orig_func, patched_func = self._leaf_functions_register[name] + if restore: + root.__class__.forward.__globals__.pop(name) + setattr(module, name, orig_func) + else: + root.__class__.forward.__globals__[name] = patched_func + leaf_getter = _function_leaf_getter(name, root.__class__.forward.__globals__) + leaf_getter.__module__ = __name__ + setattr(module, name, leaf_getter) + + def _method_is_called_in_leaf_module(self, module_ids: List[int]) -> bool: + """ + Finds out if the method (that is being recorded) is called inside a leaf module, this allows to not record + outputs that will not be encountered by the tracer. + """ + + currentframe = inspect.currentframe() + while currentframe: + if currentframe is None: + return False + module = currentframe.f_locals.get("self", None) + if id(module) in module_ids and self.is_leaf_module(module, "Not used anyway"): + return True + currentframe = currentframe.f_back + return False + + def _wrap_method_for_model_recording( + self, model: PreTrainedModel, method_name: str, cache_name: str, module_ids: List[int] + ): + """Helper function that wraps a torch.Tensor method to record its outputs during forward pass.""" + method = getattr(torch.Tensor, method_name) + + @functools.wraps(method) + def wrapped(*args, **kwargs): + if self._method_is_called_in_leaf_module(module_ids): + return method(*args, **kwargs) + if not hasattr(model, cache_name): + setattr(model, cache_name, []) + cache = getattr(model, cache_name) + res = method(*args, **kwargs) + cache.append(res) + return res + + return wrapped + + def _monkey_patch_tensor_methods_for_model_recording(self, model: PreTrainedModel, method_names: Iterable[str]): + """ + Helper function that patches torch.Tensor methods (specified by the method_names list) to record model + inference before symbolic tracing. + """ + cache_names = {} + original_methods = {} + module_ids = set(id(mod) for mod in model.modules()) + for method_name in method_names: + cache_name = f"cache_{method_name}" + cache_names[method_name] = cache_name + if not hasattr(torch.Tensor, method_name): + logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.") + continue + original_methods[method_name] = getattr(torch.Tensor, method_name) + setattr( + torch.Tensor, + method_name, + self._wrap_method_for_model_recording(model, method_name, cache_name, module_ids), + ) - def _generate_dummy_input(self, model, input_name): + if method_name == "size": + original_methods["shape"] = torch.Tensor.shape + setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name))) + + return cache_names, original_methods + + def _generate_dummy_input( + self, model: PreTrainedModel, input_name: str, shape: List[int] + ) -> Dict[str, torch.Tensor]: """Generates dummy input for model inference recording.""" model_class = model.__class__ device = model.device - inputs_dict = dict() + inputs_dict = {} if input_name in ["labels", "start_positions", "end_positions"]: - batch_size = self.encoder_shape[0] + batch_size = shape[0] if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): - inputs_dict["labels"] = torch.ones(batch_size, dtype=torch.long, device=device) - elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING): + inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) + elif model_class in [ + *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING), + XLNetForQuestionAnswering, + ]: inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) elif model_class in [ @@ -288,59 +378,56 @@ def _generate_dummy_input(self, model, input_name): ]: inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) elif model_class in [ + *get_values(MODEL_FOR_PRETRAINING_MAPPING), *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING), *get_values(MODEL_FOR_CAUSAL_LM_MAPPING), *get_values(MODEL_FOR_MASKED_LM_MAPPING), *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING), GPT2DoubleHeadsModel, ]: - inputs_dict["labels"] = torch.zeros(self.decoder_shape, dtype=torch.long, device=device) - elif model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING): - inputs_dict["labels"] = torch.zeros(self.encoder_shape, dtype=torch.long, device=device) + inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device) else: raise NotImplementedError(f"{model_class} not supported yet.") elif "mask" in input_name or "ids" in input_name: - shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape - inputs_dict[input_name] = torch.ones(shape, dtype=torch.long, device=device) + inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device) else: - shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape - shape += [model.config.hidden_size] - inputs_dict[input_name] = torch.ones(shape, dtype=torch.float, device=device) + shape_with_hidden_size = shape + [model.config.hidden_size] + inputs_dict[input_name] = torch.zeros(shape_with_hidden_size, dtype=torch.float, device=device) return inputs_dict - def record(self, model, input_names, method_names=None): + def record(self, model: PreTrainedModel, input_names: List[str], method_names: Optional[Iterable[str]] = None): """ - Records torch.Tensor method outputs (specified by the method_names list) that will then be used during symbolic - tracing. + Records torch.Tensor method outputs (specified by method_names) that will then be used during symbolic tracing. """ if method_names is None: - method_names = self.default_methods_to_record + method_names = self._DEFAULT_METHODS_TO_RECORD + + # Creating a random input shape to generate dummy inputs. + batch_size = _generate_random_int() + sequence_length = _generate_random_int() + shape = [batch_size, sequence_length] + + if model.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): + num_choices = _generate_random_int(low=2, high=5) + shape.insert(1, num_choices) inputs = {} for input_name in input_names: - inputs.update(self._generate_dummy_input(model, input_name)) + inputs.update(self._generate_dummy_input(model, input_name, shape)) - clone = copy.deepcopy(model) - cache_names, original_methods = _monkey_patch_tensor_methods_for_model_recording(clone, method_names) + cache_names, original_methods = self._monkey_patch_tensor_methods_for_model_recording(model, method_names) self.original_methods = original_methods - clone(**inputs) - - # Useful because sometime the config is changed at inference time, for instance for - # classification tasks where config.problem_type can be set. - model.config = clone.config + model(**inputs) _reset_tensor_methods(original_methods) self.recorded_methods = { - method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(clone, cache_name) + method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(model, cache_name) } - for cache_name in self.recorded_methods.values(): - setattr(model, cache_name, getattr(clone, cache_name)) - def _module_getattr(self, attr, attr_val, parameter_proxy_cache): if isinstance(attr_val, torch.nn.Parameter): for n, p in self.root.named_parameters(): @@ -357,7 +444,20 @@ def _module_getattr(self, attr, attr_val, parameter_proxy_cache): return parameter_proxy_cache[n] return attr_val - def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None, method_names=None) -> Graph: + def proxy(self, node: Node): + p = HFProxy(node, self) + if self.recorded_methods: + for method_name, cache_name in self.recorded_methods.items(): + return_proxy = self._DEFAULT_METHODS_TO_RECORD[method_name] + _create_recorded_proxy_method(p, method_name, cache_name, return_proxy) + return p + + def trace( + self, + root: PreTrainedModel, + concrete_args: Optional[Dict[str, Any]] = None, + method_names: Optional[Iterable[str]] = None, + ) -> Graph: if concrete_args is None: concrete_args = {} @@ -366,11 +466,16 @@ def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = self.record(root, input_names, method_names=method_names) - for method_name, cache_name in self.recorded_methods.items(): - _wrap_method_for_model_tracing(root, method_name, cache_name) + # TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer. + autowrap_functions = [patched for (_, _, patched) in self._leaf_functions_register.values()] + self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions])) + + self._patch_leaf_functions_for_root(root) graph = super().trace(root, concrete_args=concrete_args) + self._patch_leaf_functions_for_root(root, restore=True) + _reset_tensor_methods(self.original_methods) # TODO: keep this until necessary. @@ -388,7 +493,7 @@ def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = return graph - def _insert_module_as_submodule(self, mod): + def _insert_module_as_submodule(self, mod: nn.Module) -> str: """ Helper method which tries to insert a module that was not declared as submodule. """ @@ -434,72 +539,19 @@ def path_of_module(self, mod: nn.Module) -> str: self.prev_module = path return path + def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool: + is_loss_module = m.__module__.startswith("torch.nn.modules.loss") + return (not is_loss_module) and super().is_leaf_module(m, module_qualified_name) + def create_arg(self, a: Any) -> Argument: if isinstance(a, range): return super().create_arg(list(a)) return super().create_arg(a) -@transformation -def prepare_for_retracing(gm: GraphModule) -> Tuple[GraphModule, Dict[str, Any]]: - """ - Prepares a GraphModule produced by symbolic_trace for retracing by: - - - Caching all the attributes specific to the way the model was initially traced - - Patching back the model to a "static input shapes" version if it was traced to accept dynamic input shapes - For instance, the need to retrace a GraphModule can happen when applying quantization. - """ - attributes = _cache_attributes(gm) - _patch_arguments_(gm, gm.dynamic2static) - - return gm, attributes - - -def restore_after_retracing_(gm: GraphModule, attributes: Dict[str, Any]): - """Restores a GraphModule that was retraced to its initial state in terms of static / dynamic input shapes.""" - _restore_attributes_(gm, attributes) - # transform_to_dynamic_input_ will override the static2dynamic and dynamic2static dictionaries which is the desired - # behaviour as the previously restored dictionaries contain nodes from the original GraphModule as values. - transform_to_dynamic_input_(gm, is_retracing=True) - _patch_arguments_(gm, gm.static2dynamic) - return gm - - -def retrace_graph_with( - gm: GraphModule, tracer: Tracer = None, func: Callable[[GraphModule], GraphModule] = None -) -> GraphModule: - """ - Retraces a GraphModule by either using a tracer or a function using a tracer (for instance - torch.quantization.quantize_fx.prepare_fx). It takes care of preparing the model for retracing, retracing it and - restoring anything necessary after the retrace. - """ - if tracer is None and func is None: - raise ValueError("Either a tracer or a function using a tracer must be provided.") - elif tracer is not None and func is not None: - raise ValueError("Either provide a tracer or a function using a tracer, but not both.") - else: - gm, attributes = prepare_for_retracing(gm) - tracing_func = tracer.trace if tracer else func - traced = tracing_func(gm) - restore_after_retracing_(traced, attributes) - return traced - - -def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None): - if forbidden_values is None: - forbidden_values = [] - value = random.randint(low, high) - while value in forbidden_values: - value = random.randint(low, high) - return value - - def symbolic_trace( model: PreTrainedModel, input_names: Optional[List[str]] = None, - batch_size: int = 1, - sequence_length: Union[int, List[int], Tuple[int]] = (128, 128), - num_choices: int = -1, ) -> GraphModule: """ @@ -510,89 +562,33 @@ def symbolic_trace( The model to trace. input_names (`List[str]`, *optional*): The names of the inputs of the traced model. If unset, model.dummy_inputs().keys() are used instead. - batch_size (`int`, *optional*, defaults to 1): - The batch size of the traced model inputs. - sequence_length (`int` or `List[int]]`): - The sequence length of the traced model inputs. For sequence-to-sequence models with different sequence - lengths between the encoder and the decoder inputs, this must be `[encoder_sequence_length, - decoder_sequence_length]`. - num_choices (`int`, *optional*, defaults to -1): - The number of possible choices for a multiple choice task. Returns: `torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model. Example: - ```python - from transformers.utils.fx import symbolic_trace + ```python + from transformers.utils.fx import symbolic_trace - traced_model = symbolic_trace( - model, - input_names=["input_ids", "attention_mask", "token_type_ids"], - batch_size=1, - sequence_length=128, - ) - ```""" + traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"]) + ``` + """ if input_names is None: input_names = model.dummy_inputs.keys() sig = inspect.signature(model.forward) concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} - # Preparing HFTracer batch_size and sequence_lenght values for potential dynamic axes. - use_dynamic_batch_size = batch_size <= 0 - if isinstance(sequence_length, (list, tuple)): - use_dynamic_sequence_length = sequence_length[0] <= 0 or sequence_length[1] <= 0 - else: - use_dynamic_sequence_length = sequence_length <= 0 - - if use_dynamic_batch_size or use_dynamic_sequence_length: - forbidden_values = [ - model.config.num_attention_heads, - model.config.hidden_size, - model.config.hidden_size // model.config.num_attention_heads, - ] - if use_dynamic_batch_size: - batch_size = _generate_random_int(forbidden_values=forbidden_values) - forbidden_values.append(batch_size) - if use_dynamic_sequence_length: - encoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values) - forbidden_values.append(encoder_sequence_length) - decoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values) - sequence_length = [encoder_sequence_length, decoder_sequence_length] - if not isinstance(model, _SUPPORTED_MODELS): supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS)) raise NotImplementedError( f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}" ) - if (use_dynamic_batch_size or use_dynamic_sequence_length) and not isinstance( - model, _SUPPORTED_MODELS_FOR_DYNAMIC_AXES - ): - supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS_FOR_DYNAMIC_AXES)) - raise NotImplementedError( - f"Dynamic axes are not supported for {model.__class__.__name__} yet, supported models: {supported_model_names}" - ) # Tracing. - tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices) - + tracer = HFTracer() traced_graph = tracer.trace(model, concrete_args=concrete_args) traced = torch.fx.GraphModule(model, traced_graph) - traced.config = copy.deepcopy(model.config) - traced.num_choices = num_choices - traced.dummy_inputs = {} - - for name in input_names: - traced.dummy_inputs.update(tracer._generate_dummy_input(model, name)) - - traced.use_dynamic_batch_size = use_dynamic_batch_size - traced.use_dynamic_sequence_length = use_dynamic_sequence_length - traced.static_batch_size = batch_size - traced.static_sequence_length = sequence_length - - transform_to_dynamic_input_(traced) - return traced diff --git a/src/transformers/utils/fx_transformations.py b/src/transformers/utils/fx_transformations.py deleted file mode 100644 index 3e181617af10..000000000000 --- a/src/transformers/utils/fx_transformations.py +++ /dev/null @@ -1,321 +0,0 @@ -import copy -import functools -import operator -from inspect import signature -from typing import Any, Callable, Dict, Optional, Union - -import torch -from torch.fx import Graph, GraphModule, Node - - -# Torch FX transformation convention: -# - transformations that are supposed to act on a copy of the original GraphModule are decorated with @transformation -# - transformations that are inplace have a name ending with "_" - - -def _cache_attributes(gm: GraphModule) -> Dict[str, Any]: - attributes_to_keep = [ - "config", - "num_choices", - "dummy_inputs", - "use_dynamic_batch_size", - "use_dynamic_sequence_length", - "static_batch_size", - "static_sequence_length", - "static2dynamic", - "dynamic2static", - ] - attributes = {k: getattr(gm, k, None) for k in attributes_to_keep} - return attributes - - -def _restore_attributes_(gm: GraphModule, attributes: Dict[str, Any]): - for name, attr in attributes.items(): - setattr(gm, name, attr) - - -def deepcopy_graph(gm: GraphModule) -> GraphModule: - """ - Performs a deepcopy of the GraphModule while also copying the relevant attributes to know whether the model was - traced with dynamic axes, and what were the values if that is the case. - """ - - # First, create a copy of the module without the graph. - graph = gm.__dict__.pop("_graph") - fake_mod = torch.nn.Module() - fake_mod.__dict__ = copy.deepcopy(gm.__dict__) - gm.__dict__["_graph"] = graph - - # Then, copy the graph. - val_map = {} - graph_clone = Graph() - output_val = graph_clone.graph_copy(graph, val_map=val_map) - graph_clone.output(output_val) - - # Finally create a new GraphModule (or a subclass of GraphModule) from the module and the graph copies. - # gm.__class__ is used to take into account that gm can be an instance of a subclass of GraphModule. - clone = gm.__class__(fake_mod, graph_clone) - - # Restore the dynamic axes related attributes to the clone. - attributes = _cache_attributes(gm) - attributes["dynamic2static"] = {val_map.get(k, k): v for k, v in attributes["dynamic2static"].items()} - attributes["static2dynamic"] = {v: k for k, v in attributes["dynamic2static"].items()} - _restore_attributes_(clone, attributes) - - return clone - - -def transformation(func): - """ - Decorator that wraps a torch.fx transformation by feeding it a copy of the GraphModule to transform instead of the - original. - """ - - def map_fn(arg): - if isinstance(arg, GraphModule): - return deepcopy_graph(arg) - return arg - - @functools.wraps(func) - def wrapper(*args, **kwargs): - new_args = tuple(map_fn(arg) for arg in args) - new_kwargs = {k: map_fn(v) for k, v in kwargs.items()} - return func(*new_args, **new_kwargs) - - wrapper._is_transformation = True - - return wrapper - - -def compose_transformations( - *args: Callable[[GraphModule], Optional[GraphModule]], inplace: bool = False -) -> GraphModule: - """ - Allows to compose transformations together and takes of: - - 1. Performing the transformations on a copy of the GraphModule if inplace is set to False, transformations that - are decorated with @transformation (which means that they are not modifying the original GraphModule) are - unwrapped to make them inplace. - 2. Linting and recompiling only at the end of the composition for performance purposes. - """ - args = list(args) - if not inplace: - args.insert(0, deepcopy_graph) - - for i, transformation in enumerate(args[:-1]): - sig = signature(transformation) - - # Unwrapping @transformation decorated transformations as performing the transformations inplace or on a copy is - # already handled by this function. - if getattr(transformation, "_is_transformation", False): - transformation = transformation.__wrapped__ - - # Linting and recompiling only after the last transformation applied to make composition efficient. - if "lint_and_recompile" in sig.parameters: - args[i] = functools.partial(transformation, lint_and_recompile=False) - - def reduce_func(f, g): - def compose_f_and_g(gm): - output_g = g(gm) - if output_g is None: - output_g = gm - output_f = f(output_g) - if output_f is None: - output_f = gm - return output_f - - return compose_f_and_g - - return functools.reduce(reduce_func, reversed(args), lambda x: x) - - -def remove_unused_nodes_(gm: GraphModule, lint_and_recompile: bool = True): - """Removes all the unused nodes in a GraphModule.""" - graph = gm.graph - for node in graph.nodes: - if not node.users and node.op not in ["placeholder", "output"]: - graph.erase_node(node) - - if lint_and_recompile: - graph.lint() - gm.recompile() - - -def _insert_batch_size_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node: - """Inserts a node that retrieves the batch size dynamically from the input of the model.""" - graph = gm.graph - input_names = set(gm.dummy_inputs.keys()) - batch_size_node = None - for node in graph.nodes: - if node.op == "placeholder" and node.name in input_names: - with graph.inserting_after(node): - batch_size_node = graph.call_method("size", args=(node, 0)) - - if batch_size_node is None: - raise ValueError("Could not insert the node that computes the batch size") - - if lint_and_recompile: - graph.lint() - gm.recompile() - - # Useful when retracing for quantization. - if hasattr(gm, "_qconfig_map"): - gm._qconfig_map[batch_size_node.name] = None - - return batch_size_node - - -def _insert_encoder_sequence_length_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node: - """Inserts a node that retrieves the encoder sequence length dynamically from the input of the model.""" - graph = gm.graph - input_names = set(gm.dummy_inputs.keys()) - encoder_sequence_length_node = None - for node in graph.nodes: - if node.op == "placeholder" and node.name in input_names and "decoder" not in node.name: - with graph.inserting_after(node): - # There are two cases to handle: - # 1. num_choices < 0, meaning that the model is not performing a "multiple choice" task, in this case the - # input shapes is [batch_size, sequence_length] => index 1 - # 2. num_choices > 0, meaning the model is performing a "multiple choice" task, in this case the input - # shape is [batch_size, num_choices, sequence_length] => index 2 - encoder_sequence_length_node = graph.call_method("size", args=(node, 1 if gm.num_choices < 0 else 2)) - - if encoder_sequence_length_node is None: - raise ValueError("Could not insert the node that computes the encoder sequence length") - - if lint_and_recompile: - graph.lint() - gm.recompile() - - # Useful when retracing for quantization. - if hasattr(gm, "_qconfig_map"): - gm._qconfig_map[encoder_sequence_length_node.name] = None - - return encoder_sequence_length_node - - -def _change_view_methods_( - gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True -): - """ - Changes arguments of view ops that refer to static batch size / sequence lengths to make them refer to the - batch_size / sequence_length nodes. - """ - graph = gm.graph - for node in graph.nodes: - if node.op == "call_method" and node.target == "view": - if isinstance(node.args[1], tuple): - node.args = (node.args[0], *node.args[1]) - node.args = tuple((mapping.get(arg, arg) for arg in node.args)) - - if lint_and_recompile: - graph.lint() - gm.recompile() - - -def _patch_getitem_( - gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True -): - """Patches getitem nodes by replacing current arguments to their corresponding values in mapping.""" - # TODO: combine this with the patch_argument function which seems to do almost the same thing. - graph = gm.graph - for node in graph.nodes: - if node.op == "call_function" and node.target == operator.getitem: - indices = node.args[1] - if isinstance(indices, tuple): - new_indices = [] - for idx in indices: - if isinstance(idx, slice): - new_indices.append( - slice( - mapping.get(idx.start, idx.start), - mapping.get(idx.stop, idx.stop), - mapping.get(idx.step, idx.step), - ) - ) - elif isinstance(idx, int): - new_indices.append(mapping.get(idx, idx)) - else: - new_indices.append(idx) - - node.args = (node.args[0], tuple(new_indices)) - else: - node.args = (node.args[0], mapping.get(node.args[1], node.args[1])) - - if lint_and_recompile: - graph.lint() - gm.recompile() - - -def _patch_arguments_( - gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True -): - """ - Patches node by replacing their argument to their corresponding values in mapping (supports regular types, tuples - and slices). - """ - - def _patch_slice(s, mapping): - return slice(mapping.get(s.start, s.start), mapping.get(s.stop, s.stop), mapping.get(s.step, s.step)) - - graph = gm.graph - supported_types = (Node, str, int, float) - for node in graph.nodes: - new_args = [] - for arg in node.args: - if isinstance(arg, tuple): - new_arg = [] - for a in arg: - if isinstance(a, slice): - new_arg.append(_patch_slice(a, mapping)) - else: - new_arg.append(mapping.get(a, a)) - new_args.append(tuple(new_arg)) - elif isinstance(arg, slice): - new_args.append(_patch_slice(arg, mapping)) - elif isinstance(arg, supported_types): - new_args.append(mapping.get(arg, arg)) - else: - new_args.append(arg) - node.args = tuple(new_args) - - if lint_and_recompile: - graph.lint() - gm.recompile() - - -def transform_to_dynamic_input_(gm: GraphModule, is_retracing: bool = False): - """Transformation that enables traced models to perform inference on dynamic input shapes.""" - graph = gm.graph - static2dynamic = {} - - # Inserting the nodes that will fetch the batch size and sequence lengths dynamically. - if gm.use_dynamic_batch_size: - batch_size_node = _insert_batch_size_node_(gm, lint_and_recompile=False) - static2dynamic[gm.static_batch_size] = batch_size_node - if gm.num_choices > 0: - with graph.inserting_after(batch_size_node): - static2dynamic[gm.static_batch_size * gm.num_choices] = graph.call_function( - operator.mul, args=(batch_size_node, gm.num_choices) - ) - # Useful when retracing for quantization. - if hasattr(gm, "_qconfig_map"): - gm._qconfig_map[static2dynamic[gm.static_batch_size * gm.num_choices]] = None - - if gm.use_dynamic_sequence_length: - encoder_sequence_length_node = _insert_encoder_sequence_length_node_(gm, lint_and_recompile=False) - static2dynamic[gm.static_sequence_length[0]] = encoder_sequence_length_node - - # TODO: do the same for the decoder. - pass - - _change_view_methods_(gm, static2dynamic, lint_and_recompile=False) - _patch_getitem_(gm, static2dynamic, lint_and_recompile=False) - - remove_unused_nodes_(gm, lint_and_recompile=False) - - graph.lint() - gm.recompile() - - gm.static2dynamic = static2dynamic - gm.dynamic2static = {v: k for (k, v) in static2dynamic.items()} diff --git a/tests/test_modeling_albert.py b/tests/test_modeling_albert.py index d16dcadd5e6a..ab5595f4b6f8 100644 --- a/tests/test_modeling_albert.py +++ b/tests/test_modeling_albert.py @@ -231,8 +231,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) - fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes + fx_compatible = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index 7a6628509799..7b8738fd60f3 100755 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -444,8 +444,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): else () ) all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () - fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes + fx_compatible = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1df5e9e0f061..2ca59c3f0c9e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -108,8 +108,7 @@ class ModelTesterMixin: model_tester = None all_model_classes = () all_generative_model_classes = () - fx_ready_model_classes = () - fx_dynamic_ready_model_classes = () + fx_compatible = False test_torchscript = True test_pruning = True test_resize_embeddings = True @@ -658,19 +657,14 @@ def test_torch_fx_output_loss(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True) - def test_torch_fx_dynamic_axes(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - self._create_and_check_torch_fx_tracing(config, inputs_dict, dynamic_axes=True) - - def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False, dynamic_axes=False): - if not is_torch_fx_available(): + def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): + if not is_torch_fx_available() or not self.fx_compatible: return configs_no_init = _config_zero_init(config) # To be sure we have no Nan configs_no_init.return_dict = False - model_classes = self.fx_ready_model_classes if not dynamic_axes else self.fx_dynamic_ready_model_classes - for model_class in model_classes: + for model_class in self.all_model_classes: model = model_class(config=configs_no_init) model.to(torch_device) model.eval() @@ -679,8 +673,6 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa try: if model.config.is_encoder_decoder: model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward - input_ids = inputs["input_ids"] - decoder_attention_mask = inputs["decoder_attention_mask"] labels = inputs.get("labels", None) input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] if labels is not None: @@ -689,17 +681,7 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa model_output = model(**filtered_inputs) - batch_size = input_ids.shape[0] - encoder_sequence_length = input_ids.shape[1] - decoder_sequence_length = decoder_attention_mask.shape[1] - - traced_model = symbolic_trace( - model, - input_names, - batch_size=batch_size if not dynamic_axes else -1, - sequence_length=[encoder_sequence_length, decoder_sequence_length] if not dynamic_axes else -1, - ) - + traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) else: input_names = ["input_ids", "attention_mask", "token_type_ids"] @@ -721,23 +703,12 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa model_output = model(**filtered_inputs) rank = len(input_ids.shape) - if rank == 2: - batch_size, sequence_length = input_ids.shape - num_choices = -1 - elif rank == 3: - batch_size, num_choices, sequence_length = input_ids.shape - else: + if rank not in [2, 3]: raise NotImplementedError( f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}." ) - traced_model = symbolic_trace( - model, - input_names, - batch_size=batch_size if not dynamic_axes else -1, - sequence_length=sequence_length if not dynamic_axes else -1, - num_choices=num_choices, - ) + traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) except RuntimeError: diff --git a/tests/test_modeling_distilbert.py b/tests/test_modeling_distilbert.py index ee8a8cbd3dd9..b81e42bcf175 100644 --- a/tests/test_modeling_distilbert.py +++ b/tests/test_modeling_distilbert.py @@ -209,8 +209,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else None ) - fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes + fx_compatible = True test_pruning = True test_torchscript = True test_resize_embeddings = True diff --git a/tests/test_modeling_electra.py b/tests/test_modeling_electra.py index be19f8d610db..065d59682693 100644 --- a/tests/test_modeling_electra.py +++ b/tests/test_modeling_electra.py @@ -369,10 +369,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) - all_generative_model_classes = (ElectraForCausalLM,) if is_torch_available() else () - - fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes + fx_compatible = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index ef51c815e455..cd13be27bbc3 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -433,7 +433,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ) all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () - fx_ready_model_classes = all_model_classes + fx_compatible = True test_missing_keys = False test_model_parallel = True diff --git a/tests/test_modeling_gpt_neo.py b/tests/test_modeling_gpt_neo.py index a8e5b4babc57..b8f942ef1786 100644 --- a/tests/test_modeling_gpt_neo.py +++ b/tests/test_modeling_gpt_neo.py @@ -372,7 +372,7 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase (GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else () ) all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else () - fx_ready_model_classes = all_model_classes + fx_compatible = True test_missing_keys = False test_pruning = False test_model_parallel = False diff --git a/tests/test_modeling_gptj.py b/tests/test_modeling_gptj.py index dd743b80d76a..d6b9f9292621 100644 --- a/tests/test_modeling_gptj.py +++ b/tests/test_modeling_gptj.py @@ -363,7 +363,7 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): else () ) all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else () - fx_ready_model_classes = all_model_classes + fx_compatible = True test_pruning = False test_missing_keys = False test_model_parallel = False diff --git a/tests/test_modeling_megatron_bert.py b/tests/test_modeling_megatron_bert.py index a7f47ddea322..7ac507988fe0 100644 --- a/tests/test_modeling_megatron_bert.py +++ b/tests/test_modeling_megatron_bert.py @@ -283,9 +283,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) - fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes - + fx_compatible = True # test_resize_embeddings = False test_head_masking = False diff --git a/tests/test_modeling_mobilebert.py b/tests/test_modeling_mobilebert.py index 716714157a76..6ca14526a6dc 100644 --- a/tests/test_modeling_mobilebert.py +++ b/tests/test_modeling_mobilebert.py @@ -269,8 +269,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) - fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes + fx_compatible = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_roberta.py b/tests/test_modeling_roberta.py index d0a8aab6b78e..1a55fda15292 100644 --- a/tests/test_modeling_roberta.py +++ b/tests/test_modeling_roberta.py @@ -356,6 +356,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas else () ) all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else () + fx_compatible = True def setUp(self): self.model_tester = RobertaModelTester(self) diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 575850aa9014..c0b5739bca7f 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -509,7 +509,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () - fx_ready_model_classes = all_model_classes + fx_compatible = True all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () test_pruning = False test_torchscript = True diff --git a/tests/test_modeling_xlnet.py b/tests/test_modeling_xlnet.py index 5516b28e17e1..f4e90fbe7749 100644 --- a/tests/test_modeling_xlnet.py +++ b/tests/test_modeling_xlnet.py @@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) all_generative_model_classes = ( (XLNetLMHeadModel,) if is_torch_available() else () ) # TODO (PVP): Check other models whether language generation is also applicable + test_pruning = False # XLNet has 2 QA models -> need to manually set the correct labels for one of them here