Skip to content
Merged
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
76d3628
add image_embeddings option in generate-related methods
leot13 Aug 10, 2023
0687433
style
leot13 Aug 10, 2023
0914394
rename image_embeddings and allow perceiver embeddings precomputation
leot13 Aug 11, 2023
6bb49c4
compute embeddings within generate
leot13 Aug 14, 2023
86bdb7e
make is_encoder_decoder= True the default in config
leot13 Aug 14, 2023
c8cc8f7
nested if else fix
leot13 Aug 14, 2023
08f9000
better triple check
leot13 Aug 14, 2023
f6f5367
switch if elif order for pixel values / img embeds
leot13 Aug 14, 2023
2d4daf6
update model_kwargs perceiver only at the end
leot13 Aug 14, 2023
784f270
use _prepare_model_inputs instead of encoder_decoder logic
leot13 Aug 14, 2023
aa79134
fix comment typo
leot13 Aug 14, 2023
190ea96
fix config default for is_encoder_decoder
leot13 Aug 14, 2023
6fdd61b
style
leot13 Aug 15, 2023
480be33
add typehints
leot13 Aug 15, 2023
79349ad
precompute in forward
leot13 Aug 17, 2023
e6781fb
doc builder
leot13 Aug 17, 2023
df0d79b
style
leot13 Aug 17, 2023
cda719d
pop instead of get image hidden states
leot13 Aug 17, 2023
6988051
Trigger CI
leot13 Aug 17, 2023
5f6fb1e
Update src/transformers/models/idefics/modeling_idefics.py
leot13 Aug 18, 2023
603efa8
Update src/transformers/models/idefics/modeling_idefics.py
leot13 Aug 18, 2023
46fa4b0
fix * + indentation + style
leot13 Aug 18, 2023
c5585e6
simplify a bit the use_resampler logic using comments
leot13 Aug 18, 2023
f0036e0
update diocstrings
leot13 Aug 18, 2023
de358ee
Trigger CI
leot13 Aug 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 169 additions & 41 deletions src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Idefics model."""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import torch
Expand All @@ -28,7 +29,7 @@

from ... import PreTrainedModel
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PretrainedConfig
from ...utils import (
add_start_docstrings,
Expand All @@ -52,6 +53,93 @@
]


@dataclass
class IdeficsBaseModelOutputWithPast(ModelOutput):
"""
Base class for Idefics model's outputs that may also contain a past key/values (to speed up sequential decoding).

Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.

If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.

Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
sequence_length, hidden_size)`.

image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
"""

last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class IdeficsCausalLMOutputWithPast(ModelOutput):
"""
Base class for Idefics causal language model (or autoregressive) outputs.

Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)

Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
sequence_length, hidden_size)`.

image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None


def expand_inputs_for_generation(
input_ids,
expand_size=1,
Expand All @@ -64,29 +152,40 @@ def expand_inputs_for_generation(
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
)
input_ids = input_ids.index_select(0, expanded_return_idx)
model_kwargs["pixel_values"] = model_kwargs.get("pixel_values", None)
model_kwargs["image_encoder_embeddings"] = model_kwargs.get("image_encoder_embeddings", None)
model_kwargs["perceiver_embeddings"] = model_kwargs.get("perceiver_embeddings", None)
model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask", None)

if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)

if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)

if model_kwargs["image_attention_mask"] is not None:
model_kwargs["image_attention_mask"] = model_kwargs["image_attention_mask"].index_select(
0, expanded_return_idx
)

if model_kwargs["pixel_values"] is not None:
model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx)

if is_encoder_decoder:
if encoder_outputs is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
elif model_kwargs["image_encoder_embeddings"] is not None:
model_kwargs["image_encoder_embeddings"] = model_kwargs["image_encoder_embeddings"].index_select(
0, expanded_return_idx
)

elif model_kwargs["perceiver_embeddings"] is not None:
model_kwargs["perceiver_embeddings"] = model_kwargs["perceiver_embeddings"].index_select(
0, expanded_return_idx
)
model_kwargs["encoder_outputs"] = encoder_outputs

return input_ids, model_kwargs


def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
def update_model_kwargs_for_generation(outputs, model_kwargs):
# must have this key set to at least None
model_kwargs["past_key_values"] = model_kwargs.get("past_key_values", None)

Expand All @@ -106,16 +205,18 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)

# update attention masks
if not is_encoder_decoder:
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
if "image_attention_mask" in model_kwargs:
image_attention_mask = model_kwargs["image_attention_mask"]
last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
model_kwargs["image_attention_mask"] = last_mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
if "image_attention_mask" in model_kwargs:
image_attention_mask = model_kwargs["image_attention_mask"]
last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
model_kwargs["image_attention_mask"] = last_mask

# Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states

return model_kwargs

Expand All @@ -139,9 +240,9 @@ def prepare_inputs_for_generation(input_ids, past=None, **kwargs):
position_ids = position_ids[:, -1].unsqueeze(-1)

pixel_values = kwargs.get("pixel_values", None)
image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None)
perceiver_embeddings = kwargs.get("perceiver_embeddings", None)
image_attention_mask = kwargs.get("image_attention_mask", None)
# if pixel_values is None or image_attention_mask is None:
# raise ValueError("pixel values and image attention mask cannot be None")

return {
"input_ids": input_ids,
Expand All @@ -151,6 +252,8 @@ def prepare_inputs_for_generation(input_ids, past=None, **kwargs):
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"pixel_values": pixel_values,
"image_encoder_embeddings": image_encoder_embeddings,
"perceiver_embeddings": perceiver_embeddings,
"image_attention_mask": image_attention_mask,
}

Expand Down Expand Up @@ -1055,13 +1158,14 @@ def forward(
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_embeddings: Optional[torch.FloatTensor] = None,
image_encoder_embeddings: Optional[torch.FloatTensor] = None,
perceiver_embeddings: Optional[torch.FloatTensor] = None,
image_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
) -> Union[Tuple, IdeficsBaseModelOutputWithPast]:
device = input_ids.device if input_ids is not None else inputs_embeds.device

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down Expand Up @@ -1103,11 +1207,10 @@ def forward(
position_ids = position_ids.view(-1, seq_length).long()

no_images = False
if pixel_values is None and image_embeddings is None:
raise ValueError("Either pixel_values and image_embeddings have to be not-None.")

elif pixel_values is not None and image_embeddings is not None:
raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
if (pixel_values, image_encoder_embeddings, perceiver_embeddings).count(None) != 2:
raise ValueError(
"Exactly 1 of pixel_values, image_encoder_embeddings or perceiver_embeddings has to be not-None."
)

elif pixel_values is not None:
no_images = len(torch.nonzero(pixel_values)) == 0
Expand All @@ -1118,14 +1221,23 @@ def forward(
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state

elif image_embeddings is not None:
batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
elif image_encoder_embeddings is not None:
batch_size, num_images, image_seq_len, image_hidden_size = image_encoder_embeddings.size()
image_hidden_states = image_encoder_embeddings.to(dtype=self.dtype, device=input_ids.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last nit, is this required? Would think that accelerate handles this since it's not a tensor created on the fly. (unless the casting is what's required not necessarly moving to a different device!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, maybe not. I kept it out of caution when modifying this part. @VictorSanh, do you know if there was a particular reason for adding this?

image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)

if self.config.use_resampler:
image_hidden_states = self.perceiver_resampler(image_hidden_states)
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
if perceiver_embeddings is None:
perceiver_embeddings = self.perceiver_resampler(image_hidden_states)
image_seq_len, image_hidden_size = perceiver_embeddings.size(1), perceiver_embeddings.size(2)
else:
batch_size, num_images, image_seq_len, image_hidden_size = perceiver_embeddings.size()
image_hidden_states = perceiver_embeddings
elif perceiver_embeddings is None:
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
else:
raise ValueError("If `perceiver_embeddings` are passed, use_resampler should be True")

image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
# # Hack to use the model in full language modeling mode
# image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device)
Expand Down Expand Up @@ -1272,13 +1384,19 @@ def vblock(
all_hidden_states += (hidden_states,)

next_cache = next_decoder_cache if use_cache else None
image_hidden_states = image_hidden_states.view(batch_size, num_images, image_seq_len, image_hidden_size)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, image_hidden_states]
if v is not None
)
return IdeficsBaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
image_hidden_states=image_hidden_states,
)


Expand Down Expand Up @@ -1341,7 +1459,7 @@ def tie_weights(self):
output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings

@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=IdeficsCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
Expand All @@ -1350,14 +1468,15 @@ def forward(
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_embeddings: Optional[torch.FloatTensor] = None,
image_encoder_embeddings: Optional[torch.FloatTensor] = None,
perceiver_embeddings: Optional[torch.FloatTensor] = None,
image_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> Union[Tuple, IdeficsCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down Expand Up @@ -1398,7 +1517,8 @@ def forward(
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
image_embeddings=image_embeddings,
image_encoder_embeddings=image_encoder_embeddings,
perceiver_embeddings=perceiver_embeddings,
image_attention_mask=image_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
Expand Down Expand Up @@ -1427,15 +1547,23 @@ def forward(
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return CausalLMOutputWithPast(
return IdeficsCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)

def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
image_hidden_states = kwargs.pop("image_hidden_states", None)
if image_hidden_states is not None:
if self.config.use_resampler:
kwargs["perceiver_embeddings"] = image_hidden_states
else:
kwargs["image_encoder_embeddings"] = image_hidden_states
kwargs["pixel_values"] = None
inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
unwanted_kwargs = ["token_type_ids"]
for kwarg in unwanted_kwargs:
Expand All @@ -1450,8 +1578,8 @@ def _expand_inputs_for_generation(
return expand_inputs_for_generation(*args, **model_kwargs)

@staticmethod
def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder)
def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder):
return update_model_kwargs_for_generation(outputs, model_kwargs)

@staticmethod
def _reorder_cache(past, beam_idx):
Expand Down