diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 098adacc8dee..d46e9d9e731b 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Idefics model.""" +from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch @@ -28,7 +29,7 @@ from ... import PreTrainedModel from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_outputs import ModelOutput from ...modeling_utils import PretrainedConfig from ...utils import ( add_start_docstrings, @@ -52,6 +53,93 @@ ] +@dataclass +class IdeficsBaseModelOutputWithPast(ModelOutput): + """ + Base class for Idefics model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + + image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class IdeficsCausalLMOutputWithPast(ModelOutput): + """ + Base class for Idefics causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + + image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + def expand_inputs_for_generation( input_ids, expand_size=1, @@ -64,6 +152,10 @@ def expand_inputs_for_generation( torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) ) input_ids = input_ids.index_select(0, expanded_return_idx) + model_kwargs["pixel_values"] = model_kwargs.get("pixel_values", None) + model_kwargs["image_encoder_embeddings"] = model_kwargs.get("image_encoder_embeddings", None) + model_kwargs["perceiver_embeddings"] = model_kwargs.get("perceiver_embeddings", None) + model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask", None) if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] @@ -71,22 +163,29 @@ def expand_inputs_for_generation( if attention_mask is not None: model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) + + if model_kwargs["image_attention_mask"] is not None: model_kwargs["image_attention_mask"] = model_kwargs["image_attention_mask"].index_select( 0, expanded_return_idx ) + + if model_kwargs["pixel_values"] is not None: model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx) - if is_encoder_decoder: - if encoder_outputs is None: - raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") - encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( - 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) + elif model_kwargs["image_encoder_embeddings"] is not None: + model_kwargs["image_encoder_embeddings"] = model_kwargs["image_encoder_embeddings"].index_select( + 0, expanded_return_idx + ) + + elif model_kwargs["perceiver_embeddings"] is not None: + model_kwargs["perceiver_embeddings"] = model_kwargs["perceiver_embeddings"].index_select( + 0, expanded_return_idx ) - model_kwargs["encoder_outputs"] = encoder_outputs + return input_ids, model_kwargs -def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): +def update_model_kwargs_for_generation(outputs, model_kwargs): # must have this key set to at least None model_kwargs["past_key_values"] = model_kwargs.get("past_key_values", None) @@ -106,16 +205,18 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) # update attention masks - if not is_encoder_decoder: - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - if "image_attention_mask" in model_kwargs: - image_attention_mask = model_kwargs["image_attention_mask"] - last_mask = image_attention_mask[:, -1, :].unsqueeze(1) - model_kwargs["image_attention_mask"] = last_mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + if "image_attention_mask" in model_kwargs: + image_attention_mask = model_kwargs["image_attention_mask"] + last_mask = image_attention_mask[:, -1, :].unsqueeze(1) + model_kwargs["image_attention_mask"] = last_mask + + # Get the precomputed image_hidden_states + model_kwargs["image_hidden_states"] = outputs.image_hidden_states return model_kwargs @@ -139,9 +240,9 @@ def prepare_inputs_for_generation(input_ids, past=None, **kwargs): position_ids = position_ids[:, -1].unsqueeze(-1) pixel_values = kwargs.get("pixel_values", None) + image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None) + perceiver_embeddings = kwargs.get("perceiver_embeddings", None) image_attention_mask = kwargs.get("image_attention_mask", None) - # if pixel_values is None or image_attention_mask is None: - # raise ValueError("pixel values and image attention mask cannot be None") return { "input_ids": input_ids, @@ -151,6 +252,8 @@ def prepare_inputs_for_generation(input_ids, past=None, **kwargs): "attention_mask": attention_mask, "token_type_ids": token_type_ids, "pixel_values": pixel_values, + "image_encoder_embeddings": image_encoder_embeddings, + "perceiver_embeddings": perceiver_embeddings, "image_attention_mask": image_attention_mask, } @@ -1055,13 +1158,14 @@ def forward( past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, - image_embeddings: Optional[torch.FloatTensor] = None, + image_encoder_embeddings: Optional[torch.FloatTensor] = None, + perceiver_embeddings: Optional[torch.FloatTensor] = None, image_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> Union[Tuple, IdeficsBaseModelOutputWithPast]: device = input_ids.device if input_ids is not None else inputs_embeds.device output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1103,11 +1207,10 @@ def forward( position_ids = position_ids.view(-1, seq_length).long() no_images = False - if pixel_values is None and image_embeddings is None: - raise ValueError("Either pixel_values and image_embeddings have to be not-None.") - - elif pixel_values is not None and image_embeddings is not None: - raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time") + if (pixel_values, image_encoder_embeddings, perceiver_embeddings).count(None) != 2: + raise ValueError( + "Exactly 1 of pixel_values, image_encoder_embeddings or perceiver_embeddings has to be not-None." + ) elif pixel_values is not None: no_images = len(torch.nonzero(pixel_values)) == 0 @@ -1118,14 +1221,23 @@ def forward( # Get sequence from the vision encoder image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state - elif image_embeddings is not None: - batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size() - image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device) + elif image_encoder_embeddings is not None: + batch_size, num_images, image_seq_len, image_hidden_size = image_encoder_embeddings.size() + image_hidden_states = image_encoder_embeddings.to(dtype=self.dtype, device=input_ids.device) image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size) if self.config.use_resampler: - image_hidden_states = self.perceiver_resampler(image_hidden_states) - image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2) + if perceiver_embeddings is None: + perceiver_embeddings = self.perceiver_resampler(image_hidden_states) + image_seq_len, image_hidden_size = perceiver_embeddings.size(1), perceiver_embeddings.size(2) + else: + batch_size, num_images, image_seq_len, image_hidden_size = perceiver_embeddings.size() + image_hidden_states = perceiver_embeddings + elif perceiver_embeddings is None: + image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2) + else: + raise ValueError("If `perceiver_embeddings` are passed, use_resampler should be True") + image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size) # # Hack to use the model in full language modeling mode # image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device) @@ -1272,13 +1384,19 @@ def vblock( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + image_hidden_states = image_hidden_states.view(batch_size, num_images, image_seq_len, image_hidden_size) if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, image_hidden_states] + if v is not None + ) + return IdeficsBaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, + image_hidden_states=image_hidden_states, ) @@ -1341,7 +1459,7 @@ def tie_weights(self): output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=IdeficsCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, @@ -1350,14 +1468,15 @@ def forward( past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, - image_embeddings: Optional[torch.FloatTensor] = None, + image_encoder_embeddings: Optional[torch.FloatTensor] = None, + perceiver_embeddings: Optional[torch.FloatTensor] = None, image_attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> Union[Tuple, IdeficsCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1398,7 +1517,8 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, pixel_values=pixel_values, - image_embeddings=image_embeddings, + image_encoder_embeddings=image_encoder_embeddings, + perceiver_embeddings=perceiver_embeddings, image_attention_mask=image_attention_mask, use_cache=use_cache, output_attentions=output_attentions, @@ -1427,15 +1547,23 @@ def forward( output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( + return IdeficsCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + image_hidden_states = kwargs.pop("image_hidden_states", None) + if image_hidden_states is not None: + if self.config.use_resampler: + kwargs["perceiver_embeddings"] = image_hidden_states + else: + kwargs["image_encoder_embeddings"] = image_hidden_states + kwargs["pixel_values"] = None inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs) unwanted_kwargs = ["token_type_ids"] for kwarg in unwanted_kwargs: @@ -1450,8 +1578,8 @@ def _expand_inputs_for_generation( return expand_inputs_for_generation(*args, **model_kwargs) @staticmethod - def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): - return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder) + def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder): + return update_model_kwargs_for_generation(outputs, model_kwargs) @staticmethod def _reorder_cache(past, beam_idx):