Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support LLaVA 1.5 #1853

Merged
merged 2 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 6 additions & 24 deletions configs/llava/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,28 @@ Instruction tuning large language models (LLMs) using machine-generated instruct

<!-- [TABS-BEGIN] -->

**Prepare the checkpoint**

According to the license of LLaMA, we cannot provide the merged checkpoint directly. Please use the below
script to download and get the merged the checkpoint.

```shell
python tools/model_converters/llava-delta2mmpre.py huggyllama/llama-7b liuhaotian/LLaVA-Lightning-7B-delta-v1-1 ./LLaVA-Lightning-7B-delta-v1-1.pth
```

**Use the model**

```python
import torch
from mmpretrain import get_model, inference_model

model = get_model('llava-7b-v1_caption', pretrained='MERGED_CHECKPOINT_PATH', device='cuda')
out = inference_model(model, 'demo/cat-dog.png')
out = inference_model('llava-7b-v1_caption', 'demo/cat-dog.png', device='cuda')
print(out)
# {'pred_caption': 'In the image, there are two cats sitting on a blanket.'}
```

**Test Command**

Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).

Test:

```shell
python tools/test.py configs/llava/llava-7b-v1_caption.py MERGED_CHECKPOINT_PATH
```

<!-- [TABS-END] -->

## Models and results

### Image Caption on COCO

| Model | Params (M) | BLEU-4 | CIDER | Config | Download |
| :-------------------- | :--------: | :------: | :------: | :------------------------------: | :--------------------: |
| `llava-7b-v1_caption` | 7045.82 | Upcoming | Upcoming | [config](llava-7b-v1_caption.py) | See the above tutorial |
| Model | Params (M) | Config | Download |
| :---------------------- | :--------: | :--------------------------------: | :-------------------------------------------------------------------------------------------------------------: |
| `llava-7b-v1_caption` | 7045.82 | [config](llava-7b-v1_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth) |
| `llava-7b-v1.5_caption` | 7062.90 | [config](llava-7b-v1.5_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) |
| `llava-7b-v1.5_vqa` | 7062.90 | [config](llava-7b-v1.5_vqa.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) |

## Citation

Expand Down
76 changes: 76 additions & 0 deletions configs/llava/llava-7b-v1.5_caption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
_base_ = '../_base_/default_runtime.py'

meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501
image_size = 336
prompt_tmpl = f'''{meta_prompt} User: <image>
Describe the image in detail. ASSISTANT:'''

# model settings
model = dict(
type='Llava',
tokenizer=dict(
type='AutoTokenizer', name_or_path='liuhaotian/llava-v1.5-7b'),
vision_encoder=dict(
type='VisionTransformer',
arch='l',
patch_size=14,
img_size=image_size,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
final_norm=False,
out_type='raw',
pretrained='https://download.openmmlab.com/mmclassification/v0/clip/'
'vit-large-p14_clip-openai-pre_336px_20231025-fb1315ed.pth',
),
mm_hidden_size=1024,
use_im_patch=False,
use_im_start_end=False,
mm_proj_depth=2,
lang_encoder=dict(
type='AutoModelForCausalLM',
name_or_path='huggyllama/llama-7b',
),
task='caption',
prompt_tmpl=prompt_tmpl,
generation_cfg=dict(num_beams=3, max_new_tokens=50, length_penalty=-1.0),
)

# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(image_size, image_size),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id']),
]

test_dataloader = dict(
batch_size=8,
num_workers=5,
dataset=dict(
type='COCOCaption',
data_root='data/coco',
ann_file='annotations/coco_karpathy_val.json',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)

test_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
)

# schedule settings
test_cfg = dict()
76 changes: 76 additions & 0 deletions configs/llava/llava-7b-v1.5_vqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
_base_ = '../_base_/default_runtime.py'

meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501
image_size = 336
prompt_tmpl = f'''{meta_prompt} User: <image>
{{question}} ASSISTANT:'''

# model settings
model = dict(
type='Llava',
tokenizer=dict(
type='AutoTokenizer', name_or_path='liuhaotian/llava-v1.5-7b'),
vision_encoder=dict(
type='VisionTransformer',
arch='l',
patch_size=14,
img_size=image_size,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
final_norm=False,
out_type='raw',
pretrained='https://download.openmmlab.com/mmclassification/v0/clip/'
'vit-large-p14_clip-openai-pre_336px_20231025-fb1315ed.pth',
),
mm_hidden_size=1024,
use_im_patch=False,
use_im_start_end=False,
mm_proj_depth=2,
lang_encoder=dict(
type='AutoModelForCausalLM',
name_or_path='huggyllama/llama-7b',
),
task='vqa',
prompt_tmpl=prompt_tmpl,
generation_cfg=dict(max_new_tokens=100),
)

# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(image_size, image_size),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id', 'question']),
]

test_dataloader = dict(
batch_size=8,
num_workers=5,
dataset=dict(
type='COCOCaption',
data_root='data/coco',
ann_file='annotations/coco_karpathy_val.json',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)

test_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
)

# schedule settings
test_cfg = dict()
21 changes: 8 additions & 13 deletions configs/llava/llava-7b-v1_caption.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
_base_ = '../_base_/default_runtime.py'

meta_prompt = 'You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.Follow the instructions carefully and explain your answers in detail.' # noqa: E501
im_patch_token = '<im_patch>'
patch_size = 14
image_size = 224
num_patches = (image_size // patch_size)**2
caption_prompt = ' '.join([
meta_prompt,
'User: a photo of\n',
im_patch_token * num_patches,
'ASSISTANT:',
])
prompt_tmpl = f'''{meta_prompt} User: <im_start><image><im_end>
Describe the image in detail. ASSISTANT:'''

# model settings
model = dict(
Expand All @@ -22,6 +15,7 @@
type='VisionTransformer',
arch='l',
patch_size=14,
img_size=image_size,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
Expand All @@ -32,15 +26,16 @@
'vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth'),
),
mm_hidden_size=1024,
use_im_start_end=False,
use_mm_proj=True,
use_im_patch=False,
use_im_start_end=True,
mm_proj_depth=1,
lang_encoder=dict(
type='AutoModelForCausalLM',
name_or_path='huggyllama/llama-7b',
),
task='caption',
prompt_tmpl=caption_prompt,
generation_cfg=dict(num_beams=3, max_new_tokens=20, length_penalty=-2.0),
prompt_tmpl=prompt_tmpl,
generation_cfg=dict(max_new_tokens=50),
)

# data settings
Expand Down
28 changes: 27 additions & 1 deletion configs/llava/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,31 @@ Models:
Metrics:
BLEU-4: null
CIDER: null
Weights: null
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth
Config: configs/llava/llava-7b-v1_caption.py
- Name: llava-7b-v1.5_caption
Metadata:
FLOPs: null
Parameters: 7062900736
In Collection: LLaVA
Results:
- Task: Image Caption
Dataset: COCO
Metrics:
BLEU-4: null
CIDER: null
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth
Config: configs/llava/llava-7b-v1.5_caption.py
- Name: llava-7b-v1.5_vqa
Metadata:
FLOPs: null
Parameters: 7062900736
In Collection: LLaVA
Results:
- Task: Visual Question Answering
Dataset: COCO
Metrics:
BLEU-4: null
CIDER: null
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth
Config: configs/llava/llava-7b-v1.5_vqa.py
35 changes: 22 additions & 13 deletions mmpretrain/models/multimodal/llava/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class Llava(BaseModel):
use_im_start_end (bool): Whether to use the im_start and im_end tokens
mm_vision_select_layer (int): The index from vision encoder output.
Defaults to -1.
use_mm_proj (bool): Whether to enable multi-modal projection.
Defaults to True.
mm_proj_depth (int): The number of linear layers for multi-modal
projection. Defaults to 1.
load_lang_pretrained (bool): Whether to load the pretrained model of
language encoder. Defaults to False.
generation_cfg (dict): The extra generation config, accept the keyword
Expand All @@ -51,9 +51,10 @@ def __init__(self,
mm_hidden_size: int,
prompt_tmpl: str,
task: str = 'caption',
use_im_patch: bool = True,
use_im_start_end: bool = False,
mm_vision_select_layer: int = -1,
use_mm_proj: bool = True,
mm_proj_depth: int = 1,
generation_cfg: dict = dict(),
load_lang_pretrained: bool = False,
data_preprocessor: Optional[dict] = None,
Expand All @@ -75,7 +76,9 @@ def __init__(self,
# init tokenizer
self.tokenizer = TOKENIZER.build(tokenizer)
# add Llava special tokens to the tokenizer
self.tokenizer.add_tokens([self.im_patch_token], special_tokens=True)
if use_im_patch:
self.tokenizer.add_tokens([self.im_patch_token],
special_tokens=True)
if use_im_start_end:
self.tokenizer.add_tokens([self.im_start_token, self.im_end_token],
special_tokens=True)
Expand Down Expand Up @@ -108,14 +111,12 @@ def __init__(self,
vision_encoder=vision_encoder,
lang_encoder=lang_encoder,
mm_hidden_size=mm_hidden_size,
use_mm_proj=use_mm_proj,
mm_proj_depth=mm_proj_depth,
use_im_start_end=use_im_start_end,
im_start_token=self.tokenizer.convert_tokens_to_ids(
self.im_start_token),
im_end_token=self.tokenizer.convert_tokens_to_ids(
self.im_end_token),
im_patch_token=self.tokenizer.convert_tokens_to_ids(
self.im_patch_token),
mm_vision_select_layer=mm_vision_select_layer)

self.generation_cfg = generation_cfg
Expand Down Expand Up @@ -207,16 +208,24 @@ def preprocess_text(self, data_samples: List[DataSample],
Returns:
List[DataSample]: Return list of data samples.
"""
prompts = []
tokens = []
for sample in data_samples:
final_prompt = self.prompt_tmpl.format(**sample.to_dict())
prompts.append(final_prompt)
prompt = self.prompt_tmpl.format(**sample.to_dict())
input_ids = []
while '<image>' in prompt:
prefix, _, prompt = prompt.partition('<image>')
input_ids.extend(
self.tokenizer(prefix, add_special_tokens=False).input_ids)
input_ids.append(-200)
if prompt:
input_ids.extend(
self.tokenizer(prompt, add_special_tokens=False).input_ids)
tokens.append(dict(input_ids=input_ids))

self.tokenizer.padding_side = 'left'
input_text = self.tokenizer(
prompts,
input_text = self.tokenizer.pad(
tokens,
padding='longest',
truncation=True,
return_tensors='pt',
max_length=2000,
).to(device)
Expand Down
Loading
Loading