Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ The following model architectures, tasks and device distributions have been vali
| VideoMAE | | <div style="text-align:left"><li>Single card</li></div> | <li>[Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification)</li> |
| TableTransformer | | <div style="text-align:left"><li>Single card</li></div> | <li>[table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection) </li> |
| DETR | | <div style="text-align:left"><li>Single card</li></div> | <li>[object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/object-detection)</li> |
| Mllama | <div style="text-align:left"><li>LoRA</li></div> | :heavy_check_mark: | <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |

</div>

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| VideoMAE | | <div style="text-align:left"><li>Single card</li></div> | <li>[Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification)</li> |
| TableTransformer | | <div style="text-align:left"><li>Single card</li></div> | <li>[table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection)</li> |
| DETR | | <div style="text-align:left"><li>Single card</li></div> | <li>[object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/object-detection)</li> |
| Mllama | <div style="text-align:left"><li>LoRA</li></div> |✅ | <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |

- Diffusers

Expand Down
79 changes: 79 additions & 0 deletions examples/image-to-text/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Models that have been validated:
- [llava-hf/llava-v1.6-34b-hf](https://huggingface.co/llava-hf/llava-v1.6-34b-hf)
- [llava-hf/llama3-llava-next-8b-hf](https://huggingface.co/llava-hf/llama3-llava-next-8b-hf)
- [HuggingFaceM4/idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b)
- [meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)

### Inference with BF16

Expand Down Expand Up @@ -102,6 +103,15 @@ python3 run_pipeline.py \
--bf16
```

To run mllama inference, use the following command:

```bash
python3 run_pipeline.py \
--model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \
--use_hpu_graphs \
--bf16
```

### Inference with FP8
Inference for Llava-1.5-7b, Llava-1.5-13b, Llava-v1.6-mistral-7b and Llava-v1.6-vicuna-13b in FP8 precision are enabled using [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html), which provides model measurement and quantization capabilities in PyTorch.

Expand Down Expand Up @@ -286,6 +296,75 @@ python3 ../gaudi_spawn.py \
--lora_target_modules '".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"'
```

Here are single-/multi-device command examples for meta-llama/Llama-3.2-11B-Vision-Instruct.

```bash
python3 run_image2text_lora_finetune.py \
--model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \
--dataset_name nielsr/docvqa_1200_examples \
--bf16 True \
--output_dir ./model_lora_llama \
--num_train_epochs 2 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 8 \
--weight_decay 0.01 \
--logging_steps 25 \
--eval_strategy "no" \
--save_strategy "no" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--lr_scheduler_type "constant" \
--input_column_names 'image' 'query' \
--output_column_names 'answers' \
--remove_unused_columns False \
--do_train \
--do_eval \
--use_habana \
--use_lazy_mode \
--lora_rank=8 \
--lora_alpha=8 \
--lora_dropout=0.1 \
--low_cpu_mem_usage True \
--max_seq_length=512 \
--use_hpu_graphs_for_inference True \
--lora_target_modules ".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"
```

```bash
python3 ../gaudi_spawn.py \
--world_size 8 --use_mpi run_image2text_lora_finetune.py \
--model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \
--dataset_name nielsr/docvqa_1200_examples \
--bf16 True \
--output_dir ./model_lora_llama \
--num_train_epochs 2 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 8 \
--weight_decay 0.01 \
--logging_steps 25 \
--eval_strategy "no" \
--save_strategy "no" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--lr_scheduler_type "constant" \
--input_column_names 'image' 'query' \
--output_column_names 'answers' \
--remove_unused_columns False \
--do_train \
--do_eval \
--use_habana \
--use_lazy_mode \
--lora_rank=8 \
--lora_alpha=8 \
--lora_dropout=0.1 \
--low_cpu_mem_usage True \
--max_seq_length=512 \
--use_hpu_graphs_for_inference True \
--lora_target_modules '".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"'
```

## Multi-HPU inference

To enable multi-card inference, you must set the environment variable `PT_HPU_ENABLE_LAZY_COLLECTIVES=true`,
Expand Down
17 changes: 11 additions & 6 deletions examples/image-to-text/run_image2text_lora_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,9 @@ class FinetuneArguments:


class MyDataCollator:
def __init__(self, processor, max_seq_length):
def __init__(self, processor, max_seq_length, image_token_id):
self.processor = processor
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
processor.tokenizer.additional_special_tokens.index("<image>")
]
self.image_token_id = image_token_id
self.max_seq_length = max_seq_length

def __call__(self, examples):
Expand Down Expand Up @@ -458,8 +456,15 @@ def main():
if col not in (data_args.input_column_names + data_args.output_column_names)
]
)

data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length)
if hasattr(config, "image_token_id"):
# idefics
image_token_id = config.image_token_id
elif hasattr(config, "image_token_index"):
# mllama
image_token_id = config.image_token_index
else:
raise ValueError("Please provide value for image_token_id")
data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length, image_token_id=image_token_id)

gaudi_config = GaudiConfig()
gaudi_config.use_fused_adam = True
Expand Down
41 changes: 27 additions & 14 deletions examples/image-to-text/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import PIL.Image
import requests
import torch
from transformers import AutoConfig, AutoProcessor, pipeline
from transformers import AutoConfig, AutoModelForVision2Seq, AutoProcessor, pipeline

from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

Expand Down Expand Up @@ -185,14 +185,13 @@ def main():
adapt_transformers_to_gaudi()

model_type = AutoConfig.from_pretrained(args.model_name_or_path).model_type
if args.image_path is None and model_type in ["llava", "idefics2"]:
if args.image_path is None and model_type in ["llava", "idefics2", "mllama"]:
args.image_path = ["https://llava-vl.github.io/static/images/view.jpg"]
elif args.image_path is None and model_type == "llava_next":
args.image_path = [
"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
]

if args.prompt is None and model_type in ["llava", "idefics2", "llava_next"]:
if args.prompt is None and model_type in ["llava", "idefics2", "llava_next", "mllama"]:
processor = AutoProcessor.from_pretrained(args.model_name_or_path)
conversation = [
{
Expand Down Expand Up @@ -231,17 +230,31 @@ def main():

htcore.hpu_set_env()

generator = pipeline(
"image-to-text",
model=args.model_name_or_path,
torch_dtype=model_dtype,
device="hpu",
)

if args.world_size > 1:
generator.model = initialize_distributed_model(args, generator.model, logger, model_dtype)

import deepspeed

with deepspeed.OnDevice(dtype=model_dtype, device="cpu"):
model = AutoModelForVision2Seq.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype)
if model_type == "mllama":
model.language_model = initialize_distributed_model(args, model.language_model, logger, model_dtype)
else:
model = initialize_distributed_model(args, model, logger, model_dtype)
generator = pipeline(
"image-to-text",
model=model,
config=args.model_name_or_path,
tokenizer=args.model_name_or_path,
image_processor=args.model_name_or_path,
torch_dtype=model_dtype,
device="hpu",
)
else:
generator = pipeline(
"image-to-text",
model=args.model_name_or_path,
torch_dtype=model_dtype,
device="hpu",
)
if args.use_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

Expand All @@ -263,7 +276,7 @@ def main():
htcore.hpu_initialize(generator.model)

# delete once pipeline integrate AutoProcessor as preprocess engine
if model_type in ["idefics2"]:
if model_type in ["idefics2", "mllama"]:
from transformers.image_utils import load_image

def preprocess(self, image, prompt=None, timeout=None):
Expand Down
48 changes: 36 additions & 12 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
"qwen2_moe",
"whisper",
"idefics2",
"mllama",
]


Expand Down Expand Up @@ -329,11 +330,13 @@ def _expand_dict_for_generation(dict_to_expand):

def _pad_past_key_values(self, model_kwargs):
pad_amount = model_kwargs.get("kv_cache_pad_len", 0)
kv_cache_len = model_kwargs.get("kv_cache_len", 0)
if model_kwargs["past_key_values"]:
if model_kwargs.get("mqa_model", False):
for i in range(len(model_kwargs["past_key_values"])): # layer
if torch.is_tensor(
model_kwargs["past_key_values"][i]
if (
torch.is_tensor(model_kwargs["past_key_values"][i])
and model_kwargs["past_key_values"][i].shape[-2] == kv_cache_len - pad_amount
): # tensor(batch_size, kv_cache_len, n_heads * head_dim * 2) k and v stacked
model_kwargs["past_key_values"][i] = torch.nn.functional.pad(
model_kwargs["past_key_values"][i], (0, 0, 0, pad_amount)
Expand All @@ -343,8 +346,9 @@ def _pad_past_key_values(self, model_kwargs):
else:
for i in range(len(model_kwargs["past_key_values"])): # layer
for j in range(len(model_kwargs["past_key_values"][i])): # k or v
if torch.is_tensor(
model_kwargs["past_key_values"][i][j]
if (
torch.is_tensor(model_kwargs["past_key_values"][i][j])
and model_kwargs["past_key_values"][i][j].shape[-2] == kv_cache_len - pad_amount
): # tensor(batch_size, n_heads, kv_cache_len, head_dim)
model_kwargs["past_key_values"][i][j] = torch.nn.functional.pad(
model_kwargs["past_key_values"][i][j], (0, 0, 0, pad_amount)
Expand Down Expand Up @@ -460,6 +464,14 @@ def update_model_kwargs_for_bucketing(
)
else:
assert False, "Not tested for cases where attn_mask isnt passed"

if model_kwargs.get("cross_attention_mask") is not None:
model_kwargs["cross_attention_mask"] = torch.nn.functional.pad(
model_kwargs["cross_attention_mask"],
(0, 0, 0, 0, 0, pad_amount),
value=0,
)

if reduce_recompile and params["passnum"] == 0:
position_ids_cpu = model_kwargs["attention_mask"].long().cumsum(-1) - 1
position_ids_cpu.masked_fill_(model_kwargs["attention_mask"] == 0, 1)
Expand Down Expand Up @@ -502,14 +514,20 @@ def create_pad_arg(pad_amount, i, j):
# This is a necessary (but not sufficient) condition: what ever dimension we are padding, should be a multiple of bucket_size
# This check is added in case we get a new model with a new kv-cache structure, and we attempt to pad some wrong dimension
# in peft case, if there's virtual token. the model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size == num_virtual_token, no need of assert, the pad length of past_key_value should be aligned with input id and attention_mask
num_virtual_tokens = model_kwargs.get("num_virtual_tokens", 0)
assert (
model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size
== num_virtual_tokens
)
tmp_lst[j] = torch.nn.functional.pad(
model_kwargs["past_key_values"][i][j], pad_tuple, value=pad_token_id
)
if (
model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)]
== params["allocated_space"] - pad_amount
):
num_virtual_tokens = model_kwargs.get("num_virtual_tokens", 0)
assert (
model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size
== num_virtual_tokens
)
tmp_lst[j] = torch.nn.functional.pad(
model_kwargs["past_key_values"][i][j], pad_tuple, value=pad_token_id
)
else:
tmp_lst[j] = model_kwargs["past_key_values"][i][j]
new_kv[i] = tuple(tmp_lst)
model_kwargs["past_key_values"] = tuple(new_kv)

Expand Down Expand Up @@ -1109,6 +1127,12 @@ def generate(
(0, generation_config.max_new_tokens),
value=0,
)
if model_kwargs.get("cross_attention_mask") is not None:
model_kwargs["cross_attention_mask"] = torch.nn.functional.pad(
model_kwargs["cross_attention_mask"],
(0, 0, 0, 0, 0, generation_config.max_new_tokens),
value=0,
)
else:
assert generation_config.bucket_size <= 0, "Untested path for bucket>0"
if model_kwargs.get("decoder_input_ids", None) is None:
Expand Down
18 changes: 18 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@
GaudiMixtralDecoderLayer,
GaudiMixtralForCausalLM,
GaudiMixtralModel,
GaudiMllamaCrossAttentionDecoderLayer,
GaudiMllamaForCausalLM,
GaudiMllamaForConditionalGeneration,
GaudiMllamaSelfAttentionDecoderLayer,
GaudiMllamaTextCrossAttention,
GaudiMllamaTextModel,
GaudiMllamaTextSelfAttention,
GaudiMllamaVisionModel,
GaudiMptAttention,
GaudiMptBlock,
GaudiMptForCausalLM,
Expand Down Expand Up @@ -618,6 +626,16 @@ def adapt_transformers_to_gaudi():
transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration = GaudiWhisperForConditionalGeneration
transformers.models.whisper.modeling_whisper.WHISPER_ATTENTION_CLASSES = GAUDI_WHISPER_ATTENTION_CLASSES

# Optimization for mllama on Gaudi
transformers.models.mllama.modeling_mllama.MllamaSelfAttentionDecoderLayer = GaudiMllamaSelfAttentionDecoderLayer
transformers.models.mllama.modeling_mllama.MllamaCrossAttentionDecoderLayer = GaudiMllamaCrossAttentionDecoderLayer
transformers.models.mllama.modeling_mllama.MllamaForCausalLM = GaudiMllamaForCausalLM
transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention = GaudiMllamaTextSelfAttention
transformers.models.mllama.modeling_mllama.MllamaTextCrossAttention = GaudiMllamaTextCrossAttention
transformers.models.mllama.modeling_mllama.MllamaForConditionalGeneration = GaudiMllamaForConditionalGeneration
transformers.models.mllama.modeling_mllama.MllamaTextModel = GaudiMllamaTextModel
transformers.models.mllama.modeling_mllama.MllamaVisionModel = GaudiMllamaVisionModel

transformers.AutoConfig.register("deci", DeciLMConfig)
transformers.AutoModelForCausalLM.register(DeciLMConfig, DeciLMForCausalLM)

Expand Down
10 changes: 10 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@
gaudi_mixtral_block_sparse_moe_forward,
gaudi_mixtral_rmsnorm_forward,
)
from .mllama import (
GaudiMllamaCrossAttentionDecoderLayer,
GaudiMllamaForCausalLM,
GaudiMllamaForConditionalGeneration,
GaudiMllamaSelfAttentionDecoderLayer,
GaudiMllamaTextCrossAttention,
GaudiMllamaTextModel,
GaudiMllamaTextSelfAttention,
GaudiMllamaVisionModel,
)
from .modeling_all_models import (
gaudi_check_and_enable_sdpa,
gaudi_conv1d_forward,
Expand Down
10 changes: 10 additions & 0 deletions optimum/habana/transformers/models/mllama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .modeling_mllama import (
GaudiMllamaCrossAttentionDecoderLayer,
GaudiMllamaForCausalLM,
GaudiMllamaForConditionalGeneration,
GaudiMllamaSelfAttentionDecoderLayer,
GaudiMllamaTextCrossAttention,
GaudiMllamaTextModel,
GaudiMllamaTextSelfAttention,
GaudiMllamaVisionModel,
)
Loading