From 488802db185168bf90997fccc4c0ff2b236687b2 Mon Sep 17 00:00:00 2001 From: Anahita Bhiwandiwalla Date: Sun, 5 Mar 2023 22:50:46 -0800 Subject: [PATCH 1/5] Add BridgeTower for ITC --- src/transformers/__init__.py | 2 + .../models/bridgetower/__init__.py | 2 + .../bridgetower/modeling_bridgetower.py | 169 +++++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 7 + 4 files changed, 179 insertions(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 6514acd20389..40ab5371000b 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1212,6 +1212,7 @@ [ "BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST", "BridgeTowerForImageAndTextRetrieval", + "BridgeTowerForITC", "BridgeTowerForMaskedLM", "BridgeTowerModel", "BridgeTowerPreTrainedModel", @@ -4669,6 +4670,7 @@ from .models.bridgetower import ( BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST, BridgeTowerForImageAndTextRetrieval, + BridgeTowerForITC, BridgeTowerForMaskedLM, BridgeTowerModel, BridgeTowerPreTrainedModel, diff --git a/src/transformers/models/bridgetower/__init__.py b/src/transformers/models/bridgetower/__init__.py index 7058fffa529e..eb361f19803d 100644 --- a/src/transformers/models/bridgetower/__init__.py +++ b/src/transformers/models/bridgetower/__init__.py @@ -43,6 +43,7 @@ _import_structure["modeling_bridgetower"] = [ "BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST", "BridgeTowerForImageAndTextRetrieval", + "BridgeTowerForITC", "BridgeTowerForMaskedLM", "BridgeTowerModel", "BridgeTowerPreTrainedModel", @@ -75,6 +76,7 @@ from .modeling_bridgetower import ( BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST, BridgeTowerForImageAndTextRetrieval, + BridgeTowerForITC, BridgeTowerForMaskedLM, BridgeTowerModel, BridgeTowerPreTrainedModel, diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 1fbc85ad314f..6c12d2213446 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -23,6 +23,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss +from torch.nn.functional import normalize from ...activations import ACT2FN, QuickGELUActivation from ...modeling_outputs import ( @@ -57,6 +58,10 @@ ] +def contrastive_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, labels) + + BRIDGETOWER_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and @@ -144,7 +149,6 @@ class BridgeTowerModelOutput(ModelOutput): hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, @@ -161,6 +165,36 @@ class BridgeTowerModelOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +class BridgeTowerITCOutput(ModelOutput): + """ + Args: + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + image_embeds + cross_embeds + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + ITC loss. + """ + + attentions: Optional[Tuple[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + text_embeds: Optional[Tuple[torch.FloatTensor]] = None + image_embeds: Optional[Tuple[torch.FloatTensor]] = None + cross_embeds: Optional[Tuple[torch.FloatTensor]] = None + logits: Optional[torch.FloatTensor] = None + loss: Optional[torch.FloatTensor] = None + + class BridgeTowerResidualAttention(nn.Module): def __init__(self, config): super().__init__() @@ -1698,3 +1732,136 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class BridgeTowerITCHead(nn.Module): + def __init__(self, hidden_size, embed_size): + super().__init__() + self.fc = nn.Linear(hidden_size, embed_size) + + def forward(self, x): + x = self.fc(x) + return x + + +@add_start_docstrings( + """ + BridgeTower Model with a image-text contrastive head on top as done during pretraining. + """, + BRIDGETOWER_START_DOCSTRING, +) +class BridgeTowerForITC(BridgeTowerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bridgetower = BridgeTowerModel(config) + + self.itc_text_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size) + self.itc_image_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size) + self.itc_cross_modal_head = BridgeTowerITCHead(config.hidden_size * 2, config.contrastive_hidden_size) + + self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value) + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + image_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): + Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match. + The pairs with 0 will be skipped for calculation. + Returns: + + Examples: + + ```python + >>> from transformers import BridgeTowerProcessor, BridgeTowerForITC + >>> import requests + >>> from PIL import Image + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> texts = "An image of two cats chilling on a couch" + + >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") + >>> model = BridgeTowerForITC.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") + >>> outputs = model(**inputs, output_hidden_states=True) + ```""" + assert output_hidden_states, 'output_hidden_states should be set to True for BridgeTowerForITC' + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bridgetower( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + pixel_values=pixel_values, + pixel_mask=pixel_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + image_embeds=image_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooler_output = outputs.pooler_output if return_dict else outputs[2] + hidden_states_txt, hidden_states_img, hidden_states_cross_modal = outputs.hidden_states + + text_embeds = hidden_states_txt[-1] + image_embeds = hidden_states_img[-1] + + image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(image_embeds) + image_token_type_embeddings = self.bridgetower.token_type_embeddings( + torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device) + ).expand_as(image_embeds_with_ln) + + image_embeds = ( + self.bridgetower.cross_modal_image_transform(image_embeds_with_ln) + + image_token_type_embeddings + ) + + # normalized features + text_embeds = normalize(self.itc_text_head(text_embeds[:,0,:]), dim=-1, p=2) + image_embeds = normalize(self.itc_image_head(image_embeds[:,0,:]), dim=-1, p=2) + cross_embeds = normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2) + + logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2) + + logit_scale = self.logit_scale.exp() + logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale + + itc_loss = None + + if labels is not None: + labels = torch.arange(len(labels), device=labels.device) + text_to_image_loss = contrastive_loss(logits_text_to_image, labels) + text_to_cross_loss = contrastive_loss(logits_text_to_cross, labels) + itc_loss = (text_to_image_loss + text_to_cross_loss) / 2.0 + + if not return_dict: + output = tuple(logits) + return ((itc_loss,) + output) if itc_loss is not None else output + + return BridgeTowerITCOutput( + attentions=outputs.attentions, + hidden_states=outputs.hidden_states, + text_embeds=text_embeds, + image_embeds=image_embeds, + cross_embeds=cross_embeds, + logits=logits, + loss=itc_loss + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index d520c7dd1bee..5f0da55184e3 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1304,6 +1304,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class BridgeTowerForITC(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class BridgeTowerForMaskedLM(metaclass=DummyObject): _backends = ["torch"] From 58b334fc014b6ec449529adbf5f4a11f6f57f85d Mon Sep 17 00:00:00 2001 From: Anahita Bhiwandiwalla Date: Mon, 6 Mar 2023 11:41:55 -0800 Subject: [PATCH 2/5] Fix review feedback --- .../bridgetower/modeling_bridgetower.py | 50 +++++++++---------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 6c12d2213446..0ad9aac829df 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -23,7 +23,6 @@ import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss -from torch.nn.functional import normalize from ...activations import ACT2FN, QuickGELUActivation from ...modeling_outputs import ( @@ -58,10 +57,6 @@ ] -def contrastive_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - return nn.functional.cross_entropy(logits, labels) - - BRIDGETOWER_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and @@ -148,8 +143,8 @@ class BridgeTowerModelOutput(ModelOutput): token), respectively, after further processing through layers used for auxiliary pretraining tasks. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of + the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. @@ -168,22 +163,26 @@ class BridgeTowerModelOutput(ModelOutput): @dataclass class BridgeTowerITCOutput(ModelOutput): """ + Output type of ['BridgeTowerForITC'] + Args: attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of + the model at the output of each layer plus the optional initial embedding outputs. text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): The text embeddings obtained by applying the projection layer to the pooler_output. - image_embeds - cross_embeds + image_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): + The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - ITC loss. + Image-text contrastive loss """ attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -1746,7 +1745,7 @@ def forward(self, x): @add_start_docstrings( """ - BridgeTower Model with a image-text contrastive head on top as done during pretraining. + BridgeTower Model with a image-text contrastive head on top computing image-text contrastive loss. """, BRIDGETOWER_START_DOCSTRING, ) @@ -1764,6 +1763,8 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + @add_start_docstrings_to_model_forward(BRIDGETOWER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BridgeTowerITCOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1775,10 +1776,10 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, image_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, + output_hidden_states: Optional[bool] = True, return_dict: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, - ) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]: + ) -> Union[BridgeTowerITCOutput, Tuple[torch.FloatTensor]]: r""" labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match. @@ -1800,7 +1801,6 @@ def forward( >>> model = BridgeTowerForITC.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") >>> outputs = model(**inputs, output_hidden_states=True) ```""" - assert output_hidden_states, 'output_hidden_states should be set to True for BridgeTowerForITC' return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bridgetower( @@ -1828,15 +1828,12 @@ def forward( torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device) ).expand_as(image_embeds_with_ln) - image_embeds = ( - self.bridgetower.cross_modal_image_transform(image_embeds_with_ln) - + image_token_type_embeddings - ) + image_embeds = self.bridgetower.cross_modal_image_transform(image_embeds_with_ln) + image_token_type_embeddings # normalized features - text_embeds = normalize(self.itc_text_head(text_embeds[:,0,:]), dim=-1, p=2) - image_embeds = normalize(self.itc_image_head(image_embeds[:,0,:]), dim=-1, p=2) - cross_embeds = normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2) + text_embeds = nn.funtional.normalize(self.itc_text_head(text_embeds[:, 0, :]), dim=-1, p=2) + image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2) + cross_embeds = nn.funtional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2) logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2) @@ -1848,8 +1845,9 @@ def forward( if labels is not None: labels = torch.arange(len(labels), device=labels.device) - text_to_image_loss = contrastive_loss(logits_text_to_image, labels) - text_to_cross_loss = contrastive_loss(logits_text_to_cross, labels) + loss_fct = CrossEntropyLoss() + text_to_image_loss = loss_fct(logits_text_to_image, labels) + text_to_cross_loss = loss_fct(logits_text_to_cross, labels) itc_loss = (text_to_image_loss + text_to_cross_loss) / 2.0 if not return_dict: @@ -1863,5 +1861,5 @@ def forward( image_embeds=image_embeds, cross_embeds=cross_embeds, logits=logits, - loss=itc_loss + loss=itc_loss, ) From be2be4a5c095da8642207a2507bfbe97738421b9 Mon Sep 17 00:00:00 2001 From: Anahita Bhiwandiwalla Date: Tue, 7 Mar 2023 15:09:17 -0800 Subject: [PATCH 3/5] Rename BridgeTowerForITC, cleanup --- docs/source/en/model_doc/bridgetower.mdx | 27 +++++++++++ src/transformers/__init__.py | 4 +- .../models/bridgetower/__init__.py | 4 +- .../bridgetower/modeling_bridgetower.py | 46 +++++++++++-------- src/transformers/utils/dummy_pt_objects.py | 2 +- 5 files changed, 58 insertions(+), 25 deletions(-) diff --git a/docs/source/en/model_doc/bridgetower.mdx b/docs/source/en/model_doc/bridgetower.mdx index 87015877dc9c..9f7572f3122b 100644 --- a/docs/source/en/model_doc/bridgetower.mdx +++ b/docs/source/en/model_doc/bridgetower.mdx @@ -42,6 +42,28 @@ In principle, one can apply any visual, textual or cross-modal encoder in the pr The [`BridgeTowerProcessor`] wraps [`RobertaTokenizer`] and [`BridgeTowerImageProcessor`] into a single instance to both encode the text and prepare the images respectively. +The following example shows how to run contrastive learning using [`BridgeTowerProcessor`] and [`BridgeTowerForContrastiveLearning`]. +```python +>>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning +>>> import requests +>>> from PIL import Image + +>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" +>>> image = Image.open(requests.get(url, stream=True).raw) +>>> texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"] + +>>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") +>>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") + +>>> # forward pass +>>> scores = dict() +>>> for text in texts: +... # prepare inputs +... encoding = processor(image, text, return_tensors="pt") +... outputs = model(**encoding) +... scores[text] = outputs +``` + The following example shows how to run image-text retrieval using [`BridgeTowerProcessor`] and [`BridgeTowerForImageAndTextRetrieval`]. ```python >>> from transformers import BridgeTowerProcessor, BridgeTowerForImageAndTextRetrieval @@ -128,6 +150,11 @@ Tips: [[autodoc]] BridgeTowerModel - forward +## BridgeTowerForContrastiveLearning + +[[autodoc]] BridgeTowerForContrastiveLearning + - forward + ## BridgeTowerForMaskedLM [[autodoc]] BridgeTowerForMaskedLM diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5b06baa0c12f..0448cc723cd5 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1227,8 +1227,8 @@ _import_structure["models.bridgetower"].extend( [ "BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST", + "BridgeTowerForContrastiveLearning", "BridgeTowerForImageAndTextRetrieval", - "BridgeTowerForITC", "BridgeTowerForMaskedLM", "BridgeTowerModel", "BridgeTowerPreTrainedModel", @@ -4700,8 +4700,8 @@ ) from .models.bridgetower import ( BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST, + BridgeTowerForContrastiveLearning, BridgeTowerForImageAndTextRetrieval, - BridgeTowerForITC, BridgeTowerForMaskedLM, BridgeTowerModel, BridgeTowerPreTrainedModel, diff --git a/src/transformers/models/bridgetower/__init__.py b/src/transformers/models/bridgetower/__init__.py index eb361f19803d..cbd5bd4a366a 100644 --- a/src/transformers/models/bridgetower/__init__.py +++ b/src/transformers/models/bridgetower/__init__.py @@ -42,8 +42,8 @@ else: _import_structure["modeling_bridgetower"] = [ "BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST", + "BridgeTowerForContrastiveLearning", "BridgeTowerForImageAndTextRetrieval", - "BridgeTowerForITC", "BridgeTowerForMaskedLM", "BridgeTowerModel", "BridgeTowerPreTrainedModel", @@ -75,8 +75,8 @@ else: from .modeling_bridgetower import ( BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST, + BridgeTowerForContrastiveLearning, BridgeTowerForImageAndTextRetrieval, - BridgeTowerForITC, BridgeTowerForMaskedLM, BridgeTowerModel, BridgeTowerPreTrainedModel, diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 5a077fcad031..66b2bba59ab2 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -161,9 +161,9 @@ class BridgeTowerModelOutput(ModelOutput): @dataclass -class BridgeTowerITCOutput(ModelOutput): +class BridgeTowerContrastiveOutput(ModelOutput): """ - Output type of ['BridgeTowerForITC'] + Output type of ['BridgeTowerForContrastiveLearning'] Args: attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): @@ -1347,7 +1347,12 @@ def forward( if output_hidden_states: all_hidden_states_text += (text_embeds,) - image_embeds = self.vision_model.visual.forward_pre(pixel_values.type(self.vision_model.dtype)) + if image_embeds is None: + image_embeds = self.vision_model.visual.forward_pre(pixel_values.type(self.vision_model.dtype)) + else: + # Permute as BridgeTowerResidualAttention has batch_first=True + image_embeds = image_embeds.permute(1,0,2) + if output_hidden_states: all_hidden_states_image += (image_embeds,) @@ -1735,7 +1740,7 @@ def forward( ) -class BridgeTowerITCHead(nn.Module): +class BridgeTowerContrastiveHead(nn.Module): def __init__(self, hidden_size, embed_size): super().__init__() self.fc = nn.Linear(hidden_size, embed_size) @@ -1751,22 +1756,22 @@ def forward(self, x): """, BRIDGETOWER_START_DOCSTRING, ) -class BridgeTowerForITC(BridgeTowerPreTrainedModel): +class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel): def __init__(self, config): super().__init__(config) self.bridgetower = BridgeTowerModel(config) - self.itc_text_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size) - self.itc_image_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size) - self.itc_cross_modal_head = BridgeTowerITCHead(config.hidden_size * 2, config.contrastive_hidden_size) + self.itc_text_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size) + self.itc_image_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size) + self.itc_cross_modal_head = BridgeTowerContrastiveHead(config.hidden_size * 2, config.contrastive_hidden_size) self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value) # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(BRIDGETOWER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BridgeTowerITCOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=BridgeTowerContrastiveOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1781,7 +1786,7 @@ def forward( output_hidden_states: Optional[bool] = True, return_dict: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, - ) -> Union[BridgeTowerITCOutput, Tuple[torch.FloatTensor]]: + ) -> Union[BridgeTowerContrastiveOutput, Tuple[torch.FloatTensor]]: r""" labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match. @@ -1791,7 +1796,7 @@ def forward( Examples: ```python - >>> from transformers import BridgeTowerProcessor, BridgeTowerForITC + >>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning >>> import requests >>> from PIL import Image @@ -1800,7 +1805,7 @@ def forward( >>> texts = "An image of two cats chilling on a couch" >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") - >>> model = BridgeTowerForITC.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") + >>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") >>> outputs = model(**inputs, output_hidden_states=True) ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1833,30 +1838,31 @@ def forward( image_embeds = self.bridgetower.cross_modal_image_transform(image_embeds_with_ln) + image_token_type_embeddings # normalized features - text_embeds = nn.funtional.normalize(self.itc_text_head(text_embeds[:, 0, :]), dim=-1, p=2) + text_embeds = nn.functional.normalize(self.itc_text_head(text_embeds[:, 0, :]), dim=-1, p=2) image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2) - cross_embeds = nn.funtional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2) + cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2) logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2) logit_scale = self.logit_scale.exp() logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale + logits_image_to_cross = torch.matmul(image_embeds, cross_embeds.t()) * logit_scale itc_loss = None if labels is not None: - labels = torch.arange(len(labels), device=labels.device) - loss_fct = CrossEntropyLoss() - text_to_image_loss = loss_fct(logits_text_to_image, labels) - text_to_cross_loss = loss_fct(logits_text_to_cross, labels) - itc_loss = (text_to_image_loss + text_to_cross_loss) / 2.0 + labels = torch.arange(len(labels), device=logits.device) + text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels) + text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels) + image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels) + itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0 if not return_dict: output = tuple(logits) return ((itc_loss,) + output) if itc_loss is not None else output - return BridgeTowerITCOutput( + return BridgeTowerContrastiveOutput( attentions=outputs.attentions, hidden_states=outputs.hidden_states, text_embeds=text_embeds, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 3f7ec3deee0c..0baa6f5c27eb 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1335,7 +1335,7 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class BridgeTowerForITC(metaclass=DummyObject): +class BridgeTowerForContrastiveLearning(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): From cc34c2b4755c88cc81bf4bd9e81d9b04c65048a1 Mon Sep 17 00:00:00 2001 From: Anahita Bhiwandiwalla Date: Tue, 7 Mar 2023 15:25:24 -0800 Subject: [PATCH 4/5] Fix style and quality --- src/transformers/models/bridgetower/modeling_bridgetower.py | 2 +- src/transformers/utils/dummy_pt_objects.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 66b2bba59ab2..2b1976a433f1 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -1351,7 +1351,7 @@ def forward( image_embeds = self.vision_model.visual.forward_pre(pixel_values.type(self.vision_model.dtype)) else: # Permute as BridgeTowerResidualAttention has batch_first=True - image_embeds = image_embeds.permute(1,0,2) + image_embeds = image_embeds.permute(1, 0, 2) if output_hidden_states: all_hidden_states_image += (image_embeds,) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 0baa6f5c27eb..45c5b9d1f24e 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1328,14 +1328,14 @@ def __init__(self, *args, **kwargs): BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST = None -class BridgeTowerForImageAndTextRetrieval(metaclass=DummyObject): +class BridgeTowerForContrastiveLearning(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class BridgeTowerForContrastiveLearning(metaclass=DummyObject): +class BridgeTowerForImageAndTextRetrieval(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): From 19b5a79b8a4c188c6740b248e3057affe597d549 Mon Sep 17 00:00:00 2001 From: Tiep Le Date: Wed, 8 Mar 2023 01:28:05 -0800 Subject: [PATCH 5/5] implement tests --- .../bridgetower/modeling_bridgetower.py | 38 ++++++---- .../bridgetower/test_modeling_bridgetower.py | 76 +++++++++++++++++-- utils/check_repo.py | 1 + 3 files changed, 94 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 2b1976a433f1..f405407d7d9b 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -166,32 +166,32 @@ class BridgeTowerContrastiveOutput(ModelOutput): Output type of ['BridgeTowerForContrastiveLearning'] Args: - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of - the model at the output of each layer plus the optional initial embedding outputs. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): The text embeddings obtained by applying the projection layer to the pooler_output. image_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): The image embeddings obtained by applying the projection layer to the pooler_output. cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output. - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Image-text contrastive loss + Image-text contrastive loss. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of + the model at the output of each layer plus the optional initial embedding outputs. """ - attentions: Optional[Tuple[torch.FloatTensor]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None + logits: torch.FloatTensor = None text_embeds: Optional[Tuple[torch.FloatTensor]] = None image_embeds: Optional[Tuple[torch.FloatTensor]] = None cross_embeds: Optional[Tuple[torch.FloatTensor]] = None - logits: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None class BridgeTowerResidualAttention(nn.Module): @@ -1476,7 +1476,11 @@ def forward( all_hidden_states = (all_hidden_states_text, all_hidden_states_image, all_hidden_states_cross) if not return_dict: - return tuple(v for v in [text_features, image_features, cls_features] if v is not None) + return tuple( + v + for v in [text_features, image_features, cls_features, all_hidden_states, all_self_attentions] + if v is not None + ) return BridgeTowerModelOutput( text_features=text_features, @@ -1820,12 +1824,14 @@ def forward( inputs_embeds=inputs_embeds, image_embeds=image_embeds, output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + output_hidden_states=True, return_dict=return_dict, ) pooler_output = outputs.pooler_output if return_dict else outputs[2] - hidden_states_txt, hidden_states_img, hidden_states_cross_modal = outputs.hidden_states + hidden_states_txt, hidden_states_img, hidden_states_cross_modal = ( + outputs.hidden_states if return_dict else outputs[3] + ) text_embeds = hidden_states_txt[-1] image_embeds = hidden_states_img[-1] diff --git a/tests/models/bridgetower/test_modeling_bridgetower.py b/tests/models/bridgetower/test_modeling_bridgetower.py index afe3febb6983..9e70c4cdcd44 100644 --- a/tests/models/bridgetower/test_modeling_bridgetower.py +++ b/tests/models/bridgetower/test_modeling_bridgetower.py @@ -24,14 +24,25 @@ from transformers.utils import cached_property from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask +from ...test_modeling_common import ( + ModelTesterMixin, + _config_zero_init, + floats_tensor, + ids_tensor, + random_attention_mask, +) from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): import torch - from transformers import BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM, BridgeTowerModel + from transformers import ( + BridgeTowerForContrastiveLearning, + BridgeTowerForImageAndTextRetrieval, + BridgeTowerForMaskedLM, + BridgeTowerModel, + ) from transformers.models.bridgetower.modeling_bridgetower import BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_10 else: @@ -65,6 +76,8 @@ def __init__( text_config=None, vision_config=None, image_size=288, + contrastive_hidden_size=512, + logit_scale_init_value=2.6592, ): self.parent = parent self.share_cross_modal_transformer_layers = share_cross_modal_transformer_layers @@ -90,6 +103,8 @@ def __init__( self.is_training = False self.expected_num_hidden_layers = 32 self.output_hidden_states = output_hidden_states + self.contrastive_hidden_size = contrastive_hidden_size + self.logit_scale_init_value = logit_scale_init_value def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -118,6 +133,8 @@ def get_config(self): init_layernorm_from_vision_encoder=self.init_layernorm_from_vision_encoder, num_channels=self.num_channels, output_hidden_states=self.output_hidden_states, + contrastive_hidden_size=self.contrastive_hidden_size, + logit_scale_init_value=self.logit_scale_init_value, ) def create_and_check_model( @@ -189,7 +206,14 @@ def prepare_config_and_inputs_for_common(self): @unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+") class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (BridgeTowerModel, BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM) if is_torch_available() else () + ( + BridgeTowerModel, + BridgeTowerForImageAndTextRetrieval, + BridgeTowerForMaskedLM, + BridgeTowerForContrastiveLearning, + ) + if is_torch_available() + else () ) pipeline_model_mapping = {"feature-extraction": BridgeTowerModel} if is_torch_available() else {} @@ -347,6 +371,29 @@ def test_retain_grad_hidden_states_attentions(self): if self.has_attentions: self.assertIsNotNone(attentions.grad) + # override as the `logit_scale` parameter initilization is different for BRIDGE TOWER + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + if name == "logit_scale": + self.assertAlmostEqual( + param.data.item(), + config.logit_scale_init_value, + delta=1e-3, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + @unittest.skip(reason="""Bridge Tower does not have input/output embeddings. So this test is not applicable.""") def test_model_common_attributes(self): pass @@ -429,12 +476,31 @@ def test_masked_language_modeling(self): outputs = model(**inputs) self.assertAlmostEqual(outputs.loss.item(), 5.7373, places=4) + @slow + def test_constrastive_learning(self): + model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc").to( + torch_device + ) + model.eval() + processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") + image = prepare_img() + text = "a bunch of cats laying on a tower." + inputs = processor(image, text, return_tensors="pt").to(torch_device) + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True) + + # verify the logits + expected_shape = torch.Size([1, 3, 512]) + self.assertEqual(outputs.logits.shape, expected_shape) + @require_torch @unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+") class BridgeTowerModelTrainingTest(unittest.TestCase): all_training_supported_model_classes = ( - (BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM) if is_torch_available() else () + (BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM, BridgeTowerForContrastiveLearning) + if is_torch_available() + else () ) def setUp(self): @@ -445,7 +511,7 @@ def _prepare_inputs_for_training(self, model_class): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() if model_class == BridgeTowerForMaskedLM: inputs_dict["labels"] = inputs_dict["input_ids"] - elif model_class == BridgeTowerForImageAndTextRetrieval: + elif model_class == BridgeTowerForImageAndTextRetrieval or model_class == BridgeTowerForContrastiveLearning: inputs_dict["labels"] = ids_tensor([1], 2) return config, inputs_dict diff --git a/utils/check_repo.py b/utils/check_repo.py index af0237d38c8a..f16c4fb851bf 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -204,6 +204,7 @@ "Swin2SRForImageSuperResolution", "BridgeTowerForImageAndTextRetrieval", "BridgeTowerForMaskedLM", + "BridgeTowerForContrastiveLearning", "CLIPSegForImageSegmentation", "CLIPSegVisionModel", "CLIPSegTextModel",