Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,14 +726,23 @@ def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_en
elif mask_length_diff > 0:
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)

# Handle cross attention models
if "cross_attention_mask" in model_kwargs:
# Mllama case is special and has another mask for cross attention model
# Mllama case
cross_mask = model_kwargs["cross_attention_mask"]
if mask_length_diff < 0:
model_kwargs["cross_attention_mask"] = cross_mask[:, :mask_length_diff]
elif mask_length_diff > 0:
new_mask = cross_mask[:, -1:, :, :].repeat(1, mask_length_diff, 1, 1)
model_kwargs["cross_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1)
elif "image_attention_mask" in model_kwargs:
# IDEFICS case
cross_mask = model_kwargs["image_attention_mask"]
if mask_length_diff < 0:
model_kwargs["image_attention_mask"] = cross_mask[:, :mask_length_diff]
elif mask_length_diff > 0:
new_mask = cross_mask[:, -1:, :].repeat(1, mask_length_diff, 1)
model_kwargs["image_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1)

return model_kwargs

Expand Down
4 changes: 3 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2005,6 +2005,7 @@ def generate(
# generating the first new token or not, and we only want to use the embeddings for the first new token)
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
model_kwargs["use_cache"] = True
generation_config.use_cache = True
else:
model_kwargs["use_cache"] = generation_config.use_cache

Expand Down Expand Up @@ -4299,7 +4300,8 @@ def _assisted_decoding(
newly_added_length,
is_decoder_attention=True,
)
else:
# some (V)LLMs have hard requirement on SDPA and thus never return attn
elif outputs.attentions[0] is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ouuu very nice!

decoder_attentions = _split_model_outputs(
decoder_attentions,
outputs.attentions,
Expand Down
37 changes: 17 additions & 20 deletions src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
from torch import nn
from torch.nn import CrossEntropyLoss

from ... import PreTrainedModel
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PretrainedConfig
from ...modeling_utils import PretrainedConfig, PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
Expand Down Expand Up @@ -622,11 +622,9 @@ def forward(
query_states = self.q_layer_norm(query_states)
key_states = self.k_layer_norm(key_states)

causal_mask = attention_mask
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
Expand All @@ -638,13 +636,13 @@ def forward(
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
attn_mask=causal_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
Expand Down Expand Up @@ -1490,7 +1488,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask


class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"]

Expand Down Expand Up @@ -1670,6 +1668,7 @@ def prepare_inputs_for_generation(
position_ids=None,
pixel_values=None,
image_hidden_states=None,
image_attention_mask=None,
use_cache=None,
cache_position=None,
**kwargs,
Expand All @@ -1678,6 +1677,8 @@ def prepare_inputs_for_generation(
if past_key_values is not None:
if input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
if image_attention_mask is not None:
image_attention_mask = image_attention_mask[:, -input_ids.shape[1] :]

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
Expand All @@ -1696,7 +1697,8 @@ def prepare_inputs_for_generation(
model_inputs["perceiver_embeddings"] = image_hidden_states
else:
model_inputs["image_encoder_embeddings"] = image_hidden_states
pixel_values = None
else:
model_inputs["pixel_values"] = pixel_values

model_inputs.update(
{
Expand All @@ -1706,21 +1708,13 @@ def prepare_inputs_for_generation(
"cache_position": cache_position,
"position_ids": position_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"image_attention_mask": kwargs.get("image_attention_mask", None),
"image_attention_mask": image_attention_mask,
"interpolate_pos_encoding": kwargs.get("interpolate_pos_encoding", False),
}
)

return model_inputs

@staticmethod
def _expand_inputs_for_generation(
*args,
**model_kwargs,
):
return expand_inputs_for_generation(*args, **model_kwargs)

def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
Expand All @@ -1738,7 +1732,10 @@ def _update_model_kwargs_for_generation(
if "image_attention_mask" in model_kwargs:
image_attention_mask = model_kwargs["image_attention_mask"]
last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
model_kwargs["image_attention_mask"] = last_mask
if model_kwargs.get("use_cache", True):
model_kwargs["image_attention_mask"] = last_mask
else:
model_kwargs["image_attention_mask"] = torch.cat([image_attention_mask, last_mask], dim=1)

# Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
Expand Down
48 changes: 17 additions & 31 deletions src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
Expand Down Expand Up @@ -1657,35 +1658,19 @@ def prepare_inputs_for_generation(
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
pixel_values=None,
pixel_attention_mask=None,
image_hidden_states=None,
num_logits_to_keep=None,
**kwargs,
):
past_length = 0
# Omit tokens covered by past_key_values
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = past_key_values.get_seq_length()
max_cache_length = past_key_values.get_max_cache_shape()

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and past_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
Expand All @@ -1696,21 +1681,22 @@ def prepare_inputs_for_generation(
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
# but IDEFICS requires noth ids and embeds to be present
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep

image_hidden_states = kwargs.get("image_hidden_states", None)
if image_hidden_states is not None:
pixel_values = None
pixel_attention_mask = None
else:
pixel_values = kwargs.get("pixel_values", None)
pixel_attention_mask = kwargs.get("pixel_attention_mask", None)
pixel_values = pixel_values
pixel_attention_mask = pixel_attention_mask
model_inputs.update(
{
"position_ids": position_ids,
Expand Down
55 changes: 22 additions & 33 deletions src/transformers/models/idefics3/modeling_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

from ... import PreTrainedModel
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...utils import (
Expand Down Expand Up @@ -953,6 +954,8 @@ def forward(

past_seen_tokens = 0
if use_cache:
if past_key_values is None:
past_key_values = DynamicCache()
past_seen_tokens = past_key_values.get_seq_length()

if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
Expand Down Expand Up @@ -1019,6 +1022,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
Expand All @@ -1040,7 +1044,7 @@ def forward(
"""The Idefics3 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """,
IDEFICS3_START_DOCSTRING,
)
class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel):
class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.__init__ with Idefics2->Idefics3
Expand Down Expand Up @@ -1245,35 +1249,19 @@ def prepare_inputs_for_generation(
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
pixel_values=None,
pixel_attention_mask=None,
image_hidden_states=None,
num_logits_to_keep=None,
**kwargs,
):
past_length = 0
# Omit tokens covered by past_key_values
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = past_key_values.get_seq_length()
max_cache_length = past_key_values.get_max_cache_shape()

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and past_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
Expand All @@ -1284,21 +1272,22 @@ def prepare_inputs_for_generation(
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
# but IDEFICS requires noth ids and embeds to be present
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep

image_hidden_states = kwargs.get("image_hidden_states", None)
if image_hidden_states is not None:
pixel_values = None
pixel_attention_mask = None
else:
pixel_values = kwargs.get("pixel_values", None)
pixel_attention_mask = kwargs.get("pixel_attention_mask", None)
pixel_values = pixel_values
pixel_attention_mask = pixel_attention_mask
model_inputs.update(
{
"position_ids": position_ids,
Expand Down
17 changes: 11 additions & 6 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,11 @@ def _get_logits_processor_kwargs(self, do_sample=False, config=None):
# This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them
# to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens.
if config is not None:
image_token_index = config.image_token_index if hasattr(config, "image_token_index") else None
image_token_index = (
config.image_token_index
if getattr(config, "image_token_index", None) is not None
else getattr(config, "image_token_id", None)
)
video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None
if image_token_index is not None and image_token_index < config.get_text_config().vocab_size:
logits_processor_kwargs["bad_words_ids"].append([image_token_index])
Expand Down Expand Up @@ -1496,13 +1500,14 @@ def test_past_key_values_format(self):
if "past_key_values" not in outputs:
self.skipTest(reason="This model doesn't return `past_key_values`")

text_config = config.get_text_config()
num_hidden_layers = (
getattr(config, "decoder_layers", None)
or getattr(config, "num_decoder_layers", None)
or config.num_hidden_layers
getattr(text_config, "decoder_layers", None)
or getattr(text_config, "num_decoder_layers", None)
or text_config.num_hidden_layers
)
num_attention_heads = getattr(config, "decoder_attention_heads", config.num_attention_heads)
embed_dim = getattr(config, "d_model", config.hidden_size)
num_attention_heads = getattr(text_config, "decoder_attention_heads", text_config.num_attention_heads)
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
per_head_embed_dim = embed_dim // num_attention_heads

past_kv = outputs["past_key_values"]
Expand Down
Loading