diff --git a/docs/source/en/model_doc/bridgetower.mdx b/docs/source/en/model_doc/bridgetower.mdx index dec59225c521..87015877dc9c 100644 --- a/docs/source/en/model_doc/bridgetower.mdx +++ b/docs/source/en/model_doc/bridgetower.mdx @@ -28,7 +28,7 @@ This enables effective bottom-up cross-modal alignment and fusion between visual In particular, on the VQAv2 test-std set, BRIDGETOWER achieves an accuracy of 78.73%, outperforming the previous state-of-the-art model METER by 1.09% with the same pre-training data and almost negligible additional parameters and computational costs. Notably, when further scaling the model, BRIDGETOWER achieves an accuracy of 81.15%, surpassing models that are pre-trained on orders-of-magnitude larger datasets.* -drawing BridgeTower architecture. Taken from the original paper. diff --git a/src/transformers/models/bridgetower/__init__.py b/src/transformers/models/bridgetower/__init__.py index aaa453cedf73..e92cf933c6cb 100644 --- a/src/transformers/models/bridgetower/__init__.py +++ b/src/transformers/models/bridgetower/__init__.py @@ -27,8 +27,9 @@ "BridgeTowerConfig", "BridgeTowerTextConfig", "BridgeTowerVisionConfig", - ] + ], } +_import_structure["processing_bridgetower"] = ["BridgeTowerProcessor"] try: if not is_vision_available(): @@ -37,7 +38,6 @@ pass else: _import_structure["image_processing_bridgetower"] = ["BridgeTowerImageProcessor"] -_import_structure["processing_bridgetower"] = ["BridgeTowerProcessor"] try: if not is_torch_available(): diff --git a/src/transformers/models/bridgetower/configuration_bridgetower.py b/src/transformers/models/bridgetower/configuration_bridgetower.py index 1677a773d7b1..bd1457719d15 100644 --- a/src/transformers/models/bridgetower/configuration_bridgetower.py +++ b/src/transformers/models/bridgetower/configuration_bridgetower.py @@ -80,6 +80,7 @@ def __init__( self, hidden_size=768, num_hidden_layers=12, + num_channels=3, patch_size=16, image_size=288, initializer_factor=1, @@ -92,6 +93,7 @@ def __init__( super().__init__(**kwargs) self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers + self.num_channels = num_channels self.patch_size = patch_size self.image_size = image_size self.initializer_factor = initializer_factor diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index b196c821b017..ac8e6a772d3f 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -127,6 +127,39 @@ """ +@dataclass +class BridgeTowerModelOutput(ModelOutput): + """ + Output type of [`BridgeTowerModel`]. + + Args: + text_features (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, hidden_size)`): + Sequence of hidden-states at the text output of the last layer of the model. + image_features (`torch.FloatTensor` of shape `(batch_size, image_sequence_length, hidden_size)`): + Sequence of hidden-states at the image output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size x 2)`): + Concatenation of last layer hidden-state of the first token of the text and image sequence (classification + token), respectively, after further processing through layers used for auxiliary pretraining tasks. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + text_features: torch.FloatTensor = None + image_features: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + class BridgeTowerResidualAttention(nn.Module): def __init__(self, config): super().__init__() @@ -197,22 +230,45 @@ def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Ten return hidden_states -class BridgeTowerVisualTransformer(nn.Module): - def __init__(self, config): +class BridgeTowerVisionEmbeddings(nn.Module): + def __init__(self, config: BridgeTowerVisionConfig): super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) - self.conv1 = nn.Conv2d( - in_channels=3, - out_channels=config.hidden_size, - kernel_size=config.patch_size, - stride=config.patch_size, + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, bias=False, ) - scale = config.hidden_size**-0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(config.hidden_size)) - self.positional_embedding = nn.Parameter( - scale * torch.randn((config.image_size // config.patch_size) ** 2 + 1, config.hidden_size) - ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class BridgeTowerVisionTransformer(nn.Module): + def __init__(self, config): + super().__init__() + + self.embeddings = BridgeTowerVisionEmbeddings(config) self.ln_pre = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.transformer = BridgeTowerTransformer(config) self.ln_post = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -222,60 +278,34 @@ def __init__(self, config): [nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) for _ in range(config.num_hidden_layers)] ) - def forward(self, hidden_state: torch.Tensor, attention_mask): - # shape = [*, hidden_size, grid, grid] - visual_output = self.conv1(hidden_state) - # shape = [*, hidden_size, grid ** 2] - visual_output = visual_output.reshape(visual_output.shape[0], visual_output.shape[1], -1) - # shape = [*, grid ** 2, hidden_size] - visual_output = visual_output.permute(0, 2, 1) - t = self.class_embedding.to(visual_output.dtype) + torch.zeros( - visual_output.shape[0], 1, visual_output.shape[-1], dtype=visual_output.dtype, device=visual_output.device - ) - # shape = [*, grid ** 2 + 1, hidden_size] - visual_output = torch.cat([t, visual_output], dim=1) - visual_output = visual_output + self.positional_embedding.to(visual_output.dtype) - visual_output = self.ln_pre(visual_output) + def forward(self, pixel_values: torch.Tensor, attention_mask): + hidden_states = self.embeddings(pixel_values) + hidden_states = self.ln_pre(hidden_states) # NLD -> LND - visual_output = visual_output.permute(1, 0, 2) + hidden_states = hidden_states.permute(1, 0, 2) - visual_outputs = self.transformer(visual_output, attention_mask) + hidden_states = self.transformer(hidden_states, attention_mask) # shape = [num_hidden_layers, hidden_size, *, grid ** 2] - visual_outputs = torch.stack(visual_outputs, dim=0) + hidden_states = torch.stack(hidden_states, dim=0) # shape = [num_hidden_layers, *, hidden_size, grid ** 2] - visual_outputs = visual_outputs.permute(0, 2, 1, 3) + hidden_states = hidden_states.permute(0, 2, 1, 3) if self.share_layernorm: - visual_outputs = self.ln_post(visual_outputs) + hidden_states = self.ln_post(hidden_states) else: - visual_outputs_stack = [] - for visual_output, ln in zip(visual_outputs, self.ln_separate): - visual_output = ln(visual_output) - visual_outputs_stack.append(visual_output) + hidden_states_stack = [] + for hidden_states, ln in zip(hidden_states, self.ln_separate): + hidden_states = ln(hidden_states) + hidden_states_stack.append(hidden_states) # shape = [num_hidden_layers, *, hidden_size, grid ** 2] - visual_outputs = torch.stack(visual_outputs_stack, dim=0) - return visual_outputs - - def forward_pre(self, hidden_state: torch.Tensor): - # shape = [*, hidden_size, grid, grid] - visual_outputs_pre = self.conv1(hidden_state) - # shape = [*, hidden_size, grid ** 2] - visual_outputs_pre = visual_outputs_pre.reshape(visual_outputs_pre.shape[0], visual_outputs_pre.shape[1], -1) - # shape = [*, grid ** 2, hidden_size] - visual_outputs_pre = visual_outputs_pre.permute(0, 2, 1) - embeddings_to = self.class_embedding.to(visual_outputs_pre.dtype) + torch.zeros( - visual_outputs_pre.shape[0], - 1, - visual_outputs_pre.shape[-1], - dtype=visual_outputs_pre.dtype, - device=visual_outputs_pre.device, - ) - # shape = [*, grid ** 2 + 1, hidden_size] - visual_outputs_pre = torch.cat([embeddings_to, visual_outputs_pre], dim=1) - visual_outputs_pre = visual_outputs_pre + self.positional_embedding.to(visual_outputs_pre.dtype) - visual_outputs_pre = self.ln_pre(visual_outputs_pre) + hidden_states = torch.stack(hidden_states_stack, dim=0) + return hidden_states + + def forward_pre(self, pixel_values: torch.Tensor): + hidden_states = self.embeddings(pixel_values) + hidden_states = self.ln_pre(hidden_states) # NLD -> LND - visual_outputs_pre = visual_outputs_pre.permute(1, 0, 2) - return visual_outputs_pre + hidden_states = hidden_states.permute(1, 0, 2) + return hidden_states def forward_post(self, hidden_state: torch.Tensor): visual_output_post = hidden_state.permute(1, 0, 2) @@ -370,6 +400,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return pooled_output +# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->BridgeTower class BridgeTowerSelfAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() @@ -456,15 +487,16 @@ def forward( if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device) - position_ids_l = position_ids_l.view(-1, 1) + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) else: position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) @@ -476,7 +508,7 @@ def forward( attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + # Apply the attention mask is (precomputed for all layers in BridgeTowerModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. @@ -486,6 +518,7 @@ def forward( # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) + # Mask heads if we want to if head_mask is not None: attention_probs = attention_probs * head_mask @@ -700,6 +733,7 @@ def feed_forward_chunk(self, attention_output): return layer_output +# Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->BridgeTowerText class BridgeTowerTextEncoder(nn.Module): def __init__(self, config): super().__init__() @@ -717,7 +751,7 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = False, - output_hidden_states: Optional[bool] = None, + output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None @@ -733,6 +767,7 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: + if use_cache: logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." @@ -884,6 +919,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): return position_ids.unsqueeze(0).expand(input_shape) +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): """ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols @@ -900,39 +936,6 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l return incremental_indices.long() + padding_idx -@dataclass -class BridgeTowerModelOutput(ModelOutput): - """ - Output type of [`BridgeTowerModel`]. - - Args: - text_features (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, hidden_size)`): - Sequence of hidden-states at the text output of the last layer of the model. - image_features (`torch.FloatTensor` of shape `(batch_size, image_sequence_length, hidden_size)`): - Sequence of hidden-states at the image output of the last layer of the model. - pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size x 2)`): - Concatenation of last layer hidden-state of the first token of the text and image sequence (classification - token), respectively, after further processing through layers used for auxiliary pretraining tasks. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - text_features: torch.FloatTensor = None - image_features: torch.FloatTensor = None - pooler_output: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - class BridgeTowerPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -974,11 +977,11 @@ class BridgeTowerVisionModel(BridgeTowerPreTrainedModel): def __init__(self, config): super().__init__(config) - self.visual = BridgeTowerVisualTransformer(config) + self.visual = BridgeTowerVisionTransformer(config) @property def dtype(self): - return self.visual.conv1.weight.dtype + return self.visual.embeddings.patch_embedding.weight.dtype def forward(self, image, image_mask=None): return self.visual(image.type(self.dtype), image_mask)