diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index aef9b8cebec9..e7df05785886 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -233,7 +233,6 @@ def __init__(self, config: BlipVisionConfig): self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) - # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution @@ -245,14 +244,14 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: """ num_patches = embeddings.shape[1] - 1 - num_positions = self.position_embeddings.shape[1] - 1 + num_positions = self.position_embedding.shape[1] - 1 # always interpolate when tracing to ensure the exported model works for dynamic input shapes if not torch.jit.is_tracing() and num_patches == num_positions and height == width: - return self.position_embeddings + return self.position_embedding - class_pos_embed = self.position_embeddings[:, :1] - patch_pos_embed = self.position_embeddings[:, 1:] + class_pos_embed = self.position_embedding[:, :1] + patch_pos_embed = self.position_embedding[:, 1:] dim = embeddings.shape[-1] diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 0b33572a689c..4b0ed4f71d9c 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -200,7 +200,6 @@ def __init__(self, config: Blip2VisionConfig): self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) - # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution @@ -212,14 +211,14 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: """ num_patches = embeddings.shape[1] - 1 - num_positions = self.position_embeddings.shape[1] - 1 + num_positions = self.position_embedding.shape[1] - 1 # always interpolate when tracing to ensure the exported model works for dynamic input shapes if not torch.jit.is_tracing() and num_patches == num_positions and height == width: - return self.position_embeddings + return self.position_embedding - class_pos_embed = self.position_embeddings[:, :1] - patch_pos_embed = self.position_embeddings[:, 1:] + class_pos_embed = self.position_embedding[:, :1] + patch_pos_embed = self.position_embedding[:, 1:] dim = embeddings.shape[-1] diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index dff897f59d2d..de4e84b82f83 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -104,7 +104,6 @@ def __init__(self, config: InstructBlipVisionConfig): self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) - # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution @@ -116,14 +115,14 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: """ num_patches = embeddings.shape[1] - 1 - num_positions = self.position_embeddings.shape[1] - 1 + num_positions = self.position_embedding.shape[1] - 1 # always interpolate when tracing to ensure the exported model works for dynamic input shapes if not torch.jit.is_tracing() and num_patches == num_positions and height == width: - return self.position_embeddings + return self.position_embedding - class_pos_embed = self.position_embeddings[:, :1] - patch_pos_embed = self.position_embeddings[:, 1:] + class_pos_embed = self.position_embedding[:, :1] + patch_pos_embed = self.position_embedding[:, 1:] dim = embeddings.shape[-1] diff --git a/src/transformers/models/instructblip/processing_instructblip.py b/src/transformers/models/instructblip/processing_instructblip.py index f6d35c1e6f72..dc6c9deaf177 100644 --- a/src/transformers/models/instructblip/processing_instructblip.py +++ b/src/transformers/models/instructblip/processing_instructblip.py @@ -122,8 +122,10 @@ def __call__( elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError("Invalid input text. Please provide a string, or a list of strings") - _text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) - + # we have to concatenate lists - so we keep track of return_tensors here + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + _text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None) + output_kwargs["text_kwargs"]["return_tensors"] = return_tensors # if we know how many query tokens, expand text inside processor. We need this hacky manipulation # because BLIP expects image tokens to be at the beginning even before BOS token if self.num_query_tokens is not None and images is not None: @@ -145,9 +147,7 @@ def __call__( ) # cast to desired return tensors type after concatenating - text_encoding = BatchEncoding( - text_encoding, tensor_type=output_kwargs["common_kwargs"].get("return_tensors") - ) + text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors) encoding.update(text_encoding) qformer_text_encoding = self.qformer_tokenizer(text, **output_kwargs["text_kwargs"]) diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index e165fb411af8..0808aa58b855 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -111,7 +111,6 @@ def __init__(self, config: InstructBlipVideoVisionConfig): self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) - # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution @@ -123,14 +122,14 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: """ num_patches = embeddings.shape[1] - 1 - num_positions = self.position_embeddings.shape[1] - 1 + num_positions = self.position_embedding.shape[1] - 1 # always interpolate when tracing to ensure the exported model works for dynamic input shapes if not torch.jit.is_tracing() and num_patches == num_positions and height == width: - return self.position_embeddings + return self.position_embedding - class_pos_embed = self.position_embeddings[:, :1] - patch_pos_embed = self.position_embeddings[:, 1:] + class_pos_embed = self.position_embedding[:, :1] + patch_pos_embed = self.position_embedding[:, 1:] dim = embeddings.shape[-1]