From 76d76998430dd73edc63c150e9ed03604de24d91 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Sun, 21 Apr 2024 16:58:38 +0000 Subject: [PATCH 01/14] Initial commit --- src/transformers/models/t5/modeling_t5.py | 198 +++++++++++++++++++++- 1 file changed, 190 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 224769fdfefd..b9cb9ddc1b4a 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -25,6 +25,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -572,10 +578,157 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): return outputs +class T5SdpaAttention(T5Attention): + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + if output_attentions or layer_head_mask is not None: + logger.warning_once( + "T5Model is using T5SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + mask=mask, + key_value_states=key_value_states, + position_bias=position_bias, + past_key_value=past_key_value, + layer_head_mask=layer_head_mask, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + if len(past_key_value) != 2: + raise ValueError( + f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + ) + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), + device=query_states.device, + dtype=query_states.dtype, + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length, device=query_states.device) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + attn_output = unshape( + torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=position_bias_masked, + dropout_p=self.dropout if self.training else 0.0, + scale=1.0, + ) + ) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + return outputs + + +T5_ATTENTION_CLASSES = { + "eager": T5Attention, + "sdpa": T5SdpaAttention, +} + + class T5LayerSelfAttention(nn.Module): def __init__(self, config, has_relative_attention_bias=False): super().__init__() - self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = T5_ATTENTION_CLASSES[config._attn_implementation]( + config, has_relative_attention_bias=has_relative_attention_bias + ) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -607,7 +760,9 @@ def forward( class T5LayerCrossAttention(nn.Module): def __init__(self, config): super().__init__() - self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) + self.EncDecAttention = T5_ATTENTION_CLASSES[config._attn_implementation]( + config, has_relative_attention_bias=False + ) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -795,6 +950,7 @@ class T5PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["T5Block"] _keep_in_fp32_modules = ["wo"] + _supports_sdpa = True @property def dummy_inputs(self): @@ -908,6 +1064,7 @@ def __init__(self, config, embed_tokens=None): ) self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + self._use_sdpa = config._attn_implementation == "sdpa" # Initialize weights and apply final processing self.post_init() @@ -1012,9 +1169,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = input_shape - - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if use_cache is True: if not self.is_decoder: @@ -1025,11 +1180,27 @@ def forward( past_key_values = [None] * len(self.block) if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + attention_mask = torch.ones(batch_size, seq_length, device=inputs_embeds.device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self._use_sdpa and head_mask is None and not output_attentions: + extended_attention_mask = ( + _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + if self.is_decoder + else _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + ) + else: + extended_attention_mask = ( + _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length) + if self.is_decoder + else _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + ) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1040,7 +1211,18 @@ def forward( encoder_attention_mask = torch.ones( encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long ) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + if self._use_sdpa and cross_attn_head_mask is None and not output_attentions: + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + else: + encoder_extended_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) else: encoder_extended_attention_mask = None From bf3fb94f0927189947596a94c65331ccb394b7d2 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Sat, 1 Jun 2024 13:17:51 +0000 Subject: [PATCH 02/14] Use contiguous for attn_mask --- src/transformers/models/t5/modeling_t5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index b9cb9ddc1b4a..6c84b53a14c7 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -704,7 +704,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): query_states, key_states, value_states, - attn_mask=position_bias_masked, + attn_mask=position_bias_masked.contiguous(), dropout_p=self.dropout if self.training else 0.0, scale=1.0, ) From 1cb1af674da5bf359927124d13ec3aa741757078 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Sat, 1 Jun 2024 13:34:03 +0000 Subject: [PATCH 03/14] Address comment about duplicated code --- src/transformers/models/t5/modeling_t5.py | 142 ++++++++++------------ 1 file changed, 61 insertions(+), 81 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 6c84b53a14c7..09e03aee00e1 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -448,6 +448,41 @@ def compute_bias(self, query_length, key_length, device=None): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values + def _shape(self, states, batch_size): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def _unshape(self, states, batch_size): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def _project(self, hidden_states, proj_layer, key_value_states, past_key_value, batch_size): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(hidden_states), batch_size) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(key_value_states), batch_size) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(key_value_states), batch_size) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + def forward( self, hidden_states, @@ -479,49 +514,16 @@ def forward( key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + query_states = self._shape( + self.q(hidden_states), batch_size + ) # (batch_size, n_heads, seq_length, dim_per_head) # get key/value states - key_states = project( + key_states = self._project( hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None ) - value_states = project( + value_states = self._project( hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None ) @@ -567,7 +569,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self._unshape( + torch.matmul(attn_weights, value_states), batch_size + ) # (batch_size, seq_length, dim) attn_output = self.o(attn_output) present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None @@ -626,50 +630,25 @@ def forward( key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + query_states = self._shape( + self.q(hidden_states), batch_size + ) # (batch_size, n_heads, seq_length, dim_per_head) # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + key_states = self._project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + batch_size, ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + value_states = self._project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + batch_size, ) if position_bias is None: @@ -699,7 +678,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): else: position_bias_masked = position_bias - attn_output = unshape( + attn_output = self._unshape( torch.nn.functional.scaled_dot_product_attention( query_states, key_states, @@ -707,7 +686,8 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): attn_mask=position_bias_masked.contiguous(), dropout_p=self.dropout if self.training else 0.0, scale=1.0, - ) + ), + batch_size, ) attn_output = self.o(attn_output) From ff5142b4f169fea6cbc3820b40966ed28ab44c9d Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Sat, 1 Jun 2024 13:36:40 +0000 Subject: [PATCH 04/14] Fix --- src/transformers/models/t5/modeling_t5.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 09e03aee00e1..4220a5aa6f0b 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -521,10 +521,10 @@ def forward( # get key/value states key_states = self._project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None, batch_size ) value_states = self._project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None, batch_size ) # compute scores From 53d2fbb521f2088b5e4533f5d2f4b973623cff6f Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Sat, 1 Jun 2024 13:44:00 +0000 Subject: [PATCH 05/14] Fix style --- src/transformers/models/t5/modeling_t5.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 4220a5aa6f0b..3f9dc896487b 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -521,10 +521,18 @@ def forward( # get key/value states key_states = self._project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None, batch_size + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + batch_size, ) value_states = self._project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None, batch_size + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + batch_size, ) # compute scores From 151751f98e30ea2f059eeaa921847b47b28ef343 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Mon, 3 Jun 2024 12:32:26 +0000 Subject: [PATCH 06/14] Fix stride issue --- src/transformers/models/t5/modeling_t5.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 3f9dc896487b..bf0163490f48 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -686,12 +686,16 @@ def forward( else: position_bias_masked = position_bias + # spda kernels require tensors to have stride=1 in the last dimension + # .contiguous() does not behave correctly for tensors with singleton dimensions + # .clone(memory_format=torch.contiguous_format) is a workaround + position_bias_masked = position_bias_masked.clone(memory_format=torch.contiguous_format) attn_output = self._unshape( torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, - attn_mask=position_bias_masked.contiguous(), + attn_mask=position_bias_masked, dropout_p=self.dropout if self.training else 0.0, scale=1.0, ), From a54cea430f6015dc7c6fd8a7a965bbb744f08e1f Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Sat, 15 Jun 2024 11:37:36 +0000 Subject: [PATCH 07/14] Add link to issue --- src/transformers/models/t5/modeling_t5.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index bf0163490f48..40b0f118edff 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -688,8 +688,12 @@ def forward( # spda kernels require tensors to have stride=1 in the last dimension # .contiguous() does not behave correctly for tensors with singleton dimensions - # .clone(memory_format=torch.contiguous_format) is a workaround - position_bias_masked = position_bias_masked.clone(memory_format=torch.contiguous_format) + # the following is a workaround as suggested in https://github.com/pytorch/pytorch/issues/127523 + if position_bias_masked.stride(-1) != 1: + position_bias_masked = torch.empty_like(position_bias_masked, memory_format=torch.contiguous_format).copy_( + position_bias_masked + ) + attn_output = self._unshape( torch.nn.functional.scaled_dot_product_attention( query_states, From e7499ace2db23e768d656e0b17dfc4350a46153d Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Sat, 15 Jun 2024 12:01:53 +0000 Subject: [PATCH 08/14] Fix default attention_mask --- src/transformers/models/t5/modeling_t5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 40b0f118edff..60f4d446668f 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1176,7 +1176,7 @@ def forward( past_key_values = [None] * len(self.block) if attention_mask is None: - attention_mask = torch.ones(batch_size, seq_length, device=inputs_embeds.device) + attention_mask = torch.ones(batch_size, past_key_values_length + seq_length, device=inputs_embeds.device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. From 43e8035ddaa1e4f839a10ca431f02076981c365e Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Sat, 15 Jun 2024 14:10:03 +0000 Subject: [PATCH 09/14] Fix copies --- .../models/longt5/modeling_longt5.py | 96 +++--- src/transformers/models/mt5/modeling_mt5.py | 277 ++++++++++++++---- .../models/pix2struct/modeling_pix2struct.py | 2 + .../models/pop2piano/modeling_pop2piano.py | 234 ++++++++++++--- .../modeling_switch_transformers.py | 96 +++--- src/transformers/models/udop/modeling_udop.py | 96 +++--- 6 files changed, 586 insertions(+), 215 deletions(-) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index b2a6ed11ca57..7f8d02cd60da 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -420,6 +420,41 @@ def compute_bias(self, query_length, key_length, device=None): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values + def _shape(self, states, batch_size): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def _unshape(self, states, batch_size): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def _project(self, hidden_states, proj_layer, key_value_states, past_key_value, batch_size): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(hidden_states), batch_size) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(key_value_states), batch_size) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(key_value_states), batch_size) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + def forward( self, hidden_states, @@ -451,50 +486,25 @@ def forward( key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + query_states = self._shape( + self.q(hidden_states), batch_size + ) # (batch_size, n_heads, seq_length, dim_per_head) # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + key_states = self._project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + batch_size, ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + value_states = self._project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + batch_size, ) # compute scores @@ -539,7 +549,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self._unshape( + torch.matmul(attn_weights, value_states), batch_size + ) # (batch_size, seq_length, dim) attn_output = self.o(attn_output) present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None @@ -1007,6 +1019,7 @@ def unshape(states): # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->LongT5 class LongT5LayerSelfAttention(nn.Module): + # Ignore copy def __init__(self, config, has_relative_attention_bias=False): super().__init__() self.SelfAttention = LongT5Attention(config, has_relative_attention_bias=has_relative_attention_bias) @@ -1104,6 +1117,7 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->LongT5 class LongT5LayerCrossAttention(nn.Module): + # Ignore copy def __init__(self, config): super().__init__() self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 1336b919618f..e2c4a1c10049 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -25,6 +25,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -317,6 +323,41 @@ def compute_bias(self, query_length, key_length, device=None): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values + def _shape(self, states, batch_size): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def _unshape(self, states, batch_size): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def _project(self, hidden_states, proj_layer, key_value_states, past_key_value, batch_size): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(hidden_states), batch_size) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(key_value_states), batch_size) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(key_value_states), batch_size) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + def forward( self, hidden_states, @@ -348,50 +389,25 @@ def forward( key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + query_states = self._shape( + self.q(hidden_states), batch_size + ) # (batch_size, n_heads, seq_length, dim_per_head) # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + key_states = self._project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + batch_size, ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + value_states = self._project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + batch_size, ) # compute scores @@ -436,7 +452,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self._unshape( + torch.matmul(attn_weights, value_states), batch_size + ) # (batch_size, seq_length, dim) attn_output = self.o(attn_output) present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None @@ -447,11 +465,143 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): return outputs +# Copied from transformers.models.t5.modeling_t5.T5SdpaAttention with T5->MT5 +class MT5SdpaAttention(MT5Attention): + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + if output_attentions or layer_head_mask is not None: + logger.warning_once( + "MT5Model is using MT5SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + mask=mask, + key_value_states=key_value_states, + position_bias=position_bias, + past_key_value=past_key_value, + layer_head_mask=layer_head_mask, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + if len(past_key_value) != 2: + raise ValueError( + f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + ) + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + # get query states + query_states = self._shape( + self.q(hidden_states), batch_size + ) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = self._project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + batch_size, + ) + value_states = self._project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + batch_size, + ) + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), + device=query_states.device, + dtype=query_states.dtype, + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length, device=query_states.device) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + # spda kernels require tensors to have stride=1 in the last dimension + # .contiguous() does not behave correctly for tensors with singleton dimensions + # the following is a workaround as suggested in https://github.com/pytorch/pytorch/issues/127523 + if position_bias_masked.stride(-1) != 1: + position_bias_masked = torch.empty_like(position_bias_masked, memory_format=torch.contiguous_format).copy_( + position_bias_masked + ) + + attn_output = self._unshape( + torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=position_bias_masked, + dropout_p=self.dropout if self.training else 0.0, + scale=1.0, + ), + batch_size, + ) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + return outputs + + +MT5_ATTENTION_CLASSES = { + "eager": MT5Attention, + "sdpa": MT5SdpaAttention, +} + + # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->MT5 class MT5LayerSelfAttention(nn.Module): def __init__(self, config, has_relative_attention_bias=False): super().__init__() - self.SelfAttention = MT5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = MT5_ATTENTION_CLASSES[config._attn_implementation]( + config, has_relative_attention_bias=has_relative_attention_bias + ) self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -484,7 +634,9 @@ def forward( class MT5LayerCrossAttention(nn.Module): def __init__(self, config): super().__init__() - self.EncDecAttention = MT5Attention(config, has_relative_attention_bias=False) + self.EncDecAttention = MT5_ATTENTION_CLASSES[config._attn_implementation]( + config, has_relative_attention_bias=False + ) self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -781,6 +933,7 @@ class MT5PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MT5Block"] _keep_in_fp32_modules = ["wo"] + _supports_sdpa = True @property def dummy_inputs(self): @@ -895,6 +1048,7 @@ def __init__(self, config, embed_tokens=None): ) self.final_layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + self._use_sdpa = config._attn_implementation == "sdpa" # Initialize weights and apply final processing self.post_init() @@ -999,9 +1153,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = input_shape - - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if use_cache is True: if not self.is_decoder: @@ -1012,11 +1164,27 @@ def forward( past_key_values = [None] * len(self.block) if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + attention_mask = torch.ones(batch_size, past_key_values_length + seq_length, device=inputs_embeds.device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self._use_sdpa and head_mask is None and not output_attentions: + extended_attention_mask = ( + _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + if self.is_decoder + else _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + ) + else: + extended_attention_mask = ( + _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length) + if self.is_decoder + else _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + ) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1027,7 +1195,18 @@ def forward( encoder_attention_mask = torch.ones( encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long ) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + if self._use_sdpa and cross_attn_head_mask is None and not output_attentions: + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + else: + encoder_extended_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) else: encoder_extended_attention_mask = None diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 94d882c80566..58e254204d5c 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -892,6 +892,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size class Pix2StructTextLayerSelfAttention(nn.Module): + # Ignore copy def __init__(self, config, has_relative_attention_bias=False): super().__init__() self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=has_relative_attention_bias) @@ -925,6 +926,7 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size class Pix2StructTextLayerCrossAttention(nn.Module): + # Ignore copy def __init__(self, config): super().__init__() self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False) diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index c769cff3c454..18e0b7ddaab0 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -348,6 +348,41 @@ def compute_bias(self, query_length, key_length, device=None): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values + def _shape(self, states, batch_size): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def _unshape(self, states, batch_size): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def _project(self, hidden_states, proj_layer, key_value_states, past_key_value, batch_size): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(hidden_states), batch_size) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(key_value_states), batch_size) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(key_value_states), batch_size) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + def forward( self, hidden_states, @@ -379,50 +414,25 @@ def forward( key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + query_states = self._shape( + self.q(hidden_states), batch_size + ) # (batch_size, n_heads, seq_length, dim_per_head) # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + key_states = self._project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + batch_size, ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + value_states = self._project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + batch_size, ) # compute scores @@ -467,7 +477,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self._unshape( + torch.matmul(attn_weights, value_states), batch_size + ) # (batch_size, seq_length, dim) attn_output = self.o(attn_output) present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None @@ -478,11 +490,143 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): return outputs +# Copied from transformers.models.t5.modeling_t5.T5SdpaAttention with T5->Pop2Piano +class Pop2PianoSdpaAttention(Pop2PianoAttention): + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + if output_attentions or layer_head_mask is not None: + logger.warning_once( + "Pop2PianoModel is using Pop2PianoSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + mask=mask, + key_value_states=key_value_states, + position_bias=position_bias, + past_key_value=past_key_value, + layer_head_mask=layer_head_mask, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + if len(past_key_value) != 2: + raise ValueError( + f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + ) + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + # get query states + query_states = self._shape( + self.q(hidden_states), batch_size + ) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = self._project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + batch_size, + ) + value_states = self._project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + batch_size, + ) + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), + device=query_states.device, + dtype=query_states.dtype, + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length, device=query_states.device) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + # spda kernels require tensors to have stride=1 in the last dimension + # .contiguous() does not behave correctly for tensors with singleton dimensions + # the following is a workaround as suggested in https://github.com/pytorch/pytorch/issues/127523 + if position_bias_masked.stride(-1) != 1: + position_bias_masked = torch.empty_like(position_bias_masked, memory_format=torch.contiguous_format).copy_( + position_bias_masked + ) + + attn_output = self._unshape( + torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=position_bias_masked, + dropout_p=self.dropout if self.training else 0.0, + scale=1.0, + ), + batch_size, + ) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + return outputs + + +Pop2Piano_ATTENTION_CLASSES = { + "eager": Pop2PianoAttention, + "sdpa": Pop2PianoSdpaAttention, +} + + # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Pop2Piano,t5->pop2piano class Pop2PianoLayerSelfAttention(nn.Module): def __init__(self, config, has_relative_attention_bias=False): super().__init__() - self.SelfAttention = Pop2PianoAttention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = Pop2Piano_ATTENTION_CLASSES[config._attn_implementation]( + config, has_relative_attention_bias=has_relative_attention_bias + ) self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -515,7 +659,9 @@ def forward( class Pop2PianoLayerCrossAttention(nn.Module): def __init__(self, config): super().__init__() - self.EncDecAttention = Pop2PianoAttention(config, has_relative_attention_bias=False) + self.EncDecAttention = Pop2Piano_ATTENTION_CLASSES[config._attn_implementation]( + config, has_relative_attention_bias=False + ) self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -685,6 +831,7 @@ class Pop2PianoPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Pop2PianoBlock"] _keep_in_fp32_modules = ["wo"] + _supports_sdpa = True def _init_weights(self, module): """Initialize the weights""" @@ -772,6 +919,7 @@ def __init__(self, config, embed_tokens=None): ) self.final_layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + self._use_sdpa = config._attn_implementation == "sdpa" # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 3701b30a227f..c69e1c199e43 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -450,6 +450,41 @@ def compute_bias(self, query_length, key_length, device=None): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values + def _shape(self, states, batch_size): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def _unshape(self, states, batch_size): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def _project(self, hidden_states, proj_layer, key_value_states, past_key_value, batch_size): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(hidden_states), batch_size) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(key_value_states), batch_size) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(key_value_states), batch_size) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + def forward( self, hidden_states, @@ -481,50 +516,25 @@ def forward( key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + query_states = self._shape( + self.q(hidden_states), batch_size + ) # (batch_size, n_heads, seq_length, dim_per_head) # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + key_states = self._project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + batch_size, ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + value_states = self._project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + batch_size, ) # compute scores @@ -569,7 +579,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self._unshape( + torch.matmul(attn_weights, value_states), batch_size + ) # (batch_size, seq_length, dim) attn_output = self.o(attn_output) present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None @@ -582,6 +594,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->SwitchTransformers class SwitchTransformersLayerSelfAttention(nn.Module): + # Ignore copy def __init__(self, config, has_relative_attention_bias=False): super().__init__() self.SelfAttention = SwitchTransformersAttention( @@ -617,6 +630,7 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->SwitchTransformers class SwitchTransformersLayerCrossAttention(nn.Module): + # Ignore copy def __init__(self, config): super().__init__() self.EncDecAttention = SwitchTransformersAttention(config, has_relative_attention_bias=False) diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 4b170c1023a6..6acfd2091336 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -701,6 +701,41 @@ def compute_bias(self, query_length, key_length, device=None): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values + def _shape(self, states, batch_size): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def _unshape(self, states, batch_size): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def _project(self, hidden_states, proj_layer, key_value_states, past_key_value, batch_size): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(hidden_states), batch_size) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(key_value_states), batch_size) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = self._shape(proj_layer(key_value_states), batch_size) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + def forward( self, hidden_states, @@ -732,50 +767,25 @@ def forward( key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + query_states = self._shape( + self.q(hidden_states), batch_size + ) # (batch_size, n_heads, seq_length, dim_per_head) # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + key_states = self._project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + batch_size, ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + value_states = self._project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + batch_size, ) # compute scores @@ -820,7 +830,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self._unshape( + torch.matmul(attn_weights, value_states), batch_size + ) # (batch_size, seq_length, dim) attn_output = self.o(attn_output) present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None @@ -833,6 +845,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Udop class UdopLayerSelfAttention(nn.Module): + # Ignore copy def __init__(self, config, has_relative_attention_bias=False): super().__init__() self.SelfAttention = UdopAttention(config, has_relative_attention_bias=has_relative_attention_bias) @@ -866,6 +879,7 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Udop class UdopLayerCrossAttention(nn.Module): + # Ignore copy def __init__(self, config): super().__init__() self.EncDecAttention = UdopAttention(config, has_relative_attention_bias=False) From da98387836c0d27ff19ad37ffc87ae94729b1086 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Sat, 15 Jun 2024 14:10:19 +0000 Subject: [PATCH 10/14] Update docs --- docs/source/en/perf_infer_gpu_one.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index a81fc4938138..23fec7f45bcc 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -208,12 +208,15 @@ For now, Transformers supports SDPA inference and training for the following arc * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) * [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration) * [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) +* [Pop2Piano](https://huggingface.co/docs/transformers/model_doc/pop2piano#transformers.Pop2PianoForConditionalGeneration) * [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel) * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) +* [MT5](https://huggingface.co/docs/transformers/model_doc/mt5#transformers.MT5Model) * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) +* [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Model) * [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model) * [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) From 4628c89db8d6deee447fcbce2f6f528425ea0d8c Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Wed, 19 Jun 2024 12:24:42 +0000 Subject: [PATCH 11/14] Use clone instead of copy_ --- src/transformers/models/mt5/modeling_mt5.py | 4 +--- src/transformers/models/pop2piano/modeling_pop2piano.py | 4 +--- src/transformers/models/t5/modeling_t5.py | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index e2c4a1c10049..316cbc44ee78 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -566,9 +566,7 @@ def forward( # .contiguous() does not behave correctly for tensors with singleton dimensions # the following is a workaround as suggested in https://github.com/pytorch/pytorch/issues/127523 if position_bias_masked.stride(-1) != 1: - position_bias_masked = torch.empty_like(position_bias_masked, memory_format=torch.contiguous_format).copy_( - position_bias_masked - ) + position_bias_masked = torch.clone(position_bias_masked, memory_format=torch.contiguous_format) attn_output = self._unshape( torch.nn.functional.scaled_dot_product_attention( diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 18e0b7ddaab0..47a7d77e56a4 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -591,9 +591,7 @@ def forward( # .contiguous() does not behave correctly for tensors with singleton dimensions # the following is a workaround as suggested in https://github.com/pytorch/pytorch/issues/127523 if position_bias_masked.stride(-1) != 1: - position_bias_masked = torch.empty_like(position_bias_masked, memory_format=torch.contiguous_format).copy_( - position_bias_masked - ) + position_bias_masked = torch.clone(position_bias_masked, memory_format=torch.contiguous_format) attn_output = self._unshape( torch.nn.functional.scaled_dot_product_attention( diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 60f4d446668f..7e3d59cc2f20 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -690,9 +690,7 @@ def forward( # .contiguous() does not behave correctly for tensors with singleton dimensions # the following is a workaround as suggested in https://github.com/pytorch/pytorch/issues/127523 if position_bias_masked.stride(-1) != 1: - position_bias_masked = torch.empty_like(position_bias_masked, memory_format=torch.contiguous_format).copy_( - position_bias_masked - ) + position_bias_masked = torch.clone(position_bias_masked, memory_format=torch.contiguous_format) attn_output = self._unshape( torch.nn.functional.scaled_dot_product_attention( From 8748841d9ea2a7182978cd773a926777109300e0 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Wed, 19 Jun 2024 13:12:49 +0000 Subject: [PATCH 12/14] Use higher opset version --- tests/models/mt5/test_modeling_mt5.py | 2 +- tests/models/pop2piano/test_modeling_pop2piano.py | 2 +- tests/models/t5/test_modeling_t5.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/mt5/test_modeling_mt5.py b/tests/models/mt5/test_modeling_mt5.py index 7f0d7d992ad9..55ce5c82ac88 100644 --- a/tests/models/mt5/test_modeling_mt5.py +++ b/tests/models/mt5/test_modeling_mt5.py @@ -848,7 +848,7 @@ def test_export_to_onnx(self): (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), f"{tmpdirname}/t5_test.onnx", export_params=True, - opset_version=9, + opset_version=14, input_names=["input_ids", "decoder_input_ids"], ) diff --git a/tests/models/pop2piano/test_modeling_pop2piano.py b/tests/models/pop2piano/test_modeling_pop2piano.py index 3a33b5a98128..39ff67f08ce5 100644 --- a/tests/models/pop2piano/test_modeling_pop2piano.py +++ b/tests/models/pop2piano/test_modeling_pop2piano.py @@ -620,7 +620,7 @@ def test_export_to_onnx(self): (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), f"{tmpdirname}/Pop2Piano_test.onnx", export_params=True, - opset_version=9, + opset_version=14, input_names=["input_ids", "decoder_input_ids"], ) diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index 23e8640d5b99..620111ba57bd 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -851,7 +851,7 @@ def test_export_to_onnx(self): (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), f"{tmpdirname}/t5_test.onnx", export_params=True, - opset_version=9, + opset_version=14, input_names=["input_ids", "decoder_input_ids"], ) From 0160651fd6cf650acc85a1ffcdb1d5f158bca0d4 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Wed, 19 Jun 2024 15:20:54 +0000 Subject: [PATCH 13/14] Skip test --- tests/models/pop2piano/test_modeling_pop2piano.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/pop2piano/test_modeling_pop2piano.py b/tests/models/pop2piano/test_modeling_pop2piano.py index 39ff67f08ce5..a7f5cac8905c 100644 --- a/tests/models/pop2piano/test_modeling_pop2piano.py +++ b/tests/models/pop2piano/test_modeling_pop2piano.py @@ -610,7 +610,7 @@ def test_model_from_pretrained(self): model = Pop2PianoForConditionalGeneration.from_pretrained(model_name) self.assertIsNotNone(model) - @require_onnx + @unittest.skip("ONNX support deprecated") def test_export_to_onnx(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() model = Pop2PianoForConditionalGeneration(config_and_inputs[0]).to(torch_device) From 44107ac3beac209b8882f8657b820a9a7dd729f5 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Wed, 19 Jun 2024 15:22:49 +0000 Subject: [PATCH 14/14] fix style --- tests/models/pop2piano/test_modeling_pop2piano.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/pop2piano/test_modeling_pop2piano.py b/tests/models/pop2piano/test_modeling_pop2piano.py index a7f5cac8905c..22d60e41d64e 100644 --- a/tests/models/pop2piano/test_modeling_pop2piano.py +++ b/tests/models/pop2piano/test_modeling_pop2piano.py @@ -26,7 +26,6 @@ from transformers.testing_utils import ( require_essentia, require_librosa, - require_onnx, require_scipy, require_torch, slow,