-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[Idefics] add image_embeddings option in generate-related methods
#25442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
76d3628
0687433
0914394
6bb49c4
86bdb7e
c8cc8f7
08f9000
f6f5367
2d4daf6
784f270
aa79134
190ea96
6fdd61b
480be33
79349ad
e6781fb
df0d79b
cda719d
6988051
5f6fb1e
603efa8
46fa4b0
c5585e6
f0036e0
de358ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ PyTorch Idefics model.""" | ||
| from dataclasses import dataclass | ||
| from typing import List, Optional, Tuple, Union | ||
|
|
||
| import torch | ||
|
|
@@ -28,7 +29,7 @@ | |
|
|
||
| from ... import PreTrainedModel | ||
| from ...activations import ACT2FN | ||
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast | ||
| from ...modeling_outputs import ModelOutput | ||
| from ...modeling_utils import PretrainedConfig | ||
| from ...utils import ( | ||
| add_start_docstrings, | ||
|
|
@@ -52,6 +53,93 @@ | |
| ] | ||
|
|
||
|
|
||
| @dataclass | ||
| class IdeficsBaseModelOutputWithPast(ModelOutput): | ||
| """ | ||
| Base class for Idefics model's outputs that may also contain a past key/values (to speed up sequential decoding). | ||
|
|
||
| Args: | ||
| last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): | ||
| Sequence of hidden-states at the output of the last layer of the model. | ||
|
|
||
| If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, | ||
| hidden_size)` is output. | ||
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | ||
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape | ||
| `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if | ||
| `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, | ||
| encoder_sequence_length, embed_size_per_head)`. | ||
|
|
||
| Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if | ||
| `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` | ||
| input) to speed up sequential decoding. | ||
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | ||
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + | ||
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. | ||
|
|
||
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. | ||
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | ||
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | ||
| sequence_length)`. | ||
|
|
||
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | ||
| heads. | ||
| image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): | ||
| Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, | ||
| sequence_length, hidden_size)`. | ||
|
|
||
| image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver | ||
| """ | ||
|
|
||
| last_hidden_state: torch.FloatTensor = None | ||
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | ||
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | ||
| attentions: Optional[Tuple[torch.FloatTensor]] = None | ||
| image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None | ||
|
|
||
|
|
||
| @dataclass | ||
| class IdeficsCausalLMOutputWithPast(ModelOutput): | ||
| """ | ||
| Base class for Idefics causal language model (or autoregressive) outputs. | ||
|
|
||
| Args: | ||
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): | ||
| Language modeling loss (for next-token prediction). | ||
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): | ||
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | ||
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | ||
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape | ||
| `(batch_size, num_heads, sequence_length, embed_size_per_head)`) | ||
|
|
||
| Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see | ||
| `past_key_values` input) to speed up sequential decoding. | ||
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | ||
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + | ||
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. | ||
|
|
||
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. | ||
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | ||
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | ||
| sequence_length)`. | ||
|
|
||
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | ||
| heads. | ||
| image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): | ||
| Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, | ||
| sequence_length, hidden_size)`. | ||
|
|
||
| image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver | ||
| """ | ||
|
|
||
| loss: Optional[torch.FloatTensor] = None | ||
| logits: torch.FloatTensor = None | ||
| past_key_values: Optional[List[torch.FloatTensor]] = None | ||
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | ||
| attentions: Optional[Tuple[torch.FloatTensor]] = None | ||
| image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None | ||
|
|
||
|
|
||
| def expand_inputs_for_generation( | ||
| input_ids, | ||
| expand_size=1, | ||
|
|
@@ -64,29 +152,40 @@ def expand_inputs_for_generation( | |
| torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) | ||
| ) | ||
| input_ids = input_ids.index_select(0, expanded_return_idx) | ||
| model_kwargs["pixel_values"] = model_kwargs.get("pixel_values", None) | ||
| model_kwargs["image_encoder_embeddings"] = model_kwargs.get("image_encoder_embeddings", None) | ||
| model_kwargs["perceiver_embeddings"] = model_kwargs.get("perceiver_embeddings", None) | ||
| model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask", None) | ||
|
|
||
| if "token_type_ids" in model_kwargs: | ||
| token_type_ids = model_kwargs["token_type_ids"] | ||
| model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) | ||
|
|
||
| if attention_mask is not None: | ||
| model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) | ||
|
|
||
| if model_kwargs["image_attention_mask"] is not None: | ||
| model_kwargs["image_attention_mask"] = model_kwargs["image_attention_mask"].index_select( | ||
| 0, expanded_return_idx | ||
| ) | ||
|
|
||
| if model_kwargs["pixel_values"] is not None: | ||
| model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx) | ||
|
|
||
| if is_encoder_decoder: | ||
| if encoder_outputs is None: | ||
| raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") | ||
| encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( | ||
| 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) | ||
| elif model_kwargs["image_encoder_embeddings"] is not None: | ||
| model_kwargs["image_encoder_embeddings"] = model_kwargs["image_encoder_embeddings"].index_select( | ||
| 0, expanded_return_idx | ||
| ) | ||
|
|
||
| elif model_kwargs["perceiver_embeddings"] is not None: | ||
| model_kwargs["perceiver_embeddings"] = model_kwargs["perceiver_embeddings"].index_select( | ||
| 0, expanded_return_idx | ||
| ) | ||
| model_kwargs["encoder_outputs"] = encoder_outputs | ||
|
|
||
| return input_ids, model_kwargs | ||
|
|
||
|
|
||
| def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): | ||
| def update_model_kwargs_for_generation(outputs, model_kwargs): | ||
| # must have this key set to at least None | ||
| model_kwargs["past_key_values"] = model_kwargs.get("past_key_values", None) | ||
|
|
||
|
|
@@ -106,16 +205,18 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder | |
| model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) | ||
|
|
||
| # update attention masks | ||
| if not is_encoder_decoder: | ||
| if "attention_mask" in model_kwargs: | ||
| attention_mask = model_kwargs["attention_mask"] | ||
| model_kwargs["attention_mask"] = torch.cat( | ||
| [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 | ||
| ) | ||
| if "image_attention_mask" in model_kwargs: | ||
| image_attention_mask = model_kwargs["image_attention_mask"] | ||
| last_mask = image_attention_mask[:, -1, :].unsqueeze(1) | ||
| model_kwargs["image_attention_mask"] = last_mask | ||
| if "attention_mask" in model_kwargs: | ||
| attention_mask = model_kwargs["attention_mask"] | ||
| model_kwargs["attention_mask"] = torch.cat( | ||
| [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 | ||
| ) | ||
| if "image_attention_mask" in model_kwargs: | ||
| image_attention_mask = model_kwargs["image_attention_mask"] | ||
| last_mask = image_attention_mask[:, -1, :].unsqueeze(1) | ||
| model_kwargs["image_attention_mask"] = last_mask | ||
|
|
||
| # Get the precomputed image_hidden_states | ||
| model_kwargs["image_hidden_states"] = outputs.image_hidden_states | ||
|
|
||
| return model_kwargs | ||
|
|
||
|
|
@@ -139,9 +240,9 @@ def prepare_inputs_for_generation(input_ids, past=None, **kwargs): | |
| position_ids = position_ids[:, -1].unsqueeze(-1) | ||
|
|
||
| pixel_values = kwargs.get("pixel_values", None) | ||
| image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None) | ||
| perceiver_embeddings = kwargs.get("perceiver_embeddings", None) | ||
| image_attention_mask = kwargs.get("image_attention_mask", None) | ||
| # if pixel_values is None or image_attention_mask is None: | ||
| # raise ValueError("pixel values and image attention mask cannot be None") | ||
|
|
||
| return { | ||
| "input_ids": input_ids, | ||
|
|
@@ -151,6 +252,8 @@ def prepare_inputs_for_generation(input_ids, past=None, **kwargs): | |
| "attention_mask": attention_mask, | ||
| "token_type_ids": token_type_ids, | ||
| "pixel_values": pixel_values, | ||
| "image_encoder_embeddings": image_encoder_embeddings, | ||
| "perceiver_embeddings": perceiver_embeddings, | ||
| "image_attention_mask": image_attention_mask, | ||
| } | ||
|
|
||
|
|
@@ -1055,13 +1158,14 @@ def forward( | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | ||
| inputs_embeds: Optional[torch.FloatTensor] = None, | ||
| pixel_values: Optional[torch.FloatTensor] = None, | ||
| image_embeddings: Optional[torch.FloatTensor] = None, | ||
| image_encoder_embeddings: Optional[torch.FloatTensor] = None, | ||
| perceiver_embeddings: Optional[torch.FloatTensor] = None, | ||
| image_attention_mask: Optional[torch.Tensor] = None, | ||
| use_cache: Optional[bool] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None, | ||
| ) -> Union[Tuple, BaseModelOutputWithPast]: | ||
| ) -> Union[Tuple, IdeficsBaseModelOutputWithPast]: | ||
| device = input_ids.device if input_ids is not None else inputs_embeds.device | ||
|
|
||
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
|
|
@@ -1103,11 +1207,10 @@ def forward( | |
| position_ids = position_ids.view(-1, seq_length).long() | ||
|
|
||
| no_images = False | ||
| if pixel_values is None and image_embeddings is None: | ||
| raise ValueError("Either pixel_values and image_embeddings have to be not-None.") | ||
|
|
||
| elif pixel_values is not None and image_embeddings is not None: | ||
| raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time") | ||
| if (pixel_values, image_encoder_embeddings, perceiver_embeddings).count(None) != 2: | ||
| raise ValueError( | ||
| "Exactly 1 of pixel_values, image_encoder_embeddings or perceiver_embeddings has to be not-None." | ||
| ) | ||
|
|
||
| elif pixel_values is not None: | ||
| no_images = len(torch.nonzero(pixel_values)) == 0 | ||
|
|
@@ -1118,14 +1221,23 @@ def forward( | |
| # Get sequence from the vision encoder | ||
| image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state | ||
|
|
||
| elif image_embeddings is not None: | ||
| batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size() | ||
| image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device) | ||
| elif image_encoder_embeddings is not None: | ||
| batch_size, num_images, image_seq_len, image_hidden_size = image_encoder_embeddings.size() | ||
| image_hidden_states = image_encoder_embeddings.to(dtype=self.dtype, device=input_ids.device) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Last nit, is this required? Would think that accelerate handles this since it's not a tensor created on the fly. (unless the casting is what's required not necessarly moving to a different device!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure, maybe not. I kept it out of caution when modifying this part. @VictorSanh, do you know if there was a particular reason for adding this? |
||
| image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size) | ||
|
|
||
| if self.config.use_resampler: | ||
| image_hidden_states = self.perceiver_resampler(image_hidden_states) | ||
| image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2) | ||
| if perceiver_embeddings is None: | ||
| perceiver_embeddings = self.perceiver_resampler(image_hidden_states) | ||
| image_seq_len, image_hidden_size = perceiver_embeddings.size(1), perceiver_embeddings.size(2) | ||
| else: | ||
| batch_size, num_images, image_seq_len, image_hidden_size = perceiver_embeddings.size() | ||
| image_hidden_states = perceiver_embeddings | ||
| elif perceiver_embeddings is None: | ||
| image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2) | ||
| else: | ||
| raise ValueError("If `perceiver_embeddings` are passed, use_resampler should be True") | ||
|
|
||
| image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size) | ||
| # # Hack to use the model in full language modeling mode | ||
| # image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device) | ||
|
|
@@ -1272,13 +1384,19 @@ def vblock( | |
| all_hidden_states += (hidden_states,) | ||
|
|
||
| next_cache = next_decoder_cache if use_cache else None | ||
| image_hidden_states = image_hidden_states.view(batch_size, num_images, image_seq_len, image_hidden_size) | ||
| if not return_dict: | ||
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) | ||
| return BaseModelOutputWithPast( | ||
| return tuple( | ||
| v | ||
| for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, image_hidden_states] | ||
| if v is not None | ||
| ) | ||
| return IdeficsBaseModelOutputWithPast( | ||
| last_hidden_state=hidden_states, | ||
| past_key_values=next_cache, | ||
| hidden_states=all_hidden_states, | ||
| attentions=all_self_attns, | ||
| image_hidden_states=image_hidden_states, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -1341,7 +1459,7 @@ def tie_weights(self): | |
| output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings | ||
|
|
||
| @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) | ||
| @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) | ||
| @replace_return_docstrings(output_type=IdeficsCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) | ||
| def forward( | ||
| self, | ||
| input_ids: torch.LongTensor = None, | ||
|
|
@@ -1350,14 +1468,15 @@ def forward( | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | ||
| inputs_embeds: Optional[torch.FloatTensor] = None, | ||
| pixel_values: Optional[torch.FloatTensor] = None, | ||
| image_embeddings: Optional[torch.FloatTensor] = None, | ||
| image_encoder_embeddings: Optional[torch.FloatTensor] = None, | ||
| perceiver_embeddings: Optional[torch.FloatTensor] = None, | ||
| image_attention_mask: Optional[torch.Tensor] = None, | ||
| labels: Optional[torch.LongTensor] = None, | ||
| use_cache: Optional[bool] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None, | ||
| ) -> Union[Tuple, CausalLMOutputWithPast]: | ||
| ) -> Union[Tuple, IdeficsCausalLMOutputWithPast]: | ||
| r""" | ||
| Args: | ||
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||
|
|
@@ -1398,7 +1517,8 @@ def forward( | |
| past_key_values=past_key_values, | ||
| inputs_embeds=inputs_embeds, | ||
| pixel_values=pixel_values, | ||
| image_embeddings=image_embeddings, | ||
| image_encoder_embeddings=image_encoder_embeddings, | ||
| perceiver_embeddings=perceiver_embeddings, | ||
| image_attention_mask=image_attention_mask, | ||
| use_cache=use_cache, | ||
| output_attentions=output_attentions, | ||
|
|
@@ -1427,15 +1547,23 @@ def forward( | |
| output = (logits,) + outputs[1:] | ||
| return (loss,) + output if loss is not None else output | ||
|
|
||
| return CausalLMOutputWithPast( | ||
| return IdeficsCausalLMOutputWithPast( | ||
| loss=loss, | ||
| logits=logits, | ||
| past_key_values=outputs.past_key_values, | ||
| hidden_states=outputs.hidden_states, | ||
| attentions=outputs.attentions, | ||
| image_hidden_states=outputs.image_hidden_states, | ||
| ) | ||
|
|
||
| def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): | ||
| image_hidden_states = kwargs.pop("image_hidden_states", None) | ||
| if image_hidden_states is not None: | ||
| if self.config.use_resampler: | ||
| kwargs["perceiver_embeddings"] = image_hidden_states | ||
| else: | ||
| kwargs["image_encoder_embeddings"] = image_hidden_states | ||
| kwargs["pixel_values"] = None | ||
| inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs) | ||
| unwanted_kwargs = ["token_type_ids"] | ||
| for kwarg in unwanted_kwargs: | ||
|
|
@@ -1450,8 +1578,8 @@ def _expand_inputs_for_generation( | |
| return expand_inputs_for_generation(*args, **model_kwargs) | ||
|
|
||
| @staticmethod | ||
| def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): | ||
| return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder) | ||
| def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder): | ||
| return update_model_kwargs_for_generation(outputs, model_kwargs) | ||
|
|
||
| @staticmethod | ||
| def _reorder_cache(past, beam_idx): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.