From 67bd5334370873abcf9ab126103a363bdc89df60 Mon Sep 17 00:00:00 2001 From: Anahita Bhiwandiwalla Date: Tue, 24 Jan 2023 11:12:48 -0800 Subject: [PATCH 1/4] Added copied to, some more review feedback --- docs/source/en/model_doc/bridgetower.mdx | 2 +- .../models/bridgetower/__init__.py | 4 +- .../bridgetower/modeling_bridgetower.py | 73 ++++++++++--------- 3 files changed, 41 insertions(+), 38 deletions(-) diff --git a/docs/source/en/model_doc/bridgetower.mdx b/docs/source/en/model_doc/bridgetower.mdx index dec59225c521..87015877dc9c 100644 --- a/docs/source/en/model_doc/bridgetower.mdx +++ b/docs/source/en/model_doc/bridgetower.mdx @@ -28,7 +28,7 @@ This enables effective bottom-up cross-modal alignment and fusion between visual In particular, on the VQAv2 test-std set, BRIDGETOWER achieves an accuracy of 78.73%, outperforming the previous state-of-the-art model METER by 1.09% with the same pre-training data and almost negligible additional parameters and computational costs. Notably, when further scaling the model, BRIDGETOWER achieves an accuracy of 81.15%, surpassing models that are pre-trained on orders-of-magnitude larger datasets.* -drawing BridgeTower architecture. Taken from the original paper. diff --git a/src/transformers/models/bridgetower/__init__.py b/src/transformers/models/bridgetower/__init__.py index aaa453cedf73..0282a48b694a 100644 --- a/src/transformers/models/bridgetower/__init__.py +++ b/src/transformers/models/bridgetower/__init__.py @@ -27,7 +27,8 @@ "BridgeTowerConfig", "BridgeTowerTextConfig", "BridgeTowerVisionConfig", - ] + ], + "processing_bridgetower": "BridgeTowerProcessor" } try: @@ -37,7 +38,6 @@ pass else: _import_structure["image_processing_bridgetower"] = ["BridgeTowerImageProcessor"] -_import_structure["processing_bridgetower"] = ["BridgeTowerProcessor"] try: if not is_torch_available(): diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index b196c821b017..129ae00a4f5f 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -127,6 +127,39 @@ """ +@dataclass +class BridgeTowerModelOutput(ModelOutput): + """ + Output type of [`BridgeTowerModel`]. + + Args: + text_features (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, hidden_size)`): + Sequence of hidden-states at the text output of the last layer of the model. + image_features (`torch.FloatTensor` of shape `(batch_size, image_sequence_length, hidden_size)`): + Sequence of hidden-states at the image output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size x 2)`): + Concatenation of last layer hidden-state of the first token of the text and image sequence (classification + token), respectively, after further processing through layers used for auxiliary pretraining tasks. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + text_features: torch.FloatTensor = None + image_features: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + class BridgeTowerResidualAttention(nn.Module): def __init__(self, config): super().__init__() @@ -197,7 +230,7 @@ def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Ten return hidden_states -class BridgeTowerVisualTransformer(nn.Module): +class BridgeTowerVisionTransformer(nn.Module): def __init__(self, config): super().__init__() @@ -370,6 +403,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return pooled_output +# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->BridgeTower class BridgeTowerSelfAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() @@ -700,6 +734,7 @@ def feed_forward_chunk(self, attention_output): return layer_output +# Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->BridgeTowerText class BridgeTowerTextEncoder(nn.Module): def __init__(self, config): super().__init__() @@ -884,6 +919,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): return position_ids.unsqueeze(0).expand(input_shape) +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): """ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols @@ -900,39 +936,6 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l return incremental_indices.long() + padding_idx -@dataclass -class BridgeTowerModelOutput(ModelOutput): - """ - Output type of [`BridgeTowerModel`]. - - Args: - text_features (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, hidden_size)`): - Sequence of hidden-states at the text output of the last layer of the model. - image_features (`torch.FloatTensor` of shape `(batch_size, image_sequence_length, hidden_size)`): - Sequence of hidden-states at the image output of the last layer of the model. - pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size x 2)`): - Concatenation of last layer hidden-state of the first token of the text and image sequence (classification - token), respectively, after further processing through layers used for auxiliary pretraining tasks. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - text_features: torch.FloatTensor = None - image_features: torch.FloatTensor = None - pooler_output: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - class BridgeTowerPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -974,7 +977,7 @@ class BridgeTowerVisionModel(BridgeTowerPreTrainedModel): def __init__(self, config): super().__init__(config) - self.visual = BridgeTowerVisualTransformer(config) + self.visual = BridgeTowerVisionTransformer(config) @property def dtype(self): From 1d7e6631df75464cb951261575ce23f2695fdadb Mon Sep 17 00:00:00 2001 From: Anahita Bhiwandiwalla Date: Tue, 24 Jan 2023 12:19:32 -0800 Subject: [PATCH 2/4] make fixup --- src/transformers/models/bridgetower/__init__.py | 2 +- .../models/bridgetower/modeling_bridgetower.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/bridgetower/__init__.py b/src/transformers/models/bridgetower/__init__.py index 0282a48b694a..e92cf933c6cb 100644 --- a/src/transformers/models/bridgetower/__init__.py +++ b/src/transformers/models/bridgetower/__init__.py @@ -28,8 +28,8 @@ "BridgeTowerTextConfig", "BridgeTowerVisionConfig", ], - "processing_bridgetower": "BridgeTowerProcessor" } +_import_structure["processing_bridgetower"] = ["BridgeTowerProcessor"] try: if not is_vision_available(): diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 129ae00a4f5f..e3462f815d24 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -490,15 +490,16 @@ def forward( if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device) - position_ids_l = position_ids_l.view(-1, 1) + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) else: position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) @@ -510,7 +511,7 @@ def forward( attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + # Apply the attention mask is (precomputed for all layers in BridgeTowerModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. @@ -520,6 +521,7 @@ def forward( # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) + # Mask heads if we want to if head_mask is not None: attention_probs = attention_probs * head_mask @@ -752,7 +754,7 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = False, - output_hidden_states: Optional[bool] = None, + output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None @@ -768,6 +770,7 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: + if use_cache: logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." From 90c2ec47450f692e0932d5a85af7d616f0a75484 Mon Sep 17 00:00:00 2001 From: Shaoyen Tseng Date: Tue, 24 Jan 2023 15:25:38 -0800 Subject: [PATCH 3/4] Use BridgeTowerVisionEmbeddings --- .../bridgetower/configuration_bridgetower.py | 2 + .../bridgetower/modeling_bridgetower.py | 151 +++++++++++------- 2 files changed, 96 insertions(+), 57 deletions(-) diff --git a/src/transformers/models/bridgetower/configuration_bridgetower.py b/src/transformers/models/bridgetower/configuration_bridgetower.py index 1677a773d7b1..bd1457719d15 100644 --- a/src/transformers/models/bridgetower/configuration_bridgetower.py +++ b/src/transformers/models/bridgetower/configuration_bridgetower.py @@ -80,6 +80,7 @@ def __init__( self, hidden_size=768, num_hidden_layers=12, + num_channels=3, patch_size=16, image_size=288, initializer_factor=1, @@ -92,6 +93,7 @@ def __init__( super().__init__(**kwargs) self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers + self.num_channels = num_channels self.patch_size = patch_size self.image_size = image_size self.initializer_factor = initializer_factor diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index e3462f815d24..ef152baeb437 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -230,22 +230,57 @@ def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Ten return hidden_states -class BridgeTowerVisionTransformer(nn.Module): - def __init__(self, config): +class BridgeTowerVisionEmbeddings(nn.Module): + def __init__(self, config: BridgeTowerVisionConfig): super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size - self.conv1 = nn.Conv2d( - in_channels=3, - out_channels=config.hidden_size, - kernel_size=config.patch_size, - stride=config.patch_size, + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, bias=False, ) - scale = config.hidden_size**-0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(config.hidden_size)) - self.positional_embedding = nn.Parameter( - scale * torch.randn((config.image_size // config.patch_size) ** 2 + 1, config.hidden_size) - ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class BridgeTowerVisionTransformer(nn.Module): + def __init__(self, config): + super().__init__() + + # self.conv1 = nn.Conv2d( + # in_channels=3, + # out_channels=config.hidden_size, + # kernel_size=config.patch_size, + # stride=config.patch_size, + # bias=False, + # ) + # scale = config.hidden_size**-0.5 + # self.class_embedding = nn.Parameter(scale * torch.randn(config.hidden_size)) + # self.positional_embedding = nn.Parameter( + # scale * torch.randn((config.image_size // config.patch_size) ** 2 + 1, config.hidden_size) + # ) + self.embeddings = BridgeTowerVisionEmbeddings(config) self.ln_pre = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.transformer = BridgeTowerTransformer(config) self.ln_post = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -255,60 +290,62 @@ def __init__(self, config): [nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) for _ in range(config.num_hidden_layers)] ) - def forward(self, hidden_state: torch.Tensor, attention_mask): + def forward(self, pixel_values: torch.Tensor, attention_mask): # shape = [*, hidden_size, grid, grid] - visual_output = self.conv1(hidden_state) - # shape = [*, hidden_size, grid ** 2] - visual_output = visual_output.reshape(visual_output.shape[0], visual_output.shape[1], -1) - # shape = [*, grid ** 2, hidden_size] - visual_output = visual_output.permute(0, 2, 1) - t = self.class_embedding.to(visual_output.dtype) + torch.zeros( - visual_output.shape[0], 1, visual_output.shape[-1], dtype=visual_output.dtype, device=visual_output.device - ) - # shape = [*, grid ** 2 + 1, hidden_size] - visual_output = torch.cat([t, visual_output], dim=1) - visual_output = visual_output + self.positional_embedding.to(visual_output.dtype) - visual_output = self.ln_pre(visual_output) + # visual_output = self.conv1(hidden_state) + # # shape = [*, hidden_size, grid ** 2] + # visual_output = visual_output.reshape(visual_output.shape[0], visual_output.shape[1], -1) + # # shape = [*, grid ** 2, hidden_size] + # visual_output = visual_output.permute(0, 2, 1) + # t = self.class_embedding.to(visual_output.dtype) + torch.zeros( + # visual_output.shape[0], 1, visual_output.shape[-1], dtype=visual_output.dtype, device=visual_output.device + # ) + # # shape = [*, grid ** 2 + 1, hidden_size] + # visual_output = torch.cat([t, visual_output], dim=1) + # visual_output = visual_output + self.positional_embedding.to(visual_output.dtype) + hidden_states = self.embeddings(pixel_values) + hidden_states = self.ln_pre(hidden_states) # NLD -> LND - visual_output = visual_output.permute(1, 0, 2) + hidden_states = hidden_states.permute(1, 0, 2) - visual_outputs = self.transformer(visual_output, attention_mask) + hidden_states = self.transformer(hidden_states, attention_mask) # shape = [num_hidden_layers, hidden_size, *, grid ** 2] - visual_outputs = torch.stack(visual_outputs, dim=0) + hidden_states = torch.stack(hidden_states, dim=0) # shape = [num_hidden_layers, *, hidden_size, grid ** 2] - visual_outputs = visual_outputs.permute(0, 2, 1, 3) + hidden_states = hidden_states.permute(0, 2, 1, 3) if self.share_layernorm: - visual_outputs = self.ln_post(visual_outputs) + hidden_states = self.ln_post(hidden_states) else: - visual_outputs_stack = [] - for visual_output, ln in zip(visual_outputs, self.ln_separate): - visual_output = ln(visual_output) - visual_outputs_stack.append(visual_output) + hidden_states_stack = [] + for hidden_states, ln in zip(hidden_states, self.ln_separate): + hidden_states = ln(hidden_states) + hidden_states_stack.append(hidden_states) # shape = [num_hidden_layers, *, hidden_size, grid ** 2] - visual_outputs = torch.stack(visual_outputs_stack, dim=0) - return visual_outputs + hidden_states = torch.stack(hidden_states_stack, dim=0) + return hidden_states - def forward_pre(self, hidden_state: torch.Tensor): - # shape = [*, hidden_size, grid, grid] - visual_outputs_pre = self.conv1(hidden_state) - # shape = [*, hidden_size, grid ** 2] - visual_outputs_pre = visual_outputs_pre.reshape(visual_outputs_pre.shape[0], visual_outputs_pre.shape[1], -1) - # shape = [*, grid ** 2, hidden_size] - visual_outputs_pre = visual_outputs_pre.permute(0, 2, 1) - embeddings_to = self.class_embedding.to(visual_outputs_pre.dtype) + torch.zeros( - visual_outputs_pre.shape[0], - 1, - visual_outputs_pre.shape[-1], - dtype=visual_outputs_pre.dtype, - device=visual_outputs_pre.device, - ) - # shape = [*, grid ** 2 + 1, hidden_size] - visual_outputs_pre = torch.cat([embeddings_to, visual_outputs_pre], dim=1) - visual_outputs_pre = visual_outputs_pre + self.positional_embedding.to(visual_outputs_pre.dtype) - visual_outputs_pre = self.ln_pre(visual_outputs_pre) + def forward_pre(self, pixel_values: torch.Tensor): + # # shape = [*, hidden_size, grid, grid] + # visual_outputs_pre = self.conv1(hidden_state) + # # shape = [*, hidden_size, grid ** 2] + # visual_outputs_pre = visual_outputs_pre.reshape(visual_outputs_pre.shape[0], visual_outputs_pre.shape[1], -1) + # # shape = [*, grid ** 2, hidden_size] + # visual_outputs_pre = visual_outputs_pre.permute(0, 2, 1) + # embeddings_to = self.class_embedding.to(visual_outputs_pre.dtype) + torch.zeros( + # visual_outputs_pre.shape[0], + # 1, + # visual_outputs_pre.shape[-1], + # dtype=visual_outputs_pre.dtype, + # device=visual_outputs_pre.device, + # ) + # # shape = [*, grid ** 2 + 1, hidden_size] + # visual_outputs_pre = torch.cat([embeddings_to, visual_outputs_pre], dim=1) + # visual_outputs_pre = visual_outputs_pre + self.positional_embedding.to(visual_outputs_pre.dtype) + hidden_states = self.embeddings(pixel_values) + hidden_states = self.ln_pre(hidden_states) # NLD -> LND - visual_outputs_pre = visual_outputs_pre.permute(1, 0, 2) - return visual_outputs_pre + hidden_states = hidden_states.permute(1, 0, 2) + return hidden_states def forward_post(self, hidden_state: torch.Tensor): visual_output_post = hidden_state.permute(1, 0, 2) @@ -984,7 +1021,7 @@ def __init__(self, config): @property def dtype(self): - return self.visual.conv1.weight.dtype + return self.visual.embeddings.patch_embedding.weight.dtype def forward(self, image, image_mask=None): return self.visual(image.type(self.dtype), image_mask) From 09457d411eb920811feefcd777c34411d1a7408f Mon Sep 17 00:00:00 2001 From: Anahita Bhiwandiwalla Date: Tue, 24 Jan 2023 17:11:26 -0800 Subject: [PATCH 4/4] Code cleanup --- .../bridgetower/modeling_bridgetower.py | 40 ------------------- 1 file changed, 40 deletions(-) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index ef152baeb437..ac8e6a772d3f 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -268,18 +268,6 @@ class BridgeTowerVisionTransformer(nn.Module): def __init__(self, config): super().__init__() - # self.conv1 = nn.Conv2d( - # in_channels=3, - # out_channels=config.hidden_size, - # kernel_size=config.patch_size, - # stride=config.patch_size, - # bias=False, - # ) - # scale = config.hidden_size**-0.5 - # self.class_embedding = nn.Parameter(scale * torch.randn(config.hidden_size)) - # self.positional_embedding = nn.Parameter( - # scale * torch.randn((config.image_size // config.patch_size) ** 2 + 1, config.hidden_size) - # ) self.embeddings = BridgeTowerVisionEmbeddings(config) self.ln_pre = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.transformer = BridgeTowerTransformer(config) @@ -291,18 +279,6 @@ def __init__(self, config): ) def forward(self, pixel_values: torch.Tensor, attention_mask): - # shape = [*, hidden_size, grid, grid] - # visual_output = self.conv1(hidden_state) - # # shape = [*, hidden_size, grid ** 2] - # visual_output = visual_output.reshape(visual_output.shape[0], visual_output.shape[1], -1) - # # shape = [*, grid ** 2, hidden_size] - # visual_output = visual_output.permute(0, 2, 1) - # t = self.class_embedding.to(visual_output.dtype) + torch.zeros( - # visual_output.shape[0], 1, visual_output.shape[-1], dtype=visual_output.dtype, device=visual_output.device - # ) - # # shape = [*, grid ** 2 + 1, hidden_size] - # visual_output = torch.cat([t, visual_output], dim=1) - # visual_output = visual_output + self.positional_embedding.to(visual_output.dtype) hidden_states = self.embeddings(pixel_values) hidden_states = self.ln_pre(hidden_states) # NLD -> LND @@ -325,22 +301,6 @@ def forward(self, pixel_values: torch.Tensor, attention_mask): return hidden_states def forward_pre(self, pixel_values: torch.Tensor): - # # shape = [*, hidden_size, grid, grid] - # visual_outputs_pre = self.conv1(hidden_state) - # # shape = [*, hidden_size, grid ** 2] - # visual_outputs_pre = visual_outputs_pre.reshape(visual_outputs_pre.shape[0], visual_outputs_pre.shape[1], -1) - # # shape = [*, grid ** 2, hidden_size] - # visual_outputs_pre = visual_outputs_pre.permute(0, 2, 1) - # embeddings_to = self.class_embedding.to(visual_outputs_pre.dtype) + torch.zeros( - # visual_outputs_pre.shape[0], - # 1, - # visual_outputs_pre.shape[-1], - # dtype=visual_outputs_pre.dtype, - # device=visual_outputs_pre.device, - # ) - # # shape = [*, grid ** 2 + 1, hidden_size] - # visual_outputs_pre = torch.cat([embeddings_to, visual_outputs_pre], dim=1) - # visual_outputs_pre = visual_outputs_pre + self.positional_embedding.to(visual_outputs_pre.dtype) hidden_states = self.embeddings(pixel_values) hidden_states = self.ln_pre(hidden_states) # NLD -> LND