diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index dfccab53ea30..0acf31eb2d99 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -1010,15 +1010,52 @@ def __init__(self, config: AltCLIPVisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_patches = embeddings.shape[1] - 1 + num_positions = position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return position_embeddings + class_pos_embed = position_embeddings[:, 0] + patch_pos_embed = position_embeddings[:, 1:] + dim = embeddings.shape[-1] + height = height // self.config.patch_size + width = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings @@ -1099,6 +1136,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -1113,7 +1151,7 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs = self.encoder( @@ -1158,6 +1196,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -1188,6 +1227,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1548,6 +1588,7 @@ def get_image_features( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" @@ -1580,6 +1621,7 @@ def get_image_features( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1600,6 +1642,7 @@ def forward( return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, AltCLIPOutput]: r""" @@ -1644,6 +1687,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 91cbda9b72ed..f9c3c62ecaa9 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -276,15 +276,52 @@ def __init__(self, config: BridgeTowerVisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_patches = embeddings.shape[1] - 1 + num_positions = position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return position_embeddings + class_pos_embed = position_embeddings[:, 0] + patch_pos_embed = position_embeddings[:, 1:] + dim = embeddings.shape[-1] + height = height // self.config.patch_size + width = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings @@ -302,8 +339,13 @@ def __init__(self, config): [nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) for _ in range(config.num_hidden_layers)] ) - def forward(self, pixel_values: torch.Tensor, attention_mask): - hidden_states = self.embeddings(pixel_values) + def forward( + self, + pixel_values: torch.Tensor, + attention_mask, + interpolate_pos_encoding: bool = False, + ): + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding) hidden_states = self.ln_pre(hidden_states) # NLD -> LND hidden_states = hidden_states.permute(1, 0, 2) @@ -324,8 +366,12 @@ def forward(self, pixel_values: torch.Tensor, attention_mask): hidden_states = torch.stack(hidden_states_stack, dim=0) return hidden_states - def forward_pre(self, pixel_values: torch.Tensor): - hidden_states = self.embeddings(pixel_values) + def forward_pre( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = False, + ): + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.ln_pre(hidden_states) # NLD -> LND hidden_states = hidden_states.permute(1, 0, 2) @@ -1015,8 +1061,8 @@ def __init__(self, config): def dtype(self): return self.visual.embeddings.patch_embedding.weight.dtype - def forward(self, image, image_mask=None): - return self.visual(image.type(self.dtype), image_mask) + def forward(self, image, image_mask=None, interpolate_pos_encoding=False): + return self.visual(image.type(self.dtype), image_mask, interpolate_pos_encoding) class BridgeTowerTextModel(BridgeTowerPreTrainedModel): @@ -1280,6 +1326,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple[torch.Tensor], BridgeTowerModelOutput]: r""" output_hidden_states (`bool`, *optional*): @@ -1352,7 +1399,9 @@ def forward( all_hidden_states_text += (text_embeds,) if image_embeds is None: - image_embeds = self.vision_model.visual.forward_pre(pixel_values.type(self.vision_model.dtype)) + image_embeds = self.vision_model.visual.forward_pre( + pixel_values.type(self.vision_model.dtype), interpolate_pos_encoding=interpolate_pos_encoding + ) else: # Permute as BridgeTowerResidualAttention has batch_first=True image_embeds = image_embeds.permute(1, 0, 2) diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 573a39fb1c29..f69f41795ce1 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -189,15 +189,52 @@ def __init__(self, config: ChineseCLIPVisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_patches = embeddings.shape[1] - 1 + num_positions = position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return position_embeddings + class_pos_embed = position_embeddings[:, 0] + patch_pos_embed = position_embeddings[:, 1:] + dim = embeddings.shape[-1] + height = height // self.config.patch_size + width = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings @@ -799,6 +836,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -814,6 +853,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -1053,6 +1094,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -1067,7 +1109,7 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs = self.encoder( @@ -1300,6 +1342,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -1330,6 +1373,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1426,6 +1470,7 @@ def get_image_features( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" @@ -1462,6 +1507,7 @@ def get_image_features( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1482,6 +1528,7 @@ def forward( return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, ChineseCLIPOutput]: r""" @@ -1517,6 +1564,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index c2fc3424a9ce..eaccac458399 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -15,6 +15,7 @@ """ PyTorch CLIP model.""" +import math from dataclasses import dataclass from typing import Any, Optional, Tuple, Union @@ -179,15 +180,52 @@ def __init__(self, config: CLIPVisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_patches = embeddings.shape[1] - 1 + num_positions = position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return position_embeddings + class_pos_embed = position_embeddings[:, 0] + patch_pos_embed = position_embeddings[:, 1:] + dim = embeddings.shape[-1] + height = height // self.config.patch_size + width = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings @@ -518,8 +556,11 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ CLIP_INPUTS_DOCSTRING = r""" @@ -555,6 +596,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -833,6 +876,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -847,7 +891,7 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs = self.encoder( @@ -898,6 +942,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -928,6 +973,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) @@ -1022,6 +1068,7 @@ def get_image_features( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> torch.FloatTensor: r""" Returns: @@ -1057,6 +1104,7 @@ def get_image_features( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) pooled_output = vision_outputs[1] # pooled_output @@ -1076,6 +1124,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, CLIPOutput]: r""" Returns: @@ -1113,6 +1162,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) text_outputs = self.text_model( @@ -1270,6 +1320,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, CLIPVisionModelOutput]: r""" Returns: @@ -1299,6 +1350,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) pooled_output = vision_outputs[1] # pooled_output diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 00dcecff2d26..4cfc8870ae86 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -819,7 +819,6 @@ def __init__(self, config: CLIPSegVisionConfig): @add_start_docstrings_to_model_forward(CLIPSEG_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegVisionConfig) - # Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward def forward( self, pixel_values: Optional[torch.FloatTensor] = None, diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 87b0cb8073b5..f7ed9cab6d33 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -605,15 +605,52 @@ def __init__(self, config: GitVisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_patches = embeddings.shape[1] - 1 + num_positions = position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return position_embeddings + class_pos_embed = position_embeddings[:, 0] + patch_pos_embed = position_embeddings[:, 1:] + dim = embeddings.shape[-1] + height = height // self.config.patch_size + width = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 161ebbf95c1f..b81417f35eeb 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -121,6 +121,8 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -259,6 +261,8 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -401,15 +405,52 @@ def __init__(self, config: Kosmos2VisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_patches = embeddings.shape[1] - 1 + num_positions = position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return position_embeddings + class_pos_embed = position_embeddings[:, 0] + patch_pos_embed = position_embeddings[:, 1:] + dim = embeddings.shape[-1] + height = height // self.config.patch_size + width = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings @@ -701,6 +742,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -712,7 +754,7 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs = self.encoder( @@ -1444,6 +1486,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -1453,6 +1496,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1769,6 +1813,7 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, Kosmos2ModelOutput]: r""" @@ -1820,6 +1865,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`. diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 092ea9476173..7508a7743d52 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -15,6 +15,7 @@ """ PyTorch X-CLIP model.""" +import math from copy import copy from dataclasses import dataclass from typing import Any, Optional, Tuple, Union @@ -122,15 +123,52 @@ def __init__(self, config: XCLIPVisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_patches = embeddings.shape[1] - 1 + num_positions = position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return position_embeddings + class_pos_embed = position_embeddings[:, 0] + patch_pos_embed = position_embeddings[:, 1:] + dim = embeddings.shape[-1] + height = height // self.config.patch_size + width = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings @@ -568,6 +606,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -605,6 +645,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -955,6 +997,7 @@ def forward( pixel_values: torch.FloatTensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -967,7 +1010,7 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layernorm(hidden_states) encoder_outputs = self.encoder( @@ -1457,6 +1500,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, XCLIPOutput]: r""" Returns: @@ -1556,6 +1600,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/tests/models/altclip/test_modeling_altclip.py b/tests/models/altclip/test_modeling_altclip.py index 8c5be789c02b..64bd2a0c8d1d 100755 --- a/tests/models/altclip/test_modeling_altclip.py +++ b/tests/models/altclip/test_modeling_altclip.py @@ -594,3 +594,37 @@ def test_inference(self): expected_probs = torch.tensor([[9.9942e-01, 5.7805e-04]], device=torch_device) self.assertTrue(torch.allclose(probs, expected_probs, atol=5e-3)) + + @slow + def test_inference_interpolate_pos_encoding(self): + # ViT models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model_name = "BAAI/AltCLIP" + model = AltCLIPModel.from_pretrained(model_name).to(torch_device) + + image_processor = AltCLIPProcessor.from_pretrained(model_name) + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = image_processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 257, 1024)) + print("nilesh ") + print(outputs.vision_model_output.last_hidden_state.shape) + print(outputs.vision_model_output.last_hidden_state[0, :3, :3]) + + self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[-0.5297, -0.7713, 0.4655], [0.8688, 0.1690, 0.6678], [1.1742, -0.7551, 0.0396]] + ).to(torch_device) + + self.assertTrue( + torch.allclose(outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4) + ) diff --git a/tests/models/bridgetower/test_modeling_bridgetower.py b/tests/models/bridgetower/test_modeling_bridgetower.py index c9ce8f076fc3..d6ae6ecae1ab 100644 --- a/tests/models/bridgetower/test_modeling_bridgetower.py +++ b/tests/models/bridgetower/test_modeling_bridgetower.py @@ -656,3 +656,32 @@ def test_training(self): for name, param in model.named_parameters(): if self._is_layer_used(model_class, name): self.assertIsNotNone(param.grad, f"Gradients should not be None - got {param.grad} for {name}") + + @slow + def test_inference_interpolate_pos_encoding(self): + # ViT models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model_name = "BridgeTower/bridgetower-base" + model = BridgeTowerModel.from_pretrained(model_name).to(torch_device) + + image_processor = BridgeTowerProcessor.from_pretrained(model_name) + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = image_processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 325, 768)) + + self.assertEqual(outputs.image_features.shape, expected_shape) + + expected_slice = torch.tensor( + [[0.3433, 0.4557, -0.5287], [-0.7111, 0.6576, -1.0850], [-0.2122, 0.2021, -0.0536]] + ).to(torch_device) + + self.assertTrue(torch.allclose(outputs.image_features[0, :3, :3], expected_slice, atol=1e-4)) diff --git a/tests/models/chinese_clip/test_modeling_chinese_clip.py b/tests/models/chinese_clip/test_modeling_chinese_clip.py index 8ee9028eca26..9396f39c36ef 100644 --- a/tests/models/chinese_clip/test_modeling_chinese_clip.py +++ b/tests/models/chinese_clip/test_modeling_chinese_clip.py @@ -736,3 +736,34 @@ def test_inference(self): expected_probs = torch.tensor([[1.2686e-03, 5.4499e-02, 6.7968e-04, 9.4355e-01]], device=torch_device) self.assertTrue(torch.allclose(probs, expected_probs, atol=5e-3)) + + @slow + def test_inference_interpolate_pos_encoding(self): + # ViT models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model_name = "OFA-Sys/chinese-clip-vit-base-patch16" + model = ChineseCLIPModel.from_pretrained(model_name).to(torch_device) + + image_processor = ChineseCLIPProcessor.from_pretrained(model_name) + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = image_processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 197, 768)) + + self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[-0.3374, 0.3212, -0.1293], [-0.2208, -0.6150, 0.7010], [-0.1901, -0.6576, 0.4843]] + ).to(torch_device) + + self.assertTrue( + torch.allclose(outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4) + ) diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index 16c8f47b782f..65ace71aee9e 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -842,3 +842,33 @@ def test_inference(self): expected_logits = torch.tensor([[24.5701, 19.3049]], device=torch_device) self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3)) + + @slow + def test_inference_interpolate_pos_encoding(self): + # ViT models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(torch_device) + + image_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", size=480) + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = image_processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 50, 768)) + + self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[-0.0966, 0.3521, -0.3485], [0.5785, 0.8967, 0.3586], [0.2314, 0.3896, 0.2557]] + ).to(torch_device) + + self.assertTrue( + torch.allclose(outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4) + ) diff --git a/tests/models/kosmos2/test_modeling_kosmos2.py b/tests/models/kosmos2/test_modeling_kosmos2.py index 9bc95b8bd44c..2c66bdf207a5 100644 --- a/tests/models/kosmos2/test_modeling_kosmos2.py +++ b/tests/models/kosmos2/test_modeling_kosmos2.py @@ -761,3 +761,33 @@ def test_snowman_image_captioning_batch(self): self.assertEqual(processed_text[0], EXPECTED_PROCESSED_TEXT_0) self.assertEqual(all_final_text[0], EXPECTED_FINAL_TEXT_0) self.assertListEqual(all_entities[0], EXPECTED_ENTITIES_0) + + @slow + def test_inference_interpolate_pos_encoding(self): + # ViT models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model = Kosmos2Model.from_pretrained("microsoft/kosmos-2-patch14-224").to(torch_device) + + processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224", padding_side="left") + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 257, 1024)) + + self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[1.4228, -1.9611, 3.8449], [3.4988, 2.0516, 0.3597], [3.1699, 0.2604, -0.4210]] + ).to(torch_device) + + self.assertTrue( + torch.allclose(outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4) + ) diff --git a/tests/models/x_clip/test_modeling_x_clip.py b/tests/models/x_clip/test_modeling_x_clip.py index fc5c1679a659..a42a8c3acf85 100644 --- a/tests/models/x_clip/test_modeling_x_clip.py +++ b/tests/models/x_clip/test_modeling_x_clip.py @@ -728,3 +728,35 @@ def test_inference(self): expected_logits = torch.tensor([[14.0181, 20.2771, 14.4776]], device=torch_device) self.assertTrue(torch.allclose(outputs.logits_per_video, expected_logits, atol=1e-3)) + + @slow + def test_inference_interpolate_pos_encoding(self): + # ViT models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model = XCLIPModel.from_pretrained("microsoft/xclip-base-patch32").to(torch_device) + + image_processor = XCLIPProcessor.from_pretrained("microsoft/xclip-base-patch32", size=480) + + video = prepare_video() + inputs = image_processor(text="what's in the video", videos=video, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((8, 50, 768)) + + self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[0.1806, 0.3649, -0.0850], [0.0210, 0.3411, -0.0637], [0.2307, 0.3106, -0.2027]] + ).to(torch_device) + + print(outputs.vision_model_output.last_hidden_state[0, :3, :3]) + + self.assertTrue( + torch.allclose(outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4) + )