-
Notifications
You must be signed in to change notification settings - Fork 33.6k
Add BridgeTowerForContrastiveLearning #21964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
488802d
08fa3c0
c474659
58b334f
f922a13
6ffb05b
be2be4a
cc34c2b
c9984da
27d076a
19b5a79
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| import torch.utils.checkpoint | ||
| from torch import nn | ||
| from torch.nn import CrossEntropyLoss | ||
| from torch.nn.functional import normalize | ||
|
|
||
| from ...activations import ACT2FN, QuickGELUActivation | ||
| from ...modeling_outputs import ( | ||
|
|
@@ -57,6 +58,10 @@ | |
| ] | ||
|
|
||
|
|
||
| def contrastive_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: | ||
| return nn.functional.cross_entropy(logits, labels) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since it's just one line, let's not define a function for this and just use cross entropy. |
||
|
|
||
|
|
||
| BRIDGETOWER_START_DOCSTRING = r""" | ||
| This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use | ||
| it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and | ||
|
|
@@ -144,7 +149,6 @@ class BridgeTowerModelOutput(ModelOutput): | |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | ||
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + | ||
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. | ||
|
|
||
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. | ||
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | ||
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | ||
|
|
@@ -161,6 +165,36 @@ class BridgeTowerModelOutput(ModelOutput): | |
| attentions: Optional[Tuple[torch.FloatTensor]] = None | ||
|
|
||
|
|
||
| @dataclass | ||
| class BridgeTowerITCOutput(ModelOutput): | ||
| """ | ||
| Args: | ||
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | ||
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | ||
| sequence_length)`. | ||
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | ||
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + | ||
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. | ||
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. | ||
| text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): | ||
| The text embeddings obtained by applying the projection layer to the pooler_output. | ||
| image_embeds | ||
| cross_embeds | ||
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): | ||
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | ||
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): | ||
| ITC loss. | ||
| """ | ||
|
|
||
| attentions: Optional[Tuple[torch.FloatTensor]] = None | ||
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | ||
| text_embeds: Optional[Tuple[torch.FloatTensor]] = None | ||
| image_embeds: Optional[Tuple[torch.FloatTensor]] = None | ||
| cross_embeds: Optional[Tuple[torch.FloatTensor]] = None | ||
| logits: Optional[torch.FloatTensor] = None | ||
| loss: Optional[torch.FloatTensor] = None | ||
|
|
||
|
|
||
| class BridgeTowerResidualAttention(nn.Module): | ||
| def __init__(self, config): | ||
| super().__init__() | ||
|
|
@@ -1698,3 +1732,136 @@ def forward( | |
| hidden_states=outputs.hidden_states, | ||
| attentions=outputs.attentions, | ||
| ) | ||
|
|
||
|
|
||
| class BridgeTowerITCHead(nn.Module): | ||
| def __init__(self, hidden_size, embed_size): | ||
| super().__init__() | ||
| self.fc = nn.Linear(hidden_size, embed_size) | ||
|
|
||
| def forward(self, x): | ||
| x = self.fc(x) | ||
| return x | ||
|
|
||
|
|
||
| @add_start_docstrings( | ||
| """ | ||
| BridgeTower Model with a image-text contrastive head on top as done during pretraining. | ||
| """, | ||
| BRIDGETOWER_START_DOCSTRING, | ||
| ) | ||
| class BridgeTowerForITC(BridgeTowerPreTrainedModel): | ||
| def __init__(self, config): | ||
| super().__init__(config) | ||
|
|
||
| self.bridgetower = BridgeTowerModel(config) | ||
|
|
||
| self.itc_text_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size) | ||
| self.itc_image_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size) | ||
| self.itc_cross_modal_head = BridgeTowerITCHead(config.hidden_size * 2, config.contrastive_hidden_size) | ||
|
|
||
| self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value) | ||
| # Initialize weights and apply final processing | ||
| self.post_init() | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: Optional[torch.LongTensor] = None, | ||
| attention_mask: Optional[torch.FloatTensor] = None, | ||
| token_type_ids: Optional[torch.LongTensor] = None, | ||
| pixel_values: Optional[torch.FloatTensor] = None, | ||
| pixel_mask: Optional[torch.LongTensor] = None, | ||
| head_mask: Optional[torch.FloatTensor] = None, | ||
| inputs_embeds: Optional[torch.FloatTensor] = None, | ||
| image_embeds: Optional[torch.FloatTensor] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None, | ||
| labels: Optional[torch.LongTensor] = None, | ||
| ) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]: | ||
| r""" | ||
| labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): | ||
| Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match. | ||
| The pairs with 0 will be skipped for calculation. | ||
| Returns: | ||
|
|
||
| Examples: | ||
|
|
||
| ```python | ||
| >>> from transformers import BridgeTowerProcessor, BridgeTowerForITC | ||
| >>> import requests | ||
| >>> from PIL import Image | ||
|
|
||
| >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" | ||
| >>> image = Image.open(requests.get(url, stream=True).raw) | ||
| >>> texts = "An image of two cats chilling on a couch" | ||
|
|
||
| >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") | ||
| >>> model = BridgeTowerForITC.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") | ||
| >>> outputs = model(**inputs, output_hidden_states=True) | ||
| ```""" | ||
| assert output_hidden_states, 'output_hidden_states should be set to True for BridgeTowerForITC' | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No new asserts in the code base, please use a test and raise an appropriate error. |
||
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
|
||
| outputs = self.bridgetower( | ||
| input_ids, | ||
| attention_mask=attention_mask, | ||
| token_type_ids=token_type_ids, | ||
| pixel_values=pixel_values, | ||
| pixel_mask=pixel_mask, | ||
| head_mask=head_mask, | ||
| inputs_embeds=inputs_embeds, | ||
| image_embeds=image_embeds, | ||
| output_attentions=output_attentions, | ||
| output_hidden_states=output_hidden_states, | ||
| return_dict=return_dict, | ||
| ) | ||
|
|
||
| pooler_output = outputs.pooler_output if return_dict else outputs[2] | ||
| hidden_states_txt, hidden_states_img, hidden_states_cross_modal = outputs.hidden_states | ||
|
|
||
| text_embeds = hidden_states_txt[-1] | ||
| image_embeds = hidden_states_img[-1] | ||
|
|
||
| image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(image_embeds) | ||
| image_token_type_embeddings = self.bridgetower.token_type_embeddings( | ||
| torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device) | ||
| ).expand_as(image_embeds_with_ln) | ||
|
|
||
| image_embeds = ( | ||
| self.bridgetower.cross_modal_image_transform(image_embeds_with_ln) | ||
| + image_token_type_embeddings | ||
| ) | ||
|
|
||
| # normalized features | ||
| text_embeds = normalize(self.itc_text_head(text_embeds[:,0,:]), dim=-1, p=2) | ||
| image_embeds = normalize(self.itc_image_head(image_embeds[:,0,:]), dim=-1, p=2) | ||
| cross_embeds = normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use nn.functional.normalize here, so it's clear to the user where the normalize comes from. |
||
|
|
||
| logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2) | ||
|
|
||
| logit_scale = self.logit_scale.exp() | ||
| logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale | ||
| logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale | ||
|
|
||
| itc_loss = None | ||
|
|
||
| if labels is not None: | ||
| labels = torch.arange(len(labels), device=labels.device) | ||
| text_to_image_loss = contrastive_loss(logits_text_to_image, labels) | ||
| text_to_cross_loss = contrastive_loss(logits_text_to_cross, labels) | ||
| itc_loss = (text_to_image_loss + text_to_cross_loss) / 2.0 | ||
|
|
||
| if not return_dict: | ||
| output = tuple(logits) | ||
| return ((itc_loss,) + output) if itc_loss is not None else output | ||
|
|
||
| return BridgeTowerITCOutput( | ||
| attentions=outputs.attentions, | ||
| hidden_states=outputs.hidden_states, | ||
| text_embeds=text_embeds, | ||
| image_embeds=image_embeds, | ||
| cross_embeds=cross_embeds, | ||
| logits=logits, | ||
| loss=itc_loss | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As said below, let's not import this and use it via
nn.functional.normalize.