diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 4e4a1ee26c12..0f696cc3ac6a 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1475,11 +1475,7 @@ def from_legacy_cache( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor` - if self.self_attention_cache.key_cache == []: - return 0 - if len(self.self_attention_cache.key_cache) > 1 and self.self_attention_cache.key_cache[layer_idx] == []: - return 0 - return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + return self.self_attention_cache.get_seq_length(layer_idx) def reset(self): if hasattr(self.self_attention_cache, "reset"): diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9ede527ecb7b..c399a8a2c829 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1535,8 +1535,12 @@ def _prepare_generation_config( def _get_initial_cache_position(self, input_ids, model_kwargs): """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` - if "inputs_embeds" in model_kwargs: + if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder: cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 + elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder: + cache_position = ( + torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 + ) else: cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 @@ -1633,7 +1637,7 @@ def get_layer_device_map(execution_device_map: Optional[dict] = None): cache_kwargs = { "config": self.config.get_text_config(), - "max_batch_size": batch_size, + "batch_size": batch_size, "max_cache_len": max_cache_len, "device": device, "dtype": cache_dtype, diff --git a/src/transformers/models/longt5/configuration_longt5.py b/src/transformers/models/longt5/configuration_longt5.py index 0e541ae2a1b4..b6e7d21b3d67 100644 --- a/src/transformers/models/longt5/configuration_longt5.py +++ b/src/transformers/models/longt5/configuration_longt5.py @@ -79,7 +79,12 @@ class LongT5Config(PretrainedConfig): model_type = "longt5" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "head_dim": "d_kv", + } def __init__( self, diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index d351e798ac7f..29536d9ad6f2 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -24,7 +24,9 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -39,6 +41,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -317,7 +320,12 @@ def forward(self, hidden_states): # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->LongT5 class LongT5Attention(nn.Module): - def __init__(self, config: LongT5Config, has_relative_attention_bias=False): + def __init__( + self, + config: LongT5Config, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -328,6 +336,13 @@ def __init__(self, config: LongT5Config, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -404,11 +419,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -432,94 +450,72 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -529,22 +525,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -1008,9 +1004,11 @@ def unshape(states): # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->LongT5 class LongT5LayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = LongT5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = LongT5Attention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -1023,6 +1021,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -1033,6 +1032,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -1042,7 +1042,7 @@ def forward( class LongT5LayerLocalSelfAttention(nn.Module): """Local self attention used in encoder""" - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.LocalSelfAttention = LongT5LocalAttention(config, has_relative_attention_bias=has_relative_attention_bias) self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) @@ -1073,7 +1073,7 @@ def forward( class LongT5LayerTransientGlobalSelfAttention(nn.Module): """Transient-Global self attention used in encoder""" - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.TransientGlobalSelfAttention = LongT5TransientGlobalAttention( config, has_relative_attention_bias=has_relative_attention_bias @@ -1105,9 +1105,9 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->LongT5 class LongT5LayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False) + self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -1122,6 +1122,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -1134,6 +1135,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -1141,7 +1143,7 @@ def forward( class LongT5Block(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder if config.is_decoder: @@ -1156,9 +1158,11 @@ def __init__(self, config, has_relative_attention_bias=False): f"but got {config.encoder_attention_type}." ) self.layer = nn.ModuleList() - self.layer.append(attention_layer(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append( + attention_layer(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) + ) if self.is_decoder: - self.layer.append(LongT5LayerCrossAttention(config)) + self.layer.append(LongT5LayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(LongT5LayerFF(config)) @@ -1176,34 +1180,19 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/ @@ -1213,35 +1202,25 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -1256,7 +1235,7 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs @@ -1273,6 +1252,8 @@ class LongT5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["LongT5Block"] + _supports_cache_class = True + _supports_static_cache = False # TODO: @raushan more involved due to local/global attn @property # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs @@ -1376,7 +1357,10 @@ def __init__(self, config, embed_tokens=None): self.block_len = self.local_radius + 1 self.block = nn.ModuleList( - [LongT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [ + LongT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) + for i in range(config.num_layers) + ] ) self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -1408,6 +1392,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1430,36 +1415,65 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - - if use_cache is True: - assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) - if attention_mask is None: + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used if self.is_decoder: - extended_attention_mask = self.get_extended_attention_mask( - attention_mask, input_shape, inputs_embeds.device + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, ) + # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used elif self.config.encoder_attention_type == "local": - extended_attention_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device) + causal_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device) else: # we need to use both local attention mask and standard extended mask for transient-global attention - extended_attention_mask = attention_mask + causal_mask = attention_mask # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1472,17 +1486,9 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None @@ -1491,7 +1497,7 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] @@ -1502,7 +1508,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -1512,20 +1518,24 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + return_dict, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -1533,7 +1543,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -1541,9 +1551,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1557,12 +1564,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1571,12 +1584,135 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + LONGT5_START_DOCSTRING = r""" @@ -1693,6 +1829,9 @@ def forward( more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ LONGT5_ENCODER_INPUTS_DOCSTRING = r""" @@ -1817,6 +1956,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" Returns: @@ -1883,6 +2023,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1975,6 +2116,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -2050,6 +2192,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] diff --git a/src/transformers/models/mt5/configuration_mt5.py b/src/transformers/models/mt5/configuration_mt5.py index ef629718b1b5..267179f81247 100644 --- a/src/transformers/models/mt5/configuration_mt5.py +++ b/src/transformers/models/mt5/configuration_mt5.py @@ -72,7 +72,12 @@ class MT5Config(PretrainedConfig): model_type = "mt5" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "head_dim": "d_kv", + } def __init__( self, diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 9051414d7414..659a84c5fe37 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -25,7 +25,9 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -43,6 +45,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -214,7 +217,12 @@ def forward(self, hidden_states): # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->MT5 class MT5Attention(nn.Module): - def __init__(self, config: MT5Config, has_relative_attention_bias=False): + def __init__( + self, + config: MT5Config, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -225,6 +233,13 @@ def __init__(self, config: MT5Config, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -301,11 +316,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -329,94 +347,72 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -426,22 +422,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -450,9 +446,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->MT5 class MT5LayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = MT5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = MT5Attention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -465,6 +463,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -475,6 +474,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -483,9 +483,9 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->MT5 class MT5LayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = MT5Attention(config, has_relative_attention_bias=False) + self.EncDecAttention = MT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -500,6 +500,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -512,6 +513,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -520,13 +522,15 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5Block with T5->MT5 class MT5Block(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(MT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append( + MT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) + ) if self.is_decoder: - self.layer.append(MT5LayerCrossAttention(config)) + self.layer.append(MT5LayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(MT5LayerFF(config)) @@ -544,34 +548,19 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -585,25 +574,18 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -614,10 +596,6 @@ def forward( ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -636,11 +614,11 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) def load_tf_weights_in_mt5(model, config, tf_checkpoint_path): @@ -780,6 +758,9 @@ class MT5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True + _supports_quantized_cache = False # enc-dec models don't support yet + _supports_static_cache = True + _supports_cache_class = True _no_split_modules = ["MT5Block"] _keep_in_fp32_modules = ["wo"] @@ -892,7 +873,7 @@ def __init__(self, config, embed_tokens=None): self.is_decoder = config.is_decoder self.block = nn.ModuleList( - [MT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [MT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)] ) self.final_layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -968,6 +949,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): # Model parallel if self.model_parallel: @@ -994,6 +976,13 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -1001,23 +990,57 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) - if attention_mask is None: + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + elif attention_mask is not None: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + else: + causal_mask = None # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1032,17 +1055,9 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None @@ -1051,15 +1066,15 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) + if causal_mask is not None: + causal_mask = causal_mask.to(hidden_states.device) if position_bias is not None: position_bias = position_bias.to(hidden_states.device) if encoder_hidden_states is not None: @@ -1079,7 +1094,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -1089,20 +1104,24 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + return_dict, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -1110,7 +1129,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -1118,9 +1137,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1140,12 +1156,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1154,12 +1176,135 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + MT5_START_DOCSTRING = r""" @@ -1454,6 +1599,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" Returns: @@ -1533,6 +1679,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1685,6 +1832,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1779,6 +1927,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 37090677a6e2..b1ac81bb1f21 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -22,7 +22,9 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -38,6 +40,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -184,14 +187,17 @@ def to_projection_shape(states): if self.gradient_checkpointing and self.training: position_bias.requires_grad = True - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length), device=scores.device, dtype=scores.dtype) - if attention_mask.dim() == 2: position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device) - else: + elif attention_mask is not None: # (batch_size, n_heads, seq_length, key_length) position_bias = position_bias + attention_mask.to(position_bias.device) + elif not is_torchdynamo_compiling(): + attention_mask = torch.ones( + (batch_size, seq_length), device=position_bias.device, dtype=position_bias.dtype + ) + position_bias = position_bias + attention_mask.to(position_bias.device) + position_bias = 1 - position_bias position_bias_masked = position_bias.masked_fill(position_bias == 1, torch.finfo(scores.dtype).min) @@ -355,6 +361,8 @@ class Pix2StructPreTrainedModel(PreTrainedModel): """ config_class = Pix2StructConfig + _supports_cache_class = True + _supports_static_cache = False @property def dummy_inputs(self): @@ -673,7 +681,9 @@ def forward(self, hidden_states): class Pix2StructTextAttention(nn.Module): - def __init__(self, config: Pix2StructTextConfig, has_relative_attention_bias=False): + def __init__( + self, config: Pix2StructTextConfig, has_relative_attention_bias=False, layer_idx: Optional[int] = None + ): super().__init__() self.has_relative_attention_bias = has_relative_attention_bias self.relative_attention_num_buckets = config.relative_attention_num_buckets @@ -683,6 +693,13 @@ def __init__(self, config: Pix2StructTextConfig, has_relative_attention_bias=Fal self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=False) @@ -773,75 +790,56 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, query_length, key_length) batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length - - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def to_projection_shape(states): - """projection""" - return states.contiguous().view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = to_projection_shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = to_projection_shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = to_projection_shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states + query_states = self.query(hidden_states).contiguous() + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - # get query states - # (batch_size, n_heads, seq_length, dim_per_head) - query_states = to_projection_shape(self.query(hidden_states)) + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache # get key/value states - key_states = project( - hidden_states, self.key, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.value, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self.key(current_states).contiguous() + value_states = self.value(current_states).contiguous() + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + real_seq_length = cache_position[-1] + 1 if query_length is None else query_length + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] if not self.has_relative_attention_bias: position_bias = torch.zeros( (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype @@ -851,11 +849,6 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): else: position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] - if mask is not None: position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) @@ -883,19 +876,20 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): attn_output = self.output(attn_output) - present_key_value_state = (key_states, value_states) if use_cache else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output,) + (past_key_value,) + (position_bias,) if output_attentions: outputs = outputs + (attn_weights,) return outputs -# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size +# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerSelfAttention->Pix2StructTextLayerSelfAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size class Pix2StructTextLayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=has_relative_attention_bias) + self.attention = Pix2StructTextAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -908,6 +902,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.attention( @@ -918,17 +913,18 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them return outputs -# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size +# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerCrossAttention->Pix2StructTextLayerCrossAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size class Pix2StructTextLayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False) + self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -943,6 +939,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.attention( @@ -955,6 +952,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -962,11 +960,13 @@ def forward( class Pix2StructTextBlock(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.self_attention = Pix2StructTextLayerSelfAttention( - config, has_relative_attention_bias=has_relative_attention_bias + config, + has_relative_attention_bias=has_relative_attention_bias, + layer_idx=layer_idx, ) self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config) @@ -987,32 +987,19 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.self_attention( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -1022,35 +1009,25 @@ def forward( do_cross_attention = encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - cross_attention_outputs = self.encoder_decoder_attention( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -1065,7 +1042,7 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs @@ -1187,6 +1164,9 @@ def forward( more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ PIX2STRUCT_INPUTS_DOCSTRING = r""" @@ -1293,7 +1273,10 @@ def __init__(self, config): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layer = nn.ModuleList( - [Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [ + Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) + for i in range(config.num_layers) + ] ) self.final_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -1364,6 +1347,7 @@ def forward( output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple[torch.FloatTensor, ...], CausalLMOutputWithCrossAttentions]: r""" @@ -1405,24 +1389,54 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if use_cache or past_key_values is not None: + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + past_key_values_length = 0 + if cache_position is not None: + past_key_values_length = cache_position[0] + elif past_key_values is not None: + past_key_values_length = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - if encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length ) + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.layer) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1438,7 +1452,6 @@ def forward( # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions) else None @@ -1447,7 +1460,7 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.layer, past_key_values)): + for i, layer_module in enumerate(self.layer): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] if output_hidden_states: @@ -1462,7 +1475,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -1472,20 +1485,22 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -1493,7 +1508,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -1501,9 +1516,6 @@ def forward( position_bias = layer_outputs[2] if encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1527,13 +1539,19 @@ def forward( loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), labels.contiguous().view(-1)) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ loss, logits, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1543,12 +1561,135 @@ def forward( return CausalLMOutputWithCrossAttentions( loss=loss, logits=logits, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + @add_start_docstrings( "A conditional generation model with a language modeling head. Can be used for sequence generation tasks.", @@ -1615,6 +1756,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" Returns: @@ -1723,6 +1865,7 @@ def forward( output_hidden_states=output_hidden_states, labels=labels, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index d6f92e9fe034..6a64a27e007b 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -25,7 +25,9 @@ from transformers.generation import GenerationConfig from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -37,6 +39,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -136,6 +139,9 @@ more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ @@ -245,7 +251,12 @@ def forward(self, hidden_states): # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Pop2Piano,t5->pop2piano class Pop2PianoAttention(nn.Module): - def __init__(self, config: Pop2PianoConfig, has_relative_attention_bias=False): + def __init__( + self, + config: Pop2PianoConfig, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -256,6 +267,13 @@ def __init__(self, config: Pop2PianoConfig, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -332,11 +350,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -360,94 +381,72 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -457,22 +456,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -481,9 +480,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Pop2Piano,t5->pop2piano class Pop2PianoLayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = Pop2PianoAttention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = Pop2PianoAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -496,6 +497,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -506,6 +508,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -514,9 +517,9 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Pop2Piano,t5->pop2piano class Pop2PianoLayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = Pop2PianoAttention(config, has_relative_attention_bias=False) + self.EncDecAttention = Pop2PianoAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -531,6 +534,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -543,6 +547,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -551,13 +556,17 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Pop2Piano,t5->pop2piano class Pop2PianoBlock(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(Pop2PianoLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append( + Pop2PianoLayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + ) if self.is_decoder: - self.layer.append(Pop2PianoLayerCrossAttention(config)) + self.layer.append(Pop2PianoLayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(Pop2PianoLayerFF(config)) @@ -575,34 +584,19 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -616,25 +610,18 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -645,10 +632,6 @@ def forward( ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -667,11 +650,11 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) class Pop2PianoPreTrainedModel(PreTrainedModel): @@ -684,6 +667,8 @@ class Pop2PianoPreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = False supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = False _no_split_modules = ["Pop2PianoBlock"] _keep_in_fp32_modules = ["wo"] @@ -769,7 +754,10 @@ def __init__(self, config, embed_tokens=None): self.is_decoder = config.is_decoder self.block = nn.ModuleList( - [Pop2PianoBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [ + Pop2PianoBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) + for i in range(config.num_layers) + ] ) self.final_layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -803,6 +791,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -825,6 +814,13 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -832,28 +828,55 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -866,17 +889,9 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None @@ -885,7 +900,7 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] if output_hidden_states: @@ -895,7 +910,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -905,20 +920,22 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -926,7 +943,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -934,9 +951,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -950,12 +964,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -964,12 +984,135 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + class Pop2PianoConcatEmbeddingToMel(nn.Module): """Embedding Matrix for `composer` tokens.""" @@ -1122,6 +1265,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1177,6 +1321,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index c39e85bacdd3..b150b04eea57 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -24,7 +24,9 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( MoEModelOutput, MoEModelOutputWithPastAndCrossAttentions, @@ -39,6 +41,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -355,7 +358,12 @@ def forward(self, hidden_states, output_router_logits): # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->SwitchTransformers class SwitchTransformersAttention(nn.Module): - def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias=False): + def __init__( + self, + config: SwitchTransformersConfig, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -366,6 +374,13 @@ def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -442,11 +457,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -470,94 +488,72 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -567,22 +563,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -591,10 +587,10 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->SwitchTransformers class SwitchTransformersLayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.SelfAttention = SwitchTransformersAttention( - config, has_relative_attention_bias=has_relative_attention_bias + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx ) self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -608,6 +604,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -618,6 +615,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -626,9 +624,11 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->SwitchTransformers class SwitchTransformersLayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = SwitchTransformersAttention(config, has_relative_attention_bias=False) + self.EncDecAttention = SwitchTransformersAttention( + config, has_relative_attention_bias=False, layer_idx=layer_idx + ) self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -643,6 +643,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -655,6 +656,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -662,16 +664,18 @@ def forward( class SwitchTransformersBlock(nn.Module): - def __init__(self, config, has_relative_attention_bias=False, is_sparse=False): + def __init__(self, config, has_relative_attention_bias=False, is_sparse=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.is_sparse = is_sparse self.layer = nn.ModuleList() self.layer.append( - SwitchTransformersLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias) + SwitchTransformersLayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) ) if self.is_decoder: - self.layer.append(SwitchTransformersLayerCrossAttention(config)) + self.layer.append(SwitchTransformersLayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(SwitchTransformersLayerFF(config, is_sparse=self.is_sparse)) @@ -690,34 +694,19 @@ def forward( output_attentions=False, output_router_logits=True, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -727,35 +716,25 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -775,11 +754,11 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + (router_tuple,) + outputs = outputs + (past_key_value,) + attention_outputs + (router_tuple,) else: outputs = outputs + attention_outputs + (router_tuple,) - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (router_tuple) + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (router_tuple) class SwitchTransformersPreTrainedModel(PreTrainedModel): @@ -791,6 +770,8 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): config_class = SwitchTransformersConfig base_model_prefix = "switch_transformers" supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = False _no_split_modules = ["SwitchTransformersBlock"] @property @@ -897,7 +878,9 @@ def __init__(self, config, embed_tokens=None): is_sparse = (i % sparse_step == 1 or sparse_step == 1) if sparse_step > 0 else False self.block.append( - SwitchTransformersBlock(config, has_relative_attention_bias=bool(i == 0), is_sparse=is_sparse) + SwitchTransformersBlock( + config, has_relative_attention_bias=bool(i == 0), is_sparse=is_sparse, layer_idx=i + ) ) self.final_layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) @@ -930,6 +913,7 @@ def forward( output_hidden_states=None, output_router_logits=True, return_dict=None, + cache_position=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -952,6 +936,13 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -959,28 +950,55 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -993,17 +1011,9 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_router_probs = () if output_router_logits else None @@ -1013,7 +1023,7 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] @@ -1024,7 +1034,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -1034,21 +1044,26 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + output_router_logits, + return_dict, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, ) router_probs = layer_outputs[-1] @@ -1059,7 +1074,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -1067,9 +1082,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1086,12 +1098,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1101,13 +1119,136 @@ def forward( ) return MoEModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, router_probs=all_router_probs, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + SWITCH_TRANSFORMERS_START_DOCSTRING = r""" @@ -1228,6 +1369,9 @@ def forward( should not be returned during inference. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ SWITCH_TRANSFORMERS_ENCODER_INPUTS_DOCSTRING = r""" @@ -1355,6 +1499,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqMoEModelOutput]: r""" Returns: @@ -1435,6 +1580,7 @@ def forward( output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1535,6 +1681,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = True, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqMoEOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1618,6 +1765,7 @@ def forward( output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] diff --git a/src/transformers/models/t5/configuration_t5.py b/src/transformers/models/t5/configuration_t5.py index e5f2615611b8..be6fbe9528d1 100644 --- a/src/transformers/models/t5/configuration_t5.py +++ b/src/transformers/models/t5/configuration_t5.py @@ -73,7 +73,12 @@ class T5Config(PretrainedConfig): model_type = "t5" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "head_dim": "d_kv", + } def __init__( self, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 91596f013ab4..9012c8db9feb 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -25,7 +25,9 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -43,6 +45,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -339,7 +342,12 @@ def forward(self, hidden_states): class T5Attention(nn.Module): - def __init__(self, config: T5Config, has_relative_attention_bias=False): + def __init__( + self, + config: T5Config, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -350,6 +358,13 @@ def __init__(self, config: T5Config, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -426,11 +441,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -454,94 +472,72 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -551,22 +547,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -574,9 +570,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): class T5LayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = T5Attention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -589,6 +587,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -599,6 +598,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -606,9 +606,9 @@ def forward( class T5LayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -623,6 +623,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -635,6 +636,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -642,13 +644,15 @@ def forward( class T5Block(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append( + T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) + ) if self.is_decoder: - self.layer.append(T5LayerCrossAttention(config)) + self.layer.append(T5LayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(T5LayerFF(config)) @@ -666,34 +670,19 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -707,25 +696,18 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -736,10 +718,6 @@ def forward( ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -758,11 +736,11 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) class T5ClassificationHead(nn.Module): @@ -794,6 +772,9 @@ class T5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True + _supports_quantized_cache = False # enc-dec models don't support yet + _supports_static_cache = True + _supports_cache_class = True _no_split_modules = ["T5Block"] _keep_in_fp32_modules = ["wo"] @@ -905,7 +886,7 @@ def __init__(self, config, embed_tokens=None): self.is_decoder = config.is_decoder self.block = nn.ModuleList( - [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [T5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)] ) self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -981,6 +962,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): # Model parallel if self.model_parallel: @@ -1007,6 +989,13 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -1014,23 +1003,57 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) - if attention_mask is None: + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + elif attention_mask is not None: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + else: + causal_mask = None # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1045,17 +1068,9 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None @@ -1064,15 +1079,15 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) + if causal_mask is not None: + causal_mask = causal_mask.to(hidden_states.device) if position_bias is not None: position_bias = position_bias.to(hidden_states.device) if encoder_hidden_states is not None: @@ -1092,7 +1107,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -1102,20 +1117,24 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + return_dict, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -1123,7 +1142,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -1131,9 +1150,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1153,12 +1169,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1167,12 +1189,135 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + T5_START_DOCSTRING = r""" @@ -1286,6 +1431,9 @@ def forward( more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ T5_ENCODER_INPUTS_DOCSTRING = r""" @@ -1446,6 +1594,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" Returns: @@ -1525,6 +1674,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1656,6 +1806,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1750,6 +1901,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 6be8752d5b63..1928ac8a5c20 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -34,13 +34,16 @@ ) from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, replace_return_docstrings, ) @@ -154,6 +157,9 @@ more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ @@ -411,6 +417,8 @@ class UdopPreTrainedModel(PreTrainedModel): config_class = UdopConfig base_model_prefix = "transformer" supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = False _keep_in_fp32_modules = ["wo"] def _init_weights(self, module): @@ -598,7 +606,12 @@ def forward(self, hidden_states): # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Udop class UdopAttention(nn.Module): - def __init__(self, config: UdopConfig, has_relative_attention_bias=False): + def __init__( + self, + config: UdopConfig, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -609,6 +622,13 @@ def __init__(self, config: UdopConfig, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -685,11 +705,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -713,94 +736,72 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length - - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -810,22 +811,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -834,9 +835,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Udop class UdopLayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = UdopAttention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = UdopAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -849,6 +852,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -859,6 +863,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -867,9 +872,9 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Udop class UdopLayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = UdopAttention(config, has_relative_attention_bias=False) + self.EncDecAttention = UdopAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -884,6 +889,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -896,6 +902,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -904,13 +911,17 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Udop class UdopBlock(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(UdopLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append( + UdopLayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + ) if self.is_decoder: - self.layer.append(UdopLayerCrossAttention(config)) + self.layer.append(UdopLayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(UdopLayerFF(config)) @@ -928,34 +939,19 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -969,25 +965,18 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -998,10 +987,6 @@ def forward( ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -1020,11 +1005,11 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) class UdopCellEmbeddings(nn.Module): @@ -1286,7 +1271,7 @@ def __init__(self, config, embed_tokens=None, embed_patches=None): self.num_layers = config.num_layers self.block = nn.ModuleList( - [UdopBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(self.num_layers)] + [UdopBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(self.num_layers)] ) self.final_layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) @@ -1338,6 +1323,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1399,26 +1385,54 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self) - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min if self.is_decoder and encoder_attention_mask is not None: encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) @@ -1427,7 +1441,6 @@ def forward( # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None @@ -1436,34 +1449,35 @@ def forward( position_bias = None else: position_bias = self.relative_bias(attention_mask=attention_mask, bbox=bbox) - position_bias = position_bias + extended_attention_mask + position_bias = position_bias + causal_mask encoder_decoder_position_bias = None hidden_states = inputs_embeds hidden_states = self.dropout(hidden_states) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=head_mask[i], - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) if use_cache is False: # MP fixes layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention weights), @@ -1472,9 +1486,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now @@ -1488,13 +1499,19 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, attention_mask, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1505,12 +1522,135 @@ def forward( return BaseModelOutputWithAttentionMask( last_hidden_state=hidden_states, attention_mask=attention_mask, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + @add_start_docstrings( "The bare UDOP encoder-decoder Transformer outputting raw hidden-states without any specific head on top.", @@ -1584,6 +1724,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[Tensor, ...]: r""" Returns: @@ -1653,6 +1794,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1759,6 +1901,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[Tensor, ...]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1837,6 +1980,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] diff --git a/src/transformers/models/umt5/configuration_umt5.py b/src/transformers/models/umt5/configuration_umt5.py index d7323d759fd0..ba8ea0460ba0 100644 --- a/src/transformers/models/umt5/configuration_umt5.py +++ b/src/transformers/models/umt5/configuration_umt5.py @@ -72,7 +72,12 @@ class UMT5Config(PretrainedConfig): model_type = "umt5" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "head_dim": "d_kv", + } def __init__( self, diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index bd621fc2fb3a..985dc5e4426d 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -23,7 +23,9 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -40,6 +42,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -155,7 +158,7 @@ class UMT5Attention(nn.Module): T5's attention using relative_attention_bias. """ - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -166,6 +169,13 @@ def __init__(self, config, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -230,11 +240,14 @@ def _relative_position_bucket(self, relative_position): relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket(relative_position) @@ -249,78 +262,95 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, ): - is_cross_attention = encoder_hidden_states is not None batch_size, seq_length = hidden_states.shape[:2] - # use encoder_hidden_states if cross attention - current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - # checking that the `sequence_length` of the `past_key_value` is the same as the he provided - # `encoder_hidden_states` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + # if encoder_hidden_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = encoder_hidden_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - key_states = self._shape(self.k(current_states)) - value_states = self._shape(self.v(current_states)) - if past_key_value is not None and not is_cross_attention: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - query_states = self._shape(self.q(hidden_states)) - attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - # compute positional bias - if self.has_relative_attention_bias: - query_length = seq_length if past_key_value is not None: - query_length += past_key_value[0].shape[2] - position_bias = self.compute_bias(query_length, key_states.size(2), device=attention_scores.device) - else: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = seq_length + past_key_value.get_seq_length() if past_key_value is not None else seq_length + key_length = key_states.shape[-2] + if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, seq_length, key_states.size(2)), - device=attention_scores.device, - dtype=attention_scores.dtype, - requires_grad=self.training, + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] + if attention_mask is not None: - position_bias = position_bias + attention_mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - attention_scores += position_bias # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.softmax(attention_scores.float(), dim=-1).type_as(attention_scores) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - # attn_output = torch.bmm(attn_probs, value_states) ? - context_states = torch.matmul(attn_weights, value_states) - # attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) ? - context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) - attn_output = self.o(context_states) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, seq_length, -1) + + attn_output = self.o(attn_output) return attn_output, attn_weights, past_key_value class UMT5LayerSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = UMT5Attention(config, has_relative_attention_bias=True) + self.SelfAttention = UMT5Attention(config, has_relative_attention_bias=True, layer_idx=layer_idx) self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -330,6 +360,7 @@ def forward( attention_mask=None, layer_head_mask=None, past_key_value=None, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -337,6 +368,7 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, past_key_value=past_key_value, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -344,9 +376,9 @@ def forward( class UMT5LayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = UMT5Attention(config, has_relative_attention_bias=False) + self.EncDecAttention = UMT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -357,6 +389,7 @@ def forward( attention_mask=None, layer_head_mask=None, past_key_value=None, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -365,6 +398,7 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, past_key_value=past_key_value, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -372,13 +406,13 @@ def forward( class UMT5Block(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(UMT5LayerSelfAttention(config)) + self.layer.append(UMT5LayerSelfAttention(config, layer_idx=layer_idx)) if self.is_decoder: - self.layer.append(UMT5LayerCrossAttention(config)) + self.layer.append(UMT5LayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(UMT5LayerFF(config)) @@ -393,16 +427,14 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - - hidden_states, self_attn_weights, present_key_value = self.layer[0]( + hidden_states, self_attn_weights, past_key_value = self.layer[0]( hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) # clamp inf values to enable fp16 training @@ -412,18 +444,16 @@ def forward( hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.layer[1]( + hidden_states, cross_attn_weights, past_key_value = self.layer[1]( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -431,8 +461,6 @@ def forward( clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - present_key_value += cross_attn_present_key_value - # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states) @@ -444,7 +472,7 @@ def forward( outputs = ( hidden_states, - present_key_value, + past_key_value, ) if output_attentions: @@ -481,6 +509,8 @@ class UMT5PreTrainedModel(PreTrainedModel): config_class = UMT5Config base_model_prefix = "transformer" supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = True _no_split_modules = ["UMT5Block"] _keep_in_fp32_modules = ["wo"] @@ -594,7 +624,7 @@ def __init__(self, config, embed_tokens=None): super().__init__(config) self.embed_tokens = embed_tokens self.is_decoder = config.is_decoder - self.block = nn.ModuleList([UMT5Block(config) for i in range(config.num_layers)]) + self.block = nn.ModuleList([UMT5Block(config, layer_idx=i) for i in range(config.num_layers)]) self.final_layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -622,6 +652,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -644,6 +675,13 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -651,28 +689,57 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + elif attention_mask is not None: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + else: + causal_mask = None # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -685,24 +752,16 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.is_decoder else None hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] @@ -713,7 +772,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, encoder_hidden_states, encoder_extended_attention_mask, layer_head_mask, @@ -721,24 +780,26 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - present_key_value_states += (layer_outputs[1],) + next_decoder_cache = layer_outputs[1] if output_attentions: all_attentions += (layer_outputs[2],) @@ -752,12 +813,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -766,12 +833,135 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + UMT5_START_DOCSTRING = r""" @@ -885,6 +1075,9 @@ def forward( more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ UMT5_ENCODER_INPUTS_DOCSTRING = r""" @@ -1022,6 +1215,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" Returns: @@ -1084,6 +1278,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1197,6 +1392,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1268,6 +1464,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] diff --git a/tests/models/longt5/test_modeling_longt5.py b/tests/models/longt5/test_modeling_longt5.py index c0cf21b2369d..a9d3e7479e95 100644 --- a/tests/models/longt5/test_modeling_longt5.py +++ b/tests/models/longt5/test_modeling_longt5.py @@ -31,6 +31,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import ( MODEL_FOR_QUESTION_ANSWERING_MAPPING, @@ -574,6 +575,41 @@ def test_decoder_model_past_with_3d_attn_mask(self): lm_labels, ) + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) @@ -602,7 +638,7 @@ def test_export_to_onnx(self): (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), f"{tmpdirname}/longt5_test.onnx", export_params=True, - opset_version=13, + opset_version=14, input_names=["input_ids", "decoder_input_ids"], ) diff --git a/tests/models/mt5/test_modeling_mt5.py b/tests/models/mt5/test_modeling_mt5.py index 6e912ec3607d..20412da2e1db 100644 --- a/tests/models/mt5/test_modeling_mt5.py +++ b/tests/models/mt5/test_modeling_mt5.py @@ -40,6 +40,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import ( AutoModelForSeq2SeqLM, @@ -575,6 +576,9 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, # The small MT5 model needs higher percentages for CPU/MP tests model_split_percents = [0.5, 0.8, 0.9] + # used in `test_torch_compile` + _torch_compile_test_ckpt = "google/mt5-small" + def setUp(self): self.model_tester = MT5ModelTester(self) self.config_tester = ConfigTester(self, config_class=MT5Config, d_model=37) @@ -627,12 +631,9 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa ] if labels is not None: input_names.append("labels") - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} input_names = list(filtered_inputs.keys()) - model_output = model(**filtered_inputs) - traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) else: @@ -647,7 +648,6 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa "visual_feats", "visual_pos", ] - labels = inputs.get("labels", None) start_positions = inputs.get("start_positions", None) end_positions = inputs.get("end_positions", None) @@ -657,15 +657,12 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa input_names.append("start_positions") if end_positions is not None: input_names.append("end_positions") - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} input_names = list(filtered_inputs.keys()) - if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( not hasattr(model.config, "problem_type") or model.config.problem_type is None ): model.config.problem_type = "single_label_classification" - traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) model_output = model(**filtered_inputs) @@ -718,6 +715,41 @@ def flatten_output(output): # (Even with this call, there are still memory leak by ~0.04MB) self.clear_torch_jit_class_registry() + # overwrite because MT5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + def test_config(self): self.config_tester.run_common_tests() diff --git a/tests/models/pop2piano/test_modeling_pop2piano.py b/tests/models/pop2piano/test_modeling_pop2piano.py index 3a33b5a98128..39ff67f08ce5 100644 --- a/tests/models/pop2piano/test_modeling_pop2piano.py +++ b/tests/models/pop2piano/test_modeling_pop2piano.py @@ -620,7 +620,7 @@ def test_export_to_onnx(self): (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), f"{tmpdirname}/Pop2Piano_test.onnx", export_params=True, - opset_version=9, + opset_version=14, input_names=["input_ids", "decoder_input_ids"], ) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 13215b2826fe..7adb1f40c6e6 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -36,6 +36,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import ( AutoTokenizer, @@ -645,6 +646,41 @@ def test_decoder_model_past_with_3d_attn_mask(self): lm_labels, ) + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index fe9b40a54abe..68dd5a52b3d6 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -27,6 +27,7 @@ require_sentencepiece, require_tokenizers, require_torch, + require_torch_gpu, slow, torch_device, ) @@ -44,6 +45,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import ( AutoTokenizer, @@ -578,6 +580,9 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, # The small T5 model needs higher percentages for CPU/MP tests model_split_percents = [0.5, 0.8, 0.9] + # used in `test_torch_compile` + _torch_compile_test_ckpt = "google-t5/t5-small" + def setUp(self): self.model_tester = T5ModelTester(self) self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37) @@ -630,12 +635,9 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa ] if labels is not None: input_names.append("labels") - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} input_names = list(filtered_inputs.keys()) - model_output = model(**filtered_inputs) - traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) else: @@ -650,7 +652,6 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa "visual_feats", "visual_pos", ] - labels = inputs.get("labels", None) start_positions = inputs.get("start_positions", None) end_positions = inputs.get("end_positions", None) @@ -660,15 +661,12 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa input_names.append("start_positions") if end_positions is not None: input_names.append("end_positions") - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} input_names = list(filtered_inputs.keys()) - if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( not hasattr(model.config, "problem_type") or model.config.problem_type is None ): model.config.problem_type = "single_label_classification" - traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) model_output = model(**filtered_inputs) @@ -721,6 +719,41 @@ def flatten_output(output): # (Even with this call, there are still memory leak by ~0.04MB) self.clear_torch_jit_class_registry() + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + def test_config(self): self.config_tester.run_common_tests() @@ -1482,6 +1515,7 @@ def test_summarization(self): [model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]], padding="max_length", truncation=True, + max_length=512, return_tensors="pt", ).to(torch_device) self.assertEqual(512, dct["input_ids"].shape[1]) @@ -1604,14 +1638,76 @@ def test_contrastive_search_t5(self): outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True) + # TODO: @arthur? + # PR #31938 caused regression on this test which was fixed by PR #34089 self.assertListEqual( generated_text, [ - "Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for " - "permanent residence after the marriages, prosecutors say." + "Liana Barrientos has been married 10 times, nine of them in the Bronx . Her husbands filed for " + "permanent residence after the marriages, prosecutors say ." ], ) + @slow + @require_torch_gpu + def test_compile_static_cache(self): + NUM_TOKENS_TO_GENERATE = 40 + EXPECTED_TEXT_COMPLETION = [ + "theory of relativity states that 1) the speed of light is constant in all inertial reference frames. the laws of physics are the same for all inertial reference frames.", + "ketchup is my favorite condiment.", + ] + + prompts = [ + "summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " + "theory of relativity is not hard to grasp.", + "summarize: My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " + "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my pizza.", + ] + model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small").to(torch_device) + tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small") + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + + # Static Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + + # Static Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) + + @slow + @require_torch_gpu + def test_compile_static_cache_encoder(self): + prompts = [ + "summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " + "theory of relativity is not hard to grasp.", + "summarize: My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " + "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my pizza.", + ] + model = T5EncoderModel.from_pretrained("google-t5/t5-small").to(torch_device) + tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small") + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + logits = model(**inputs) + + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + logits_compiled = model(**inputs) + self.assertTrue(torch.allclose(logits[0][:, -3:, -3], logits_compiled[0][:, -3:, -3], atol=1e-5)) + @require_torch class TestAsymmetricT5(unittest.TestCase): diff --git a/tests/models/udop/test_modeling_udop.py b/tests/models/udop/test_modeling_udop.py index a3ae498606a3..9d82173b1aed 100644 --- a/tests/models/udop/test_modeling_udop.py +++ b/tests/models/udop/test_modeling_udop.py @@ -37,6 +37,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import UdopEncoderModel, UdopForConditionalGeneration, UdopModel, UdopProcessor @@ -348,6 +349,7 @@ def test_forward_signature(self): expected_arg_names = [ "attention_mask", "bbox", + "cache_position", "cross_attn_head_mask", "decoder_attention_mask", "decoder_head_mask", @@ -365,6 +367,43 @@ def test_forward_signature(self): expected_arg_names = sorted(expected_arg_names) self.assertListEqual(sorted(arg_names[: len(expected_arg_names)]), expected_arg_names) + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + bbox=input_dict["bbox"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + bbox=input_dict["bbox"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + @unittest.skip( "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" ) @@ -534,6 +573,41 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + @unittest.skip( "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" ) diff --git a/tests/models/umt5/test_modeling_umt5.py b/tests/models/umt5/test_modeling_umt5.py index 1bd01da8e6ca..ec4c1d019b6d 100644 --- a/tests/models/umt5/test_modeling_umt5.py +++ b/tests/models/umt5/test_modeling_umt5.py @@ -41,6 +41,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import ( AutoTokenizer, @@ -316,6 +317,9 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin # The small UMT5 model needs higher percentages for CPU/MP tests model_split_percents = [0.5, 0.8, 0.9] + # used in `test_torch_compile` + _torch_compile_test_ckpt = "google/umt5-small" + def setUp(self): self.model_tester = UMT5ModelTester(self) @@ -486,6 +490,41 @@ def test_inputs_embeds(self): with torch.no_grad(): model(**inputs)[0] + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + def test_with_sequence_classification_head(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index dec1482f562a..964b7b912b4e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -37,6 +37,7 @@ from transformers import ( AutoModel, AutoModelForCausalLM, + AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig, @@ -5109,10 +5110,15 @@ def test_torch_compile(self): batch_size = 1 n_iter = 3 - tokenizer = AutoTokenizer.from_pretrained(ckpt, revision=revision) - model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( - torch_device - ) + tokenizer = AutoTokenizer.from_pretrained(ckpt) + if self.is_encoder_decoder: + model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( + torch_device + ) + else: + model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( + torch_device + ) model.generation_config.max_new_tokens = 4 @@ -5184,10 +5190,15 @@ def test_compile_cuda_graph_time(self): os.environ["TOKENIZERS_PARALLELISM"] = "false" - tokenizer = AutoTokenizer.from_pretrained(ckpt, revision=revision) - model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( - torch_device - ) + tokenizer = AutoTokenizer.from_pretrained(ckpt) + if self.is_encoder_decoder: + model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( + torch_device + ) + else: + model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( + torch_device + ) cache_implementation = "static" if model.config.model_type == "gemma2":