diff --git a/docs/source/en/model_doc/llava.md b/docs/source/en/model_doc/llava.md index ee7d9bbd1af9..1d0bdd49b91c 100644 --- a/docs/source/en/model_doc/llava.md +++ b/docs/source/en/model_doc/llava.md @@ -66,6 +66,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h - A [similar notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LLaVa/Inference_with_LLaVa_for_multimodal_generation.ipynb) showcasing batched inference. 🌎 +## LlavaImageProcessor + +[[autodoc]] LlavaImageProcessor + - preprocess + ## LlavaConfig [[autodoc]] LlavaConfig diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 027cf495466c..384dcf29684d 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1303,6 +1303,7 @@ _import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"]) _import_structure["models.layoutlmv3"].extend(["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"]) _import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"]) + _import_structure["models.llava"].append("LlavaImageProcessor") _import_structure["models.mask2former"].append("Mask2FormerImageProcessor") _import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"]) _import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"]) @@ -6071,6 +6072,7 @@ LayoutLMv3ImageProcessor, ) from .models.levit import LevitFeatureExtractor, LevitImageProcessor + from .models.llava import LlavaImageProcessor from .models.mask2former import Mask2FormerImageProcessor from .models.maskformer import ( MaskFormerFeatureExtractor, diff --git a/src/transformers/models/llava/__init__.py b/src/transformers/models/llava/__init__.py index 11aedf9476cf..53825ca84651 100644 --- a/src/transformers/models/llava/__init__.py +++ b/src/transformers/models/llava/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available _import_structure = {"configuration_llava": ["LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlavaConfig"]} @@ -32,6 +32,14 @@ ] _import_structure["processing_llava"] = ["LlavaProcessor"] +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["image_processing_llava"] = ["LlavaImageProcessor"] + if TYPE_CHECKING: from .configuration_llava import LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlavaConfig @@ -49,6 +57,14 @@ ) from .processing_llava import LlavaProcessor + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .image_processing_llava import LlavaImageProcessor + else: import sys diff --git a/src/transformers/models/llava/configuration_llava.py b/src/transformers/models/llava/configuration_llava.py index 1f174bc1b423..4f0d6cf4112f 100644 --- a/src/transformers/models/llava/configuration_llava.py +++ b/src/transformers/models/llava/configuration_llava.py @@ -53,7 +53,12 @@ class LlavaConfig(PretrainedConfig): The index of the layer to select the vision feature. vocab_size (`int`, *optional*, defaults to 32000): Vocabulary size of the Llava model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`~LlavaForConditionalGeneration`] + `inputs_ids` passed when calling [`~LlavaForConditionalGeneration`]. + use_image_newline_parameter (`bool`, *optional*, defaults to `False`): + Whether to add a trainable parameter for the image newline token. + image_grid_pinpoints (`List`, *optional*): + A list of possible resolutions to use for processing high resolution images. Each item in the list should be a tuple or list + of the form `(height, width)`. Only used by the newer LLaVa 1.6 variant. Example: @@ -89,14 +94,25 @@ def __init__( vision_feature_select_strategy="default", vision_feature_layer=-2, vocab_size=32000, + use_image_newline_parameter=False, + image_grid_pinpoints=None, **kwargs, ): self.ignore_index = ignore_index self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act + + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError( + "vision_feature_select_strategy should be one of 'default', 'full'." + f"Got: {vision_feature_select_strategy}" + ) + self.vision_feature_select_strategy = vision_feature_select_strategy self.vision_feature_layer = vision_feature_layer self.vocab_size = vocab_size + self.use_image_newline_parameter = use_image_newline_parameter + self.image_grid_pinpoints = image_grid_pinpoints self.vision_config = vision_config diff --git a/src/transformers/models/llava/convert_llava_1_6_to_hf.py b/src/transformers/models/llava/convert_llava_1_6_to_hf.py new file mode 100644 index 000000000000..b00cc3af6c0f --- /dev/null +++ b/src/transformers/models/llava/convert_llava_1_6_to_hf.py @@ -0,0 +1,279 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert LLaVa 1.6 checkpoints from the original repository. + +URL: https://github.com/haotian-liu/LLaVA/tree/main. + + +The command used to obtain original logits is the following: +python llava/eval/run_llava.py --model-path "liuhaotian/llava-v1.6-mistral-7b" --image-file "images/llava_v1_5_radar.jpg" --query "What is shown in this image?" --max_new_tokens 100 --temperature 0 +""" + +import argparse +import glob +import json +from pathlib import Path + +import requests +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download, snapshot_download +from PIL import Image +from safetensors import safe_open + +from transformers import ( + AddedToken, + AutoConfig, + AutoTokenizer, + LlavaConfig, + LlavaForConditionalGeneration, + LlavaImageProcessor, + LlavaProcessor, +) + + +KEYS_TO_MODIFY_MAPPING = { + "model.vision_tower.": "", + "model.mm_projector": "multi_modal_projector", + "model": "model.model", + "vision_model.model": "vision_model", + "lm_head": "language_model.lm_head", + "model.model": "language_model.model", + "multi_modal_projector.0": "multi_modal_projector.linear_1", + "multi_modal_projector.2": "multi_modal_projector.linear_2", + "language_model.model.image_newline": "image_newline", +} + + +def load_original_state_dict(model_id): + directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"]) + + original_state_dict = {} + for path in glob.glob(f"{directory_path}/*"): + if path.endswith(".safetensors"): + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + original_state_dict[key] = f.get_tensor(key) + + return original_state_dict + + +def convert_state_dict_to_hf(state_dict): + new_state_dict = {} + for key, value in state_dict.items(): + if key.endswith(".inv_freq"): + continue + for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + new_state_dict[key] = value.to(torch.float16) + return new_state_dict + + +def load_image(): + url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" + image = Image.open(requests.get(url, stream=True).raw) + return image + + +def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): + # load original config + filepath = hf_hub_download(repo_id=model_id, filename="config.json", repo_type="model") + # read json + with open(filepath) as f: + data = json.load(f) + print(data) + + if model_id == "liuhaotian/llava-v1.6-mistral-7b": + text_model_id = data["_name_or_path"] + elif model_id == "liuhaotian/llava-v1.6-vicuna-7b": + text_model_id = "lmsys/vicuna-7b-v1.5" + vision_model_id = data["mm_vision_tower"] + + torch.set_default_dtype(torch.float16) + text_config = AutoConfig.from_pretrained(text_model_id) + + tokenizer = AutoTokenizer.from_pretrained(text_model_id) + tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True) + tokenizer.add_special_tokens({"pad_token": ""}) + + image_processor = LlavaImageProcessor.from_pretrained(vision_model_id) + processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor) + + config = LlavaConfig( + text_config=text_config.to_dict(), + image_grid_pinpoints=image_processor.image_grid_pinpoints, + use_image_newline_parameter=True, + ) + config.pad_token_id = 32001 + + with init_empty_weights(): + model = LlavaForConditionalGeneration(config) + + # load original state dict + state_dict = load_original_state_dict(model_id) + state_dict = convert_state_dict_to_hf(state_dict) + model.load_state_dict(state_dict, assign=True) + model.eval() + + pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data + mu = torch.mean(pre_expansion_embeddings, dim=0).float() + n = pre_expansion_embeddings.size()[0] + sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n + dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma) + + # We add an image token so we resize the model + # Pad to 64 for performance reasons + pad_shape = 64 + model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape) + model.language_model.model.embed_tokens.weight.data[32000:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[32000:].shape[0]))), + dim=0, + ) + model.language_model.lm_head.weight.data[32000:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[32000:].shape[0]))), + dim=0, + ) + + device = "cuda:2" + model.to(device) + + # prepare inputs + image = load_image() + if model_id == "liuhaotian/llava-v1.6-mistral-7b": + prompt = "[INST] \nWhat is shown in this image? [/INST]" + elif model_id == "liuhaotian/llava-v1.6-vicuna-7b": + prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nWhat is shown in this image? ASSISTANT:" + inputs = processor(images=image, text=prompt, return_tensors="pt") + + filepath = hf_hub_download(repo_id="nielsr/test-image", filename="llava_1_6_input_ids.pt", repo_type="dataset") + original_input_ids = torch.load(filepath, map_location="cpu") + filepath = hf_hub_download(repo_id="nielsr/test-image", filename="llava_1_6_pixel_values.pt", repo_type="dataset") + original_pixel_values = torch.load(filepath, map_location="cpu") + + # verify inputs + if model_id == "liuhaotian/llava-v1.6-mistral-7b": + # replace -200 by 32000 (since we use token ID = 32000 for the image token) + original_input_ids[original_input_ids == -200] = 32000 + print(tokenizer.decode([id for id in original_input_ids.tolist()[0] if id != -200])) + + assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist() + assert torch.allclose(original_pixel_values, inputs.pixel_values.half()) + + # verify single forward pass + image_sizes = torch.tensor([[899, 1024]]) + assert image_sizes[0].tolist() == inputs.image_sizes[0].tolist() + + print("Single forward pass") + with torch.inference_mode(): + inputs = inputs.to(device) + outputs = model(**inputs) + print("Shape of logits:", outputs.logits.shape) + print("First values of logits:", outputs.logits[0, :3, :3]) + + if model_id == "liuhaotian/llava-v1.6-mistral-7b": + expected_slice = torch.tensor( + [[-4.8555, -4.6992, -0.1996], [-10.5703, -10.7344, -2.7246], [-7.0391, -7.3672, -0.2634]], + dtype=torch.float32, + device=device, + ) + elif model_id == "liuhaotian/llava-v1.6-vicuna-7b": + expected_slice = torch.tensor( + [[1.4883, 0.9976, -0.6992], [-9.7031, -5.7031, -1.5557], [-5.1328, -5.5586, 8.8281]], + dtype=torch.float32, + device=device, + ) + else: + raise ValueError(f"Model {model_id} not supported") + + assert torch.allclose(outputs.logits[0, :3, :3], expected_slice, atol=1e-4) + print("Logits are ok!") + + # verify generation + output_ids = model.generate( + **inputs, + max_new_tokens=100, + use_cache=True, + ) + + generated_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip() + + if model_id == "liuhaotian/llava-v1.6-mistral-7b": + expected_text = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point.\n\nIn this particular radar chart, there are several axes labeled with different metrics or benchmarks, such as "MMM-Vet," "MMM-Bench," "LLaVA-Bench," "SLED-Bench," "' + elif model_id == "liuhaotian/llava-v1.6-vicuna-7b": + expected_text = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\'s questions. USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a graphical representation of a benchmarking study comparing the performance of various models or systems. It\'s a scatter plot with a circular layout, where each point represents a different model or system, and the axes represent different metrics or dimensions of comparison.\n\nThe metrics are likely related to machine learning or artificial intelligence performance, as indicated by the terms like "BLIP-2," "Instruct BLIP," "POE," "QWA," "V""" + else: + raise ValueError(f"Model {model_id} not supported") + + assert generated_text == expected_text + print("Generated text is ok!") + + # verify batched generation + print("Batched generation...") + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + cats_image = Image.open(requests.get(url, stream=True).raw) + + inputs = processor(images=[image, cats_image], text=[prompt, prompt], padding=True, return_tensors="pt").to(device) + + for k, v in inputs.items(): + print(k, v.shape) + + print("Image sizes:", inputs.image_sizes) + + # make sure image_sizes are the same + # as otherwise batched generation doesn't work + inputs.image_sizes[1] = inputs.image_sizes[0] + + print("Batched generation...") + output_ids = model.generate( + **inputs, + max_new_tokens=20, + use_cache=True, + ) + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + print(outputs) + + if pytorch_dump_folder_path is not None: + print(f"Saving model and processor for {model_id} to {pytorch_dump_folder_path}") + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + repo_id = model_id.split("/")[-1] + model.push_to_hub(f"llava-hf/{repo_id}-hf") + processor.push_to_hub(f"llava-hf/{repo_id}-hf") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_id", + help="Hub location of the model to convert", + default="liuhaotian/llava-v1.6-mistral-7b", + choices=["liuhaotian/llava-v1.6-mistral-7b", "liuhaotian/llava-v1.6-vicuna-7b"], + required=False, + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + args = parser.parse_args() + + convert_llava_to_hf(args.model_id, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/src/transformers/models/llava/convert_llava_weights_to_hf.py b/src/transformers/models/llava/convert_llava_weights_to_hf.py index bb40668f32c7..ba4ee6b5347f 100644 --- a/src/transformers/models/llava/convert_llava_weights_to_hf.py +++ b/src/transformers/models/llava/convert_llava_weights_to_hf.py @@ -11,6 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +Convert LLaVa 1.5 and BakLLaVa models from the original repository. + +URL: https://github.com/haotian-liu/LLaVA/tree/main. +""" + import argparse import torch diff --git a/src/transformers/models/llava/image_processing_llava.py b/src/transformers/models/llava/image_processing_llava.py new file mode 100644 index 000000000000..6ac3fc8cf381 --- /dev/null +++ b/src/transformers/models/llava/image_processing_llava.py @@ -0,0 +1,709 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for LLaVa.""" + +import math +from typing import Dict, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + get_resize_output_image_size, + pad, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_vision_available, logging + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + from PIL import Image + + +def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple: + """ + Selects the best resolution from a list of possible resolutions based on the original size. + + This is done by calculating the effective and wasted resolution for each possible resolution. + + The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution. + + Args: + original_size (tuple): + The original size of the image in the format (height, width). + possible_resolutions (list): + A list of possible resolutions in the format [(height1, width1), (height2, width2), ...]. + + Returns: + tuple: The best fit resolution in the format (height, width). + """ + original_height, original_width = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float("inf") + + for height, width in possible_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution + ): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (height, width) + + return best_fit + + +def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]: + """ + Divides an image into patches of a specified size. + + Args: + image (`np.array`): + The input image. + patch_size (`int`): + The size of each patch. + input_data_format (`ChannelDimension` or `str`): + The channel dimension format of the input image. + + Returns: + list: A list of np.array representing the patches. + """ + patches = [] + height, width = image.shape[:-1] if input_data_format == ChannelDimension.LAST else image.shape[1:] + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + if input_data_format == ChannelDimension.LAST: + patch = image[i : i + patch_size, j : j + patch_size] + else: + patch = image[:, i : i + patch_size, j : j + patch_size] + patches.append(patch) + + return patches + + +def expand_to_square(image: np.array, background_color, input_data_format) -> np.array: + """ + Expands an image to a square by adding a background color. + """ + + height, width = get_image_size(image, channel_dim=input_data_format) + if width == height: + return image + elif width > height: + result = np.ones((width, width, image.shape[2]), dtype=image.dtype) * background_color + result[(width - height) // 2 : (width - height) // 2 + height, :] = image + return result + else: + result = np.ones((height, height, image.shape[2]), dtype=image.dtype) * background_color + result[:, (height - width) // 2 : (height - width) // 2 + width] = image + return result + + +class LlavaImageProcessor(BaseImageProcessor): + r""" + Constructs a LLaVa image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques + for processing high resolution images as explained in the [LLaVa paper](https://arxiv.org/abs/2310.03744). + + Args: + aspect_ratio_setting (`str`, *optional*, defaults to `"anyres"`): + The aspect ratio setting to use. Can be "clip" (as in CLIP), "pad" (as in LLaVa 1.5) or "anyres" (as in LLaVa 1.6). + an be overridden by `aspect_ratio_setting` in the `preprocess` method. + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + image_grid_pinpoints (`List` *optional*, defaults to `[[672, 336], [336, 672], [672, 672], [336, 1008], [1008, 336]]`): + A list of possible resolutions to use for processing high resolution images. The best resolution is selected + based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + aspect_ratio_setting: str = "anyres", + do_resize: bool = True, + size: Dict[str, int] = None, + image_grid_pinpoints: List = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + crop_size: Dict[str, int] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 224} + size = get_size_dict(size, default_to_square=False) + image_grid_pinpoints = ( + image_grid_pinpoints + if image_grid_pinpoints is not None + else [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]] + ) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + + self.aspect_ratio_setting = aspect_ratio_setting + self.do_resize = do_resize + self.size = size + self.image_grid_pinpoints = image_grid_pinpoints + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + + # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize with CLIP->LLaVa + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + default_to_square = True + if "shortest_edge" in size: + size = size["shortest_edge"] + default_to_square = False + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + size=size, + default_to_square=default_to_square, + input_data_format=input_data_format, + ) + + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def _preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: bool = None, + crop_size: int = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> Image.Image: + """ + Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + images = make_list_of_images(images) + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_center_crop: + images = [ + self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + return images + + def _resize_for_patching( + self, image: np.array, target_resolution: tuple, resample, input_data_format: ChannelDimension + ) -> np.array: + """ + Resizes an image to a target resolution while maintaining aspect ratio. + + Args: + image (np.array): + The input image. + target_resolution (tuple): + The target resolution (height, width) of the image. + resample (`PILImageResampling`): + Resampling filter to use if resizing the image. + input_data_format (`ChannelDimension` or `str`): + The channel dimension format of the input image. + + Returns: + np.array: The resized and padded image. + """ + original_height, original_width = get_image_size(image, channel_dim=input_data_format) + target_height, target_width = target_resolution + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.ceil(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.ceil(original_width * scale_h), target_width) + + # Resize the image + resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format) + + return resized_image + + def _pad_for_patching( + self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension + ) -> np.array: + """ + Pad an image to a target resolution while maintaining aspect ratio. + """ + + original_height, original_width = get_image_size(image, channel_dim=input_data_format) + target_height, target_width = target_resolution + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.ceil(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.ceil(original_width * scale_h), target_width) + + paste_x = (target_width - new_width) // 2 + paste_y = (target_height - new_height) // 2 + + padded_image = pad(image, padding=((paste_y, paste_y), (paste_x, paste_x))) + + return padded_image + + def get_image_patches( + self, + image: np.array, + grid_pinpoints, + size: tuple, + patch_size: int, + resample: PILImageResampling, + data_format: ChannelDimension, + input_data_format: ChannelDimension, + ) -> List[np.array]: + """ + Process an image with variable resolutions by dividing it into patches. + + Args: + image (np.array): + The input image to be processed. + grid_pinpoints (List): + A string representation of a list of possible resolutions. + size (`tuple`): + Size to resize the original image to. + patch_size (`int`): + Size of the patches to divide the image into. + resample (`PILImageResampling`): + Resampling filter to use if resizing the image. + data_format (`ChannelDimension` or `str`): + The channel dimension format for the output image. + input_data_format (`ChannelDimension` or `str`): + The channel dimension format of the input image. + + Returns: + List[np.array]: A list of NumPy arrays containing the processed image patches. + """ + if not isinstance(grid_pinpoints, list): + raise ValueError("grid_pinpoints must be a list of possible resolutions.") + + possible_resolutions = grid_pinpoints + + image_size = get_image_size(image, channel_dim=input_data_format) + best_resolution = select_best_resolution(image_size, possible_resolutions) + resized_image = self._resize_for_patching( + image, best_resolution, resample=resample, input_data_format=input_data_format + ) + padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format) + + patches = divide_to_patches(padded_image, patch_size=patch_size, input_data_format=input_data_format) + + # make sure that all patches use the input data format + patches = [ + to_channel_dimension_format(patch, channel_dim=data_format, input_channel_dim=input_data_format) + for patch in patches + ] + + resized_original_image = resize( + image, + size=size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + ) + + image_patches = [resized_original_image] + patches + + return image_patches + + def preprocess( + self, + images: ImageInput, + aspect_ratio_setting=None, + do_resize: bool = None, + size: Dict[str, int] = None, + image_grid_pinpoints: List = None, + resample: PILImageResampling = None, + do_center_crop: bool = None, + crop_size: int = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + aspect_ratio_setting (`str`, *optional*, defaults to `"anyres"`): + The aspect ratio setting to use. Can be "clip" (as in CLIP), "pad" (LLaVa 1.5) or "anyres" (LLaVa 1.6). + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + image_grid_pinpoints (`List` *optional*, defaults to `self.image_grid_pinpoints`): + A list of possible resolutions to use for processing high resolution images. The best resolution is + selected based on the original size of the image. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + aspect_ratio_setting = aspect_ratio_setting if aspect_ratio_setting is not None else self.aspect_ratio_setting + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + image_grid_pinpoints = image_grid_pinpoints if image_grid_pinpoints is not None else self.image_grid_pinpoints + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if aspect_ratio_setting not in ["clip", "pad", "anyres"]: + raise ValueError(f"Invalid aspect ratio setting: {aspect_ratio_setting}") + + new_images = [] + image_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images] + for image in images: + if aspect_ratio_setting == "anyres": + # convert image into a list of patches + # we intentially use the same data format as the input data format + image_patches = self.get_image_patches( + image, + image_grid_pinpoints, + size=(size["shortest_edge"], size["shortest_edge"]), + patch_size=crop_size["height"], + resample=resample, + data_format=input_data_format, + input_data_format=input_data_format, + ) + + # preprocess patches + pixel_values = self._preprocess( + image_patches, + do_resize=do_resize, + size=size, + resample=resample, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + input_data_format=input_data_format, + ) + pixel_values = np.array(pixel_values) + + elif aspect_ratio_setting == "pad": + # pad image to square + image = expand_to_square( + image, + background_color=tuple(int(x * 255) for x in self.image_mean), + input_data_format=input_data_format, + ) + # preprocess image + pixel_values = self._preprocess( + image, + do_resize=do_resize, + size=size, + resample=resample, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + input_data_format=input_data_format, + )[0] + + elif aspect_ratio_setting == "clip": + pixel_values = self._preprocess( + image, + do_resize=do_resize, + size=size, + resample=resample, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + input_data_format=input_data_format, + )[0] + + new_images.append(pixel_values) + + if aspect_ratio_setting == "anyres": + data = {"pixel_values": new_images, "image_sizes": image_sizes} + elif aspect_ratio_setting in ["clip", "pad"]: + data = {"pixel_values": new_images} + + return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 4264af04a4f0..d26dfb8f1f6d 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Llava model.""" + from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -32,6 +33,7 @@ ) from ..auto import AutoModel, AutoModelForCausalLM from .configuration_llava import LlavaConfig +from .image_processing_llava import select_best_resolution logger = logging.get_logger(__name__) @@ -46,6 +48,62 @@ ] +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (`tuple`): + The size of the input image in the format (width, height). + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if not isinstance(grid_pinpoints, list): + raise ValueError("grid_pinpoints should be a list of tuples or lists") + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (`torch.Tensor`): + The image tensor, assumed to be of shape (num_channels, height, width). + original_size (`tuple`): + The original size of the image (height, width). + + Returns: + `torch.Tensor`: The unpadded image tensor. + """ + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + @dataclass # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava class LlavaCausalLMOutputWithPast(ModelOutput): @@ -240,6 +298,10 @@ def __init__(self, config: LlavaConfig): self.vision_tower = AutoModel.from_config(config.vision_config) self.multi_modal_projector = LlavaMultiModalProjector(config) + + if config.use_image_newline_parameter: + self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size, dtype=self.dtype)) + self.vocab_size = config.vocab_size self.language_model = AutoModelForCausalLM.from_config( config.text_config, attn_implementation=config._attn_implementation @@ -279,6 +341,7 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): num_images, num_image_patches, embed_dim = image_features.shape batch_size, sequence_length = input_ids.shape + left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) # 1. Create a mask to know where special image tokens are special_image_token_mask = input_ids == self.config.image_token_index @@ -351,6 +414,7 @@ def forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, + image_sizes: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, @@ -409,25 +473,77 @@ def forward( ) if inputs_embeds is None: - # 1. Extra the input embeddings + # 1. Extract the input embeddings inputs_embeds = self.get_input_embeddings()(input_ids) # 2. Merge text and images if pixel_values is not None and input_ids.shape[1] != 1: - image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) - # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. - selected_image_feature = image_outputs.hidden_states[vision_feature_layer] - - if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature + if pixel_values.ndim == 5: + batch_size, num_patches, num_channels, height, width = pixel_values.shape + reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width) + image_features = self.vision_tower(reshaped_pixel_values, output_hidden_states=True) + + selected_image_feature = image_features.hidden_states[vision_feature_layer] + + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + image_features = self.multi_modal_projector(selected_image_feature) + + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [image.shape[0] for image in pixel_values] + image_features = torch.split(image_features, split_sizes, dim=0) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size + + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + + if height * width != base_image_feature.shape[0]: + raise ValueError("The number of patches is not consistent with the image size.") + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0) + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) + else: - raise ValueError( - f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" - ) + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. + image_features = self.vision_tower(pixel_values, output_hidden_states=True) + selected_image_feature = image_features.hidden_states[vision_feature_layer] + + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + + image_features = self.multi_modal_projector(selected_image_feature) - image_features = self.multi_modal_projector(selected_image_feature) inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids, attention_mask, labels ) @@ -508,7 +624,14 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + image_sizes=None, + attention_mask=None, + **kwargs, ): if past_key_values is not None: if isinstance(past_key_values, Cache): @@ -556,6 +679,7 @@ def prepare_inputs_for_generation( "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "pixel_values": pixel_values, + "image_sizes": image_sizes, } ) return model_inputs diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 1ba1b30e6590..89d423ed0bd2 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -28,23 +28,28 @@ class LlavaProcessor(ProcessorMixin): r""" - Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor. + Constructs a Llava processor which wraps any image processor and a Llava tokenizer into a single processor. [`LlavaProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the [`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information. Args: - image_processor ([`CLIPImageProcessor`], *optional*): + image_processor ([`CLIPImageProcessor` or `LlavaImageProcessor`], *optional*): The image processor is a required input. tokenizer ([`LlamaTokenizerFast`], *optional*): The tokenizer is a required input. """ attributes = ["image_processor", "tokenizer"] - image_processor_class = "CLIPImageProcessor" + image_processor_class = ("CLIPImageProcessor", "LlavaImageProcessor") tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") def __init__(self, image_processor=None, tokenizer=None): + if image_processor.__class__.__name__ not in ["CLIPImageProcessor", "LlavaImageProcessor"]: + raise ValueError( + f"`image_processor` has to be of type `CLIPImageProcessor` or `LlavaImageProcessor`, but is {type(image_processor)}" + ) + super().__init__(image_processor, tokenizer) def __call__( @@ -103,14 +108,14 @@ def __call__( - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ if images is not None: - pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] + image_inputs = self.image_processor(images, return_tensors=return_tensors) else: - pixel_values = None + image_inputs = {} text_inputs = self.tokenizer( text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length ) - return BatchFeature(data={**text_inputs, "pixel_values": pixel_values}) + return BatchFeature(data={**text_inputs, **image_inputs}) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index ecb8613a7e4d..f0b58b1c9af6 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -241,7 +241,6 @@ def _supports_sdpa(self): """The VIPLLAVA model which consists of a vision backbone and a language model.""", VIPLLAVA_START_DOCSTRING, ) -# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration with LLAVA->VIPLLAVA,Llava->VipLlava class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): def __init__(self, config: VipLlavaConfig): super().__init__(config) diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 25a35558fe9c..becc08df069c 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -310,6 +310,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class LlavaImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class Mask2FormerImageProcessor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/llava/test_image_processing_llava.py b/tests/models/llava/test_image_processing_llava.py new file mode 100644 index 000000000000..2df860632e3e --- /dev/null +++ b/tests/models/llava/test_image_processing_llava.py @@ -0,0 +1,253 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from transformers.models.llava.image_processing_llava import select_best_resolution +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + from transformers import LlavaImageProcessor + + +class LlavaImageProcessingTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=None, + do_center_crop=True, + crop_size=None, + do_normalize=True, + image_mean=OPENAI_CLIP_MEAN, + image_std=OPENAI_CLIP_STD, + do_convert_rgb=True, + aspect_ratio_setting="clip", + ): + size = size if size is not None else {"shortest_edge": 20} + crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18} + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + self.aspect_ratio_setting = aspect_ratio_setting + + def prepare_image_processor_dict(self): + return { + "aspect_ratio_setting": self.aspect_ratio_setting, + "do_resize": self.do_resize, + "size": self.size, + "do_center_crop": self.do_center_crop, + "crop_size": self.crop_size, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_convert_rgb": self.do_convert_rgb, + } + + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.expected_output_image_shape + def expected_output_image_shape(self, images): + return self.num_channels, self.crop_size["height"], self.crop_size["width"] + + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + return prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + + +@require_torch +@require_vision +class LlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = LlavaImageProcessor if is_vision_available() else None + + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->Llava + def setUp(self): + self.image_processor_tester = LlavaImageProcessingTester(self) + + @property + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_center_crop")) + self.assertTrue(hasattr(image_processing, "center_crop")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + self.assertTrue(hasattr(image_processing, "aspect_ratio_setting")) + self.assertTrue(hasattr(image_processing, "image_grid_pinpoints")) + + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.test_image_processor_from_dict_with_kwargs + def test_image_processor_from_dict_with_kwargs(self): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 20}) + self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) + + image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) + self.assertEqual(image_processor.size, {"shortest_edge": 42}) + self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + + def test_select_best_resolution(self): + possible_resolutions = [[672, 336], [336, 672], [672, 672], [336, 1008], [1008, 336]] + + # Test with a square aspect ratio + best_resolution = select_best_resolution((336, 336), possible_resolutions) + self.assertEqual(best_resolution, (672, 336)) + + @unittest.skip("LlavaImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy + def test_call_numpy_4_channels(self): + pass + + +@require_torch +@require_vision +class LlavaImageProcessingAnyAspectRatioTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = LlavaImageProcessor if is_vision_available() else None + + def setUp(self): + self.image_processor_tester = LlavaImageProcessingTester(self, aspect_ratio_setting="anyres") + + @property + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_center_crop")) + self.assertTrue(hasattr(image_processing, "center_crop")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + self.assertTrue(hasattr(image_processing, "aspect_ratio_setting")) + self.assertTrue(hasattr(image_processing, "image_grid_pinpoints")) + + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.test_image_processor_from_dict_with_kwargs + def test_image_processor_from_dict_with_kwargs(self): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 20}) + self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) + + image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) + self.assertEqual(image_processor.size, {"shortest_edge": 42}) + self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + + def test_select_best_resolution(self): + possible_resolutions = [[672, 336], [336, 672], [672, 672], [336, 1008], [1008, 336]] + + # Test with a square aspect ratio + best_resolution = select_best_resolution((336, 336), possible_resolutions) + self.assertEqual(best_resolution, (672, 336)) + + def test_call_pil(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + def test_call_numpy(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + def test_call_pytorch(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) + + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + @unittest.skip("LlavaImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy + def test_call_numpy_4_channels(self): + pass diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 2ece22f12a20..81812793f0ba 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -370,3 +370,21 @@ def test_llava_merge_inputs_error_bug(self): labels=input_ids, ).loss loss.backward() + + @slow + def test_llava_1_6_model_integration_test(self): + processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") + model = LlavaForConditionalGeneration.from_pretrained( + "llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True + ).to(torch_device) + + # The first batch is longer in terms of text, but only has 1 image. The second batch will be padded in text, but the first will be padded because images take more space!. + prompt = "[INST] \nWhat is shown in this image? [/INST]" + image = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) + + inputs = processor(images=image, text=prompt, return_tensors="pt") + + output = model.generate(**inputs, max_new_tokens=20) + + EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring along', 'USER: \nWhat is this?\nASSISTANT: Cats'] # fmt: skip + self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)