Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,7 @@
[
"BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST",
"BridgeTowerForImageAndTextRetrieval",
"BridgeTowerForITC",
"BridgeTowerForMaskedLM",
"BridgeTowerModel",
"BridgeTowerPreTrainedModel",
Expand Down Expand Up @@ -4700,6 +4701,7 @@
from .models.bridgetower import (
BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST,
BridgeTowerForImageAndTextRetrieval,
BridgeTowerForITC,
BridgeTowerForMaskedLM,
BridgeTowerModel,
BridgeTowerPreTrainedModel,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/bridgetower/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
_import_structure["modeling_bridgetower"] = [
"BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST",
"BridgeTowerForImageAndTextRetrieval",
"BridgeTowerForITC",
"BridgeTowerForMaskedLM",
"BridgeTowerModel",
"BridgeTowerPreTrainedModel",
Expand Down Expand Up @@ -75,6 +76,7 @@
from .modeling_bridgetower import (
BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST,
BridgeTowerForImageAndTextRetrieval,
BridgeTowerForITC,
BridgeTowerForMaskedLM,
BridgeTowerModel,
BridgeTowerPreTrainedModel,
Expand Down
169 changes: 168 additions & 1 deletion src/transformers/models/bridgetower/modeling_bridgetower.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn.functional import normalize

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As said below, let's not import this and use it via nn.functional.normalize.


from ...activations import ACT2FN, QuickGELUActivation
from ...modeling_outputs import (
Expand Down Expand Up @@ -57,6 +58,10 @@
]


def contrastive_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
return nn.functional.cross_entropy(logits, labels)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's just one line, let's not define a function for this and just use cross entropy.



BRIDGETOWER_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
Expand Down Expand Up @@ -144,7 +149,6 @@ class BridgeTowerModelOutput(ModelOutput):
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
Expand All @@ -161,6 +165,36 @@ class BridgeTowerModelOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class BridgeTowerITCOutput(ModelOutput):
"""
Args:
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
The text embeddings obtained by applying the projection layer to the pooler_output.
image_embeds
cross_embeds
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
ITC loss.
"""

attentions: Optional[Tuple[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
text_embeds: Optional[Tuple[torch.FloatTensor]] = None
image_embeds: Optional[Tuple[torch.FloatTensor]] = None
cross_embeds: Optional[Tuple[torch.FloatTensor]] = None
logits: Optional[torch.FloatTensor] = None
loss: Optional[torch.FloatTensor] = None


class BridgeTowerResidualAttention(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -1698,3 +1732,136 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class BridgeTowerITCHead(nn.Module):
def __init__(self, hidden_size, embed_size):
super().__init__()
self.fc = nn.Linear(hidden_size, embed_size)

def forward(self, x):
x = self.fc(x)
return x


@add_start_docstrings(
"""
BridgeTower Model with a image-text contrastive head on top as done during pretraining.
""",
BRIDGETOWER_START_DOCSTRING,
)
class BridgeTowerForITC(BridgeTowerPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.bridgetower = BridgeTowerModel(config)

self.itc_text_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size)
self.itc_image_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size)
self.itc_cross_modal_head = BridgeTowerITCHead(config.hidden_size * 2, config.contrastive_hidden_size)

self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value)
# Initialize weights and apply final processing
self.post_init()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
The pairs with 0 will be skipped for calculation.
Returns:

Examples:

```python
>>> from transformers import BridgeTowerProcessor, BridgeTowerForITC
>>> import requests
>>> from PIL import Image

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> texts = "An image of two cats chilling on a couch"

>>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
>>> model = BridgeTowerForITC.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
>>> outputs = model(**inputs, output_hidden_states=True)
```"""
assert output_hidden_states, 'output_hidden_states should be set to True for BridgeTowerForITC'

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No new asserts in the code base, please use a test and raise an appropriate error.

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bridgetower(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
image_embeds=image_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

pooler_output = outputs.pooler_output if return_dict else outputs[2]
hidden_states_txt, hidden_states_img, hidden_states_cross_modal = outputs.hidden_states

text_embeds = hidden_states_txt[-1]
image_embeds = hidden_states_img[-1]

image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(image_embeds)
image_token_type_embeddings = self.bridgetower.token_type_embeddings(
torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device)
).expand_as(image_embeds_with_ln)

image_embeds = (
self.bridgetower.cross_modal_image_transform(image_embeds_with_ln)
+ image_token_type_embeddings
)

# normalized features
text_embeds = normalize(self.itc_text_head(text_embeds[:,0,:]), dim=-1, p=2)
image_embeds = normalize(self.itc_image_head(image_embeds[:,0,:]), dim=-1, p=2)
cross_embeds = normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use nn.functional.normalize here, so it's clear to the user where the normalize comes from.


logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2)

logit_scale = self.logit_scale.exp()
logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale

itc_loss = None

if labels is not None:
labels = torch.arange(len(labels), device=labels.device)
text_to_image_loss = contrastive_loss(logits_text_to_image, labels)
text_to_cross_loss = contrastive_loss(logits_text_to_cross, labels)
itc_loss = (text_to_image_loss + text_to_cross_loss) / 2.0

if not return_dict:
output = tuple(logits)
return ((itc_loss,) + output) if itc_loss is not None else output

return BridgeTowerITCOutput(
attentions=outputs.attentions,
hidden_states=outputs.hidden_states,
text_embeds=text_embeds,
image_embeds=image_embeds,
cross_embeds=cross_embeds,
logits=logits,
loss=itc_loss
)
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,6 +1335,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class BridgeTowerForITC(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class BridgeTowerForMaskedLM(metaclass=DummyObject):
_backends = ["torch"]

Expand Down