Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/source/en/model_doc/bridgetower.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@ In principle, one can apply any visual, textual or cross-modal encoder in the pr
The [`BridgeTowerProcessor`] wraps [`RobertaTokenizer`] and [`BridgeTowerImageProcessor`] into a single instance to both
encode the text and prepare the images respectively.

The following example shows how to run contrastive learning using [`BridgeTowerProcessor`] and [`BridgeTowerForContrastiveLearning`].
```python
>>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
>>> import requests
>>> from PIL import Image

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]

>>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
>>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")

>>> # forward pass
>>> scores = dict()
>>> for text in texts:
... # prepare inputs
... encoding = processor(image, text, return_tensors="pt")
... outputs = model(**encoding)
... scores[text] = outputs
```

The following example shows how to run image-text retrieval using [`BridgeTowerProcessor`] and [`BridgeTowerForImageAndTextRetrieval`].
```python
>>> from transformers import BridgeTowerProcessor, BridgeTowerForImageAndTextRetrieval
Expand Down Expand Up @@ -128,6 +150,11 @@ Tips:
[[autodoc]] BridgeTowerModel
- forward

## BridgeTowerForContrastiveLearning

[[autodoc]] BridgeTowerForContrastiveLearning
- forward

## BridgeTowerForMaskedLM

[[autodoc]] BridgeTowerForMaskedLM
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,7 @@
_import_structure["models.bridgetower"].extend(
[
"BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST",
"BridgeTowerForContrastiveLearning",
"BridgeTowerForImageAndTextRetrieval",
"BridgeTowerForMaskedLM",
"BridgeTowerModel",
Expand Down Expand Up @@ -4666,6 +4667,7 @@
)
from .models.bridgetower import (
BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST,
BridgeTowerForContrastiveLearning,
BridgeTowerForImageAndTextRetrieval,
BridgeTowerForMaskedLM,
BridgeTowerModel,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/bridgetower/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
else:
_import_structure["modeling_bridgetower"] = [
"BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST",
"BridgeTowerForContrastiveLearning",
"BridgeTowerForImageAndTextRetrieval",
"BridgeTowerForMaskedLM",
"BridgeTowerModel",
Expand Down Expand Up @@ -74,6 +75,7 @@
else:
from .modeling_bridgetower import (
BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST,
BridgeTowerForContrastiveLearning,
BridgeTowerForImageAndTextRetrieval,
BridgeTowerForMaskedLM,
BridgeTowerModel,
Expand Down
187 changes: 182 additions & 5 deletions src/transformers/models/bridgetower/modeling_bridgetower.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,8 @@ class BridgeTowerModelOutput(ModelOutput):
token), respectively, after further processing through layers used for auxiliary pretraining tasks.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Expand All @@ -161,6 +160,40 @@ class BridgeTowerModelOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class BridgeTowerContrastiveOutput(ModelOutput):
"""
Output type of ['BridgeTowerForContrastiveLearning']

Args:
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
The text embeddings obtained by applying the projection layer to the pooler_output.
image_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output.
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Image-text contrastive loss.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
the model at the output of each layer plus the optional initial embedding outputs.
"""

logits: torch.FloatTensor = None
text_embeds: Optional[Tuple[torch.FloatTensor]] = None
image_embeds: Optional[Tuple[torch.FloatTensor]] = None
cross_embeds: Optional[Tuple[torch.FloatTensor]] = None
loss: Optional[torch.FloatTensor] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None


class BridgeTowerResidualAttention(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -1314,7 +1347,12 @@ def forward(
if output_hidden_states:
all_hidden_states_text += (text_embeds,)

image_embeds = self.vision_model.visual.forward_pre(pixel_values.type(self.vision_model.dtype))
if image_embeds is None:
image_embeds = self.vision_model.visual.forward_pre(pixel_values.type(self.vision_model.dtype))
else:
# Permute as BridgeTowerResidualAttention has batch_first=True
image_embeds = image_embeds.permute(1, 0, 2)

if output_hidden_states:
all_hidden_states_image += (image_embeds,)

Expand Down Expand Up @@ -1438,7 +1476,11 @@ def forward(
all_hidden_states = (all_hidden_states_text, all_hidden_states_image, all_hidden_states_cross)

if not return_dict:
return tuple(v for v in [text_features, image_features, cls_features] if v is not None)
return tuple(
v
for v in [text_features, image_features, cls_features, all_hidden_states, all_self_attentions]
if v is not None
)

return BridgeTowerModelOutput(
text_features=text_features,
Expand Down Expand Up @@ -1700,3 +1742,138 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class BridgeTowerContrastiveHead(nn.Module):
def __init__(self, hidden_size, embed_size):
super().__init__()
self.fc = nn.Linear(hidden_size, embed_size)

def forward(self, x):
x = self.fc(x)
return x


@add_start_docstrings(
"""
BridgeTower Model with a image-text contrastive head on top computing image-text contrastive loss.
""",
BRIDGETOWER_START_DOCSTRING,
)
class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.bridgetower = BridgeTowerModel(config)

self.itc_text_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size)
self.itc_image_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size)
self.itc_cross_modal_head = BridgeTowerContrastiveHead(config.hidden_size * 2, config.contrastive_hidden_size)

self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value)
# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(BRIDGETOWER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BridgeTowerContrastiveOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = True,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
) -> Union[BridgeTowerContrastiveOutput, Tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
The pairs with 0 will be skipped for calculation.
Returns:

Examples:

```python
>>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
>>> import requests
>>> from PIL import Image

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> texts = "An image of two cats chilling on a couch"

>>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
>>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
>>> outputs = model(**inputs, output_hidden_states=True)
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bridgetower(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
image_embeds=image_embeds,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
)

pooler_output = outputs.pooler_output if return_dict else outputs[2]
hidden_states_txt, hidden_states_img, hidden_states_cross_modal = (
outputs.hidden_states if return_dict else outputs[3]
)

text_embeds = hidden_states_txt[-1]
image_embeds = hidden_states_img[-1]

image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(image_embeds)
image_token_type_embeddings = self.bridgetower.token_type_embeddings(
torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device)
).expand_as(image_embeds_with_ln)

image_embeds = self.bridgetower.cross_modal_image_transform(image_embeds_with_ln) + image_token_type_embeddings

# normalized features
text_embeds = nn.functional.normalize(self.itc_text_head(text_embeds[:, 0, :]), dim=-1, p=2)
image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2)
cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2)

logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2)

logit_scale = self.logit_scale.exp()
logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale
logits_image_to_cross = torch.matmul(image_embeds, cross_embeds.t()) * logit_scale

itc_loss = None

if labels is not None:
labels = torch.arange(len(labels), device=logits.device)
text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels)
text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels)
image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels)
itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0

if not return_dict:
output = tuple(logits)
return ((itc_loss,) + output) if itc_loss is not None else output

return BridgeTowerContrastiveOutput(
attentions=outputs.attentions,
hidden_states=outputs.hidden_states,
text_embeds=text_embeds,
image_embeds=image_embeds,
cross_embeds=cross_embeds,
logits=logits,
loss=itc_loss,
)
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,13 @@ def __init__(self, *args, **kwargs):
BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST = None


class BridgeTowerForContrastiveLearning(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class BridgeTowerForImageAndTextRetrieval(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
Loading