-
Notifications
You must be signed in to change notification settings - Fork 32k
Add LLaVa 1.6 #29012
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Add LLaVa 1.6 #29012
Changes from all commits
Commits
Show all changes
72 commits
Select commit
Hold shift + click to select a range
de62574
First draft
NielsRogge 32e8de8
More improvements
NielsRogge 4675e20
More improvements
NielsRogge 621f956
More improvements
NielsRogge e47d690
Improve conversion script
NielsRogge 3af438b
Improve conversion script
NielsRogge 5be6e77
Improve script
NielsRogge f6fe6ca
Update script
NielsRogge fd68ed9
More improvements
NielsRogge 9be0258
Convert logits
NielsRogge 66b7ebc
Add generation
NielsRogge 9790272
More improvements
NielsRogge 908c122
More improvements
NielsRogge 11a5902
Make image_sizes a tensor
NielsRogge ef82dc5
Add support for batched generation
NielsRogge adb84f0
Use appropriate prompt
NielsRogge 37e6d16
More improvements
NielsRogge 28405c5
Make fixup
NielsRogge c2848ad
Fix docstrings
NielsRogge 92228b0
Improve conversion script
NielsRogge b8911d9
More improvements
NielsRogge e44c47b
Make fixup
NielsRogge 05a5cfe
Debug
NielsRogge 7a71fa2
Merge remote-tracking branch 'upstream/main' into add_llava_1_6
NielsRogge 9fa5c0a
More improvements
NielsRogge dc46dc1
Support padding of image
NielsRogge eaf307e
Merge remote-tracking branch 'upstream/main' into add_llava_1_6
NielsRogge 31978d8
Remove unused image_aspect_ratio
NielsRogge f7218a9
Add use_image_newline_parameter
NielsRogge bab61c9
More improvements
NielsRogge e3adbec
Remove script
NielsRogge b5c19ba
Merge remote-tracking branch 'upstream/main' into add_llava_1_6
NielsRogge 756e4be
Add integration test
NielsRogge 4cb3083
Merge remote-tracking branch 'upstream/main' into add_llava_1_6
NielsRogge e3677f8
Address comments
NielsRogge db7ffa8
Address comments
NielsRogge 46638d9
Address comments
NielsRogge 92d7c3a
Address comments
NielsRogge 0372a62
Address comments
NielsRogge 0b3ac7a
Address comments
NielsRogge 6b2da1f
Address comments
NielsRogge 955a945
Address comments
NielsRogge 6ead7f7
Address comments
NielsRogge 329ec41
Address comments
NielsRogge 088d446
Address comments
NielsRogge 84d7210
Address comments
NielsRogge 1daa83d
Merge remote-tracking branch 'upstream/main' into add_llava_1_6
NielsRogge a8a13f8
Address comment
NielsRogge 4413426
Make image_sizes height width
NielsRogge f9e5276
Improve _preprocess
NielsRogge 08fa428
Address comments
NielsRogge 297d6e0
Address comment
NielsRogge 77ecb17
Rename attribute
NielsRogge c889558
Address comment
NielsRogge ecbe64b
Use height width everywhere
NielsRogge 7b4da2f
Use pad
NielsRogge ce2ea8c
Rename variables
NielsRogge a0997ae
Add image processor tests
NielsRogge 947ff66
Improve tests
NielsRogge a47d4ff
Improve image processor
NielsRogge 7b91acd
Update modeling
NielsRogge e0dea6a
Address comment
NielsRogge 071be69
Address comment
NielsRogge 3539acc
Make fixup
NielsRogge 2f6e28d
Address comments
NielsRogge 05e3611
Add resample and input_data_format
NielsRogge 1ccb416
Address comments
NielsRogge 0f1357a
Fix image processor tests
NielsRogge c5db76e
Add data_format
NielsRogge a70df67
Add type hints
NielsRogge 0360dbe
Remove script
NielsRogge adb4a82
Test batched generation
NielsRogge File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
279 changes: 279 additions & 0 deletions
279
src/transformers/models/llava/convert_llava_1_6_to_hf.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,279 @@ | ||
| # Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Convert LLaVa 1.6 checkpoints from the original repository. | ||
|
|
||
| URL: https://github.com/haotian-liu/LLaVA/tree/main. | ||
|
|
||
|
|
||
| The command used to obtain original logits is the following: | ||
| python llava/eval/run_llava.py --model-path "liuhaotian/llava-v1.6-mistral-7b" --image-file "images/llava_v1_5_radar.jpg" --query "What is shown in this image?" --max_new_tokens 100 --temperature 0 | ||
| """ | ||
|
|
||
| import argparse | ||
| import glob | ||
| import json | ||
| from pathlib import Path | ||
|
|
||
| import requests | ||
| import torch | ||
| from accelerate import init_empty_weights | ||
| from huggingface_hub import hf_hub_download, snapshot_download | ||
| from PIL import Image | ||
| from safetensors import safe_open | ||
|
|
||
| from transformers import ( | ||
| AddedToken, | ||
| AutoConfig, | ||
| AutoTokenizer, | ||
| LlavaConfig, | ||
| LlavaForConditionalGeneration, | ||
| LlavaImageProcessor, | ||
| LlavaProcessor, | ||
| ) | ||
|
|
||
|
|
||
| KEYS_TO_MODIFY_MAPPING = { | ||
| "model.vision_tower.": "", | ||
| "model.mm_projector": "multi_modal_projector", | ||
| "model": "model.model", | ||
| "vision_model.model": "vision_model", | ||
| "lm_head": "language_model.lm_head", | ||
| "model.model": "language_model.model", | ||
| "multi_modal_projector.0": "multi_modal_projector.linear_1", | ||
| "multi_modal_projector.2": "multi_modal_projector.linear_2", | ||
| "language_model.model.image_newline": "image_newline", | ||
| } | ||
|
|
||
|
|
||
| def load_original_state_dict(model_id): | ||
| directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"]) | ||
|
|
||
| original_state_dict = {} | ||
| for path in glob.glob(f"{directory_path}/*"): | ||
| if path.endswith(".safetensors"): | ||
| with safe_open(path, framework="pt", device="cpu") as f: | ||
| for key in f.keys(): | ||
| original_state_dict[key] = f.get_tensor(key) | ||
|
|
||
| return original_state_dict | ||
|
|
||
|
|
||
| def convert_state_dict_to_hf(state_dict): | ||
| new_state_dict = {} | ||
| for key, value in state_dict.items(): | ||
| if key.endswith(".inv_freq"): | ||
| continue | ||
| for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): | ||
| if key_to_modify in key: | ||
| key = key.replace(key_to_modify, new_key) | ||
|
|
||
| new_state_dict[key] = value.to(torch.float16) | ||
| return new_state_dict | ||
|
|
||
|
|
||
| def load_image(): | ||
| url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" | ||
| image = Image.open(requests.get(url, stream=True).raw) | ||
| return image | ||
|
|
||
|
|
||
| def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): | ||
| # load original config | ||
| filepath = hf_hub_download(repo_id=model_id, filename="config.json", repo_type="model") | ||
| # read json | ||
| with open(filepath) as f: | ||
| data = json.load(f) | ||
| print(data) | ||
|
|
||
| if model_id == "liuhaotian/llava-v1.6-mistral-7b": | ||
| text_model_id = data["_name_or_path"] | ||
| elif model_id == "liuhaotian/llava-v1.6-vicuna-7b": | ||
| text_model_id = "lmsys/vicuna-7b-v1.5" | ||
| vision_model_id = data["mm_vision_tower"] | ||
|
|
||
| torch.set_default_dtype(torch.float16) | ||
| text_config = AutoConfig.from_pretrained(text_model_id) | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(text_model_id) | ||
| tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True) | ||
| tokenizer.add_special_tokens({"pad_token": "<pad>"}) | ||
|
|
||
| image_processor = LlavaImageProcessor.from_pretrained(vision_model_id) | ||
| processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor) | ||
|
|
||
| config = LlavaConfig( | ||
| text_config=text_config.to_dict(), | ||
| image_grid_pinpoints=image_processor.image_grid_pinpoints, | ||
| use_image_newline_parameter=True, | ||
| ) | ||
| config.pad_token_id = 32001 | ||
|
|
||
| with init_empty_weights(): | ||
| model = LlavaForConditionalGeneration(config) | ||
|
|
||
| # load original state dict | ||
| state_dict = load_original_state_dict(model_id) | ||
| state_dict = convert_state_dict_to_hf(state_dict) | ||
| model.load_state_dict(state_dict, assign=True) | ||
| model.eval() | ||
|
|
||
| pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data | ||
| mu = torch.mean(pre_expansion_embeddings, dim=0).float() | ||
| n = pre_expansion_embeddings.size()[0] | ||
| sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n | ||
| dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma) | ||
|
|
||
| # We add an image token so we resize the model | ||
| # Pad to 64 for performance reasons | ||
| pad_shape = 64 | ||
| model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape) | ||
| model.language_model.model.embed_tokens.weight.data[32000:] = torch.stack( | ||
| tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[32000:].shape[0]))), | ||
| dim=0, | ||
| ) | ||
| model.language_model.lm_head.weight.data[32000:] = torch.stack( | ||
| tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[32000:].shape[0]))), | ||
| dim=0, | ||
| ) | ||
|
|
||
| device = "cuda:2" | ||
| model.to(device) | ||
|
|
||
| # prepare inputs | ||
| image = load_image() | ||
| if model_id == "liuhaotian/llava-v1.6-mistral-7b": | ||
| prompt = "[INST] <image>\nWhat is shown in this image? [/INST]" | ||
| elif model_id == "liuhaotian/llava-v1.6-vicuna-7b": | ||
| prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat is shown in this image? ASSISTANT:" | ||
| inputs = processor(images=image, text=prompt, return_tensors="pt") | ||
|
|
||
| filepath = hf_hub_download(repo_id="nielsr/test-image", filename="llava_1_6_input_ids.pt", repo_type="dataset") | ||
| original_input_ids = torch.load(filepath, map_location="cpu") | ||
| filepath = hf_hub_download(repo_id="nielsr/test-image", filename="llava_1_6_pixel_values.pt", repo_type="dataset") | ||
| original_pixel_values = torch.load(filepath, map_location="cpu") | ||
|
|
||
| # verify inputs | ||
| if model_id == "liuhaotian/llava-v1.6-mistral-7b": | ||
| # replace -200 by 32000 (since we use token ID = 32000 for the image token) | ||
| original_input_ids[original_input_ids == -200] = 32000 | ||
| print(tokenizer.decode([id for id in original_input_ids.tolist()[0] if id != -200])) | ||
|
|
||
| assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist() | ||
| assert torch.allclose(original_pixel_values, inputs.pixel_values.half()) | ||
|
|
||
| # verify single forward pass | ||
| image_sizes = torch.tensor([[899, 1024]]) | ||
| assert image_sizes[0].tolist() == inputs.image_sizes[0].tolist() | ||
|
|
||
| print("Single forward pass") | ||
| with torch.inference_mode(): | ||
| inputs = inputs.to(device) | ||
| outputs = model(**inputs) | ||
| print("Shape of logits:", outputs.logits.shape) | ||
| print("First values of logits:", outputs.logits[0, :3, :3]) | ||
|
|
||
| if model_id == "liuhaotian/llava-v1.6-mistral-7b": | ||
| expected_slice = torch.tensor( | ||
| [[-4.8555, -4.6992, -0.1996], [-10.5703, -10.7344, -2.7246], [-7.0391, -7.3672, -0.2634]], | ||
| dtype=torch.float32, | ||
| device=device, | ||
| ) | ||
| elif model_id == "liuhaotian/llava-v1.6-vicuna-7b": | ||
| expected_slice = torch.tensor( | ||
| [[1.4883, 0.9976, -0.6992], [-9.7031, -5.7031, -1.5557], [-5.1328, -5.5586, 8.8281]], | ||
| dtype=torch.float32, | ||
| device=device, | ||
| ) | ||
| else: | ||
| raise ValueError(f"Model {model_id} not supported") | ||
|
|
||
| assert torch.allclose(outputs.logits[0, :3, :3], expected_slice, atol=1e-4) | ||
| print("Logits are ok!") | ||
|
|
||
| # verify generation | ||
| output_ids = model.generate( | ||
| **inputs, | ||
| max_new_tokens=100, | ||
| use_cache=True, | ||
| ) | ||
|
|
||
| generated_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip() | ||
|
|
||
| if model_id == "liuhaotian/llava-v1.6-mistral-7b": | ||
| expected_text = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point.\n\nIn this particular radar chart, there are several axes labeled with different metrics or benchmarks, such as "MMM-Vet," "MMM-Bench," "LLaVA-Bench," "SLED-Bench," "' | ||
| elif model_id == "liuhaotian/llava-v1.6-vicuna-7b": | ||
| expected_text = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\'s questions. USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a graphical representation of a benchmarking study comparing the performance of various models or systems. It\'s a scatter plot with a circular layout, where each point represents a different model or system, and the axes represent different metrics or dimensions of comparison.\n\nThe metrics are likely related to machine learning or artificial intelligence performance, as indicated by the terms like "BLIP-2," "Instruct BLIP," "POE," "QWA," "V""" | ||
NielsRogge marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else: | ||
| raise ValueError(f"Model {model_id} not supported") | ||
|
|
||
| assert generated_text == expected_text | ||
| print("Generated text is ok!") | ||
|
|
||
| # verify batched generation | ||
| print("Batched generation...") | ||
| url = "http://images.cocodataset.org/val2017/000000039769.jpg" | ||
| cats_image = Image.open(requests.get(url, stream=True).raw) | ||
|
|
||
| inputs = processor(images=[image, cats_image], text=[prompt, prompt], padding=True, return_tensors="pt").to(device) | ||
|
|
||
| for k, v in inputs.items(): | ||
| print(k, v.shape) | ||
|
|
||
| print("Image sizes:", inputs.image_sizes) | ||
|
|
||
| # make sure image_sizes are the same | ||
| # as otherwise batched generation doesn't work | ||
| inputs.image_sizes[1] = inputs.image_sizes[0] | ||
|
|
||
| print("Batched generation...") | ||
| output_ids = model.generate( | ||
| **inputs, | ||
| max_new_tokens=20, | ||
| use_cache=True, | ||
| ) | ||
|
|
||
| outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) | ||
| print(outputs) | ||
|
|
||
| if pytorch_dump_folder_path is not None: | ||
| print(f"Saving model and processor for {model_id} to {pytorch_dump_folder_path}") | ||
| Path(pytorch_dump_folder_path).mkdir(exist_ok=True) | ||
| model.save_pretrained(pytorch_dump_folder_path) | ||
| processor.save_pretrained(pytorch_dump_folder_path) | ||
|
|
||
| if push_to_hub: | ||
| repo_id = model_id.split("/")[-1] | ||
| model.push_to_hub(f"llava-hf/{repo_id}-hf") | ||
| processor.push_to_hub(f"llava-hf/{repo_id}-hf") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
| "--model_id", | ||
| help="Hub location of the model to convert", | ||
| default="liuhaotian/llava-v1.6-mistral-7b", | ||
| choices=["liuhaotian/llava-v1.6-mistral-7b", "liuhaotian/llava-v1.6-vicuna-7b"], | ||
| required=False, | ||
| ) | ||
| parser.add_argument( | ||
| "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." | ||
| ) | ||
| parser.add_argument( | ||
| "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| convert_llava_to_hf(args.model_id, args.pytorch_dump_folder_path, args.push_to_hub) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is we're sampling from
disthere then we should set a seedThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ArthurZucker I took this from the original llava conversion script