diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 8a9597892c60..73d715e5941d 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -71,12 +71,14 @@ def __init__( tokenizer=None, patch_size=None, vision_feature_select_strategy=None, + vision_feature_use_cls=True, chat_template=None, image_token="", # set the default and let users change if they have peculiar special tokens in rare cases **kwargs, ): self.patch_size = patch_size self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_use_cls = vision_feature_use_cls self.image_token = image_token super().__init__(image_processor, tokenizer, chat_template=chat_template) @@ -147,7 +149,9 @@ def __call__( # Replace the image token with the expanded image token sequence pixel_values = image_inputs["pixel_values"] height, width = get_image_size(to_numpy_array(pixel_values[0])) - num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1 + num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + if self.vision_feature_use_cls: + num_image_tokens += 1 if self.vision_feature_select_strategy == "default": num_image_tokens -= 1 diff --git a/src/transformers/models/llava_next/processing_llava_next.py b/src/transformers/models/llava_next/processing_llava_next.py index ce11be6d6309..ac4fb9b0cd7a 100644 --- a/src/transformers/models/llava_next/processing_llava_next.py +++ b/src/transformers/models/llava_next/processing_llava_next.py @@ -74,12 +74,14 @@ def __init__( tokenizer=None, patch_size=None, vision_feature_select_strategy=None, + vision_feature_use_cls=True, chat_template=None, image_token="", # set the default and let users change if they have peculiar special tokens in rare cases **kwargs, ): self.patch_size = patch_size self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_use_cls = vision_feature_use_cls self.image_token = image_token super().__init__(image_processor, tokenizer, chat_template=chat_template) @@ -177,8 +179,11 @@ def _get_number_of_features(self, orig_height: int, orig_width: int, height: int unpadded_features, newline_features = self._get_unpadded_features( orig_height, orig_width, patches_height, patches_width, scale_height, scale_width ) + base_features = patches_height * patches_width + # The base patch covers the entire image (+1 for the CLS) - base_features = patches_height * patches_width + 1 + if self.vision_feature_use_cls: + base_features += 1 num_image_tokens = unpadded_features + newline_features + base_features return num_image_tokens diff --git a/src/transformers/models/llava_next_video/processing_llava_next_video.py b/src/transformers/models/llava_next_video/processing_llava_next_video.py index e0e4534e42b5..cd7dc7598838 100644 --- a/src/transformers/models/llava_next_video/processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/processing_llava_next_video.py @@ -221,8 +221,11 @@ def _get_number_of_features(self, orig_height: int, orig_width: int, height: int unpadded_features, newline_features = self._get_unpadded_features( orig_height, orig_width, patches_height, patches_width, scale_height, scale_width ) + base_features = patches_height * patches_width + # The base patch covers the entire image (+1 for the CLS) - base_features = patches_height * patches_width + 1 + if self.vision_feature_use_cls: + base_features += 1 num_image_tokens = unpadded_features + newline_features + base_features return num_image_tokens