From 15abc14d6170fa434c2cd1e73d309d64572ad62d Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 19 Sep 2024 18:27:05 +0200 Subject: [PATCH 01/30] this worked in normal generation, needs more tests --- src/transformers/models/t5/modeling_t5.py | 383 ++++++++++++++++------ 1 file changed, 285 insertions(+), 98 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index a90101924c5b..32df71fd0359 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -25,6 +25,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -227,6 +229,60 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): """ +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + class T5LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -338,7 +394,12 @@ def forward(self, hidden_states): class T5Attention(nn.Module): - def __init__(self, config: T5Config, has_relative_attention_bias=False): + def __init__( + self, + config: T5Config, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -349,6 +410,13 @@ def __init__(self, config: T5Config, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -453,6 +521,7 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). @@ -463,13 +532,7 @@ def forward( batch_size, seq_length = hidden_states.shape[:2] real_seq_length = seq_length - - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + real_seq_length += cache_position[0] if query_length is None else query_length key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] @@ -481,43 +544,34 @@ def unshape(states): """reshape""" return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - # get query states query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + # get past key value + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if key_value_states is not None: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + current_states = key_value_states if key_value_states is not None else hidden_states + if key_value_states is not None and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = shape(self.k(current_states)) + value_states = shape(self.v(current_states)) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not key_value_states is not None else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) # compute scores scores = torch.matmul( @@ -564,8 +618,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output,) + (past_key_value,) + (position_bias,) if output_attentions: outputs = outputs + (attn_weights,) @@ -573,9 +626,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): class T5LayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = T5Attention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -588,6 +643,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -598,6 +654,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -605,9 +662,9 @@ def forward( class T5LayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -622,6 +679,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -634,6 +692,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -641,13 +700,15 @@ def forward( class T5Block(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append( + T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) + ) if self.is_decoder: - self.layer.append(T5LayerCrossAttention(config)) + self.layer.append(T5LayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(T5LayerFF(config)) @@ -665,34 +726,19 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, self_attn_present_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -708,10 +754,7 @@ def forward( if do_cross_attention: # the actual query length is unknown for cross attention # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None + query_length = cache_position[0] cross_attention_outputs = self.layer[1]( hidden_states, @@ -719,7 +762,7 @@ def forward( attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, query_length=query_length, use_cache=use_cache, output_attentions=output_attentions, @@ -735,10 +778,6 @@ def forward( ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -757,7 +796,7 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs @@ -904,7 +943,7 @@ def __init__(self, config, embed_tokens=None): self.is_decoder = config.is_decoder self.block = nn.ModuleList( - [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [T5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)] ) self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -980,6 +1019,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): # Model parallel if self.model_parallel: @@ -1020,16 +1060,45 @@ def forward( if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if use_cache or past_key_values is not None: + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + past_key_values_length = 0 + if cache_position is not None: + past_key_values_length = cache_position[0] + elif past_key_values is not None: + past_key_values_length = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) if attention_mask is None: attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1054,7 +1123,6 @@ def forward( # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None @@ -1063,7 +1131,7 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] # Model parallel @@ -1091,7 +1159,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -1101,20 +1169,22 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -1122,7 +1192,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, nexy_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -1130,9 +1200,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1152,12 +1219,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = nexy_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1166,12 +1239,79 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + T5_START_DOCSTRING = r""" @@ -1285,6 +1425,9 @@ def forward( more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ T5_ENCODER_INPUTS_DOCSTRING = r""" @@ -1445,6 +1588,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" Returns: @@ -1524,6 +1668,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1655,6 +1800,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1749,6 +1895,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] @@ -1801,11 +1948,15 @@ def prepare_inputs_for_generation( cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, + cache_position=None, **kwargs, ): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + if isinstance(past_key_values, EncoderDecoderCache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + else: + past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: @@ -1816,6 +1967,41 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, remove_prefix_length:] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_ids.shape[1] :] + + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + input_ids = input_ids.contiguous() + + if ( + isinstance(past_key_values, EncoderDecoderCache) + and ( + isinstance(past_key_values.self_attention_cache, StaticCache) + or isinstance(past_key_values.cross_attention_cache, StaticCache) + ) + and decoder_attention_mask is not None + and decoder_attention_mask.ndim == 2 + ): + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + dtype = self.proj_out.weight.dtype + min_dtype = torch.finfo(dtype).min + + decoder_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + decoder_attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.self_attention_cache.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + return { "decoder_input_ids": input_ids, "past_key_values": past_key_values, @@ -1826,6 +2012,7 @@ def prepare_inputs_for_generation( "decoder_attention_mask": decoder_attention_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, + "cache_position": cache_position, } def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): From 06d9d625c9a6ab3992e424f3689922ca36390cd7 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 30 Sep 2024 13:09:09 +0200 Subject: [PATCH 02/30] fix almost all tests in t5 --- src/transformers/generation/utils.py | 9 +- src/transformers/models/t5/modeling_t5.py | 50 ++++---- tests/models/t5/test_modeling_t5.py | 147 +++++++++++++--------- 3 files changed, 125 insertions(+), 81 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d8896f91267d..9f1304ce87a5 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -496,7 +496,7 @@ def _prepare_encoder_decoder_kwargs_for_generation( add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True)) # 2. Prepare encoder args and encoder kwargs from model kwargs and generation config. - irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] + irrelevant_prefix = ["decoder_", "cross_attn", "use_cache", "past_key_values"] encoder_kwargs = { argument: value for argument, value in model_kwargs.items() @@ -1386,6 +1386,7 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): else: cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 + print("initial", cache_position) past_length = 0 if model_kwargs.get("past_key_values") is not None: cache = model_kwargs["past_key_values"] @@ -1394,6 +1395,7 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): past_length = cache[0][0].shape[2] elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: past_length = cache.get_seq_length() + print("cropped", past_length, cache.self_attention_cache, cache.self_attention_cache.get_seq_length()) # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty, # end-to-end compilation will yield bad results because `cache_position` will be incorrect. @@ -1865,11 +1867,15 @@ def generate( inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor ) + if "past_key_values" in model_kwargs: + print("before", model_kwargs["past_key_values"].get_seq_length()) if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor, model_kwargs, model_input_name, generation_config ) + if "past_key_values" in model_kwargs: + print("after", model_kwargs["past_key_values"].get_seq_length()) # 5. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: @@ -2993,6 +2999,7 @@ def _sample( this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length ): # prepare model inputs + print(model_kwargs["cache_position"]) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # prepare variable output controls (note: some models won't accept all output controls) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 32df71fd0359..ee19a27ed84a 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -604,12 +604,9 @@ def unshape(states): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: @@ -618,10 +615,10 @@ def unshape(states): attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) attn_output = self.o(attn_output) - outputs = (attn_output,) + (past_key_value,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: - outputs = outputs + (attn_weights,) + outputs += (attn_weights,) return outputs @@ -832,6 +829,9 @@ class T5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True + _supports_quantized_cache = False # enc-dec models don't support yet + _supports_static_cache = True + _supports_cache_class = True _no_split_modules = ["T5Block"] _keep_in_fp32_modules = ["wo"] @@ -1053,9 +1053,6 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") @@ -1090,15 +1087,24 @@ def forward( ) if attention_mask is None: + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length + ) attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - causal_mask = self._update_causal_mask( - attention_mask, - inputs_embeds, - cache_position, - past_key_values.self_attention_cache if past_key_values is not None else None, - output_attentions, - ) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1139,7 +1145,7 @@ def forward( torch.cuda.set_device(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) + causal_mask = causal_mask.to(hidden_states.device) if position_bias is not None: position_bias = position_bias.to(hidden_states.device) if encoder_hidden_states is not None: @@ -1169,6 +1175,7 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + return_dict, cache_position, ) else: @@ -1184,6 +1191,7 @@ def forward( past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + return_dict=return_dict, cache_position=cache_position, ) @@ -1192,7 +1200,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, nexy_decoder_cache = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -1219,7 +1227,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = nexy_decoder_cache if use_cache else None + next_cache = next_decoder_cache if use_cache else None if return_self_attention_cache: next_cache = past_key_values.self_attention_cache if return_legacy_cache: diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index 93634ef2a670..f38e8c58cbbb 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -44,6 +44,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import ( AutoTokenizer, @@ -609,65 +610,58 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa model.eval() inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) - try: - if model.config.is_encoder_decoder: - model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward - labels = inputs.get("labels", None) - input_names = [ - "attention_mask", - "decoder_attention_mask", - "decoder_input_ids", - "input_features", - "input_ids", - "input_values", - ] - if labels is not None: - input_names.append("labels") - - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) - - model_output = model(**filtered_inputs) - - traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - else: - input_names = [ - "attention_mask", - "bbox", - "input_features", - "input_ids", - "input_values", - "pixel_values", - "token_type_ids", - "visual_feats", - "visual_pos", - ] - - labels = inputs.get("labels", None) - start_positions = inputs.get("start_positions", None) - end_positions = inputs.get("end_positions", None) - if labels is not None: - input_names.append("labels") - if start_positions is not None: - input_names.append("start_positions") - if end_positions is not None: - input_names.append("end_positions") - - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) - - if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( - not hasattr(model.config, "problem_type") or model.config.problem_type is None - ): - model.config.problem_type = "single_label_classification" - - traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - model_output = model(**filtered_inputs) - - except Exception as e: - self.fail(f"Couldn't trace module: {e}") + # try: + if model.config.is_encoder_decoder: + model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward + labels = inputs.get("labels", None) + input_names = [ + "attention_mask", + "decoder_attention_mask", + "decoder_input_ids", + "input_features", + "input_ids", + "input_values", + ] + if labels is not None: + input_names.append("labels") + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + model_output = model(**filtered_inputs) + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + else: + input_names = [ + "attention_mask", + "bbox", + "input_features", + "input_ids", + "input_values", + "pixel_values", + "token_type_ids", + "visual_feats", + "visual_pos", + ] + labels = inputs.get("labels", None) + start_positions = inputs.get("start_positions", None) + end_positions = inputs.get("end_positions", None) + if labels is not None: + input_names.append("labels") + if start_positions is not None: + input_names.append("start_positions") + if end_positions is not None: + input_names.append("end_positions") + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( + not hasattr(model.config, "problem_type") or model.config.problem_type is None + ): + model.config.problem_type = "single_label_classification" + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + model_output = model(**filtered_inputs) + + # except Exception as e: + # self.fail(f"Couldn't trace module: {e}") def flatten_output(output): flatten = [] @@ -714,6 +708,41 @@ def flatten_output(output): # (Even with this call, there are still memory leak by ~0.04MB) self.clear_torch_jit_class_registry() + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + def test_config(self): self.config_tester.run_common_tests() From 51c689c7e039594eddfc3158740a051d76bfa986 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 30 Sep 2024 13:21:47 +0200 Subject: [PATCH 03/30] nit --- src/transformers/generation/utils.py | 7 ------- src/transformers/models/t5/modeling_t5.py | 14 +++++++------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9f1304ce87a5..47ec185508f8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1386,7 +1386,6 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): else: cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 - print("initial", cache_position) past_length = 0 if model_kwargs.get("past_key_values") is not None: cache = model_kwargs["past_key_values"] @@ -1395,7 +1394,6 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): past_length = cache[0][0].shape[2] elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: past_length = cache.get_seq_length() - print("cropped", past_length, cache.self_attention_cache, cache.self_attention_cache.get_seq_length()) # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty, # end-to-end compilation will yield bad results because `cache_position` will be incorrect. @@ -1867,15 +1865,11 @@ def generate( inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor ) - if "past_key_values" in model_kwargs: - print("before", model_kwargs["past_key_values"].get_seq_length()) if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor, model_kwargs, model_input_name, generation_config ) - if "past_key_values" in model_kwargs: - print("after", model_kwargs["past_key_values"].get_seq_length()) # 5. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: @@ -2999,7 +2993,6 @@ def _sample( this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length ): # prepare model inputs - print(model_kwargs["cache_position"]) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # prepare variable output controls (note: some models won't accept all output controls) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index ee19a27ed84a..1087c55a9523 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1046,6 +1046,13 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -1119,13 +1126,6 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) From 9e5244c7f00a2aae0f23adaecc9e03ea60038411 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 30 Sep 2024 14:19:15 +0200 Subject: [PATCH 04/30] longt5, umt5, mt5 --- .../models/longt5/modeling_longt5.py | 433 +++++++++++++----- src/transformers/models/mt5/modeling_mt5.py | 423 ++++++++++++----- src/transformers/models/t5/modeling_t5.py | 6 +- src/transformers/models/umt5/modeling_umt5.py | 360 ++++++++++++--- tests/models/longt5/test_modeling_longt5.py | 36 ++ tests/models/mt5/test_modeling_mt5.py | 36 ++ tests/models/umt5/test_modeling_umt5.py | 36 ++ 7 files changed, 1014 insertions(+), 316 deletions(-) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index b2a6ed11ca57..55bbda4e20b2 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -24,6 +24,8 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -52,6 +54,60 @@ # TODO: Update before the merge +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def _pad_to_multiple(x: torch.Tensor, block_len: int, dim: int, pad_value: int = 0) -> torch.Tensor: """Pad a tensor so that a sequence length will be a multiple of `block_len`""" pad_len = -x.shape[dim] % block_len @@ -316,7 +372,12 @@ def forward(self, hidden_states): # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->LongT5 class LongT5Attention(nn.Module): - def __init__(self, config: LongT5Config, has_relative_attention_bias=False): + def __init__( + self, + config: LongT5Config, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -327,6 +388,13 @@ def __init__(self, config: LongT5Config, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -431,6 +499,7 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). @@ -441,13 +510,7 @@ def forward( batch_size, seq_length = hidden_states.shape[:2] real_seq_length = seq_length - - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + real_seq_length += cache_position[0] if query_length is None else query_length key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] @@ -459,43 +522,34 @@ def unshape(states): """reshape""" return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - # get query states query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + # get past key value + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if key_value_states is not None: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + current_states = key_value_states if key_value_states is not None else hidden_states + if key_value_states is not None and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = shape(self.k(current_states)) + value_states = shape(self.v(current_states)) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not key_value_states is not None else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) # compute scores scores = torch.matmul( @@ -528,12 +582,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: @@ -542,11 +593,10 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: - outputs = outputs + (attn_weights,) + outputs += (attn_weights,) return outputs @@ -1007,9 +1057,11 @@ def unshape(states): # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->LongT5 class LongT5LayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = LongT5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = LongT5Attention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -1022,6 +1074,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -1032,6 +1085,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -1041,7 +1095,7 @@ def forward( class LongT5LayerLocalSelfAttention(nn.Module): """Local self attention used in encoder""" - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.LocalSelfAttention = LongT5LocalAttention(config, has_relative_attention_bias=has_relative_attention_bias) self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) @@ -1072,7 +1126,7 @@ def forward( class LongT5LayerTransientGlobalSelfAttention(nn.Module): """Transient-Global self attention used in encoder""" - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.TransientGlobalSelfAttention = LongT5TransientGlobalAttention( config, has_relative_attention_bias=has_relative_attention_bias @@ -1104,9 +1158,9 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->LongT5 class LongT5LayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False) + self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -1121,6 +1175,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -1133,6 +1188,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -1140,7 +1196,7 @@ def forward( class LongT5Block(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder if config.is_decoder: @@ -1155,9 +1211,11 @@ def __init__(self, config, has_relative_attention_bias=False): f"but got {config.encoder_attention_type}." ) self.layer = nn.ModuleList() - self.layer.append(attention_layer(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append( + attention_layer(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) + ) if self.is_decoder: - self.layer.append(LongT5LayerCrossAttention(config)) + self.layer.append(LongT5LayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(LongT5LayerFF(config)) @@ -1175,32 +1233,17 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states, present_key_value_state = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights @@ -1214,10 +1257,7 @@ def forward( if do_cross_attention: # the actual query length is unknown for cross attention # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None + query_length = cache_position[0] cross_attention_outputs = self.layer[1]( hidden_states, @@ -1225,10 +1265,11 @@ def forward( attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, query_length=query_length, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = cross_attention_outputs[0] @@ -1237,10 +1278,6 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -1255,7 +1292,7 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs @@ -1272,6 +1309,8 @@ class LongT5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["LongT5Block"] + _supports_cache_class = True + _supports_static_cache = True @property # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs @@ -1375,7 +1414,10 @@ def __init__(self, config, embed_tokens=None): self.block_len = self.local_radius + 1 self.block = nn.ModuleList( - [LongT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [ + LongT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) + for i in range(config.num_layers) + ] ) self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -1407,6 +1449,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1429,36 +1472,68 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - - if use_cache is True: - assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if use_cache or past_key_values is not None: + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + past_key_values_length = 0 + if cache_position is not None: + past_key_values_length = cache_position[0] + elif past_key_values is not None: + past_key_values_length = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) if attention_mask is None: + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length + ) attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used if self.is_decoder: - extended_attention_mask = self.get_extended_attention_mask( - attention_mask, input_shape, inputs_embeds.device + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, ) + # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used elif self.config.encoder_attention_type == "local": - extended_attention_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device) + causal_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device) else: # we need to use both local attention mask and standard extended mask for transient-global attention - extended_attention_mask = attention_mask + causal_mask = attention_mask # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1471,17 +1546,9 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None @@ -1490,7 +1557,7 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] @@ -1501,7 +1568,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -1511,20 +1578,24 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + return_dict, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -1532,7 +1603,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -1540,9 +1611,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1556,12 +1624,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1570,12 +1644,79 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + LONGT5_START_DOCSTRING = r""" @@ -1692,6 +1833,9 @@ def forward( more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ LONGT5_ENCODER_INPUTS_DOCSTRING = r""" @@ -1816,6 +1960,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" Returns: @@ -1882,6 +2027,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1974,6 +2120,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -2049,6 +2196,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] @@ -2094,11 +2242,15 @@ def prepare_inputs_for_generation( cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, + cache_position=None, **kwargs, ): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + if isinstance(past_key_values, EncoderDecoderCache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + else: + past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: @@ -2109,6 +2261,41 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, remove_prefix_length:] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_ids.shape[1] :] + + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + input_ids = input_ids.contiguous() + + if ( + isinstance(past_key_values, EncoderDecoderCache) + and ( + isinstance(past_key_values.self_attention_cache, StaticCache) + or isinstance(past_key_values.cross_attention_cache, StaticCache) + ) + and attention_mask is not None + and attention_mask.ndim == 2 + ): + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + dtype = self.proj_out.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.self_attention_cache.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + return { "decoder_input_ids": input_ids, "past_key_values": past_key_values, diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 54943cf982dd..92efa230b526 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -25,6 +25,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -113,6 +115,60 @@ """ +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->MT5 class MT5LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -213,7 +269,12 @@ def forward(self, hidden_states): # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->MT5 class MT5Attention(nn.Module): - def __init__(self, config: MT5Config, has_relative_attention_bias=False): + def __init__( + self, + config: MT5Config, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -224,6 +285,13 @@ def __init__(self, config: MT5Config, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -328,6 +396,7 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). @@ -338,13 +407,7 @@ def forward( batch_size, seq_length = hidden_states.shape[:2] real_seq_length = seq_length - - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + real_seq_length += cache_position[0] if query_length is None else query_length key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] @@ -356,43 +419,34 @@ def unshape(states): """reshape""" return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - # get query states query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + # get past key value + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if key_value_states is not None: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + current_states = key_value_states if key_value_states is not None else hidden_states + if key_value_states is not None and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = shape(self.k(current_states)) + value_states = shape(self.v(current_states)) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not key_value_states is not None else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) # compute scores scores = torch.matmul( @@ -425,12 +479,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: @@ -439,19 +490,20 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: - outputs = outputs + (attn_weights,) + outputs += (attn_weights,) return outputs # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->MT5 class MT5LayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = MT5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = MT5Attention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -464,6 +516,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -474,6 +527,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -482,9 +536,9 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->MT5 class MT5LayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = MT5Attention(config, has_relative_attention_bias=False) + self.EncDecAttention = MT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -499,6 +553,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -511,6 +566,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -519,13 +575,15 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5Block with T5->MT5 class MT5Block(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(MT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append( + MT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) + ) if self.is_decoder: - self.layer.append(MT5LayerCrossAttention(config)) + self.layer.append(MT5LayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(MT5LayerFF(config)) @@ -543,34 +601,19 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, self_attn_present_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -586,10 +629,7 @@ def forward( if do_cross_attention: # the actual query length is unknown for cross attention # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None + query_length = cache_position[0] cross_attention_outputs = self.layer[1]( hidden_states, @@ -597,7 +637,7 @@ def forward( attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, query_length=query_length, use_cache=use_cache, output_attentions=output_attentions, @@ -613,10 +653,6 @@ def forward( ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -635,7 +671,7 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs @@ -779,6 +815,8 @@ class MT5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True + _supports_static_cache = True + _supports_cache_class = True _no_split_modules = ["MT5Block"] _keep_in_fp32_modules = ["wo"] @@ -891,7 +929,7 @@ def __init__(self, config, embed_tokens=None): self.is_decoder = config.is_decoder self.block = nn.ModuleList( - [MT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [MT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)] ) self.final_layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -967,6 +1005,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): # Model parallel if self.model_parallel: @@ -993,6 +1032,13 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -1000,23 +1046,56 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if use_cache or past_key_values is not None: + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + past_key_values_length = 0 + if cache_position is not None: + past_key_values_length = cache_position[0] + elif past_key_values is not None: + past_key_values_length = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) if attention_mask is None: + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1031,17 +1110,9 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None @@ -1050,15 +1121,15 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) + if causal_mask is not None: + causal_mask = causal_mask.to(hidden_states.device) if position_bias is not None: position_bias = position_bias.to(hidden_states.device) if encoder_hidden_states is not None: @@ -1078,7 +1149,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -1088,20 +1159,24 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + return_dict, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -1109,7 +1184,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -1117,9 +1192,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1139,12 +1211,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1153,12 +1231,79 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + MT5_START_DOCSTRING = r""" @@ -1453,6 +1598,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" Returns: @@ -1532,6 +1678,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1684,6 +1831,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1778,6 +1926,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] @@ -1831,11 +1980,15 @@ def prepare_inputs_for_generation( cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, + cache_position=None, **kwargs, ): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + if isinstance(past_key_values, EncoderDecoderCache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + else: + past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: @@ -1846,6 +1999,41 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, remove_prefix_length:] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_ids.shape[1] :] + + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + input_ids = input_ids.contiguous() + + if ( + isinstance(past_key_values, EncoderDecoderCache) + and ( + isinstance(past_key_values.self_attention_cache, StaticCache) + or isinstance(past_key_values.cross_attention_cache, StaticCache) + ) + and decoder_attention_mask is not None + and decoder_attention_mask.ndim == 2 + ): + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + dtype = self.proj_out.weight.dtype + min_dtype = torch.finfo(dtype).min + + decoder_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + decoder_attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.self_attention_cache.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + return { "decoder_input_ids": input_ids, "past_key_values": past_key_values, @@ -1856,6 +2044,7 @@ def prepare_inputs_for_generation( "decoder_attention_mask": decoder_attention_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, + "cache_position": cache_position, } # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 1087c55a9523..a36a2796c95c 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1144,7 +1144,7 @@ def forward( if self.model_parallel: torch.cuda.set_device(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: + if causal_mask is not None: causal_mask = causal_mask.to(hidden_states.device) if position_bias is not None: position_bias = position_bias.to(hidden_states.device) @@ -1434,8 +1434,8 @@ def _update_causal_mask( return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. It is used to update the - cache in the correct position and to infer the complete sequence length. + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ T5_ENCODER_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 3271689540b9..920292ed124a 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -51,6 +52,60 @@ _CHECKPOINT_FOR_DOC = "google/umt5-small" +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->UMT5 class UMT5LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -154,7 +209,7 @@ class UMT5Attention(nn.Module): T5's attention using relative_attention_bias. """ - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -165,6 +220,13 @@ def __init__(self, config, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -248,25 +310,35 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, ): - is_cross_attention = encoder_hidden_states is not None batch_size, seq_length = hidden_states.shape[:2] - # use encoder_hidden_states if cross attention + # get past key value + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if encoder_hidden_states is not None: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # get key/value states current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - # checking that the `sequence_length` of the `past_key_value` is the same as the he provided - # `encoder_hidden_states` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + if encoder_hidden_states is not None and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] else: key_states = self._shape(self.k(current_states)) value_states = self._shape(self.v(current_states)) - if past_key_value is not None and not is_cross_attention: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not encoder_hidden_states is not None else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) query_states = self._shape(self.q(hidden_states)) attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) @@ -275,7 +347,7 @@ def forward( if self.has_relative_attention_bias: query_length = seq_length if past_key_value is not None: - query_length += past_key_value[0].shape[2] + query_length += past_key_value.get_seq_length() position_bias = self.compute_bias(query_length, key_states.size(2), device=attention_scores.device) else: position_bias = torch.zeros( @@ -289,16 +361,6 @@ def forward( if attention_mask is not None: position_bias = position_bias + attention_mask # (batch_size, n_heads, seq_length, key_length) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - attention_scores += position_bias # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.softmax(attention_scores.float(), dim=-1).type_as(attention_scores) @@ -317,9 +379,9 @@ def forward( class UMT5LayerSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = UMT5Attention(config, has_relative_attention_bias=True) + self.SelfAttention = UMT5Attention(config, has_relative_attention_bias=True, layer_idx=layer_idx) self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -329,6 +391,7 @@ def forward( attention_mask=None, layer_head_mask=None, past_key_value=None, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -336,6 +399,7 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, past_key_value=past_key_value, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -343,9 +407,9 @@ def forward( class UMT5LayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = UMT5Attention(config, has_relative_attention_bias=False) + self.EncDecAttention = UMT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -356,6 +420,7 @@ def forward( attention_mask=None, layer_head_mask=None, past_key_value=None, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -364,6 +429,7 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, past_key_value=past_key_value, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -371,13 +437,13 @@ def forward( class UMT5Block(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(UMT5LayerSelfAttention(config)) + self.layer.append(UMT5LayerSelfAttention(config, layer_idx=layer_idx)) if self.is_decoder: - self.layer.append(UMT5LayerCrossAttention(config)) + self.layer.append(UMT5LayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(UMT5LayerFF(config)) @@ -392,16 +458,14 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - hidden_states, self_attn_weights, present_key_value = self.layer[0]( hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) # clamp inf values to enable fp16 training @@ -411,18 +475,17 @@ def forward( hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.layer[1]( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -430,8 +493,6 @@ def forward( clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - present_key_value += cross_attn_present_key_value - # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states) @@ -443,7 +504,7 @@ def forward( outputs = ( hidden_states, - present_key_value, + past_key_value, ) if output_attentions: @@ -480,6 +541,8 @@ class UMT5PreTrainedModel(PreTrainedModel): config_class = UMT5Config base_model_prefix = "transformer" supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = True _no_split_modules = ["UMT5Block"] _keep_in_fp32_modules = ["wo"] @@ -593,7 +656,7 @@ def __init__(self, config, embed_tokens=None): super().__init__(config) self.embed_tokens = embed_tokens self.is_decoder = config.is_decoder - self.block = nn.ModuleList([UMT5Block(config) for i in range(config.num_layers)]) + self.block = nn.ModuleList([UMT5Block(config, layer_idx=i) for i in range(config.num_layers)]) self.final_layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -621,6 +684,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -643,6 +707,13 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -650,28 +721,58 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if use_cache or past_key_values is not None: + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + past_key_values_length = 0 + if cache_position is not None: + past_key_values_length = cache_position[0] + elif past_key_values is not None: + past_key_values_length = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + if attention_mask is None: + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length + ) + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -684,24 +785,16 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.is_decoder else None hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] @@ -712,7 +805,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, encoder_hidden_states, encoder_extended_attention_mask, layer_head_mask, @@ -720,24 +813,26 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - present_key_value_states += (layer_outputs[1],) + next_decoder_cache = layer_outputs[1] if output_attentions: all_attentions += (layer_outputs[2],) @@ -751,12 +846,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -765,12 +866,79 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + UMT5_START_DOCSTRING = r""" @@ -884,6 +1052,9 @@ def forward( more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ UMT5_ENCODER_INPUTS_DOCSTRING = r""" @@ -1021,6 +1192,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" Returns: @@ -1083,6 +1255,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1196,6 +1369,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1313,11 +1487,15 @@ def prepare_inputs_for_generation( cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, + cache_position=None, **kwargs, ): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + if isinstance(past_key_values, EncoderDecoderCache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + else: + past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: @@ -1328,6 +1506,41 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, remove_prefix_length:] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_ids.shape[1] :] + + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + input_ids = input_ids.contiguous() + + if ( + isinstance(past_key_values, EncoderDecoderCache) + and ( + isinstance(past_key_values.self_attention_cache, StaticCache) + or isinstance(past_key_values.cross_attention_cache, StaticCache) + ) + and decoder_attention_mask is not None + and decoder_attention_mask.ndim == 2 + ): + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + dtype = self.proj_out.weight.dtype + min_dtype = torch.finfo(dtype).min + + decoder_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + decoder_attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.self_attention_cache.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + return { "decoder_input_ids": input_ids, "past_key_values": past_key_values, @@ -1338,6 +1551,7 @@ def prepare_inputs_for_generation( "decoder_attention_mask": decoder_attention_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, + "cache_position": cache_position, } # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels diff --git a/tests/models/longt5/test_modeling_longt5.py b/tests/models/longt5/test_modeling_longt5.py index c0cf21b2369d..93b0556ff120 100644 --- a/tests/models/longt5/test_modeling_longt5.py +++ b/tests/models/longt5/test_modeling_longt5.py @@ -31,6 +31,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import ( MODEL_FOR_QUESTION_ANSWERING_MAPPING, @@ -574,6 +575,41 @@ def test_decoder_model_past_with_3d_attn_mask(self): lm_labels, ) + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) diff --git a/tests/models/mt5/test_modeling_mt5.py b/tests/models/mt5/test_modeling_mt5.py index ec6ec6cd85c6..d457ae338096 100644 --- a/tests/models/mt5/test_modeling_mt5.py +++ b/tests/models/mt5/test_modeling_mt5.py @@ -40,6 +40,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import ( AutoModelForSeq2SeqLM, @@ -711,6 +712,41 @@ def flatten_output(output): # (Even with this call, there are still memory leak by ~0.04MB) self.clear_torch_jit_class_registry() + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + def test_config(self): self.config_tester.run_common_tests() diff --git a/tests/models/umt5/test_modeling_umt5.py b/tests/models/umt5/test_modeling_umt5.py index 2bb841e65e65..cfb9d2be4a48 100644 --- a/tests/models/umt5/test_modeling_umt5.py +++ b/tests/models/umt5/test_modeling_umt5.py @@ -41,6 +41,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import ( AutoTokenizer, @@ -479,6 +480,41 @@ def test_inputs_embeds(self): with torch.no_grad(): model(**inputs)[0] + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + def test_with_sequence_classification_head(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs) From 417dd6d718d5e5290deaf3b02a3bb385cf37a838 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 30 Sep 2024 14:19:55 +0200 Subject: [PATCH 05/30] style --- src/transformers/models/mt5/modeling_mt5.py | 4 +++- src/transformers/models/umt5/modeling_umt5.py | 1 + tests/models/mt5/test_modeling_mt5.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 92efa230b526..0c1359ec1632 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1081,7 +1081,9 @@ def forward( if attention_mask is None: # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length + mask_seq_length = ( + past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length + ) attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) if self.config.is_decoder: diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 920292ed124a..8a34fbff06ff 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, diff --git a/tests/models/mt5/test_modeling_mt5.py b/tests/models/mt5/test_modeling_mt5.py index d457ae338096..4b32b5cec789 100644 --- a/tests/models/mt5/test_modeling_mt5.py +++ b/tests/models/mt5/test_modeling_mt5.py @@ -746,7 +746,7 @@ def test_custom_4d_attention_mask(self): normalized_0 = F.softmax(out_last_tokens) normalized_1 = F.softmax(out_shared_prefix_last_tokens) torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) - + def test_config(self): self.config_tester.run_common_tests() From 814a40527d73c3fe167e833d988315fd65393a62 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 30 Sep 2024 14:47:32 +0200 Subject: [PATCH 06/30] udop, pix2struct --- .../models/pix2struct/modeling_pix2struct.py | 404 ++++++++++++----- src/transformers/models/udop/modeling_udop.py | 426 +++++++++++++----- tests/models/udop/test_modeling_udop.py | 73 +++ 3 files changed, 680 insertions(+), 223 deletions(-) diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 94d882c80566..27f628765a46 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -22,6 +22,8 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -49,6 +51,60 @@ _CONFIG_FOR_DOC = "Pix2StructConfig" +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + # Adapted from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Pix2Struct class Pix2StructLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -354,6 +410,8 @@ class Pix2StructPreTrainedModel(PreTrainedModel): """ config_class = Pix2StructConfig + _supports_cache_class = True + _supports_static_cache = True @property def dummy_inputs(self): @@ -672,7 +730,9 @@ def forward(self, hidden_states): class Pix2StructTextAttention(nn.Module): - def __init__(self, config: Pix2StructTextConfig, has_relative_attention_bias=False): + def __init__( + self, config: Pix2StructTextConfig, has_relative_attention_bias=False, layer_idx: Optional[int] = None + ): super().__init__() self.has_relative_attention_bias = has_relative_attention_bias self.relative_attention_num_buckets = config.relative_attention_num_buckets @@ -682,6 +742,13 @@ def __init__(self, config: Pix2StructTextConfig, has_relative_attention_bias=Fal self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=False) @@ -772,6 +839,7 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). @@ -782,13 +850,7 @@ def forward( batch_size, seq_length = hidden_states.shape[:2] real_seq_length = seq_length - - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + real_seq_length += cache_position[0] if query_length is None else query_length key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] @@ -796,44 +858,35 @@ def to_projection_shape(states): """projection""" return states.contiguous().view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = to_projection_shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = to_projection_shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = to_projection_shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - # get query states # (batch_size, n_heads, seq_length, dim_per_head) query_states = to_projection_shape(self.query(hidden_states)) + # get past key value + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if key_value_states is not None: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + # get key/value states - key_states = project( - hidden_states, self.key, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.value, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + current_states = key_value_states if key_value_states is not None else hidden_states + if key_value_states is not None and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = to_projection_shape(self.key(current_states)) + value_states = to_projection_shape(self.value(current_states)) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not key_value_states is not None else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) # compute scores scores = torch.matmul( @@ -850,11 +903,6 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): else: position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] - if mask is not None: position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) @@ -882,19 +930,20 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): attn_output = self.output(attn_output) - present_key_value_state = (key_states, value_states) if use_cache else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output,) + (past_key_value,) + (position_bias,) if output_attentions: outputs = outputs + (attn_weights,) return outputs -# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size +# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerSelfAttention->Pix2StructTextLayerSelfAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size class Pix2StructTextLayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=has_relative_attention_bias) + self.attention = Pix2StructTextAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -907,6 +956,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.attention( @@ -917,17 +967,18 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them return outputs -# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size +# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerCrossAttention->Pix2StructTextLayerCrossAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size class Pix2StructTextLayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False) + self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -942,6 +993,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.attention( @@ -954,6 +1006,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -961,11 +1014,13 @@ def forward( class Pix2StructTextBlock(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.self_attention = Pix2StructTextLayerSelfAttention( - config, has_relative_attention_bias=has_relative_attention_bias + config, + has_relative_attention_bias=has_relative_attention_bias, + layer_idx=layer_idx, ) self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config) @@ -986,30 +1041,17 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.self_attention( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states, present_key_value_state = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights @@ -1023,10 +1065,7 @@ def forward( if do_cross_attention: # the actual query length is unknown for cross attention # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None + query_length = cache_position[0] cross_attention_outputs = self.encoder_decoder_attention( hidden_states, @@ -1034,10 +1073,11 @@ def forward( attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, query_length=query_length, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = cross_attention_outputs[0] @@ -1046,10 +1086,6 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -1064,7 +1100,7 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs @@ -1186,6 +1222,9 @@ def forward( more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ PIX2STRUCT_INPUTS_DOCSTRING = r""" @@ -1292,7 +1331,10 @@ def __init__(self, config): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layer = nn.ModuleList( - [Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [ + Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) + for i in range(config.num_layers) + ] ) self.final_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -1363,6 +1405,7 @@ def forward( output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple[torch.FloatTensor, ...], CausalLMOutputWithCrossAttentions]: r""" @@ -1404,24 +1447,54 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if use_cache or past_key_values is not None: + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + past_key_values_length = 0 + if cache_position is not None: + past_key_values_length = cache_position[0] + elif past_key_values is not None: + past_key_values_length = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - if encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length ) + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.layer) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1437,7 +1510,6 @@ def forward( # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions) else None @@ -1446,7 +1518,7 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.layer, past_key_values)): + for i, layer_module in enumerate(self.layer): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] if output_hidden_states: @@ -1461,7 +1533,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -1471,20 +1543,22 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -1492,7 +1566,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -1500,9 +1574,6 @@ def forward( position_bias = layer_outputs[2] if encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1526,13 +1597,19 @@ def forward( loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), labels.contiguous().view(-1)) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ loss, logits, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1542,12 +1619,79 @@ def forward( return CausalLMOutputWithCrossAttentions( loss=loss, logits=logits, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + @add_start_docstrings( "A conditional generation model with a language modeling head. Can be used for sequence generation tasks.", @@ -1614,6 +1758,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" Returns: @@ -1722,6 +1867,7 @@ def forward( output_hidden_states=output_hidden_states, labels=labels, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1751,6 +1897,7 @@ def prepare_inputs_for_generation( cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, + cache_position=None, **kwargs, ): if decoder_attention_mask is None: @@ -1758,7 +1905,10 @@ def prepare_inputs_for_generation( # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + if isinstance(past_key_values, EncoderDecoderCache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + else: + past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: @@ -1769,6 +1919,41 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, remove_prefix_length:] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_ids.shape[1] :] + + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + input_ids = input_ids.contiguous() + + if ( + isinstance(past_key_values, EncoderDecoderCache) + and ( + isinstance(past_key_values.self_attention_cache, StaticCache) + or isinstance(past_key_values.cross_attention_cache, StaticCache) + ) + and decoder_attention_mask is not None + and decoder_attention_mask.ndim == 2 + ): + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + dtype = self.proj_out.weight.dtype + min_dtype = torch.finfo(dtype).min + + decoder_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + decoder_attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.self_attention_cache.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + return { "flattened_patches": flattened_patches, "decoder_input_ids": input_ids, @@ -1780,4 +1965,5 @@ def prepare_inputs_for_generation( "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, + "cache_position": cache_position, } diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 972248daaae5..262ad304b100 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -34,6 +34,8 @@ ) from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -50,6 +52,60 @@ _CONFIG_FOR_DOC = "UdopConfig" +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + UDOP_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -153,6 +209,9 @@ more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ @@ -410,6 +469,8 @@ class UdopPreTrainedModel(PreTrainedModel): config_class = UdopConfig base_model_prefix = "transformer" supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = True _keep_in_fp32_modules = ["wo"] def _init_weights(self, module): @@ -597,7 +658,12 @@ def forward(self, hidden_states): # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Udop class UdopAttention(nn.Module): - def __init__(self, config: UdopConfig, has_relative_attention_bias=False): + def __init__( + self, + config: UdopConfig, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -608,6 +674,13 @@ def __init__(self, config: UdopConfig, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -712,6 +785,7 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). @@ -722,13 +796,7 @@ def forward( batch_size, seq_length = hidden_states.shape[:2] real_seq_length = seq_length - - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + real_seq_length += cache_position[0] if query_length is None else query_length key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] @@ -740,43 +808,34 @@ def unshape(states): """reshape""" return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - # get query states query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + # get past key value + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if key_value_states is not None: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + current_states = key_value_states if key_value_states is not None else hidden_states + if key_value_states is not None and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = shape(self.k(current_states)) + value_states = shape(self.v(current_states)) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not key_value_states is not None else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) # compute scores scores = torch.matmul( @@ -809,12 +868,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: @@ -823,19 +879,20 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: - outputs = outputs + (attn_weights,) + outputs += (attn_weights,) return outputs # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Udop class UdopLayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = UdopAttention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = UdopAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -848,6 +905,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -858,6 +916,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -866,9 +925,9 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Udop class UdopLayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = UdopAttention(config, has_relative_attention_bias=False) + self.EncDecAttention = UdopAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -883,6 +942,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -895,6 +955,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -903,13 +964,17 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Udop class UdopBlock(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(UdopLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append( + UdopLayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + ) if self.is_decoder: - self.layer.append(UdopLayerCrossAttention(config)) + self.layer.append(UdopLayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(UdopLayerFF(config)) @@ -927,34 +992,19 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, self_attn_present_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -970,10 +1020,7 @@ def forward( if do_cross_attention: # the actual query length is unknown for cross attention # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None + query_length = cache_position[0] cross_attention_outputs = self.layer[1]( hidden_states, @@ -981,7 +1028,7 @@ def forward( attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, query_length=query_length, use_cache=use_cache, output_attentions=output_attentions, @@ -997,10 +1044,6 @@ def forward( ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -1019,7 +1062,7 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs @@ -1285,7 +1328,7 @@ def __init__(self, config, embed_tokens=None, embed_patches=None): self.num_layers = config.num_layers self.block = nn.ModuleList( - [UdopBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(self.num_layers)] + [UdopBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(self.num_layers)] ) self.final_layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) @@ -1337,6 +1380,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1398,26 +1442,57 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self) - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if use_cache or past_key_values is not None: + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + past_key_values_length = 0 + if cache_position is not None: + past_key_values_length = cache_position[0] + elif past_key_values is not None: + past_key_values_length = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + if attention_mask is None: + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length + ) + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min if self.is_decoder and encoder_attention_mask is not None: encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) @@ -1426,7 +1501,6 @@ def forward( # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None @@ -1435,34 +1509,35 @@ def forward( position_bias = None else: position_bias = self.relative_bias(attention_mask=attention_mask, bbox=bbox) - position_bias = position_bias + extended_attention_mask + position_bias = position_bias + causal_mask encoder_decoder_position_bias = None hidden_states = inputs_embeds hidden_states = self.dropout(hidden_states) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=head_mask[i], - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) if use_cache is False: # MP fixes layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention weights), @@ -1471,9 +1546,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now @@ -1487,13 +1559,19 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, attention_mask, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1504,12 +1582,79 @@ def forward( return BaseModelOutputWithAttentionMask( last_hidden_state=hidden_states, attention_mask=attention_mask, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + @add_start_docstrings( "The bare UDOP encoder-decoder Transformer outputting raw hidden-states without any specific head on top.", @@ -1583,6 +1728,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[Tensor, ...]: r""" Returns: @@ -1652,6 +1798,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1758,6 +1905,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[Tensor, ...]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1836,6 +1984,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] @@ -1878,11 +2027,59 @@ def prepare_inputs_for_generation( cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, + cache_position=None, **kwargs, ): - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + if isinstance(past_key_values, EncoderDecoderCache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + else: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_ids.shape[1] :] + + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + input_ids = input_ids.contiguous() + + if ( + isinstance(past_key_values, EncoderDecoderCache) + and ( + isinstance(past_key_values.self_attention_cache, StaticCache) + or isinstance(past_key_values.cross_attention_cache, StaticCache) + ) + and attention_mask is not None + and attention_mask.ndim == 2 + ): + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + dtype = self.proj_out.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.self_attention_cache.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) return { "decoder_input_ids": input_ids, @@ -1896,6 +2093,7 @@ def prepare_inputs_for_generation( "bbox": kwargs.get("bbox", None), "pixel_values": kwargs.get("pixel_values", None), "visual_bbox": kwargs.get("visual_bbox", None), + "cache_position": cache_position, } # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._reorder_cache diff --git a/tests/models/udop/test_modeling_udop.py b/tests/models/udop/test_modeling_udop.py index a3ae498606a3..1fd0f01bea38 100644 --- a/tests/models/udop/test_modeling_udop.py +++ b/tests/models/udop/test_modeling_udop.py @@ -37,6 +37,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import UdopEncoderModel, UdopForConditionalGeneration, UdopModel, UdopProcessor @@ -365,6 +366,43 @@ def test_forward_signature(self): expected_arg_names = sorted(expected_arg_names) self.assertListEqual(sorted(arg_names[: len(expected_arg_names)]), expected_arg_names) + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + bbox=input_dict["bbox"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + bbox=input_dict["bbox"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + @unittest.skip( "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" ) @@ -534,6 +572,41 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + @unittest.skip( "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" ) From 0bc8b54d5714de510ffb04a90df7a864ee263ff7 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 30 Sep 2024 16:29:46 +0200 Subject: [PATCH 07/30] more models --- src/transformers/models/mt5/modeling_mt5.py | 1 + .../models/pop2piano/modeling_pop2piano.py | 392 +++++++++++----- .../modeling_switch_transformers.py | 437 +++++++++++++----- tests/models/mt5/test_modeling_mt5.py | 113 +++-- .../test_modeling_switch_transformers.py | 36 ++ 5 files changed, 675 insertions(+), 304 deletions(-) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 0c1359ec1632..e676c52fb747 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -815,6 +815,7 @@ class MT5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True + _supports_quantized_cache = False # enc-dec models don't support yet _supports_static_cache = True _supports_cache_class = True _no_split_modules = ["MT5Block"] diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index c769cff3c454..43ae0dded424 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -25,6 +25,8 @@ from transformers.generation import GenerationConfig from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -64,6 +66,60 @@ _CHECKPOINT_FOR_DOC = "sweetcocoa/pop2piano" +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + POP2PIANO_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -135,6 +191,9 @@ more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ @@ -244,7 +303,12 @@ def forward(self, hidden_states): # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Pop2Piano,t5->pop2piano class Pop2PianoAttention(nn.Module): - def __init__(self, config: Pop2PianoConfig, has_relative_attention_bias=False): + def __init__( + self, + config: Pop2PianoConfig, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -255,6 +319,13 @@ def __init__(self, config: Pop2PianoConfig, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -359,6 +430,7 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). @@ -369,13 +441,7 @@ def forward( batch_size, seq_length = hidden_states.shape[:2] real_seq_length = seq_length - - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + real_seq_length += cache_position[0] if query_length is None else query_length key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] @@ -387,43 +453,34 @@ def unshape(states): """reshape""" return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - # get query states query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + # get past key value + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if key_value_states is not None: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + current_states = key_value_states if key_value_states is not None else hidden_states + if key_value_states is not None and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = shape(self.k(current_states)) + value_states = shape(self.v(current_states)) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not key_value_states is not None else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) # compute scores scores = torch.matmul( @@ -456,12 +513,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: @@ -470,19 +524,20 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: - outputs = outputs + (attn_weights,) + outputs += (attn_weights,) return outputs # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Pop2Piano,t5->pop2piano class Pop2PianoLayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = Pop2PianoAttention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = Pop2PianoAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -495,6 +550,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -505,6 +561,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -513,9 +570,9 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Pop2Piano,t5->pop2piano class Pop2PianoLayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = Pop2PianoAttention(config, has_relative_attention_bias=False) + self.EncDecAttention = Pop2PianoAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -530,6 +587,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -542,6 +600,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -550,13 +609,17 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Pop2Piano,t5->pop2piano class Pop2PianoBlock(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(Pop2PianoLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append( + Pop2PianoLayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + ) if self.is_decoder: - self.layer.append(Pop2PianoLayerCrossAttention(config)) + self.layer.append(Pop2PianoLayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(Pop2PianoLayerFF(config)) @@ -574,34 +637,19 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, self_attn_present_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -617,10 +665,7 @@ def forward( if do_cross_attention: # the actual query length is unknown for cross attention # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None + query_length = cache_position[0] cross_attention_outputs = self.layer[1]( hidden_states, @@ -628,7 +673,7 @@ def forward( attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, query_length=query_length, use_cache=use_cache, output_attentions=output_attentions, @@ -644,10 +689,6 @@ def forward( ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -666,7 +707,7 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs @@ -683,6 +724,8 @@ class Pop2PianoPreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = False supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = True _no_split_modules = ["Pop2PianoBlock"] _keep_in_fp32_modules = ["wo"] @@ -768,7 +811,10 @@ def __init__(self, config, embed_tokens=None): self.is_decoder = config.is_decoder self.block = nn.ModuleList( - [Pop2PianoBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [ + Pop2PianoBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) + for i in range(config.num_layers) + ] ) self.final_layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -802,6 +848,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -824,6 +871,13 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -831,28 +885,58 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if use_cache or past_key_values is not None: + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + past_key_values_length = 0 + if cache_position is not None: + past_key_values_length = cache_position[0] + elif past_key_values is not None: + past_key_values_length = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + if attention_mask is None: + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length + ) + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -865,17 +949,9 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None @@ -884,7 +960,7 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] if output_hidden_states: @@ -894,7 +970,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -904,20 +980,22 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -925,7 +1003,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -933,9 +1011,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -949,12 +1024,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -963,12 +1044,79 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + class Pop2PianoConcatEmbeddingToMel(nn.Module): """Embedding Matrix for `composer` tokens.""" @@ -1121,6 +1269,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1176,6 +1325,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] @@ -1310,7 +1460,7 @@ def prepare_inputs_for_generation( encoder_outputs=None, **kwargs, ): - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: input_ids = input_ids[:, -1:] diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index c5797d4573b7..8a3a3879abdf 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -24,6 +24,8 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( MoEModelOutput, MoEModelOutputWithPastAndCrossAttentions, @@ -55,6 +57,60 @@ #################################################### +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def router_z_loss_func(router_logits: torch.Tensor) -> float: r""" Compute the router z-loss implemented in PyTorch. @@ -354,7 +410,12 @@ def forward(self, hidden_states, output_router_logits): # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->SwitchTransformers class SwitchTransformersAttention(nn.Module): - def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias=False): + def __init__( + self, + config: SwitchTransformersConfig, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -365,6 +426,13 @@ def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -469,6 +537,7 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). @@ -479,13 +548,7 @@ def forward( batch_size, seq_length = hidden_states.shape[:2] real_seq_length = seq_length - - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + real_seq_length += cache_position[0] if query_length is None else query_length key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] @@ -497,43 +560,34 @@ def unshape(states): """reshape""" return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - # get query states query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + # get past key value + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if key_value_states is not None: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + current_states = key_value_states if key_value_states is not None else hidden_states + if key_value_states is not None and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = shape(self.k(current_states)) + value_states = shape(self.v(current_states)) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not key_value_states is not None else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) # compute scores scores = torch.matmul( @@ -566,12 +620,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: @@ -580,20 +631,19 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: - outputs = outputs + (attn_weights,) + outputs += (attn_weights,) return outputs # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->SwitchTransformers class SwitchTransformersLayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.SelfAttention = SwitchTransformersAttention( - config, has_relative_attention_bias=has_relative_attention_bias + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx ) self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -607,6 +657,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -617,6 +668,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -625,9 +677,11 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->SwitchTransformers class SwitchTransformersLayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = SwitchTransformersAttention(config, has_relative_attention_bias=False) + self.EncDecAttention = SwitchTransformersAttention( + config, has_relative_attention_bias=False, layer_idx=layer_idx + ) self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -642,6 +696,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -654,6 +709,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -661,16 +717,18 @@ def forward( class SwitchTransformersBlock(nn.Module): - def __init__(self, config, has_relative_attention_bias=False, is_sparse=False): + def __init__(self, config, has_relative_attention_bias=False, is_sparse=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.is_sparse = is_sparse self.layer = nn.ModuleList() self.layer.append( - SwitchTransformersLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias) + SwitchTransformersLayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) ) if self.is_decoder: - self.layer.append(SwitchTransformersLayerCrossAttention(config)) + self.layer.append(SwitchTransformersLayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(SwitchTransformersLayerFF(config, is_sparse=self.is_sparse)) @@ -689,32 +747,17 @@ def forward( output_attentions=False, output_router_logits=True, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states, present_key_value_state = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights @@ -728,10 +771,7 @@ def forward( if do_cross_attention: # the actual query length is unknown for cross attention # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None + query_length = cache_position[0] cross_attention_outputs = self.layer[1]( hidden_states, @@ -739,10 +779,11 @@ def forward( attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, query_length=query_length, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = cross_attention_outputs[0] @@ -751,10 +792,6 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -774,11 +811,11 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + (router_tuple,) + outputs = outputs + (past_key_value,) + attention_outputs + (router_tuple,) else: outputs = outputs + attention_outputs + (router_tuple,) - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (router_tuple) + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (router_tuple) class SwitchTransformersPreTrainedModel(PreTrainedModel): @@ -790,6 +827,8 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): config_class = SwitchTransformersConfig base_model_prefix = "switch_transformers" supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = True _no_split_modules = ["SwitchTransformersBlock"] @property @@ -896,7 +935,9 @@ def __init__(self, config, embed_tokens=None): is_sparse = (i % sparse_step == 1 or sparse_step == 1) if sparse_step > 0 else False self.block.append( - SwitchTransformersBlock(config, has_relative_attention_bias=bool(i == 0), is_sparse=is_sparse) + SwitchTransformersBlock( + config, has_relative_attention_bias=bool(i == 0), is_sparse=is_sparse, layer_idx=i + ) ) self.final_layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) @@ -929,6 +970,7 @@ def forward( output_hidden_states=None, output_router_logits=True, return_dict=None, + cache_position=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -951,6 +993,13 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -958,28 +1007,58 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if use_cache or past_key_values is not None: + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + past_key_values_length = 0 + if cache_position is not None: + past_key_values_length = cache_position[0] + elif past_key_values is not None: + past_key_values_length = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + if attention_mask is None: + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length + ) + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -992,17 +1071,9 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_router_probs = () if output_router_logits else None @@ -1012,7 +1083,7 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] @@ -1023,7 +1094,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -1033,21 +1104,26 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + output_router_logits, + return_dict, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, ) router_probs = layer_outputs[-1] @@ -1058,7 +1134,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -1066,9 +1142,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1085,12 +1158,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1100,13 +1179,80 @@ def forward( ) return MoEModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, router_probs=all_router_probs, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + SWITCH_TRANSFORMERS_START_DOCSTRING = r""" @@ -1227,6 +1373,9 @@ def forward( should not be returned during inference. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ SWITCH_TRANSFORMERS_ENCODER_INPUTS_DOCSTRING = r""" @@ -1354,6 +1503,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqMoEModelOutput]: r""" Returns: @@ -1434,6 +1584,7 @@ def forward( output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1534,6 +1685,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = True, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqMoEOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1617,6 +1769,7 @@ def forward( output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] @@ -1711,11 +1864,15 @@ def prepare_inputs_for_generation( cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, + cache_position=None, **kwargs, ): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + if isinstance(past_key_values, EncoderDecoderCache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + else: + past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: @@ -1726,7 +1883,40 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, remove_prefix_length:] - output_router_logits = kwargs.get("output_router_logits", True) + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_ids.shape[1] :] + + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + input_ids = input_ids.contiguous() + + if ( + isinstance(past_key_values, EncoderDecoderCache) + and ( + isinstance(past_key_values.self_attention_cache, StaticCache) + or isinstance(past_key_values.cross_attention_cache, StaticCache) + ) + and attention_mask is not None + and attention_mask.ndim == 2 + ): + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + dtype = self.proj_out.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.self_attention_cache.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) return { "decoder_input_ids": input_ids, @@ -1737,7 +1927,8 @@ def prepare_inputs_for_generation( "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, - "output_router_logits": output_router_logits, + "output_router_logits": kwargs.get("output_router_logits", True), + "cache_position": cache_position, } def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): diff --git a/tests/models/mt5/test_modeling_mt5.py b/tests/models/mt5/test_modeling_mt5.py index 4b32b5cec789..9169c4cdef4d 100644 --- a/tests/models/mt5/test_modeling_mt5.py +++ b/tests/models/mt5/test_modeling_mt5.py @@ -607,65 +607,58 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa model.eval() inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) - try: - if model.config.is_encoder_decoder: - model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward - labels = inputs.get("labels", None) - input_names = [ - "attention_mask", - "decoder_attention_mask", - "decoder_input_ids", - "input_features", - "input_ids", - "input_values", - ] - if labels is not None: - input_names.append("labels") - - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) - - model_output = model(**filtered_inputs) - - traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - else: - input_names = [ - "attention_mask", - "bbox", - "input_features", - "input_ids", - "input_values", - "pixel_values", - "token_type_ids", - "visual_feats", - "visual_pos", - ] - - labels = inputs.get("labels", None) - start_positions = inputs.get("start_positions", None) - end_positions = inputs.get("end_positions", None) - if labels is not None: - input_names.append("labels") - if start_positions is not None: - input_names.append("start_positions") - if end_positions is not None: - input_names.append("end_positions") - - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) - - if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( - not hasattr(model.config, "problem_type") or model.config.problem_type is None - ): - model.config.problem_type = "single_label_classification" - - traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - model_output = model(**filtered_inputs) - - except Exception as e: - self.fail(f"Couldn't trace module: {e}") + # try: + if model.config.is_encoder_decoder: + model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward + labels = inputs.get("labels", None) + input_names = [ + "attention_mask", + "decoder_attention_mask", + "decoder_input_ids", + "input_features", + "input_ids", + "input_values", + ] + if labels is not None: + input_names.append("labels") + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + model_output = model(**filtered_inputs) + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + else: + input_names = [ + "attention_mask", + "bbox", + "input_features", + "input_ids", + "input_values", + "pixel_values", + "token_type_ids", + "visual_feats", + "visual_pos", + ] + labels = inputs.get("labels", None) + start_positions = inputs.get("start_positions", None) + end_positions = inputs.get("end_positions", None) + if labels is not None: + input_names.append("labels") + if start_positions is not None: + input_names.append("start_positions") + if end_positions is not None: + input_names.append("end_positions") + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( + not hasattr(model.config, "problem_type") or model.config.problem_type is None + ): + model.config.problem_type = "single_label_classification" + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + model_output = model(**filtered_inputs) + + # except Exception as e: + # self.fail(f"Couldn't trace module: {e}") def flatten_output(output): flatten = [] @@ -712,7 +705,7 @@ def flatten_output(output): # (Even with this call, there are still memory leak by ~0.04MB) self.clear_torch_jit_class_registry() - # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + # overwrite because MT5 doesn't accept position ids as input and expects `decoder_input_ids` def test_custom_4d_attention_mask(self): for model_class in self.all_generative_model_classes: config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 13215b2826fe..7adb1f40c6e6 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -36,6 +36,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import ( AutoTokenizer, @@ -645,6 +646,41 @@ def test_decoder_model_past_with_3d_attn_mask(self): lm_labels, ) + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) From 7c5925b58d9cd54f42e03b951454bb29098540dd Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 30 Sep 2024 17:12:27 +0200 Subject: [PATCH 08/30] fix some tests --- src/transformers/models/t5/modeling_t5.py | 1 + src/transformers/models/umt5/modeling_umt5.py | 2 +- tests/models/rag/test_modeling_rag.py | 2 +- tests/models/udop/test_modeling_udop.py | 1 + 4 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index a36a2796c95c..813b8b426e24 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1053,6 +1053,7 @@ def forward( ) use_cache = False + print(input_ids.shape, self.embed_tokens.weight.data.shape) if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 8a34fbff06ff..4b72ae232f23 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -1528,7 +1528,7 @@ def prepare_inputs_for_generation( batch_size, sequence_length = input_ids.shape device = input_ids.device - dtype = self.proj_out.weight.dtype + dtype = self.get_output_embeddings().weight.dtype min_dtype = torch.finfo(dtype).min decoder_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( diff --git a/tests/models/rag/test_modeling_rag.py b/tests/models/rag/test_modeling_rag.py index 392ff40d7702..d00c06344118 100644 --- a/tests/models/rag/test_modeling_rag.py +++ b/tests/models/rag/test_modeling_rag.py @@ -653,7 +653,7 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase): def config_and_inputs(self): question_encoder_tester = DPRModelTester(self) dpr_config_and_inputs = question_encoder_tester.prepare_config_and_inputs() - generator_tester = T5ModelTester(self, vocab_size=1100) + generator_tester = T5ModelTester(self, vocab_size=1101) t5_config_and_inputs = generator_tester.prepare_config_and_inputs() (question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs diff --git a/tests/models/udop/test_modeling_udop.py b/tests/models/udop/test_modeling_udop.py index 1fd0f01bea38..9d82173b1aed 100644 --- a/tests/models/udop/test_modeling_udop.py +++ b/tests/models/udop/test_modeling_udop.py @@ -349,6 +349,7 @@ def test_forward_signature(self): expected_arg_names = [ "attention_mask", "bbox", + "cache_position", "cross_attn_head_mask", "decoder_attention_mask", "decoder_head_mask", From 038bb1e912132b30df51a292753854ff9eea5014 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 30 Sep 2024 19:34:45 +0200 Subject: [PATCH 09/30] fix onnx tests --- src/transformers/models/longt5/modeling_longt5.py | 6 +++++- src/transformers/models/mt5/modeling_mt5.py | 8 ++++++-- src/transformers/models/pop2piano/modeling_pop2piano.py | 6 +++++- .../switch_transformers/modeling_switch_transformers.py | 6 +++++- src/transformers/models/t5/modeling_t5.py | 9 ++++++--- src/transformers/models/udop/modeling_udop.py | 6 +++++- tests/models/longt5/test_modeling_longt5.py | 2 +- tests/models/pop2piano/test_modeling_pop2piano.py | 2 +- 8 files changed, 34 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 55bbda4e20b2..04425c84c2dc 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -509,6 +509,10 @@ def forward( # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) batch_size, seq_length = hidden_states.shape[:2] + # failure that tensors are not on the same device otherwise + if torch.jit.is_tracing(): + seq_length = seq_length.to(hidden_states.device) + real_seq_length = seq_length real_seq_length += cache_position[0] if query_length is None else query_length @@ -596,7 +600,7 @@ def unshape(states): outputs = (attn_output, past_key_value, position_bias) if output_attentions: - outputs += (attn_weights,) + outputs = outputs + (attn_weights,) return outputs diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index e676c52fb747..587e16abc2b4 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -406,6 +406,10 @@ def forward( # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) batch_size, seq_length = hidden_states.shape[:2] + # failure that tensors are not on the same device otherwise + if torch.jit.is_tracing(): + seq_length = seq_length.to(hidden_states.device) + real_seq_length = seq_length real_seq_length += cache_position[0] if query_length is None else query_length @@ -493,7 +497,7 @@ def unshape(states): outputs = (attn_output, past_key_value, position_bias) if output_attentions: - outputs += (attn_weights,) + outputs = outputs + (attn_weights,) return outputs @@ -2023,7 +2027,7 @@ def prepare_inputs_for_generation( batch_size, sequence_length = input_ids.shape device = input_ids.device - dtype = self.proj_out.weight.dtype + dtype = self.get_output_embeddings().weight.dtype min_dtype = torch.finfo(dtype).min decoder_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 43ae0dded424..ed65ac023431 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -440,6 +440,10 @@ def forward( # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) batch_size, seq_length = hidden_states.shape[:2] + # failure that tensors are not on the same device otherwise + if torch.jit.is_tracing(): + seq_length = seq_length.to(hidden_states.device) + real_seq_length = seq_length real_seq_length += cache_position[0] if query_length is None else query_length @@ -527,7 +531,7 @@ def unshape(states): outputs = (attn_output, past_key_value, position_bias) if output_attentions: - outputs += (attn_weights,) + outputs = outputs + (attn_weights,) return outputs diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 8a3a3879abdf..fc577883c627 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -547,6 +547,10 @@ def forward( # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) batch_size, seq_length = hidden_states.shape[:2] + # failure that tensors are not on the same device otherwise + if torch.jit.is_tracing(): + seq_length = seq_length.to(hidden_states.device) + real_seq_length = seq_length real_seq_length += cache_position[0] if query_length is None else query_length @@ -634,7 +638,7 @@ def unshape(states): outputs = (attn_output, past_key_value, position_bias) if output_attentions: - outputs += (attn_weights,) + outputs = outputs + (attn_weights,) return outputs diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 813b8b426e24..bc911eaafe3d 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -531,6 +531,10 @@ def forward( # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) batch_size, seq_length = hidden_states.shape[:2] + # failure that tensors are not on the same device otherwise + if torch.jit.is_tracing(): + seq_length = seq_length.to(hidden_states.device) + real_seq_length = seq_length real_seq_length += cache_position[0] if query_length is None else query_length @@ -618,7 +622,7 @@ def unshape(states): outputs = (attn_output, past_key_value, position_bias) if output_attentions: - outputs += (attn_weights,) + outputs = outputs + (attn_weights,) return outputs @@ -1053,7 +1057,6 @@ def forward( ) use_cache = False - print(input_ids.shape, self.embed_tokens.weight.data.shape) if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -1997,7 +2000,7 @@ def prepare_inputs_for_generation( batch_size, sequence_length = input_ids.shape device = input_ids.device - dtype = self.proj_out.weight.dtype + dtype = self.get_output_embeddings().weight.dtype min_dtype = torch.finfo(dtype).min decoder_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 262ad304b100..7c219aa494a4 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -795,6 +795,10 @@ def forward( # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) batch_size, seq_length = hidden_states.shape[:2] + # failure that tensors are not on the same device otherwise + if torch.jit.is_tracing(): + seq_length = seq_length.to(hidden_states.device) + real_seq_length = seq_length real_seq_length += cache_position[0] if query_length is None else query_length @@ -882,7 +886,7 @@ def unshape(states): outputs = (attn_output, past_key_value, position_bias) if output_attentions: - outputs += (attn_weights,) + outputs = outputs + (attn_weights,) return outputs diff --git a/tests/models/longt5/test_modeling_longt5.py b/tests/models/longt5/test_modeling_longt5.py index 93b0556ff120..a9d3e7479e95 100644 --- a/tests/models/longt5/test_modeling_longt5.py +++ b/tests/models/longt5/test_modeling_longt5.py @@ -638,7 +638,7 @@ def test_export_to_onnx(self): (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), f"{tmpdirname}/longt5_test.onnx", export_params=True, - opset_version=13, + opset_version=14, input_names=["input_ids", "decoder_input_ids"], ) diff --git a/tests/models/pop2piano/test_modeling_pop2piano.py b/tests/models/pop2piano/test_modeling_pop2piano.py index 3a33b5a98128..39ff67f08ce5 100644 --- a/tests/models/pop2piano/test_modeling_pop2piano.py +++ b/tests/models/pop2piano/test_modeling_pop2piano.py @@ -620,7 +620,7 @@ def test_export_to_onnx(self): (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), f"{tmpdirname}/Pop2Piano_test.onnx", export_params=True, - opset_version=9, + opset_version=14, input_names=["input_ids", "decoder_input_ids"], ) From df988425cb9e873255388106666164cb6779cc8d Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 1 Oct 2024 10:40:12 +0200 Subject: [PATCH 10/30] tracing tests fixed --- .../models/longt5/modeling_longt5.py | 20 +++++++++---------- src/transformers/models/mt5/modeling_mt5.py | 20 +++++++++---------- .../models/pop2piano/modeling_pop2piano.py | 20 +++++++++---------- .../modeling_switch_transformers.py | 20 +++++++++---------- src/transformers/models/t5/modeling_t5.py | 20 +++++++++---------- src/transformers/models/udop/modeling_udop.py | 20 +++++++++---------- 6 files changed, 60 insertions(+), 60 deletions(-) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 04425c84c2dc..2d83a76bc68f 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -513,10 +513,10 @@ def forward( if torch.jit.is_tracing(): seq_length = seq_length.to(hidden_states.device) - real_seq_length = seq_length - real_seq_length += cache_position[0] if query_length is None else query_length + if past_key_value is not None: + seq_length += cache_position[0] if query_length is None else query_length - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + key_length = seq_length if key_value_states is None else key_value_states.shape[1] def shape(states): """projection""" @@ -555,20 +555,18 @@ def unshape(states): key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) + position_bias = self.compute_bias(seq_length, key_length, device=scores.device) # if key and values are already calculated # we want only the last query position bias @@ -576,7 +574,8 @@ def unshape(states): position_bias = position_bias[:, :, -hidden_states.size(1) :, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -586,6 +585,7 @@ def unshape(states): position_bias_masked = position_bias scores += position_bias_masked + # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 587e16abc2b4..5d7d69f98167 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -410,10 +410,10 @@ def forward( if torch.jit.is_tracing(): seq_length = seq_length.to(hidden_states.device) - real_seq_length = seq_length - real_seq_length += cache_position[0] if query_length is None else query_length + if past_key_value is not None: + seq_length += cache_position[0] if query_length is None else query_length - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + key_length = seq_length if key_value_states is None else key_value_states.shape[1] def shape(states): """projection""" @@ -452,20 +452,18 @@ def unshape(states): key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) + position_bias = self.compute_bias(seq_length, key_length, device=scores.device) # if key and values are already calculated # we want only the last query position bias @@ -473,7 +471,8 @@ def unshape(states): position_bias = position_bias[:, :, -hidden_states.size(1) :, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -483,6 +482,7 @@ def unshape(states): position_bias_masked = position_bias scores += position_bias_masked + # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index ed65ac023431..5cdf5d26009d 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -444,10 +444,10 @@ def forward( if torch.jit.is_tracing(): seq_length = seq_length.to(hidden_states.device) - real_seq_length = seq_length - real_seq_length += cache_position[0] if query_length is None else query_length + if past_key_value is not None: + seq_length += cache_position[0] if query_length is None else query_length - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + key_length = seq_length if key_value_states is None else key_value_states.shape[1] def shape(states): """projection""" @@ -486,20 +486,18 @@ def unshape(states): key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) + position_bias = self.compute_bias(seq_length, key_length, device=scores.device) # if key and values are already calculated # we want only the last query position bias @@ -507,7 +505,8 @@ def unshape(states): position_bias = position_bias[:, :, -hidden_states.size(1) :, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -517,6 +516,7 @@ def unshape(states): position_bias_masked = position_bias scores += position_bias_masked + # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index fc577883c627..a2ffe9be5f13 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -551,10 +551,10 @@ def forward( if torch.jit.is_tracing(): seq_length = seq_length.to(hidden_states.device) - real_seq_length = seq_length - real_seq_length += cache_position[0] if query_length is None else query_length + if past_key_value is not None: + seq_length += cache_position[0] if query_length is None else query_length - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + key_length = seq_length if key_value_states is None else key_value_states.shape[1] def shape(states): """projection""" @@ -593,20 +593,18 @@ def unshape(states): key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) + position_bias = self.compute_bias(seq_length, key_length, device=scores.device) # if key and values are already calculated # we want only the last query position bias @@ -614,7 +612,8 @@ def unshape(states): position_bias = position_bias[:, :, -hidden_states.size(1) :, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -624,6 +623,7 @@ def unshape(states): position_bias_masked = position_bias scores += position_bias_masked + # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index bc911eaafe3d..a8515aa2543a 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -535,10 +535,10 @@ def forward( if torch.jit.is_tracing(): seq_length = seq_length.to(hidden_states.device) - real_seq_length = seq_length - real_seq_length += cache_position[0] if query_length is None else query_length + if past_key_value is not None: + seq_length += cache_position[0] if query_length is None else query_length - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + key_length = seq_length if key_value_states is None else key_value_states.shape[1] def shape(states): """projection""" @@ -577,20 +577,18 @@ def unshape(states): key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) + position_bias = self.compute_bias(seq_length, key_length, device=scores.device) # if key and values are already calculated # we want only the last query position bias @@ -598,7 +596,8 @@ def unshape(states): position_bias = position_bias[:, :, -hidden_states.size(1) :, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -608,6 +607,7 @@ def unshape(states): position_bias_masked = position_bias scores += position_bias_masked + # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 7c219aa494a4..e4eb72e5b6c4 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -799,10 +799,10 @@ def forward( if torch.jit.is_tracing(): seq_length = seq_length.to(hidden_states.device) - real_seq_length = seq_length - real_seq_length += cache_position[0] if query_length is None else query_length + if past_key_value is not None: + seq_length += cache_position[0] if query_length is None else query_length - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + key_length = seq_length if key_value_states is None else key_value_states.shape[1] def shape(states): """projection""" @@ -841,20 +841,18 @@ def unshape(states): key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) + position_bias = self.compute_bias(seq_length, key_length, device=scores.device) # if key and values are already calculated # we want only the last query position bias @@ -862,7 +860,8 @@ def unshape(states): position_bias = position_bias[:, :, -hidden_states.size(1) :, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -872,6 +871,7 @@ def unshape(states): position_bias_masked = position_bias scores += position_bias_masked + # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) From 0544b65ab4699e05f776d9110129e39b8463ccd1 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 1 Oct 2024 14:36:36 +0200 Subject: [PATCH 11/30] compile enabled and tested for t5 models --- src/transformers/generation/utils.py | 2 +- .../models/longt5/configuration_longt5.py | 7 +++++- .../models/longt5/modeling_longt5.py | 22 +++++++++++-------- .../models/mt5/configuration_mt5.py | 7 +++++- src/transformers/models/mt5/modeling_mt5.py | 21 +++++++++++------- .../models/pix2struct/modeling_pix2struct.py | 14 +++++++----- .../models/pop2piano/modeling_pop2piano.py | 12 +++++----- .../modeling_switch_transformers.py | 14 +++++++----- .../models/t5/configuration_t5.py | 7 +++++- src/transformers/models/t5/modeling_t5.py | 21 +++++++++++------- src/transformers/models/udop/modeling_udop.py | 14 +++++++----- .../models/umt5/configuration_umt5.py | 7 +++++- src/transformers/models/umt5/modeling_umt5.py | 15 +++++++++---- tests/models/mt5/test_modeling_mt5.py | 3 +++ tests/models/t5/test_modeling_t5.py | 3 +++ tests/models/umt5/test_modeling_umt5.py | 3 +++ tests/test_modeling_common.py | 11 ++++++++-- 17 files changed, 125 insertions(+), 58 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 47ec185508f8..06af7cc26a49 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1477,7 +1477,7 @@ def get_layer_device_map(execution_device_map: Optional[dict] = None): cache_kwargs = { "config": self.config if hasattr(self.config, "text_config") else self.config, - "max_batch_size": batch_size, + "batch_size": batch_size, "max_cache_len": max_cache_len, "device": device, "dtype": cache_dtype, diff --git a/src/transformers/models/longt5/configuration_longt5.py b/src/transformers/models/longt5/configuration_longt5.py index 0e541ae2a1b4..b6e7d21b3d67 100644 --- a/src/transformers/models/longt5/configuration_longt5.py +++ b/src/transformers/models/longt5/configuration_longt5.py @@ -79,7 +79,12 @@ class LongT5Config(PretrainedConfig): model_type = "longt5" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "head_dim": "d_kv", + } def __init__( self, diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 2d83a76bc68f..4ec4b8f0f46a 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -40,6 +40,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -513,11 +514,6 @@ def forward( if torch.jit.is_tracing(): seq_length = seq_length.to(hidden_states.device) - if past_key_value is not None: - seq_length += cache_position[0] if query_length is None else query_length - - key_length = seq_length if key_value_states is None else key_value_states.shape[1] - def shape(states): """projection""" return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) @@ -539,6 +535,13 @@ def unshape(states): else: past_key_value = past_key_value.self_attention_cache + if isinstance(past_key_value, StaticCache): + seq_length = past_key_value.get_max_length() + elif past_key_value is not None: + seq_length += cache_position[0] if query_length is None else query_length + + key_length = seq_length if key_value_states is None else key_value_states.shape[1] + # get key/value states current_states = key_value_states if key_value_states is not None else hidden_states if key_value_states is not None and past_key_value and is_updated: @@ -1314,7 +1317,7 @@ class LongT5PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LongT5Block"] _supports_cache_class = True - _supports_static_cache = True + _supports_static_cache = False # TODO: @raushan more involvede due to local/global attn @property # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs @@ -1518,7 +1521,7 @@ def forward( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - if attention_mask is None: + if attention_mask is None and not is_torchdynamo_compiling(): # required mask seq length can be calculated via length of past mask_seq_length = ( past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length @@ -2272,7 +2275,7 @@ def prepare_inputs_for_generation( # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - input_ids = input_ids.contiguous() + input_ids = input_ids.clone(memory_format=torch.contiguous_format) if ( isinstance(past_key_values, EncoderDecoderCache) @@ -2286,7 +2289,7 @@ def prepare_inputs_for_generation( batch_size, sequence_length = input_ids.shape device = input_ids.device - dtype = self.proj_out.weight.dtype + dtype = self.get_output_embeddings().weight.dtype min_dtype = torch.finfo(dtype).min attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( @@ -2309,6 +2312,7 @@ def prepare_inputs_for_generation( "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, + "cache_position": cache_position, } def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): diff --git a/src/transformers/models/mt5/configuration_mt5.py b/src/transformers/models/mt5/configuration_mt5.py index ef629718b1b5..267179f81247 100644 --- a/src/transformers/models/mt5/configuration_mt5.py +++ b/src/transformers/models/mt5/configuration_mt5.py @@ -72,7 +72,12 @@ class MT5Config(PretrainedConfig): model_type = "mt5" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "head_dim": "d_kv", + } def __init__( self, diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 5d7d69f98167..f790507636c6 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -44,6 +44,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -410,11 +411,6 @@ def forward( if torch.jit.is_tracing(): seq_length = seq_length.to(hidden_states.device) - if past_key_value is not None: - seq_length += cache_position[0] if query_length is None else query_length - - key_length = seq_length if key_value_states is None else key_value_states.shape[1] - def shape(states): """projection""" return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) @@ -436,6 +432,13 @@ def unshape(states): else: past_key_value = past_key_value.self_attention_cache + if isinstance(past_key_value, StaticCache): + seq_length = past_key_value.get_max_length() + elif past_key_value is not None: + seq_length += cache_position[0] if query_length is None else query_length + + key_length = seq_length if key_value_states is None else key_value_states.shape[1] + # get key/value states current_states = key_value_states if key_value_states is not None else hidden_states if key_value_states is not None and past_key_value and is_updated: @@ -1084,7 +1087,7 @@ def forward( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - if attention_mask is None: + if attention_mask is None and not is_torchdynamo_compiling(): # required mask seq length can be calculated via length of past mask_seq_length = ( past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length @@ -1099,10 +1102,12 @@ def forward( past_key_values.self_attention_cache if past_key_values is not None else None, output_attentions, ) - else: + elif attention_mask is not None: causal_mask = attention_mask[:, None, None, :] causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + else: + causal_mask = None # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -2013,7 +2018,7 @@ def prepare_inputs_for_generation( # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - input_ids = input_ids.contiguous() + input_ids = input_ids.clone(memory_format=torch.contiguous_format) if ( isinstance(past_key_values, EncoderDecoderCache) diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 27f628765a46..9630c3707d10 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -39,6 +39,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -239,14 +240,17 @@ def to_projection_shape(states): if self.gradient_checkpointing and self.training: position_bias.requires_grad = True - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length), device=scores.device, dtype=scores.dtype) - if attention_mask.dim() == 2: position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device) - else: + elif attention_mask is not None: # (batch_size, n_heads, seq_length, key_length) position_bias = position_bias + attention_mask.to(position_bias.device) + elif not is_torchdynamo_compiling(): + attention_mask = torch.ones( + (batch_size, seq_length), device=position_bias.device, dtype=position_bias.dtype + ) + position_bias = position_bias + attention_mask.to(position_bias.device) + position_bias = 1 - position_bias position_bias_masked = position_bias.masked_fill(position_bias == 1, torch.finfo(scores.dtype).min) @@ -1926,7 +1930,7 @@ def prepare_inputs_for_generation( # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - input_ids = input_ids.contiguous() + input_ids = input_ids.clone(memory_format=torch.contiguous_format) if ( isinstance(past_key_values, EncoderDecoderCache) diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 5cdf5d26009d..819642983904 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -444,11 +444,6 @@ def forward( if torch.jit.is_tracing(): seq_length = seq_length.to(hidden_states.device) - if past_key_value is not None: - seq_length += cache_position[0] if query_length is None else query_length - - key_length = seq_length if key_value_states is None else key_value_states.shape[1] - def shape(states): """projection""" return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) @@ -470,6 +465,13 @@ def unshape(states): else: past_key_value = past_key_value.self_attention_cache + if isinstance(past_key_value, StaticCache): + seq_length = past_key_value.get_max_length() + elif past_key_value is not None: + seq_length += cache_position[0] if query_length is None else query_length + + key_length = seq_length if key_value_states is None else key_value_states.shape[1] + # get key/value states current_states = key_value_states if key_value_states is not None else hidden_states if key_value_states is not None and past_key_value and is_updated: diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index a2ffe9be5f13..1f7ca762c8a4 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -551,11 +551,6 @@ def forward( if torch.jit.is_tracing(): seq_length = seq_length.to(hidden_states.device) - if past_key_value is not None: - seq_length += cache_position[0] if query_length is None else query_length - - key_length = seq_length if key_value_states is None else key_value_states.shape[1] - def shape(states): """projection""" return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) @@ -577,6 +572,13 @@ def unshape(states): else: past_key_value = past_key_value.self_attention_cache + if isinstance(past_key_value, StaticCache): + seq_length = past_key_value.get_max_length() + elif past_key_value is not None: + seq_length += cache_position[0] if query_length is None else query_length + + key_length = seq_length if key_value_states is None else key_value_states.shape[1] + # get key/value states current_states = key_value_states if key_value_states is not None else hidden_states if key_value_states is not None and past_key_value and is_updated: @@ -1894,7 +1896,7 @@ def prepare_inputs_for_generation( # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - input_ids = input_ids.contiguous() + input_ids = input_ids.clone(memory_format=torch.contiguous_format) if ( isinstance(past_key_values, EncoderDecoderCache) diff --git a/src/transformers/models/t5/configuration_t5.py b/src/transformers/models/t5/configuration_t5.py index e5f2615611b8..be6fbe9528d1 100644 --- a/src/transformers/models/t5/configuration_t5.py +++ b/src/transformers/models/t5/configuration_t5.py @@ -73,7 +73,12 @@ class T5Config(PretrainedConfig): model_type = "t5" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "head_dim": "d_kv", + } def __init__( self, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index a8515aa2543a..d6f10738ee9e 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -44,6 +44,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -535,11 +536,6 @@ def forward( if torch.jit.is_tracing(): seq_length = seq_length.to(hidden_states.device) - if past_key_value is not None: - seq_length += cache_position[0] if query_length is None else query_length - - key_length = seq_length if key_value_states is None else key_value_states.shape[1] - def shape(states): """projection""" return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) @@ -561,6 +557,13 @@ def unshape(states): else: past_key_value = past_key_value.self_attention_cache + if isinstance(past_key_value, StaticCache): + seq_length = past_key_value.get_max_length() + elif past_key_value is not None: + seq_length += cache_position[0] if query_length is None else query_length + + key_length = seq_length if key_value_states is None else key_value_states.shape[1] + # get key/value states current_states = key_value_states if key_value_states is not None else hidden_states if key_value_states is not None and past_key_value and is_updated: @@ -1097,7 +1100,7 @@ def forward( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - if attention_mask is None: + if attention_mask is None and not is_torchdynamo_compiling(): # required mask seq length can be calculated via length of past mask_seq_length = ( past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length @@ -1112,10 +1115,12 @@ def forward( past_key_values.self_attention_cache if past_key_values is not None else None, output_attentions, ) - else: + elif attention_mask is not None: causal_mask = attention_mask[:, None, None, :] causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + else: + causal_mask = None # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1986,7 +1991,7 @@ def prepare_inputs_for_generation( # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - input_ids = input_ids.contiguous() + input_ids = input_ids.clone(memory_format=torch.contiguous_format) if ( isinstance(past_key_values, EncoderDecoderCache) diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index e4eb72e5b6c4..3bd137922fa6 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -799,11 +799,6 @@ def forward( if torch.jit.is_tracing(): seq_length = seq_length.to(hidden_states.device) - if past_key_value is not None: - seq_length += cache_position[0] if query_length is None else query_length - - key_length = seq_length if key_value_states is None else key_value_states.shape[1] - def shape(states): """projection""" return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) @@ -825,6 +820,13 @@ def unshape(states): else: past_key_value = past_key_value.self_attention_cache + if isinstance(past_key_value, StaticCache): + seq_length = past_key_value.get_max_length() + elif past_key_value is not None: + seq_length += cache_position[0] if query_length is None else query_length + + key_length = seq_length if key_value_states is None else key_value_states.shape[1] + # get key/value states current_states = key_value_states if key_value_states is not None else hidden_states if key_value_states is not None and past_key_value and is_updated: @@ -2057,7 +2059,7 @@ def prepare_inputs_for_generation( # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - input_ids = input_ids.contiguous() + input_ids = input_ids.clone(memory_format=torch.contiguous_format) if ( isinstance(past_key_values, EncoderDecoderCache) diff --git a/src/transformers/models/umt5/configuration_umt5.py b/src/transformers/models/umt5/configuration_umt5.py index d7323d759fd0..ba8ea0460ba0 100644 --- a/src/transformers/models/umt5/configuration_umt5.py +++ b/src/transformers/models/umt5/configuration_umt5.py @@ -72,7 +72,12 @@ class UMT5Config(PretrainedConfig): model_type = "umt5" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "head_dim": "d_kv", + } def __init__( self, diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 4b72ae232f23..ec8a8771f995 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -41,6 +41,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -347,8 +348,11 @@ def forward( # compute positional bias if self.has_relative_attention_bias: query_length = seq_length - if past_key_value is not None: + if isinstance(past_key_value, StaticCache): + query_length = past_key_value.get_max_length() + elif past_key_value is not None: query_length += past_key_value.get_seq_length() + position_bias = self.compute_bias(query_length, key_states.size(2), device=attention_scores.device) else: position_bias = torch.zeros( @@ -755,7 +759,7 @@ def forward( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - if attention_mask is None: + if attention_mask is None and not is_torchdynamo_compiling(): # required mask seq length can be calculated via length of past mask_seq_length = ( past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length @@ -770,10 +774,12 @@ def forward( past_key_values.self_attention_cache if past_key_values is not None else None, output_attentions, ) - else: + elif attention_mask is not None: causal_mask = attention_mask[:, None, None, :] causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + else: + causal_mask = None # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1442,6 +1448,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] @@ -1514,7 +1521,7 @@ def prepare_inputs_for_generation( # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - input_ids = input_ids.contiguous() + input_ids = input_ids.clone(memory_format=torch.contiguous_format) if ( isinstance(past_key_values, EncoderDecoderCache) diff --git a/tests/models/mt5/test_modeling_mt5.py b/tests/models/mt5/test_modeling_mt5.py index 9169c4cdef4d..75346ad6b572 100644 --- a/tests/models/mt5/test_modeling_mt5.py +++ b/tests/models/mt5/test_modeling_mt5.py @@ -576,6 +576,9 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, # The small MT5 model needs higher percentages for CPU/MP tests model_split_percents = [0.5, 0.8, 0.9] + # used in `test_torch_compile` + _torch_compile_test_ckpt = "google/mt5-small" + def setUp(self): self.model_tester = MT5ModelTester(self) self.config_tester = ConfigTester(self, config_class=MT5Config, d_model=37) diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index f38e8c58cbbb..3d34e91d243c 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -579,6 +579,9 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, # The small T5 model needs higher percentages for CPU/MP tests model_split_percents = [0.5, 0.8, 0.9] + # used in `test_torch_compile` + _torch_compile_test_ckpt = "google-t5/t5-small" + def setUp(self): self.model_tester = T5ModelTester(self) self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37) diff --git a/tests/models/umt5/test_modeling_umt5.py b/tests/models/umt5/test_modeling_umt5.py index cfb9d2be4a48..8cde27ceceea 100644 --- a/tests/models/umt5/test_modeling_umt5.py +++ b/tests/models/umt5/test_modeling_umt5.py @@ -317,6 +317,9 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin # The small UMT5 model needs higher percentages for CPU/MP tests model_split_percents = [0.5, 0.8, 0.9] + # used in `test_torch_compile` + _torch_compile_test_ckpt = "google/umt5-small" + def setUp(self): self.model_tester = UMT5ModelTester(self) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index c7af0b1c9f5b..d28b41b24d04 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -36,6 +36,7 @@ from transformers import ( AutoModel, AutoModelForCausalLM, + AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig, @@ -4764,7 +4765,10 @@ def test_torch_compile(self): n_iter = 3 tokenizer = AutoTokenizer.from_pretrained(ckpt) - model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device) + if self.is_encoder_decoder: + model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device) + else: + model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device) model.generation_config.max_new_tokens = 4 @@ -4793,7 +4797,10 @@ def test_compile_cuda_graph_time(self): os.environ["TOKENIZERS_PARALLELISM"] = "false" tokenizer = AutoTokenizer.from_pretrained(ckpt) - model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device) + if self.is_encoder_decoder: + model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device) + else: + model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device) cache_implementation = "static" if model.config.model_type == "gemma2": From 1063971776f8da4386e95d8d0d8cd64051050291 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 1 Oct 2024 14:53:43 +0200 Subject: [PATCH 12/30] fix small bug in slow tests --- tests/models/t5/test_modeling_t5.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index 3d34e91d243c..d3f2c686de35 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -1487,6 +1487,7 @@ def test_summarization(self): [model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]], padding="max_length", truncation=True, + max_length=512, return_tensors="pt", ).to(torch_device) self.assertEqual(512, dct["input_ids"].shape[1]) From 0e7fb50a08a462bd4329d98aa2e7bab4ec39df59 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 1 Oct 2024 14:54:14 +0200 Subject: [PATCH 13/30] [run-slow] t5 From c4ccdeacc039427f29190ea2c5a072a3b987bcf3 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 1 Oct 2024 15:10:55 +0200 Subject: [PATCH 14/30] uncomment --- tests/models/mt5/test_modeling_mt5.py | 104 +++++++++++++------------- tests/models/t5/test_modeling_t5.py | 104 +++++++++++++------------- 2 files changed, 104 insertions(+), 104 deletions(-) diff --git a/tests/models/mt5/test_modeling_mt5.py b/tests/models/mt5/test_modeling_mt5.py index 75346ad6b572..75bb919c6ea9 100644 --- a/tests/models/mt5/test_modeling_mt5.py +++ b/tests/models/mt5/test_modeling_mt5.py @@ -610,58 +610,58 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa model.eval() inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) - # try: - if model.config.is_encoder_decoder: - model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward - labels = inputs.get("labels", None) - input_names = [ - "attention_mask", - "decoder_attention_mask", - "decoder_input_ids", - "input_features", - "input_ids", - "input_values", - ] - if labels is not None: - input_names.append("labels") - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) - model_output = model(**filtered_inputs) - traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - else: - input_names = [ - "attention_mask", - "bbox", - "input_features", - "input_ids", - "input_values", - "pixel_values", - "token_type_ids", - "visual_feats", - "visual_pos", - ] - labels = inputs.get("labels", None) - start_positions = inputs.get("start_positions", None) - end_positions = inputs.get("end_positions", None) - if labels is not None: - input_names.append("labels") - if start_positions is not None: - input_names.append("start_positions") - if end_positions is not None: - input_names.append("end_positions") - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) - if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( - not hasattr(model.config, "problem_type") or model.config.problem_type is None - ): - model.config.problem_type = "single_label_classification" - traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - model_output = model(**filtered_inputs) - - # except Exception as e: - # self.fail(f"Couldn't trace module: {e}") + try: + if model.config.is_encoder_decoder: + model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward + labels = inputs.get("labels", None) + input_names = [ + "attention_mask", + "decoder_attention_mask", + "decoder_input_ids", + "input_features", + "input_ids", + "input_values", + ] + if labels is not None: + input_names.append("labels") + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + model_output = model(**filtered_inputs) + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + else: + input_names = [ + "attention_mask", + "bbox", + "input_features", + "input_ids", + "input_values", + "pixel_values", + "token_type_ids", + "visual_feats", + "visual_pos", + ] + labels = inputs.get("labels", None) + start_positions = inputs.get("start_positions", None) + end_positions = inputs.get("end_positions", None) + if labels is not None: + input_names.append("labels") + if start_positions is not None: + input_names.append("start_positions") + if end_positions is not None: + input_names.append("end_positions") + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( + not hasattr(model.config, "problem_type") or model.config.problem_type is None + ): + model.config.problem_type = "single_label_classification" + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + model_output = model(**filtered_inputs) + + except Exception as e: + self.fail(f"Couldn't trace module: {e}") def flatten_output(output): flatten = [] diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index d3f2c686de35..d56ba0ef48b4 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -613,58 +613,58 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa model.eval() inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) - # try: - if model.config.is_encoder_decoder: - model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward - labels = inputs.get("labels", None) - input_names = [ - "attention_mask", - "decoder_attention_mask", - "decoder_input_ids", - "input_features", - "input_ids", - "input_values", - ] - if labels is not None: - input_names.append("labels") - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) - model_output = model(**filtered_inputs) - traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - else: - input_names = [ - "attention_mask", - "bbox", - "input_features", - "input_ids", - "input_values", - "pixel_values", - "token_type_ids", - "visual_feats", - "visual_pos", - ] - labels = inputs.get("labels", None) - start_positions = inputs.get("start_positions", None) - end_positions = inputs.get("end_positions", None) - if labels is not None: - input_names.append("labels") - if start_positions is not None: - input_names.append("start_positions") - if end_positions is not None: - input_names.append("end_positions") - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) - if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( - not hasattr(model.config, "problem_type") or model.config.problem_type is None - ): - model.config.problem_type = "single_label_classification" - traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - model_output = model(**filtered_inputs) - - # except Exception as e: - # self.fail(f"Couldn't trace module: {e}") + try: + if model.config.is_encoder_decoder: + model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward + labels = inputs.get("labels", None) + input_names = [ + "attention_mask", + "decoder_attention_mask", + "decoder_input_ids", + "input_features", + "input_ids", + "input_values", + ] + if labels is not None: + input_names.append("labels") + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + model_output = model(**filtered_inputs) + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + else: + input_names = [ + "attention_mask", + "bbox", + "input_features", + "input_ids", + "input_values", + "pixel_values", + "token_type_ids", + "visual_feats", + "visual_pos", + ] + labels = inputs.get("labels", None) + start_positions = inputs.get("start_positions", None) + end_positions = inputs.get("end_positions", None) + if labels is not None: + input_names.append("labels") + if start_positions is not None: + input_names.append("start_positions") + if end_positions is not None: + input_names.append("end_positions") + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( + not hasattr(model.config, "problem_type") or model.config.problem_type is None + ): + model.config.problem_type = "single_label_classification" + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + model_output = model(**filtered_inputs) + + except Exception as e: + self.fail(f"Couldn't trace module: {e}") def flatten_output(output): flatten = [] From 993f3186df82919be2fb70cadd83f14de7bacd0d Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 1 Oct 2024 15:12:46 +0200 Subject: [PATCH 15/30] style --- src/transformers/models/longt5/modeling_longt5.py | 2 +- src/transformers/models/mt5/modeling_mt5.py | 2 +- src/transformers/models/pix2struct/modeling_pix2struct.py | 2 +- src/transformers/models/pop2piano/modeling_pop2piano.py | 2 +- .../models/switch_transformers/modeling_switch_transformers.py | 2 +- src/transformers/models/t5/modeling_t5.py | 2 +- src/transformers/models/udop/modeling_udop.py | 2 +- src/transformers/models/umt5/modeling_umt5.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 539abb10bf22..f91c8b3dac1f 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -25,8 +25,8 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache -from ...modeling_attn_mask_utils import AttentionMaskConverter from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 0babc7b0aeab..301e6391ea3b 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -26,8 +26,8 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache -from ...modeling_attn_mask_utils import AttentionMaskConverter from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 1cb49372919d..7be5544f5f8e 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -23,8 +23,8 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache -from ...modeling_attn_mask_utils import AttentionMaskConverter from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index e9385a264f95..fc655827052c 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -26,8 +26,8 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache -from ...modeling_attn_mask_utils import AttentionMaskConverter from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 6e775a776a75..1e04f806e45f 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -25,8 +25,8 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache -from ...modeling_attn_mask_utils import AttentionMaskConverter from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( MoEModelOutput, MoEModelOutputWithPastAndCrossAttentions, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index c39b152ede73..8f0a679c2cbc 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -26,8 +26,8 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache -from ...modeling_attn_mask_utils import AttentionMaskConverter from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 8cae743eb4f2..02619d98ab03 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -35,8 +35,8 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache -from ...modeling_attn_mask_utils import AttentionMaskConverter from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index c1021c9fe388..570444bf156b 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -24,8 +24,8 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache -from ...modeling_attn_mask_utils import AttentionMaskConverter from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, From 41911b74c47f9efd62d21cbde25c579fbd930251 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 1 Oct 2024 15:37:29 +0200 Subject: [PATCH 16/30] update with new generation refactoring --- .../models/longt5/modeling_longt5.py | 117 ++++++++--------- src/transformers/models/mt5/modeling_mt5.py | 114 ++++++++--------- .../models/pix2struct/modeling_pix2struct.py | 119 +++++++++--------- .../models/pop2piano/modeling_pop2piano.py | 117 ++++++++--------- .../modeling_switch_transformers.py | 119 +++++++++--------- src/transformers/models/t5/modeling_t5.py | 117 ++++++++--------- src/transformers/models/udop/modeling_udop.py | 119 +++++++++--------- src/transformers/models/umt5/modeling_umt5.py | 117 ++++++++--------- 8 files changed, 474 insertions(+), 465 deletions(-) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index f91c8b3dac1f..f2b445cc4d7f 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -56,60 +56,6 @@ # TODO: Update before the merge -# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - def _pad_to_multiple(x: torch.Tensor, block_len: int, dim: int, pad_value: int = 0) -> torch.Tensor: """Pad a tensor so that a sequence length will be a multiple of `block_len`""" pad_len = -x.shape[dim] % block_len @@ -1689,7 +1635,6 @@ def _update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() @@ -1701,13 +1646,12 @@ def _update_causal_mask( ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1721,10 +1665,67 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + LONGT5_START_DOCSTRING = r""" @@ -2293,7 +2294,7 @@ def prepare_inputs_for_generation( dtype = self.get_output_embeddings().weight.dtype min_dtype = torch.finfo(dtype).min - attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=past_key_values.self_attention_cache.get_max_length(), diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 301e6391ea3b..4eb699f8619e 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -117,60 +117,6 @@ """ -# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->MT5 class MT5LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -1293,7 +1239,7 @@ def _update_causal_mask( ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, @@ -1317,6 +1263,62 @@ def _update_causal_mask( return causal_mask + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + MT5_START_DOCSTRING = r""" @@ -2036,7 +2038,7 @@ def prepare_inputs_for_generation( dtype = self.get_output_embeddings().weight.dtype min_dtype = torch.finfo(dtype).min - decoder_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + decoder_attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( decoder_attention_mask, sequence_length=sequence_length, target_length=past_key_values.self_attention_cache.get_max_length(), diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 7be5544f5f8e..5ab8ae1e4484 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -53,60 +53,6 @@ _CONFIG_FOR_DOC = "Pix2StructConfig" -# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - # Adapted from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Pix2Struct class Pix2StructLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -416,7 +362,7 @@ class Pix2StructPreTrainedModel(PreTrainedModel): config_class = Pix2StructConfig _supports_cache_class = True - _supports_static_cache = True + _supports_static_cache = False @property def dummy_inputs(self): @@ -1661,7 +1607,6 @@ def _update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() @@ -1673,13 +1618,12 @@ def _update_causal_mask( ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1693,10 +1637,67 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + @add_start_docstrings( "A conditional generation model with a language modeling head. Can be used for sequence generation tasks.", @@ -1948,7 +1949,7 @@ def prepare_inputs_for_generation( dtype = self.proj_out.weight.dtype min_dtype = torch.finfo(dtype).min - decoder_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + decoder_attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( decoder_attention_mask, sequence_length=sequence_length, target_length=past_key_values.self_attention_cache.get_max_length(), diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index fc655827052c..f1fc46f89598 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -67,60 +67,6 @@ _CHECKPOINT_FOR_DOC = "sweetcocoa/pop2piano" -# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - POP2PIANO_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -732,7 +678,7 @@ class Pop2PianoPreTrainedModel(PreTrainedModel): is_parallelizable = False supports_gradient_checkpointing = True _supports_cache_class = True - _supports_static_cache = True + _supports_static_cache = False _no_split_modules = ["Pop2PianoBlock"] _keep_in_fp32_modules = ["wo"] @@ -1088,7 +1034,6 @@ def _update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() @@ -1100,13 +1045,12 @@ def _update_causal_mask( ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1120,10 +1064,67 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + class Pop2PianoConcatEmbeddingToMel(nn.Module): """Embedding Matrix for `composer` tokens.""" diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 1e04f806e45f..6ebaf1a14e1f 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -58,60 +58,6 @@ #################################################### -# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - def router_z_loss_func(router_logits: torch.Tensor) -> float: r""" Compute the router z-loss implemented in PyTorch. @@ -835,7 +781,7 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): base_model_prefix = "switch_transformers" supports_gradient_checkpointing = True _supports_cache_class = True - _supports_static_cache = True + _supports_static_cache = False _no_split_modules = ["SwitchTransformersBlock"] @property @@ -1224,7 +1170,6 @@ def _update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() @@ -1236,13 +1181,12 @@ def _update_causal_mask( ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1256,10 +1200,67 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + SWITCH_TRANSFORMERS_START_DOCSTRING = r""" @@ -1914,7 +1915,7 @@ def prepare_inputs_for_generation( dtype = self.proj_out.weight.dtype min_dtype = torch.finfo(dtype).min - attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=past_key_values.self_attention_cache.get_max_length(), diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 8f0a679c2cbc..f2a9dd8d480a 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -231,60 +231,6 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): """ -# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - class T5LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -1294,7 +1240,6 @@ def _update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() @@ -1306,13 +1251,12 @@ def _update_causal_mask( ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1326,10 +1270,67 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + T5_START_DOCSTRING = r""" @@ -2009,7 +2010,7 @@ def prepare_inputs_for_generation( dtype = self.get_output_embeddings().weight.dtype min_dtype = torch.finfo(dtype).min - decoder_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + decoder_attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( decoder_attention_mask, sequence_length=sequence_length, target_length=past_key_values.self_attention_cache.get_max_length(), diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 02619d98ab03..a34c1c2323b6 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -53,60 +53,6 @@ _CONFIG_FOR_DOC = "UdopConfig" -# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - UDOP_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -471,7 +417,7 @@ class UdopPreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True _supports_cache_class = True - _supports_static_cache = True + _supports_static_cache = False _keep_in_fp32_modules = ["wo"] def _init_weights(self, module): @@ -1626,7 +1572,6 @@ def _update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() @@ -1638,13 +1583,12 @@ def _update_causal_mask( ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1658,10 +1602,67 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + @add_start_docstrings( "The bare UDOP encoder-decoder Transformer outputting raw hidden-states without any specific head on top.", @@ -2077,7 +2078,7 @@ def prepare_inputs_for_generation( dtype = self.proj_out.weight.dtype min_dtype = torch.finfo(dtype).min - attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=past_key_values.self_attention_cache.get_max_length(), diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 570444bf156b..fe6ff784ec76 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -55,60 +55,6 @@ _CHECKPOINT_FOR_DOC = "google/umt5-small" -# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->UMT5 class UMT5LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -911,7 +857,6 @@ def _update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() @@ -923,13 +868,12 @@ def _update_causal_mask( ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -943,10 +887,67 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + UMT5_START_DOCSTRING = r""" @@ -1539,7 +1540,7 @@ def prepare_inputs_for_generation( dtype = self.get_output_embeddings().weight.dtype min_dtype = torch.finfo(dtype).min - decoder_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + decoder_attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( decoder_attention_mask, sequence_length=sequence_length, target_length=past_key_values.self_attention_cache.get_max_length(), From 2449e32c49625cac48868187560d0a5049491e7a Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 1 Oct 2024 15:38:16 +0200 Subject: [PATCH 17/30] nit --- src/transformers/models/vipllava/modeling_vipllava.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 2228f99d4d2a..e26da0ed397f 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -283,7 +283,7 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m return model_embeds # Ignore copy - def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: list[int]): + def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: List[int]): image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) # For VIP-llava, the image features are computed this way From df0a05cd35de1ce83c4a59b3b99858f92572e6f1 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 1 Oct 2024 15:48:50 +0200 Subject: [PATCH 18/30] fix copies --- src/transformers/models/mt5/modeling_mt5.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 4eb699f8619e..8b193240ce61 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1227,7 +1227,6 @@ def _update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() @@ -1245,7 +1244,6 @@ def _update_causal_mask( target_length=target_length, dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1259,6 +1257,7 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @@ -1271,7 +1270,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( target_length: int, dtype: torch.dtype, device: torch.device, - min_dtype: float, cache_position: torch.Tensor, batch_size: int, ): @@ -1281,17 +1279,17 @@ def _prepare_4d_causal_attention_mask_with_cache_position( Args: attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1301,6 +1299,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: + min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) From c98e54187166ce865501dad2fc126204fe600fd9 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 1 Oct 2024 16:07:44 +0200 Subject: [PATCH 19/30] this is the fix, had to change t5 to fix copies --- src/transformers/models/longt5/modeling_longt5.py | 2 -- src/transformers/models/mt5/modeling_mt5.py | 2 -- src/transformers/models/pix2struct/modeling_pix2struct.py | 2 -- .../switch_transformers/modeling_switch_transformers.py | 6 ++---- src/transformers/models/t5/modeling_t5.py | 2 -- src/transformers/models/udop/modeling_udop.py | 8 ++------ src/transformers/models/umt5/modeling_umt5.py | 2 -- 7 files changed, 4 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index f2b445cc4d7f..c5337fae5ebd 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -2292,7 +2292,6 @@ def prepare_inputs_for_generation( device = input_ids.device dtype = self.get_output_embeddings().weight.dtype - min_dtype = torch.finfo(dtype).min attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, @@ -2300,7 +2299,6 @@ def prepare_inputs_for_generation( target_length=past_key_values.self_attention_cache.get_max_length(), dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=batch_size, ) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 8b193240ce61..1c59857c0a60 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -2035,7 +2035,6 @@ def prepare_inputs_for_generation( device = input_ids.device dtype = self.get_output_embeddings().weight.dtype - min_dtype = torch.finfo(dtype).min decoder_attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( decoder_attention_mask, @@ -2043,7 +2042,6 @@ def prepare_inputs_for_generation( target_length=past_key_values.self_attention_cache.get_max_length(), dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=batch_size, ) diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 5ab8ae1e4484..c7d771403cf9 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1947,7 +1947,6 @@ def prepare_inputs_for_generation( device = input_ids.device dtype = self.proj_out.weight.dtype - min_dtype = torch.finfo(dtype).min decoder_attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( decoder_attention_mask, @@ -1955,7 +1954,6 @@ def prepare_inputs_for_generation( target_length=past_key_values.self_attention_cache.get_max_length(), dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=batch_size, ) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 6ebaf1a14e1f..bcdb58a1e585 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1862,6 +1862,7 @@ def _unpack_router_logits(self, router_outputs): total_expert_indexes.append(expert_indexes) return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1) + # Copied from transformers.models.longt5.modeling_longt5.LongT5ForConditionalGeneration.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1912,8 +1913,7 @@ def prepare_inputs_for_generation( batch_size, sequence_length = input_ids.shape device = input_ids.device - dtype = self.proj_out.weight.dtype - min_dtype = torch.finfo(dtype).min + dtype = self.get_output_embeddings().weight.dtype attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, @@ -1921,7 +1921,6 @@ def prepare_inputs_for_generation( target_length=past_key_values.self_attention_cache.get_max_length(), dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=batch_size, ) @@ -1935,7 +1934,6 @@ def prepare_inputs_for_generation( "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, - "output_router_logits": kwargs.get("output_router_logits", True), "cache_position": cache_position, } diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index f2a9dd8d480a..35b41aa4347b 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -2008,7 +2008,6 @@ def prepare_inputs_for_generation( device = input_ids.device dtype = self.get_output_embeddings().weight.dtype - min_dtype = torch.finfo(dtype).min decoder_attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( decoder_attention_mask, @@ -2016,7 +2015,6 @@ def prepare_inputs_for_generation( target_length=past_key_values.self_attention_cache.get_max_length(), dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=batch_size, ) diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index a34c1c2323b6..6f2a2b82c055 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -2025,6 +2025,7 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) + # Copied from transformers.models.longt5.modeling_longt5.LongT5ForConditionalGeneration.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -2075,8 +2076,7 @@ def prepare_inputs_for_generation( batch_size, sequence_length = input_ids.shape device = input_ids.device - dtype = self.proj_out.weight.dtype - min_dtype = torch.finfo(dtype).min + dtype = self.get_output_embeddings().weight.dtype attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, @@ -2084,7 +2084,6 @@ def prepare_inputs_for_generation( target_length=past_key_values.self_attention_cache.get_max_length(), dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=batch_size, ) @@ -2098,9 +2097,6 @@ def prepare_inputs_for_generation( "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, - "bbox": kwargs.get("bbox", None), - "pixel_values": kwargs.get("pixel_values", None), - "visual_bbox": kwargs.get("visual_bbox", None), "cache_position": cache_position, } diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index fe6ff784ec76..76ed98d402ba 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -1538,7 +1538,6 @@ def prepare_inputs_for_generation( device = input_ids.device dtype = self.get_output_embeddings().weight.dtype - min_dtype = torch.finfo(dtype).min decoder_attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( decoder_attention_mask, @@ -1546,7 +1545,6 @@ def prepare_inputs_for_generation( target_length=past_key_values.self_attention_cache.get_max_length(), dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=batch_size, ) From 4f168560c825bd570c7a242f7500c4460607d5a6 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 8 Oct 2024 15:20:18 +0200 Subject: [PATCH 20/30] update --- src/transformers/cache_utils.py | 6 +- src/transformers/generation/utils.py | 2 +- .../models/longt5/modeling_longt5.py | 110 +++++++-------- src/transformers/models/mt5/modeling_mt5.py | 114 +++++++--------- .../models/pix2struct/modeling_pix2struct.py | 53 +++----- .../models/pop2piano/modeling_pop2piano.py | 117 ++++++++-------- .../modeling_switch_transformers.py | 115 +++++++--------- src/transformers/models/t5/modeling_t5.py | 114 +++++++--------- src/transformers/models/udop/modeling_udop.py | 117 ++++++++-------- src/transformers/models/umt5/modeling_umt5.py | 125 ++++++++++-------- .../test_modeling_switch_transformers.py | 1 + 11 files changed, 397 insertions(+), 477 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 0b82b17dcde0..12003a041ce8 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1428,11 +1428,7 @@ def from_legacy_cache( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor` - if self.self_attention_cache.key_cache == []: - return 0 - if len(self.self_attention_cache.key_cache) > 1 and self.self_attention_cache.key_cache[layer_idx] == []: - return 0 - return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + return self.self_attention_cache.get_seq_length(layer_idx) def reset(self): if hasattr(self.self_attention_cache, "reset"): diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c50e305fb5fe..39551ede23bf 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -580,7 +580,7 @@ def _prepare_encoder_decoder_kwargs_for_generation( add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True)) # 2. Prepare encoder args and encoder kwargs from model kwargs and generation config. - irrelevant_prefix = ["decoder_", "cross_attn", "use_cache", "past_key_values"] + irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] encoder_kwargs = { argument: value for argument, value in model_kwargs.items() diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index c5337fae5ebd..f96d028069af 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -419,11 +419,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -453,62 +456,51 @@ def forward( Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - # failure that tensors are not on the same device otherwise - if torch.jit.is_tracing(): - seq_length = seq_length.to(hidden_states.device) + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - - # get past key value if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) - if key_value_states is not None: + if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - past_key_value.is_updated[self.layer_idx] = True - past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_value.cross_attention_cache else: - past_key_value = past_key_value.self_attention_cache - - if isinstance(past_key_value, StaticCache): - seq_length = past_key_value.get_max_length() - elif past_key_value is not None: - seq_length += cache_position[0] if query_length is None else query_length + curr_past_key_value = past_key_value.self_attention_cache - key_length = seq_length if key_value_states is None else key_value_states.shape[1] - - # get key/value states - current_states = key_value_states if key_value_states is not None else hidden_states - if key_value_states is not None and past_key_value and is_updated: + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - key_states = shape(self.k(current_states)) - value_states = shape(self.v(current_states)) + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not key_value_states is not None else None - key_states, value_states = past_key_value.update( + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype @@ -516,12 +508,10 @@ def unshape(states): if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: causal_mask = mask[:, :, :, : key_states.shape[-2]] @@ -544,7 +534,10 @@ def unshape(states): if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) outputs = (attn_output, past_key_value, position_bias) @@ -1199,7 +1192,7 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/ @@ -1209,10 +1202,6 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - query_length = cache_position[0] - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, @@ -1220,12 +1209,12 @@ def forward( position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, - query_length=query_length, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): @@ -1442,7 +1431,7 @@ def forward( # initialize past_key_values return_legacy_cache = False return_self_attention_cache = False - if use_cache or past_key_values is not None: + if self.is_decoder and (use_cache or past_key_values is not None): if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): return_self_attention_cache = True past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) @@ -1456,13 +1445,12 @@ def forward( past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) elif past_key_values is None: past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None - past_key_values_length = 0 - if cache_position is not None: - past_key_values_length = cache_position[0] - elif past_key_values is not None: - past_key_values_length = past_key_values.get_seq_length() - + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device @@ -1470,9 +1458,7 @@ def forward( if attention_mask is None and not is_torchdynamo_compiling(): # required mask seq length can be calculated via length of past - mask_seq_length = ( - past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length - ) + mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) if self.is_decoder: diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 1c59857c0a60..c239ed33c1e3 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -316,11 +316,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -350,62 +353,51 @@ def forward( Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - # failure that tensors are not on the same device otherwise - if torch.jit.is_tracing(): - seq_length = seq_length.to(hidden_states.device) - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - # get past key value if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) - if key_value_states is not None: + if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - past_key_value.is_updated[self.layer_idx] = True - past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_value.cross_attention_cache else: - past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_value.self_attention_cache - if isinstance(past_key_value, StaticCache): - seq_length = past_key_value.get_max_length() - elif past_key_value is not None: - seq_length += cache_position[0] if query_length is None else query_length - - key_length = seq_length if key_value_states is None else key_value_states.shape[1] - - # get key/value states - current_states = key_value_states if key_value_states is not None else hidden_states - if key_value_states is not None and past_key_value and is_updated: + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - key_states = shape(self.k(current_states)) - value_states = shape(self.v(current_states)) + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not key_value_states is not None else None - key_states, value_states = past_key_value.update( + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype @@ -413,12 +405,10 @@ def unshape(states): if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: causal_mask = mask[:, :, :, : key_states.shape[-2]] @@ -441,7 +431,10 @@ def unshape(states): if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) outputs = (attn_output, past_key_value, position_bias) @@ -567,7 +560,7 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states, self_attn_present_key_value = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -581,10 +574,6 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - query_length = cache_position[0] - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, @@ -592,11 +581,11 @@ def forward( position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, - query_length=query_length, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -629,7 +618,7 @@ def forward( else: outputs = outputs + attention_outputs - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) def load_tf_weights_in_mt5(model, config, tf_checkpoint_path): @@ -1008,7 +997,7 @@ def forward( # initialize past_key_values return_legacy_cache = False return_self_attention_cache = False - if use_cache or past_key_values is not None: + if self.is_decoder and (use_cache or past_key_values is not None): if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): return_self_attention_cache = True past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) @@ -1022,23 +1011,20 @@ def forward( past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) elif past_key_values is None: past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None - past_key_values_length = 0 - if cache_position is not None: - past_key_values_length = cache_position[0] - elif past_key_values is not None: - past_key_values_length = past_key_values.get_seq_length() - + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) if attention_mask is None and not is_torchdynamo_compiling(): - # required mask seq length can be calculated via length of past - mask_seq_length = ( - past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length - ) + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) if self.config.is_decoder: diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index c7d771403cf9..84ee25ed5c3e 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -796,55 +796,50 @@ def forward( Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, query_length, key_length) batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length - real_seq_length += cache_position[0] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def to_projection_shape(states): - """projection""" - return states.contiguous().view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - # get query states - # (batch_size, n_heads, seq_length, dim_per_head) - query_states = to_projection_shape(self.query(hidden_states)) + query_states = self.query(hidden_states).contiguous() + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - # get past key value if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) - if key_value_states is not None: + if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - past_key_value.is_updated[self.layer_idx] = True past_key_value = past_key_value.cross_attention_cache else: past_key_value = past_key_value.self_attention_cache # get key/value states - current_states = key_value_states if key_value_states is not None else hidden_states - if key_value_states is not None and past_key_value and is_updated: + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions key_states = past_key_value.key_cache[self.layer_idx] value_states = past_key_value.value_cache[self.layer_idx] else: - key_states = to_projection_shape(self.key(current_states)) - value_states = to_projection_shape(self.value(current_states)) + key_states = self.key(current_states).contiguous() + value_states = self.value(current_states).contiguous() + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not key_value_states is not None else None + cache_position = cache_position if not is_cross_attention else None key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + real_seq_length = cache_position[-1] + 1 if query_length is None else query_length + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] if not self.has_relative_attention_bias: position_bias = torch.zeros( (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype @@ -1004,7 +999,7 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -1014,10 +1009,6 @@ def forward( do_cross_attention = encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - query_length = cache_position[0] - cross_attention_outputs = self.encoder_decoder_attention( hidden_states, key_value_states=encoder_hidden_states, @@ -1025,12 +1016,12 @@ def forward( position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, - query_length=query_length, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index f1fc46f89598..9a649cc13e7d 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -39,6 +39,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -349,11 +350,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -383,62 +387,51 @@ def forward( Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - # failure that tensors are not on the same device otherwise - if torch.jit.is_tracing(): - seq_length = seq_length.to(hidden_states.device) - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - # get past key value if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) - if key_value_states is not None: + if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - past_key_value.is_updated[self.layer_idx] = True - past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_value.cross_attention_cache else: - past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_value.self_attention_cache - if isinstance(past_key_value, StaticCache): - seq_length = past_key_value.get_max_length() - elif past_key_value is not None: - seq_length += cache_position[0] if query_length is None else query_length - - key_length = seq_length if key_value_states is None else key_value_states.shape[1] - - # get key/value states - current_states = key_value_states if key_value_states is not None else hidden_states - if key_value_states is not None and past_key_value and is_updated: + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - key_states = shape(self.k(current_states)) - value_states = shape(self.v(current_states)) + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not key_value_states is not None else None - key_states, value_states = past_key_value.update( + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype @@ -446,12 +439,10 @@ def unshape(states): if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: causal_mask = mask[:, :, :, : key_states.shape[-2]] @@ -474,7 +465,10 @@ def unshape(states): if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) outputs = (attn_output, past_key_value, position_bias) @@ -602,7 +596,7 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states, self_attn_present_key_value = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -616,10 +610,6 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - query_length = cache_position[0] - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, @@ -627,11 +617,11 @@ def forward( position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, - query_length=query_length, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -664,7 +654,7 @@ def forward( else: outputs = outputs + attention_outputs - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) class Pop2PianoPreTrainedModel(PreTrainedModel): @@ -845,7 +835,7 @@ def forward( # initialize past_key_values return_legacy_cache = False return_self_attention_cache = False - if use_cache or past_key_values is not None: + if self.is_decoder and (use_cache or past_key_values is not None): if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): return_self_attention_cache = True past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) @@ -859,23 +849,20 @@ def forward( past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) elif past_key_values is None: past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None - past_key_values_length = 0 - if cache_position is not None: - past_key_values_length = cache_position[0] - elif past_key_values is not None: - past_key_values_length = past_key_values.get_seq_length() - + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - if attention_mask is None: - # required mask seq length can be calculated via length of past - mask_seq_length = ( - past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length - ) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) if self.config.is_decoder: diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index bcdb58a1e585..5b6b42dd6134 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -41,6 +41,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -456,11 +457,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -490,62 +494,51 @@ def forward( Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - # failure that tensors are not on the same device otherwise - if torch.jit.is_tracing(): - seq_length = seq_length.to(hidden_states.device) - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - # get past key value if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) - if key_value_states is not None: + if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - past_key_value.is_updated[self.layer_idx] = True - past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_value.cross_attention_cache else: - past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_value.self_attention_cache - if isinstance(past_key_value, StaticCache): - seq_length = past_key_value.get_max_length() - elif past_key_value is not None: - seq_length += cache_position[0] if query_length is None else query_length - - key_length = seq_length if key_value_states is None else key_value_states.shape[1] - - # get key/value states - current_states = key_value_states if key_value_states is not None else hidden_states - if key_value_states is not None and past_key_value and is_updated: + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - key_states = shape(self.k(current_states)) - value_states = shape(self.v(current_states)) + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not key_value_states is not None else None - key_states, value_states = past_key_value.update( + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype @@ -553,12 +546,10 @@ def unshape(states): if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: causal_mask = mask[:, :, :, : key_states.shape[-2]] @@ -581,7 +572,10 @@ def unshape(states): if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) outputs = (attn_output, past_key_value, position_bias) @@ -712,7 +706,7 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -722,10 +716,6 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - query_length = cache_position[0] - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, @@ -733,12 +723,12 @@ def forward( position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, - query_length=query_length, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): @@ -967,7 +957,7 @@ def forward( # initialize past_key_values return_legacy_cache = False return_self_attention_cache = False - if use_cache or past_key_values is not None: + if self.is_decoder and (use_cache or past_key_values is not None): if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): return_self_attention_cache = True past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) @@ -981,23 +971,20 @@ def forward( past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) elif past_key_values is None: past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None - past_key_values_length = 0 - if cache_position is not None: - past_key_values_length = cache_position[0] - elif past_key_values is not None: - past_key_values_length = past_key_values.get_seq_length() - + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - if attention_mask is None: - # required mask seq length can be calculated via length of past - mask_seq_length = ( - past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length - ) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) if self.config.is_decoder: diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 35b41aa4347b..d51beb8cea8c 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -441,11 +441,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -475,62 +478,51 @@ def forward( Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - # failure that tensors are not on the same device otherwise - if torch.jit.is_tracing(): - seq_length = seq_length.to(hidden_states.device) - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - # get past key value if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) - if key_value_states is not None: + if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - past_key_value.is_updated[self.layer_idx] = True - past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_value.cross_attention_cache else: - past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_value.self_attention_cache - if isinstance(past_key_value, StaticCache): - seq_length = past_key_value.get_max_length() - elif past_key_value is not None: - seq_length += cache_position[0] if query_length is None else query_length - - key_length = seq_length if key_value_states is None else key_value_states.shape[1] - - # get key/value states - current_states = key_value_states if key_value_states is not None else hidden_states - if key_value_states is not None and past_key_value and is_updated: + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - key_states = shape(self.k(current_states)) - value_states = shape(self.v(current_states)) + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not key_value_states is not None else None - key_states, value_states = past_key_value.update( + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype @@ -538,12 +530,10 @@ def unshape(states): if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: causal_mask = mask[:, :, :, : key_states.shape[-2]] @@ -566,7 +556,10 @@ def unshape(states): if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) outputs = (attn_output, past_key_value, position_bias) @@ -689,7 +682,7 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states, self_attn_present_key_value = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -703,10 +696,6 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - query_length = cache_position[0] - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, @@ -714,11 +703,11 @@ def forward( position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, - query_length=query_length, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -751,7 +740,7 @@ def forward( else: outputs = outputs + attention_outputs - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) class T5ClassificationHead(nn.Module): @@ -1021,7 +1010,7 @@ def forward( # initialize past_key_values return_legacy_cache = False return_self_attention_cache = False - if use_cache or past_key_values is not None: + if self.is_decoder and (use_cache or past_key_values is not None): if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): return_self_attention_cache = True past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) @@ -1035,23 +1024,20 @@ def forward( past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) elif past_key_values is None: past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None - past_key_values_length = 0 - if cache_position is not None: - past_key_values_length = cache_position[0] - elif past_key_values is not None: - past_key_values_length = past_key_values.get_seq_length() - + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) if attention_mask is None and not is_torchdynamo_compiling(): - # required mask seq length can be calculated via length of past - mask_seq_length = ( - past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length - ) + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) if self.config.is_decoder: diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 6f2a2b82c055..5dd0768ae634 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -43,6 +43,7 @@ ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, replace_return_docstrings, ) @@ -704,11 +705,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -738,62 +742,51 @@ def forward( Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - # failure that tensors are not on the same device otherwise - if torch.jit.is_tracing(): - seq_length = seq_length.to(hidden_states.device) - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - # get past key value if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) - if key_value_states is not None: + if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - past_key_value.is_updated[self.layer_idx] = True - past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_value.cross_attention_cache else: - past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_value.self_attention_cache - if isinstance(past_key_value, StaticCache): - seq_length = past_key_value.get_max_length() - elif past_key_value is not None: - seq_length += cache_position[0] if query_length is None else query_length - - key_length = seq_length if key_value_states is None else key_value_states.shape[1] - - # get key/value states - current_states = key_value_states if key_value_states is not None else hidden_states - if key_value_states is not None and past_key_value and is_updated: + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - key_states = shape(self.k(current_states)) - value_states = shape(self.v(current_states)) + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not key_value_states is not None else None - key_states, value_states = past_key_value.update( + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype @@ -801,12 +794,10 @@ def unshape(states): if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: causal_mask = mask[:, :, :, : key_states.shape[-2]] @@ -829,7 +820,10 @@ def unshape(states): if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) outputs = (attn_output, past_key_value, position_bias) @@ -957,7 +951,7 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states, self_attn_present_key_value = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -971,10 +965,6 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - query_length = cache_position[0] - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, @@ -982,11 +972,11 @@ def forward( position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, - query_length=query_length, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -1019,7 +1009,7 @@ def forward( else: outputs = outputs + attention_outputs - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) class UdopCellEmbeddings(nn.Module): @@ -1401,7 +1391,7 @@ def forward( # initialize past_key_values return_legacy_cache = False return_self_attention_cache = False - if use_cache or past_key_values is not None: + if self.is_decoder and (use_cache or past_key_values is not None): if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): return_self_attention_cache = True past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) @@ -1415,23 +1405,20 @@ def forward( past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) elif past_key_values is None: past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None - past_key_values_length = 0 - if cache_position is not None: - past_key_values_length = cache_position[0] - elif past_key_values is not None: - past_key_values_length = past_key_values.get_seq_length() - + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - if attention_mask is None: - # required mask seq length can be calculated via length of past - mask_seq_length = ( - past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length - ) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) if self.config.is_decoder: diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 76ed98d402ba..c8535ccf3681 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -240,11 +240,14 @@ def _relative_position_bucket(self, relative_position): relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket(relative_position) @@ -263,70 +266,84 @@ def forward( ): batch_size, seq_length = hidden_states.shape[:2] - # get past key value + # if encoder_hidden_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = encoder_hidden_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) - if encoder_hidden_states is not None: + if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - past_key_value.is_updated[self.layer_idx] = True - past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_value.cross_attention_cache else: - past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_value.self_attention_cache - # get key/value states - current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - if encoder_hidden_states is not None and past_key_value and is_updated: + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - key_states = self._shape(self.k(current_states)) - value_states = self._shape(self.v(current_states)) + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not encoder_hidden_states is not None else None - key_states, value_states = past_key_value.update( + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True - query_states = self._shape(self.q(hidden_states)) - attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) - # compute positional bias - if self.has_relative_attention_bias: - query_length = seq_length - if isinstance(past_key_value, StaticCache): - query_length = past_key_value.get_max_length() - elif past_key_value is not None: - query_length += past_key_value.get_seq_length() - - position_bias = self.compute_bias(query_length, key_states.size(2), device=attention_scores.device) - else: + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = seq_length + past_key_value.get_seq_length() if past_key_value is not None else seq_length + key_length = key_states.shape[-2] + if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, seq_length, key_states.size(2)), - device=attention_scores.device, - dtype=attention_scores.dtype, - requires_grad=self.training, + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] + if attention_mask is not None: - position_bias = position_bias + attention_mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked - attention_scores += position_bias # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.softmax(attention_scores.float(), dim=-1).type_as(attention_scores) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - # attn_output = torch.bmm(attn_probs, value_states) ? - context_states = torch.matmul(attn_weights, value_states) - # attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) ? - context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) - attn_output = self.o(context_states) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, seq_length, -1) + + attn_output = self.o(attn_output) return attn_output, attn_weights, past_key_value @@ -412,7 +429,7 @@ def forward( output_attentions=False, cache_position=None, ): - hidden_states, self_attn_weights, present_key_value = self.layer[0]( + hidden_states, self_attn_weights, past_key_value = self.layer[0]( hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -430,8 +447,7 @@ def forward( cross_attn_weights = None do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.layer[1]( + hidden_states, cross_attn_weights, past_key_value = self.layer[1]( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -680,7 +696,7 @@ def forward( # initialize past_key_values return_legacy_cache = False return_self_attention_cache = False - if use_cache or past_key_values is not None: + if self.is_decoder and (use_cache or past_key_values is not None): if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): return_self_attention_cache = True past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) @@ -694,23 +710,20 @@ def forward( past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) elif past_key_values is None: past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None - past_key_values_length = 0 - if cache_position is not None: - past_key_values_length = cache_position[0] - elif past_key_values is not None: - past_key_values_length = past_key_values.get_seq_length() - + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) if attention_mask is None and not is_torchdynamo_compiling(): - # required mask seq length can be calculated via length of past - mask_seq_length = ( - past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length - ) + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) if self.is_decoder: diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 7adb1f40c6e6..41758c2acfbb 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -425,6 +425,7 @@ def create_and_check_generate_with_past_key_values( ) torch.manual_seed(0) output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) + print(output_with_past_cache, output_without_past_cache) self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) def create_and_check_model_fp16_forward( From d7260d3b7e3f00101920dafcec0ac6ee9212db32 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 8 Oct 2024 15:31:29 +0200 Subject: [PATCH 21/30] [run-slow] t5 --- src/transformers/models/longt5/modeling_longt5.py | 2 +- .../switch_transformers/test_modeling_switch_transformers.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index f96d028069af..065ee45d7b87 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1253,7 +1253,7 @@ class LongT5PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LongT5Block"] _supports_cache_class = True - _supports_static_cache = False # TODO: @raushan more involvede due to local/global attn + _supports_static_cache = False # TODO: @raushan more involved due to local/global attn @property # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 41758c2acfbb..7adb1f40c6e6 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -425,7 +425,6 @@ def create_and_check_generate_with_past_key_values( ) torch.manual_seed(0) output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) - print(output_with_past_cache, output_without_past_cache) self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) def create_and_check_model_fp16_forward( From 5f5f66f7727a5ae961357544a0a3d341a0b96690 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 9 Oct 2024 13:29:17 +0200 Subject: [PATCH 22/30] [run-slow] t5 --- tests/models/t5/test_modeling_t5.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index d56ba0ef48b4..195e42800fef 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -1613,8 +1613,8 @@ def test_contrastive_search_t5(self): self.assertListEqual( generated_text, [ - "Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for " - "permanent residence after the marriages, prosecutors say." + "Liana Barrientos has been married 10 times, nine of them in the Bronx . Her husbands filed for " + "permanent residence after the marriages, prosecutors say ." ], ) From e4040635012fbcad006ce745fa3c778b84075d26 Mon Sep 17 00:00:00 2001 From: raushan Date: Fri, 11 Oct 2024 11:39:31 +0200 Subject: [PATCH 23/30] update --- .../models/longt5/modeling_longt5.py | 2 +- src/transformers/models/mt5/modeling_mt5.py | 2 +- .../models/pop2piano/modeling_pop2piano.py | 2 +- .../modeling_switch_transformers.py | 2 +- src/transformers/models/t5/modeling_t5.py | 2 +- src/transformers/models/udop/modeling_udop.py | 2 +- tests/models/t5/test_modeling_t5.py | 41 +++++++++++++++++++ 7 files changed, 47 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 065ee45d7b87..642f649f6a2f 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -426,7 +426,7 @@ def compute_bias(self, query_length, key_length, device=None, cache_position=Non if cache_position is None: context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] else: - context_position = cache_position[:, None] + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index c239ed33c1e3..8529af826139 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -323,7 +323,7 @@ def compute_bias(self, query_length, key_length, device=None, cache_position=Non if cache_position is None: context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] else: - context_position = cache_position[:, None] + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 9a649cc13e7d..c8ed3629e15d 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -357,7 +357,7 @@ def compute_bias(self, query_length, key_length, device=None, cache_position=Non if cache_position is None: context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] else: - context_position = cache_position[:, None] + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 5b6b42dd6134..7c1bac267dbd 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -464,7 +464,7 @@ def compute_bias(self, query_length, key_length, device=None, cache_position=Non if cache_position is None: context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] else: - context_position = cache_position[:, None] + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index d51beb8cea8c..0bda6b59e0ae 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -448,7 +448,7 @@ def compute_bias(self, query_length, key_length, device=None, cache_position=Non if cache_position is None: context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] else: - context_position = cache_position[:, None] + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 5dd0768ae634..8bf8d47a79f2 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -712,7 +712,7 @@ def compute_bias(self, query_length, key_length, device=None, cache_position=Non if cache_position is None: context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] else: - context_position = cache_position[:, None] + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index 195e42800fef..6800b5f75b02 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -27,6 +27,7 @@ require_sentencepiece, require_tokenizers, require_torch, + require_torch_gpu, slow, torch_device, ) @@ -1618,6 +1619,46 @@ def test_contrastive_search_t5(self): ], ) + @slow + @require_torch_gpu + def test_compile_static_cache(self): + NUM_TOKENS_TO_GENERATE = 40 + EXPECTED_TEXT_COMPLETION = [ + "theory of relativity states that 1) the speed of light is constant in all inertial reference frames. the laws of physics are the same for all inertial reference frames.", + "ketchup is my favorite condiment.", + ] + + prompts = [ + "summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " + "theory of relativity is not hard to grasp.", + "summarize: My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " + "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my pizza.", + ] + model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small").to(torch_device) + tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small") + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + + # Static Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + + # Static Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) + @require_torch class TestAsymmetricT5(unittest.TestCase): From 47d70c5881f529f76be2ab9b77c6b0173643c1d7 Mon Sep 17 00:00:00 2001 From: raushan Date: Fri, 11 Oct 2024 12:25:11 +0200 Subject: [PATCH 24/30] add test for encoder only T5 --- tests/models/t5/test_modeling_t5.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index 6800b5f75b02..2eb2a2cd004f 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -1659,6 +1659,26 @@ def test_compile_static_cache(self): static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) + @slow + @require_torch_gpu + def test_compile_static_cache_encoder(self): + prompts = [ + "summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " + "theory of relativity is not hard to grasp.", + "summarize: My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " + "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my pizza.", + ] + model = T5EncoderModel.from_pretrained("google-t5/t5-small").to(torch_device) + tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small") + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + logits = model(**inputs) + + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + logits_compiled = model(**inputs) + self.assertTrue(torch.allclose(logits[0][:, -3:, -3], logits_compiled[0][:, -3:, -3], atol=1e-5)) + @require_torch class TestAsymmetricT5(unittest.TestCase): From 3048ab848bba5b531011bf0b3f9cfa0c7d454c02 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 14 Oct 2024 09:51:56 +0200 Subject: [PATCH 25/30] clean up after rebase --- .../models/longt5/modeling_longt5.py | 76 +---------------- src/transformers/models/mt5/modeling_mt5.py | 79 +----------------- .../models/pix2struct/modeling_pix2struct.py | 83 +------------------ .../models/pop2piano/modeling_pop2piano.py | 29 +------ .../modeling_switch_transformers.py | 77 +---------------- src/transformers/models/t5/modeling_t5.py | 78 +---------------- src/transformers/models/udop/modeling_udop.py | 77 +---------------- src/transformers/models/umt5/modeling_umt5.py | 79 +----------------- 8 files changed, 8 insertions(+), 570 deletions(-) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 642f649f6a2f..c0bae9dc194f 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1623,7 +1623,7 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1] if using_static_cache: - target_length = past_key_values.get_max_length() + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -2227,80 +2227,6 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - cache_position=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - if isinstance(past_key_values, EncoderDecoderCache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - else: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_ids.shape[1] :] - - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - input_ids = input_ids.clone(memory_format=torch.contiguous_format) - - if ( - isinstance(past_key_values, EncoderDecoderCache) - and ( - isinstance(past_key_values.self_attention_cache, StaticCache) - or isinstance(past_key_values.cross_attention_cache, StaticCache) - ) - and attention_mask is not None - and attention_mask.ndim == 2 - ): - batch_size, sequence_length = input_ids.shape - device = input_ids.device - - dtype = self.get_output_embeddings().weight.dtype - - attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.self_attention_cache.get_max_length(), - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - ) - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - "cache_position": cache_position, - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 8529af826139..2b50b4ba0e0c 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1215,7 +1215,7 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1] if using_static_cache: - target_length = past_key_values.get_max_length() + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -1968,83 +1968,6 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - decoder_attention_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - cache_position=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - if isinstance(past_key_values, EncoderDecoderCache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - else: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_ids.shape[1] :] - - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - input_ids = input_ids.clone(memory_format=torch.contiguous_format) - - if ( - isinstance(past_key_values, EncoderDecoderCache) - and ( - isinstance(past_key_values.self_attention_cache, StaticCache) - or isinstance(past_key_values.cross_attention_cache, StaticCache) - ) - and decoder_attention_mask is not None - and decoder_attention_mask.ndim == 2 - ): - batch_size, sequence_length = input_ids.shape - device = input_ids.device - - dtype = self.get_output_embeddings().weight.dtype - - decoder_attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( - decoder_attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.self_attention_cache.get_max_length(), - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - ) - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "decoder_attention_mask": decoder_attention_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - "cache_position": cache_position, - } - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 84ee25ed5c3e..672ea2c10986 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1600,7 +1600,7 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1] if using_static_cache: - target_length = past_key_values.get_max_length() + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -1881,84 +1881,3 @@ def forward( encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) - - def prepare_inputs_for_generation( - self, - input_ids, - flattened_patches: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - cache_position=None, - **kwargs, - ): - if decoder_attention_mask is None: - decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device) - - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - if isinstance(past_key_values, EncoderDecoderCache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - else: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_ids.shape[1] :] - - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - input_ids = input_ids.clone(memory_format=torch.contiguous_format) - - if ( - isinstance(past_key_values, EncoderDecoderCache) - and ( - isinstance(past_key_values.self_attention_cache, StaticCache) - or isinstance(past_key_values.cross_attention_cache, StaticCache) - ) - and decoder_attention_mask is not None - and decoder_attention_mask.ndim == 2 - ): - batch_size, sequence_length = input_ids.shape - device = input_ids.device - - dtype = self.proj_out.weight.dtype - - decoder_attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( - decoder_attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.self_attention_cache.get_max_length(), - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - ) - - return { - "flattened_patches": flattened_patches, - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - "cache_position": cache_position, - } diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index c8ed3629e15d..7e8d14800159 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -1023,7 +1023,7 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1] if using_static_cache: - target_length = past_key_values.get_max_length() + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -1443,33 +1443,6 @@ def generate( **kwargs, ) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 282218b84708..0cc9cc0a4525 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1159,7 +1159,7 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1] if using_static_cache: - target_length = past_key_values.get_max_length() + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -1849,81 +1849,6 @@ def _unpack_router_logits(self, router_outputs): total_expert_indexes.append(expert_indexes) return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1) - # Copied from transformers.models.longt5.modeling_longt5.LongT5ForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - cache_position=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - if isinstance(past_key_values, EncoderDecoderCache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - else: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_ids.shape[1] :] - - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - input_ids = input_ids.clone(memory_format=torch.contiguous_format) - - if ( - isinstance(past_key_values, EncoderDecoderCache) - and ( - isinstance(past_key_values.self_attention_cache, StaticCache) - or isinstance(past_key_values.cross_attention_cache, StaticCache) - ) - and attention_mask is not None - and attention_mask.ndim == 2 - ): - batch_size, sequence_length = input_ids.shape - device = input_ids.device - - dtype = self.get_output_embeddings().weight.dtype - - attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.self_attention_cache.get_max_length(), - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - ) - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - "cache_position": cache_position, - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 0bda6b59e0ae..02d67b64c031 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1228,7 +1228,7 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1] if using_static_cache: - target_length = past_key_values.get_max_length() + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -1942,82 +1942,6 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - decoder_attention_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - cache_position=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - if isinstance(past_key_values, EncoderDecoderCache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - else: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_ids.shape[1] :] - - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - input_ids = input_ids.clone(memory_format=torch.contiguous_format) - - if ( - isinstance(past_key_values, EncoderDecoderCache) - and ( - isinstance(past_key_values.self_attention_cache, StaticCache) - or isinstance(past_key_values.cross_attention_cache, StaticCache) - ) - and decoder_attention_mask is not None - and decoder_attention_mask.ndim == 2 - ): - batch_size, sequence_length = input_ids.shape - device = input_ids.device - - dtype = self.get_output_embeddings().weight.dtype - - decoder_attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( - decoder_attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.self_attention_cache.get_max_length(), - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - ) - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "decoder_attention_mask": decoder_attention_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - "cache_position": cache_position, - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 8bf8d47a79f2..6b773e57e332 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -1561,7 +1561,7 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1] if using_static_cache: - target_length = past_key_values.get_max_length() + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -2012,81 +2012,6 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) - # Copied from transformers.models.longt5.modeling_longt5.LongT5ForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - cache_position=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - if isinstance(past_key_values, EncoderDecoderCache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - else: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_ids.shape[1] :] - - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - input_ids = input_ids.clone(memory_format=torch.contiguous_format) - - if ( - isinstance(past_key_values, EncoderDecoderCache) - and ( - isinstance(past_key_values.self_attention_cache, StaticCache) - or isinstance(past_key_values.cross_attention_cache, StaticCache) - ) - and attention_mask is not None - and attention_mask.ndim == 2 - ): - batch_size, sequence_length = input_ids.shape - device = input_ids.device - - dtype = self.get_output_embeddings().weight.dtype - - attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.self_attention_cache.get_max_length(), - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - ) - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - "cache_position": cache_position, - } - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._reorder_cache def _reorder_cache(self, past_key_values, beam_idx): # if decoder past is not included in output diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index c8535ccf3681..631b16917156 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -872,7 +872,7 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1] if using_static_cache: - target_length = past_key_values.get_max_length() + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -1498,83 +1498,6 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - decoder_attention_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - cache_position=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - if isinstance(past_key_values, EncoderDecoderCache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - else: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device) - elif use_cache: - cache_position = cache_position[-input_ids.shape[1] :] - - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - input_ids = input_ids.clone(memory_format=torch.contiguous_format) - - if ( - isinstance(past_key_values, EncoderDecoderCache) - and ( - isinstance(past_key_values.self_attention_cache, StaticCache) - or isinstance(past_key_values.cross_attention_cache, StaticCache) - ) - and decoder_attention_mask is not None - and decoder_attention_mask.ndim == 2 - ): - batch_size, sequence_length = input_ids.shape - device = input_ids.device - - dtype = self.get_output_embeddings().weight.dtype - - decoder_attention_mask = self.decoder._prepare_4d_causal_attention_mask_with_cache_position( - decoder_attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.self_attention_cache.get_max_length(), - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - ) - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "decoder_attention_mask": decoder_attention_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - "cache_position": cache_position, - } - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) From 2c805f2563b98ec16e09fcdfcc847b94d6ce1d39 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 14 Oct 2024 12:21:05 +0200 Subject: [PATCH 26/30] fix pop2piano --- src/transformers/generation/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 74c5123033b4..4291b84efaba 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1525,8 +1525,12 @@ def _prepare_generation_config( def _get_initial_cache_position(self, input_ids, model_kwargs): """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` - if "inputs_embeds" in model_kwargs: + if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder: cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 + elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder: + cache_position = ( + torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 + ) else: cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 From 9e1fefa217403c7af8f50df45296abb8c9a609d6 Mon Sep 17 00:00:00 2001 From: raushan Date: Fri, 18 Oct 2024 12:09:39 +0200 Subject: [PATCH 27/30] add comment --- tests/models/t5/test_modeling_t5.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index eb7d8bbf8087..68dd5a52b3d6 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -1638,6 +1638,8 @@ def test_contrastive_search_t5(self): outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True) + # TODO: @arthur? + # PR #31938 caused regression on this test which was fixed by PR #34089 self.assertListEqual( generated_text, [ From 56d036cb965e7d3f1a48496bd754eea3d69a70bb Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 21 Oct 2024 10:03:26 +0200 Subject: [PATCH 28/30] style --- tests/test_modeling_common.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f3c005f711cc..d1b9fd0b362c 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4948,9 +4948,13 @@ def test_torch_compile(self): tokenizer = AutoTokenizer.from_pretrained(ckpt) if self.is_encoder_decoder: - model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(torch_device) + model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( + torch_device + ) else: - model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(torch_device) + model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( + torch_device + ) model.generation_config.max_new_tokens = 4 @@ -5024,9 +5028,13 @@ def test_compile_cuda_graph_time(self): tokenizer = AutoTokenizer.from_pretrained(ckpt) if self.is_encoder_decoder: - model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(torch_device) + model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( + torch_device + ) else: - model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(torch_device) + model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( + torch_device + ) cache_implementation = "static" if model.config.model_type == "gemma2": From c25a8a4f06f36e2cd11202420d6583edd655c7ea Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 22 Oct 2024 07:49:35 +0200 Subject: [PATCH 29/30] fix copies after rebase --- src/transformers/models/longt5/modeling_longt5.py | 1 + src/transformers/models/pix2struct/modeling_pix2struct.py | 1 + src/transformers/models/pop2piano/modeling_pop2piano.py | 1 + .../models/switch_transformers/modeling_switch_transformers.py | 1 + src/transformers/models/t5/modeling_t5.py | 1 + src/transformers/models/udop/modeling_udop.py | 1 + src/transformers/models/umt5/modeling_umt5.py | 1 + 7 files changed, 7 insertions(+) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index c0bae9dc194f..29536d9ad6f2 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1666,6 +1666,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( device: torch.device, cache_position: torch.Tensor, batch_size: int, + **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 672ea2c10986..b1ac81bb1f21 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1643,6 +1643,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( device: torch.device, cache_position: torch.Tensor, batch_size: int, + **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 7e8d14800159..6a64a27e007b 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -1066,6 +1066,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( device: torch.device, cache_position: torch.Tensor, batch_size: int, + **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 0cc9cc0a4525..b150b04eea57 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1202,6 +1202,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( device: torch.device, cache_position: torch.Tensor, batch_size: int, + **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 02d67b64c031..9012c8db9feb 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1271,6 +1271,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( device: torch.device, cache_position: torch.Tensor, batch_size: int, + **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 6b773e57e332..1928ac8a5c20 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -1604,6 +1604,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( device: torch.device, cache_position: torch.Tensor, batch_size: int, + **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 631b16917156..985dc5e4426d 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -915,6 +915,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( device: torch.device, cache_position: torch.Tensor, batch_size: int, + **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape From befe2d82310320ef08c06892fb4657735706e38c Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 22 Oct 2024 07:51:37 +0200 Subject: [PATCH 30/30] fix copies missed this one --- src/transformers/models/mt5/modeling_mt5.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 2b50b4ba0e0c..659a84c5fe37 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1258,6 +1258,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( device: torch.device, cache_position: torch.Tensor, batch_size: int, + **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape