Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/blip.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ However, most existing pre-trained models only excel in either understanding-bas
This model was contributed by [ybelkada](https://huggingface.co/ybelkada).
The original code can be found [here](https://github.com/salesforce/BLIP).

## Resources

- [Jupyter notebook](https://github.com/huggingface/notebooks/blob/main/examples/image_captioning_blip.ipynb) on how to fine-tune BLIP for image captioning on a custom dataset


## BlipConfig

Expand Down
55 changes: 43 additions & 12 deletions src/transformers/models/blip/modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,7 @@ def forward(
encoder_hidden_states=image_embeds,
labels=labels,
return_dict=return_dict,
reduction="mean",
)

if not return_dict:
Expand Down Expand Up @@ -1125,14 +1126,27 @@ def __init__(self, config: BlipConfig):
self.text_decoder = BlipTextLMHeadModel(config.text_config)

self.decoder_pad_token_id = config.text_config.pad_token_id
self.decoder_bos_token_id = config.text_config.bos_token_id
self.decoder_start_token_id = config.text_config.bos_token_id

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding

# Adapted from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right
def _shift_right(self, input_ids):
pad_token_id = self.decoder_pad_token_id

shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = self.decoder_start_token_id

# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

return shifted_input_ids

@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
def forward(
Expand Down Expand Up @@ -1168,8 +1182,14 @@ def forward(

>>> outputs = model(**inputs)
```"""
if labels is None and decoder_input_ids is None:
raise ValueError(
"Either `decoder_input_ids` or `labels` should be passed when calling `forward` with"
" `BlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you"
" are using the model for inference make sure that `decoder_input_ids` is passed."
)

return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size = input_ids.shape[0]

vision_outputs = self.vision_model(
pixel_values=pixel_values,
Expand All @@ -1191,11 +1211,11 @@ def forward(

question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state

if decoder_input_ids is None:
decoder_input_ids = torch.LongTensor([self.decoder_bos_token_id]).repeat((batch_size, 1))

if labels is None:
labels = decoder_input_ids.masked_fill(decoder_input_ids == self.decoder_pad_token_id, -100)
if labels is not None and decoder_input_ids is None:
# get decoder inputs from shifting lm labels to the right - this is used in training mode
decoder_input_ids = self._shift_right(labels)
# replace possible -100 values in labels by `pad_token_id`
labels = labels.masked_fill(labels == self.decoder_pad_token_id, -100)

answer_output = self.text_decoder(
input_ids=decoder_input_ids,
Expand All @@ -1204,10 +1224,13 @@ def forward(
encoder_attention_mask=attention_mask,
labels=labels,
return_dict=return_dict,
reduction="none",
reduction="mean",
)

decoder_loss = answer_output.loss.mean() if return_dict else answer_output[0].mean()
if labels is not None:
decoder_loss = answer_output.loss.mean() if return_dict else answer_output[0].mean()
else:
decoder_loss = None

if not return_dict:
outputs = (decoder_loss, image_embeds, vision_outputs[0]) + vision_outputs[2:]
Expand Down Expand Up @@ -1288,7 +1311,7 @@ def generate(
question_attention_mask = torch.ones(question_embeds.size()[:-1], dtype=torch.long).to(question_embeds.device)

bos_ids = torch.full(
(question_embeds.size(0), 1), fill_value=self.decoder_bos_token_id, device=question_embeds.device
(question_embeds.size(0), 1), fill_value=self.decoder_start_token_id, device=question_embeds.device
)

outputs = self.text_decoder.generate(
Expand Down Expand Up @@ -1330,8 +1353,16 @@ def __init__(self, config: BlipConfig):
# image text matching head
self.itm_head = nn.Linear(config.text_config.hidden_size, 2)

self.decoder_pad_token_id = config.text_config.pad_token_id
self.decoder_bos_token_id = config.text_config.bos_token_id
self.decoder_pad_token_id = (
config.text_config.pad_token_id
if not hasattr(config, "decoder_pad_token_id")
else config.decoder_pad_token_id
)
self.decoder_start_token_id = (
config.text_config.bos_token_id
if not hasattr(config, "decoder_start_token_id")
else config.decoder_start_token_id
)

# Initialize weights and apply final processing
self.post_init()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/blip/modeling_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ def forward(
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)))
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length))).to(device)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
Expand Down
Loading