Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
de62574
First draft
NielsRogge Feb 11, 2024
32e8de8
More improvements
NielsRogge Feb 12, 2024
4675e20
More improvements
NielsRogge Feb 12, 2024
621f956
More improvements
NielsRogge Feb 12, 2024
e47d690
Improve conversion script
NielsRogge Feb 12, 2024
3af438b
Improve conversion script
NielsRogge Feb 12, 2024
5be6e77
Improve script
NielsRogge Feb 12, 2024
f6fe6ca
Update script
NielsRogge Feb 12, 2024
fd68ed9
More improvements
NielsRogge Feb 12, 2024
9be0258
Convert logits
NielsRogge Feb 12, 2024
66b7ebc
Add generation
NielsRogge Feb 12, 2024
9790272
More improvements
NielsRogge Feb 12, 2024
908c122
More improvements
NielsRogge Feb 12, 2024
11a5902
Make image_sizes a tensor
NielsRogge Feb 14, 2024
ef82dc5
Add support for batched generation
NielsRogge Feb 16, 2024
adb84f0
Use appropriate prompt
NielsRogge Feb 16, 2024
37e6d16
More improvements
NielsRogge Feb 16, 2024
28405c5
Make fixup
NielsRogge Feb 16, 2024
c2848ad
Fix docstrings
NielsRogge Feb 16, 2024
92228b0
Improve conversion script
NielsRogge Feb 17, 2024
b8911d9
More improvements
NielsRogge Feb 17, 2024
e44c47b
Make fixup
NielsRogge Feb 17, 2024
05a5cfe
Debug
NielsRogge Feb 17, 2024
7a71fa2
Merge remote-tracking branch 'upstream/main' into add_llava_1_6
NielsRogge Feb 17, 2024
9fa5c0a
More improvements
NielsRogge Feb 17, 2024
dc46dc1
Support padding of image
NielsRogge Feb 17, 2024
eaf307e
Merge remote-tracking branch 'upstream/main' into add_llava_1_6
NielsRogge Feb 19, 2024
31978d8
Remove unused image_aspect_ratio
NielsRogge Feb 19, 2024
f7218a9
Add use_image_newline_parameter
NielsRogge Feb 19, 2024
bab61c9
More improvements
NielsRogge Feb 19, 2024
e3adbec
Remove script
NielsRogge Feb 19, 2024
b5c19ba
Merge remote-tracking branch 'upstream/main' into add_llava_1_6
NielsRogge Feb 19, 2024
756e4be
Add integration test
NielsRogge Feb 20, 2024
4cb3083
Merge remote-tracking branch 'upstream/main' into add_llava_1_6
NielsRogge Feb 20, 2024
e3677f8
Address comments
NielsRogge Mar 1, 2024
db7ffa8
Address comments
NielsRogge Mar 1, 2024
46638d9
Address comments
NielsRogge Mar 2, 2024
92d7c3a
Address comments
NielsRogge Mar 2, 2024
0372a62
Address comments
NielsRogge Mar 2, 2024
0b3ac7a
Address comments
NielsRogge Mar 2, 2024
6b2da1f
Address comments
NielsRogge Mar 2, 2024
955a945
Address comments
NielsRogge Mar 2, 2024
6ead7f7
Address comments
NielsRogge Mar 2, 2024
329ec41
Address comments
NielsRogge Mar 2, 2024
088d446
Address comments
NielsRogge Mar 2, 2024
84d7210
Address comments
NielsRogge Mar 2, 2024
1daa83d
Merge remote-tracking branch 'upstream/main' into add_llava_1_6
NielsRogge Mar 2, 2024
a8a13f8
Address comment
NielsRogge Mar 2, 2024
4413426
Make image_sizes height width
NielsRogge Mar 2, 2024
f9e5276
Improve _preprocess
NielsRogge Mar 2, 2024
08fa428
Address comments
NielsRogge Mar 2, 2024
297d6e0
Address comment
NielsRogge Mar 2, 2024
77ecb17
Rename attribute
NielsRogge Mar 2, 2024
c889558
Address comment
NielsRogge Mar 2, 2024
ecbe64b
Use height width everywhere
NielsRogge Mar 2, 2024
7b4da2f
Use pad
NielsRogge Mar 2, 2024
ce2ea8c
Rename variables
NielsRogge Mar 2, 2024
a0997ae
Add image processor tests
NielsRogge Mar 2, 2024
947ff66
Improve tests
NielsRogge Mar 2, 2024
a47d4ff
Improve image processor
NielsRogge Mar 2, 2024
7b91acd
Update modeling
NielsRogge Mar 3, 2024
e0dea6a
Address comment
NielsRogge Mar 4, 2024
071be69
Address comment
NielsRogge Mar 4, 2024
3539acc
Make fixup
NielsRogge Mar 4, 2024
2f6e28d
Address comments
NielsRogge Mar 4, 2024
05e3611
Add resample and input_data_format
NielsRogge Mar 4, 2024
1ccb416
Address comments
NielsRogge Mar 4, 2024
0f1357a
Fix image processor tests
NielsRogge Mar 4, 2024
c5db76e
Add data_format
NielsRogge Mar 4, 2024
a70df67
Add type hints
NielsRogge Mar 4, 2024
0360dbe
Remove script
NielsRogge Mar 4, 2024
adb4a82
Test batched generation
NielsRogge Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/llava.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
- A [similar notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LLaVa/Inference_with_LLaVa_for_multimodal_generation.ipynb) showcasing batched inference. 🌎


## LlavaImageProcessor

[[autodoc]] LlavaImageProcessor
- preprocess

## LlavaConfig

[[autodoc]] LlavaConfig
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,7 @@
_import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"])
_import_structure["models.layoutlmv3"].extend(["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"])
_import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"])
_import_structure["models.llava"].append("LlavaImageProcessor")
_import_structure["models.mask2former"].append("Mask2FormerImageProcessor")
_import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"])
_import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"])
Expand Down Expand Up @@ -6071,6 +6072,7 @@
LayoutLMv3ImageProcessor,
)
from .models.levit import LevitFeatureExtractor, LevitImageProcessor
from .models.llava import LlavaImageProcessor
from .models.mask2former import Mask2FormerImageProcessor
from .models.maskformer import (
MaskFormerFeatureExtractor,
Expand Down
18 changes: 17 additions & 1 deletion src/transformers/models/llava/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available


_import_structure = {"configuration_llava": ["LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlavaConfig"]}
Expand All @@ -32,6 +32,14 @@
]
_import_structure["processing_llava"] = ["LlavaProcessor"]

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_llava"] = ["LlavaImageProcessor"]


if TYPE_CHECKING:
from .configuration_llava import LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlavaConfig
Expand All @@ -49,6 +57,14 @@
)
from .processing_llava import LlavaProcessor

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_llava import LlavaImageProcessor


else:
import sys
Expand Down
18 changes: 17 additions & 1 deletion src/transformers/models/llava/configuration_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ class LlavaConfig(PretrainedConfig):
The index of the layer to select the vision feature.
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Llava model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`~LlavaForConditionalGeneration`]
`inputs_ids` passed when calling [`~LlavaForConditionalGeneration`].
use_image_newline_parameter (`bool`, *optional*, defaults to `False`):
Whether to add a trainable parameter for the image newline token.
image_grid_pinpoints (`List`, *optional*):
A list of possible resolutions to use for processing high resolution images. Each item in the list should be a tuple or list
of the form `(height, width)`. Only used by the newer LLaVa 1.6 variant.

Example:

Expand Down Expand Up @@ -89,14 +94,25 @@ def __init__(
vision_feature_select_strategy="default",
vision_feature_layer=-2,
vocab_size=32000,
use_image_newline_parameter=False,
image_grid_pinpoints=None,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act

if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(
"vision_feature_select_strategy should be one of 'default', 'full'."
f"Got: {vision_feature_select_strategy}"
)

self.vision_feature_select_strategy = vision_feature_select_strategy
self.vision_feature_layer = vision_feature_layer
self.vocab_size = vocab_size
self.use_image_newline_parameter = use_image_newline_parameter
self.image_grid_pinpoints = image_grid_pinpoints

self.vision_config = vision_config

Expand Down
279 changes: 279 additions & 0 deletions src/transformers/models/llava/convert_llava_1_6_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Convert LLaVa 1.6 checkpoints from the original repository.

URL: https://github.com/haotian-liu/LLaVA/tree/main.


The command used to obtain original logits is the following:
python llava/eval/run_llava.py --model-path "liuhaotian/llava-v1.6-mistral-7b" --image-file "images/llava_v1_5_radar.jpg" --query "What is shown in this image?" --max_new_tokens 100 --temperature 0
"""

import argparse
import glob
import json
from pathlib import Path

import requests
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from PIL import Image
from safetensors import safe_open

from transformers import (
AddedToken,
AutoConfig,
AutoTokenizer,
LlavaConfig,
LlavaForConditionalGeneration,
LlavaImageProcessor,
LlavaProcessor,
)


KEYS_TO_MODIFY_MAPPING = {
"model.vision_tower.": "",
"model.mm_projector": "multi_modal_projector",
"model": "model.model",
"vision_model.model": "vision_model",
"lm_head": "language_model.lm_head",
"model.model": "language_model.model",
"multi_modal_projector.0": "multi_modal_projector.linear_1",
"multi_modal_projector.2": "multi_modal_projector.linear_2",
"language_model.model.image_newline": "image_newline",
}


def load_original_state_dict(model_id):
directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"])

original_state_dict = {}
for path in glob.glob(f"{directory_path}/*"):
if path.endswith(".safetensors"):
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
original_state_dict[key] = f.get_tensor(key)

return original_state_dict


def convert_state_dict_to_hf(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if key.endswith(".inv_freq"):
continue
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)

new_state_dict[key] = value.to(torch.float16)
return new_state_dict


def load_image():
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
return image


def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
# load original config
filepath = hf_hub_download(repo_id=model_id, filename="config.json", repo_type="model")
# read json
with open(filepath) as f:
data = json.load(f)
print(data)

if model_id == "liuhaotian/llava-v1.6-mistral-7b":
text_model_id = data["_name_or_path"]
elif model_id == "liuhaotian/llava-v1.6-vicuna-7b":
text_model_id = "lmsys/vicuna-7b-v1.5"
vision_model_id = data["mm_vision_tower"]

torch.set_default_dtype(torch.float16)
text_config = AutoConfig.from_pretrained(text_model_id)

tokenizer = AutoTokenizer.from_pretrained(text_model_id)
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True)
tokenizer.add_special_tokens({"pad_token": "<pad>"})

image_processor = LlavaImageProcessor.from_pretrained(vision_model_id)
processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)

config = LlavaConfig(
text_config=text_config.to_dict(),
image_grid_pinpoints=image_processor.image_grid_pinpoints,
use_image_newline_parameter=True,
)
config.pad_token_id = 32001

with init_empty_weights():
model = LlavaForConditionalGeneration(config)

# load original state dict
state_dict = load_original_state_dict(model_id)
state_dict = convert_state_dict_to_hf(state_dict)
model.load_state_dict(state_dict, assign=True)
model.eval()

pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data
mu = torch.mean(pre_expansion_embeddings, dim=0).float()
n = pre_expansion_embeddings.size()[0]
sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma)

# We add an image token so we resize the model
# Pad to 64 for performance reasons
pad_shape = 64
model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape)
model.language_model.model.embed_tokens.weight.data[32000:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[32000:].shape[0]))),
dim=0,
)
model.language_model.lm_head.weight.data[32000:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[32000:].shape[0]))),
dim=0,
)
Comment on lines +142 to +149
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is we're sampling from dist here then we should set a seed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ArthurZucker I took this from the original llava conversion script


device = "cuda:2"
model.to(device)

# prepare inputs
image = load_image()
if model_id == "liuhaotian/llava-v1.6-mistral-7b":
prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
elif model_id == "liuhaotian/llava-v1.6-vicuna-7b":
prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat is shown in this image? ASSISTANT:"
inputs = processor(images=image, text=prompt, return_tensors="pt")

filepath = hf_hub_download(repo_id="nielsr/test-image", filename="llava_1_6_input_ids.pt", repo_type="dataset")
original_input_ids = torch.load(filepath, map_location="cpu")
filepath = hf_hub_download(repo_id="nielsr/test-image", filename="llava_1_6_pixel_values.pt", repo_type="dataset")
original_pixel_values = torch.load(filepath, map_location="cpu")

# verify inputs
if model_id == "liuhaotian/llava-v1.6-mistral-7b":
# replace -200 by 32000 (since we use token ID = 32000 for the image token)
original_input_ids[original_input_ids == -200] = 32000
print(tokenizer.decode([id for id in original_input_ids.tolist()[0] if id != -200]))

assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist()
assert torch.allclose(original_pixel_values, inputs.pixel_values.half())

# verify single forward pass
image_sizes = torch.tensor([[899, 1024]])
assert image_sizes[0].tolist() == inputs.image_sizes[0].tolist()

print("Single forward pass")
with torch.inference_mode():
inputs = inputs.to(device)
outputs = model(**inputs)
print("Shape of logits:", outputs.logits.shape)
print("First values of logits:", outputs.logits[0, :3, :3])

if model_id == "liuhaotian/llava-v1.6-mistral-7b":
expected_slice = torch.tensor(
[[-4.8555, -4.6992, -0.1996], [-10.5703, -10.7344, -2.7246], [-7.0391, -7.3672, -0.2634]],
dtype=torch.float32,
device=device,
)
elif model_id == "liuhaotian/llava-v1.6-vicuna-7b":
expected_slice = torch.tensor(
[[1.4883, 0.9976, -0.6992], [-9.7031, -5.7031, -1.5557], [-5.1328, -5.5586, 8.8281]],
dtype=torch.float32,
device=device,
)
else:
raise ValueError(f"Model {model_id} not supported")

assert torch.allclose(outputs.logits[0, :3, :3], expected_slice, atol=1e-4)
print("Logits are ok!")

# verify generation
output_ids = model.generate(
**inputs,
max_new_tokens=100,
use_cache=True,
)

generated_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip()

if model_id == "liuhaotian/llava-v1.6-mistral-7b":
expected_text = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point.\n\nIn this particular radar chart, there are several axes labeled with different metrics or benchmarks, such as "MMM-Vet," "MMM-Bench," "LLaVA-Bench," "SLED-Bench," "'
elif model_id == "liuhaotian/llava-v1.6-vicuna-7b":
expected_text = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\'s questions. USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a graphical representation of a benchmarking study comparing the performance of various models or systems. It\'s a scatter plot with a circular layout, where each point represents a different model or system, and the axes represent different metrics or dimensions of comparison.\n\nThe metrics are likely related to machine learning or artificial intelligence performance, as indicated by the terms like "BLIP-2," "Instruct BLIP," "POE," "QWA," "V"""
else:
raise ValueError(f"Model {model_id} not supported")

assert generated_text == expected_text
print("Generated text is ok!")

# verify batched generation
print("Batched generation...")
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
cats_image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=[image, cats_image], text=[prompt, prompt], padding=True, return_tensors="pt").to(device)

for k, v in inputs.items():
print(k, v.shape)

print("Image sizes:", inputs.image_sizes)

# make sure image_sizes are the same
# as otherwise batched generation doesn't work
inputs.image_sizes[1] = inputs.image_sizes[0]

print("Batched generation...")
output_ids = model.generate(
**inputs,
max_new_tokens=20,
use_cache=True,
)

outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
print(outputs)

if pytorch_dump_folder_path is not None:
print(f"Saving model and processor for {model_id} to {pytorch_dump_folder_path}")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)

if push_to_hub:
repo_id = model_id.split("/")[-1]
model.push_to_hub(f"llava-hf/{repo_id}-hf")
processor.push_to_hub(f"llava-hf/{repo_id}-hf")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_id",
help="Hub location of the model to convert",
default="liuhaotian/llava-v1.6-mistral-7b",
choices=["liuhaotian/llava-v1.6-mistral-7b", "liuhaotian/llava-v1.6-vicuna-7b"],
required=False,
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()

convert_llava_to_hf(args.model_id, args.pytorch_dump_folder_path, args.push_to_hub)
Loading