diff --git a/docs/source/en/model_doc/bridgetower.mdx b/docs/source/en/model_doc/bridgetower.mdx
index dec59225c521..87015877dc9c 100644
--- a/docs/source/en/model_doc/bridgetower.mdx
+++ b/docs/source/en/model_doc/bridgetower.mdx
@@ -28,7 +28,7 @@ This enables effective bottom-up cross-modal alignment and fusion between visual
In particular, on the VQAv2 test-std set, BRIDGETOWER achieves an accuracy of 78.73%, outperforming the previous state-of-the-art model METER by 1.09% with the same pre-training data and almost negligible additional parameters and computational costs.
Notably, when further scaling the model, BRIDGETOWER achieves an accuracy of 81.15%, surpassing models that are pre-trained on orders-of-magnitude larger datasets.*
-
BridgeTower architecture. Taken from the original paper.
diff --git a/src/transformers/models/bridgetower/__init__.py b/src/transformers/models/bridgetower/__init__.py
index aaa453cedf73..e92cf933c6cb 100644
--- a/src/transformers/models/bridgetower/__init__.py
+++ b/src/transformers/models/bridgetower/__init__.py
@@ -27,8 +27,9 @@
"BridgeTowerConfig",
"BridgeTowerTextConfig",
"BridgeTowerVisionConfig",
- ]
+ ],
}
+_import_structure["processing_bridgetower"] = ["BridgeTowerProcessor"]
try:
if not is_vision_available():
@@ -37,7 +38,6 @@
pass
else:
_import_structure["image_processing_bridgetower"] = ["BridgeTowerImageProcessor"]
-_import_structure["processing_bridgetower"] = ["BridgeTowerProcessor"]
try:
if not is_torch_available():
diff --git a/src/transformers/models/bridgetower/configuration_bridgetower.py b/src/transformers/models/bridgetower/configuration_bridgetower.py
index 1677a773d7b1..bd1457719d15 100644
--- a/src/transformers/models/bridgetower/configuration_bridgetower.py
+++ b/src/transformers/models/bridgetower/configuration_bridgetower.py
@@ -80,6 +80,7 @@ def __init__(
self,
hidden_size=768,
num_hidden_layers=12,
+ num_channels=3,
patch_size=16,
image_size=288,
initializer_factor=1,
@@ -92,6 +93,7 @@ def __init__(
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
+ self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.initializer_factor = initializer_factor
diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py
index b196c821b017..ac8e6a772d3f 100644
--- a/src/transformers/models/bridgetower/modeling_bridgetower.py
+++ b/src/transformers/models/bridgetower/modeling_bridgetower.py
@@ -127,6 +127,39 @@
"""
+@dataclass
+class BridgeTowerModelOutput(ModelOutput):
+ """
+ Output type of [`BridgeTowerModel`].
+
+ Args:
+ text_features (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, hidden_size)`):
+ Sequence of hidden-states at the text output of the last layer of the model.
+ image_features (`torch.FloatTensor` of shape `(batch_size, image_sequence_length, hidden_size)`):
+ Sequence of hidden-states at the image output of the last layer of the model.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size x 2)`):
+ Concatenation of last layer hidden-state of the first token of the text and image sequence (classification
+ token), respectively, after further processing through layers used for auxiliary pretraining tasks.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ text_features: torch.FloatTensor = None
+ image_features: torch.FloatTensor = None
+ pooler_output: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
class BridgeTowerResidualAttention(nn.Module):
def __init__(self, config):
super().__init__()
@@ -197,22 +230,45 @@ def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Ten
return hidden_states
-class BridgeTowerVisualTransformer(nn.Module):
- def __init__(self, config):
+class BridgeTowerVisionEmbeddings(nn.Module):
+ def __init__(self, config: BridgeTowerVisionConfig):
super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
- self.conv1 = nn.Conv2d(
- in_channels=3,
- out_channels=config.hidden_size,
- kernel_size=config.patch_size,
- stride=config.patch_size,
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
bias=False,
)
- scale = config.hidden_size**-0.5
- self.class_embedding = nn.Parameter(scale * torch.randn(config.hidden_size))
- self.positional_embedding = nn.Parameter(
- scale * torch.randn((config.image_size // config.patch_size) ** 2 + 1, config.hidden_size)
- )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches + 1
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ batch_size = pixel_values.shape[0]
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+ return embeddings
+
+
+class BridgeTowerVisionTransformer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.embeddings = BridgeTowerVisionEmbeddings(config)
self.ln_pre = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.transformer = BridgeTowerTransformer(config)
self.ln_post = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -222,60 +278,34 @@ def __init__(self, config):
[nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) for _ in range(config.num_hidden_layers)]
)
- def forward(self, hidden_state: torch.Tensor, attention_mask):
- # shape = [*, hidden_size, grid, grid]
- visual_output = self.conv1(hidden_state)
- # shape = [*, hidden_size, grid ** 2]
- visual_output = visual_output.reshape(visual_output.shape[0], visual_output.shape[1], -1)
- # shape = [*, grid ** 2, hidden_size]
- visual_output = visual_output.permute(0, 2, 1)
- t = self.class_embedding.to(visual_output.dtype) + torch.zeros(
- visual_output.shape[0], 1, visual_output.shape[-1], dtype=visual_output.dtype, device=visual_output.device
- )
- # shape = [*, grid ** 2 + 1, hidden_size]
- visual_output = torch.cat([t, visual_output], dim=1)
- visual_output = visual_output + self.positional_embedding.to(visual_output.dtype)
- visual_output = self.ln_pre(visual_output)
+ def forward(self, pixel_values: torch.Tensor, attention_mask):
+ hidden_states = self.embeddings(pixel_values)
+ hidden_states = self.ln_pre(hidden_states)
# NLD -> LND
- visual_output = visual_output.permute(1, 0, 2)
+ hidden_states = hidden_states.permute(1, 0, 2)
- visual_outputs = self.transformer(visual_output, attention_mask)
+ hidden_states = self.transformer(hidden_states, attention_mask)
# shape = [num_hidden_layers, hidden_size, *, grid ** 2]
- visual_outputs = torch.stack(visual_outputs, dim=0)
+ hidden_states = torch.stack(hidden_states, dim=0)
# shape = [num_hidden_layers, *, hidden_size, grid ** 2]
- visual_outputs = visual_outputs.permute(0, 2, 1, 3)
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
if self.share_layernorm:
- visual_outputs = self.ln_post(visual_outputs)
+ hidden_states = self.ln_post(hidden_states)
else:
- visual_outputs_stack = []
- for visual_output, ln in zip(visual_outputs, self.ln_separate):
- visual_output = ln(visual_output)
- visual_outputs_stack.append(visual_output)
+ hidden_states_stack = []
+ for hidden_states, ln in zip(hidden_states, self.ln_separate):
+ hidden_states = ln(hidden_states)
+ hidden_states_stack.append(hidden_states)
# shape = [num_hidden_layers, *, hidden_size, grid ** 2]
- visual_outputs = torch.stack(visual_outputs_stack, dim=0)
- return visual_outputs
-
- def forward_pre(self, hidden_state: torch.Tensor):
- # shape = [*, hidden_size, grid, grid]
- visual_outputs_pre = self.conv1(hidden_state)
- # shape = [*, hidden_size, grid ** 2]
- visual_outputs_pre = visual_outputs_pre.reshape(visual_outputs_pre.shape[0], visual_outputs_pre.shape[1], -1)
- # shape = [*, grid ** 2, hidden_size]
- visual_outputs_pre = visual_outputs_pre.permute(0, 2, 1)
- embeddings_to = self.class_embedding.to(visual_outputs_pre.dtype) + torch.zeros(
- visual_outputs_pre.shape[0],
- 1,
- visual_outputs_pre.shape[-1],
- dtype=visual_outputs_pre.dtype,
- device=visual_outputs_pre.device,
- )
- # shape = [*, grid ** 2 + 1, hidden_size]
- visual_outputs_pre = torch.cat([embeddings_to, visual_outputs_pre], dim=1)
- visual_outputs_pre = visual_outputs_pre + self.positional_embedding.to(visual_outputs_pre.dtype)
- visual_outputs_pre = self.ln_pre(visual_outputs_pre)
+ hidden_states = torch.stack(hidden_states_stack, dim=0)
+ return hidden_states
+
+ def forward_pre(self, pixel_values: torch.Tensor):
+ hidden_states = self.embeddings(pixel_values)
+ hidden_states = self.ln_pre(hidden_states)
# NLD -> LND
- visual_outputs_pre = visual_outputs_pre.permute(1, 0, 2)
- return visual_outputs_pre
+ hidden_states = hidden_states.permute(1, 0, 2)
+ return hidden_states
def forward_post(self, hidden_state: torch.Tensor):
visual_output_post = hidden_state.permute(1, 0, 2)
@@ -370,6 +400,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return pooled_output
+# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->BridgeTower
class BridgeTowerSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
@@ -456,15 +487,16 @@ def forward(
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
- position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device)
- position_ids_l = position_ids_l.view(-1, 1)
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
+ -1, 1
+ )
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
- positional_embedding = positional_embedding.to(dtype=query_layer.dtype)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
@@ -476,7 +508,7 @@ def forward(
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
- # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ # Apply the attention mask is (precomputed for all layers in BridgeTowerModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
@@ -486,6 +518,7 @@ def forward(
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
+ # Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
@@ -700,6 +733,7 @@ def feed_forward_chunk(self, attention_output):
return layer_output
+# Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->BridgeTowerText
class BridgeTowerTextEncoder(nn.Module):
def __init__(self, config):
super().__init__()
@@ -717,7 +751,7 @@ def forward(
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
- output_hidden_states: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None
@@ -733,6 +767,7 @@ def forward(
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
+
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
@@ -884,6 +919,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
return position_ids.unsqueeze(0).expand(input_shape)
+# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
@@ -900,39 +936,6 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
return incremental_indices.long() + padding_idx
-@dataclass
-class BridgeTowerModelOutput(ModelOutput):
- """
- Output type of [`BridgeTowerModel`].
-
- Args:
- text_features (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, hidden_size)`):
- Sequence of hidden-states at the text output of the last layer of the model.
- image_features (`torch.FloatTensor` of shape `(batch_size, image_sequence_length, hidden_size)`):
- Sequence of hidden-states at the image output of the last layer of the model.
- pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size x 2)`):
- Concatenation of last layer hidden-state of the first token of the text and image sequence (classification
- token), respectively, after further processing through layers used for auxiliary pretraining tasks.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- text_features: torch.FloatTensor = None
- image_features: torch.FloatTensor = None
- pooler_output: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
class BridgeTowerPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@@ -974,11 +977,11 @@ class BridgeTowerVisionModel(BridgeTowerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
- self.visual = BridgeTowerVisualTransformer(config)
+ self.visual = BridgeTowerVisionTransformer(config)
@property
def dtype(self):
- return self.visual.conv1.weight.dtype
+ return self.visual.embeddings.patch_embedding.weight.dtype
def forward(self, image, image_mask=None):
return self.visual(image.type(self.dtype), image_mask)