-
Notifications
You must be signed in to change notification settings - Fork 33.6k
Update BridgeTowerForContrastiveLearning #22145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
0d0525e
dbba2d7
8974ba7
280abc0
5d54c37
5e60f56
38e4bb7
b2471ee
dff0fe1
07509f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -166,6 +166,8 @@ class BridgeTowerContrastiveOutput(ModelOutput): | |
| Output type of ['BridgeTowerForContrastiveLearning'] | ||
|
|
||
| Args: | ||
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss=True`: | ||
| Image-text contrastive loss. | ||
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): | ||
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | ||
| text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): | ||
|
|
@@ -174,24 +176,22 @@ class BridgeTowerContrastiveOutput(ModelOutput): | |
| The image embeddings obtained by applying the projection layer to the pooler_output. | ||
| cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): | ||
| The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output. | ||
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): | ||
| Image-text contrastive loss. | ||
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | ||
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | ||
| sequence_length)`. | ||
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | ||
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + | ||
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of | ||
| the model at the output of each layer plus the optional initial embedding outputs. | ||
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | ||
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | ||
| sequence_length)`. | ||
| """ | ||
|
|
||
| loss: Optional[torch.FloatTensor] = None | ||
| logits: torch.FloatTensor = None | ||
| text_embeds: Optional[Tuple[torch.FloatTensor]] = None | ||
| image_embeds: Optional[Tuple[torch.FloatTensor]] = None | ||
| cross_embeds: Optional[Tuple[torch.FloatTensor]] = None | ||
| loss: Optional[torch.FloatTensor] = None | ||
| attentions: Optional[Tuple[torch.FloatTensor]] = None | ||
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | ||
| attentions: Optional[Tuple[torch.FloatTensor]] = None | ||
|
|
||
|
|
||
| class BridgeTowerResidualAttention(nn.Module): | ||
|
|
@@ -1789,12 +1789,11 @@ def forward( | |
| output_attentions: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = True, | ||
| return_dict: Optional[bool] = None, | ||
| labels: Optional[torch.LongTensor] = None, | ||
| return_loss: Optional[bool] = None, | ||
| ) -> Union[BridgeTowerContrastiveOutput, Tuple[torch.FloatTensor]]: | ||
| r""" | ||
| labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): | ||
| Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match. | ||
| The pairs with 0 will be skipped for calculation. | ||
| return_loss (`bool`, *optional*): | ||
| Whether or not to return the contrastive loss. | ||
| Returns: | ||
|
|
||
| Examples: | ||
|
|
@@ -1803,14 +1802,29 @@ def forward( | |
| >>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning | ||
| >>> import requests | ||
| >>> from PIL import Image | ||
| >>> import torch | ||
|
|
||
| >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" | ||
| >>> image = Image.open(requests.get(url, stream=True).raw) | ||
| >>> texts = "An image of two cats chilling on a couch" | ||
| >>> image_urls = [ | ||
| ... "https://farm4.staticflickr.com/3395/3428278415_81c3e27f15_z.jpg", | ||
| ... "http://images.cocodataset.org/val2017/000000039769.jpg", | ||
| ... ] | ||
| >>> texts = ["two dogs in a car", "two cats sleeping on a couch"] | ||
| >>> images = [Image.open(requests.get(url, stream=True).raw) for url in image_urls] | ||
|
|
||
| >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") | ||
| >>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") | ||
| >>> outputs = model(**inputs, output_hidden_states=True) | ||
|
|
||
| >>> inputs = processor(images, texts, padding=True, return_tensors="pt") | ||
| >>> loss = model(**inputs, return_loss=True).loss | ||
|
|
||
| >>> inputs = processor(images, texts[::-1], padding=True, return_tensors="pt") | ||
| >>> loss_swapped = model(**inputs, return_loss=True).loss | ||
|
|
||
| >>> print("Loss", round(loss.item(), 4)) | ||
| Loss 0.0019 | ||
|
|
||
| >>> print("Loss with swapped images", round(loss_swapped.item(), 4)) | ||
| Loss with swapped images 2.126 | ||
| ```""" | ||
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
|
||
|
|
@@ -1857,23 +1871,23 @@ def forward( | |
|
|
||
| itc_loss = None | ||
|
|
||
| if labels is not None: | ||
| labels = torch.arange(len(labels), device=logits.device) | ||
| if return_loss: | ||
| labels = torch.arange(len(logits), device=logits.device) | ||
| text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels) | ||
| text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels) | ||
| image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels) | ||
| itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0 | ||
|
|
||
| if not return_dict: | ||
| output = tuple(logits) | ||
| output = (logits, text_embeds, image_embeds, cross_embeds) + outputs[3:] | ||
| return ((itc_loss,) + output) if itc_loss is not None else output | ||
|
|
||
| return BridgeTowerContrastiveOutput( | ||
| attentions=outputs.attentions, | ||
| hidden_states=outputs.hidden_states, | ||
| loss=itc_loss, | ||
| logits=logits, | ||
| text_embeds=text_embeds, | ||
| image_embeds=image_embeds, | ||
| cross_embeds=cross_embeds, | ||
| logits=logits, | ||
| loss=itc_loss, | ||
| attentions=outputs.attentions, | ||
| hidden_states=outputs.hidden_states, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might missed this in previous review. But let's put
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank @ydshieh for helping this.
ydshieh marked this conversation as resolved.
Outdated
|
||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -94,7 +94,7 @@ def __init__( | |
| self.num_hidden_layers = num_hidden_layers | ||
| self.tie_word_embeddings = tie_word_embeddings | ||
| self.init_layernorm_from_vision_encoder = init_layernorm_from_vision_encoder | ||
| self.vocab_size = 50265 | ||
| self.vocab_size = 99 | ||
| self.num_channels = 3 | ||
| self.seq_length = 4 | ||
| self.num_image_features = 325 | ||
|
|
@@ -115,6 +115,8 @@ def prepare_config_and_inputs(self): | |
| return (config, input_ids, attention_mask, pixel_values, pixel_mask) | ||
|
|
||
| def get_config(self): | ||
| text_config = {"vocab_size": self.vocab_size} | ||
|
|
||
| return BridgeTowerConfig( | ||
| share_cross_modal_transformer_layers=self.share_cross_modal_transformer_layers, | ||
| drop_rate=self.drop_rate, | ||
|
|
@@ -135,6 +137,7 @@ def get_config(self): | |
| output_hidden_states=self.output_hidden_states, | ||
| contrastive_hidden_size=self.contrastive_hidden_size, | ||
| logit_scale_init_value=self.logit_scale_init_value, | ||
| text_config=text_config, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regarding the test, the complete picture is given below. But this requires much more work, and probably doesn't match perfectly with your interest. So we can also work on this internally if you find this is better. Let me know your thoughts :-). Here are the details: We would like to have something similar in
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's revert the change regarding
|
||
| ) | ||
|
|
||
| def create_and_check_model( | ||
|
|
@@ -231,7 +234,7 @@ def extract_output(self, outputs, model_class): | |
|
|
||
| def setUp(self): | ||
| self.model_tester = BridgeTowerModelTester(self) | ||
| self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265) | ||
| self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=99) | ||
|
|
||
| def test_config(self): | ||
| self.config_tester.run_common_tests() | ||
|
|
@@ -486,9 +489,9 @@ def test_constrastive_learning(self): | |
| processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") | ||
| image = prepare_img() | ||
| text = "a bunch of cats laying on a tower." | ||
| inputs = processor(image, text, return_tensors="pt").to(torch_device) | ||
| inputs = processor(image, text, padding=True, return_tensors="pt").to(torch_device) | ||
| with torch.no_grad(): | ||
| outputs = model(**inputs, output_hidden_states=True) | ||
| outputs = model(**inputs, output_hidden_states=True, return_loss=True) | ||
|
|
||
| # verify the logits | ||
| expected_shape = torch.Size([1, 3, 512]) | ||
|
|
@@ -507,14 +510,16 @@ class BridgeTowerModelTrainingTest(unittest.TestCase): | |
|
|
||
| def setUp(self): | ||
| self.model_tester = BridgeTowerModelTester(self) | ||
| self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265) | ||
| self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=99) | ||
|
|
||
| def _prepare_inputs_for_training(self, model_class): | ||
| config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
| if model_class == BridgeTowerForMaskedLM: | ||
| inputs_dict["labels"] = inputs_dict["input_ids"] | ||
| elif model_class == BridgeTowerForImageAndTextRetrieval or model_class == BridgeTowerForContrastiveLearning: | ||
| elif model_class == BridgeTowerForImageAndTextRetrieval: | ||
| inputs_dict["labels"] = ids_tensor([1], 2) | ||
| elif model_class == BridgeTowerForContrastiveLearning: | ||
| inputs_dict["return_loss"] = True | ||
| return config, inputs_dict | ||
|
|
||
| def _get_non_used_layer_names(self, model_class): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We updated this accordingly in our latest commit. Thanks for the suggestion.