diff --git a/docs/source/en/model_doc/blip-2.mdx b/docs/source/en/model_doc/blip-2.mdx index 690fc02bd8cf..57b380f99fe0 100644 --- a/docs/source/en/model_doc/blip-2.mdx +++ b/docs/source/en/model_doc/blip-2.mdx @@ -71,6 +71,14 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] Blip2QFormerModel - forward +## Blip2Model + +[[autodoc]] Blip2Model + - forward + - get_text_features + - get_image_features + - get_qformer_features + ## Blip2ForConditionalGeneration [[autodoc]] Blip2ForConditionalGeneration diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 8e1b779a116f..7843b485fad2 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1182,6 +1182,7 @@ [ "BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST", "Blip2ForConditionalGeneration", + "Blip2Model", "Blip2PreTrainedModel", "Blip2QFormerModel", "Blip2VisionModel", @@ -4608,6 +4609,7 @@ from .models.blip_2 import ( BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST, Blip2ForConditionalGeneration, + Blip2Model, Blip2PreTrainedModel, Blip2QFormerModel, Blip2VisionModel, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 8b141b2cd458..30caf51c0c9b 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -42,6 +42,7 @@ ("blenderbot", "BlenderbotModel"), ("blenderbot-small", "BlenderbotSmallModel"), ("blip", "BlipModel"), + ("blip_2", "Blip2Model"), ("bloom", "BloomModel"), ("bridgetower", "BridgeTowerModel"), ("camembert", "CamembertModel"), diff --git a/src/transformers/models/blip_2/__init__.py b/src/transformers/models/blip_2/__init__.py index 1c25156654c2..6fbfd53b3703 100644 --- a/src/transformers/models/blip_2/__init__.py +++ b/src/transformers/models/blip_2/__init__.py @@ -34,6 +34,7 @@ else: _import_structure["modeling_blip_2"] = [ "BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST", + "Blip2Model", "Blip2QFormerModel", "Blip2PreTrainedModel", "Blip2ForConditionalGeneration", @@ -58,6 +59,7 @@ from .modeling_blip_2 import ( BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST, Blip2ForConditionalGeneration, + Blip2Model, Blip2PreTrainedModel, Blip2QFormerModel, Blip2VisionModel, diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 5006a7819c58..62967dc74e44 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -342,6 +342,43 @@ def _set_gradient_checkpointing(self, module, value=False): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ +BLIP_2_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + BLIP_2_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): @@ -1171,6 +1208,337 @@ def forward( ) +@add_start_docstrings( + """ + BLIP-2 Model for generating text and image features. The model consists of a vision encoder, Querying Transformer + (Q-Former) and a language model. + """, + BLIP_2_START_DOCSTRING, +) +class Blip2Model(Blip2PreTrainedModel): + config_class = Blip2Config + main_input_name = "pixel_values" + + def __init__(self, config: Blip2Config): + super().__init__(config) + + self.vision_model = Blip2VisionModel(config.vision_config) + + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + self.qformer = Blip2QFormerModel(config.qformer_config) + + self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) + if config.use_decoder_only_language_model: + language_model = AutoModelForCausalLM.from_config(config.text_config) + else: + language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) + self.language_model = language_model + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(BLIP_2_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns: + text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`): + The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that + contains the language model logits, the past key values and the hidden states if + `output_hidden_states=True`. + Examples: + ```python + >>> import torch + >>> from transformers import AutoTokenizer, Blip2Model + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16) + + >>> model.to(device) # doctest: +IGNORE_RESULT + + >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/blip2-opt-2.7b") + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt").to(device) + >>> text_features = model.get_text_features(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.use_decoder_only_language_model: + text_outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + else: + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + + text_outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + labels=labels, + ) + + return text_outputs + + @add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns: + vision_outputs (`BaseModelOutputWithPooling` or tuple of `torch.FloatTensor`): + The vision model outputs. If `return_dict=True`, the output is a [`BaseModelOutputWithPooling`] that + contains the image features, the pooled image features and the hidden states if + `output_hidden_states=True`. + Examples: + ```python + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Blip2Model + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16) + + >>> model.to(device) # doctest: +IGNORE_RESULT + + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) + >>> image_outputs = model.get_image_features(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return vision_outputs + + @add_start_docstrings_to_model_forward(BLIP_2_INPUTS_DOCSTRING) + def get_qformer_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns: + vision_outputs (`BaseModelOutputWithPooling` or tuple of `torch.FloatTensor`): + The vision model outputs. If `return_dict=True`, the output is a [`BaseModelOutputWithPooling`] that + contains the image features, the pooled image features and the hidden states if + `output_hidden_states=True`. + Examples: + ```python + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from transformers import Blip2Processor, Blip2Model + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + >>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16) + >>> model.to(device) # doctest: +IGNORE_RESULT + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) + >>> qformer_outputs = model.get_qformer_features(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[0] + + # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_outputs = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return query_outputs + + @add_start_docstrings_to_model_forward(BLIP_2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Blip2ForConditionalGenerationModelOutput, config_class=Blip2VisionConfig) + def forward( + self, + pixel_values: torch.FloatTensor, + input_ids: torch.FloatTensor, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import Blip2Processor, Blip2Model + >>> import torch + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + >>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16) + >>> model.to(device) # doctest: +IGNORE_RESULT + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> prompt = "Question: how many cats are there? Answer:" + >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16) + + >>> outputs = model(**inputs) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # step 1: forward the images through the vision encoder, + # to get image embeddings of shape (batch_size, seq_len, hidden_size) + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeds = vision_outputs[0] + + # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_outputs = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + query_output = query_outputs[0] + + # step 3: use the language model, conditioned on the query outputs and the prompt + language_model_inputs = self.language_projection(query_output) + language_model_attention_mask = torch.ones( + language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device + ) + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + inputs_embeds = torch.cat([language_model_inputs, inputs_embeds], dim=1) + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + expected_device = language_model_attention_mask.device + attention_mask = torch.cat([language_model_attention_mask, attention_mask.to(expected_device)], dim=1) + + if self.config.use_decoder_only_language_model: + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs.logits if return_dict else outputs[0] + loss = None + # we compute the loss here since we need to take into account the sequence length of the query embeds + if labels is not None: + logits = logits[:, -labels.size(1) :, :] + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous().to(logits.device) + + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction="mean") + + loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1)) + else: + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + labels=labels, + ) + loss = outputs.loss if return_dict else outputs[0] + logits = outputs.logits if return_dict else outputs[1] + + if not return_dict: + output = (logits, vision_outputs, query_outputs, outputs) + return ((loss,) + output) if loss is not None else output + + return Blip2ForConditionalGenerationModelOutput( + loss=loss, + logits=logits, + vision_outputs=vision_outputs, + qformer_outputs=query_outputs, + language_model_outputs=outputs, + ) + + @add_start_docstrings( """ BLIP-2 Model for generating text given an image and an optional text prompt. The model consists of a vision diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 36886cfe643c..b31387dc2d40 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1221,6 +1221,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Blip2Model(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Blip2PreTrainedModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index c888eb080141..8bff72494053 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -40,7 +40,7 @@ import torch from torch import nn - from transformers import Blip2ForConditionalGeneration, Blip2VisionModel + from transformers import Blip2ForConditionalGeneration, Blip2Model, Blip2VisionModel from transformers.models.blip_2.modeling_blip_2 import BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST @@ -664,8 +664,8 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class Blip2ForConditionalGenerationTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (Blip2ForConditionalGeneration,) if is_torch_available() else () +class Blip2ModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else () fx_compatible = False test_head_masking = False test_pruning = False @@ -737,6 +737,56 @@ def test_model_from_pretrained(self): model = Blip2ForConditionalGeneration.from_pretrained(model_name) self.assertIsNotNone(model) + def test_get_text_features(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + inputs_dict = { + "input_ids": torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).to(torch_device), + "attention_mask": torch.LongTensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]).to(torch_device), + "decoder_input_ids": torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).to(torch_device), + } + + model = Blip2Model(config).to(torch_device) + model.eval() + text_features = model.get_text_features(**inputs_dict) + self.assertEqual(text_features[0].shape, (1, 10, config.text_config.vocab_size)) + + def test_get_image_features(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"] + + for key in keys_to_pop: + inputs_dict.pop(key) + + model = Blip2Model(config).to(torch_device) + model.eval() + image_features = model.get_image_features(**inputs_dict) + self.assertEqual( + image_features[0].shape, + ( + self.model_tester.vision_model_tester.batch_size, + self.model_tester.vision_model_tester.seq_length, + config.vision_config.hidden_size, + ), + ) + + def test_get_qformer_features(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"] + + for key in keys_to_pop: + inputs_dict.pop(key) + + model = Blip2Model(config).to(torch_device) + model.eval() + qformer_features = model.get_qformer_features(**inputs_dict) + self.assertEqual( + qformer_features[0].shape, + (self.model_tester.vision_model_tester.batch_size, 10, config.vision_config.hidden_size), + ) + # override from common to deal with nested configurations (`vision_config`, `text_config` and `qformer_config`) def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()