diff --git a/src/transformers/models/blip_2/configuration_blip_2.py b/src/transformers/models/blip_2/configuration_blip_2.py index 17d16cd3be5c..9db098adf1a3 100644 --- a/src/transformers/models/blip_2/configuration_blip_2.py +++ b/src/transformers/models/blip_2/configuration_blip_2.py @@ -338,6 +338,9 @@ def __init__(self, vision_config=None, qformer_config=None, text_config=None, nu text_model_type = text_config["model_type"] if "model_type" in text_config else "opt" self.text_config = CONFIG_MAPPING[text_model_type](**text_config) + self.tie_word_embeddings = self.text_config.tie_word_embeddings + self.is_encoder_decoder = self.text_config.is_encoder_decoder + self.num_query_tokens = num_query_tokens self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index da4563e51aa1..672d845bcaf7 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -283,6 +283,7 @@ class Blip2PreTrainedModel(PreTrainedModel): r"position_ids", r"language_model.encoder.embed_tokens.weight", r"language_model.decoder.embed_tokens.weight", + r"language_model.lm_head.weight", ] _no_split_modules = ["Blip2Attention", "T5Block", "OPTDecoderLayer"] _keep_in_fp32_modules = ["wo"] @@ -1203,8 +1204,48 @@ def __init__(self, config: Blip2Config): # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self) -> nn.Module: - return self.vision_model.embeddings.patch_embedding + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def get_output_embeddings(self) -> nn.Module: + return self.language_model.get_output_embeddings() + + def get_encoder(self): + return self.language_model.get_encoder() + + def get_decoder(self): + return self.language_model.get_decoder() + + def _tie_weights(self): + if not self.config.use_decoder_only_language_model: + self.language_model.encoder.embed_tokens = self.language_model.shared + self.language_model.decoder.embed_tokens = self.language_model.shared + + def _preprocess_accelerate(self): + r""" + Some pre-processing hacks to make the model `accelerate` compatible. Check + https://github.com/huggingface/transformers/pull/21707 for more details. + """ + hf_device_map = self.hf_device_map + + if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1: + # warn users about unexpected behavior when using multi-GPU + BLIP-2 + `accelerate`. + logger.warning( + "The `language_model` is not in the `hf_device_map` dictionary and you are running your script" + " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`." + " Please pass a `device_map` that contains `language_model` to remove this warning." + " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for", + " more details on creating a `device_map` for large models.", + ) + + if hasattr(self.language_model, "_hf_hook"): + self.language_model._hf_hook.io_same_device = True # For `generate` compatibility @add_start_docstrings_to_model_forward(BLIP_2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Blip2ForConditionalGenerationModelOutput, config_class=Blip2VisionConfig) @@ -1311,7 +1352,7 @@ def forward( language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) inputs_embeds = self.language_model.get_input_embeddings()(input_ids) - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds], dim=1) + inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) if attention_mask is None: attention_mask = torch.ones_like(input_ids) @@ -1387,6 +1428,10 @@ def generate( Returns: captions (list): A list of strings of length batch_size * num_captions. """ + if hasattr(self, "hf_device_map"): + # preprocess for `accelerate` + self._preprocess_accelerate() + batch_size = pixel_values.shape[0] image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) @@ -1412,11 +1457,11 @@ def generate( ) if attention_mask is None: attention_mask = torch.ones_like(input_ids) - attention_mask = torch.cat([language_attention_mask, attention_mask], dim=1) + attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1) # concatenate query embeddings with prompt embeddings - inputs_embeds = self.language_model.get_input_embeddings()(input_ids) - inputs_embeds = torch.cat([language_model_inputs, inputs_embeds], dim=1) + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) outputs = self.language_model.generate( inputs_embeds=inputs_embeds, diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index c888eb080141..67dcaa86241d 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -23,7 +23,7 @@ import requests from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device from transformers.utils import is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -859,3 +859,75 @@ def test_inference_t5_batched_beam_search(self): # Test output (in this case, slightly different from greedy search) self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1]) self.assertEqual(predictions[1].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1]) + + @require_torch_multi_gpu + def test_inference_opt_multi_gpu(self): + processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + model = Blip2ForConditionalGeneration.from_pretrained( + "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="balanced" + ) + + # prepare image + image = prepare_img() + inputs = processor(images=image, return_tensors="pt").to(0, dtype=torch.float16) + + predictions = model.generate(**inputs) + generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() + + # Test output + self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118]) + self.assertEqual("a woman sitting on the beach with a dog", generated_text) + + # image and context + prompt = "Question: which city is this? Answer:" + inputs = processor(images=image, text=prompt, return_tensors="pt").to(0, dtype=torch.float16) + + predictions = model.generate(**inputs) + generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() + + # Test output + self.assertEqual( + predictions[0].tolist(), + [2, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118], + ) + self.assertEqual(generated_text, "it's not a city, it's a beach") + + @require_torch_multi_gpu + def test_inference_t5_multi_gpu(self): + processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") + device_map = device_map = { + "query_tokens": 0, + "vision_model": 0, + "language_model": 1, + "language_projection": 0, + "qformer": 0, + } + + model = Blip2ForConditionalGeneration.from_pretrained( + "Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16, device_map=device_map + ) + + # prepare image + image = prepare_img() + inputs = processor(images=image, return_tensors="pt").to(0, dtype=torch.float16) + + predictions = model.generate(**inputs) + generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() + + # Test output + self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1]) + self.assertEqual("woman playing with dog on the beach", generated_text) + + # image and context + prompt = "Question: which city is this? Answer:" + inputs = processor(images=image, text=prompt, return_tensors="pt").to(0, dtype=torch.float16) + + predictions = model.generate(**inputs) + generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() + + # Test output + self.assertEqual( + predictions[0].tolist(), + [0, 3, 7, 152, 67, 839, 1], + ) + self.assertEqual(generated_text, "san diego")