Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/transformers/commands/add_new_model_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,16 @@ def create_new_model_like(
if "tokenization" not in str(f) and "processor" not in str(f) and "feature_extraction" not in str(f)
]

def disable_fx_test(filename: Path) -> bool:
with open(filename) as fp:
content = fp.read()
new_content = re.sub(r"fx_compatible\s*=\s*True", "fx_compatible = False", content)
with open(filename, "w") as fp:
fp.write(new_content)
return content != new_content

disabled_fx_test = False

for test_file in files_to_adapt:
new_test_file_name = test_file.name.replace(
old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
Expand All @@ -1201,6 +1211,13 @@ def create_new_model_like(
dest_file=dest_file,
add_copied_from=False,
)
disabled_fx_test = disabled_fx_test | disable_fx_test(dest_file)

if disabled_fx_test:
print(
"The tests for symbolic tracing with torch.fx were disabled, you can add those once symbolic tracing works "
"for your new model."
)
Comment on lines +1217 to +1220
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally this would use the logger

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed what was done in the script, but can definitely change that to logger if needed.


# 4. Add model to auto classes
add_model_to_auto_classes(old_model_patterns, new_model_patterns, model_classes)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"

# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
TORCH_FX_REQUIRED_VERSION = version.parse("1.9")
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, is it possible to support many different versions, or are there breaking changes in torch.fx that we have to support one version at a time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can check for torch 1.9, the plan from now on is to support torch 1.10 + as fx became stable starting at this version (still need to validate that with pytorch team).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, sounds good to me

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And you probably need to change this line from == to >=.

TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8")

_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False
Expand Down
46 changes: 25 additions & 21 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,27 @@ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:

return encoder_extended_attention_mask

def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask, device):
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
Comment on lines +252 to +253
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
causal_mask = torch.tril(torch.ones(batch_size, seq_length, seq_length, dtype=torch.bool, device=device))

Unrelated to this PR, but constructing a triangular matrix should be a bit more simple IMO (unless I'm missing something) ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice if we keep the code as is for now to make sure to not break anything here accidentally. Could you also run T5's and Bart's SLOW tests to be sure nothing is broken with the attention mask?

# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)

if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
causal_mask,
],
axis=-1,
)

extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
return extended_attention_mask

def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Expand All @@ -271,26 +292,9 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)

if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[
torch.ones(
(batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype
),
causal_mask,
],
axis=-1,
)

extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
extended_attention_mask = self.create_extended_attention_mask_for_decoder(
input_shape, attention_mask, device
)
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
Expand Down Expand Up @@ -1835,7 +1839,7 @@ def __init__(self, nf, nx):
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
x = x.view(size_out)
return x


Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __init__(self, config):
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def prune_heads(self, heads):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def __init__(self, config, position_embedding_type=None):

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -341,7 +341,7 @@ def forward(

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def __init__(self, config, position_embedding_type=None):

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -334,7 +334,7 @@ def forward(

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))

if self.scale_attn_weights:
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
attn_weights = attn_weights / (value.size(-1) ** 0.5)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this backwards compatible?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, this doesn't cause any problems.

When we do tracing, python values cause several problems.
I don't think there is any reason to change this value to a Python value.

Copy link
Contributor

@JingyaHuang JingyaHuang Jul 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change seems to cause the fail on mixed-precision training gpt-2 with ONNX Runtime backend. Link to the reported issue #11279.


# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
Expand Down Expand Up @@ -281,7 +281,7 @@ def _split_heads(self, tensor, num_heads, attn_head_size):
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)

def _merge_heads(self, tensor, num_heads, attn_head_size):
Expand Down Expand Up @@ -915,7 +915,7 @@ def custom_forward(*inputs):

hidden_states = self.ln_f(hidden_states)

hidden_states = hidden_states.view(*output_shape)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -1410,7 +1410,7 @@ def forward(
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[range(batch_size), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]

loss = None
if labels is not None:
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _split_heads(self, tensor, num_heads, attn_head_size):
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)

def _merge_heads(self, tensor, num_heads, attn_head_size):
Expand Down Expand Up @@ -637,7 +637,7 @@ def custom_forward(*inputs):

hidden_states = self.ln_f(hidden_states)

hidden_states = hidden_states.view(*output_shape)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -891,7 +891,7 @@ def forward(
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[torch.arange(batch_size), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]

loss = None
if labels is not None:
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary):
Splits hidden dim into attn_head_size and num_attention_heads
"""
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
tensor = tensor.view(*new_shape)
tensor = tensor.view(new_shape)
if rotary:
return tensor
if len(tensor.shape) == 5:
Expand Down Expand Up @@ -665,7 +665,7 @@ def custom_forward(*inputs):

hidden_states = self.ln_f(hidden_states)

hidden_states = hidden_states.view(*output_shape)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -945,7 +945,7 @@ def forward(
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[range(batch_size), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]

loss = None
if labels is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/layoutlm/modeling_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self, config, position_embedding_type=None):

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -249,7 +249,7 @@ def forward(

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def __init__(self, config, position_embedding_type=None):

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -312,7 +312,7 @@ def forward(

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/mobilebert/modeling_mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def __init__(self, config):

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -274,7 +274,7 @@ def forward(
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/realm/modeling_realm.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def __init__(self, config, position_embedding_type=None):

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -349,7 +349,7 @@ def forward(

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/roberta/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __init__(self, config, position_embedding_type=None):

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -276,7 +276,7 @@ def forward(

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/splinter/modeling_splinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(self, config, position_embedding_type=None):

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -216,7 +216,7 @@ def forward(

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def __init__(self, config, position_embedding_type=None):

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -270,7 +270,7 @@ def forward(

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down
Loading