diff --git a/docs/source/en/model_doc/bridgetower.mdx b/docs/source/en/model_doc/bridgetower.mdx index 87015877dc9c..9f7572f3122b 100644 --- a/docs/source/en/model_doc/bridgetower.mdx +++ b/docs/source/en/model_doc/bridgetower.mdx @@ -42,6 +42,28 @@ In principle, one can apply any visual, textual or cross-modal encoder in the pr The [`BridgeTowerProcessor`] wraps [`RobertaTokenizer`] and [`BridgeTowerImageProcessor`] into a single instance to both encode the text and prepare the images respectively. +The following example shows how to run contrastive learning using [`BridgeTowerProcessor`] and [`BridgeTowerForContrastiveLearning`]. +```python +>>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning +>>> import requests +>>> from PIL import Image + +>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" +>>> image = Image.open(requests.get(url, stream=True).raw) +>>> texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"] + +>>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") +>>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") + +>>> # forward pass +>>> scores = dict() +>>> for text in texts: +... # prepare inputs +... encoding = processor(image, text, return_tensors="pt") +... outputs = model(**encoding) +... scores[text] = outputs +``` + The following example shows how to run image-text retrieval using [`BridgeTowerProcessor`] and [`BridgeTowerForImageAndTextRetrieval`]. ```python >>> from transformers import BridgeTowerProcessor, BridgeTowerForImageAndTextRetrieval @@ -128,6 +150,11 @@ Tips: [[autodoc]] BridgeTowerModel - forward +## BridgeTowerForContrastiveLearning + +[[autodoc]] BridgeTowerForContrastiveLearning + - forward + ## BridgeTowerForMaskedLM [[autodoc]] BridgeTowerForMaskedLM diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index fafe8ba287d4..cceefc393718 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1182,6 +1182,7 @@ _import_structure["models.bridgetower"].extend( [ "BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST", + "BridgeTowerForContrastiveLearning", "BridgeTowerForImageAndTextRetrieval", "BridgeTowerForMaskedLM", "BridgeTowerModel", @@ -4666,6 +4667,7 @@ ) from .models.bridgetower import ( BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST, + BridgeTowerForContrastiveLearning, BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM, BridgeTowerModel, diff --git a/src/transformers/models/bridgetower/__init__.py b/src/transformers/models/bridgetower/__init__.py index 7058fffa529e..cbd5bd4a366a 100644 --- a/src/transformers/models/bridgetower/__init__.py +++ b/src/transformers/models/bridgetower/__init__.py @@ -42,6 +42,7 @@ else: _import_structure["modeling_bridgetower"] = [ "BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST", + "BridgeTowerForContrastiveLearning", "BridgeTowerForImageAndTextRetrieval", "BridgeTowerForMaskedLM", "BridgeTowerModel", @@ -74,6 +75,7 @@ else: from .modeling_bridgetower import ( BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST, + BridgeTowerForContrastiveLearning, BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM, BridgeTowerModel, diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index bae846201098..f405407d7d9b 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -143,9 +143,8 @@ class BridgeTowerModelOutput(ModelOutput): token), respectively, after further processing through layers used for auxiliary pretraining tasks. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of + the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. @@ -161,6 +160,40 @@ class BridgeTowerModelOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +class BridgeTowerContrastiveOutput(ModelOutput): + """ + Output type of ['BridgeTowerForContrastiveLearning'] + + Args: + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + image_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): + The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output. + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Image-text contrastive loss. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of + the model at the output of each layer plus the optional initial embedding outputs. + """ + + logits: torch.FloatTensor = None + text_embeds: Optional[Tuple[torch.FloatTensor]] = None + image_embeds: Optional[Tuple[torch.FloatTensor]] = None + cross_embeds: Optional[Tuple[torch.FloatTensor]] = None + loss: Optional[torch.FloatTensor] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + class BridgeTowerResidualAttention(nn.Module): def __init__(self, config): super().__init__() @@ -1314,7 +1347,12 @@ def forward( if output_hidden_states: all_hidden_states_text += (text_embeds,) - image_embeds = self.vision_model.visual.forward_pre(pixel_values.type(self.vision_model.dtype)) + if image_embeds is None: + image_embeds = self.vision_model.visual.forward_pre(pixel_values.type(self.vision_model.dtype)) + else: + # Permute as BridgeTowerResidualAttention has batch_first=True + image_embeds = image_embeds.permute(1, 0, 2) + if output_hidden_states: all_hidden_states_image += (image_embeds,) @@ -1438,7 +1476,11 @@ def forward( all_hidden_states = (all_hidden_states_text, all_hidden_states_image, all_hidden_states_cross) if not return_dict: - return tuple(v for v in [text_features, image_features, cls_features] if v is not None) + return tuple( + v + for v in [text_features, image_features, cls_features, all_hidden_states, all_self_attentions] + if v is not None + ) return BridgeTowerModelOutput( text_features=text_features, @@ -1700,3 +1742,138 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class BridgeTowerContrastiveHead(nn.Module): + def __init__(self, hidden_size, embed_size): + super().__init__() + self.fc = nn.Linear(hidden_size, embed_size) + + def forward(self, x): + x = self.fc(x) + return x + + +@add_start_docstrings( + """ + BridgeTower Model with a image-text contrastive head on top computing image-text contrastive loss. + """, + BRIDGETOWER_START_DOCSTRING, +) +class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bridgetower = BridgeTowerModel(config) + + self.itc_text_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size) + self.itc_image_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size) + self.itc_cross_modal_head = BridgeTowerContrastiveHead(config.hidden_size * 2, config.contrastive_hidden_size) + + self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value) + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BRIDGETOWER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BridgeTowerContrastiveOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + image_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = True, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Union[BridgeTowerContrastiveOutput, Tuple[torch.FloatTensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): + Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match. + The pairs with 0 will be skipped for calculation. + Returns: + + Examples: + + ```python + >>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning + >>> import requests + >>> from PIL import Image + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> texts = "An image of two cats chilling on a couch" + + >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") + >>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") + >>> outputs = model(**inputs, output_hidden_states=True) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bridgetower( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + pixel_values=pixel_values, + pixel_mask=pixel_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + image_embeds=image_embeds, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + pooler_output = outputs.pooler_output if return_dict else outputs[2] + hidden_states_txt, hidden_states_img, hidden_states_cross_modal = ( + outputs.hidden_states if return_dict else outputs[3] + ) + + text_embeds = hidden_states_txt[-1] + image_embeds = hidden_states_img[-1] + + image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(image_embeds) + image_token_type_embeddings = self.bridgetower.token_type_embeddings( + torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device) + ).expand_as(image_embeds_with_ln) + + image_embeds = self.bridgetower.cross_modal_image_transform(image_embeds_with_ln) + image_token_type_embeddings + + # normalized features + text_embeds = nn.functional.normalize(self.itc_text_head(text_embeds[:, 0, :]), dim=-1, p=2) + image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2) + cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2) + + logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2) + + logit_scale = self.logit_scale.exp() + logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale + logits_image_to_cross = torch.matmul(image_embeds, cross_embeds.t()) * logit_scale + + itc_loss = None + + if labels is not None: + labels = torch.arange(len(labels), device=logits.device) + text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels) + text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels) + image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels) + itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0 + + if not return_dict: + output = tuple(logits) + return ((itc_loss,) + output) if itc_loss is not None else output + + return BridgeTowerContrastiveOutput( + attentions=outputs.attentions, + hidden_states=outputs.hidden_states, + text_embeds=text_embeds, + image_embeds=image_embeds, + cross_embeds=cross_embeds, + logits=logits, + loss=itc_loss, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index e6323118b7e2..85b4010f38c4 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1328,6 +1328,13 @@ def __init__(self, *args, **kwargs): BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST = None +class BridgeTowerForContrastiveLearning(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class BridgeTowerForImageAndTextRetrieval(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/bridgetower/test_modeling_bridgetower.py b/tests/models/bridgetower/test_modeling_bridgetower.py index afe3febb6983..9e70c4cdcd44 100644 --- a/tests/models/bridgetower/test_modeling_bridgetower.py +++ b/tests/models/bridgetower/test_modeling_bridgetower.py @@ -24,14 +24,25 @@ from transformers.utils import cached_property from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask +from ...test_modeling_common import ( + ModelTesterMixin, + _config_zero_init, + floats_tensor, + ids_tensor, + random_attention_mask, +) from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): import torch - from transformers import BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM, BridgeTowerModel + from transformers import ( + BridgeTowerForContrastiveLearning, + BridgeTowerForImageAndTextRetrieval, + BridgeTowerForMaskedLM, + BridgeTowerModel, + ) from transformers.models.bridgetower.modeling_bridgetower import BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_10 else: @@ -65,6 +76,8 @@ def __init__( text_config=None, vision_config=None, image_size=288, + contrastive_hidden_size=512, + logit_scale_init_value=2.6592, ): self.parent = parent self.share_cross_modal_transformer_layers = share_cross_modal_transformer_layers @@ -90,6 +103,8 @@ def __init__( self.is_training = False self.expected_num_hidden_layers = 32 self.output_hidden_states = output_hidden_states + self.contrastive_hidden_size = contrastive_hidden_size + self.logit_scale_init_value = logit_scale_init_value def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -118,6 +133,8 @@ def get_config(self): init_layernorm_from_vision_encoder=self.init_layernorm_from_vision_encoder, num_channels=self.num_channels, output_hidden_states=self.output_hidden_states, + contrastive_hidden_size=self.contrastive_hidden_size, + logit_scale_init_value=self.logit_scale_init_value, ) def create_and_check_model( @@ -189,7 +206,14 @@ def prepare_config_and_inputs_for_common(self): @unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+") class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (BridgeTowerModel, BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM) if is_torch_available() else () + ( + BridgeTowerModel, + BridgeTowerForImageAndTextRetrieval, + BridgeTowerForMaskedLM, + BridgeTowerForContrastiveLearning, + ) + if is_torch_available() + else () ) pipeline_model_mapping = {"feature-extraction": BridgeTowerModel} if is_torch_available() else {} @@ -347,6 +371,29 @@ def test_retain_grad_hidden_states_attentions(self): if self.has_attentions: self.assertIsNotNone(attentions.grad) + # override as the `logit_scale` parameter initilization is different for BRIDGE TOWER + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + if name == "logit_scale": + self.assertAlmostEqual( + param.data.item(), + config.logit_scale_init_value, + delta=1e-3, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + @unittest.skip(reason="""Bridge Tower does not have input/output embeddings. So this test is not applicable.""") def test_model_common_attributes(self): pass @@ -429,12 +476,31 @@ def test_masked_language_modeling(self): outputs = model(**inputs) self.assertAlmostEqual(outputs.loss.item(), 5.7373, places=4) + @slow + def test_constrastive_learning(self): + model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc").to( + torch_device + ) + model.eval() + processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") + image = prepare_img() + text = "a bunch of cats laying on a tower." + inputs = processor(image, text, return_tensors="pt").to(torch_device) + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True) + + # verify the logits + expected_shape = torch.Size([1, 3, 512]) + self.assertEqual(outputs.logits.shape, expected_shape) + @require_torch @unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+") class BridgeTowerModelTrainingTest(unittest.TestCase): all_training_supported_model_classes = ( - (BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM) if is_torch_available() else () + (BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM, BridgeTowerForContrastiveLearning) + if is_torch_available() + else () ) def setUp(self): @@ -445,7 +511,7 @@ def _prepare_inputs_for_training(self, model_class): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() if model_class == BridgeTowerForMaskedLM: inputs_dict["labels"] = inputs_dict["input_ids"] - elif model_class == BridgeTowerForImageAndTextRetrieval: + elif model_class == BridgeTowerForImageAndTextRetrieval or model_class == BridgeTowerForContrastiveLearning: inputs_dict["labels"] = ids_tensor([1], 2) return config, inputs_dict diff --git a/utils/check_repo.py b/utils/check_repo.py index af0237d38c8a..f16c4fb851bf 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -204,6 +204,7 @@ "Swin2SRForImageSuperResolution", "BridgeTowerForImageAndTextRetrieval", "BridgeTowerForMaskedLM", + "BridgeTowerForContrastiveLearning", "CLIPSegForImageSegmentation", "CLIPSegVisionModel", "CLIPSegTextModel",