From 2f50f104991a050f450840e2757a760d97e9cd08 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 26 May 2024 17:16:44 +0200 Subject: [PATCH 01/72] remove torch and mlx-lm --- requirements.txt | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index 259a4568..ddb02fe5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,6 @@ -mlx>=0.8.0 -mlx-lm>=0.4.0 +mlx>=0.14 numpy -transformers -torch -huggingface_hub +transformers>=4.39.3 gradio Pillow requests From d14849fc95e4773e62e013920483a17e977689ac Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 26 May 2024 17:16:44 +0200 Subject: [PATCH 02/72] remove torch and mlx-lm --- mlx_vlm/trainer/__init__.py | 2 ++ mlx_vlm/trainer/lora.py | 65 ++++++++++++++++++++++++++++++++++ mlx_vlm/trainer/trainer.py | 70 +++++++++++++++++++++++++++++++++++++ mlx_vlm/trainer/utils.py | 30 ++++++++++++++++ 4 files changed, 167 insertions(+) create mode 100644 mlx_vlm/trainer/__init__.py create mode 100644 mlx_vlm/trainer/lora.py create mode 100644 mlx_vlm/trainer/trainer.py create mode 100644 mlx_vlm/trainer/utils.py diff --git a/mlx_vlm/trainer/__init__.py b/mlx_vlm/trainer/__init__.py new file mode 100644 index 00000000..0450264a --- /dev/null +++ b/mlx_vlm/trainer/__init__.py @@ -0,0 +1,2 @@ +from .utils import collate_fn, find_all_linear_names +from .lora import LoRaLayer, replace_lora_with_linear \ No newline at end of file diff --git a/mlx_vlm/trainer/lora.py b/mlx_vlm/trainer/lora.py new file mode 100644 index 00000000..f27f1cc0 --- /dev/null +++ b/mlx_vlm/trainer/lora.py @@ -0,0 +1,65 @@ +import math +from typing import Union +import mlx.core as mx +import mlx.nn as nn + +class LoRaLayer(nn.Module): + def __init__( + self, + linear: Union[nn.Linear, nn.QuantizedLinear], + rank: int, + alpha: float = 0.1, + dropout: float = 0.0, + + ): + super().__init__() + + self.original_layer = linear + self.dropout = nn.Dropout(p=dropout) + + output_dims, input_dims = linear.weight.shape + + std_dev = 1 / math.sqrt(rank) + + self.A = mx.random.uniform( + low=-std_dev, + high=std_dev, + shape=(input_dims, rank), + ) + self.B = mx.zeros(rank, output_dims) + self.alpha = alpha + + def __call__(self, x): + y = self.original_layer(x) + lora_update = (self.dropout(x) @ self.A) @ self.B + return y + (self.alpha * lora_update).astype(x.dtype) + +def replace_lora_with_linear(model): + for i, layer in enumerate(model.layers): + if isinstance(layer, LoRaLayer): + # Compute the final merged weight + lora_update = layer.alpha * (layer.A @ layer.B) + updated_weight = layer.original_layer.weight + lora_update + use_bias = layer.original_layer.bias is not None + + updated_bias = layer.original_layer.bias + + # Create a new Linear layer with the updated parameters + new_linear_layer = nn.Linear(updated_weight.size(1), updated_weight.size(0), bias=use_bias) + + new_linear_layer.weight = updated_weight + + if use_bias: + new_linear_layer.bias = updated_bias + + if isinstance(layer.original_layer, nn.QuantizedLinear): + new_linear_layer = nn.QuantizedLinear.from_linear( + new_linear_layer, + new_linear_layer.group_size, + new_linear_layer.bits, + ) + + + # Replace the LoRaLayer with the new Linear layer in the model + model.layers[i] = new_linear_layer + diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py new file mode 100644 index 00000000..15111c63 --- /dev/null +++ b/mlx_vlm/trainer/trainer.py @@ -0,0 +1,70 @@ +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Union + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from mlx.utils import tree_flatten + + +def grad_checkpoint(layer): + """ + Update all instances of type(layer) to use gradient checkpointing. + """ + fn = type(layer).__call__ + + def checkpointed_fn(model, *args, **kwargs): + def inner_fn(params, *args, **kwargs): + model.update(params) + return fn(model, *args, **kwargs) + + return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs) + + type(layer).__call__ = checkpointed_fn + +@dataclass +class TrainingArgs: + batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) + iters: int = field(default=100, metadata={"help": "Iterations to train for."}) + val_batches: int = field( + default=25, + metadata={ + "help": "Number of validation batches, -1 uses the entire validation set." + }, + ) + steps_per_report: int = field( + default=10, + metadata={"help": "Number of training steps between loss reporting."}, + ) + steps_per_eval: int = field( + default=200, metadata={"help": "Number of training steps between validations."} + ) + steps_per_save: int = field( + default=100, metadata={"help": "Save the model every number steps"} + ) + max_seq_length: int = field( + default=2048, metadata={"help": "Maximum sequence length."} + ) + adapter_file: str = field( + default="adapters.safetensors", + metadata={"help": "Save/load path for the trained adapter weights."}, + ) + grad_checkpoint: bool = field( + default=False, + metadata={"help": "Use gradient checkpointing to reduce memory use."}, + ) + + +def default_loss(model, inputs, targets, lengths): + logits = model(inputs) + logits = logits.astype(mx.float32) + + length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + + ce = nn.losses.cross_entropy(logits, targets) * length_mask + ntoks = length_mask.sum() + ce = ce.sum() / ntoks + + return ce, ntoks \ No newline at end of file diff --git a/mlx_vlm/trainer/utils.py b/mlx_vlm/trainer/utils.py new file mode 100644 index 00000000..d75bc473 --- /dev/null +++ b/mlx_vlm/trainer/utils.py @@ -0,0 +1,30 @@ + +import mlx.nn as nn +import mlx.core as mx + +def find_all_linear_names(model): + cls = nn.Linear + lora_module_names = set() + multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + + +def collate_fn(processor, examples): + texts = ["answer " + example["question"] for example in examples] + labels= [example['multiple_choice_answer'] for example in examples] + images = [example["image"].convert("RGB") for example in examples] + tokens = processor(text=texts, images=images, suffix=labels, + return_tensors="pt", padding="longest", + tokenize_newline_separately=False) + + tokens = tokens.to(mx.float16) + return tokens \ No newline at end of file From 2391df40c550c2a0fb50691a8987fe3adb1bee36 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 12 Jun 2024 13:15:08 +0200 Subject: [PATCH 03/72] add peft model creation --- mlx_vlm/trainer/__init__.py | 10 ++- mlx_vlm/trainer/lora.py | 15 +++-- mlx_vlm/trainer/utils.py | 130 +++++++++++++++++++++++++++++++++--- 3 files changed, 136 insertions(+), 19 deletions(-) diff --git a/mlx_vlm/trainer/__init__.py b/mlx_vlm/trainer/__init__.py index 0450264a..d0132a04 100644 --- a/mlx_vlm/trainer/__init__.py +++ b/mlx_vlm/trainer/__init__.py @@ -1,2 +1,8 @@ -from .utils import collate_fn, find_all_linear_names -from .lora import LoRaLayer, replace_lora_with_linear \ No newline at end of file +from .lora import LoRaLayer, replace_lora_with_linear +from .utils import ( + collate_fn, + count_parameters, + find_all_linear_names, + get_peft_model, + print_trainable_parameters, +) diff --git a/mlx_vlm/trainer/lora.py b/mlx_vlm/trainer/lora.py index f27f1cc0..c5afbf2f 100644 --- a/mlx_vlm/trainer/lora.py +++ b/mlx_vlm/trainer/lora.py @@ -1,8 +1,10 @@ import math from typing import Union + import mlx.core as mx import mlx.nn as nn + class LoRaLayer(nn.Module): def __init__( self, @@ -10,11 +12,11 @@ def __init__( rank: int, alpha: float = 0.1, dropout: float = 0.0, - ): super().__init__() self.original_layer = linear + self.dropout = nn.Dropout(p=dropout) output_dims, input_dims = linear.weight.shape @@ -26,13 +28,14 @@ def __init__( high=std_dev, shape=(input_dims, rank), ) - self.B = mx.zeros(rank, output_dims) + self.B = mx.zeros((rank, output_dims)) self.alpha = alpha def __call__(self, x): y = self.original_layer(x) lora_update = (self.dropout(x) @ self.A) @ self.B - return y + (self.alpha * lora_update).astype(x.dtype) + return y + (self.alpha * lora_update).astype(x.dtype) + def replace_lora_with_linear(model): for i, layer in enumerate(model.layers): @@ -45,7 +48,9 @@ def replace_lora_with_linear(model): updated_bias = layer.original_layer.bias # Create a new Linear layer with the updated parameters - new_linear_layer = nn.Linear(updated_weight.size(1), updated_weight.size(0), bias=use_bias) + new_linear_layer = nn.Linear( + updated_weight.size(1), updated_weight.size(0), bias=use_bias + ) new_linear_layer.weight = updated_weight @@ -59,7 +64,5 @@ def replace_lora_with_linear(model): new_linear_layer.bits, ) - # Replace the LoRaLayer with the new Linear layer in the model model.layers[i] = new_linear_layer - diff --git a/mlx_vlm/trainer/utils.py b/mlx_vlm/trainer/utils.py index d75bc473..f6089dae 100644 --- a/mlx_vlm/trainer/utils.py +++ b/mlx_vlm/trainer/utils.py @@ -1,30 +1,138 @@ - -import mlx.nn as nn import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from .lora import LoRaLayer + + +def get_module_by_name(model, name): + parts = name.split(".") + module = model + for part in parts: + if part.isdigit(): + module = module[int(part)] + else: + module = getattr(module, part) + return module + + +def set_module_by_name(model, name, new_module): + parts = name.split(".") + module = model + for part in parts[:-1]: + if part.isdigit(): + module = module[int(part)] + else: + module = getattr(module, part) + if parts[-1].isdigit(): + module[int(parts[-1])] = new_module + else: + setattr(module, parts[-1], new_module) + + +def get_peft_model(model, linear_layers, freeze=True, verbose=True): + source_model_trainable = count_parameters( + model.language_model.trainable_parameters() + ) + + if freeze: + freeze_model(model) + + for name, module in model.named_modules(): + if isinstance(module, nn.Linear) and name.split(".")[-1] in linear_layers: + lora_layer = LoRaLayer(module, 10, 0.1, 0.1) + set_module_by_name(model, name, lora_layer) + + lora_model_trainable = count_parameters(model.language_model.trainable_parameters()) + if verbose: + print_trainable_parameters(source_model_trainable, lora_model_trainable) + + return model + + +def freeze_model(model): + for name, module in model.named_modules(): + if name in [ + "language_model", + "vision_model", + "vision_tower", + "aligner", + "connector", + "multi_modal_projector", + "mm_projector", + ]: + model[f"{name}"].freeze() + def find_all_linear_names(model): cls = nn.Linear lora_module_names = set() - multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] + multimodal_keywords = [ + "mm_projector", + "vision_tower", + "vision_resampler", + "aligner", + ] for name, module in model.named_modules(): if any(mm_keyword in name for mm_keyword in multimodal_keywords): continue if isinstance(module, cls): - names = name.split('.') + names = name.split(".") lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - if 'lm_head' in lora_module_names: # needed for 16-bit - lora_module_names.remove('lm_head') + if "lm_head" in lora_module_names: # needed for 16-bit + lora_module_names.remove("lm_head") return list(lora_module_names) def collate_fn(processor, examples): texts = ["answer " + example["question"] for example in examples] - labels= [example['multiple_choice_answer'] for example in examples] + labels = [example["multiple_choice_answer"] for example in examples] images = [example["image"].convert("RGB") for example in examples] - tokens = processor(text=texts, images=images, suffix=labels, - return_tensors="pt", padding="longest", - tokenize_newline_separately=False) + tokens = processor( + text=texts, + images=images, + suffix=labels, + return_tensors="np", + padding="longest", + tokenize_newline_separately=False, + ) tokens = tokens.to(mx.float16) - return tokens \ No newline at end of file + return tokens + + +def flatten_dict(dd, separator="_", prefix=""): + return ( + { + prefix + separator + k if prefix else k: v + for kk, vv in dd.items() + for k, v in flatten_dict(vv, separator, kk).items() + } + if isinstance(dd, dict) + else {prefix: dd} + ) + + +def count_parameters(trainable_params_dict): + total_params = 0 + for k, v in flatten_dict(trainable_params_dict).items(): + if hasattr(v, "shape"): + total_params += np.prod(v.shape) + + if isinstance(v, list): + for v_ in v: + v_ = flatten_dict(v_) + if isinstance(v_, dict): + total_params += sum( + np.prod(p.shape) for p in v_.values() if hasattr(p, "shape") + ) + + return total_params + + +def print_trainable_parameters(source_model_trainable, lora_model_trainable): + lora_trainable_percent = (lora_model_trainable / source_model_trainable) * 100 + print( + f"#trainable params: {lora_model_trainable} || all params: {source_model_trainable} || trainable%: {lora_trainable_percent}" + ) From 5fcaed2505b0c402273f9cb62e25dce5224d03b2 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 7 Jul 2024 16:55:36 +0200 Subject: [PATCH 04/72] use tree flatten --- mlx_vlm/trainer/utils.py | 28 +++++----------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/mlx_vlm/trainer/utils.py b/mlx_vlm/trainer/utils.py index f6089dae..47df1f52 100644 --- a/mlx_vlm/trainer/utils.py +++ b/mlx_vlm/trainer/utils.py @@ -1,6 +1,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np +from mlx.utils import tree_flatten from .lora import LoRaLayer @@ -102,31 +103,12 @@ def collate_fn(processor, examples): return tokens -def flatten_dict(dd, separator="_", prefix=""): - return ( - { - prefix + separator + k if prefix else k: v - for kk, vv in dd.items() - for k, v in flatten_dict(vv, separator, kk).items() - } - if isinstance(dd, dict) - else {prefix: dd} - ) - - def count_parameters(trainable_params_dict): total_params = 0 - for k, v in flatten_dict(trainable_params_dict).items(): - if hasattr(v, "shape"): - total_params += np.prod(v.shape) - - if isinstance(v, list): - for v_ in v: - v_ = flatten_dict(v_) - if isinstance(v_, dict): - total_params += sum( - np.prod(p.shape) for p in v_.values() if hasattr(p, "shape") - ) + for modules in tree_flatten(trainable_params_dict): + mx_array = modules[-1] + if hasattr(mx_array, "shape"): + total_params += np.prod(mx_array.shape) return total_params From a88029f2940cc4859953659db72ffec0232033fb Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Tue, 9 Jul 2024 12:55:18 +0200 Subject: [PATCH 05/72] add dataset loader --- mlx_vlm/trainer/trainer.py | 81 +++++++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index 15111c63..4e1e6741 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -1,3 +1,4 @@ +import os import time from dataclasses import dataclass, field from pathlib import Path @@ -7,6 +8,40 @@ import mlx.nn as nn import numpy as np from mlx.utils import tree_flatten +from PIL import Image + + +class ImageTextDataset: + def __init__(self, image_dir, caption_file, img_size=(224, 224)): + self.image_dir = image_dir + self.img_size = img_size + self.image_captions = [] + self.unique_captions = set() + + with open(caption_file, "r") as f: + for line in f: + image_name, caption = line.strip().split(",") + self.image_captions.append((image_name, caption)) + self.unique_captions.add(caption) + + self.caption_to_id = { + caption: i for i, caption in enumerate(self.unique_captions) + } + + def __len__(self): + return len(self.image_captions) + + def __getitem__(self, idx): + image_name, caption = self.image_captions[idx] + image_path = os.path.join(self.image_dir, image_name) + + image = Image.open(image_path).convert("RGB") + image = image.resize(self.img_size) + image_array = np.array(image).astype(np.float32) / 255.0 + + caption_id = self.caption_to_id[caption] + + return mx.array(image_array), mx.array(caption_id, dtype=mx.int32) def grad_checkpoint(layer): @@ -24,6 +59,7 @@ def inner_fn(params, *args, **kwargs): type(layer).__call__ = checkpointed_fn + @dataclass class TrainingArgs: batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) @@ -67,4 +103,47 @@ def default_loss(model, inputs, targets, lengths): ntoks = length_mask.sum() ce = ce.sum() / ntoks - return ce, ntoks \ No newline at end of file + return ce, ntoks + + +class Trainer: + def __init__(self, model, optimizer, loss_fn): + self.model = model + self.optimizer = optimizer + self.loss_fn = loss_fn + + def train_step(self, batch): + images, labels = batch + + def loss_fn(model): + logits = model(images) + return self.loss_fn(logits, labels) + + loss, grads = mx.value_and_grad(loss_fn)(self.model) + self.optimizer.update(self.model, grads) + return loss + + @mx.compile + def train_epoch(self, dataloader): + total_loss = 0 + for batch in dataloader: + loss = self.train_step(batch) + total_loss += loss + return total_loss / len(dataloader) + + def evaluate(self, dataloader): + correct = total = 0 + for images, labels in dataloader: + logits = self.model(images) + predictions = mx.argmax(logits, axis=1) + correct += mx.sum(predictions == labels) + total += labels.size + return correct / total + + +def save_adapter( + model: nn.Module, + adapter_file: Union[str, Path], +): + flattened_tree = tree_flatten(model.trainable_parameters()) + mx.save_safetensors(str(adapter_file), dict(flattened_tree)) From 9aa5072979ef49532999e2654d8c6af02bbd3b46 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 4 Sep 2024 00:17:20 +0200 Subject: [PATCH 06/72] fix dataset --- mlx_vlm/trainer/__init__.py | 1 + mlx_vlm/trainer/trainer.py | 103 +++++++++++++++++++++++++++--------- mlx_vlm/utils.py | 54 +++++++++++++++---- 3 files changed, 121 insertions(+), 37 deletions(-) diff --git a/mlx_vlm/trainer/__init__.py b/mlx_vlm/trainer/__init__.py index d0132a04..16df5fea 100644 --- a/mlx_vlm/trainer/__init__.py +++ b/mlx_vlm/trainer/__init__.py @@ -1,4 +1,5 @@ from .lora import LoRaLayer, replace_lora_with_linear +from .trainer import * from .utils import ( collate_fn, count_parameters, diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index 4e1e6741..52551ac3 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -2,6 +2,7 @@ import time from dataclasses import dataclass, field from pathlib import Path +from pprint import pprint from typing import Union import mlx.core as mx @@ -10,38 +11,88 @@ from mlx.utils import tree_flatten from PIL import Image +from mlx_vlm.prompt_utils import get_message_json +from mlx_vlm.utils import prepare_inputs + class ImageTextDataset: - def __init__(self, image_dir, caption_file, img_size=(224, 224)): - self.image_dir = image_dir - self.img_size = img_size - self.image_captions = [] - self.unique_captions = set() - - with open(caption_file, "r") as f: - for line in f: - image_name, caption = line.strip().split(",") - self.image_captions.append((image_name, caption)) - self.unique_captions.add(caption) - - self.caption_to_id = { - caption: i for i, caption in enumerate(self.unique_captions) - } + def __init__( + self, + hf_dataset, + config, + processor, + image_processor=None, + take=None, + split="train", + ): + self.dataset = hf_dataset[split] + if take is not None: + self.dataset = self.dataset.take(take) + self.processor = processor + self.config = config + self.image_processor = image_processor def __len__(self): - return len(self.image_captions) + return len(self.dataset) def __getitem__(self, idx): - image_name, caption = self.image_captions[idx] - image_path = os.path.join(self.image_dir, image_name) - - image = Image.open(image_path).convert("RGB") - image = image.resize(self.img_size) - image_array = np.array(image).astype(np.float32) / 255.0 - - caption_id = self.caption_to_id[caption] - - return mx.array(image_array), mx.array(caption_id, dtype=mx.int32) + item = self.dataset[idx] + + # Process image data + image = item["image"] + + conversations = item["conversations"] + # check if conversation is a list of list + if isinstance(conversations, list) and isinstance(conversations[0], list): + prompts = [] + for conversation in conversations: + if "chat_template" in self.processor.__dict__.keys(): + prompts.append( + self.processor.apply_chat_template(conversation, tokenize=False) + ) + + elif "tokenizer" in self.processor.__dict__.keys(): + if self.config["model_type"] != "paligemma": + prompts.append( + self.processor.tokenizer.apply_chat_template( + conversation, tokenize=False + ) + ) + else: + raise ValueError( + "Processor does not have 'chat_template' or 'tokenizer' attribute." + ) + + else: + if "chat_template" in self.processor.__dict__.keys(): + prompts = self.processor.apply_chat_template( + conversations, tokenize=False + ) + + elif "tokenizer" in self.processor.__dict__.keys(): + if self.config["model_type"] != "paligemma": + prompts = self.processor.tokenizer.apply_chat_template( + conversations, tokenize=False + ) + else: + raise ValueError( + "Processor does not have 'chat_template' or 'tokenizer' attribute." + ) + + print(prompts) + image_token_index = self.config["image_token_index"] + input_ids, pixel_values, mask = prepare_inputs( + self.image_processor, self.processor, image, prompts, image_token_index + ) + + if mask is None: + mask = mx.ones_like(input_ids) + + return { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": mask, + } def grad_checkpoint(layer): diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 1b445866..139c4887 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -688,24 +688,56 @@ def load_image(image_source: Union[str, Path, BytesIO]): ) -def prepare_inputs(image_processor, processor, image, prompt, image_token_index): +def prepare_inputs(image_processor, processor, images, prompts, image_token_index): from transformers.image_utils import load_image - mask = None - if isinstance(image, str): - image = load_image(image) + # from pprint import pprint + if not isinstance(images, list): + images = [images] + if not isinstance(prompts, list): + prompts = [prompts] + + # print(len(images), len(prompts)) + # print(images) + # pprint(prompts) + assert len(images) == len( + prompts + ), f"Number of images ({len(images)}) and prompts ({len(prompts)}) must match" + + masks = None + loaded_images = [load_image(img) if isinstance(img, str) else img for img in images] if image_processor is not None: - text_chunks = [processor(chunk).input_ids for chunk in prompt.split("")] - input_ids = mx.array([text_chunks[0] + [image_token_index] + text_chunks[1]]) - pixel_values = image_processor.preprocess(images=[image])[0] - pixel_values = mx.array(np.expand_dims(pixel_values, axis=0)) + text_chunks = [ + [processor(chunk).input_ids for chunk in prompt.split("")] + for prompt in prompts + ] + + # Find the maximum length for padding + max_length = max( + sum(len(chunk) for chunk in chunks) + 1 for chunks in text_chunks + ) + + # Pad and create input_ids + input_ids = [] + for chunks in text_chunks: + ids = chunks[0] + [image_token_index] + chunks[1] + padding = [processor.pad_token_id] * (max_length - len(ids)) + input_ids.append(mx.array(ids + padding)) + + input_ids = mx.array(input_ids) + pixel_values = image_processor.preprocess(images=loaded_images) + pixel_values = mx.array(np.stack(pixel_values)) + masks = mx.array([(ids != processor.pad_token_id) for ids in input_ids]).astype( + mx.int32 + ) else: - inputs = processor(prompt, image, return_tensors="np") + inputs = processor(prompts, loaded_images, return_tensors="np", padding=True) pixel_values = mx.array(inputs["pixel_values"]) input_ids = mx.array(inputs["input_ids"]) - mask = mx.array(inputs["attention_mask"]) - return input_ids, pixel_values, mask + masks = mx.array(inputs["attention_mask"]) + + return input_ids, pixel_values, masks def sample(logits: mx.array, temp: float, top_p: float) -> Tuple[mx.array, float]: From 911eaaa947330b2ed2705c3f67b403e7d057a1d6 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 4 Sep 2024 11:19:14 +0200 Subject: [PATCH 07/72] fix masks and rename dataset --- mlx_vlm/models/llava_bunny/language.py | 7 +-- mlx_vlm/models/llava_bunny/llava_bunny.py | 2 +- mlx_vlm/models/llava_bunny/vision.py | 4 +- mlx_vlm/trainer/trainer.py | 54 +++++++++++++++-------- 4 files changed, 42 insertions(+), 25 deletions(-) diff --git a/mlx_vlm/models/llava_bunny/language.py b/mlx_vlm/models/llava_bunny/language.py index e4010da8..d46936d6 100644 --- a/mlx_vlm/models/llava_bunny/language.py +++ b/mlx_vlm/models/llava_bunny/language.py @@ -167,6 +167,7 @@ def __call__( inputs: mx.array, cache=None, inputs_embeds: Optional[mx.array] = None, + mask: Optional[mx.array] = None, ): # for passing merged input embeddings if inputs_embeds is None: @@ -174,8 +175,8 @@ def __call__( else: h = inputs_embeds - mask = None - if h.shape[1] > 1: + # mask = None + if h.shape[1] > 1 and mask is None: mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) mask = mask.astype(h.dtype) @@ -202,7 +203,7 @@ def __call__( inputs_embeds: Optional[mx.array] = None, mask: Optional[mx.array] = None, ): - out = self.model(inputs, cache=cache, inputs_embeds=inputs_embeds) + out = self.model(inputs, cache=cache, inputs_embeds=inputs_embeds, mask=None) return out def sanitize(self, weights): diff --git a/mlx_vlm/models/llava_bunny/llava_bunny.py b/mlx_vlm/models/llava_bunny/llava_bunny.py index cf210f9c..c95a418b 100644 --- a/mlx_vlm/models/llava_bunny/llava_bunny.py +++ b/mlx_vlm/models/llava_bunny/llava_bunny.py @@ -186,7 +186,7 @@ def __call__( ): input_embeddings = self.get_input_embeddings(input_ids, pixel_values) logits = self.language_model( - inputs=input_ids, cache=cache, inputs_embeds=input_embeddings + inputs=input_ids, cache=cache, inputs_embeds=input_embeddings, mask=mask ) return logits diff --git a/mlx_vlm/models/llava_bunny/vision.py b/mlx_vlm/models/llava_bunny/vision.py index 636cbf78..df3e3c57 100644 --- a/mlx_vlm/models/llava_bunny/vision.py +++ b/mlx_vlm/models/llava_bunny/vision.py @@ -207,9 +207,9 @@ def __call__(self, x: mx.array) -> mx.array: batch_size = x.shape[0] patch_embeddings = self.patch_embedding(x) patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) - self.position_ids = mx.array(np.arange(self.num_positions)[None, :]) + position_ids = mx.array(np.arange(self.num_positions)[None, :]) embeddings = patch_embeddings - embeddings += self.position_embedding(self.position_ids) + embeddings += self.position_embedding(position_ids) return embeddings diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index 52551ac3..f1152cab 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -15,7 +15,7 @@ from mlx_vlm.utils import prepare_inputs -class ImageTextDataset: +class Dataset: def __init__( self, hf_dataset, @@ -79,7 +79,6 @@ def __getitem__(self, idx): "Processor does not have 'chat_template' or 'tokenizer' attribute." ) - print(prompts) image_token_index = self.config["image_token_index"] input_ids, pixel_values, mask = prepare_inputs( self.image_processor, self.processor, image, prompts, image_token_index @@ -158,19 +157,45 @@ def default_loss(model, inputs, targets, lengths): class Trainer: - def __init__(self, model, optimizer, loss_fn): + def __init__(self, model, optimizer): self.model = model self.optimizer = optimizer - self.loss_fn = loss_fn - def train_step(self, batch): - images, labels = batch + def loss_fn(self, model, batch): + pixel_values = batch["pixel_values"] + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + labels = mx.where( + attention_mask == 1, input_ids, -100 + ) # Only compute loss on non-padded tokens + + logits = model(input_ids, pixel_values, attention_mask) + + # Ensure logits and labels have the same sequence length + min_length = min(logits.shape[1], labels.shape[1]) + logits = logits[:, :min_length, :] + labels = labels[:, :min_length] + attention_mask = attention_mask[:, :min_length] + + # Shift logits and labels for next-token prediction + shift_logits = logits[:, :-1, :] + shift_labels = labels[:, 1:] + shift_attention_mask = attention_mask[:, 1:] - def loss_fn(model): - logits = model(images) - return self.loss_fn(logits, labels) + # Flatten the tensors + flat_logits = shift_logits.reshape(-1, shift_logits.shape[-1]) + flat_labels = shift_labels.reshape(-1) + flat_attention_mask = shift_attention_mask.reshape(-1) - loss, grads = mx.value_and_grad(loss_fn)(self.model) + # Compute loss only on non-padded tokens + ce = nn.losses.cross_entropy(flat_logits, flat_labels, reduction="none") + ce = (ce * flat_attention_mask).sum() / flat_attention_mask.sum() + + return ce + + def train_step(self, batch): + loss_and_grad_fn = nn.value_and_grad(self.model, self.loss_fn) + loss, grads = loss_and_grad_fn(self.model, batch) self.optimizer.update(self.model, grads) return loss @@ -182,15 +207,6 @@ def train_epoch(self, dataloader): total_loss += loss return total_loss / len(dataloader) - def evaluate(self, dataloader): - correct = total = 0 - for images, labels in dataloader: - logits = self.model(images) - predictions = mx.argmax(logits, axis=1) - correct += mx.sum(predictions == labels) - total += labels.size - return correct / total - def save_adapter( model: nn.Module, From 8fa9bb9f7bcae4bdc9e4cc8b2daf594589f59034 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 7 Sep 2024 21:31:15 +0200 Subject: [PATCH 08/72] support batch processing and train on completions --- mlx_vlm/models/llava_bunny/llava_bunny.py | 36 +++++++++++------------ mlx_vlm/trainer/trainer.py | 25 ++++++++++++++-- 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/mlx_vlm/models/llava_bunny/llava_bunny.py b/mlx_vlm/models/llava_bunny/llava_bunny.py index c95a418b..96fd2da8 100644 --- a/mlx_vlm/models/llava_bunny/llava_bunny.py +++ b/mlx_vlm/models/llava_bunny/llava_bunny.py @@ -118,6 +118,7 @@ def __call__( class Model(nn.Module): def __init__(self, config: ModelConfig): + super().__init__() self.model_type = config.model_type self.config = config @@ -151,31 +152,28 @@ def get_input_embeddings( def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids): image_token_index = self.config.image_token_index - num_images, num_image_patches, embed_dim = image_features.shape - - # Positions of tokens in input_ids, assuming batch size is 1 - image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() + batch_size, seq_length, embed_dim = inputs_embeds.shape + num_images, num_image_patches, _ = image_features.shape - if len(image_positions) != num_images: - raise ValueError( - f"The number of image tokens ({len(image_positions)}) does not " - f" match the number of image inputs ({num_images})." - ) + # Positions of tokens in input_ids for each batch + image_positions = mx.argmax(input_ids == image_token_index, axis=1) - text_segments = [] - start_idx = 0 + final_embeddings = [] + for b in range(batch_size): + text_segments = [] + start_idx = 0 + position = int(image_positions[b].item()) - for position in image_positions: - text_segments.append(inputs_embeds[:, start_idx:position]) - start_idx = position + 1 + text_segments.append(inputs_embeds[b : b + 1, start_idx:position]) + text_segments.append(image_features[b : b + 1]) + text_segments.append(inputs_embeds[b : b + 1, position + 1 :]) - image_embeddings = mx.split(image_features, image_features.shape[0]) - final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] - final_embeddings += [inputs_embeds[:, start_idx:]] + batch_embeddings = mx.concatenate(text_segments, axis=1) + final_embeddings.append(batch_embeddings) # Create a final embedding of shape - # (1, num_image_patches*num_images + sequence_len, embed_dim) - return mx.concatenate(final_embeddings, axis=1) + # (batch_size, num_image_patches + sequence_len, embed_dim) + return mx.concatenate(final_embeddings, axis=0) def __call__( self, diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index f1152cab..04c1d75b 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -168,6 +168,19 @@ def loss_fn(self, model, batch): labels = mx.where( attention_mask == 1, input_ids, -100 ) # Only compute loss on non-padded tokens + weight_mask = mx.ones_like(attention_mask) + + assistant_response_index = np.where(input_ids == 77091)[1] + batch_size, seq_length = input_ids.shape + range_matrix = mx.repeat( + mx.expand_dims(mx.arange(seq_length), 0), batch_size, axis=0 + ) + assistant_mask = range_matrix <= mx.array(assistant_response_index).reshape( + -1, 1 + ) + + # Apply the mask to weight_mask + weight_mask = mx.where(assistant_mask, mx.zeros_like(weight_mask), weight_mask) logits = model(input_ids, pixel_values, attention_mask) @@ -181,15 +194,21 @@ def loss_fn(self, model, batch): shift_logits = logits[:, :-1, :] shift_labels = labels[:, 1:] shift_attention_mask = attention_mask[:, 1:] + shift_weight_mask = weight_mask[:, 1:] # Flatten the tensors flat_logits = shift_logits.reshape(-1, shift_logits.shape[-1]) flat_labels = shift_labels.reshape(-1) flat_attention_mask = shift_attention_mask.reshape(-1) + flat_weight_mask = shift_weight_mask.reshape(-1) # Compute loss only on non-padded tokens - ce = nn.losses.cross_entropy(flat_logits, flat_labels, reduction="none") - ce = (ce * flat_attention_mask).sum() / flat_attention_mask.sum() + ce = ( + nn.losses.cross_entropy(flat_logits, flat_labels, weights=flat_weight_mask) + * flat_attention_mask + ) + ntoks = flat_attention_mask.sum() + ce = ce.sum() / ntoks return ce @@ -197,6 +216,7 @@ def train_step(self, batch): loss_and_grad_fn = nn.value_and_grad(self.model, self.loss_fn) loss, grads = loss_and_grad_fn(self.model, batch) self.optimizer.update(self.model, grads) + return loss @mx.compile @@ -204,6 +224,7 @@ def train_epoch(self, dataloader): total_loss = 0 for batch in dataloader: loss = self.train_step(batch) + mx.eval(self.model, self.optimizer.state) total_loss += loss return total_loss / len(dataloader) From bf9bed6e2e29ec7202f3a7b07424cff8c50c385e Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 16 Sep 2024 10:41:42 +0200 Subject: [PATCH 09/72] fix trainer --- mlx_vlm/trainer/trainer.py | 46 ++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index 04c1d75b..5775dc04 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -157,17 +157,20 @@ def default_loss(model, inputs, targets, lengths): class Trainer: - def __init__(self, model, optimizer): + def __init__(self, model, optimizer, train_on_completions=False): self.model = model self.optimizer = optimizer + self.train_on_completions = train_on_completions def loss_fn(self, model, batch): pixel_values = batch["pixel_values"] input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] + lengths = mx.sum(attention_mask, axis=1) labels = mx.where( attention_mask == 1, input_ids, -100 ) # Only compute loss on non-padded tokens + labels = labels[:, 1:] weight_mask = mx.ones_like(attention_mask) assistant_response_index = np.where(input_ids == 77091)[1] @@ -179,35 +182,34 @@ def loss_fn(self, model, batch): -1, 1 ) - # Apply the mask to weight_mask - weight_mask = mx.where(assistant_mask, mx.zeros_like(weight_mask), weight_mask) + if self.train_on_completions: + # Apply the mask to weight_mask + weight_mask = mx.where( + assistant_mask, mx.zeros_like(weight_mask), weight_mask + )[:, 1:] + else: + weight_mask = None + input_ids = input_ids[:, :-1] logits = model(input_ids, pixel_values, attention_mask) + logits.astype(mx.float32) # Ensure logits and labels have the same sequence length - min_length = min(logits.shape[1], labels.shape[1]) - logits = logits[:, :min_length, :] - labels = labels[:, :min_length] - attention_mask = attention_mask[:, :min_length] - - # Shift logits and labels for next-token prediction - shift_logits = logits[:, :-1, :] - shift_labels = labels[:, 1:] - shift_attention_mask = attention_mask[:, 1:] - shift_weight_mask = weight_mask[:, 1:] - - # Flatten the tensors - flat_logits = shift_logits.reshape(-1, shift_logits.shape[-1]) - flat_labels = shift_labels.reshape(-1) - flat_attention_mask = shift_attention_mask.reshape(-1) - flat_weight_mask = shift_weight_mask.reshape(-1) + if logits.shape[1] != labels.shape[1]: + logits = logits[:, -labels.shape[1] :, :] + + length_mask = mx.arange(input_ids.shape[1])[None, :] < lengths[:, None] # Compute loss only on non-padded tokens ce = ( - nn.losses.cross_entropy(flat_logits, flat_labels, weights=flat_weight_mask) - * flat_attention_mask + nn.losses.cross_entropy( + logits, + labels, + weights=weight_mask, + ) + * length_mask ) - ntoks = flat_attention_mask.sum() + ntoks = length_mask.sum() ce = ce.sum() / ntoks return ce From f00252d656f298eb77d277a31616c37adf81463e Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 16 Sep 2024 11:56:12 +0200 Subject: [PATCH 10/72] formatting --- mlx_vlm/trainer/trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index 5775dc04..a6bddf42 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -171,18 +171,18 @@ def loss_fn(self, model, batch): attention_mask == 1, input_ids, -100 ) # Only compute loss on non-padded tokens labels = labels[:, 1:] - weight_mask = mx.ones_like(attention_mask) - assistant_response_index = np.where(input_ids == 77091)[1] batch_size, seq_length = input_ids.shape - range_matrix = mx.repeat( - mx.expand_dims(mx.arange(seq_length), 0), batch_size, axis=0 - ) - assistant_mask = range_matrix <= mx.array(assistant_response_index).reshape( - -1, 1 - ) if self.train_on_completions: + weight_mask = mx.ones_like(attention_mask) + assistant_response_index = np.where(input_ids == 77091)[1] + range_matrix = mx.repeat( + mx.expand_dims(mx.arange(seq_length), 0), batch_size, axis=0 + ) + assistant_mask = range_matrix <= mx.array(assistant_response_index).reshape( + -1, 1 + ) # Apply the mask to weight_mask weight_mask = mx.where( assistant_mask, mx.zeros_like(weight_mask), weight_mask From f206dedeb65fb2a882b5261b2e92b7f7e916e426 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 28 Sep 2024 16:47:13 +0200 Subject: [PATCH 11/72] add support for none splits and fix assistant id --- mlx_vlm/trainer/trainer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index a6bddf42..c81298e3 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -25,7 +25,10 @@ def __init__( take=None, split="train", ): - self.dataset = hf_dataset[split] + if split is not None: + self.dataset = hf_dataset[split] + else: + self.dataset = hf_dataset if take is not None: self.dataset = self.dataset.take(take) self.processor = processor @@ -157,7 +160,9 @@ def default_loss(model, inputs, targets, lengths): class Trainer: - def __init__(self, model, optimizer, train_on_completions=False): + def __init__( + self, model, optimizer, train_on_completions=False, assistant_id=77091 + ): self.model = model self.optimizer = optimizer self.train_on_completions = train_on_completions @@ -176,7 +181,8 @@ def loss_fn(self, model, batch): if self.train_on_completions: weight_mask = mx.ones_like(attention_mask) - assistant_response_index = np.where(input_ids == 77091)[1] + + assistant_response_index = np.where(input_ids == assistant_id)[1] range_matrix = mx.repeat( mx.expand_dims(mx.arange(seq_length), 0), batch_size, axis=0 ) From dab901c5ce8a3721260a409ae657cad7dcff1ddd Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 28 Sep 2024 17:19:25 +0200 Subject: [PATCH 12/72] Add lora script and docs --- mlx_vlm/LORA.MD | 63 +++++++++++++++++++++++++++++++ mlx_vlm/lora.py | 99 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 162 insertions(+) create mode 100644 mlx_vlm/LORA.MD create mode 100644 mlx_vlm/lora.py diff --git a/mlx_vlm/LORA.MD b/mlx_vlm/LORA.MD new file mode 100644 index 00000000..62b366b9 --- /dev/null +++ b/mlx_vlm/LORA.MD @@ -0,0 +1,63 @@ +# lora.py - NanoLLaVA LoRA Training Script + +## Overview + +`lora.py` is a Python script for fine-tuning a NanoLLaVA model using Low-Rank Adaptation (LoRA). This script allows you to train the model on your custom dataset, adjusting various parameters through command-line arguments. + +## Requirements + +- Python 3.7+ +- MLX VLM library +- Required Python packages: `argparse`, `mlx_vlm`, `mlx` + +## Usage + +To use the script, run it from the command line with the desired arguments: + +``` +python lora.py --dataset /path/to/your/dataset [other options] +``` + +## Arguments + +The script accepts the following command-line arguments: + +- `--model_path`: Path to the pre-trained model (default: "mlx-community/nanoLLaVA-1.5-bf16") +- `--dataset`: Path to your dataset (required) +- `--learning_rate`: Learning rate for the optimizer (default: 1e-4) +- `--batch_size`: Batch size for training (default: 2) +- `--epochs`: Number of epochs to train (default: 1) +- `--steps`: Number of steps per epoch (default: 100) +- `--print_every`: Print loss every n steps (default: 10) +- `--output_path`: Path to save the trained adapter (default: "nanollava_lora_adapter.safetensors") + +## Example + +Here's an example of how to run the script with custom parameters: + +``` +python lora.py --dataset /path/to/your/dataset --epochs 2 --steps 200 --batch_size 4 --learning_rate 5e-5 +``` + +This command will: +- Use the dataset at `/path/to/your/dataset` +- Train for 2 epochs +- Perform 200 steps per epoch +- Use a batch size of 4 +- Set the learning rate to 5e-5 + +## Output + +The script will print the training loss at regular intervals (defined by `--print_every`). After training, it will save the LoRA adapter to the specified output path. + +## Note + +Make sure you have the necessary permissions to read the dataset and write the output file. Also, ensure that your system has sufficient computational resources to handle the specified batch size and model. + +## Contributing + +Feel free to submit issues or pull requests if you find any bugs or have suggestions for improvements. + +## License + +[Specify the license here, e.g., MIT, Apache 2.0, etc.] \ No newline at end of file diff --git a/mlx_vlm/lora.py b/mlx_vlm/lora.py new file mode 100644 index 00000000..6648946b --- /dev/null +++ b/mlx_vlm/lora.py @@ -0,0 +1,99 @@ +import argparse + +import mlx.optimizers as optim + +from mlx_vlm.trainer import Dataset, Trainer +from mlx_vlm.trainer.lora import * +from mlx_vlm.trainer.utils import * +from mlx_vlm.utils import load, load_image_processor + + +def add_image_token(items, image_token=""): + conversations = [] + for item in items["conversations"]: + if item["role"] == "user": + if item["content"].startswith(image_token): + conversations.append({"role": "user", "content": item["content"]}) + else: + conversations.append( + {"role": "user", "content": image_token + "\n" + item["content"]} + ) + else: + conversations.append({"role": "assistant", "content": item["content"]}) + return {"conversations": conversations} + + +def main(args): + model, processor = load( + args.model_path, processor_config={"trust_remote_code": True} + ) + image_processor = load_image_processor(args.model_path) + + dataset = Dataset( + args.dataset, + model.config.__dict__, + processor, + image_processor=image_processor, + take=None, + split=None, + ) + dataset = dataset.map(add_image_token) + + optimizer = optim.Adam(learning_rate=args.learning_rate) + trainer = Trainer(model, optimizer) + + list_of_modules = find_all_linear_names(model.language_model.model) + model = get_peft_model(model, list_of_modules) + + model.vision_tower.freeze() + model.train() + + for epoch in range(args.epochs): + for i in range(args.steps): + loss = trainer.train_step( + dataset[i * args.batch_size : (i + 1) * args.batch_size] + ) + if i % args.print_every == 0: + print(f"Epoch {epoch} Step {i} Loss {loss}") + + save_adapter(model, args.output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train NanoLLaVA model") + parser.add_argument( + "--model_path", + type=str, + default="mlx-community/nanoLLaVA-1.5-bf16", + help="Path to the pre-trained model", + ) + parser.add_argument( + "--dataset", type=str, required=True, help="Path to the dataset" + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Learning rate for the optimizer", + ) + parser.add_argument( + "--batch_size", type=int, default=2, help="Batch size for training" + ) + parser.add_argument( + "--epochs", type=int, default=1, help="Number of epochs to train" + ) + parser.add_argument( + "--steps", type=int, default=100, help="Number of steps per epoch" + ) + parser.add_argument( + "--print_every", type=int, default=10, help="Print loss every n steps" + ) + parser.add_argument( + "--output_path", + type=str, + default="nanollava_lora_adapter.safetensors", + help="Path to save the trained adapter", + ) + + args = parser.parse_args() + main(args) From 607b24991068c94cc7fff82c15ee574ac92a3160 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 26 May 2024 17:16:44 +0200 Subject: [PATCH 13/72] remove torch and mlx-lm --- mlx_vlm/trainer/__init__.py | 2 ++ mlx_vlm/trainer/lora.py | 65 ++++++++++++++++++++++++++++++++++ mlx_vlm/trainer/trainer.py | 70 +++++++++++++++++++++++++++++++++++++ mlx_vlm/trainer/utils.py | 30 ++++++++++++++++ 4 files changed, 167 insertions(+) create mode 100644 mlx_vlm/trainer/__init__.py create mode 100644 mlx_vlm/trainer/lora.py create mode 100644 mlx_vlm/trainer/trainer.py create mode 100644 mlx_vlm/trainer/utils.py diff --git a/mlx_vlm/trainer/__init__.py b/mlx_vlm/trainer/__init__.py new file mode 100644 index 00000000..0450264a --- /dev/null +++ b/mlx_vlm/trainer/__init__.py @@ -0,0 +1,2 @@ +from .utils import collate_fn, find_all_linear_names +from .lora import LoRaLayer, replace_lora_with_linear \ No newline at end of file diff --git a/mlx_vlm/trainer/lora.py b/mlx_vlm/trainer/lora.py new file mode 100644 index 00000000..f27f1cc0 --- /dev/null +++ b/mlx_vlm/trainer/lora.py @@ -0,0 +1,65 @@ +import math +from typing import Union +import mlx.core as mx +import mlx.nn as nn + +class LoRaLayer(nn.Module): + def __init__( + self, + linear: Union[nn.Linear, nn.QuantizedLinear], + rank: int, + alpha: float = 0.1, + dropout: float = 0.0, + + ): + super().__init__() + + self.original_layer = linear + self.dropout = nn.Dropout(p=dropout) + + output_dims, input_dims = linear.weight.shape + + std_dev = 1 / math.sqrt(rank) + + self.A = mx.random.uniform( + low=-std_dev, + high=std_dev, + shape=(input_dims, rank), + ) + self.B = mx.zeros(rank, output_dims) + self.alpha = alpha + + def __call__(self, x): + y = self.original_layer(x) + lora_update = (self.dropout(x) @ self.A) @ self.B + return y + (self.alpha * lora_update).astype(x.dtype) + +def replace_lora_with_linear(model): + for i, layer in enumerate(model.layers): + if isinstance(layer, LoRaLayer): + # Compute the final merged weight + lora_update = layer.alpha * (layer.A @ layer.B) + updated_weight = layer.original_layer.weight + lora_update + use_bias = layer.original_layer.bias is not None + + updated_bias = layer.original_layer.bias + + # Create a new Linear layer with the updated parameters + new_linear_layer = nn.Linear(updated_weight.size(1), updated_weight.size(0), bias=use_bias) + + new_linear_layer.weight = updated_weight + + if use_bias: + new_linear_layer.bias = updated_bias + + if isinstance(layer.original_layer, nn.QuantizedLinear): + new_linear_layer = nn.QuantizedLinear.from_linear( + new_linear_layer, + new_linear_layer.group_size, + new_linear_layer.bits, + ) + + + # Replace the LoRaLayer with the new Linear layer in the model + model.layers[i] = new_linear_layer + diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py new file mode 100644 index 00000000..15111c63 --- /dev/null +++ b/mlx_vlm/trainer/trainer.py @@ -0,0 +1,70 @@ +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Union + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from mlx.utils import tree_flatten + + +def grad_checkpoint(layer): + """ + Update all instances of type(layer) to use gradient checkpointing. + """ + fn = type(layer).__call__ + + def checkpointed_fn(model, *args, **kwargs): + def inner_fn(params, *args, **kwargs): + model.update(params) + return fn(model, *args, **kwargs) + + return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs) + + type(layer).__call__ = checkpointed_fn + +@dataclass +class TrainingArgs: + batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) + iters: int = field(default=100, metadata={"help": "Iterations to train for."}) + val_batches: int = field( + default=25, + metadata={ + "help": "Number of validation batches, -1 uses the entire validation set." + }, + ) + steps_per_report: int = field( + default=10, + metadata={"help": "Number of training steps between loss reporting."}, + ) + steps_per_eval: int = field( + default=200, metadata={"help": "Number of training steps between validations."} + ) + steps_per_save: int = field( + default=100, metadata={"help": "Save the model every number steps"} + ) + max_seq_length: int = field( + default=2048, metadata={"help": "Maximum sequence length."} + ) + adapter_file: str = field( + default="adapters.safetensors", + metadata={"help": "Save/load path for the trained adapter weights."}, + ) + grad_checkpoint: bool = field( + default=False, + metadata={"help": "Use gradient checkpointing to reduce memory use."}, + ) + + +def default_loss(model, inputs, targets, lengths): + logits = model(inputs) + logits = logits.astype(mx.float32) + + length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + + ce = nn.losses.cross_entropy(logits, targets) * length_mask + ntoks = length_mask.sum() + ce = ce.sum() / ntoks + + return ce, ntoks \ No newline at end of file diff --git a/mlx_vlm/trainer/utils.py b/mlx_vlm/trainer/utils.py new file mode 100644 index 00000000..d75bc473 --- /dev/null +++ b/mlx_vlm/trainer/utils.py @@ -0,0 +1,30 @@ + +import mlx.nn as nn +import mlx.core as mx + +def find_all_linear_names(model): + cls = nn.Linear + lora_module_names = set() + multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + + +def collate_fn(processor, examples): + texts = ["answer " + example["question"] for example in examples] + labels= [example['multiple_choice_answer'] for example in examples] + images = [example["image"].convert("RGB") for example in examples] + tokens = processor(text=texts, images=images, suffix=labels, + return_tensors="pt", padding="longest", + tokenize_newline_separately=False) + + tokens = tokens.to(mx.float16) + return tokens \ No newline at end of file From 5c135ac952aea237ae80e90c45ac80c57cc09fbc Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 12 Jun 2024 13:15:08 +0200 Subject: [PATCH 14/72] add peft model creation --- mlx_vlm/trainer/__init__.py | 10 ++- mlx_vlm/trainer/lora.py | 15 +++-- mlx_vlm/trainer/utils.py | 130 +++++++++++++++++++++++++++++++++--- 3 files changed, 136 insertions(+), 19 deletions(-) diff --git a/mlx_vlm/trainer/__init__.py b/mlx_vlm/trainer/__init__.py index 0450264a..d0132a04 100644 --- a/mlx_vlm/trainer/__init__.py +++ b/mlx_vlm/trainer/__init__.py @@ -1,2 +1,8 @@ -from .utils import collate_fn, find_all_linear_names -from .lora import LoRaLayer, replace_lora_with_linear \ No newline at end of file +from .lora import LoRaLayer, replace_lora_with_linear +from .utils import ( + collate_fn, + count_parameters, + find_all_linear_names, + get_peft_model, + print_trainable_parameters, +) diff --git a/mlx_vlm/trainer/lora.py b/mlx_vlm/trainer/lora.py index f27f1cc0..c5afbf2f 100644 --- a/mlx_vlm/trainer/lora.py +++ b/mlx_vlm/trainer/lora.py @@ -1,8 +1,10 @@ import math from typing import Union + import mlx.core as mx import mlx.nn as nn + class LoRaLayer(nn.Module): def __init__( self, @@ -10,11 +12,11 @@ def __init__( rank: int, alpha: float = 0.1, dropout: float = 0.0, - ): super().__init__() self.original_layer = linear + self.dropout = nn.Dropout(p=dropout) output_dims, input_dims = linear.weight.shape @@ -26,13 +28,14 @@ def __init__( high=std_dev, shape=(input_dims, rank), ) - self.B = mx.zeros(rank, output_dims) + self.B = mx.zeros((rank, output_dims)) self.alpha = alpha def __call__(self, x): y = self.original_layer(x) lora_update = (self.dropout(x) @ self.A) @ self.B - return y + (self.alpha * lora_update).astype(x.dtype) + return y + (self.alpha * lora_update).astype(x.dtype) + def replace_lora_with_linear(model): for i, layer in enumerate(model.layers): @@ -45,7 +48,9 @@ def replace_lora_with_linear(model): updated_bias = layer.original_layer.bias # Create a new Linear layer with the updated parameters - new_linear_layer = nn.Linear(updated_weight.size(1), updated_weight.size(0), bias=use_bias) + new_linear_layer = nn.Linear( + updated_weight.size(1), updated_weight.size(0), bias=use_bias + ) new_linear_layer.weight = updated_weight @@ -59,7 +64,5 @@ def replace_lora_with_linear(model): new_linear_layer.bits, ) - # Replace the LoRaLayer with the new Linear layer in the model model.layers[i] = new_linear_layer - diff --git a/mlx_vlm/trainer/utils.py b/mlx_vlm/trainer/utils.py index d75bc473..f6089dae 100644 --- a/mlx_vlm/trainer/utils.py +++ b/mlx_vlm/trainer/utils.py @@ -1,30 +1,138 @@ - -import mlx.nn as nn import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from .lora import LoRaLayer + + +def get_module_by_name(model, name): + parts = name.split(".") + module = model + for part in parts: + if part.isdigit(): + module = module[int(part)] + else: + module = getattr(module, part) + return module + + +def set_module_by_name(model, name, new_module): + parts = name.split(".") + module = model + for part in parts[:-1]: + if part.isdigit(): + module = module[int(part)] + else: + module = getattr(module, part) + if parts[-1].isdigit(): + module[int(parts[-1])] = new_module + else: + setattr(module, parts[-1], new_module) + + +def get_peft_model(model, linear_layers, freeze=True, verbose=True): + source_model_trainable = count_parameters( + model.language_model.trainable_parameters() + ) + + if freeze: + freeze_model(model) + + for name, module in model.named_modules(): + if isinstance(module, nn.Linear) and name.split(".")[-1] in linear_layers: + lora_layer = LoRaLayer(module, 10, 0.1, 0.1) + set_module_by_name(model, name, lora_layer) + + lora_model_trainable = count_parameters(model.language_model.trainable_parameters()) + if verbose: + print_trainable_parameters(source_model_trainable, lora_model_trainable) + + return model + + +def freeze_model(model): + for name, module in model.named_modules(): + if name in [ + "language_model", + "vision_model", + "vision_tower", + "aligner", + "connector", + "multi_modal_projector", + "mm_projector", + ]: + model[f"{name}"].freeze() + def find_all_linear_names(model): cls = nn.Linear lora_module_names = set() - multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] + multimodal_keywords = [ + "mm_projector", + "vision_tower", + "vision_resampler", + "aligner", + ] for name, module in model.named_modules(): if any(mm_keyword in name for mm_keyword in multimodal_keywords): continue if isinstance(module, cls): - names = name.split('.') + names = name.split(".") lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - if 'lm_head' in lora_module_names: # needed for 16-bit - lora_module_names.remove('lm_head') + if "lm_head" in lora_module_names: # needed for 16-bit + lora_module_names.remove("lm_head") return list(lora_module_names) def collate_fn(processor, examples): texts = ["answer " + example["question"] for example in examples] - labels= [example['multiple_choice_answer'] for example in examples] + labels = [example["multiple_choice_answer"] for example in examples] images = [example["image"].convert("RGB") for example in examples] - tokens = processor(text=texts, images=images, suffix=labels, - return_tensors="pt", padding="longest", - tokenize_newline_separately=False) + tokens = processor( + text=texts, + images=images, + suffix=labels, + return_tensors="np", + padding="longest", + tokenize_newline_separately=False, + ) tokens = tokens.to(mx.float16) - return tokens \ No newline at end of file + return tokens + + +def flatten_dict(dd, separator="_", prefix=""): + return ( + { + prefix + separator + k if prefix else k: v + for kk, vv in dd.items() + for k, v in flatten_dict(vv, separator, kk).items() + } + if isinstance(dd, dict) + else {prefix: dd} + ) + + +def count_parameters(trainable_params_dict): + total_params = 0 + for k, v in flatten_dict(trainable_params_dict).items(): + if hasattr(v, "shape"): + total_params += np.prod(v.shape) + + if isinstance(v, list): + for v_ in v: + v_ = flatten_dict(v_) + if isinstance(v_, dict): + total_params += sum( + np.prod(p.shape) for p in v_.values() if hasattr(p, "shape") + ) + + return total_params + + +def print_trainable_parameters(source_model_trainable, lora_model_trainable): + lora_trainable_percent = (lora_model_trainable / source_model_trainable) * 100 + print( + f"#trainable params: {lora_model_trainable} || all params: {source_model_trainable} || trainable%: {lora_trainable_percent}" + ) From 534f20c04fb7a38b7516bbf3258adee23e619683 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 7 Jul 2024 16:55:36 +0200 Subject: [PATCH 15/72] use tree flatten --- mlx_vlm/trainer/utils.py | 28 +++++----------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/mlx_vlm/trainer/utils.py b/mlx_vlm/trainer/utils.py index f6089dae..47df1f52 100644 --- a/mlx_vlm/trainer/utils.py +++ b/mlx_vlm/trainer/utils.py @@ -1,6 +1,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np +from mlx.utils import tree_flatten from .lora import LoRaLayer @@ -102,31 +103,12 @@ def collate_fn(processor, examples): return tokens -def flatten_dict(dd, separator="_", prefix=""): - return ( - { - prefix + separator + k if prefix else k: v - for kk, vv in dd.items() - for k, v in flatten_dict(vv, separator, kk).items() - } - if isinstance(dd, dict) - else {prefix: dd} - ) - - def count_parameters(trainable_params_dict): total_params = 0 - for k, v in flatten_dict(trainable_params_dict).items(): - if hasattr(v, "shape"): - total_params += np.prod(v.shape) - - if isinstance(v, list): - for v_ in v: - v_ = flatten_dict(v_) - if isinstance(v_, dict): - total_params += sum( - np.prod(p.shape) for p in v_.values() if hasattr(p, "shape") - ) + for modules in tree_flatten(trainable_params_dict): + mx_array = modules[-1] + if hasattr(mx_array, "shape"): + total_params += np.prod(mx_array.shape) return total_params From c1edc2279295ed8dd0691d563075881864043760 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Tue, 9 Jul 2024 12:55:18 +0200 Subject: [PATCH 16/72] add dataset loader --- mlx_vlm/trainer/trainer.py | 81 +++++++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index 15111c63..4e1e6741 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -1,3 +1,4 @@ +import os import time from dataclasses import dataclass, field from pathlib import Path @@ -7,6 +8,40 @@ import mlx.nn as nn import numpy as np from mlx.utils import tree_flatten +from PIL import Image + + +class ImageTextDataset: + def __init__(self, image_dir, caption_file, img_size=(224, 224)): + self.image_dir = image_dir + self.img_size = img_size + self.image_captions = [] + self.unique_captions = set() + + with open(caption_file, "r") as f: + for line in f: + image_name, caption = line.strip().split(",") + self.image_captions.append((image_name, caption)) + self.unique_captions.add(caption) + + self.caption_to_id = { + caption: i for i, caption in enumerate(self.unique_captions) + } + + def __len__(self): + return len(self.image_captions) + + def __getitem__(self, idx): + image_name, caption = self.image_captions[idx] + image_path = os.path.join(self.image_dir, image_name) + + image = Image.open(image_path).convert("RGB") + image = image.resize(self.img_size) + image_array = np.array(image).astype(np.float32) / 255.0 + + caption_id = self.caption_to_id[caption] + + return mx.array(image_array), mx.array(caption_id, dtype=mx.int32) def grad_checkpoint(layer): @@ -24,6 +59,7 @@ def inner_fn(params, *args, **kwargs): type(layer).__call__ = checkpointed_fn + @dataclass class TrainingArgs: batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) @@ -67,4 +103,47 @@ def default_loss(model, inputs, targets, lengths): ntoks = length_mask.sum() ce = ce.sum() / ntoks - return ce, ntoks \ No newline at end of file + return ce, ntoks + + +class Trainer: + def __init__(self, model, optimizer, loss_fn): + self.model = model + self.optimizer = optimizer + self.loss_fn = loss_fn + + def train_step(self, batch): + images, labels = batch + + def loss_fn(model): + logits = model(images) + return self.loss_fn(logits, labels) + + loss, grads = mx.value_and_grad(loss_fn)(self.model) + self.optimizer.update(self.model, grads) + return loss + + @mx.compile + def train_epoch(self, dataloader): + total_loss = 0 + for batch in dataloader: + loss = self.train_step(batch) + total_loss += loss + return total_loss / len(dataloader) + + def evaluate(self, dataloader): + correct = total = 0 + for images, labels in dataloader: + logits = self.model(images) + predictions = mx.argmax(logits, axis=1) + correct += mx.sum(predictions == labels) + total += labels.size + return correct / total + + +def save_adapter( + model: nn.Module, + adapter_file: Union[str, Path], +): + flattened_tree = tree_flatten(model.trainable_parameters()) + mx.save_safetensors(str(adapter_file), dict(flattened_tree)) From 91e93051df67f978feab95c04fae44fb729b2937 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 4 Sep 2024 00:17:20 +0200 Subject: [PATCH 17/72] fix dataset --- mlx_vlm/trainer/__init__.py | 1 + mlx_vlm/trainer/trainer.py | 103 +++++++++++++++++++++++++++--------- mlx_vlm/utils.py | 2 +- 3 files changed, 79 insertions(+), 27 deletions(-) diff --git a/mlx_vlm/trainer/__init__.py b/mlx_vlm/trainer/__init__.py index d0132a04..16df5fea 100644 --- a/mlx_vlm/trainer/__init__.py +++ b/mlx_vlm/trainer/__init__.py @@ -1,4 +1,5 @@ from .lora import LoRaLayer, replace_lora_with_linear +from .trainer import * from .utils import ( collate_fn, count_parameters, diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index 4e1e6741..52551ac3 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -2,6 +2,7 @@ import time from dataclasses import dataclass, field from pathlib import Path +from pprint import pprint from typing import Union import mlx.core as mx @@ -10,38 +11,88 @@ from mlx.utils import tree_flatten from PIL import Image +from mlx_vlm.prompt_utils import get_message_json +from mlx_vlm.utils import prepare_inputs + class ImageTextDataset: - def __init__(self, image_dir, caption_file, img_size=(224, 224)): - self.image_dir = image_dir - self.img_size = img_size - self.image_captions = [] - self.unique_captions = set() - - with open(caption_file, "r") as f: - for line in f: - image_name, caption = line.strip().split(",") - self.image_captions.append((image_name, caption)) - self.unique_captions.add(caption) - - self.caption_to_id = { - caption: i for i, caption in enumerate(self.unique_captions) - } + def __init__( + self, + hf_dataset, + config, + processor, + image_processor=None, + take=None, + split="train", + ): + self.dataset = hf_dataset[split] + if take is not None: + self.dataset = self.dataset.take(take) + self.processor = processor + self.config = config + self.image_processor = image_processor def __len__(self): - return len(self.image_captions) + return len(self.dataset) def __getitem__(self, idx): - image_name, caption = self.image_captions[idx] - image_path = os.path.join(self.image_dir, image_name) - - image = Image.open(image_path).convert("RGB") - image = image.resize(self.img_size) - image_array = np.array(image).astype(np.float32) / 255.0 - - caption_id = self.caption_to_id[caption] - - return mx.array(image_array), mx.array(caption_id, dtype=mx.int32) + item = self.dataset[idx] + + # Process image data + image = item["image"] + + conversations = item["conversations"] + # check if conversation is a list of list + if isinstance(conversations, list) and isinstance(conversations[0], list): + prompts = [] + for conversation in conversations: + if "chat_template" in self.processor.__dict__.keys(): + prompts.append( + self.processor.apply_chat_template(conversation, tokenize=False) + ) + + elif "tokenizer" in self.processor.__dict__.keys(): + if self.config["model_type"] != "paligemma": + prompts.append( + self.processor.tokenizer.apply_chat_template( + conversation, tokenize=False + ) + ) + else: + raise ValueError( + "Processor does not have 'chat_template' or 'tokenizer' attribute." + ) + + else: + if "chat_template" in self.processor.__dict__.keys(): + prompts = self.processor.apply_chat_template( + conversations, tokenize=False + ) + + elif "tokenizer" in self.processor.__dict__.keys(): + if self.config["model_type"] != "paligemma": + prompts = self.processor.tokenizer.apply_chat_template( + conversations, tokenize=False + ) + else: + raise ValueError( + "Processor does not have 'chat_template' or 'tokenizer' attribute." + ) + + print(prompts) + image_token_index = self.config["image_token_index"] + input_ids, pixel_values, mask = prepare_inputs( + self.image_processor, self.processor, image, prompts, image_token_index + ) + + if mask is None: + mask = mx.ones_like(input_ids) + + return { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": mask, + } def grad_checkpoint(layer): diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 95453110..9904b8c3 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -704,7 +704,7 @@ def load_image(image_source: Union[str, Path, BytesIO]): ) -def prepare_inputs(image_processor, processor, image, prompt, image_token_index): +def prepare_inputs(image_processor, processor, images, prompts, image_token_index): from transformers.image_utils import load_image mask = None From e5c0424e28aa563635647c9d152c220eb6903ffc Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 4 Sep 2024 11:19:14 +0200 Subject: [PATCH 18/72] fix masks and rename dataset --- mlx_vlm/models/llava_bunny/language.py | 3 +- mlx_vlm/models/llava_bunny/llava_bunny.py | 2 +- mlx_vlm/models/llava_bunny/vision.py | 4 +- mlx_vlm/trainer/trainer.py | 54 +++++++++++++++-------- 4 files changed, 40 insertions(+), 23 deletions(-) diff --git a/mlx_vlm/models/llava_bunny/language.py b/mlx_vlm/models/llava_bunny/language.py index 153a650b..a5a4fb01 100644 --- a/mlx_vlm/models/llava_bunny/language.py +++ b/mlx_vlm/models/llava_bunny/language.py @@ -167,6 +167,7 @@ def __call__( inputs: mx.array, cache=None, inputs_embeds: Optional[mx.array] = None, + mask: Optional[mx.array] = None, ): # for passing merged input embeddings if inputs_embeds is None: @@ -199,7 +200,7 @@ def __call__( inputs_embeds: Optional[mx.array] = None, mask: Optional[mx.array] = None, ): - out = self.model(inputs, cache=cache, inputs_embeds=inputs_embeds) + out = self.model(inputs, cache=cache, inputs_embeds=inputs_embeds, mask=None) return out def sanitize(self, weights): diff --git a/mlx_vlm/models/llava_bunny/llava_bunny.py b/mlx_vlm/models/llava_bunny/llava_bunny.py index 0e145a2f..74b4fe6b 100644 --- a/mlx_vlm/models/llava_bunny/llava_bunny.py +++ b/mlx_vlm/models/llava_bunny/llava_bunny.py @@ -187,7 +187,7 @@ def __call__( ): input_embeddings = self.get_input_embeddings(input_ids, pixel_values) logits = self.language_model( - inputs=input_ids, cache=cache, inputs_embeds=input_embeddings + inputs=input_ids, cache=cache, inputs_embeds=input_embeddings, mask=mask ) return logits diff --git a/mlx_vlm/models/llava_bunny/vision.py b/mlx_vlm/models/llava_bunny/vision.py index 636cbf78..df3e3c57 100644 --- a/mlx_vlm/models/llava_bunny/vision.py +++ b/mlx_vlm/models/llava_bunny/vision.py @@ -207,9 +207,9 @@ def __call__(self, x: mx.array) -> mx.array: batch_size = x.shape[0] patch_embeddings = self.patch_embedding(x) patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) - self.position_ids = mx.array(np.arange(self.num_positions)[None, :]) + position_ids = mx.array(np.arange(self.num_positions)[None, :]) embeddings = patch_embeddings - embeddings += self.position_embedding(self.position_ids) + embeddings += self.position_embedding(position_ids) return embeddings diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index 52551ac3..f1152cab 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -15,7 +15,7 @@ from mlx_vlm.utils import prepare_inputs -class ImageTextDataset: +class Dataset: def __init__( self, hf_dataset, @@ -79,7 +79,6 @@ def __getitem__(self, idx): "Processor does not have 'chat_template' or 'tokenizer' attribute." ) - print(prompts) image_token_index = self.config["image_token_index"] input_ids, pixel_values, mask = prepare_inputs( self.image_processor, self.processor, image, prompts, image_token_index @@ -158,19 +157,45 @@ def default_loss(model, inputs, targets, lengths): class Trainer: - def __init__(self, model, optimizer, loss_fn): + def __init__(self, model, optimizer): self.model = model self.optimizer = optimizer - self.loss_fn = loss_fn - def train_step(self, batch): - images, labels = batch + def loss_fn(self, model, batch): + pixel_values = batch["pixel_values"] + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + labels = mx.where( + attention_mask == 1, input_ids, -100 + ) # Only compute loss on non-padded tokens + + logits = model(input_ids, pixel_values, attention_mask) + + # Ensure logits and labels have the same sequence length + min_length = min(logits.shape[1], labels.shape[1]) + logits = logits[:, :min_length, :] + labels = labels[:, :min_length] + attention_mask = attention_mask[:, :min_length] + + # Shift logits and labels for next-token prediction + shift_logits = logits[:, :-1, :] + shift_labels = labels[:, 1:] + shift_attention_mask = attention_mask[:, 1:] - def loss_fn(model): - logits = model(images) - return self.loss_fn(logits, labels) + # Flatten the tensors + flat_logits = shift_logits.reshape(-1, shift_logits.shape[-1]) + flat_labels = shift_labels.reshape(-1) + flat_attention_mask = shift_attention_mask.reshape(-1) - loss, grads = mx.value_and_grad(loss_fn)(self.model) + # Compute loss only on non-padded tokens + ce = nn.losses.cross_entropy(flat_logits, flat_labels, reduction="none") + ce = (ce * flat_attention_mask).sum() / flat_attention_mask.sum() + + return ce + + def train_step(self, batch): + loss_and_grad_fn = nn.value_and_grad(self.model, self.loss_fn) + loss, grads = loss_and_grad_fn(self.model, batch) self.optimizer.update(self.model, grads) return loss @@ -182,15 +207,6 @@ def train_epoch(self, dataloader): total_loss += loss return total_loss / len(dataloader) - def evaluate(self, dataloader): - correct = total = 0 - for images, labels in dataloader: - logits = self.model(images) - predictions = mx.argmax(logits, axis=1) - correct += mx.sum(predictions == labels) - total += labels.size - return correct / total - def save_adapter( model: nn.Module, From 130d87655f9f624aa9640749675949c2e68315c6 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 7 Sep 2024 21:31:15 +0200 Subject: [PATCH 19/72] support batch processing and train on completions --- mlx_vlm/models/llava_bunny/llava_bunny.py | 36 +++++++++++------------ mlx_vlm/trainer/trainer.py | 25 ++++++++++++++-- 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/mlx_vlm/models/llava_bunny/llava_bunny.py b/mlx_vlm/models/llava_bunny/llava_bunny.py index 74b4fe6b..bbbb5cd5 100644 --- a/mlx_vlm/models/llava_bunny/llava_bunny.py +++ b/mlx_vlm/models/llava_bunny/llava_bunny.py @@ -118,6 +118,7 @@ def __call__( class Model(nn.Module): def __init__(self, config: ModelConfig): + super().__init__() self.model_type = config.model_type self.config = config @@ -151,31 +152,28 @@ def get_input_embeddings( def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids): image_token_index = self.config.image_token_index - num_images, num_image_patches, embed_dim = image_features.shape - - # Positions of tokens in input_ids, assuming batch size is 1 - image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() + batch_size, seq_length, embed_dim = inputs_embeds.shape + num_images, num_image_patches, _ = image_features.shape - if len(image_positions) != num_images: - raise ValueError( - f"The number of image tokens ({len(image_positions)}) does not " - f" match the number of image inputs ({num_images})." - ) + # Positions of tokens in input_ids for each batch + image_positions = mx.argmax(input_ids == image_token_index, axis=1) - text_segments = [] - start_idx = 0 + final_embeddings = [] + for b in range(batch_size): + text_segments = [] + start_idx = 0 + position = int(image_positions[b].item()) - for position in image_positions: - text_segments.append(inputs_embeds[:, start_idx:position]) - start_idx = position + 1 + text_segments.append(inputs_embeds[b : b + 1, start_idx:position]) + text_segments.append(image_features[b : b + 1]) + text_segments.append(inputs_embeds[b : b + 1, position + 1 :]) - image_embeddings = mx.split(image_features, image_features.shape[0]) - final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] - final_embeddings += [inputs_embeds[:, start_idx:]] + batch_embeddings = mx.concatenate(text_segments, axis=1) + final_embeddings.append(batch_embeddings) # Create a final embedding of shape - # (1, num_image_patches*num_images + sequence_len, embed_dim) - return mx.concatenate(final_embeddings, axis=1) + # (batch_size, num_image_patches + sequence_len, embed_dim) + return mx.concatenate(final_embeddings, axis=0) def __call__( self, diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index f1152cab..04c1d75b 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -168,6 +168,19 @@ def loss_fn(self, model, batch): labels = mx.where( attention_mask == 1, input_ids, -100 ) # Only compute loss on non-padded tokens + weight_mask = mx.ones_like(attention_mask) + + assistant_response_index = np.where(input_ids == 77091)[1] + batch_size, seq_length = input_ids.shape + range_matrix = mx.repeat( + mx.expand_dims(mx.arange(seq_length), 0), batch_size, axis=0 + ) + assistant_mask = range_matrix <= mx.array(assistant_response_index).reshape( + -1, 1 + ) + + # Apply the mask to weight_mask + weight_mask = mx.where(assistant_mask, mx.zeros_like(weight_mask), weight_mask) logits = model(input_ids, pixel_values, attention_mask) @@ -181,15 +194,21 @@ def loss_fn(self, model, batch): shift_logits = logits[:, :-1, :] shift_labels = labels[:, 1:] shift_attention_mask = attention_mask[:, 1:] + shift_weight_mask = weight_mask[:, 1:] # Flatten the tensors flat_logits = shift_logits.reshape(-1, shift_logits.shape[-1]) flat_labels = shift_labels.reshape(-1) flat_attention_mask = shift_attention_mask.reshape(-1) + flat_weight_mask = shift_weight_mask.reshape(-1) # Compute loss only on non-padded tokens - ce = nn.losses.cross_entropy(flat_logits, flat_labels, reduction="none") - ce = (ce * flat_attention_mask).sum() / flat_attention_mask.sum() + ce = ( + nn.losses.cross_entropy(flat_logits, flat_labels, weights=flat_weight_mask) + * flat_attention_mask + ) + ntoks = flat_attention_mask.sum() + ce = ce.sum() / ntoks return ce @@ -197,6 +216,7 @@ def train_step(self, batch): loss_and_grad_fn = nn.value_and_grad(self.model, self.loss_fn) loss, grads = loss_and_grad_fn(self.model, batch) self.optimizer.update(self.model, grads) + return loss @mx.compile @@ -204,6 +224,7 @@ def train_epoch(self, dataloader): total_loss = 0 for batch in dataloader: loss = self.train_step(batch) + mx.eval(self.model, self.optimizer.state) total_loss += loss return total_loss / len(dataloader) From 5c028d4248f86f8c618c145a3cb59240c98e700b Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 16 Sep 2024 10:41:42 +0200 Subject: [PATCH 20/72] fix trainer --- mlx_vlm/trainer/trainer.py | 46 ++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index 04c1d75b..5775dc04 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -157,17 +157,20 @@ def default_loss(model, inputs, targets, lengths): class Trainer: - def __init__(self, model, optimizer): + def __init__(self, model, optimizer, train_on_completions=False): self.model = model self.optimizer = optimizer + self.train_on_completions = train_on_completions def loss_fn(self, model, batch): pixel_values = batch["pixel_values"] input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] + lengths = mx.sum(attention_mask, axis=1) labels = mx.where( attention_mask == 1, input_ids, -100 ) # Only compute loss on non-padded tokens + labels = labels[:, 1:] weight_mask = mx.ones_like(attention_mask) assistant_response_index = np.where(input_ids == 77091)[1] @@ -179,35 +182,34 @@ def loss_fn(self, model, batch): -1, 1 ) - # Apply the mask to weight_mask - weight_mask = mx.where(assistant_mask, mx.zeros_like(weight_mask), weight_mask) + if self.train_on_completions: + # Apply the mask to weight_mask + weight_mask = mx.where( + assistant_mask, mx.zeros_like(weight_mask), weight_mask + )[:, 1:] + else: + weight_mask = None + input_ids = input_ids[:, :-1] logits = model(input_ids, pixel_values, attention_mask) + logits.astype(mx.float32) # Ensure logits and labels have the same sequence length - min_length = min(logits.shape[1], labels.shape[1]) - logits = logits[:, :min_length, :] - labels = labels[:, :min_length] - attention_mask = attention_mask[:, :min_length] - - # Shift logits and labels for next-token prediction - shift_logits = logits[:, :-1, :] - shift_labels = labels[:, 1:] - shift_attention_mask = attention_mask[:, 1:] - shift_weight_mask = weight_mask[:, 1:] - - # Flatten the tensors - flat_logits = shift_logits.reshape(-1, shift_logits.shape[-1]) - flat_labels = shift_labels.reshape(-1) - flat_attention_mask = shift_attention_mask.reshape(-1) - flat_weight_mask = shift_weight_mask.reshape(-1) + if logits.shape[1] != labels.shape[1]: + logits = logits[:, -labels.shape[1] :, :] + + length_mask = mx.arange(input_ids.shape[1])[None, :] < lengths[:, None] # Compute loss only on non-padded tokens ce = ( - nn.losses.cross_entropy(flat_logits, flat_labels, weights=flat_weight_mask) - * flat_attention_mask + nn.losses.cross_entropy( + logits, + labels, + weights=weight_mask, + ) + * length_mask ) - ntoks = flat_attention_mask.sum() + ntoks = length_mask.sum() ce = ce.sum() / ntoks return ce From 1bc0aa4001bbae8b52535fac381313dd6090f482 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 16 Sep 2024 11:56:12 +0200 Subject: [PATCH 21/72] formatting --- mlx_vlm/trainer/trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index 5775dc04..a6bddf42 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -171,18 +171,18 @@ def loss_fn(self, model, batch): attention_mask == 1, input_ids, -100 ) # Only compute loss on non-padded tokens labels = labels[:, 1:] - weight_mask = mx.ones_like(attention_mask) - assistant_response_index = np.where(input_ids == 77091)[1] batch_size, seq_length = input_ids.shape - range_matrix = mx.repeat( - mx.expand_dims(mx.arange(seq_length), 0), batch_size, axis=0 - ) - assistant_mask = range_matrix <= mx.array(assistant_response_index).reshape( - -1, 1 - ) if self.train_on_completions: + weight_mask = mx.ones_like(attention_mask) + assistant_response_index = np.where(input_ids == 77091)[1] + range_matrix = mx.repeat( + mx.expand_dims(mx.arange(seq_length), 0), batch_size, axis=0 + ) + assistant_mask = range_matrix <= mx.array(assistant_response_index).reshape( + -1, 1 + ) # Apply the mask to weight_mask weight_mask = mx.where( assistant_mask, mx.zeros_like(weight_mask), weight_mask From d62bf6393519d3fee26afbc14fb5e3f98583916e Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 28 Sep 2024 16:47:13 +0200 Subject: [PATCH 22/72] add support for none splits and fix assistant id --- mlx_vlm/trainer/trainer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index a6bddf42..c81298e3 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -25,7 +25,10 @@ def __init__( take=None, split="train", ): - self.dataset = hf_dataset[split] + if split is not None: + self.dataset = hf_dataset[split] + else: + self.dataset = hf_dataset if take is not None: self.dataset = self.dataset.take(take) self.processor = processor @@ -157,7 +160,9 @@ def default_loss(model, inputs, targets, lengths): class Trainer: - def __init__(self, model, optimizer, train_on_completions=False): + def __init__( + self, model, optimizer, train_on_completions=False, assistant_id=77091 + ): self.model = model self.optimizer = optimizer self.train_on_completions = train_on_completions @@ -176,7 +181,8 @@ def loss_fn(self, model, batch): if self.train_on_completions: weight_mask = mx.ones_like(attention_mask) - assistant_response_index = np.where(input_ids == 77091)[1] + + assistant_response_index = np.where(input_ids == assistant_id)[1] range_matrix = mx.repeat( mx.expand_dims(mx.arange(seq_length), 0), batch_size, axis=0 ) From a6d411ec6a9c1970b5ef12ca6b68e7c24167e71d Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 28 Sep 2024 17:19:25 +0200 Subject: [PATCH 23/72] Add lora script and docs --- mlx_vlm/LORA.MD | 63 +++++++++++++++++++++++++++++++ mlx_vlm/lora.py | 99 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 162 insertions(+) create mode 100644 mlx_vlm/LORA.MD create mode 100644 mlx_vlm/lora.py diff --git a/mlx_vlm/LORA.MD b/mlx_vlm/LORA.MD new file mode 100644 index 00000000..62b366b9 --- /dev/null +++ b/mlx_vlm/LORA.MD @@ -0,0 +1,63 @@ +# lora.py - NanoLLaVA LoRA Training Script + +## Overview + +`lora.py` is a Python script for fine-tuning a NanoLLaVA model using Low-Rank Adaptation (LoRA). This script allows you to train the model on your custom dataset, adjusting various parameters through command-line arguments. + +## Requirements + +- Python 3.7+ +- MLX VLM library +- Required Python packages: `argparse`, `mlx_vlm`, `mlx` + +## Usage + +To use the script, run it from the command line with the desired arguments: + +``` +python lora.py --dataset /path/to/your/dataset [other options] +``` + +## Arguments + +The script accepts the following command-line arguments: + +- `--model_path`: Path to the pre-trained model (default: "mlx-community/nanoLLaVA-1.5-bf16") +- `--dataset`: Path to your dataset (required) +- `--learning_rate`: Learning rate for the optimizer (default: 1e-4) +- `--batch_size`: Batch size for training (default: 2) +- `--epochs`: Number of epochs to train (default: 1) +- `--steps`: Number of steps per epoch (default: 100) +- `--print_every`: Print loss every n steps (default: 10) +- `--output_path`: Path to save the trained adapter (default: "nanollava_lora_adapter.safetensors") + +## Example + +Here's an example of how to run the script with custom parameters: + +``` +python lora.py --dataset /path/to/your/dataset --epochs 2 --steps 200 --batch_size 4 --learning_rate 5e-5 +``` + +This command will: +- Use the dataset at `/path/to/your/dataset` +- Train for 2 epochs +- Perform 200 steps per epoch +- Use a batch size of 4 +- Set the learning rate to 5e-5 + +## Output + +The script will print the training loss at regular intervals (defined by `--print_every`). After training, it will save the LoRA adapter to the specified output path. + +## Note + +Make sure you have the necessary permissions to read the dataset and write the output file. Also, ensure that your system has sufficient computational resources to handle the specified batch size and model. + +## Contributing + +Feel free to submit issues or pull requests if you find any bugs or have suggestions for improvements. + +## License + +[Specify the license here, e.g., MIT, Apache 2.0, etc.] \ No newline at end of file diff --git a/mlx_vlm/lora.py b/mlx_vlm/lora.py new file mode 100644 index 00000000..6648946b --- /dev/null +++ b/mlx_vlm/lora.py @@ -0,0 +1,99 @@ +import argparse + +import mlx.optimizers as optim + +from mlx_vlm.trainer import Dataset, Trainer +from mlx_vlm.trainer.lora import * +from mlx_vlm.trainer.utils import * +from mlx_vlm.utils import load, load_image_processor + + +def add_image_token(items, image_token=""): + conversations = [] + for item in items["conversations"]: + if item["role"] == "user": + if item["content"].startswith(image_token): + conversations.append({"role": "user", "content": item["content"]}) + else: + conversations.append( + {"role": "user", "content": image_token + "\n" + item["content"]} + ) + else: + conversations.append({"role": "assistant", "content": item["content"]}) + return {"conversations": conversations} + + +def main(args): + model, processor = load( + args.model_path, processor_config={"trust_remote_code": True} + ) + image_processor = load_image_processor(args.model_path) + + dataset = Dataset( + args.dataset, + model.config.__dict__, + processor, + image_processor=image_processor, + take=None, + split=None, + ) + dataset = dataset.map(add_image_token) + + optimizer = optim.Adam(learning_rate=args.learning_rate) + trainer = Trainer(model, optimizer) + + list_of_modules = find_all_linear_names(model.language_model.model) + model = get_peft_model(model, list_of_modules) + + model.vision_tower.freeze() + model.train() + + for epoch in range(args.epochs): + for i in range(args.steps): + loss = trainer.train_step( + dataset[i * args.batch_size : (i + 1) * args.batch_size] + ) + if i % args.print_every == 0: + print(f"Epoch {epoch} Step {i} Loss {loss}") + + save_adapter(model, args.output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train NanoLLaVA model") + parser.add_argument( + "--model_path", + type=str, + default="mlx-community/nanoLLaVA-1.5-bf16", + help="Path to the pre-trained model", + ) + parser.add_argument( + "--dataset", type=str, required=True, help="Path to the dataset" + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Learning rate for the optimizer", + ) + parser.add_argument( + "--batch_size", type=int, default=2, help="Batch size for training" + ) + parser.add_argument( + "--epochs", type=int, default=1, help="Number of epochs to train" + ) + parser.add_argument( + "--steps", type=int, default=100, help="Number of steps per epoch" + ) + parser.add_argument( + "--print_every", type=int, default=10, help="Print loss every n steps" + ) + parser.add_argument( + "--output_path", + type=str, + default="nanollava_lora_adapter.safetensors", + help="Path to save the trained adapter", + ) + + args = parser.parse_args() + main(args) From c1033b58835780405bdb71953a68352569569c99 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 29 Sep 2024 22:03:31 +0200 Subject: [PATCH 24/72] remove duplicates --- mlx_vlm/models/llava_bunny/language.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlx_vlm/models/llava_bunny/language.py b/mlx_vlm/models/llava_bunny/language.py index 7e041ec3..a5a4fb01 100644 --- a/mlx_vlm/models/llava_bunny/language.py +++ b/mlx_vlm/models/llava_bunny/language.py @@ -168,7 +168,6 @@ def __call__( cache=None, inputs_embeds: Optional[mx.array] = None, mask: Optional[mx.array] = None, - mask: Optional[mx.array] = None, ): # for passing merged input embeddings if inputs_embeds is None: @@ -201,7 +200,6 @@ def __call__( inputs_embeds: Optional[mx.array] = None, mask: Optional[mx.array] = None, ): - out = self.model(inputs, cache=cache, inputs_embeds=inputs_embeds, mask=None) out = self.model(inputs, cache=cache, inputs_embeds=inputs_embeds, mask=None) return out From 80cdcd621072934a01d6ff3df762fe4a7215b663 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 29 Sep 2024 22:37:35 +0200 Subject: [PATCH 25/72] fix batch load --- mlx_vlm/utils.py | 66 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 45 insertions(+), 21 deletions(-) diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 196ddfec..b97bdb73 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -708,11 +708,20 @@ def prepare_inputs(image_processor, processor, images, prompts, image_token_inde from transformers.image_utils import load_image mask = None - if isinstance(image, str): - image = load_image(image) + if not isinstance(images, list): + images = [images] + if not isinstance(prompts, list): + prompts = [prompts] + + assert len(images) == len( + prompts + ), f"Number of images ({len(images)}) and prompts ({len(prompts)}) must match" + + images = [load_image(img) if isinstance(img, str) else img for img in images] image_grid_thw = None if image_processor is not None: + processor.pad_token = processor.eos_token text_chunks = [ [processor(chunk).input_ids for chunk in prompt.split("")] for prompt in prompts @@ -731,36 +740,51 @@ def prepare_inputs(image_processor, processor, images, prompts, image_token_inde input_ids.append(mx.array(ids + padding)) input_ids = mx.array(input_ids) - pixel_values = image_processor.preprocess(images=loaded_images) + pixel_values = image_processor.preprocess(images=images) pixel_values = mx.array(np.stack(pixel_values)) - masks = mx.array([(ids != processor.pad_token_id) for ids in input_ids]).astype( + + mask = mx.array([(ids != processor.pad_token_id) for ids in input_ids]).astype( mx.int32 ) else: processor.tokenizer.pad_token = processor.tokenizer.eos_token + try: inputs = processor( - text=[prompt], images=[image], padding=True, return_tensors="mlx" + text=prompts, images=images, padding=True, return_tensors="mlx" ) - except Exception as e: - inputs = processor( - text=prompt, images=[image], padding=True, return_tensors="mlx" - ) # for phi3_v model - - if isinstance(inputs["pixel_values"], list): - pixel_values = mx.array(inputs["pixel_values"][0][0])[None, :] - elif isinstance(inputs["pixel_values"], np.ndarray): pixel_values = mx.array(inputs["pixel_values"]) - else: - raise ValueError( - f"Invalid pixel_values type: {type(inputs['pixel_values'])}" + input_ids = mx.array(inputs["input_ids"]) + mask = mx.array(inputs["attention_mask"]) + image_grid_thw = inputs.get("image_grid_thw", None) + if image_grid_thw is not None: + image_grid_thw = mx.array(image_grid_thw) + + except Exception as e: + inputs = [] + for i in range(len(images)): + inputs.append( + processor( + text=str(prompts[i]), + images=images[i], + padding=True, + return_tensors="mlx", + ) + ) + input_ids = mx.concatenate( + [mx.array(i["input_ids"]) for i in inputs], axis=0 + ) + pixel_values = mx.concatenate( + [mx.array(i["pixel_values"]) for i in inputs], axis=0 + ) + mask = mx.concatenate( + [mx.array(i["attention_mask"]) for i in inputs], axis=0 + ) + image_sizes = mx.concatenate( + [mx.array(i["image_sizes"]) for i in inputs], axis=0 ) - input_ids = mx.array(inputs["input_ids"]) - mask = inputs["attention_mask"] - image_grid_thw = inputs.get("image_grid_thw", None) - if "image_sizes" in inputs: - return input_ids, pixel_values, inputs["image_sizes"], image_grid_thw + return input_ids, pixel_values, image_sizes, image_grid_thw return input_ids, pixel_values, mask, image_grid_thw From 935e1bde46f840d5229cbde52ee84e49bcb80ebb Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 29 Sep 2024 23:25:01 +0200 Subject: [PATCH 26/72] load trained adapters and add super to all models --- mlx_vlm/__init__.py | 3 +- mlx_vlm/generate.py | 16 ++++++-- mlx_vlm/models/idefics2/idefics2.py | 1 + mlx_vlm/models/llava/llava.py | 1 + mlx_vlm/models/llava_next/llava_next.py | 1 + .../models/multi_modality/multi_modality.py | 1 + mlx_vlm/models/paligemma/paligemma.py | 1 + mlx_vlm/models/pixtral/pixtral.py | 1 + mlx_vlm/models/qwen2_vl/qwen2_vl.py | 1 + mlx_vlm/trainer/__init__.py | 3 +- mlx_vlm/trainer/trainer.py | 17 ++++++--- mlx_vlm/trainer/utils.py | 38 +++++++++++++++++-- mlx_vlm/utils.py | 9 ++++- 13 files changed, 78 insertions(+), 15 deletions(-) diff --git a/mlx_vlm/__init__.py b/mlx_vlm/__init__.py index 03cb6f1f..50494f86 100644 --- a/mlx_vlm/__init__.py +++ b/mlx_vlm/__init__.py @@ -1,2 +1,3 @@ -from .utils import convert, generate, load +from .prompt_utils import get_message_json +from .utils import convert, generate, load, prepare_inputs from .version import __version__ diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 3e95cc20..514e1bca 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -25,6 +25,12 @@ def parse_arguments(): default=DEFAULT_MODEL_PATH, help="The path to the local model directory or Hugging Face repo.", ) + parser.add_argument( + "--adapter-path", + type=str, + default=None, + help="The path to the adapter weights.", + ) parser.add_argument( "--image", type=str, @@ -50,17 +56,21 @@ def parse_arguments(): return parser.parse_args() -def get_model_and_processors(model_path): +def get_model_and_processors(model_path, adapter_path): model_path = get_model_path(model_path) config = load_config(model_path) - model, processor = load(model_path, {"trust_remote_code": True}) + model, processor = load( + model_path, {"trust_remote_code": True}, adapter_path=adapter_path + ) image_processor = load_image_processor(model_path) return model, processor, image_processor, config def main(): args = parse_arguments() - model, processor, image_processor, config = get_model_and_processors(args.model) + model, processor, image_processor, config = get_model_and_processors( + args.model, args.adapter_path + ) prompt = codecs.decode(args.prompt, "unicode_escape") diff --git a/mlx_vlm/models/idefics2/idefics2.py b/mlx_vlm/models/idefics2/idefics2.py index 1c78365a..52085dd3 100644 --- a/mlx_vlm/models/idefics2/idefics2.py +++ b/mlx_vlm/models/idefics2/idefics2.py @@ -197,6 +197,7 @@ def __call__(self, x: mx.array, mask=None) -> mx.array: class Model(nn.Module): def __init__(self, config: ModelConfig): + super().__init__() self.model_type = config.model_type self.config = config diff --git a/mlx_vlm/models/llava/llava.py b/mlx_vlm/models/llava/llava.py index 39aae4a7..39298a9b 100644 --- a/mlx_vlm/models/llava/llava.py +++ b/mlx_vlm/models/llava/llava.py @@ -56,6 +56,7 @@ def __call__(self, x: mx.array) -> mx.array: class Model(nn.Module): def __init__(self, config: ModelConfig): + super().__init__() self.config = config self.vision_tower = VisionModel(config.vision_config) self.language_model = LanguageModel(config.text_config) diff --git a/mlx_vlm/models/llava_next/llava_next.py b/mlx_vlm/models/llava_next/llava_next.py index 878d7ca0..29abea1f 100644 --- a/mlx_vlm/models/llava_next/llava_next.py +++ b/mlx_vlm/models/llava_next/llava_next.py @@ -56,6 +56,7 @@ def __call__(self, x: mx.array) -> mx.array: class Model(nn.Module): def __init__(self, config: ModelConfig): + super().__init__() self.config = config self.vision_tower = VisionModel(config.vision_config) self.language_model = LanguageModel(config.text_config) diff --git a/mlx_vlm/models/multi_modality/multi_modality.py b/mlx_vlm/models/multi_modality/multi_modality.py index 52a0bc9f..c1d9df9b 100644 --- a/mlx_vlm/models/multi_modality/multi_modality.py +++ b/mlx_vlm/models/multi_modality/multi_modality.py @@ -238,6 +238,7 @@ def __call__(self, x: Union[mx.array, Tuple]) -> mx.array: class Model(nn.Module): def __init__(self, config: ModelConfig): + super().__init__() self.config = config self.vision_model = VisionModel(config.vision_config) self.language_model = LanguageModel(config.text_config) diff --git a/mlx_vlm/models/paligemma/paligemma.py b/mlx_vlm/models/paligemma/paligemma.py index 74007388..fabdac5b 100644 --- a/mlx_vlm/models/paligemma/paligemma.py +++ b/mlx_vlm/models/paligemma/paligemma.py @@ -52,6 +52,7 @@ def __call__(self, x: mx.array) -> mx.array: class Model(nn.Module): def __init__(self, config: ModelConfig): + super().__init__() self.model_type = config.model_type self.config = config diff --git a/mlx_vlm/models/pixtral/pixtral.py b/mlx_vlm/models/pixtral/pixtral.py index b49397b5..3ae1d2e7 100644 --- a/mlx_vlm/models/pixtral/pixtral.py +++ b/mlx_vlm/models/pixtral/pixtral.py @@ -56,6 +56,7 @@ def __call__(self, x: mx.array) -> mx.array: class Model(nn.Module): def __init__(self, config: ModelConfig): + super().__init__() self.config = config self.vision_tower = VisionModel(config.vision_config) self.language_model = LanguageModel(config.text_config) diff --git a/mlx_vlm/models/qwen2_vl/qwen2_vl.py b/mlx_vlm/models/qwen2_vl/qwen2_vl.py index dcdbbefa..b042aa3d 100644 --- a/mlx_vlm/models/qwen2_vl/qwen2_vl.py +++ b/mlx_vlm/models/qwen2_vl/qwen2_vl.py @@ -39,6 +39,7 @@ def from_dict(cls, params): class Model(nn.Module): def __init__(self, config: ModelConfig): + super().__init__() self.config = config self.vision_tower = VisionModel(config.vision_config) self.language_model = LanguageModel(config.text_config) diff --git a/mlx_vlm/trainer/__init__.py b/mlx_vlm/trainer/__init__.py index 16df5fea..9afca41c 100644 --- a/mlx_vlm/trainer/__init__.py +++ b/mlx_vlm/trainer/__init__.py @@ -1,6 +1,7 @@ from .lora import LoRaLayer, replace_lora_with_linear -from .trainer import * +from .trainer import Dataset, Trainer from .utils import ( + apply_lora_layers, collate_fn, count_parameters, find_all_linear_names, diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index c81298e3..4b0625f9 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -11,9 +11,6 @@ from mlx.utils import tree_flatten from PIL import Image -from mlx_vlm.prompt_utils import get_message_json -from mlx_vlm.utils import prepare_inputs - class Dataset: def __init__( @@ -39,6 +36,8 @@ def __len__(self): return len(self.dataset) def __getitem__(self, idx): + from mlx_vlm.utils import prepare_inputs + item = self.dataset[idx] # Process image data @@ -83,7 +82,7 @@ def __getitem__(self, idx): ) image_token_index = self.config["image_token_index"] - input_ids, pixel_values, mask = prepare_inputs( + input_ids, pixel_values, mask, image_grid_thw = prepare_inputs( self.image_processor, self.processor, image, prompts, image_token_index ) @@ -94,6 +93,7 @@ def __getitem__(self, idx): "pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": mask, + "image_grid_thw": image_grid_thw, } @@ -197,7 +197,14 @@ def loss_fn(self, model, batch): weight_mask = None input_ids = input_ids[:, :-1] - logits = model(input_ids, pixel_values, attention_mask) + + kwargs = ( + {"image_grid_thw": batch["image_grid_thw"]} + if "image_grid_thw" in batch + else {} + ) + logits = model(input_ids, pixel_values, attention_mask, **kwargs) + logits.astype(mx.float32) # Ensure logits and labels have the same sequence length diff --git a/mlx_vlm/trainer/utils.py b/mlx_vlm/trainer/utils.py index 47df1f52..9f36ec17 100644 --- a/mlx_vlm/trainer/utils.py +++ b/mlx_vlm/trainer/utils.py @@ -1,3 +1,5 @@ +from pathlib import Path + import mlx.core as mx import mlx.nn as nn import numpy as np @@ -40,9 +42,10 @@ def get_peft_model(model, linear_layers, freeze=True, verbose=True): freeze_model(model) for name, module in model.named_modules(): - if isinstance(module, nn.Linear) and name.split(".")[-1] in linear_layers: - lora_layer = LoRaLayer(module, 10, 0.1, 0.1) - set_module_by_name(model, name, lora_layer) + if isinstance(module, nn.Linear) or isinstance(module, nn.QuantizedLinear): + if name.split(".")[-1] in linear_layers: + lora_layer = LoRaLayer(module, 10, 0.1, 0.1) + set_module_by_name(model, name, lora_layer) lora_model_trainable = count_parameters(model.language_model.trainable_parameters()) if verbose: @@ -67,6 +70,7 @@ def freeze_model(model): def find_all_linear_names(model): cls = nn.Linear + quantized_cls = nn.QuantizedLinear lora_module_names = set() multimodal_keywords = [ "mm_projector", @@ -77,7 +81,7 @@ def find_all_linear_names(model): for name, module in model.named_modules(): if any(mm_keyword in name for mm_keyword in multimodal_keywords): continue - if isinstance(module, cls): + if isinstance(module, cls) or isinstance(module, quantized_cls): names = name.split(".") lora_module_names.add(names[0] if len(names) == 1 else names[-1]) @@ -118,3 +122,29 @@ def print_trainable_parameters(source_model_trainable, lora_model_trainable): print( f"#trainable params: {lora_model_trainable} || all params: {source_model_trainable} || trainable%: {lora_trainable_percent}" ) + + +def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module: + """ + Apply LoRA layers to the model. + + Args: + model (nn.Module): The neural network model. + adapter_path (str): Path to the adapter configuration file. + + Returns: + nn.Module: The updated model with LoRA layers applied. + """ + adapter_path = Path(adapter_path) + + if not adapter_path.exists(): + raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}") + + # TODO: add lora params to the config and load them here + list_of_modules = find_all_linear_names(model.language_model.model) + model = get_peft_model(model, list_of_modules) + + # TODO: Use custom adapter name + model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False) + + return model diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index b97bdb73..2b517181 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -26,7 +26,8 @@ from .models.base import BaseImageProcessor, KVCache from .sample_utils import top_p_sampling -from .tokenizer_utils import TokenizerWrapper, load_tokenizer +from .tokenizer_utils import load_tokenizer +from .trainer import apply_lora_layers # Constants MODEL_REMAPPING = {"llava-qwen2": "llava_bunny", "bunny-llama": "llava_bunny"} @@ -223,6 +224,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: def load( path_or_hf_repo: str, processor_config={}, + adapter_path: Optional[str] = None, lazy: bool = False, ) -> Tuple[nn.Module, Union[PreTrainedTokenizer, PreTrainedTokenizerFast]]: """ @@ -247,6 +249,11 @@ def load( model_path = get_model_path(path_or_hf_repo) model = load_model(model_path, lazy) + if adapter_path is not None: + # TODO: Support more modules than just language_model + model = apply_lora_layers(model, adapter_path) + model.eval() + processor = load_processor(model_path, processor_config=processor_config) return model, processor From 8ba507d22245d94c51f1db54687a7afe487d44d8 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 30 Sep 2024 00:36:16 +0200 Subject: [PATCH 27/72] fix pixtral quant --- mlx_vlm/models/pixtral/pixtral.py | 2 +- mlx_vlm/utils.py | 32 +++++++++++++++---------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/mlx_vlm/models/pixtral/pixtral.py b/mlx_vlm/models/pixtral/pixtral.py index 3ae1d2e7..f3857c17 100644 --- a/mlx_vlm/models/pixtral/pixtral.py +++ b/mlx_vlm/models/pixtral/pixtral.py @@ -182,7 +182,7 @@ def from_pretrained(path_or_hf_repo: str): def sanitize(self, weights): def transform_key(key): - if "vision_tower" in key: + if "vision_tower" in key and "vision_model" not in key: if "transformer" in key: key = key.replace("vision_tower", "vision_tower.vision_model") if "patch_conv" in key: diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 2b517181..ec60ec09 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -295,14 +295,15 @@ def load_image_processor(model_path: Union[str, Path]) -> BaseImageProcessor: def load_processor( - model_path, processor_config={"trust_remote_code": True} + model_path, processor_config={"trust_remote_code": True}, add_detokenizer=True ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: processor = AutoProcessor.from_pretrained(model_path, **processor_config) - detokenizer_class = load_tokenizer(model_path, return_tokenizer=False) - if "tokenizer" in processor.__dict__.keys(): - processor.detokenizer = detokenizer_class(processor.tokenizer) - else: - processor.detokenizer = detokenizer_class(processor) + if add_detokenizer: + detokenizer_class = load_tokenizer(model_path, return_tokenizer=False) + if "tokenizer" in processor.__dict__.keys(): + processor.detokenizer = detokenizer_class(processor.tokenizer) + else: + processor.detokenizer = detokenizer_class(processor) return processor @@ -311,8 +312,7 @@ def fetch_from_hub( ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: model = load_model(model_path, lazy) config = load_config(model_path) - processor = load_processor(model_path) - + processor = load_processor(model_path, add_detokenizer=False) return model, config, processor @@ -644,7 +644,7 @@ def convert( ): print("[INFO] Loading") model_path = get_model_path(hf_path, revision=revision) - model, config, tokenizer = fetch_from_hub(model_path, lazy=False) + model, config, processor = fetch_from_hub(model_path, lazy=False) weights = dict(tree_flatten(model.parameters())) dtype = mx.float16 if quantize else getattr(mx, dtype) @@ -673,7 +673,7 @@ def convert( for file in py_files: shutil.copy(file, mlx_path) - tokenizer.save_pretrained(mlx_path) + processor.save_pretrained(mlx_path) save_config(config, config_path=mlx_path / "config.json") @@ -724,7 +724,7 @@ def prepare_inputs(image_processor, processor, images, prompts, image_token_inde prompts ), f"Number of images ({len(images)}) and prompts ({len(prompts)}) must match" - images = [load_image(img) if isinstance(img, str) else img for img in images] + images = [img for img in images] image_grid_thw = None if image_processor is not None: @@ -755,12 +755,12 @@ def prepare_inputs(image_processor, processor, images, prompts, image_token_inde ) else: processor.tokenizer.pad_token = processor.tokenizer.eos_token - try: - inputs = processor( - text=prompts, images=images, padding=True, return_tensors="mlx" - ) - pixel_values = mx.array(inputs["pixel_values"]) + inputs = processor(text=prompts, images=images, return_tensors="mlx") + if isinstance(inputs["pixel_values"], list): + pixel_values = mx.array(inputs["pixel_values"][0][0])[None, :] + else: + pixel_values = mx.array(inputs["pixel_values"]) input_ids = mx.array(inputs["input_ids"]) mask = mx.array(inputs["attention_mask"]) image_grid_thw = inputs.get("image_grid_thw", None) From 23598ad1c7b04c59fa916b4f615c188da34907ab Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 30 Sep 2024 00:36:53 +0200 Subject: [PATCH 28/72] speed up qwen batch processing --- mlx_vlm/models/qwen2_vl/qwen2_vl.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/mlx_vlm/models/qwen2_vl/qwen2_vl.py b/mlx_vlm/models/qwen2_vl/qwen2_vl.py index b042aa3d..54efba29 100644 --- a/mlx_vlm/models/qwen2_vl/qwen2_vl.py +++ b/mlx_vlm/models/qwen2_vl/qwen2_vl.py @@ -73,10 +73,27 @@ def _merge_input_ids_with_image_features( ): image_token_index = self.config.image_token_index - # Positions of tokens in input_ids, assuming batch size is 1 + # Positions of tokens in input_ids image_positions = input_ids == image_token_index + + # Convert inputs_embeds to numpy array if it's not already inputs_embeds = np.array(inputs_embeds.astype(mx.float32)) - inputs_embeds[image_positions] = image_features + + # Reshape image_features to match the batch size and number of image tokens + batch_size, seq_length, hidden_size = inputs_embeds.shape + num_images = image_positions.sum(axis=1) + max_images = num_images.max() + + if max_images > 0: + image_features = image_features.reshape(batch_size, max_images, -1) + + # Create a mask for valid image positions + valid_image_mask = np.arange(max_images) < num_images[:, None] + + # Use broadcasting to assign image features to the correct positions + inputs_embeds[image_positions] = ( + image_features * valid_image_mask[:, :, None] + ).reshape(-1, hidden_size) # TODO: Add video features From dc2226efed80f583eb0d9dd849cd718092e27d1d Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Tue, 1 Oct 2024 15:13:36 +0200 Subject: [PATCH 29/72] fix qlora training --- mlx_vlm/trainer/lora.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlx_vlm/trainer/lora.py b/mlx_vlm/trainer/lora.py index c5afbf2f..c139e472 100644 --- a/mlx_vlm/trainer/lora.py +++ b/mlx_vlm/trainer/lora.py @@ -20,6 +20,8 @@ def __init__( self.dropout = nn.Dropout(p=dropout) output_dims, input_dims = linear.weight.shape + if isinstance(linear, nn.QuantizedLinear): + input_dims *= 32 // linear.bits std_dev = 1 / math.sqrt(rank) From 4cb7956427998e1ed28097a1bbb5b37b5a8e82ab Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Tue, 1 Oct 2024 15:14:18 +0200 Subject: [PATCH 30/72] fix dataloader --- mlx_vlm/trainer/trainer.py | 74 +++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 38 deletions(-) diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index 4b0625f9..eea35834 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -1,8 +1,8 @@ +import json import os import time from dataclasses import dataclass, field from pathlib import Path -from pprint import pprint from typing import Union import mlx.core as mx @@ -11,6 +11,26 @@ from mlx.utils import tree_flatten from PIL import Image +from ..prompt_utils import apply_chat_template + + +def get_prompt(processor, conversation): + if "chat_template" in processor.__dict__.keys(): + prompt = processor.apply_chat_template( + conversation, + tokenize=False, + add_generation_prompt=False, + ) + + elif "tokenizer" in processor.__dict__.keys(): + prompt = processor.tokenizer.apply_chat_template( + conversation, + tokenize=False, + add_generation_prompt=False, + ) + + return prompt + class Dataset: def __init__( @@ -40,50 +60,24 @@ def __getitem__(self, idx): item = self.dataset[idx] - # Process image data - image = item["image"] + images = item["images"] + conversations = item["messages"] + prompts = [] + + if self.config["model_type"] == "paligemma": + raise NotImplementedError("Paligemma is not supported yet") - conversations = item["conversations"] - # check if conversation is a list of list if isinstance(conversations, list) and isinstance(conversations[0], list): - prompts = [] for conversation in conversations: - if "chat_template" in self.processor.__dict__.keys(): - prompts.append( - self.processor.apply_chat_template(conversation, tokenize=False) - ) - - elif "tokenizer" in self.processor.__dict__.keys(): - if self.config["model_type"] != "paligemma": - prompts.append( - self.processor.tokenizer.apply_chat_template( - conversation, tokenize=False - ) - ) - else: - raise ValueError( - "Processor does not have 'chat_template' or 'tokenizer' attribute." - ) - + prompt = get_prompt(self.processor, conversation) + prompts.append(prompt) else: - if "chat_template" in self.processor.__dict__.keys(): - prompts = self.processor.apply_chat_template( - conversations, tokenize=False - ) - - elif "tokenizer" in self.processor.__dict__.keys(): - if self.config["model_type"] != "paligemma": - prompts = self.processor.tokenizer.apply_chat_template( - conversations, tokenize=False - ) - else: - raise ValueError( - "Processor does not have 'chat_template' or 'tokenizer' attribute." - ) + prompt = get_prompt(self.processor, conversations) + prompts.append(prompt) image_token_index = self.config["image_token_index"] input_ids, pixel_values, mask, image_grid_thw = prepare_inputs( - self.image_processor, self.processor, image, prompts, image_token_index + self.image_processor, self.processor, images, prompts, image_token_index ) if mask is None: @@ -248,5 +242,9 @@ def save_adapter( model: nn.Module, adapter_file: Union[str, Path], ): + path = Path(adapter_file) + if hasattr(model.config, "lora"): + with open(path.parent / "adapter_config.json", "w") as f: + json.dump(model.config.lora, f) flattened_tree = tree_flatten(model.trainable_parameters()) mx.save_safetensors(str(adapter_file), dict(flattened_tree)) From 0659d45ace930f7a53c8246bbc725167511ac66c Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 2 Oct 2024 02:17:45 +0200 Subject: [PATCH 31/72] formatting --- mlx_vlm/generate.py | 2 - mlx_vlm/prompt_utils.py | 108 +++++++++++++++++++++++++++++++++------- mlx_vlm/utils.py | 21 +++++--- 3 files changed, 102 insertions(+), 29 deletions(-) diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 514e1bca..5f4cadbe 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -1,8 +1,6 @@ import argparse import codecs -import mlx.core as mx - from .prompt_utils import apply_chat_template from .utils import generate, get_model_path, load, load_config, load_image_processor diff --git a/mlx_vlm/prompt_utils.py b/mlx_vlm/prompt_utils.py index b47a9a1a..c2fc465e 100644 --- a/mlx_vlm/prompt_utils.py +++ b/mlx_vlm/prompt_utils.py @@ -1,4 +1,4 @@ -def get_message_json(model_name, prompt): +def get_message_json(model_name, prompt, role="user", skip_image_token=False): """ Get the appropriate JSON message based on the specified model. @@ -11,46 +11,116 @@ def get_message_json(model_name, prompt): Returns: dict: A dictionary representing the JSON message for the specified model. """ - if model_name.lower() in ["idefics2", "qwen2_vl", "llava"]: + if model_name.lower() in ["idefics2", "qwen2_vl", "llava", "llava_next"]: message = { - "role": "user", - "content": [{"type": "image"}, {"type": "text", "text": prompt}], + "role": role, + "content": [ + {"type": "text", "text": prompt}, + ], } + if role == "user" and not skip_image_token: + message["content"].append({"type": "image"}) + elif model_name.lower() in ["llava-qwen2", "bunny-llama"]: + + message = {"role": role} + if role == "user" and not skip_image_token: + message["content"] = f"\n{prompt}" + else: + message["content"] = prompt - elif model_name.lower() in ["llava-qwen2", "llava_next", "bunny-llama"]: - message = {"role": "user", "content": f"\n{prompt}"} elif model_name.lower() == "phi3_v": - message = {"role": "user", "content": f"<|image_1|>\n{prompt}"} + message = {"role": role} + if role == "user" and not skip_image_token: + message["content"] = f"<|image_1|>\n{prompt}" + else: + message["content"] = prompt + elif model_name.lower() == "multi_modality": - message = {"role": "user", "content": f"{prompt}"} + message = {"role": role} + if role == "user" and not skip_image_token: + message["content"] = f"{prompt}" + else: + message["content"] = prompt + elif model_name.lower() == "pixtral": + message = {"role": role, "content": prompt} + + if role == "user" and not skip_image_token: + message["content"] = [ + {"type": "text", "content": prompt}, + ] + message["content"].append({"type": "image"}) elif model_name.lower() == "paligemma": message = prompt - elif model_name.lower() == "pixtral": - message = { - "role": "user", - "content": [{"type": "image"}, {"type": "text", "content": prompt}], - } else: raise ValueError(f"Unsupported model: {model_name}") return message -def apply_chat_template(processor, config, prompt): - message = get_message_json(config["model_type"], prompt) +def apply_chat_template( + processor, config, prompt, add_generation_prompt=True, return_messages=False +): + messages = [] + if isinstance(prompt, list): + if isinstance(prompt[0], dict) and len(prompt) >= 1: + for i, p in enumerate(prompt): + if isinstance(p, str): + message = get_message_json( + config["model_type"], p, skip_image_token=i >= 1 + ) + elif isinstance(p, dict) and "role" in p.keys(): + message = get_message_json( + config["model_type"], + p["content"], + p["role"], + skip_image_token=i >= 1, + ) + else: + raise ValueError("Invalid prompt type") + messages.append(message) + else: + for prompts in prompt: + for i, p in enumerate(prompts): + if isinstance(p, str): + message = get_message_json( + config["model_type"], p, skip_image_token=i >= 1 + ) + elif isinstance(p, dict) and "role" in p.keys(): + message = get_message_json( + config["model_type"], + p["content"], + p["role"], + skip_image_token=i >= 1, + ) + else: + raise ValueError("Invalid prompt type") + messages.append(message) + else: + if isinstance(prompt, str): + message = get_message_json(config["model_type"], prompt) + elif isinstance(prompt, dict) and "role" in prompt.keys(): + message = get_message_json( + config["model_type"], prompt["content"], prompt["role"] + ) + else: + raise ValueError("Invalid prompt type") + messages.append(message) + + if return_messages: + return messages if "chat_template" in processor.__dict__.keys(): return processor.apply_chat_template( - [message], + messages, tokenize=False, - add_generation_prompt=True, + add_generation_prompt=add_generation_prompt, ) elif "tokenizer" in processor.__dict__.keys(): return processor.tokenizer.apply_chat_template( - [message], + messages, tokenize=False, - add_generation_prompt=True, + add_generation_prompt=add_generation_prompt, ) else: diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index ec60ec09..9c46fbde 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -720,14 +720,16 @@ def prepare_inputs(image_processor, processor, images, prompts, image_token_inde if not isinstance(prompts, list): prompts = [prompts] - assert len(images) == len( - prompts - ), f"Number of images ({len(images)}) and prompts ({len(prompts)}) must match" + if len(images) != len(prompts): + print( + f"Number of images ({len(images)}) and prompts ({len(prompts)}) don't match" + ) - images = [img for img in images] + images = [load_image(img) if isinstance(img, str) else img for img in images] image_grid_thw = None if image_processor is not None: + processor.pad_token = processor.eos_token text_chunks = [ [processor(chunk).input_ids for chunk in prompt.split("")] @@ -747,6 +749,7 @@ def prepare_inputs(image_processor, processor, images, prompts, image_token_inde input_ids.append(mx.array(ids + padding)) input_ids = mx.array(input_ids) + pixel_values = image_processor.preprocess(images=images) pixel_values = mx.array(np.stack(pixel_values)) @@ -756,9 +759,11 @@ def prepare_inputs(image_processor, processor, images, prompts, image_token_inde else: processor.tokenizer.pad_token = processor.tokenizer.eos_token try: - inputs = processor(text=prompts, images=images, return_tensors="mlx") + inputs = processor( + text=prompts, images=images, padding=True, return_tensors="mlx" + ) if isinstance(inputs["pixel_values"], list): - pixel_values = mx.array(inputs["pixel_values"][0][0])[None, :] + pixel_values = inputs["pixel_values"] else: pixel_values = mx.array(inputs["pixel_values"]) input_ids = mx.array(inputs["input_ids"]) @@ -769,11 +774,11 @@ def prepare_inputs(image_processor, processor, images, prompts, image_token_inde except Exception as e: inputs = [] - for i in range(len(images)): + for i, image in enumerate(images): inputs.append( processor( text=str(prompts[i]), - images=images[i], + images=image, padding=True, return_tensors="mlx", ) From 1880162c35559df8837fe925ada2dbafb09218cc Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 2 Oct 2024 02:18:22 +0200 Subject: [PATCH 32/72] fix pixtral pixel loading --- mlx_vlm/models/pixtral/pixtral.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/mlx_vlm/models/pixtral/pixtral.py b/mlx_vlm/models/pixtral/pixtral.py index f3857c17..8cf041a6 100644 --- a/mlx_vlm/models/pixtral/pixtral.py +++ b/mlx_vlm/models/pixtral/pixtral.py @@ -76,23 +76,17 @@ def get_input_embeddings( inputs_embeds = self.language_model.model.embed_tokens(input_ids) # Get the ouptut hidden states from the vision model + if isinstance(pixel_values, list): + pixel_values = mx.array(pixel_values[0][0])[None, ...] + if pixel_values.ndim == 3: + pixel_values = pixel_values[None, ...] + *_, hidden_states = self.vision_tower( pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True ) - # Select the hidden states from the desired layer selected_image_feature = hidden_states[self.vision_feature_layer] - if self.vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif self.vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - else: - raise ValueError( - "Unexpected feature selection strategy: " - f"{self.vision_feature_select_strategy}" - ) - # Pass image features through the multi-modal projector image_features = self.multi_modal_projector(selected_image_feature) From 858caab46205101e0e6d43b42fd6b9a5be83fe18 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 2 Oct 2024 02:20:30 +0200 Subject: [PATCH 33/72] fix lora and dataset --- mlx_vlm/lora.py | 117 ++++++++++++++++++++++++++---------- mlx_vlm/trainer/__init__.py | 2 +- mlx_vlm/trainer/trainer.py | 25 +++++++- mlx_vlm/trainer/utils.py | 59 ++++++++++++------ 4 files changed, 148 insertions(+), 55 deletions(-) diff --git a/mlx_vlm/lora.py b/mlx_vlm/lora.py index 6648946b..1acf8381 100644 --- a/mlx_vlm/lora.py +++ b/mlx_vlm/lora.py @@ -1,60 +1,98 @@ import argparse +import json +import logging import mlx.optimizers as optim +from datasets import load_dataset +from tqdm import tqdm -from mlx_vlm.trainer import Dataset, Trainer -from mlx_vlm.trainer.lora import * -from mlx_vlm.trainer.utils import * +from mlx_vlm.prompt_utils import apply_chat_template +from mlx_vlm.trainer import Dataset, Trainer, save_adapter +from mlx_vlm.trainer.utils import find_all_linear_names, get_peft_model from mlx_vlm.utils import load, load_image_processor - -def add_image_token(items, image_token=""): - conversations = [] - for item in items["conversations"]: - if item["role"] == "user": - if item["content"].startswith(image_token): - conversations.append({"role": "user", "content": item["content"]}) - else: - conversations.append( - {"role": "user", "content": image_token + "\n" + item["content"]} - ) - else: - conversations.append({"role": "assistant", "content": item["content"]}) - return {"conversations": conversations} +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) def main(args): + logger.info(f"\033[32mLoading model from {args.model_path}\033[0m") model, processor = load( args.model_path, processor_config={"trust_remote_code": True} ) + config = model.config.__dict__ image_processor = load_image_processor(args.model_path) + logger.info(f"\033[32mLoading dataset from {args.dataset}\033[0m") + dataset = load_dataset(args.dataset, split=args.split) + + if "messages" not in dataset.column_names: + raise ValueError("Dataset must have a 'messages' column") + if "images" not in dataset.column_names: + raise ValueError("Dataset must have an 'images' column") + + if args.apply_chat_template: + logger.info(f"\033[32mApplying chat template to the dataset\033[0m") + + def process_data(examples): + if config["model_type"] == "pixtral": + conversations = apply_chat_template( + config=config, + processor=processor, + prompt=examples["messages"], + return_messages=True, + ) + examples["messages"] = [ + json.dumps(item, ensure_ascii=False) for item in conversations + ] + else: + examples["messages"] = apply_chat_template( + config=config, + processor=processor, + prompt=examples["messages"], + return_messages=True, + ) + return examples + + dataset = dataset.map(process_data) + dataset = Dataset( - args.dataset, - model.config.__dict__, + dataset, + config, processor, image_processor=image_processor, take=None, split=None, ) - dataset = dataset.map(add_image_token) + logger.info(f"\033[32mSetting up LoRA\033[0m") + list_of_modules = find_all_linear_names(model.language_model) + model = get_peft_model( + model, + list_of_modules, + rank=args.lora_rank, + alpha=args.lora_alpha, + dropout=args.lora_dropout, + ) + + logger.info(f"\033[32mSetting up optimizer\033[0m") optimizer = optim.Adam(learning_rate=args.learning_rate) - trainer = Trainer(model, optimizer) - list_of_modules = find_all_linear_names(model.language_model.model) - model = get_peft_model(model, list_of_modules) + logger.info(f"\033[32mSetting up trainer\033[0m") + trainer = Trainer(model, optimizer) - model.vision_tower.freeze() model.train() for epoch in range(args.epochs): - for i in range(args.steps): + if args.steps == 0: + args.steps = len(dataset) // args.batch_size + + for i in tqdm(range(args.steps)): loss = trainer.train_step( dataset[i * args.batch_size : (i + 1) * args.batch_size] ) if i % args.print_every == 0: - print(f"Epoch {epoch} Step {i} Loss {loss}") + print(f"Epoch {epoch} Step {i} Loss {loss.item():.4f}") save_adapter(model, args.output_path) @@ -62,7 +100,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train NanoLLaVA model") parser.add_argument( - "--model_path", + "--model-path", type=str, default="mlx-community/nanoLLaVA-1.5-bf16", help="Path to the pre-trained model", @@ -71,27 +109,40 @@ def main(args): "--dataset", type=str, required=True, help="Path to the dataset" ) parser.add_argument( - "--learning_rate", + "--split", type=str, default="train", help="Split to use for training" + ) + parser.add_argument( + "--apply-chat-template", + action="store_false", + help="Apply chat template to the dataset", + ) + parser.add_argument( + "--learning-rate", type=float, default=1e-4, help="Learning rate for the optimizer", ) parser.add_argument( - "--batch_size", type=int, default=2, help="Batch size for training" + "--batch-size", type=int, default=1, help="Batch size for training" ) parser.add_argument( "--epochs", type=int, default=1, help="Number of epochs to train" ) parser.add_argument( - "--steps", type=int, default=100, help="Number of steps per epoch" + "--steps", type=int, default=10, help="Number of steps per epoch" + ) + parser.add_argument( + "--print-every", type=int, default=10, help="Print loss every n steps" ) parser.add_argument( - "--print_every", type=int, default=10, help="Print loss every n steps" + "--lora-alpha", type=int, default=0.1, help="LoRA alpha parameter" ) + parser.add_argument("--lora-rank", type=int, default=10, help="LoRA rank") + parser.add_argument("--lora-dropout", type=float, default=0.1, help="LoRA dropout") parser.add_argument( - "--output_path", + "--output-path", type=str, - default="nanollava_lora_adapter.safetensors", + default="adapters", help="Path to save the trained adapter", ) diff --git a/mlx_vlm/trainer/__init__.py b/mlx_vlm/trainer/__init__.py index 9afca41c..33630b92 100644 --- a/mlx_vlm/trainer/__init__.py +++ b/mlx_vlm/trainer/__init__.py @@ -1,5 +1,5 @@ from .lora import LoRaLayer, replace_lora_with_linear -from .trainer import Dataset, Trainer +from .trainer import * from .utils import ( apply_lora_layers, collate_fn, diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index eea35834..4fa7bb3d 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -1,6 +1,7 @@ import json import os import time +import warnings from dataclasses import dataclass, field from pathlib import Path from typing import Union @@ -69,9 +70,19 @@ def __getitem__(self, idx): if isinstance(conversations, list) and isinstance(conversations[0], list): for conversation in conversations: + if self.config["model_type"] == "pixtral": + conversation = [json.loads(i) for i in conversation] + if len(conversations) > 1: + warnings.warn( + "Pixtral batch processing is not supported yet. Set batch size to 1." + ) + prompt = get_prompt(self.processor, conversation) prompts.append(prompt) + else: + if self.config["model_type"] == "pixtral": + conversations = [json.loads(i) for i in conversations] prompt = get_prompt(self.processor, conversations) prompts.append(prompt) @@ -160,6 +171,7 @@ def __init__( self.model = model self.optimizer = optimizer self.train_on_completions = train_on_completions + self.assistant_id = assistant_id def loss_fn(self, model, batch): pixel_values = batch["pixel_values"] @@ -176,7 +188,7 @@ def loss_fn(self, model, batch): if self.train_on_completions: weight_mask = mx.ones_like(attention_mask) - assistant_response_index = np.where(input_ids == assistant_id)[1] + assistant_response_index = np.where(input_ids == self.assistant_id)[1] range_matrix = mx.repeat( mx.expand_dims(mx.arange(seq_length), 0), batch_size, axis=0 ) @@ -204,6 +216,17 @@ def loss_fn(self, model, batch): # Ensure logits and labels have the same sequence length if logits.shape[1] != labels.shape[1]: logits = logits[:, -labels.shape[1] :, :] + if logits.shape[1] != labels.shape[1]: + # pad logits with -100 + pad_length = labels.shape[1] - logits.shape[1] + pad_width = ( + (0, 0), + (0, pad_length), + (0, 0), + ) # Padding for each dimension + logits = mx.pad( + logits, pad_width, mode="constant", constant_values=-100 + ) length_mask = mx.arange(input_ids.shape[1])[None, :] < lengths[:, None] diff --git a/mlx_vlm/trainer/utils.py b/mlx_vlm/trainer/utils.py index 9f36ec17..892d1761 100644 --- a/mlx_vlm/trainer/utils.py +++ b/mlx_vlm/trainer/utils.py @@ -33,29 +33,32 @@ def set_module_by_name(model, name, new_module): setattr(module, parts[-1], new_module) -def get_peft_model(model, linear_layers, freeze=True, verbose=True): - source_model_trainable = count_parameters( - model.language_model.trainable_parameters() - ) - +def get_peft_model( + model, linear_layers, rank=10, alpha=0.1, dropout=0.1, freeze=True, verbose=True +): if freeze: freeze_model(model) - for name, module in model.named_modules(): + for name, module in model.language_model.named_modules(): if isinstance(module, nn.Linear) or isinstance(module, nn.QuantizedLinear): if name.split(".")[-1] in linear_layers: - lora_layer = LoRaLayer(module, 10, 0.1, 0.1) - set_module_by_name(model, name, lora_layer) + lora_layer = LoRaLayer(module, rank, alpha, dropout) + set_module_by_name(model.language_model, name, lora_layer) + + model.config.lora = {} + model.config.lora["rank"] = rank + model.config.lora["alpha"] = alpha + model.config.lora["dropout"] = dropout - lora_model_trainable = count_parameters(model.language_model.trainable_parameters()) if verbose: - print_trainable_parameters(source_model_trainable, lora_model_trainable) + print_trainable_parameters(model.language_model) return model def freeze_model(model): for name, module in model.named_modules(): + name = name.split(".")[0] if name in [ "language_model", "vision_model", @@ -107,20 +110,36 @@ def collate_fn(processor, examples): return tokens -def count_parameters(trainable_params_dict): - total_params = 0 - for modules in tree_flatten(trainable_params_dict): - mx_array = modules[-1] - if hasattr(mx_array, "shape"): - total_params += np.prod(mx_array.shape) +def count_parameters(model): + def nparams(m): + if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): + return m.weight.size * (32 // m.bits) + return sum(v.size for _, v in tree_flatten(m.parameters())) + + leaf_modules = tree_flatten( + model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) + ) + total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6 + + return total_p + - return total_params +def print_trainable_parameters(model): + def nparams(m): + if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): + return m.weight.size * (32 // m.bits) + return sum(v.size for _, v in tree_flatten(m.parameters())) + leaf_modules = tree_flatten( + model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) + ) + total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6 + trainable_p = ( + sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 + ) -def print_trainable_parameters(source_model_trainable, lora_model_trainable): - lora_trainable_percent = (lora_model_trainable / source_model_trainable) * 100 print( - f"#trainable params: {lora_model_trainable} || all params: {source_model_trainable} || trainable%: {lora_trainable_percent}" + f"#trainable params: {trainable_p} M || all params: {total_p} M || trainable%: {(trainable_p * 100 / total_p):.3f}%" ) From 5e5ab716e9bdb74759d2d779800a108f964893bc Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 2 Oct 2024 02:21:02 +0200 Subject: [PATCH 34/72] add batch processing suppor for qwen2_vl --- mlx_vlm/models/qwen2_vl/qwen2_vl.py | 24 +++++------------------- mlx_vlm/models/qwen2_vl/vision.py | 13 ++++++++++++- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/mlx_vlm/models/qwen2_vl/qwen2_vl.py b/mlx_vlm/models/qwen2_vl/qwen2_vl.py index 54efba29..47d264e3 100644 --- a/mlx_vlm/models/qwen2_vl/qwen2_vl.py +++ b/mlx_vlm/models/qwen2_vl/qwen2_vl.py @@ -62,6 +62,9 @@ def get_input_embeddings( pixel_values, image_grid_thw, output_hidden_states=False ) + if hidden_states.ndim == 2: + hidden_states = hidden_states[None, :, :] + # Insert special image tokens in the input_ids final_inputs_embeds = self._merge_input_ids_with_image_features( hidden_states, inputs_embeds, input_ids @@ -73,27 +76,10 @@ def _merge_input_ids_with_image_features( ): image_token_index = self.config.image_token_index - # Positions of tokens in input_ids + # Positions of tokens in input_ids, assuming batch size is 1 image_positions = input_ids == image_token_index - - # Convert inputs_embeds to numpy array if it's not already inputs_embeds = np.array(inputs_embeds.astype(mx.float32)) - - # Reshape image_features to match the batch size and number of image tokens - batch_size, seq_length, hidden_size = inputs_embeds.shape - num_images = image_positions.sum(axis=1) - max_images = num_images.max() - - if max_images > 0: - image_features = image_features.reshape(batch_size, max_images, -1) - - # Create a mask for valid image positions - valid_image_mask = np.arange(max_images) < num_images[:, None] - - # Use broadcasting to assign image features to the correct positions - inputs_embeds[image_positions] = ( - image_features * valid_image_mask[:, :, None] - ).reshape(-1, hidden_size) + inputs_embeds[image_positions] = image_features # TODO: Add video features diff --git a/mlx_vlm/models/qwen2_vl/vision.py b/mlx_vlm/models/qwen2_vl/vision.py index f07f880e..48b01c56 100644 --- a/mlx_vlm/models/qwen2_vl/vision.py +++ b/mlx_vlm/models/qwen2_vl/vision.py @@ -323,7 +323,18 @@ def __call__( hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_thw) - cu_seqlens = mx.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]) + # Assuming grid_thw has shape (batch_size, 3) + batch_size = grid_thw.shape[0] + + # Calculate cu_seqlens for each item in the batch + cu_seqlens = [] + for i in range(batch_size): + seq_len = grid_thw[i, 1] * grid_thw[i, 2] + cu_seqlens.append(mx.repeat(seq_len, grid_thw[i, 0])) + + # Concatenate the cu_seqlens for all items in the batch + cu_seqlens = mx.concatenate(cu_seqlens) + cu_seqlens = mx.cumsum(cu_seqlens.astype(mx.int32)) cu_seqlens = mx.pad(cu_seqlens, (1, 0), mode="constant", constant_values=0) From 86baba3cdeda203a91eedfbfbd44ed36cce545f5 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 2 Oct 2024 02:21:20 +0200 Subject: [PATCH 35/72] update lora docs --- mlx_vlm/LORA.MD | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mlx_vlm/LORA.MD b/mlx_vlm/LORA.MD index 62b366b9..e847e4b2 100644 --- a/mlx_vlm/LORA.MD +++ b/mlx_vlm/LORA.MD @@ -10,6 +10,19 @@ - MLX VLM library - Required Python packages: `argparse`, `mlx_vlm`, `mlx` +## Supported Models +- Qwen2 +- LLaVA (except for LLaVA-Next) +- Pixtral +- Idefics 2 + +## Coming Soon +- LLaVA-Next +- Phi3_vision +- Paligemma + +Note: The script only works with model in full or half precision. Quantized models are not supported at the moment. + ## Usage To use the script, run it from the command line with the desired arguments: From 4ce361d72dd96be517871278f9c04d1a49b33008 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Thu, 3 Oct 2024 00:42:34 +0200 Subject: [PATCH 36/72] add unit tests --- mlx_vlm/tests/test_trainer.py | 133 ++++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 mlx_vlm/tests/test_trainer.py diff --git a/mlx_vlm/tests/test_trainer.py b/mlx_vlm/tests/test_trainer.py new file mode 100644 index 00000000..519bfe41 --- /dev/null +++ b/mlx_vlm/tests/test_trainer.py @@ -0,0 +1,133 @@ +import unittest +from unittest.mock import MagicMock, patch + +import mlx.core as mx +import mlx.nn as nn + +from mlx_vlm.trainer.trainer import Dataset, Trainer, TrainingArgs +from mlx_vlm.utils import prepare_inputs + + +class TestDataset(unittest.TestCase): + def setUp(self): + self.mock_hf_dataset = MagicMock() + self.mock_config = {"model_type": "test_model", "image_token_index": 1} + self.mock_processor = MagicMock() + self.mock_image_processor = MagicMock() + + @patch("mlx_vlm.utils.prepare_inputs") + def test_dataset_initialization(self, mock_prepare_inputs): + dataset = Dataset( + self.mock_hf_dataset, + self.mock_config, + self.mock_processor, + self.mock_image_processor, + take=10, + split="train", + ) + + self.assertEqual(len(dataset), len(self.mock_hf_dataset["train"].take(10))) + self.assertEqual(dataset.config, self.mock_config) + self.assertEqual(dataset.processor, self.mock_processor) + self.assertEqual(dataset.image_processor, self.mock_image_processor) + + @patch("mlx_vlm.trainer.trainer.get_prompt") + @patch("mlx_vlm.utils.prepare_inputs") + def test_dataset_getitem(self, mock_prepare_inputs, mock_get_prompt): + dataset = Dataset( + self.mock_hf_dataset, + self.mock_config, + self.mock_processor, + self.mock_image_processor, + ) + + mock_item = { + "images": ["image1.jpg"], + "messages": [{"role": "user", "content": "Hello"}], + } + self.mock_hf_dataset.__getitem__.return_value = mock_item + + mock_get_prompt.return_value = "Mocked prompt" + + mock_prepare_inputs.return_value = ( + mx.array([1, 2, 3]), # input_ids + mx.array( + [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]] + ), # pixel_values + mx.array([1, 1, 1]), # mask + (1, 1, 1), # image_grid_thw + [224, 224], # image_sizes + ) + + result = dataset[0] + + mock_prepare_inputs.assert_called_once() + self.assertIn("pixel_values", result) + self.assertIn("input_ids", result) + self.assertIn("attention_mask", result) + self.assertIn("image_grid_thw", result) + self.assertIn("image_sizes", result) + + # Check if the returned values match the mocked input + self.assertTrue(mx.array_equal(result["input_ids"], mx.array([1, 2, 3]))) + self.assertTrue( + mx.array_equal( + result["pixel_values"], + mx.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]), + ) + ) + self.assertTrue(mx.array_equal(result["attention_mask"], mx.array([1, 1, 1]))) + self.assertEqual(result["image_grid_thw"], (1, 1, 1)) + self.assertEqual(result["image_sizes"], [224, 224]) + + +class TestTrainer(unittest.TestCase): + def setUp(self): + self.mock_model = MagicMock(spec=nn.Module) + self.mock_optimizer = MagicMock() + self.trainer = Trainer(self.mock_model, self.mock_optimizer) + + def test_trainer_initialization(self): + self.assertEqual(self.trainer.model, self.mock_model) + self.assertEqual(self.trainer.optimizer, self.mock_optimizer) + self.assertFalse(self.trainer.train_on_completions) + self.assertEqual(self.trainer.assistant_id, 77091) + + @patch("mlx.nn.losses.cross_entropy") + def test_loss_fn(self, mock_cross_entropy): + batch = { + "pixel_values": mx.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), + "input_ids": mx.array([[1, 2, 3], [4, 5, 6]]), + "attention_mask": mx.array([[1, 1, 1], [1, 1, 0]]), + "image_grid_thw": (1, 1, 1), + "image_sizes": [224, 224], + } + + self.mock_model.return_value = mx.array([[[0.1, 0.2, 0.3]], [[0.4, 0.5, 0.6]]]) + mock_cross_entropy.return_value = mx.array([[0.1, 0.2], [0.3, 0.4]]) + + loss = self.trainer.loss_fn(self.mock_model, batch) + + self.assertIsInstance(loss, mx.array) + self.assertEqual(loss.shape, ()) # Scalar value + + @patch.object(Trainer, "loss_fn") + @patch("mlx.nn.value_and_grad") + def test_train_step(self, mock_value_and_grad, mock_loss_fn): + mock_batch = MagicMock() + mock_loss = mx.array(0.5) + mock_grads = {"param1": mx.array([0.1, 0.2]), "param2": mx.array([0.3, 0.4])} + + mock_value_and_grad.return_value = lambda *args, **kwargs: ( + mock_loss, + mock_grads, + ) + + loss = self.trainer.train_step(mock_batch) + + self.mock_optimizer.update.assert_called_once_with(self.mock_model, mock_grads) + self.assertEqual(loss, mock_loss) + + +if __name__ == "__main__": + unittest.main() From eac9ee1b14f30a08ce901eec8c98ff6692a45cc4 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Thu, 3 Oct 2024 00:44:20 +0200 Subject: [PATCH 37/72] set stage for phi3_v support --- mlx_vlm/models/phi3_v/phi3_v.py | 3 +- mlx_vlm/trainer/trainer.py | 49 +++++++++++++++++++-------------- mlx_vlm/utils.py | 13 ++++++--- 3 files changed, 39 insertions(+), 26 deletions(-) diff --git a/mlx_vlm/models/phi3_v/phi3_v.py b/mlx_vlm/models/phi3_v/phi3_v.py index 770f0018..01285f5f 100644 --- a/mlx_vlm/models/phi3_v/phi3_v.py +++ b/mlx_vlm/models/phi3_v/phi3_v.py @@ -209,9 +209,10 @@ def __call__( pixel_values=None, mask=None, cache=None, + image_sizes=None, **kwargs, ): - out = self.model(inputs, pixel_values, mask, cache) + out = self.model(inputs, pixel_values, image_sizes, cache) return self.lm_head(out).astype(self.lm_head.weight.dtype) @property diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index 4fa7bb3d..e8c39040 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -15,7 +15,9 @@ from ..prompt_utils import apply_chat_template -def get_prompt(processor, conversation): +def get_prompt(model_type, processor, conversation): + if model_type == "paligemma": + return conversation if "chat_template" in processor.__dict__.keys(): prompt = processor.apply_chat_template( conversation, @@ -65,9 +67,6 @@ def __getitem__(self, idx): conversations = item["messages"] prompts = [] - if self.config["model_type"] == "paligemma": - raise NotImplementedError("Paligemma is not supported yet") - if isinstance(conversations, list) and isinstance(conversations[0], list): for conversation in conversations: if self.config["model_type"] == "pixtral": @@ -77,17 +76,21 @@ def __getitem__(self, idx): "Pixtral batch processing is not supported yet. Set batch size to 1." ) - prompt = get_prompt(self.processor, conversation) + prompt = get_prompt( + self.config["model_type"], self.processor, conversation + ) prompts.append(prompt) else: if self.config["model_type"] == "pixtral": conversations = [json.loads(i) for i in conversations] - prompt = get_prompt(self.processor, conversations) + prompt = get_prompt( + self.config["model_type"], self.processor, conversations + ) prompts.append(prompt) image_token_index = self.config["image_token_index"] - input_ids, pixel_values, mask, image_grid_thw = prepare_inputs( + input_ids, pixel_values, mask, image_grid_thw, image_sizes = prepare_inputs( self.image_processor, self.processor, images, prompts, image_token_index ) @@ -99,6 +102,7 @@ def __getitem__(self, idx): "input_ids": input_ids, "attention_mask": mask, "image_grid_thw": image_grid_thw, + "image_sizes": image_sizes, } @@ -205,28 +209,31 @@ def loss_fn(self, model, batch): input_ids = input_ids[:, :-1] kwargs = ( - {"image_grid_thw": batch["image_grid_thw"]} - if "image_grid_thw" in batch + { + "image_grid_thw": batch["image_grid_thw"], + "image_sizes": batch["image_sizes"], + } + if "image_grid_thw" in batch or "image_sizes" in batch else {} ) + + # Forward pass logits = model(input_ids, pixel_values, attention_mask, **kwargs) + # Cast to float32 logits.astype(mx.float32) # Ensure logits and labels have the same sequence length - if logits.shape[1] != labels.shape[1]: - logits = logits[:, -labels.shape[1] :, :] - if logits.shape[1] != labels.shape[1]: - # pad logits with -100 + def align_logits_with_labels(logits, labels): + if logits.shape[1] < labels.shape[1]: pad_length = labels.shape[1] - logits.shape[1] - pad_width = ( - (0, 0), - (0, pad_length), - (0, 0), - ) # Padding for each dimension - logits = mx.pad( - logits, pad_width, mode="constant", constant_values=-100 - ) + pad_width = ((0, 0), (0, pad_length), (0, 0)) + return mx.pad(logits, pad_width, mode="constant", constant_values=-100) + elif logits.shape[1] > labels.shape[1]: + return logits[:, -labels.shape[1] :, :] + return logits + + logits = align_logits_with_labels(logits, labels) length_mask = mx.arange(input_ids.shape[1])[None, :] < lengths[:, None] diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 9c46fbde..e0528007 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -728,6 +728,7 @@ def prepare_inputs(image_processor, processor, images, prompts, image_token_inde images = [load_image(img) if isinstance(img, str) else img for img in images] image_grid_thw = None + image_sizes = None if image_processor is not None: processor.pad_token = processor.eos_token @@ -772,7 +773,12 @@ def prepare_inputs(image_processor, processor, images, prompts, image_token_inde if image_grid_thw is not None: image_grid_thw = mx.array(image_grid_thw) + image_sizes = inputs.get("image_sizes", None) + if image_sizes is not None: + image_sizes = mx.array(image_sizes) + except Exception as e: + inputs = [] for i, image in enumerate(images): inputs.append( @@ -796,9 +802,7 @@ def prepare_inputs(image_processor, processor, images, prompts, image_token_inde [mx.array(i["image_sizes"]) for i in inputs], axis=0 ) - return input_ids, pixel_values, image_sizes, image_grid_thw - - return input_ids, pixel_values, mask, image_grid_thw + return input_ids, pixel_values, mask, image_grid_thw, image_sizes def generate_step( @@ -999,12 +1003,13 @@ def generate( tokenizer = processor.tokenizer image_token_index = model.config.image_token_index - input_ids, pixel_values, mask, image_grid_thw = prepare_inputs( + input_ids, pixel_values, mask, image_grid_thw, image_sizes = prepare_inputs( image_processor, processor, image, prompt, image_token_index ) kwargs = { "image_grid_thw": image_grid_thw, + "image_sizes": image_sizes, } tic = time.perf_counter() From 4ebac210db2de3ea52a41f6d8f9f0a54b6805c78 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Thu, 3 Oct 2024 00:45:29 +0200 Subject: [PATCH 38/72] update logs and readme --- mlx_vlm/LORA.MD | 3 ++- mlx_vlm/lora.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mlx_vlm/LORA.MD b/mlx_vlm/LORA.MD index e847e4b2..fbc0f3a9 100644 --- a/mlx_vlm/LORA.MD +++ b/mlx_vlm/LORA.MD @@ -15,11 +15,12 @@ - LLaVA (except for LLaVA-Next) - Pixtral - Idefics 2 +- Deepseek-VL +- Paligemma ## Coming Soon - LLaVA-Next - Phi3_vision -- Paligemma Note: The script only works with model in full or half precision. Quantized models are not supported at the moment. diff --git a/mlx_vlm/lora.py b/mlx_vlm/lora.py index 1acf8381..f637b146 100644 --- a/mlx_vlm/lora.py +++ b/mlx_vlm/lora.py @@ -92,7 +92,7 @@ def process_data(examples): dataset[i * args.batch_size : (i + 1) * args.batch_size] ) if i % args.print_every == 0: - print(f"Epoch {epoch} Step {i} Loss {loss.item():.4f}") + print({"Epoch": epoch, "Step": i, "Loss": f"{loss.item():.4f}"}) save_adapter(model, args.output_path) From 336f4237e04cc405c49d63a2a9614199263f89f1 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Thu, 3 Oct 2024 01:20:26 +0200 Subject: [PATCH 39/72] add utils tests and remove unused collate fn --- mlx_vlm/tests/test_trainer_utils.py | 59 +++++++++++++++++++++++++++++ mlx_vlm/trainer/__init__.py | 1 - mlx_vlm/trainer/utils.py | 17 --------- 3 files changed, 59 insertions(+), 18 deletions(-) create mode 100644 mlx_vlm/tests/test_trainer_utils.py diff --git a/mlx_vlm/tests/test_trainer_utils.py b/mlx_vlm/tests/test_trainer_utils.py new file mode 100644 index 00000000..c7e344ca --- /dev/null +++ b/mlx_vlm/tests/test_trainer_utils.py @@ -0,0 +1,59 @@ +import unittest +from unittest.mock import MagicMock, patch + +import mlx.nn as nn + +from mlx_vlm.trainer.utils import ( + find_all_linear_names, + get_module_by_name, + get_peft_model, + set_module_by_name, +) + + +class TestTrainerUtils(unittest.TestCase): + + def test_get_module_by_name(self): + model = MagicMock() + model.layer1.layer2.layer3 = "test_module" + + result = get_module_by_name(model, "layer1.layer2.layer3") + self.assertEqual(result, "test_module") + + def test_set_module_by_name(self): + model = MagicMock() + new_module = MagicMock() + + set_module_by_name(model, "layer1.layer2.layer3", new_module) + self.assertEqual(model.layer1.layer2.layer3, new_module) + + @patch("mlx_vlm.trainer.utils.freeze_model") + @patch("mlx_vlm.trainer.utils.print_trainable_parameters") + def test_get_peft_model(self, mock_print, mock_freeze): + model = MagicMock() + model.language_model.named_modules.return_value = [ + ("layer1", nn.Linear(256, 512)), + ("layer2", nn.QuantizedLinear(256, 512, 8)), + ] + + result = get_peft_model(model, ["layer1", "layer2"]) + + self.assertTrue(mock_freeze.called) + self.assertTrue(mock_print.called) + self.assertTrue(hasattr(model.config, "lora")) + + def test_find_all_linear_names(self): + model = MagicMock() + model.named_modules.return_value = [ + ("layer1", nn.Linear(256, 512)), + ("layer2", nn.QuantizedLinear(256, 512, 8)), + ("mm_projector", nn.Linear(256, 512)), + ("lm_head", nn.Linear(256, 512)), + ] + + result = find_all_linear_names(model) + self.assertEqual(set(result), {"layer1", "layer2"}) + + +if __name__ == "__main__": + unittest.main() diff --git a/mlx_vlm/trainer/__init__.py b/mlx_vlm/trainer/__init__.py index 33630b92..92b14d31 100644 --- a/mlx_vlm/trainer/__init__.py +++ b/mlx_vlm/trainer/__init__.py @@ -2,7 +2,6 @@ from .trainer import * from .utils import ( apply_lora_layers, - collate_fn, count_parameters, find_all_linear_names, get_peft_model, diff --git a/mlx_vlm/trainer/utils.py b/mlx_vlm/trainer/utils.py index 892d1761..930a8771 100644 --- a/mlx_vlm/trainer/utils.py +++ b/mlx_vlm/trainer/utils.py @@ -93,23 +93,6 @@ def find_all_linear_names(model): return list(lora_module_names) -def collate_fn(processor, examples): - texts = ["answer " + example["question"] for example in examples] - labels = [example["multiple_choice_answer"] for example in examples] - images = [example["image"].convert("RGB") for example in examples] - tokens = processor( - text=texts, - images=images, - suffix=labels, - return_tensors="np", - padding="longest", - tokenize_newline_separately=False, - ) - - tokens = tokens.to(mx.float16) - return tokens - - def count_parameters(model): def nparams(m): if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): From c0cd42d53e1ae8ad09f0ab51a6409a5c6e1e7a30 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 4 Oct 2024 12:33:31 +0200 Subject: [PATCH 40/72] refactor prompt utils and add multi-image support for pixtral --- mlx_vlm/models/pixtral/pixtral.py | 9 +- mlx_vlm/prompt_utils.py | 131 ++++++++++++++++++++---------- 2 files changed, 95 insertions(+), 45 deletions(-) diff --git a/mlx_vlm/models/pixtral/pixtral.py b/mlx_vlm/models/pixtral/pixtral.py index 8cf041a6..032b43f2 100644 --- a/mlx_vlm/models/pixtral/pixtral.py +++ b/mlx_vlm/models/pixtral/pixtral.py @@ -77,7 +77,14 @@ def get_input_embeddings( # Get the ouptut hidden states from the vision model if isinstance(pixel_values, list): - pixel_values = mx.array(pixel_values[0][0])[None, ...] + if input_ids.shape[0] == 1: # Batch size is 1 + pixel_values = mx.concatenate( + [mx.array(pv) for pv in pixel_values[0]], axis=1 + )[None, ...] + else: # Batch size is greater than 1 + pixel_values = mx.concatenate( + [mx.array(pv) for pv in pixel_values], axis=0 + ) if pixel_values.ndim == 3: pixel_values = pixel_values[None, ...] diff --git a/mlx_vlm/prompt_utils.py b/mlx_vlm/prompt_utils.py index c2fc465e..623ab58a 100644 --- a/mlx_vlm/prompt_utils.py +++ b/mlx_vlm/prompt_utils.py @@ -1,64 +1,97 @@ -def get_message_json(model_name, prompt, role="user", skip_image_token=False): +def get_message_json( + model_name, prompt, role="user", skip_image_token=False, num_images=1 +): """ Get the appropriate JSON message based on the specified model. Args: model_name (str): The model for which to generate the message. prompt (str): The text prompt to be included in the message. - *args: Additional positional arguments (unused). - **kwargs: Additional keyword arguments (unused). + role (str): The role of the message (default: "user"). + skip_image_token (bool): Whether to skip adding image tokens (default: False). + num_images (int): Number of image tokens to add (default: 1). Returns: dict: A dictionary representing the JSON message for the specified model. """ - if model_name.lower() in ["idefics2", "qwen2_vl", "llava", "llava_next"]: - message = { - "role": role, - "content": [ - {"type": "text", "text": prompt}, - ], - } - if role == "user" and not skip_image_token: - message["content"].append({"type": "image"}) - elif model_name.lower() in ["llava-qwen2", "bunny-llama"]: + model_name = model_name.lower() - message = {"role": role} - if role == "user" and not skip_image_token: - message["content"] = f"\n{prompt}" - else: - message["content"] = prompt + def create_message(role, prompt): + return {"role": role, "content": prompt} - elif model_name.lower() == "phi3_v": - message = {"role": role} + def add_image_tokens(message, token_format): if role == "user" and not skip_image_token: - message["content"] = f"<|image_1|>\n{prompt}" - else: - message["content"] = prompt + if isinstance(message["content"], list): + message["content"].extend([{"type": "image"}] * num_images) + else: + if model_name == "phi3_v": + message["content"] = f"{token_format}{message['content']}" + else: + message["content"] = ( + f"{token_format * num_images}{message['content']}" + ) + return message - elif model_name.lower() == "multi_modality": - message = {"role": role} - if role == "user" and not skip_image_token: - message["content"] = f"{prompt}" - else: - message["content"] = prompt - elif model_name.lower() == "pixtral": - message = {"role": role, "content": prompt} + message_formats = { + "message_list_with_image": lambda: add_image_tokens( + {"role": role, "content": [{"type": "text", "text": prompt}]}, "" + ), + "message_list_with_image_type": lambda: add_image_tokens( + {"role": role, "content": [{"type": "text", "content": prompt}]}, "" + ), + "message_with_image_token": lambda: add_image_tokens( + create_message(role, prompt), "" + ), + "message_with_image_token_new_line": lambda: add_image_tokens( + create_message(role, prompt), "\n" + ), + "message_with_numbered_image_tokens": lambda: add_image_tokens( + create_message(role, prompt), + " ".join([f"<|image_{i+1}|>" for i in range(num_images)]), + ), + "prompt_only": lambda: prompt, + } - if role == "user" and not skip_image_token: - message["content"] = [ - {"type": "text", "content": prompt}, - ] - message["content"].append({"type": "image"}) - elif model_name.lower() == "paligemma": - message = prompt + model_to_format = { + "idefics2": "message_list_with_image", + "qwen2_vl": "message_list_with_image", + "llava": "message_list_with_image", + "llava_next": "message_list_with_image", + "llava-qwen2": "message_with_image_token_new_line", + "bunny-llama": "message_with_image_token_new_line", + "phi3_v": "message_with_numbered_image_tokens", + "multi_modality": "message_with_image_token", + "pixtral": "message_list_with_image_type", + "paligemma": "prompt_only", + } + + if num_images > 1 and model_name in [ + "llava", + "llava_next", + "llava-qwen2", + "bunny-llama", + "paligemma", + "multi_modality", + ]: + raise ValueError( + f"Model {model_name} does not support multi-image chat. Please only use 1 image." + ) + + format_key = model_to_format.get(model_name) + + if format_key: + return message_formats[format_key]() else: raise ValueError(f"Unsupported model: {model_name}") - return message - def apply_chat_template( - processor, config, prompt, add_generation_prompt=True, return_messages=False + processor, + config, + prompt, + add_generation_prompt=True, + return_messages=False, + num_images=1, ): messages = [] if isinstance(prompt, list): @@ -66,7 +99,10 @@ def apply_chat_template( for i, p in enumerate(prompt): if isinstance(p, str): message = get_message_json( - config["model_type"], p, skip_image_token=i >= 1 + config["model_type"], + p, + skip_image_token=i >= 1, + num_images=num_images, ) elif isinstance(p, dict) and "role" in p.keys(): message = get_message_json( @@ -74,6 +110,7 @@ def apply_chat_template( p["content"], p["role"], skip_image_token=i >= 1, + num_images=num_images, ) else: raise ValueError("Invalid prompt type") @@ -91,16 +128,22 @@ def apply_chat_template( p["content"], p["role"], skip_image_token=i >= 1, + num_images=num_images, ) else: raise ValueError("Invalid prompt type") messages.append(message) else: if isinstance(prompt, str): - message = get_message_json(config["model_type"], prompt) + message = get_message_json( + config["model_type"], prompt, num_images=num_images + ) elif isinstance(prompt, dict) and "role" in prompt.keys(): message = get_message_json( - config["model_type"], prompt["content"], prompt["role"] + config["model_type"], + prompt["content"], + prompt["role"], + num_images=num_images, ) else: raise ValueError("Invalid prompt type") From 5f263748ff22a21aa2d569494fb6e9f57a4453f8 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 4 Oct 2024 14:23:39 +0200 Subject: [PATCH 41/72] add llava interleave support --- mlx_vlm/models/llava/language.py | 23 ++++++++++++++----- mlx_vlm/models/llava/vision.py | 38 +++++++++++++++++++++++--------- mlx_vlm/trainer/trainer.py | 8 +++---- 3 files changed, 48 insertions(+), 21 deletions(-) diff --git a/mlx_vlm/models/llava/language.py b/mlx_vlm/models/llava/language.py index 732b636a..a7f11b40 100644 --- a/mlx_vlm/models/llava/language.py +++ b/mlx_vlm/models/llava/language.py @@ -21,6 +21,7 @@ class TextConfig: rope_theta: float = 10000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = False @classmethod def from_dict(cls, params): @@ -58,9 +59,14 @@ def __init__(self, config: TextConfig): head_dim = config.hidden_size // n_heads self.scale = head_dim**-0.5 - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + if config.model_type == "qwen2": + attention_bias = True + else: + attention_bias = False + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) rope_scale = ( @@ -184,12 +190,13 @@ def __init__(self, config: TextConfig): super().__init__() self.config = config self.model_type = config.model_type - if self.model_type != "llama": + if self.model_type not in ["llama", "qwen2"]: raise ValueError( f"Model type {self.model_type} not supported. Currently only 'llama' is supported" ) self.model = Llama(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + if not config.tie_word_embeddings: + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def __call__( self, @@ -199,7 +206,11 @@ def __call__( mask: Optional[mx.array] = None, ): out = self.model(inputs, cache, inputs_embeds) - return self.lm_head(out) + if self.config.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out @staticmethod def sanitize(weights): diff --git a/mlx_vlm/models/llava/vision.py b/mlx_vlm/models/llava/vision.py index 5a5ec42e..31c27340 100644 --- a/mlx_vlm/models/llava/vision.py +++ b/mlx_vlm/models/llava/vision.py @@ -151,31 +151,44 @@ def __init__(self, config: VisionConfig): self.image_size = config.image_size self.patch_size = config.patch_size - self.class_embedding = mx.zeros((config.hidden_size,)) + if config.model_type == "siglip_vision_model": + bias = True + self.class_embedding = None + else: + bias = False + self.class_embedding = mx.zeros((config.hidden_size,)) self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, - bias=False, + bias=bias, ) self.num_patches = (self.image_size // self.patch_size) ** 2 - self.num_positions = self.num_patches + 1 + self.num_positions = ( + self.num_patches + 1 + if config.model_type == "clip_vision_model" + else self.num_patches + ) self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) def __call__(self, x: mx.array) -> mx.array: batch_size = x.shape[0] patch_embeddings = self.patch_embedding(x) patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) - embed_dim = patch_embeddings.shape[-1] - cls_embeddings = mx.broadcast_to( - self.class_embedding, (batch_size, 1, embed_dim) - ) + if self.config.model_type == "siglip_vision_model": + embeddings = patch_embeddings + else: + embed_dim = patch_embeddings.shape[-1] + cls_embeddings = mx.broadcast_to( + self.class_embedding, (batch_size, 1, embed_dim) + ) + embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) + position_ids = mx.array(np.arange(self.num_positions)[None, :]) - embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) embeddings += self.position_embedding(position_ids) return embeddings @@ -183,8 +196,10 @@ def __call__(self, x: mx.array) -> mx.array: class ClipVisionModel(nn.Module): def __init__(self, config: VisionConfig): super().__init__() + self.config = config self.embeddings = VisionEmbeddings(config) - self.pre_layrnorm = nn.LayerNorm(config.hidden_size) + if self.config.model_type == "clip_vision_model": + self.pre_layrnorm = nn.LayerNorm(config.hidden_size) self.encoder = Encoder(config) self.post_layernorm = nn.LayerNorm(config.hidden_size) @@ -194,7 +209,8 @@ def __call__( output_hidden_states: Optional[bool] = None, ) -> mx.array: x = self.embeddings(x) - x = self.pre_layrnorm(x) + if self.config.model_type == "clip_vision_model": + x = self.pre_layrnorm(x) encoder_states = (x,) if output_hidden_states else None @@ -212,7 +228,7 @@ def __init__(self, config: VisionConfig): super().__init__() self.model_type = config.model_type - if self.model_type != "clip_vision_model": + if self.model_type not in ["clip_vision_model", "siglip_vision_model"]: raise ValueError(f"Unsupported model type: {self.model_type}") self.vision_model = ClipVisionModel(config) diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index e8c39040..1cae4fff 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -182,10 +182,10 @@ def loss_fn(self, model, batch): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] lengths = mx.sum(attention_mask, axis=1) - labels = mx.where( - attention_mask == 1, input_ids, -100 - ) # Only compute loss on non-padded tokens - labels = labels[:, 1:] + # labels = mx.where( + # attention_mask == 1, input_ids, -100 + # ) # Only compute loss on non-padded tokens + labels = input_ids[:, 1:] batch_size, seq_length = input_ids.shape From efa26e6cc4fe0bbe517c1f4c0196bb1eef358243 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 4 Oct 2024 14:24:09 +0200 Subject: [PATCH 42/72] multi image support --- mlx_vlm/generate.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 5f4cadbe..66eb9d25 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -32,6 +32,7 @@ def parse_arguments(): parser.add_argument( "--image", type=str, + nargs="+", default=DEFAULT_IMAGE, help="URL or path of the image to process.", ) @@ -71,9 +72,12 @@ def main(): ) prompt = codecs.decode(args.prompt, "unicode_escape") + # prompt = "mô tả bức tranh này" if model.config.model_type != "paligemma": - prompt = apply_chat_template(processor, config, prompt) + prompt = apply_chat_template( + processor, config, prompt, num_images=len(args.image) + ) output = generate( model, From e7447223c6d56b8941ecce079558e1eccaa001f2 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 4 Oct 2024 15:17:12 +0200 Subject: [PATCH 43/72] add image resizing --- mlx_vlm/utils.py | 128 ++++++++++++++++++++++------------------------- 1 file changed, 61 insertions(+), 67 deletions(-) diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index e0528007..2aba8d93 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -711,25 +711,45 @@ def load_image(image_source: Union[str, Path, BytesIO]): ) -def prepare_inputs(image_processor, processor, images, prompts, image_token_index): +def resize_image(img, max_size): + ratio = min(max_size[0] / img.width, max_size[1] / img.height) + new_size = (int(img.width * ratio), int(img.height * ratio)) + return img.resize(new_size) + + +def process_image(img, resize_shape): + if isinstance(img, str): + img = load_image(img) + if resize_shape is not None: + img = resize_image(img, resize_shape) + return img + + +def prepare_inputs( + image_processor, processor, images, prompts, image_token_index, resize_shape=None +): from transformers.image_utils import load_image mask = None if not isinstance(images, list): images = [images] - if not isinstance(prompts, list): - prompts = [prompts] if len(images) != len(prompts): print( f"Number of images ({len(images)}) and prompts ({len(prompts)}) don't match" ) - images = [load_image(img) if isinstance(img, str) else img for img in images] + # Process images + images = [ + process_image(img, resize_shape) if isinstance(img, str) else img + for img in images + ] image_grid_thw = None image_sizes = None if image_processor is not None: + if not isinstance(prompts, list): + prompts = [prompts] processor.pad_token = processor.eos_token text_chunks = [ @@ -759,48 +779,21 @@ def prepare_inputs(image_processor, processor, images, prompts, image_token_inde ) else: processor.tokenizer.pad_token = processor.tokenizer.eos_token - try: - inputs = processor( - text=prompts, images=images, padding=True, return_tensors="mlx" - ) - if isinstance(inputs["pixel_values"], list): - pixel_values = inputs["pixel_values"] - else: - pixel_values = mx.array(inputs["pixel_values"]) - input_ids = mx.array(inputs["input_ids"]) - mask = mx.array(inputs["attention_mask"]) - image_grid_thw = inputs.get("image_grid_thw", None) - if image_grid_thw is not None: - image_grid_thw = mx.array(image_grid_thw) - - image_sizes = inputs.get("image_sizes", None) - if image_sizes is not None: - image_sizes = mx.array(image_sizes) - - except Exception as e: - - inputs = [] - for i, image in enumerate(images): - inputs.append( - processor( - text=str(prompts[i]), - images=image, - padding=True, - return_tensors="mlx", - ) - ) - input_ids = mx.concatenate( - [mx.array(i["input_ids"]) for i in inputs], axis=0 - ) - pixel_values = mx.concatenate( - [mx.array(i["pixel_values"]) for i in inputs], axis=0 - ) - mask = mx.concatenate( - [mx.array(i["attention_mask"]) for i in inputs], axis=0 - ) - image_sizes = mx.concatenate( - [mx.array(i["image_sizes"]) for i in inputs], axis=0 - ) + inputs = processor( + text=prompts, images=images, padding=True, return_tensors="mlx" + ) + if isinstance(inputs["pixel_values"], list): + pixel_values = inputs["pixel_values"] + else: + pixel_values = mx.array(inputs["pixel_values"]) + input_ids = mx.array(inputs["input_ids"]) + mask = mx.array(inputs["attention_mask"]) + image_sizes = inputs.get("image_sizes", None) + if image_sizes is not None: + image_sizes = mx.array(image_sizes) + image_grid_thw = inputs.get("image_grid_thw", None) + if image_grid_thw is not None: + image_grid_thw = mx.array(image_grid_thw) return input_ids, pixel_values, mask, image_grid_thw, image_sizes @@ -937,9 +930,11 @@ def stream_generate( tokenizer = processor.tokenizer image_token_index = model.config.image_token_index - input_ids, pixel_values, mask = prepare_inputs( + inputs = prepare_inputs( image_processor, processor, image, prompt, image_token_index ) + input_ids, pixel_values, mask = inputs[:3] + kwargs = {k: v for k, v in zip(["image_grid_thw", "image_sizes"], inputs[3:])} detokenizer = processor.detokenizer @@ -1003,33 +998,32 @@ def generate( tokenizer = processor.tokenizer image_token_index = model.config.image_token_index - input_ids, pixel_values, mask, image_grid_thw, image_sizes = prepare_inputs( + # Prepare inputs + inputs = prepare_inputs( image_processor, processor, image, prompt, image_token_index ) + input_ids, pixel_values, mask = inputs[:3] + kwargs = {k: v for k, v in zip(["image_grid_thw", "image_sizes"], inputs[3:])} - kwargs = { - "image_grid_thw": image_grid_thw, - "image_sizes": image_sizes, - } - + # Initialize timing and detokenizer tic = time.perf_counter() detokenizer = processor.detokenizer detokenizer.reset() - for (token, prob), n in zip( - generate_step( - input_ids, - model, - pixel_values, - mask, - temp, - repetition_penalty, - repetition_context_size, - top_p, - **kwargs, - ), - range(max_tokens), - ): + # Generate tokens + generator = generate_step( + input_ids, + model, + pixel_values, + mask, + temp, + repetition_penalty, + repetition_context_size, + top_p, + **kwargs, + ) + + for (token, prob), n in zip(generator, range(max_tokens)): if n == 0: prompt_time = time.perf_counter() - tic From 49d16f67bbabd355e4d877f3e6cce9c9d9375433 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 4 Oct 2024 15:17:48 +0200 Subject: [PATCH 44/72] refactor data loading --- mlx_vlm/trainer/trainer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index 1cae4fff..de3b9018 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -90,10 +90,12 @@ def __getitem__(self, idx): prompts.append(prompt) image_token_index = self.config["image_token_index"] - input_ids, pixel_values, mask, image_grid_thw, image_sizes = prepare_inputs( + + inputs = prepare_inputs( self.image_processor, self.processor, images, prompts, image_token_index ) - + input_ids, pixel_values, mask = inputs[:3] + kwargs = {k: v for k, v in zip(["image_grid_thw", "image_sizes"], inputs[3:])} if mask is None: mask = mx.ones_like(input_ids) @@ -101,8 +103,7 @@ def __getitem__(self, idx): "pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": mask, - "image_grid_thw": image_grid_thw, - "image_sizes": image_sizes, + **kwargs, } From df9962745bff3627872a49df834afe737dd9ced5 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 4 Oct 2024 15:18:38 +0200 Subject: [PATCH 45/72] update data procesing and tqdm --- mlx_vlm/lora.py | 77 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 51 insertions(+), 26 deletions(-) diff --git a/mlx_vlm/lora.py b/mlx_vlm/lora.py index f637b146..d035e921 100644 --- a/mlx_vlm/lora.py +++ b/mlx_vlm/lora.py @@ -15,6 +15,31 @@ logger = logging.getLogger(__name__) +def custom_print(*args, **kwargs): + tqdm.write(" ".join(map(str, args)), **kwargs) + + +def process_data(examples, config, processor): + if config["model_type"] == "pixtral": + conversations = apply_chat_template( + config=config, + processor=processor, + prompt=examples["messages"], + return_messages=True, + ) + examples["messages"] = [ + json.dumps(item, ensure_ascii=False) for item in conversations + ] + else: + examples["messages"] = apply_chat_template( + config=config, + processor=processor, + prompt=examples["messages"], + return_messages=True, + ) + return examples + + def main(args): logger.info(f"\033[32mLoading model from {args.model_path}\033[0m") model, processor = load( @@ -24,7 +49,9 @@ def main(args): image_processor = load_image_processor(args.model_path) logger.info(f"\033[32mLoading dataset from {args.dataset}\033[0m") - dataset = load_dataset(args.dataset, split=args.split) + dataset = load_dataset(args.dataset, split=args.split + "[:20%]") + + dataset = dataset.rename_columns({"image": "images", "conversations": "messages"}) if "messages" not in dataset.column_names: raise ValueError("Dataset must have a 'messages' column") @@ -33,28 +60,10 @@ def main(args): if args.apply_chat_template: logger.info(f"\033[32mApplying chat template to the dataset\033[0m") - - def process_data(examples): - if config["model_type"] == "pixtral": - conversations = apply_chat_template( - config=config, - processor=processor, - prompt=examples["messages"], - return_messages=True, - ) - examples["messages"] = [ - json.dumps(item, ensure_ascii=False) for item in conversations - ] - else: - examples["messages"] = apply_chat_template( - config=config, - processor=processor, - prompt=examples["messages"], - return_messages=True, - ) - return examples - - dataset = dataset.map(process_data) + dataset = dataset.map( + lambda example: process_data(example, config, processor), + desc="Applying chat template", + ) dataset = Dataset( dataset, @@ -83,16 +92,32 @@ def process_data(examples): model.train() + # Training loop + logger.info(f"\033[32mTraining model\033[0m") for epoch in range(args.epochs): if args.steps == 0: args.steps = len(dataset) // args.batch_size - for i in tqdm(range(args.steps)): + progress_bar = tqdm(range(args.steps), position=0, leave=True) + for i in progress_bar: loss = trainer.train_step( dataset[i * args.batch_size : (i + 1) * args.batch_size] ) + # Update progress bar + progress_bar.update(1) + progress_bar.set_postfix( + {"Epoch": epoch, "Step": i, "Loss": f"{loss.item():.4f}"} + ) + if i % args.print_every == 0: - print({"Epoch": epoch, "Step": i, "Loss": f"{loss.item():.4f}"}) + # Log additional information + custom_print( + { + "Epoch": epoch, + "Step": i, + "Loss": f"{loss.item():.4f}", + } + ) save_adapter(model, args.output_path) @@ -129,7 +154,7 @@ def process_data(examples): "--epochs", type=int, default=1, help="Number of epochs to train" ) parser.add_argument( - "--steps", type=int, default=10, help="Number of steps per epoch" + "--steps", type=int, default=20, help="Number of steps per epoch" ) parser.add_argument( "--print-every", type=int, default=10, help="Print loss every n steps" From b0a5bdaf29831d72ac53514da723964473c4c349 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 4 Oct 2024 15:19:24 +0200 Subject: [PATCH 46/72] add llava interleave --- mlx_vlm/prompt_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mlx_vlm/prompt_utils.py b/mlx_vlm/prompt_utils.py index 623ab58a..f263ed19 100644 --- a/mlx_vlm/prompt_utils.py +++ b/mlx_vlm/prompt_utils.py @@ -66,7 +66,6 @@ def add_image_tokens(message, token_format): } if num_images > 1 and model_name in [ - "llava", "llava_next", "llava-qwen2", "bunny-llama", From 941ebf859546cd3e0db9ef4f116a28a8c7b54c09 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 4 Oct 2024 15:19:42 +0200 Subject: [PATCH 47/72] formatting --- mlx_vlm/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx_vlm/__init__.py b/mlx_vlm/__init__.py index 50494f86..63e9873e 100644 --- a/mlx_vlm/__init__.py +++ b/mlx_vlm/__init__.py @@ -1,3 +1,3 @@ -from .prompt_utils import get_message_json +from .prompt_utils import apply_chat_template, get_message_json from .utils import convert, generate, load, prepare_inputs from .version import __version__ From 7a58b967a80159250df932ee13e2b07f7c27f5d0 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 4 Oct 2024 15:20:13 +0200 Subject: [PATCH 48/72] add list of models with multi-image support --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index c4c271dd..1da367d5 100644 --- a/README.md +++ b/README.md @@ -41,3 +41,10 @@ prompt = processor.tokenizer.apply_chat_template( output = generate(model, processor, "http://images.cocodataset.org/val2017/000000039769.jpg", prompt, verbose=False) ``` + +Models with multi-image chat support +- Idefics2 +- LlaVA (Interleave) +- Qwen2-vl +- Phi3-v +- Pixtral From ca80b6cee6efb3aae97b160da49bd4545057de31 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 5 Oct 2024 14:40:51 +0200 Subject: [PATCH 49/72] remove trimmed labels --- mlx_vlm/trainer/trainer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index de3b9018..be3edd6d 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -183,9 +183,6 @@ def loss_fn(self, model, batch): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] lengths = mx.sum(attention_mask, axis=1) - # labels = mx.where( - # attention_mask == 1, input_ids, -100 - # ) # Only compute loss on non-padded tokens labels = input_ids[:, 1:] batch_size, seq_length = input_ids.shape From 349e4d15c05be1ce8759391f9d6102589a6f252e Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 5 Oct 2024 15:03:00 +0200 Subject: [PATCH 50/72] remove warning --- mlx_vlm/utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 2aba8d93..186bd010 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -734,11 +734,6 @@ def prepare_inputs( if not isinstance(images, list): images = [images] - if len(images) != len(prompts): - print( - f"Number of images ({len(images)}) and prompts ({len(prompts)}) don't match" - ) - # Process images images = [ process_image(img, resize_shape) if isinstance(img, str) else img From 028e32cfc8a6d1b9d5e67897ff9e3166c64de2fe Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 5 Oct 2024 15:42:33 +0200 Subject: [PATCH 51/72] pin reqs --- requirements.txt | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index d27c1b4f..2723e00c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,9 @@ mlx>=0.18.0 -numpy +datasets>=2.19.1 +tqdm>=4.66.2 +numpy>=1.23.4 transformers>=4.45.1 scipy==1.13.1 gradio>=4.44.0 -Pillow -requests +Pillow>=10.3.0 +requests>=2.31.0 From cd5ecf5ce50c24aed31104a736886403bfe2bfda Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 5 Oct 2024 15:43:15 +0200 Subject: [PATCH 52/72] add config dict condition --- mlx_vlm/prompt_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlx_vlm/prompt_utils.py b/mlx_vlm/prompt_utils.py index f263ed19..f1b77a80 100644 --- a/mlx_vlm/prompt_utils.py +++ b/mlx_vlm/prompt_utils.py @@ -92,6 +92,9 @@ def apply_chat_template( return_messages=False, num_images=1, ): + if not isinstance(config, dict): + config = config.__dict__ + messages = [] if isinstance(prompt, list): if isinstance(prompt[0], dict) and len(prompt) >= 1: From a116169db6ca0e9b25d203fbb87d8983d10cc95d Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 5 Oct 2024 17:51:06 +0200 Subject: [PATCH 53/72] fix pixtral FT prompt --- mlx_vlm/prompt_utils.py | 84 ++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 52 deletions(-) diff --git a/mlx_vlm/prompt_utils.py b/mlx_vlm/prompt_utils.py index f1b77a80..1289fede 100644 --- a/mlx_vlm/prompt_utils.py +++ b/mlx_vlm/prompt_utils.py @@ -30,6 +30,8 @@ def add_image_tokens(message, token_format): message["content"] = ( f"{token_format * num_images}{message['content']}" ) + if role == "assistant" and model_name == "pixtral": + message["content"] = message["content"][0]["content"] return message message_formats = { @@ -77,7 +79,6 @@ def add_image_tokens(message, token_format): ) format_key = model_to_format.get(model_name) - if format_key: return message_formats[format_key]() else: @@ -92,68 +93,47 @@ def apply_chat_template( return_messages=False, num_images=1, ): - if not isinstance(config, dict): - config = config.__dict__ + config = config if isinstance(config, dict) else config.__dict__ - messages = [] - if isinstance(prompt, list): - if isinstance(prompt[0], dict) and len(prompt) >= 1: - for i, p in enumerate(prompt): - if isinstance(p, str): - message = get_message_json( - config["model_type"], - p, - skip_image_token=i >= 1, - num_images=num_images, - ) - elif isinstance(p, dict) and "role" in p.keys(): - message = get_message_json( - config["model_type"], - p["content"], - p["role"], - skip_image_token=i >= 1, - num_images=num_images, - ) - else: - raise ValueError("Invalid prompt type") - messages.append(message) - else: - for prompts in prompt: - for i, p in enumerate(prompts): - if isinstance(p, str): - message = get_message_json( - config["model_type"], p, skip_image_token=i >= 1 - ) - elif isinstance(p, dict) and "role" in p.keys(): - message = get_message_json( - config["model_type"], - p["content"], - p["role"], - skip_image_token=i >= 1, - num_images=num_images, - ) - else: - raise ValueError("Invalid prompt type") - messages.append(message) - else: - if isinstance(prompt, str): - message = get_message_json( - config["model_type"], prompt, num_images=num_images + def process_single_prompt(p, is_first=True): + if isinstance(p, str): + return get_message_json( + config["model_type"], + p, + skip_image_token=not is_first, + num_images=num_images, ) - elif isinstance(prompt, dict) and "role" in prompt.keys(): - message = get_message_json( + elif isinstance(p, dict) and "role" in p: + return get_message_json( config["model_type"], - prompt["content"], - prompt["role"], + p["content"], + p["role"], + skip_image_token=not is_first, num_images=num_images, ) else: raise ValueError("Invalid prompt type") - messages.append(message) + + messages = [] + if isinstance(prompt, list): + if isinstance(prompt[0], dict): + messages = [process_single_prompt(p, i == 0) for i, p in enumerate(prompt)] + else: + messages = [ + msg + for prompts in prompt + for i, p in enumerate(prompts) + for msg in [process_single_prompt(p, i == 0)] + ] + else: + messages = [process_single_prompt(prompt)] if return_messages: return messages + if config["model_type"] == "paligemma": + return messages[-1] + if "chat_template" in processor.__dict__.keys(): return processor.apply_chat_template( messages, From d791dff8a8cef2ee44707f1e6afd5fe5b0cc8b10 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 5 Oct 2024 18:10:30 +0200 Subject: [PATCH 54/72] formatting images --- mlx_vlm/generate.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 66eb9d25..3e85b30d 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -5,7 +5,7 @@ from .utils import generate, get_model_path, load, load_config, load_image_processor DEFAULT_MODEL_PATH = "mlx-community/nanoLLaVA-1.5-8bit" -DEFAULT_IMAGE = "http://images.cocodataset.org/val2017/000000039769.jpg" +DEFAULT_IMAGE = ["http://images.cocodataset.org/val2017/000000039769.jpg"] DEFAULT_PROMPT = "What are these?" DEFAULT_MAX_TOKENS = 100 DEFAULT_TEMP = 0.5 @@ -67,17 +67,16 @@ def get_model_and_processors(model_path, adapter_path): def main(): args = parse_arguments() + if isinstance(args.image, str): + args.image = [args.image] + model, processor, image_processor, config = get_model_and_processors( args.model, args.adapter_path ) prompt = codecs.decode(args.prompt, "unicode_escape") - # prompt = "mô tả bức tranh này" - if model.config.model_type != "paligemma": - prompt = apply_chat_template( - processor, config, prompt, num_images=len(args.image) - ) + prompt = apply_chat_template(processor, config, prompt, num_images=len(args.image)) output = generate( model, From c16c048e36314a236a016111dcb7033ab993d5d3 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 5 Oct 2024 18:10:49 +0200 Subject: [PATCH 55/72] remove unused --- mlx_vlm/trainer/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlx_vlm/trainer/utils.py b/mlx_vlm/trainer/utils.py index 930a8771..9873ed7e 100644 --- a/mlx_vlm/trainer/utils.py +++ b/mlx_vlm/trainer/utils.py @@ -1,8 +1,6 @@ from pathlib import Path -import mlx.core as mx import mlx.nn as nn -import numpy as np from mlx.utils import tree_flatten from .lora import LoRaLayer From 5a9c3db15a3341936c0d4d25c51c285d00b2d8c8 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 5 Oct 2024 18:11:02 +0200 Subject: [PATCH 56/72] update trainer init --- mlx_vlm/trainer/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx_vlm/trainer/__init__.py b/mlx_vlm/trainer/__init__.py index 92b14d31..813f6d1d 100644 --- a/mlx_vlm/trainer/__init__.py +++ b/mlx_vlm/trainer/__init__.py @@ -1,5 +1,5 @@ from .lora import LoRaLayer, replace_lora_with_linear -from .trainer import * +from .trainer import Dataset, Trainer, save_adapter from .utils import ( apply_lora_layers, count_parameters, From 97a4255956bde2d35582ec3967795ac4438fc0ba Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 5 Oct 2024 18:11:37 +0200 Subject: [PATCH 57/72] update lora --- mlx_vlm/lora.py | 55 +++++++++++++++++++++++-------------------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/mlx_vlm/lora.py b/mlx_vlm/lora.py index d035e921..869efd9e 100644 --- a/mlx_vlm/lora.py +++ b/mlx_vlm/lora.py @@ -6,10 +6,10 @@ from datasets import load_dataset from tqdm import tqdm -from mlx_vlm.prompt_utils import apply_chat_template -from mlx_vlm.trainer import Dataset, Trainer, save_adapter -from mlx_vlm.trainer.utils import find_all_linear_names, get_peft_model -from mlx_vlm.utils import load, load_image_processor +from .prompt_utils import apply_chat_template +from .trainer import Dataset, Trainer, save_adapter +from .trainer.utils import find_all_linear_names, get_peft_model +from .utils import load, load_image_processor logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -19,27 +19,6 @@ def custom_print(*args, **kwargs): tqdm.write(" ".join(map(str, args)), **kwargs) -def process_data(examples, config, processor): - if config["model_type"] == "pixtral": - conversations = apply_chat_template( - config=config, - processor=processor, - prompt=examples["messages"], - return_messages=True, - ) - examples["messages"] = [ - json.dumps(item, ensure_ascii=False) for item in conversations - ] - else: - examples["messages"] = apply_chat_template( - config=config, - processor=processor, - prompt=examples["messages"], - return_messages=True, - ) - return examples - - def main(args): logger.info(f"\033[32mLoading model from {args.model_path}\033[0m") model, processor = load( @@ -60,10 +39,28 @@ def main(args): if args.apply_chat_template: logger.info(f"\033[32mApplying chat template to the dataset\033[0m") - dataset = dataset.map( - lambda example: process_data(example, config, processor), - desc="Applying chat template", - ) + + def process_data(examples): + if config["model_type"] == "pixtral": + conversations = apply_chat_template( + config=config, + processor=processor, + prompt=examples["messages"], + return_messages=True, + ) + examples["messages"] = [ + json.dumps(item, ensure_ascii=False) for item in conversations + ] + else: + examples["messages"] = apply_chat_template( + config=config, + processor=processor, + prompt=examples["messages"], + return_messages=True, + ) + return examples + + dataset = dataset.map(process_data) dataset = Dataset( dataset, From 0159020b2e2f4444f4a8d583fa2fc199ea1215d9 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 5 Oct 2024 18:25:13 +0200 Subject: [PATCH 58/72] update md and formatting --- README.md | 115 +++++++++++++++++++++++++++++-------- mlx_vlm/LORA.MD | 17 ++---- mlx_vlm/trainer/trainer.py | 5 +- 3 files changed, 99 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 1da367d5..d4116782 100644 --- a/README.md +++ b/README.md @@ -1,50 +1,119 @@ # MLX-VLM -MLX-VLM a package for running Vision LLMs on your Mac using MLX. +MLX-VLM is a package for inference and fine-tuning of Vision Language Models (VLMs) on your Mac using MLX. +## Table of Contents +- [Installation](#installation) +- [Usage](#usage) + - [Command Line Interface (CLI)](#command-line-interface-cli) + - [Chat UI with Gradio](#chat-ui-with-gradio) + - [Python Script](#python-script) +- [Multi-Image Chat Support](#multi-image-chat-support) + - [Supported Models](#supported-models) + - [Usage Examples](#usage-examples) +- [Fine-tuning](#fine-tuning) -## Get started +## Installation -The easiest way to get started is to install the `mlx-vlm` package: - -**With `pip`**: +The easiest way to get started is to install the `mlx-vlm` package using pip: ```sh pip install mlx-vlm ``` -## Inference +## Usage + +### Command Line Interface (CLI) + +Generate output from a model using the CLI: -**CLI** ```sh -python -m mlx_vlm.generate --model qnguyen3/nanoLLaVA --max-tokens 100 --temp 0.0 +python -m mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --temp 0.0 --image http://images.cocodataset.org/val2017/000000039769.jpg ``` -**Chat UI with Gradio** +### Chat UI with Gradio + +Launch a chat interface using Gradio: + ```sh -python -m mlx_vlm.chat_ui --model qnguyen3/nanoLLaVA +python -m mlx_vlm.chat_ui --model mlx-community/Qwen2-VL-2B-Instruct-4bit ``` -**Script** +### Python Script + +Here's an example of how to use MLX-VLM in a Python script: + ```python import mlx.core as mx from mlx_vlm import load, generate +from mlx_vlm.prompt_utils import apply_chat_template -model_path = "mlx-community/llava-1.5-7b-4bit" +# Load the model +model_path = "mlx-community/Qwen2-VL-2B-Instruct-4bit" model, processor = load(model_path) -prompt = processor.tokenizer.apply_chat_template( - [{"role": "user", "content": f"\nWhat are these?"}], - tokenize=False, - add_generation_prompt=True, +# Prepare input +image = ["http://images.cocodataset.org/val2017/000000039769.jpg"] +prompt = "Describe this image." + +# Apply chat template +formatted_prompt = apply_chat_template( + processor, config, prompt, num_images=len(image) ) -output = generate(model, processor, "http://images.cocodataset.org/val2017/000000039769.jpg", prompt, verbose=False) +# Generate output +output = generate(model, processor, image, formatted_prompt, verbose=False) +print(output) ``` -Models with multi-image chat support -- Idefics2 -- LlaVA (Interleave) -- Qwen2-vl -- Phi3-v -- Pixtral +## Multi-Image Chat Support + +MLX-VLM supports analyzing multiple images simultaneously with select models. This feature enables more complex visual reasoning tasks and comprehensive analysis across multiple images in a single conversation. + +### Supported Models + +The following models support multi-image chat: + +1. Idefics 2 +2. LLaVA (Interleave) +3. Qwen2-VL +4. Phi3-Vision +5. Pixtral + +### Usage Examples + +#### Python Script + +```python +from mlx_vlm import load, generate +from mlx_vlm.prompt_utils import apply_chat_template + +model_path = "mlx-community/Qwen2-VL-2B-Instruct-4bit" +model, processor = load(model_path) + +images = ["path/to/image1.jpg", "path/to/image2.jpg"] +prompt = "Compare these two images." + +formatted_prompt = apply_chat_template( + processor, config, prompt, num_images=len(images) +) + +output = generate(model, processor, images, formatted_prompt, verbose=False) +print(output) +``` + +#### Command Line + +```sh +python -m mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --prompt "Compare these images" --image path/to/image1.jpg path/to/image2.jpg +``` + +These examples demonstrate how to use multiple images with MLX-VLM for more complex visual reasoning tasks. + +# Fine-tuning + +MLX-VLM supports fine-tuning models with LoRA and QLoRA. + +## LoRA & QLoRA + +To learn more about LoRA, please refer to the [LoRA.md](./mlx_vlm/LoRA.md) file. diff --git a/mlx_vlm/LORA.MD b/mlx_vlm/LORA.MD index fbc0f3a9..2e89d705 100644 --- a/mlx_vlm/LORA.MD +++ b/mlx_vlm/LORA.MD @@ -1,14 +1,13 @@ -# lora.py - NanoLLaVA LoRA Training Script +# LoRA Training Script ## Overview -`lora.py` is a Python script for fine-tuning a NanoLLaVA model using Low-Rank Adaptation (LoRA). This script allows you to train the model on your custom dataset, adjusting various parameters through command-line arguments. +`lora.py` is a Python script for fine-tuning a vision language models (VLMs) using Low-Rank Adaptation (LoRA or QLoRA). This script allows you to train the model on your custom dataset, adjusting various parameters through command-line arguments. ## Requirements - Python 3.7+ -- MLX VLM library -- Required Python packages: `argparse`, `mlx_vlm`, `mlx` +- Required Python packages: `mlx-vlm`, `numpy`, `transformers`, `datasets`, `PIL` ## Supported Models - Qwen2 @@ -22,8 +21,6 @@ - LLaVA-Next - Phi3_vision -Note: The script only works with model in full or half precision. Quantized models are not supported at the moment. - ## Usage To use the script, run it from the command line with the desired arguments: @@ -50,11 +47,12 @@ The script accepts the following command-line arguments: Here's an example of how to run the script with custom parameters: ``` -python lora.py --dataset /path/to/your/dataset --epochs 2 --steps 200 --batch_size 4 --learning_rate 5e-5 +python lora.py --dataset /path/to/your/dataset --model_path /path/to/your/model --epochs 2 --steps 200 --batch_size 4 --learning_rate 5e-5 ``` This command will: - Use the dataset at `/path/to/your/dataset` +- Use the model at `/path/to/your/model` - Train for 2 epochs - Perform 200 steps per epoch - Use a batch size of 4 @@ -66,12 +64,9 @@ The script will print the training loss at regular intervals (defined by `--prin ## Note +If you want to use QLoRA, you need to pass a pre-quantized model to the script using the `--model_path` argument (i.e. `mlx-community/nanoLLaVA-1.5-4bit`). Make sure you have the necessary permissions to read the dataset and write the output file. Also, ensure that your system has sufficient computational resources to handle the specified batch size and model. ## Contributing Feel free to submit issues or pull requests if you find any bugs or have suggestions for improvements. - -## License - -[Specify the license here, e.g., MIT, Apache 2.0, etc.] \ No newline at end of file diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index be3edd6d..02e81ceb 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -10,21 +10,18 @@ import mlx.nn as nn import numpy as np from mlx.utils import tree_flatten -from PIL import Image - -from ..prompt_utils import apply_chat_template def get_prompt(model_type, processor, conversation): if model_type == "paligemma": return conversation + if "chat_template" in processor.__dict__.keys(): prompt = processor.apply_chat_template( conversation, tokenize=False, add_generation_prompt=False, ) - elif "tokenizer" in processor.__dict__.keys(): prompt = processor.tokenizer.apply_chat_template( conversation, From 0ec2412c613a609cabd177ee9537f13e1efe0280 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 5 Oct 2024 18:31:42 +0200 Subject: [PATCH 59/72] bump version --- mlx_vlm/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx_vlm/version.py b/mlx_vlm/version.py index 6561790f..3dc1f76b 100644 --- a/mlx_vlm/version.py +++ b/mlx_vlm/version.py @@ -1 +1 @@ -__version__ = "0.0.15" +__version__ = "0.1.0" From 608adfc021abb2353e9e5cac8409914a7057590a Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 6 Oct 2024 14:38:49 +0200 Subject: [PATCH 60/72] add tests for pixtral and qwen2_vl --- mlx_vlm/models/qwen2_vl/language.py | 6 +- mlx_vlm/models/qwen2_vl/qwen2_vl.py | 1 - mlx_vlm/models/qwen2_vl/vision.py | 3 +- mlx_vlm/tests/test_models.py | 147 +++++++++++++++++++++++++++- 4 files changed, 150 insertions(+), 7 deletions(-) diff --git a/mlx_vlm/models/qwen2_vl/language.py b/mlx_vlm/models/qwen2_vl/language.py index 614bc3d6..481590bd 100644 --- a/mlx_vlm/models/qwen2_vl/language.py +++ b/mlx_vlm/models/qwen2_vl/language.py @@ -253,7 +253,7 @@ def __call__( inputs_embeds: Optional[mx.array] = None, ): if inputs_embeds is None: - h = self.embed_tokens(inputs).astype(mx.float32) + h = self.embed_tokens(inputs) else: h = inputs_embeds @@ -274,6 +274,10 @@ def __init__(self, args: TextConfig): self.args = args self.model_type = args.model_type self.model = Qwen2Model(args) + + if args.model_type != "qwen2_vl": + raise ValueError(f"Unsupported model type: {args.model_type}") + if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) diff --git a/mlx_vlm/models/qwen2_vl/qwen2_vl.py b/mlx_vlm/models/qwen2_vl/qwen2_vl.py index 47d264e3..030f7cec 100644 --- a/mlx_vlm/models/qwen2_vl/qwen2_vl.py +++ b/mlx_vlm/models/qwen2_vl/qwen2_vl.py @@ -18,7 +18,6 @@ class ModelConfig: text_config: TextConfig vision_config: VisionConfig - rope_scaling: dict model_type: str ignore_index: int = -100 image_token_index: int = 151655 diff --git a/mlx_vlm/models/qwen2_vl/vision.py b/mlx_vlm/models/qwen2_vl/vision.py index 48b01c56..7b784477 100644 --- a/mlx_vlm/models/qwen2_vl/vision.py +++ b/mlx_vlm/models/qwen2_vl/vision.py @@ -19,7 +19,7 @@ class VisionConfig: mlp_ratio: float = 4.0 in_channels: int = 3 layer_norm_eps: float = 1e-6 - spatial_patch_size = 14 + spatial_patch_size: int = 14 spatial_merge_size: int = 2 temporal_patch_size: int = 2 @@ -320,6 +320,7 @@ def __call__( grid_thw: mx.array, output_hidden_states: Optional[bool] = None, ) -> mx.array: + hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_thw) diff --git a/mlx_vlm/tests/test_models.py b/mlx_vlm/tests/test_models.py index db92cb62..385a9ef6 100644 --- a/mlx_vlm/tests/test_models.py +++ b/mlx_vlm/tests/test_models.py @@ -52,6 +52,7 @@ def vision_test_runner( num_channels, image_size: tuple, vision_feature_layer=-2, + **kwargs, ): self.assertEqual(vision_tower.model_type, model_type) @@ -61,12 +62,20 @@ def vision_test_runner( shape=(batch_size, image_size[0], image_size[1], num_channels) ) + if kwargs.get("grid_thw", None) is not None: + input_tensor = mx.random.uniform(shape=(1380, 1176)) + # Perform a forward pass - *_, hidden_states = vision_tower(input_tensor, output_hidden_states=True) - # Check the output tensor shape - self.assertEqual( - hidden_states[vision_feature_layer][-1][-1].shape, (vision_hidden_size,) + *_, hidden_states = vision_tower( + input_tensor, output_hidden_states=True, **kwargs ) + # Check the output tensor shape + if kwargs.get("grid_thw", None) is not None: + self.assertEqual(hidden_states.shape, (vision_hidden_size,)) + else: + self.assertEqual( + hidden_states[vision_feature_layer][-1][-1].shape, (vision_hidden_size,) + ) def test_llava_bunny(self): from mlx_vlm.models import llava_bunny @@ -618,6 +627,136 @@ def test_phi3_v(self): (config.vision_config.image_size, config.vision_config.image_size), ) + def test_pixtral(self): + from mlx_vlm.models import pixtral + + text_config = pixtral.TextConfig( + model_type="mistral", + hidden_size=4096, + num_hidden_layers=32, + intermediate_size=11008, + num_attention_heads=32, + rms_norm_eps=1e-5, + vocab_size=32000, + num_key_value_heads=32, + rope_theta=10000.0, + rope_traditional=False, + rope_scaling=None, + ) + + vision_config = pixtral.VisionConfig( + model_type="pixtral", + num_hidden_layers=24, + hidden_size=1024, + intermediate_size=4096, + num_attention_heads=16, + image_size=336, + patch_size=14, + projection_dim=768, + vocab_size=32000, + num_channels=3, + rms_norm_eps=1e-6, + ) + + config = pixtral.ModelConfig( + text_config=text_config, + vision_config=vision_config, + model_type="pixtral", + ignore_index=-100, + image_token_index=32000, + vocab_size=32000, + vision_feature_layer=-2, + vision_feature_select_strategy="default", + ) + + model = pixtral.Model(config) + + self.language_test_runner( + model.language_model, + config.text_config.model_type, + config.text_config.vocab_size, + config.text_config.num_hidden_layers, + ) + + self.mm_projector_test_runner( + model.multi_modal_projector, + config.vision_config.hidden_size, + config.text_config.hidden_size, + ) + + self.vision_test_runner( + model.vision_tower, + config.vision_config.model_type, + config.vision_config.hidden_size, + config.vision_config.num_channels, + (config.vision_config.image_size, config.vision_config.image_size), + ) + + def test_qwen2_vl(self): + from mlx_vlm.models import qwen2_vl + + text_config = qwen2_vl.TextConfig( + model_type="qwen2_vl", + hidden_size=4096, + num_hidden_layers=32, + intermediate_size=11008, + num_attention_heads=32, + rms_norm_eps=1e-5, + vocab_size=32000, + num_key_value_heads=32, + max_position_embeddings=32768, + rope_theta=10000.0, + rope_traditional=False, + rope_scaling={"mrope_section": [16, 24, 24]}, + tie_word_embeddings=True, + ) + + vision_config = qwen2_vl.VisionConfig( + model_type="qwen2_vl", + depth=32, + embed_dim=1280, + hidden_size=1536, + num_heads=16, + image_size=384, + patch_size=14, + vocab_size=32000, + mlp_ratio=4.0, + in_channels=3, + layer_norm_eps=1e-6, + spatial_patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + ) + + config = qwen2_vl.ModelConfig( + text_config=text_config, + vision_config=vision_config, + model_type="qwen2_vl", + ignore_index=-100, + image_token_index=32000, + vocab_size=32000, + vision_feature_layer=-2, + vision_feature_select_strategy="default", + ) + + model = qwen2_vl.Model(config) + + self.language_test_runner( + model.language_model, + config.text_config.model_type, + config.text_config.vocab_size, + config.text_config.num_hidden_layers, + ) + kwargs = {"grid_thw": mx.array([[1, 30, 46]])} + self.vision_test_runner( + model.vision_tower, + config.vision_config.model_type, + config.vision_config.hidden_size, + config.vision_config.in_channels, + (config.vision_config.image_size, config.vision_config.image_size), + **kwargs, + ) + if __name__ == "__main__": unittest.main() From 15962ec8b78c0566e4641aac608c7c979e4aa657 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 6 Oct 2024 15:38:44 +0200 Subject: [PATCH 61/72] add tests for pixtral --- mlx_vlm/models/qwen2_vl/language.py | 6 ++- mlx_vlm/models/qwen2_vl/qwen2_vl.py | 1 - mlx_vlm/models/qwen2_vl/vision.py | 3 +- mlx_vlm/tests/test_models.py | 65 +++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 3 deletions(-) diff --git a/mlx_vlm/models/qwen2_vl/language.py b/mlx_vlm/models/qwen2_vl/language.py index 614bc3d6..481590bd 100644 --- a/mlx_vlm/models/qwen2_vl/language.py +++ b/mlx_vlm/models/qwen2_vl/language.py @@ -253,7 +253,7 @@ def __call__( inputs_embeds: Optional[mx.array] = None, ): if inputs_embeds is None: - h = self.embed_tokens(inputs).astype(mx.float32) + h = self.embed_tokens(inputs) else: h = inputs_embeds @@ -274,6 +274,10 @@ def __init__(self, args: TextConfig): self.args = args self.model_type = args.model_type self.model = Qwen2Model(args) + + if args.model_type != "qwen2_vl": + raise ValueError(f"Unsupported model type: {args.model_type}") + if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) diff --git a/mlx_vlm/models/qwen2_vl/qwen2_vl.py b/mlx_vlm/models/qwen2_vl/qwen2_vl.py index 47d264e3..030f7cec 100644 --- a/mlx_vlm/models/qwen2_vl/qwen2_vl.py +++ b/mlx_vlm/models/qwen2_vl/qwen2_vl.py @@ -18,7 +18,6 @@ class ModelConfig: text_config: TextConfig vision_config: VisionConfig - rope_scaling: dict model_type: str ignore_index: int = -100 image_token_index: int = 151655 diff --git a/mlx_vlm/models/qwen2_vl/vision.py b/mlx_vlm/models/qwen2_vl/vision.py index 48b01c56..7b784477 100644 --- a/mlx_vlm/models/qwen2_vl/vision.py +++ b/mlx_vlm/models/qwen2_vl/vision.py @@ -19,7 +19,7 @@ class VisionConfig: mlp_ratio: float = 4.0 in_channels: int = 3 layer_norm_eps: float = 1e-6 - spatial_patch_size = 14 + spatial_patch_size: int = 14 spatial_merge_size: int = 2 temporal_patch_size: int = 2 @@ -320,6 +320,7 @@ def __call__( grid_thw: mx.array, output_hidden_states: Optional[bool] = None, ) -> mx.array: + hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_thw) diff --git a/mlx_vlm/tests/test_models.py b/mlx_vlm/tests/test_models.py index db92cb62..7dcb2f25 100644 --- a/mlx_vlm/tests/test_models.py +++ b/mlx_vlm/tests/test_models.py @@ -618,6 +618,71 @@ def test_phi3_v(self): (config.vision_config.image_size, config.vision_config.image_size), ) + def test_pixtral(self): + from mlx_vlm.models import pixtral + + text_config = pixtral.TextConfig( + model_type="mistral", + hidden_size=4096, + num_hidden_layers=32, + intermediate_size=11008, + num_attention_heads=32, + rms_norm_eps=1e-5, + vocab_size=32000, + num_key_value_heads=32, + rope_theta=10000.0, + rope_traditional=False, + rope_scaling=None, + ) + + vision_config = pixtral.VisionConfig( + model_type="pixtral", + num_hidden_layers=24, + hidden_size=1024, + intermediate_size=4096, + num_attention_heads=16, + image_size=336, + patch_size=14, + projection_dim=768, + vocab_size=32000, + num_channels=3, + rms_norm_eps=1e-6, + ) + + config = pixtral.ModelConfig( + text_config=text_config, + vision_config=vision_config, + model_type="pixtral", + ignore_index=-100, + image_token_index=32000, + vocab_size=32000, + vision_feature_layer=-2, + vision_feature_select_strategy="default", + ) + + model = pixtral.Model(config) + + self.language_test_runner( + model.language_model, + config.text_config.model_type, + config.text_config.vocab_size, + config.text_config.num_hidden_layers, + ) + + self.mm_projector_test_runner( + model.multi_modal_projector, + config.vision_config.hidden_size, + config.text_config.hidden_size, + ) + + self.vision_test_runner( + model.vision_tower, + config.vision_config.model_type, + config.vision_config.hidden_size, + config.vision_config.num_channels, + (config.vision_config.image_size, config.vision_config.image_size), + ) + if __name__ == "__main__": unittest.main() From b135eea46ae24d437f5c93f011cdd1c9147c056d Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 6 Oct 2024 15:41:52 +0200 Subject: [PATCH 62/72] Merge branch 'pc/tuner' of https://github.com/Blaizzy/mlx-vlm into pc/tuner --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 405a7e80..0e197bcc 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,4 @@ var/ .installed.cfg *.egg .DS_Store +*.log From b7daf4605eb88f50a4f0fad60c785f191dae31af Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 6 Oct 2024 16:03:42 +0200 Subject: [PATCH 63/72] fix test --- mlx_vlm/tests/test_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx_vlm/tests/test_models.py b/mlx_vlm/tests/test_models.py index 5e365d82..3c24ab3e 100644 --- a/mlx_vlm/tests/test_models.py +++ b/mlx_vlm/tests/test_models.py @@ -67,7 +67,7 @@ def vision_test_runner( # Check vision hidden feature layer's shape matches the expected hidden size self.assertEqual( hidden_states[vision_feature_layer].shape[-1], vision_hidden_size - + ) def test_llava_bunny(self): from mlx_vlm.models import llava_bunny @@ -662,7 +662,7 @@ def test_pixtral(self): ) model = pixtral.Model(config) - + def test_qwen2_vl(self): from mlx_vlm.models import qwen2_vl From 726faca460573c47d810d0e64a63bc24a74f4ad9 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 6 Oct 2024 16:08:39 +0200 Subject: [PATCH 64/72] remove rope scaling --- mlx_vlm/tests/test_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mlx_vlm/tests/test_models.py b/mlx_vlm/tests/test_models.py index 3c24ab3e..79b13bab 100644 --- a/mlx_vlm/tests/test_models.py +++ b/mlx_vlm/tests/test_models.py @@ -699,7 +699,6 @@ def test_qwen2_vl(self): model_type="qwen2_vl", text_config=text_config, vision_config=vision_config, - rope_scaling=text_config.rope_scaling, image_token_index=151655, vocab_size=32000, ) From a53fa13cf855eebeb42db60595115e45f305596c Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 6 Oct 2024 16:23:43 +0200 Subject: [PATCH 65/72] remove test args and update MD --- mlx_vlm/LORA.MD | 6 +++--- mlx_vlm/lora.py | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/mlx_vlm/LORA.MD b/mlx_vlm/LORA.MD index 2e89d705..22e7a2fe 100644 --- a/mlx_vlm/LORA.MD +++ b/mlx_vlm/LORA.MD @@ -36,9 +36,9 @@ The script accepts the following command-line arguments: - `--model_path`: Path to the pre-trained model (default: "mlx-community/nanoLLaVA-1.5-bf16") - `--dataset`: Path to your dataset (required) - `--learning_rate`: Learning rate for the optimizer (default: 1e-4) -- `--batch_size`: Batch size for training (default: 2) +- `--batch_size`: Batch size for training (default: 1) - `--epochs`: Number of epochs to train (default: 1) -- `--steps`: Number of steps per epoch (default: 100) +- `--steps`: Number of steps per epoch (default: 0) - `--print_every`: Print loss every n steps (default: 10) - `--output_path`: Path to save the trained adapter (default: "nanollava_lora_adapter.safetensors") @@ -47,7 +47,7 @@ The script accepts the following command-line arguments: Here's an example of how to run the script with custom parameters: ``` -python lora.py --dataset /path/to/your/dataset --model_path /path/to/your/model --epochs 2 --steps 200 --batch_size 4 --learning_rate 5e-5 +python lora.py --dataset /path/to/your/dataset --model_path /path/to/your/model --epochs 2 --batch_size 4 --learning_rate 5e-5 ``` This command will: diff --git a/mlx_vlm/lora.py b/mlx_vlm/lora.py index 869efd9e..2b0e6051 100644 --- a/mlx_vlm/lora.py +++ b/mlx_vlm/lora.py @@ -28,9 +28,7 @@ def main(args): image_processor = load_image_processor(args.model_path) logger.info(f"\033[32mLoading dataset from {args.dataset}\033[0m") - dataset = load_dataset(args.dataset, split=args.split + "[:20%]") - - dataset = dataset.rename_columns({"image": "images", "conversations": "messages"}) + dataset = load_dataset(args.dataset, split=args.split) if "messages" not in dataset.column_names: raise ValueError("Dataset must have a 'messages' column") @@ -116,6 +114,7 @@ def process_data(examples): } ) + # Save the adapter save_adapter(model, args.output_path) @@ -151,7 +150,7 @@ def process_data(examples): "--epochs", type=int, default=1, help="Number of epochs to train" ) parser.add_argument( - "--steps", type=int, default=20, help="Number of steps per epoch" + "--steps", type=int, default=0, help="Number of steps per epoch" ) parser.add_argument( "--print-every", type=int, default=10, help="Print loss every n steps" From 31cdd67fe69e246ff09097146dfd24c028ce556a Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 9 Oct 2024 14:12:16 +0200 Subject: [PATCH 66/72] format dataset defaults --- mlx_vlm/lora.py | 2 -- mlx_vlm/trainer/trainer.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/mlx_vlm/lora.py b/mlx_vlm/lora.py index 2b0e6051..7932ff45 100644 --- a/mlx_vlm/lora.py +++ b/mlx_vlm/lora.py @@ -65,8 +65,6 @@ def process_data(examples): config, processor, image_processor=image_processor, - take=None, - split=None, ) logger.info(f"\033[32mSetting up LoRA\033[0m") diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index 02e81ceb..f02674e1 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -40,7 +40,7 @@ def __init__( processor, image_processor=None, take=None, - split="train", + split=None, ): if split is not None: self.dataset = hf_dataset[split] From e33c0d289ce9a0cd1e75d57f45544b3d0eca669a Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 9 Oct 2024 14:28:11 +0200 Subject: [PATCH 67/72] add dataset formatting info --- mlx_vlm/LORA.MD | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/mlx_vlm/LORA.MD b/mlx_vlm/LORA.MD index 22e7a2fe..ef59a1f2 100644 --- a/mlx_vlm/LORA.MD +++ b/mlx_vlm/LORA.MD @@ -29,6 +29,19 @@ To use the script, run it from the command line with the desired arguments: python lora.py --dataset /path/to/your/dataset [other options] ``` +## Dataset format + +The dataset should be a Hugging Face dataset with a `images` column and a `messages` column. + +``` +{ + "images": ..., + "messages": ..., +} +``` + +Support for other formats and column names will be added soon. + ## Arguments The script accepts the following command-line arguments: @@ -50,14 +63,6 @@ Here's an example of how to run the script with custom parameters: python lora.py --dataset /path/to/your/dataset --model_path /path/to/your/model --epochs 2 --batch_size 4 --learning_rate 5e-5 ``` -This command will: -- Use the dataset at `/path/to/your/dataset` -- Use the model at `/path/to/your/model` -- Train for 2 epochs -- Perform 200 steps per epoch -- Use a batch size of 4 -- Set the learning rate to 5e-5 - ## Output The script will print the training loss at regular intervals (defined by `--print_every`). After training, it will save the LoRA adapter to the specified output path. From 1f3eabd312d58dca8c01b3a229f0f0a9e9356700 Mon Sep 17 00:00:00 2001 From: hiima234 <98786318+hiima234@users.noreply.github.com> Date: Fri, 11 Oct 2024 08:54:36 -0700 Subject: [PATCH 68/72] Fix issues with multiple image handling (#78) 1. [IMG_BREAK] and [IMG_END] are lost after embedding 2. image position encode should be done per image base https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py#L85 https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/modeling_pixtral.py#L492 Co-authored-by: Roger Xu --- mlx_vlm/models/pixtral/pixtral.py | 15 +++++++++++++-- mlx_vlm/models/pixtral/vision.py | 10 +++++----- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/mlx_vlm/models/pixtral/pixtral.py b/mlx_vlm/models/pixtral/pixtral.py index 032b43f2..1257d531 100644 --- a/mlx_vlm/models/pixtral/pixtral.py +++ b/mlx_vlm/models/pixtral/pixtral.py @@ -75,6 +75,9 @@ def get_input_embeddings( # Get the input embeddings from the language model inputs_embeds = self.language_model.model.embed_tokens(input_ids) + # Get number of images + num_images = len(pixel_values[0]) + # Get the ouptut hidden states from the vision model if isinstance(pixel_values, list): if input_ids.shape[0] == 1: # Batch size is 1 @@ -88,8 +91,13 @@ def get_input_embeddings( if pixel_values.ndim == 3: pixel_values = pixel_values[None, ...] + pixel_values = mx.split(pixel_values, num_images, axis=2) + + # pass pixel_values as list of images, as each image is individually run through conv2d and position encoding + # reference code from transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/modeling_pixtral.py#L479C9-L479C21 + # and mistral_inference: https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py#L85 *_, hidden_states = self.vision_tower( - pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True + [pv.transpose(0, 2, 3, 1) for pv in pixel_values], output_hidden_states=True ) # Select the hidden states from the desired layer selected_image_feature = hidden_states[self.vision_feature_layer] @@ -119,7 +127,10 @@ def _merge_input_ids_with_image_features( text_segments.append(inputs_embeds[:, start_idx:position]) start_idx = position + 1 - image_embeddings = mx.split(image_features, image_features.shape[0]) + # [IMG_BREAK] and [IMG_END] are missing with existing implementation + # image_embeddings = mx.split(image_features, image_features.shape[0]) + + image_embeddings = mx.split(image_features, num_image_patches, axis=1) final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] final_embeddings += [inputs_embeds[:, start_idx:]] diff --git a/mlx_vlm/models/pixtral/vision.py b/mlx_vlm/models/pixtral/vision.py index 2db77f4d..ce5bce8c 100644 --- a/mlx_vlm/models/pixtral/vision.py +++ b/mlx_vlm/models/pixtral/vision.py @@ -1,6 +1,6 @@ import inspect from dataclasses import dataclass -from typing import Optional +from typing import Optional, List import mlx.core as mx import mlx.nn as nn @@ -253,11 +253,11 @@ def __init__(self, config: VisionConfig): def __call__( self, - x: mx.array, + x: List[mx.array], output_hidden_states: Optional[bool] = None, ) -> mx.array: - B, H, W, C = x.shape - patch_embeds_list = [self.patch_conv(img[None, :]) for img in x] + B, H, W, C = x[0].shape + patch_embeds_list = [self.patch_conv(img) for img in x] patch_embeds = mx.concatenate( [p.reshape(B, -1, p.shape[-1]) for p in patch_embeds_list], axis=1 @@ -299,7 +299,7 @@ def __init__(self, config: VisionConfig): self.vision_model = PixtralVisionModel(config) def __call__( - self, x: mx.array, output_hidden_states: Optional[bool] = None + self, x: List[mx.array], output_hidden_states: Optional[bool] = None ) -> mx.array: return self.vision_model(x, output_hidden_states) From a9488bb86e6d5579bb77992af97ccc06f02b1819 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 11 Oct 2024 17:57:31 +0200 Subject: [PATCH 69/72] fix styling --- mlx_vlm/models/pixtral/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx_vlm/models/pixtral/vision.py b/mlx_vlm/models/pixtral/vision.py index ce5bce8c..1f015ba0 100644 --- a/mlx_vlm/models/pixtral/vision.py +++ b/mlx_vlm/models/pixtral/vision.py @@ -1,6 +1,6 @@ import inspect from dataclasses import dataclass -from typing import Optional, List +from typing import List, Optional import mlx.core as mx import mlx.nn as nn From 87e598f09d7a093bfe012939b4c699429498f3d2 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 11 Oct 2024 18:41:49 +0200 Subject: [PATCH 70/72] update model --- mlx_vlm/LORA.MD | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlx_vlm/LORA.MD b/mlx_vlm/LORA.MD index ef59a1f2..48338461 100644 --- a/mlx_vlm/LORA.MD +++ b/mlx_vlm/LORA.MD @@ -46,14 +46,14 @@ Support for other formats and column names will be added soon. The script accepts the following command-line arguments: -- `--model_path`: Path to the pre-trained model (default: "mlx-community/nanoLLaVA-1.5-bf16") +- `--model_path`: Path to the pre-trained model (default: "mlx-community/Qwen2-VL-2B-Instruct-bf16") - `--dataset`: Path to your dataset (required) - `--learning_rate`: Learning rate for the optimizer (default: 1e-4) - `--batch_size`: Batch size for training (default: 1) - `--epochs`: Number of epochs to train (default: 1) - `--steps`: Number of steps per epoch (default: 0) - `--print_every`: Print loss every n steps (default: 10) -- `--output_path`: Path to save the trained adapter (default: "nanollava_lora_adapter.safetensors") +- `--output_path`: Path to save the trained adapter (default: "adapters.safetensors") ## Example @@ -69,7 +69,7 @@ The script will print the training loss at regular intervals (defined by `--prin ## Note -If you want to use QLoRA, you need to pass a pre-quantized model to the script using the `--model_path` argument (i.e. `mlx-community/nanoLLaVA-1.5-4bit`). +If you want to use QLoRA, you need to pass a pre-quantized model to the script using the `--model_path` argument (i.e. `mlx-community/Qwen2-VL-2B-Instruct-4bit`). Make sure you have the necessary permissions to read the dataset and write the output file. Also, ensure that your system has sufficient computational resources to handle the specified batch size and model. ## Contributing From dde7390fbe825fe46d0e809c815dbef042329e24 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 11 Oct 2024 18:42:59 +0200 Subject: [PATCH 71/72] update default model --- mlx_vlm/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx_vlm/lora.py b/mlx_vlm/lora.py index 7932ff45..99822ed3 100644 --- a/mlx_vlm/lora.py +++ b/mlx_vlm/lora.py @@ -121,7 +121,7 @@ def process_data(examples): parser.add_argument( "--model-path", type=str, - default="mlx-community/nanoLLaVA-1.5-bf16", + default="mlx-community/Qwen2-VL-2B-Instruct-bf16", help="Path to the pre-trained model", ) parser.add_argument( From abbe83fb2a21e97bd9a9b54173a461616394db2a Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 11 Oct 2024 18:53:21 +0200 Subject: [PATCH 72/72] rewrite comments --- mlx_vlm/models/pixtral/pixtral.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mlx_vlm/models/pixtral/pixtral.py b/mlx_vlm/models/pixtral/pixtral.py index 1257d531..fb51a84f 100644 --- a/mlx_vlm/models/pixtral/pixtral.py +++ b/mlx_vlm/models/pixtral/pixtral.py @@ -93,9 +93,9 @@ def get_input_embeddings( pixel_values = mx.split(pixel_values, num_images, axis=2) - # pass pixel_values as list of images, as each image is individually run through conv2d and position encoding - # reference code from transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/modeling_pixtral.py#L479C9-L479C21 - # and mistral_inference: https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py#L85 + # Pass pixel_values as list of images, as each image is individually run through conv2d and position encoding + # Reference code from transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/modeling_pixtral.py#L479C9-L479C21 + # and mistral_inference: https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py#L85 *_, hidden_states = self.vision_tower( [pv.transpose(0, 2, 3, 1) for pv in pixel_values], output_hidden_states=True ) @@ -127,9 +127,7 @@ def _merge_input_ids_with_image_features( text_segments.append(inputs_embeds[:, start_idx:position]) start_idx = position + 1 - # [IMG_BREAK] and [IMG_END] are missing with existing implementation - # image_embeddings = mx.split(image_features, image_features.shape[0]) - + # Split image features into separate embeddings for each image image_embeddings = mx.split(image_features, num_image_patches, axis=1) final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] final_embeddings += [inputs_embeds[:, start_idx:]]