Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/model_doc/bridgetower.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ This enables effective bottom-up cross-modal alignment and fusion between visual
In particular, on the VQAv2 test-std set, BRIDGETOWER achieves an accuracy of 78.73%, outperforming the previous state-of-the-art model METER by 1.09% with the same pre-training data and almost negligible additional parameters and computational costs.
Notably, when further scaling the model, BRIDGETOWER achieves an accuracy of 81.15%, surpassing models that are pre-trained on orders-of-magnitude larger datasets.*

<img src="https://huggingface.co/datasets/huggingface/documentation-images/blob/main/transformers/model_doc/bridgetower_architecture%20.jpg"
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/bridgetower_architecture%20.jpg"
alt="drawing" width="600"/>

<small> BridgeTower architecture. Taken from the <a href="https://arxiv.org/abs/2206.08657">original paper.</a> </small>
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bridgetower/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
"BridgeTowerConfig",
"BridgeTowerTextConfig",
"BridgeTowerVisionConfig",
]
],
}
_import_structure["processing_bridgetower"] = ["BridgeTowerProcessor"]

try:
if not is_vision_available():
Expand All @@ -37,7 +38,6 @@
pass
else:
_import_structure["image_processing_bridgetower"] = ["BridgeTowerImageProcessor"]
_import_structure["processing_bridgetower"] = ["BridgeTowerProcessor"]

try:
if not is_torch_available():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
self,
hidden_size=768,
num_hidden_layers=12,
num_channels=3,
patch_size=16,
image_size=288,
initializer_factor=1,
Expand All @@ -92,6 +93,7 @@ def __init__(
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.initializer_factor = initializer_factor
Expand Down
199 changes: 101 additions & 98 deletions src/transformers/models/bridgetower/modeling_bridgetower.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,39 @@
"""


@dataclass
class BridgeTowerModelOutput(ModelOutput):
"""
Output type of [`BridgeTowerModel`].

Args:
text_features (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, hidden_size)`):
Sequence of hidden-states at the text output of the last layer of the model.
image_features (`torch.FloatTensor` of shape `(batch_size, image_sequence_length, hidden_size)`):
Sequence of hidden-states at the image output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size x 2)`):
Concatenation of last layer hidden-state of the first token of the text and image sequence (classification
token), respectively, after further processing through layers used for auxiliary pretraining tasks.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

text_features: torch.FloatTensor = None
image_features: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None


class BridgeTowerResidualAttention(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -197,22 +230,45 @@ def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Ten
return hidden_states


class BridgeTowerVisualTransformer(nn.Module):
def __init__(self, config):
class BridgeTowerVisionEmbeddings(nn.Module):
def __init__(self, config: BridgeTowerVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size

self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size,
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
scale = config.hidden_size**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(config.hidden_size))
self.positional_embedding = nn.Parameter(
scale * torch.randn((config.image_size // config.patch_size) ** 2 + 1, config.hidden_size)
)

self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))

def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings


class BridgeTowerVisionTransformer(nn.Module):
def __init__(self, config):
super().__init__()

self.embeddings = BridgeTowerVisionEmbeddings(config)
self.ln_pre = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.transformer = BridgeTowerTransformer(config)
self.ln_post = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Expand All @@ -222,60 +278,34 @@ def __init__(self, config):
[nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) for _ in range(config.num_hidden_layers)]
)

def forward(self, hidden_state: torch.Tensor, attention_mask):
# shape = [*, hidden_size, grid, grid]
visual_output = self.conv1(hidden_state)
# shape = [*, hidden_size, grid ** 2]
visual_output = visual_output.reshape(visual_output.shape[0], visual_output.shape[1], -1)
# shape = [*, grid ** 2, hidden_size]
visual_output = visual_output.permute(0, 2, 1)
t = self.class_embedding.to(visual_output.dtype) + torch.zeros(
visual_output.shape[0], 1, visual_output.shape[-1], dtype=visual_output.dtype, device=visual_output.device
)
# shape = [*, grid ** 2 + 1, hidden_size]
visual_output = torch.cat([t, visual_output], dim=1)
visual_output = visual_output + self.positional_embedding.to(visual_output.dtype)
visual_output = self.ln_pre(visual_output)
def forward(self, pixel_values: torch.Tensor, attention_mask):
hidden_states = self.embeddings(pixel_values)
hidden_states = self.ln_pre(hidden_states)
# NLD -> LND
visual_output = visual_output.permute(1, 0, 2)
hidden_states = hidden_states.permute(1, 0, 2)

visual_outputs = self.transformer(visual_output, attention_mask)
hidden_states = self.transformer(hidden_states, attention_mask)
# shape = [num_hidden_layers, hidden_size, *, grid ** 2]
visual_outputs = torch.stack(visual_outputs, dim=0)
hidden_states = torch.stack(hidden_states, dim=0)
# shape = [num_hidden_layers, *, hidden_size, grid ** 2]
visual_outputs = visual_outputs.permute(0, 2, 1, 3)
hidden_states = hidden_states.permute(0, 2, 1, 3)
if self.share_layernorm:
visual_outputs = self.ln_post(visual_outputs)
hidden_states = self.ln_post(hidden_states)
else:
visual_outputs_stack = []
for visual_output, ln in zip(visual_outputs, self.ln_separate):
visual_output = ln(visual_output)
visual_outputs_stack.append(visual_output)
hidden_states_stack = []
for hidden_states, ln in zip(hidden_states, self.ln_separate):
hidden_states = ln(hidden_states)
hidden_states_stack.append(hidden_states)
# shape = [num_hidden_layers, *, hidden_size, grid ** 2]
visual_outputs = torch.stack(visual_outputs_stack, dim=0)
return visual_outputs

def forward_pre(self, hidden_state: torch.Tensor):
# shape = [*, hidden_size, grid, grid]
visual_outputs_pre = self.conv1(hidden_state)
# shape = [*, hidden_size, grid ** 2]
visual_outputs_pre = visual_outputs_pre.reshape(visual_outputs_pre.shape[0], visual_outputs_pre.shape[1], -1)
# shape = [*, grid ** 2, hidden_size]
visual_outputs_pre = visual_outputs_pre.permute(0, 2, 1)
embeddings_to = self.class_embedding.to(visual_outputs_pre.dtype) + torch.zeros(
visual_outputs_pre.shape[0],
1,
visual_outputs_pre.shape[-1],
dtype=visual_outputs_pre.dtype,
device=visual_outputs_pre.device,
)
# shape = [*, grid ** 2 + 1, hidden_size]
visual_outputs_pre = torch.cat([embeddings_to, visual_outputs_pre], dim=1)
visual_outputs_pre = visual_outputs_pre + self.positional_embedding.to(visual_outputs_pre.dtype)
visual_outputs_pre = self.ln_pre(visual_outputs_pre)
hidden_states = torch.stack(hidden_states_stack, dim=0)
return hidden_states

def forward_pre(self, pixel_values: torch.Tensor):
hidden_states = self.embeddings(pixel_values)
hidden_states = self.ln_pre(hidden_states)
# NLD -> LND
visual_outputs_pre = visual_outputs_pre.permute(1, 0, 2)
return visual_outputs_pre
hidden_states = hidden_states.permute(1, 0, 2)
return hidden_states

def forward_post(self, hidden_state: torch.Tensor):
visual_output_post = hidden_state.permute(1, 0, 2)
Expand Down Expand Up @@ -370,6 +400,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return pooled_output


# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->BridgeTower
class BridgeTowerSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
Expand Down Expand Up @@ -456,15 +487,16 @@ def forward(
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device)
position_ids_l = position_ids_l.view(-1, 1)
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r

positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility

if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
Expand All @@ -476,7 +508,7 @@ def forward(

attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
# Apply the attention mask is (precomputed for all layers in BridgeTowerModel forward() function)
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
Expand All @@ -486,6 +518,7 @@ def forward(
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)

# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask

Expand Down Expand Up @@ -700,6 +733,7 @@ def feed_forward_chunk(self, attention_output):
return layer_output


# Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->BridgeTowerText
class BridgeTowerTextEncoder(nn.Module):
def __init__(self, config):
super().__init__()
Expand All @@ -717,7 +751,7 @@ def forward(
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = None,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None
Expand All @@ -733,6 +767,7 @@ def forward(
past_key_value = past_key_values[i] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:

if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
Expand Down Expand Up @@ -884,6 +919,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
return position_ids.unsqueeze(0).expand(input_shape)


# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
Expand All @@ -900,39 +936,6 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
return incremental_indices.long() + padding_idx


@dataclass
class BridgeTowerModelOutput(ModelOutput):
"""
Output type of [`BridgeTowerModel`].

Args:
text_features (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, hidden_size)`):
Sequence of hidden-states at the text output of the last layer of the model.
image_features (`torch.FloatTensor` of shape `(batch_size, image_sequence_length, hidden_size)`):
Sequence of hidden-states at the image output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size x 2)`):
Concatenation of last layer hidden-state of the first token of the text and image sequence (classification
token), respectively, after further processing through layers used for auxiliary pretraining tasks.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

text_features: torch.FloatTensor = None
image_features: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None


class BridgeTowerPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
Expand Down Expand Up @@ -974,11 +977,11 @@ class BridgeTowerVisionModel(BridgeTowerPreTrainedModel):

def __init__(self, config):
super().__init__(config)
self.visual = BridgeTowerVisualTransformer(config)
self.visual = BridgeTowerVisionTransformer(config)

@property
def dtype(self):
return self.visual.conv1.weight.dtype
return self.visual.embeddings.patch_embedding.weight.dtype

def forward(self, image, image_mask=None):
return self.visual(image.type(self.dtype), image_mask)
Expand Down