Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/models/blip_2/configuration_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ def __init__(self, vision_config=None, qformer_config=None, text_config=None, nu
text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)

self.tie_word_embeddings = self.text_config.tie_word_embeddings
self.is_encoder_decoder = self.text_config.is_encoder_decoder

self.num_query_tokens = num_query_tokens
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
Expand Down
57 changes: 51 additions & 6 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
r"position_ids",
r"language_model.encoder.embed_tokens.weight",
r"language_model.decoder.embed_tokens.weight",
r"language_model.lm_head.weight",
]
_no_split_modules = ["Blip2Attention", "T5Block", "OPTDecoderLayer"]
_keep_in_fp32_modules = ["wo"]
Expand Down Expand Up @@ -1203,8 +1204,48 @@ def __init__(self, config: Blip2Config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()

def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)

def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)

def get_output_embeddings(self) -> nn.Module:
return self.language_model.get_output_embeddings()

def get_encoder(self):
return self.language_model.get_encoder()

def get_decoder(self):
return self.language_model.get_decoder()

def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
self.language_model.decoder.embed_tokens = self.language_model.shared

def _preprocess_accelerate(self):
r"""
Some pre-processing hacks to make the model `accelerate` compatible. Check
https://github.com/huggingface/transformers/pull/21707 for more details.
"""
hf_device_map = self.hf_device_map

if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
# warn users about unexpected behavior when using multi-GPU + BLIP-2 + `accelerate`.
logger.warning(
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
" Please pass a `device_map` that contains `language_model` to remove this warning."
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for",
" more details on creating a `device_map` for large models.",
)

if hasattr(self.language_model, "_hf_hook"):
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility

@add_start_docstrings_to_model_forward(BLIP_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Blip2ForConditionalGenerationModelOutput, config_class=Blip2VisionConfig)
Expand Down Expand Up @@ -1311,7 +1352,7 @@ def forward(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds], dim=1)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)

if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
Expand Down Expand Up @@ -1387,6 +1428,10 @@ def generate(
Returns:
captions (list): A list of strings of length batch_size * num_captions.
"""
if hasattr(self, "hf_device_map"):
# preprocess for `accelerate`
self._preprocess_accelerate()

batch_size = pixel_values.shape[0]
image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
Expand All @@ -1412,11 +1457,11 @@ def generate(
)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
attention_mask = torch.cat([language_attention_mask, attention_mask], dim=1)
attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1)

# concatenate query embeddings with prompt embeddings
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds], dim=1)
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
Expand Down
74 changes: 73 additions & 1 deletion tests/models/blip_2/test_modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import requests

from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available

from ...test_configuration_common import ConfigTester
Expand Down Expand Up @@ -859,3 +859,75 @@ def test_inference_t5_batched_beam_search(self):
# Test output (in this case, slightly different from greedy search)
self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
self.assertEqual(predictions[1].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])

@require_torch_multi_gpu
def test_inference_opt_multi_gpu(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="balanced"
)

# prepare image
image = prepare_img()
inputs = processor(images=image, return_tensors="pt").to(0, dtype=torch.float16)

predictions = model.generate(**inputs)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()

# Test output
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
self.assertEqual("a woman sitting on the beach with a dog", generated_text)

# image and context
prompt = "Question: which city is this? Answer:"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(0, dtype=torch.float16)

predictions = model.generate(**inputs)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()

# Test output
self.assertEqual(
predictions[0].tolist(),
[2, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118],
)
self.assertEqual(generated_text, "it's not a city, it's a beach")

@require_torch_multi_gpu
def test_inference_t5_multi_gpu(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
device_map = device_map = {
"query_tokens": 0,
"vision_model": 0,
"language_model": 1,
"language_projection": 0,
"qformer": 0,
}

model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16, device_map=device_map
)

# prepare image
image = prepare_img()
inputs = processor(images=image, return_tensors="pt").to(0, dtype=torch.float16)

predictions = model.generate(**inputs)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()

# Test output
self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
self.assertEqual("woman playing with dog on the beach", generated_text)

# image and context
prompt = "Question: which city is this? Answer:"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(0, dtype=torch.float16)

predictions = model.generate(**inputs)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()

# Test output
self.assertEqual(
predictions[0].tolist(),
[0, 3, 7, 152, 67, 839, 1],
)
self.assertEqual(generated_text, "san diego")