Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions src/transformers/models/blip/modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,17 +1176,30 @@ def forward(

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = "How many cats are in the picture?"

>>> # training
>>> text = "How many cats are in the picture?"
>>> label = "2"
>>> inputs = processor(images=image, text=text, return_tensors="pt")
>>> labels = processor(text=label, return_tensors="pt").input_ids

>>> inputs["labels"] = labels
>>> outputs = model(**inputs)
>>> loss = outputs.loss
>>> loss.backward()

>>> # inference
>>> text = "How many cats are in the picture?"
>>> inputs = processor(images=image, text=text, return_tensors="pt")
>>> outputs = model.generate(**inputs)
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
2
```"""
if labels is None and decoder_input_ids is None:
raise ValueError(
"Either `decoder_input_ids` or `labels` should be passed when calling `forward` with"
" `BlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you"
" are using the model for inference make sure that `decoder_input_ids` is passed."
" are using the model for inference make sure that `decoder_input_ids` is passed or call `generate`"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No for the change in this PR, but if you are training should be If you are training. And the second part should be , and if you ... or . If you ...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let @NielsRogge for a second eye 👀 on this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes agree here!

)

return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Expand Down Expand Up @@ -1392,8 +1405,8 @@ def forward(
>>> import requests
>>> from transformers import BlipProcessor, BlipForImageTextRetrieval

>>> model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base")
>>> processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-base")
>>> model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
>>> processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-base-coco")

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
Expand Down