-
Notifications
You must be signed in to change notification settings - Fork 31.9k
FX tracing improvement #14321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FX tracing improvement #14321
Changes from all commits
0788780
f730608
a146048
aef4d00
f7a69eb
b0e3d96
636099b
505333d
74f74d7
e171ed2
6faf263
83aedfc
ee75d02
bdb1ef8
4f22de3
e566f16
a90d319
ae60baf
9ef5813
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -322,7 +322,7 @@ | |
| HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}" | ||
|
|
||
| # This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs. | ||
| TORCH_FX_REQUIRED_VERSION = version.parse("1.9") | ||
| TORCH_FX_REQUIRED_VERSION = version.parse("1.10") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of curiosity, is it possible to support many different versions, or are there breaking changes in torch.fx that we have to support one version at a time?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can check for torch 1.9, the plan from now on is to support torch 1.10 + as fx became stable starting at this version (still need to validate that with pytorch team).
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, sounds good to me There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And you probably need to change this line from |
||
| TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8") | ||
|
|
||
| _is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -247,6 +247,27 @@ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: | |||||||
|
|
||||||||
| return encoder_extended_attention_mask | ||||||||
|
|
||||||||
| def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask, device): | ||||||||
| batch_size, seq_length = input_shape | ||||||||
| seq_ids = torch.arange(seq_length, device=device) | ||||||||
| causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] | ||||||||
|
||||||||
| seq_ids = torch.arange(seq_length, device=device) | |
| causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] | |
| causal_mask = torch.tril(torch.ones(batch_size, seq_length, seq_length, dtype=torch.bool, device=device)) |
Unrelated to this PR, but constructing a triangular matrix should be a bit more simple IMO (unless I'm missing something) ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice if we keep the code as is for now to make sure to not break anything here accidentally. Could you also run T5's and Bart's SLOW tests to be sure nothing is broken with the attention mask?
michaelbenayoun marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -193,7 +193,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): | |
| attn_weights = torch.matmul(query, key.transpose(-1, -2)) | ||
|
|
||
| if self.scale_attn_weights: | ||
| attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) | ||
| attn_weights = attn_weights / (value.size(-1) ** 0.5) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this backwards compatible?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my opinion, this doesn't cause any problems. When we do tracing, python values cause several problems.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change seems to cause the fail on mixed-precision training gpt-2 with ONNX Runtime backend. Link to the reported issue #11279. |
||
|
|
||
| # Layer-wise attention scaling | ||
| if self.scale_attn_by_inverse_layer_idx: | ||
|
|
@@ -281,7 +281,7 @@ def _split_heads(self, tensor, num_heads, attn_head_size): | |
| Splits hidden_size dim into attn_head_size and num_heads | ||
| """ | ||
| new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) | ||
| tensor = tensor.view(*new_shape) | ||
| tensor = tensor.view(new_shape) | ||
| return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) | ||
|
|
||
| def _merge_heads(self, tensor, num_heads, attn_head_size): | ||
|
|
@@ -915,7 +915,7 @@ def custom_forward(*inputs): | |
|
|
||
| hidden_states = self.ln_f(hidden_states) | ||
|
|
||
| hidden_states = hidden_states.view(*output_shape) | ||
| hidden_states = hidden_states.view(output_shape) | ||
| # Add last hidden state | ||
| if output_hidden_states: | ||
| all_hidden_states = all_hidden_states + (hidden_states,) | ||
|
|
@@ -1410,7 +1410,7 @@ def forward( | |
| f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" | ||
| ) | ||
|
|
||
| pooled_logits = logits[range(batch_size), sequence_lengths] | ||
| pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] | ||
|
|
||
| loss = None | ||
| if labels is not None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally this would use the
loggerThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed what was done in the script, but can definitely change that to logger if needed.