From cec0639edabc889d0c42ccea35178fb5e82c627e Mon Sep 17 00:00:00 2001 From: anchen Date: Fri, 23 Feb 2024 20:43:40 +1100 Subject: [PATCH 1/5] feat: add image processor for llava processor --- llava/.gitignore | 163 +++++++++++++++++++++++++++++++- llava/download.py | 54 +++++++++++ llava/image_processor.py | 93 ++++++++++++++++++ llava/mlx_model/mlx_config.json | 36 ------- llava/processing_llava.py | 23 +++++ llava/test.py | 50 ++++++++++ 6 files changed, 382 insertions(+), 37 deletions(-) create mode 100644 llava/download.py create mode 100644 llava/image_processor.py delete mode 100644 llava/mlx_model/mlx_config.json create mode 100644 llava/processing_llava.py create mode 100644 llava/test.py diff --git a/llava/.gitignore b/llava/.gitignore index 857540df8..bc0a54fe8 100644 --- a/llava/.gitignore +++ b/llava/.gitignore @@ -1 +1,162 @@ -**mlx_model \ No newline at end of file +**mlx_model# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +models \ No newline at end of file diff --git a/llava/download.py b/llava/download.py new file mode 100644 index 000000000..e755896bb --- /dev/null +++ b/llava/download.py @@ -0,0 +1,54 @@ +import argparse +import os + +import requests +from tqdm import tqdm + + +def download_file(url, path): + response = requests.get(url, stream=True) + total_size_in_bytes = int(response.headers.get("content-length", 0)) + block_size = 1024 # 1 Kbyte + progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + + with open(path, "wb") as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + + progress_bar.close() + + +def download_model(model_name, destination_folder="models"): + # Define the base URL and headers for the Hugging Face API + base_url = f"https://huggingface.co/{model_name}/resolve/main" + headers = {"User-Agent": "Hugging Face Python"} + + # Send a GET request to the Hugging Face API to get a list of all files + response = requests.get( + f"https://huggingface.co/api/models/{model_name}", headers=headers + ) + response.raise_for_status() + + # Extract the list of files from the response JSON + files_to_download = [ + file["rfilename"] + for file in response.json()["siblings"] + if not file["rfilename"].endswith(".bin") + ] + + # Ensure the directory exists + os.makedirs(f"{destination_folder}/{model_name}", exist_ok=True) + + # Download each file + for file in files_to_download: + print(f"Downloading {file}...") + download_file(f"{base_url}/{file}", f"{destination_folder}/{model_name}/{file}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("model_name", type=str, help="Name of the model to download.") + args = parser.parse_args() + + download_model(args.model_name) diff --git a/llava/image_processor.py b/llava/image_processor.py new file mode 100644 index 000000000..5f5be8484 --- /dev/null +++ b/llava/image_processor.py @@ -0,0 +1,93 @@ +# Copyright © 2023-2024 Apple Inc. + +import json +from pathlib import Path +from typing import List, Tuple + +import mlx.core as mx +import numpy as np +from PIL.Image import Image + + +class CLIPImageProcessor: + """ + A simple port of + https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py. + """ + + def __init__( + self, + crop_size: int = 336, + do_center_crop: bool = True, + do_normalize: bool = True, + do_resize: bool = True, + image_mean: List[float] = [0.48145466, 0.4578275, 0.40821073], + image_std: List[float] = [0.26862954, 0.26130258, 0.27577711], + size: int = 336, + **kwargs + ) -> None: + self.crop_size = crop_size + self.do_center_crop = do_center_crop + self.do_normalize = do_normalize + self.do_resize = do_resize + self.image_mean = mx.array(image_mean) + self.image_std = mx.array(image_std) + self.size = size + + def __call__(self, images: List[Image]) -> mx.array: + return mx.concatenate( + [self._preprocess(image)[None] for image in images], axis=0 + ) + + def _preprocess(self, image: Image) -> mx.array: + if self.do_resize: + image = resize(image, self.size) + if self.do_center_crop: + image = center_crop(image, (self.crop_size, self.crop_size)) + image = mx.array(np.array(image)) + image = rescale(image) + if self.do_normalize: + image = normalize(image, self.image_mean, self.image_std) + return image + + @staticmethod + def from_pretrained(path: str): + path = Path(path) + with open(path / "preprocessor_config.json", encoding="utf-8") as f: + config = json.load(f) + return CLIPImageProcessor(**config) + + +def resize(image: Image, short_size: int) -> Image: + """ + Resize so small size to short_size + """ + width, height = image.size + short = min(width, height) + long = max(width, height) + if short == short_size: + return image + new_short = short_size + new_long = int(short_size * long / short) + new_size = (new_short, new_long) if width <= height else (new_long, new_short) + return image.resize(new_size) + + +def center_crop(image: Image, size: Tuple[int, int]) -> Image: + if size[0] % 2 != 0 or size[1] % 2 != 0: + raise ValueError("Only even crop sizes supported.") + original_width, original_height = image.size + crop_height, crop_width = size + top = (original_height - crop_height) // 2 + bottom = top + crop_height + left = (original_width - crop_width) // 2 + right = left + crop_width + return image.crop((left, top, right, bottom)) + + +def rescale(image: mx.array) -> mx.array: + return image.astype(mx.float32) * (1 / 255.0) + + +def normalize(image: mx.array, mean: mx.array, std: mx.array) -> mx.array: + return (image - mean) / std diff --git a/llava/mlx_model/mlx_config.json b/llava/mlx_model/mlx_config.json deleted file mode 100644 index 482ec26ca..000000000 --- a/llava/mlx_model/mlx_config.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "language_model": { - "hidden_size": 4096, - "num_hidden_layers": 32, - "intermediate_size": 11008, - "num_attention_heads": 32, - "rms_norm_eps": 1e-5, - "vocab_size": 32064, - "num_key_value_heads": 32, - "rope_theta": 0, - "rope_traditional": false, - "rope_scaling": null - }, - - "vision_tower": { - "num_hidden_layers": 24, - "hidden_size": 1024, - "intermediate_size": 4096, - "num_attention_heads": 16, - "num_channels": 3, - "image_size": 336, - "patch_size": 14 - }, - - "multi_modal_projector": { - "in_features": 1024, - "out_features": 4096 - }, - - "vision_feature_layer": -2, - "vision_feature_selection_strategy": "default", - "image_token_index": 32000, - "pad_token_id": 32001, - "tie_word_embeddings": false, - "vocab_size": 32064 -} \ No newline at end of file diff --git a/llava/processing_llava.py b/llava/processing_llava.py new file mode 100644 index 000000000..705d1ccf6 --- /dev/null +++ b/llava/processing_llava.py @@ -0,0 +1,23 @@ +from image_processor import CLIPImageProcessor + + +class LlavaProcessor: + def __init__(self, image_processor=None, tokenizer=None): + self.image_processor = CLIPImageProcessor() + self.tokenizer = tokenizer + + def __call__( + self, + text=None, + images=None, + padding=False, + truncation=None, + max_length=None, + return_tensors=None, + ): + if images is not None: + pixel_values = self.image_processor(images) + else: + pixel_values = None + + return {"pixel_values": pixel_values} diff --git a/llava/test.py b/llava/test.py new file mode 100644 index 000000000..2ad128f9f --- /dev/null +++ b/llava/test.py @@ -0,0 +1,50 @@ +import unittest + +import mlx.core as mx +import numpy as np +import requests +import torch +from PIL import Image +from processing_llava import LlavaProcessor +from transformers import AutoProcessor, LlavaForConditionalGeneration + +MLX_PATH = "models/llava-hf/llava-1.5-7b-hf" +HF_PATH = "models/llava-hf/llava-1.5-7b-hf" + + +def load_mlx_models(path): + processor = LlavaProcessor() + return processor, None + + +def load_hf_models(path): + processor = AutoProcessor.from_pretrained(path) + model = LlavaForConditionalGeneration.from_pretrained(path) + + return processor, model + + +class TestCLIP(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.mx_proc, cls.mx_llava = load_mlx_models(MLX_PATH) + cls.hf_proc, cls.hf_llava = load_hf_models(HF_PATH) + + def test_processor(self): + prompt = "USER: \nWhat are these?\nASSISTANT:" + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + + hf_data = mx.array( + np.array( + self.hf_proc(prompt, raw_image, return_tensors="pt")["pixel_values"] + ) + ).transpose(0, 2, 3, 1) + + mx_data = self.mx_proc(prompt, [raw_image])["pixel_values"] + + self.assertTrue(mx.allclose(mx_data, hf_data, atol=1e-5)) + + +if __name__ == "__main__": + unittest.main() From 4dd8bca0279a6548fe495df0eba34fa36f429c30 Mon Sep 17 00:00:00 2001 From: anchen Date: Sat, 24 Feb 2024 15:19:15 +1100 Subject: [PATCH 2/5] wip --- llava/clip.py | 409 +++++++++++++++++++++---------------------------- llava/llama.py | 58 +++---- llava/llava.py | 134 +++++++--------- llava/test.py | 29 ++-- 4 files changed, 267 insertions(+), 363 deletions(-) diff --git a/llava/clip.py b/llava/clip.py index 1568858a7..736307f31 100644 --- a/llava/clip.py +++ b/llava/clip.py @@ -1,65 +1,39 @@ -# Copyright © 2023-2024 Apple Inc. - +import glob +import inspect import json +import logging +import math from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional +from typing import Optional import mlx.core as mx import mlx.nn as nn -from mlx.core import linalg as LA -from mlx.nn.losses import cross_entropy -from mlx.utils import tree_flatten - - -@dataclass -class CLIPVisionOutput: - pooler_output: mx.array - last_hidden_state: mx.array - llava_hidden_state: mx.array - - -@dataclass -class CLIPTextOutput: - pooler_output: mx.array - last_hidden_state: mx.array - - -@dataclass -class CLIPModelOutput: - loss: Optional[mx.array] - text_embeds: Optional[mx.array] - image_embeds: Optional[mx.array] - text_model_output: CLIPTextOutput - vision_model_output: CLIPVisionOutput - - -@dataclass -class CLIPTextConfig: - num_hidden_layers: int - hidden_size: int - intermediate_size: int - num_attention_heads: int - max_position_embeddings: int - vocab_size: int - - -@dataclass -class CLIPVisionConfig: - num_hidden_layers: int - hidden_size: int - intermediate_size: int - num_attention_heads: int - num_channels: int - image_size: int - patch_size: int @dataclass -class CLIPConfig: - text_config: CLIPTextConfig - vision_config: CLIPVisionConfig - projection_dim: int +class VisionConfig: + model_type: str + num_hidden_layers: int = 24 + hidden_size: int = 1024 + intermediate_size: int = 4096 + num_attention_heads: int = 16 + image_size: int = 336 + patch_size: int = 14 + projection_dim: int = 768 + vocab_size: int = 32000 + num_channels: int = 3 + layer_norm_eps: float = 1e-5 + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) def quick_gelu(x: mx.array) -> mx.array: @@ -69,227 +43,196 @@ def quick_gelu(x: mx.array) -> mx.array: return x * mx.sigmoid(1.702 * x) -def clip_loss(logits: mx.array) -> mx.array: - N, M = logits.shape - caption_loss = cross_entropy(logits, mx.arange(N), reduction="mean") - image_loss = cross_entropy(logits.T, mx.arange(M), reduction="mean") - return (caption_loss + image_loss) / 2.0 - +class Attention(nn.Module): + def __init__( + self, + dims: int, + num_heads: int, + query_input_dims: Optional[int] = None, + key_input_dims: Optional[int] = None, + value_input_dims: Optional[int] = None, + value_dims: Optional[int] = None, + value_output_dims: Optional[int] = None, + bias: bool = False, + ): + super().__init__() -class CLIPEncoderLayer(nn.TransformerEncoderLayer): - """The transformer encoder layer from CLIP.""" + if (dims % num_heads) != 0: + raise ValueError( + "The input feature dimensions should be divisible by the " + f"number of heads ({dims} % {num_heads}) != 0" + ) - def __init__(self, hidden_dim: int, intermediate_dim: int, num_heads: int): - super().__init__( - dims=hidden_dim, - mlp_dims=intermediate_dim, - num_heads=num_heads, - activation=quick_gelu, - norm_first=True, - ) - # Add biases to the attention projections - self.attention = nn.MultiHeadAttention( - hidden_dim, num_heads, bias=True) + query_input_dims = query_input_dims or dims + key_input_dims = key_input_dims or dims + value_input_dims = value_input_dims or key_input_dims + value_dims = value_dims or dims + value_output_dims = value_output_dims or dims + + self.num_heads = num_heads + self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) + self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) + self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) + self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) + + def __call__(self, queries, keys, values, mask=None): + queries = self.q_proj(queries) + keys = self.k_proj(keys) + values = self.v_proj(values) + + num_heads = self.num_heads + B, L, D = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys + if mask is not None: + scores = scores + mask.astype(scores.dtype) + scores = mx.softmax(scores, axis=-1) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(values_hat) + + +class MLP(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.activation_fn = quick_gelu + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + def __call__(self, x: mx.array) -> mx.array: + x = self.activation_fn(self.fc1(x)) + x = self.fc2(x) + return x -class CLIPTextModel(nn.Module): - """Implements the text encoder transformer from CLIP.""" - def __init__(self, config: CLIPTextConfig): +class EncoderLayer(nn.Module): + def __init__(self, config: VisionConfig): super().__init__() - - self.token_embedding = nn.Embedding( - config.vocab_size, config.hidden_size) - self.position_embedding = mx.zeros( - (config.max_position_embeddings, config.hidden_size) + self.embed_dim = config.hidden_size + self.self_attn = Attention( + config.hidden_size, config.num_attention_heads, bias=True ) - self.layers = [ - CLIPEncoderLayer( - config.hidden_size, config.intermediate_size, config.num_attention_heads - ) - for _ in range(config.num_hidden_layers) - ] - self.final_layer_norm = nn.LayerNorm(config.hidden_size) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = MLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - def _embed(self, x: mx.array) -> mx.array: - embeddings = self.token_embedding(x) - embeddings += self.position_embedding[: x.shape[1]] - return embeddings + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + y = self.layer_norm1(x) + y = self.self_attn(y, y, y, mask) + x = x + y + y = self.layer_norm2(x) + y = self.mlp(y) + return x + y - def __call__(self, x: mx.array) -> CLIPTextOutput: - B, N = x.shape - eot_tokens = mx.argmax(x, axis=-1) - x = self._embed(x) - mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype) - for l in self.layers: - x = l(x, mask) - last_hidden_state = self.final_layer_norm(x) - pooler_output = last_hidden_state[mx.arange(B), eot_tokens] - - return CLIPTextOutput( - pooler_output=pooler_output, last_hidden_state=last_hidden_state - ) +class Encoder(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] -class CLIPVisionModel(nn.Module): - """Implements the vision encoder transformer from CLIP.""" - def __init__(self, config: CLIPVisionConfig): +class VisionEmbeddings(nn.Module): + def __init__(self, config: VisionConfig): super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size self.class_embedding = mx.zeros((config.hidden_size,)) + self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, - out_channels=config.hidden_size, - kernel_size=config.patch_size, - stride=config.patch_size, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, bias=False, ) - num_patches = (config.image_size // config.patch_size) ** 2 - num_positions = num_patches + 1 - self.position_embedding = mx.zeros((num_positions, config.hidden_size)) - self.pre_layernorm = nn.LayerNorm(config.hidden_size) - self.layers = [ - CLIPEncoderLayer( - config.hidden_size, config.intermediate_size, config.num_attention_heads - ) - for _ in range(config.num_hidden_layers) - ] - self.post_layernorm = nn.LayerNorm(config.hidden_size) - def _embed(self, x: mx.array) -> mx.array: + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def __call__(self, x: mx.array) -> mx.array: batch_size = x.shape[0] - # Patchify using conv: - # [batch_size, sqrt(num_patches), sqrt(num_patches), embed_dim] patch_embeddings = self.patch_embedding(x) - # [batch_size, num_patches, embed_dim] - patch_embeddings = mx.flatten( - patch_embeddings, start_axis=1, end_axis=2) + patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) embed_dim = patch_embeddings.shape[-1] - # Prepend embeddings - # [batch_size, 1, embed_dim] cls_embeddings = mx.broadcast_to( self.class_embedding, (batch_size, 1, embed_dim) ) - # [batch_size, num_patches + 1, embed_dim] embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) - # Add positional encoding - embeddings += self.position_embedding + embeddings += self.position_embedding.weight return embeddings - def __call__(self, x: mx.array) -> CLIPVisionOutput: - x = self._embed(x) - x = self.pre_layernorm(x) - - for l in self.layers: - x = l(x, mask=None) - - # Extract token embedding - pooler_output = self.post_layernorm(x[:, 0, :]) - - llava_hidden_state = x - return CLIPVisionOutput(pooler_output=pooler_output, last_hidden_state=x, llava_hidden_state=llava_hidden_state) - -class CLIPModel(nn.Module): - def __init__(self, config: CLIPConfig): - self.text_model = CLIPTextModel(config.text_config) - self.vision_model = CLIPVisionModel(config.vision_config) - - text_embed_dim = config.text_config.hidden_size - vision_embed_dim = config.vision_config.hidden_size - projection_dim = config.projection_dim +class ClipVisionModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.embeddings = VisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(config.hidden_size) + self.encoder = Encoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size) - self.visual_projection = nn.Linear( - vision_embed_dim, projection_dim, bias=False) - self.text_projection = nn.Linear( - text_embed_dim, projection_dim, bias=False) - self.logit_scale = mx.array(0.0) + def __call__( + self, + x: mx.array, + output_hidden_states: Optional[bool] = None, + ) -> mx.array: + x = self.embeddings(x) + x = self.pre_layrnorm(x) - def get_text_features(self, x: mx.array) -> mx.array: - return self.text_projection(self.text_model(x).pooler_output) + encoder_states = (x,) if output_hidden_states else None - def get_image_features(self, x: mx.array) -> mx.array: - return self.visual_projection(self.vision_model(x).pooler_output) + for l in self.encoder.layers: + x = l(x, mask=None) + if output_hidden_states: + encoder_states = encoder_states + (x,) - def __call__( - self, - input_ids: Optional[mx.array] = None, - pixel_values: Optional[mx.array] = None, - return_loss=False, - ) -> CLIPModelOutput: - if input_ids is not None: - text_model_output = self.text_model(input_ids) - text_embeds = self.text_projection(text_model_output.pooler_output) - text_embeds = text_embeds / \ - LA.norm(text_embeds, axis=-1, keepdims=True) - else: - text_embeds = None - text_model_output = None - - if pixel_values is not None: - vision_model_output = self.vision_model(pixel_values) - image_embeds = self.visual_projection( - vision_model_output.pooler_output) - image_embeds = image_embeds / \ - LA.norm(image_embeds, axis=-1, keepdims=True) - else: - image_embeds = None - vision_model_output = None - - if return_loss and (input_ids is None or pixel_values is None): - raise ValueError( - "Must provide text and image inputs to compute loss.") - - if return_loss: - logit_scale = mx.exp(self.logit_scale) - logits = (text_embeds @ image_embeds.T) * logit_scale - loss = clip_loss(logits) - else: - loss = None - - return CLIPModelOutput( - loss=loss, - text_embeds=text_embeds, - image_embeds=image_embeds, - vision_model_output=vision_model_output, - text_model_output=text_model_output, - ) + pooler_output = self.post_layernorm(x[:, 0, :]) + return pooler_output, x, encoder_states @staticmethod def from_pretrained(path: str): path = Path(path) with open(path / "config.json", "r") as fid: - config = json.load(fid) - - text_config = config["text_config"] - text_config = CLIPTextConfig( - num_hidden_layers=text_config["num_hidden_layers"], - hidden_size=text_config["hidden_size"], - intermediate_size=text_config["intermediate_size"], - num_attention_heads=text_config["num_attention_heads"], - max_position_embeddings=text_config["max_position_embeddings"], - vocab_size=text_config["vocab_size"], - ) + config_dict = json.load(fid) + vision_config = VisionConfig(**config_dict["vision_config"]) - vision_config = config["vision_config"] + model = ClipVisionModel(vision_config) - vision_config = CLIPVisionConfig( - num_hidden_layers=vision_config["num_hidden_layers"], - hidden_size=vision_config["hidden_size"], - intermediate_size=vision_config["intermediate_size"], - num_attention_heads=vision_config["num_attention_heads"], - num_channels=3, - image_size=vision_config["image_size"], - patch_size=vision_config["patch_size"], - ) + weight_files = glob.glob(str(path / "*.safetensors")) + if not weight_files: + logging.error(f"No safetensors found in {path}") + raise FileNotFoundError(f"No safetensors found in {path}") - config = CLIPConfig( - text_config=text_config, - vision_config=vision_config, - projection_dim=config["projection_dim"], - ) - model = CLIPModel(config) - model.load_weights(str(path / "weights.npz")) + weights = {} + for wf in weight_files: + weights.update(mx.load(wf)) + + weights = model.sanitize(weights) + model.load_weights(list(weights.items())) + model.load_weights(weights) return model + + @staticmethod + def sanitize(weights): + sanitized_weights = {} + for k, v in weights.items(): + if "position_ids" in k: + # Remove unused position_ids + continue + elif "patch_embedding.weight" in k: + # pytorch conv2d expects the weight tensor to be of shape [out_channels, in_channels, kH, KW] + # mlx conv2d expects the weight tensor to be of shape [out_channels, kH, KW, in_channels] + sanitized_weights[k] = v.transpose(0, 2, 3, 1) + else: + sanitized_weights[k] = v + + return sanitized_weights diff --git a/llava/llama.py b/llava/llama.py index ba5baeda5..242251a39 100644 --- a/llava/llama.py +++ b/llava/llama.py @@ -1,14 +1,25 @@ +import inspect from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -import inspect - @dataclass -class BaseModelArgs: +class TextConfig: + model_type: str + hidden_size: int = 4096 + num_hidden_layers: int = 32 + intermediate_size: int = 11008 + num_attention_heads: int = 32 + rms_norm_eps: float = 1e-6 + vocab_size: int = 32000 + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + @classmethod def from_dict(cls, params): return cls( @@ -19,21 +30,6 @@ def from_dict(cls, params): } ) - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - num_key_value_heads: int = None - rope_theta: float = 10000 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - def __post_init__(self): if self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads @@ -41,12 +37,10 @@ def __post_init__(self): if self.rope_scaling: required_keys = {"factor", "type"} if not all(key in self.rope_scaling for key in required_keys): - raise ValueError( - f"rope_scaling must contain keys {required_keys}") + raise ValueError(f"rope_scaling must contain keys {required_keys}") if self.rope_scaling["type"] != "linear": - raise ValueError( - "rope_scaling 'type' currently only supports 'linear'") + raise ValueError("rope_scaling 'type' currently only supports 'linear'") class RMSNorm(nn.Module): @@ -64,7 +58,7 @@ def __call__(self, x): class Attention(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, args: TextConfig): super().__init__() dim = args.hidden_size @@ -106,8 +100,7 @@ def __call__( # Prepare the queries, keys and values for the attention computation queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape( - B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if self.repeats > 1: keys = mx.repeat(keys, self.repeats, axis=1) @@ -126,8 +119,7 @@ def __call__( scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) if mask is not None: scores += mask - scores = mx.softmax(scores.astype(mx.float32), - axis=-1).astype(scores.dtype) + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output), (keys, values) @@ -144,15 +136,14 @@ def __call__(self, x) -> mx.array: class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, args: TextConfig): super().__init__() self.num_attention_heads = args.num_attention_heads self.hidden_size = args.hidden_size self.self_attn = Attention(args) self.mlp = MLP(args.hidden_size, args.intermediate_size) self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm( - args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.args = args def __call__( @@ -169,7 +160,7 @@ def __call__( class Llama(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, args: TextConfig): super().__init__() self.args = args self.vocab_size = args.vocab_size @@ -190,8 +181,7 @@ def __call__( mask = None if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask( - h.shape[1]) + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) mask = mask.astype(h.dtype) if cache is None: @@ -204,7 +194,7 @@ def __call__( class LlamaModel(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, args: TextConfig): super().__init__() self.model_type = args.model_type self.model = Llama(args) diff --git a/llava/llava.py b/llava/llava.py index 34eabb9c6..4edfbeaca 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -1,56 +1,41 @@ -from clip import CLIPVisionModel -from llama import LlamaModel -from pathlib import Path +import glob +import inspect import json -import mlx.nn as nn -import mlx.core as mx -from typing import Any, Optional, Dict, Union - - +import logging from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional +import mlx.core as mx +import mlx.nn as nn +from llama import LlamaModel, TextConfig -@dataclass -class VisionConfig: - num_hidden_layers: int - hidden_size: int - intermediate_size: int - num_attention_heads: int - num_channels: int - image_size: int - patch_size: int - - -@dataclass -class LLMConfig: - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - num_key_value_heads: int - rope_theta: float = 10000 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - - -@dataclass -class ProjectionConfig: - in_features: int - out_features: int +from clip import ClipVisionModel, VisionConfig @dataclass class LlaVAConfig: - llm_config: Any + text_config: TextConfig vision_config: VisionConfig - projection_config: ProjectionConfig + ignore_index: int = -100 + image_token_index: int = 32000 + vision_feature_select_strategy: str = "default" + vision_feature_layer: int = -2 + vocab_size: int = 32000 + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) class LlavaMultiModalProjector(nn.Module): - def __init__(self, config: Any): + def __init__(self, config: LlaVAConfig): super().__init__() self.linear_1 = nn.Linear(config.in_features, config.out_features) self.gelu = nn.GELU() @@ -65,14 +50,19 @@ def forward(self, x: mx.array) -> mx.array: class LlavaModel(nn.Module): def __init__(self, config: LlaVAConfig): - self.vision_tower = CLIPVisionModel(config=config.vision_config) - self.language_model = LlamaModel(args=config.llm_config) + self.vision_tower = ClipVisionModel( + config=VisionConfig.from_dict(config.vision_config) + ) + self.language_model = LlamaModel(args=TextConfig.from_dict(config.text_config)) self.multi_modal_projector = LlavaMultiModalProjector( - config=config.projection_config) + config=config.projection_config + ) - def __call__(self, - input_ids: Optional[mx.array] = None, - pixel_values: Optional[mx.array] = None): + def __call__( + self, + input_ids: Optional[mx.array] = None, + pixel_values: Optional[mx.array] = None, + ): # TODO: add the forward pass if pixel_values is not None and input_ids.shape[1] != 1: @@ -81,10 +71,11 @@ def __call__(self, # TODO: this is not the correct output layer, but it's a placeholder selected_image_feature = image_outputs.pooler_output - image_features = self.multi_modal_projector( - selected_image_feature) + image_features = self.multi_modal_projector(selected_image_feature) - def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): + def _merge_input_ids_with_image_features( + self, image_features, inputs_embeds, input_ids, attention_mask, labels + ): # TODO: https://github.com/huggingface/transformers/blob/4f09d0fd888dbf2660313f9715992822acfb99ce/src/transformers/models/llava/modeling_llava.py#L279 special_image_token_mask = input_ids == self.config.special_tokens.image @@ -97,39 +88,20 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in def from_pretrained(path: str): path = Path(path) - with open(path / "mlx_config.json", "r") as f: + with open(path / "config.json", "r") as f: model_config = json.load(f) - llava_mlx_config = LlaVAConfig( - llm_config=LLMConfig( - model_type='vicuna', - hidden_size=model_config['language_model']['hidden_size'], - num_hidden_layers=model_config['language_model']['num_hidden_layers'], - intermediate_size=model_config['language_model']['intermediate_size'], - num_attention_heads=model_config['language_model']['num_attention_heads'], - rms_norm_eps=model_config['language_model']['rms_norm_eps'], - vocab_size=model_config['language_model']['vocab_size'], - num_key_value_heads=model_config['language_model']['num_key_value_heads'], - rope_theta=model_config['language_model']['rope_theta'], - rope_traditional=model_config['language_model']['rope_traditional'], - rope_scaling=model_config['language_model']['rope_scaling'], - ), - vision_config=VisionConfig( - num_hidden_layers=model_config['vision_tower']['num_hidden_layers'], - hidden_size=model_config['vision_tower']['hidden_size'], - intermediate_size=model_config['vision_tower']['intermediate_size'], - num_attention_heads=model_config['vision_tower']['num_attention_heads'], - num_channels=model_config['vision_tower']['num_channels'], - image_size=model_config['vision_tower']['image_size'], - patch_size=model_config['vision_tower']['patch_size'], - ), - projection_config=ProjectionConfig( - in_features=model_config['multi_modal_projector']['in_features'], - out_features=model_config['multi_modal_projector']['out_features'], - ) - ) + model_config = LlaVAConfig.from_dict(model_config) + model = LlavaModel(model_config) + weight_files = glob.glob(str(path / "*.safetensors")) + if not weight_files: + logging.error(f"No safetensors found in {path}") + raise FileNotFoundError(f"No safetensors found in {path}") - model = LlavaModel(llava_mlx_config) - model.load_weights(str(path / "weights.npz")) + weights = {} + for wf in weight_files: + weights.update(mx.load(wf)) + weights = ClipVisionModel.sanitize(weights) + model.load_weights(list(weights.items())) return model diff --git a/llava/test.py b/llava/test.py index 2ad128f9f..0d693677e 100644 --- a/llava/test.py +++ b/llava/test.py @@ -3,32 +3,31 @@ import mlx.core as mx import numpy as np import requests -import torch from PIL import Image from processing_llava import LlavaProcessor from transformers import AutoProcessor, LlavaForConditionalGeneration -MLX_PATH = "models/llava-hf/llava-1.5-7b-hf" -HF_PATH = "models/llava-hf/llava-1.5-7b-hf" +from llava import LlavaModel + +MODEL_PATH = "models/llava-hf/llava-1.5-7b-hf" def load_mlx_models(path): - processor = LlavaProcessor() - return processor, None + model = LlavaModel.from_pretrained(path) + return model def load_hf_models(path): - processor = AutoProcessor.from_pretrained(path) model = LlavaForConditionalGeneration.from_pretrained(path) - - return processor, model + return model class TestCLIP(unittest.TestCase): @classmethod def setUpClass(cls): - cls.mx_proc, cls.mx_llava = load_mlx_models(MLX_PATH) - cls.hf_proc, cls.hf_llava = load_hf_models(HF_PATH) + cls.mx_llava = load_mlx_models(MODEL_PATH) + cls.hf_llava = load_hf_models(MODEL_PATH) + cls.proc = AutoProcessor.from_pretrained(MODEL_PATH) def test_processor(self): prompt = "USER: \nWhat are these?\nASSISTANT:" @@ -36,12 +35,12 @@ def test_processor(self): raw_image = Image.open(requests.get(image_file, stream=True).raw) hf_data = mx.array( - np.array( - self.hf_proc(prompt, raw_image, return_tensors="pt")["pixel_values"] - ) - ).transpose(0, 2, 3, 1) + self.proc(prompt, raw_image, return_tensors="np")["pixel_values"] + ) - mx_data = self.mx_proc(prompt, [raw_image])["pixel_values"] + mx_data = mx.array( + self.proc(prompt, raw_image, return_tensors="np")["pixel_values"] + ) self.assertTrue(mx.allclose(mx_data, hf_data, atol=1e-5)) From c4ea94fdec3ee62a38b51f7408def4e129172561 Mon Sep 17 00:00:00 2001 From: anchen Date: Sun, 25 Feb 2024 00:17:13 +1100 Subject: [PATCH 3/5] feat: llava working example --- llava/Local LLava.ipynb | 1121 ------------------------------- llava/config.py | 37 - llava/convert.py | 87 --- llava/download.py | 54 -- llava/generate.py | 58 ++ llava/image_processor.py | 93 --- llava/{llama.py => language.py} | 12 +- llava/llava.py | 112 ++- llava/processing_llava.py | 23 - llava/test.py | 103 ++- llava/utils.py | 70 -- llava/{clip.py => vision.py} | 30 +- 12 files changed, 255 insertions(+), 1545 deletions(-) delete mode 100644 llava/Local LLava.ipynb delete mode 100644 llava/config.py delete mode 100644 llava/convert.py delete mode 100644 llava/download.py create mode 100644 llava/generate.py delete mode 100644 llava/image_processor.py rename llava/{llama.py => language.py} (95%) delete mode 100644 llava/processing_llava.py delete mode 100644 llava/utils.py rename llava/{clip.py => vision.py} (90%) diff --git a/llava/Local LLava.ipynb b/llava/Local LLava.ipynb deleted file mode 100644 index 69c760103..000000000 --- a/llava/Local LLava.ipynb +++ /dev/null @@ -1,1121 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Goal: Download and convert the weights of LlaVA into MLX, and test the forward pass of this model on example data" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import shutil\n", - "from pathlib import Path\n", - "import os\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "mlx_path = Path('mlx_model')\n", - "\n", - "if not os.path.exists(mlx_path):\n", - " os.makedirs(mlx_path)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/noahkasmanoff/anaconda3/envs/mlx/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 214177.23it/s]\n", - "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" - ] - } - ], - "source": [ - "import mlx.core as mx\n", - "from convert import get_model_path, fetch_from_hub, hf_repo\n", - "\n", - "\n", - "model_path = get_model_path(hf_repo)\n", - "model_config, model_weights, model_weight_files, config, tokenizer = fetch_from_hub(model_path)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[INFO] Converting\n", - "[INFO] Saving\n" - ] - } - ], - "source": [ - "from utils import map_weights, should_keep_weight\n", - "do_convert = True\n", - "if do_convert:\n", - "\n", - " print(\"[INFO] Converting\")\n", - " mlx_weights = dict(map_weights(k, v) for (k, v) in model_weights.items())\n", - " mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)}\n", - " print(\"[INFO] Saving\")\n", - " mx.savez(str(mlx_path / \"weights.npz\"), **mlx_weights)\n", - " for fn in [\"config.json\", \"merges.txt\", \"vocab.json\", \"preprocessor_config.json\"]:\n", - " if fn in os.listdir(model_path):\n", - " shutil.copyfile(\n", - " str(model_path / f\"{fn}\"),\n", - " str(mlx_path / f\"{fn}\"),\n", - " )\n" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from llava import LlavaModel\n", - "mlx_model = LlavaModel.from_pretrained(path='mlx_model')\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "mlx_model = LlavaModel.from_pretrained(path='mlx_model')" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "LlavaModel(\n", - " (vision_tower): CLIPVisionModel(\n", - " (patch_embedding): Conv2d(3, 1024, kernel_size=(14,), stride=(14, 14), padding=(0, 0), bias=False)\n", - " (pre_layernorm): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (layers.0): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.1): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.2): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.3): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.4): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.5): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.6): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.7): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.8): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.9): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.10): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.11): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.12): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.13): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.14): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.15): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.16): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.17): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.18): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.19): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.20): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.21): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.22): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.23): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (post_layernorm): LayerNorm(1024, eps=1e-05, affine=True)\n", - " )\n", - " (language_model): LlamaModel(\n", - " (model): Llama(\n", - " (embed_tokens): Embedding(32064, 4096)\n", - " (layers.0): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.1): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.2): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.3): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.4): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.5): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.6): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.7): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.8): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.9): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.10): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.11): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.12): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.13): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.14): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.15): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.16): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.17): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.18): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.19): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.20): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.21): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.22): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.23): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.24): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.25): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.26): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.27): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.28): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.29): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.30): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.31): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (norm): RMSNorm()\n", - " )\n", - " (lm_head): Linear(input_dims=4096, output_dims=32064, bias=False)\n", - " )\n", - " (multi_modal_projector): LlavaMultiModalProjector(\n", - " (linear_1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (gelu): GELU()\n", - " (linear_2): Linear(input_dims=4096, output_dims=4096, bias=True)\n", - " )\n", - ")" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "mlx_model" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" - ] - } - ], - "source": [ - "# Now that model weights are loaded in, now we can try and run inference code / set that up.\n", - "\n", - "# load the processor\n", - "from transformers import AutoProcessor\n", - "import requests\n", - "from PIL import Image\n", - "processor = AutoProcessor.from_pretrained(\"llava-hf/llava-1.5-7b-hf\")\n", - "\n", - "prompt = \"\\nUSER: What's the content of the image?\\nASSISTANT:\"\n", - "url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n", - "image = Image.open(requests.get(url, stream=True).raw)\n", - "\n", - "inputs = processor(text=prompt, images=image, return_tensors=\"pt\")\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "input_ids = mx.array(inputs[\"input_ids\"].numpy())\n", - "pixel_values = mx.array(inputs[\"pixel_values\"].numpy())\n" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [], - "source": [ - "vision_model_output = mlx_model.vision_tower(pixel_values.transpose(0,2,3,1))" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(1, 577, 1024)" - ] - }, - "execution_count": 55, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 57, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "CLIPVisionOutput(pooler_output=array([[-0.721487, -0.476275, 0.0173661, ..., 0.190072, -1.71528, 1.36224]], dtype=float32), last_hidden_state=array([[[-0.333623, -0.269844, 0.025435, ..., -0.0516554, -0.729696, 0.542679],\n", - " [0.208684, 0.92752, 0.0233985, ..., 1.59934, -0.024813, 0.879629],\n", - " [0.550235, 0.45201, 0.80935, ..., 1.63056, -0.37727, 0.699322],\n", - " ...,\n", - " [0.740987, 0.445616, 0.893172, ..., 0.523529, 0.0230118, -0.457155],\n", - " [0.49297, 0.0680847, 0.79401, ..., 0.476083, 0.274526, -0.284749],\n", - " [-0.0411091, 0.290756, 0.518906, ..., 0.242572, 0.40785, 0.420446]]], dtype=float32))" - ] - }, - "execution_count": 57, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "vision_model_output" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "mlx", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/llava/config.py b/llava/config.py deleted file mode 100644 index f8617ac14..000000000 --- a/llava/config.py +++ /dev/null @@ -1,37 +0,0 @@ -model_config = { - 'language_model': { - 'hidden_size': 4096, - 'num_hidden_layers': 32, - 'intermediate_size': 11008, - 'num_attention_heads': 32, - 'rms_norm_eps': 1e-5, - 'vocab_size': 32000, - 'num_key_value_heads': 32, - 'rope_theta': 0, - 'rope_traditional': False, - 'rope_scaling': None}, - - 'vision_tower': { - 'num_hidden_layers': 24, - 'hidden_size': 1024, - 'intermediate_size': 4096, - 'num_attention_heads': 16, - 'num_channels': 3, - 'image_size': 336, - 'patch_size': 14 - }, - - 'multi_modal_projector': { - 'in_features': 1024, - 'out_features': 4096 - }, - - 'vision_feature_layer': -2, - 'vision_feature_selection_strategy': 'default', - 'image_token_index': 32000, - 'pad_token_id': 32001, - 'tie_word_embeddings': False, - 'vocab_size': 32064, # TODO: confirm this value - - -} diff --git a/llava/convert.py b/llava/convert.py deleted file mode 100644 index af9973850..000000000 --- a/llava/convert.py +++ /dev/null @@ -1,87 +0,0 @@ - -from safetensors.torch import load_file -from pathlib import Path -import glob -import json -import logging -import mlx.nn as nn -from huggingface_hub import snapshot_download -from typing import Dict, Tuple -from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer - - -hf_repo = "llava-hf/llava-1.5-7b-hf" - - -def get_model_path(path_or_hf_repo: str) -> Path: - """ - Ensures the model is available locally. If the path does not exist locally, - it is downloaded from the Hugging Face Hub. - - Args: - path_or_hf_repo (str): The local path or Hugging Face repository ID of the model. - - Returns: - Path: The path to the model. - """ - model_path = Path(path_or_hf_repo) - if not model_path.exists(): - model_path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - ], - ) - ) - return model_path - - -def load_model(model_path: Path) -> nn.Module: - """ - Load and initialize the model from a given path. - - Args: - model_path (Path): The path to load the model from. - - Returns: - nn.Module: The loaded and initialized model. - - Raises: - FileNotFoundError: If the weight files (.safetensors) are not found. - ValueError: If the model class or args class are not found or cannot be instantiated. - """ - try: - with open(model_path / "config.json", "r") as f: - config = json.load(f) - except FileNotFoundError: - logging.error(f"Config file not found in {model_path}") - raise - - weight_files = glob.glob(str(model_path / "*.safetensors")) - if not weight_files: - logging.error(f"No safetensors found in {model_path}") - raise FileNotFoundError(f"No safetensors found in {model_path}") - - weights = {} - for wf in weight_files: - weights.update(load_file(wf)) - - return config, weights, weight_files - - -def fetch_from_hub( - model_path: Path, -) -> Tuple[Dict, dict, PreTrainedTokenizer]: - model_config, model_weights, model_weight_files = load_model(model_path) - - config = AutoConfig.from_pretrained(model_path) - tokenizer = AutoTokenizer.from_pretrained( - model_path) # TODO: should this be the processor? - - # TODO: replace outputs with the model alone once conversion is complete - return model_config, model_weights, model_weight_files, config, tokenizer diff --git a/llava/download.py b/llava/download.py deleted file mode 100644 index e755896bb..000000000 --- a/llava/download.py +++ /dev/null @@ -1,54 +0,0 @@ -import argparse -import os - -import requests -from tqdm import tqdm - - -def download_file(url, path): - response = requests.get(url, stream=True) - total_size_in_bytes = int(response.headers.get("content-length", 0)) - block_size = 1024 # 1 Kbyte - progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) - - with open(path, "wb") as file: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - file.write(data) - - progress_bar.close() - - -def download_model(model_name, destination_folder="models"): - # Define the base URL and headers for the Hugging Face API - base_url = f"https://huggingface.co/{model_name}/resolve/main" - headers = {"User-Agent": "Hugging Face Python"} - - # Send a GET request to the Hugging Face API to get a list of all files - response = requests.get( - f"https://huggingface.co/api/models/{model_name}", headers=headers - ) - response.raise_for_status() - - # Extract the list of files from the response JSON - files_to_download = [ - file["rfilename"] - for file in response.json()["siblings"] - if not file["rfilename"].endswith(".bin") - ] - - # Ensure the directory exists - os.makedirs(f"{destination_folder}/{model_name}", exist_ok=True) - - # Download each file - for file in files_to_download: - print(f"Downloading {file}...") - download_file(f"{base_url}/{file}", f"{destination_folder}/{model_name}/{file}") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("model_name", type=str, help="Name of the model to download.") - args = parser.parse_args() - - download_model(args.model_name) diff --git a/llava/generate.py b/llava/generate.py new file mode 100644 index 000000000..a51fe2cc9 --- /dev/null +++ b/llava/generate.py @@ -0,0 +1,58 @@ +import mlx.core as mx +import mlx.nn as nn +import requests +from PIL import Image +from transformers import AutoProcessor + +from llava import LlavaModel + +MODEL_PATH = "models/llava-hf/llava-1.5-7b-hf" + +prompt = "USER: \nWhat are these?\nASSISTANT:" +image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" +raw_image = Image.open(requests.get(image_file, stream=True).raw) + + +processor = AutoProcessor.from_pretrained(MODEL_PATH) +model = LlavaModel.from_pretrained(MODEL_PATH) + +values = processor(prompt, raw_image, return_tensors="np") +pixel_values = mx.array(values["pixel_values"]) +input_ids = mx.array(values["input_ids"]) + +input_embeds = model(input_ids, pixel_values) +max_tokens = 100 +temperature = 0.3 + + +def sample(logits, temp=0.0): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + +def generate(y: mx.array, model: nn.Module, temp: float = 0.0, cache=None): + while True: + logits, cache = model(y[None], cache=cache) + logits = logits[:, -1, :] + + y = sample(logits, temp=temp) + token = y.item() + + yield token + + +logits, cache = model.language_model(input_ids, cache=None, inputs_embeds=input_embeds) +logits = logits[:, -1, :] +y = sample(logits, temp=temperature) +tokens = [y.item()] +for token, _ in zip( + generate(y, model.language_model, temperature, cache=cache), + range(max_tokens), +): + if token == processor.tokenizer.eos_token_id: + break + tokens.append(token) + +print(processor.tokenizer.decode(tokens)) diff --git a/llava/image_processor.py b/llava/image_processor.py deleted file mode 100644 index 5f5be8484..000000000 --- a/llava/image_processor.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import json -from pathlib import Path -from typing import List, Tuple - -import mlx.core as mx -import numpy as np -from PIL.Image import Image - - -class CLIPImageProcessor: - """ - A simple port of - https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py. - """ - - def __init__( - self, - crop_size: int = 336, - do_center_crop: bool = True, - do_normalize: bool = True, - do_resize: bool = True, - image_mean: List[float] = [0.48145466, 0.4578275, 0.40821073], - image_std: List[float] = [0.26862954, 0.26130258, 0.27577711], - size: int = 336, - **kwargs - ) -> None: - self.crop_size = crop_size - self.do_center_crop = do_center_crop - self.do_normalize = do_normalize - self.do_resize = do_resize - self.image_mean = mx.array(image_mean) - self.image_std = mx.array(image_std) - self.size = size - - def __call__(self, images: List[Image]) -> mx.array: - return mx.concatenate( - [self._preprocess(image)[None] for image in images], axis=0 - ) - - def _preprocess(self, image: Image) -> mx.array: - if self.do_resize: - image = resize(image, self.size) - if self.do_center_crop: - image = center_crop(image, (self.crop_size, self.crop_size)) - image = mx.array(np.array(image)) - image = rescale(image) - if self.do_normalize: - image = normalize(image, self.image_mean, self.image_std) - return image - - @staticmethod - def from_pretrained(path: str): - path = Path(path) - with open(path / "preprocessor_config.json", encoding="utf-8") as f: - config = json.load(f) - return CLIPImageProcessor(**config) - - -def resize(image: Image, short_size: int) -> Image: - """ - Resize so small size to short_size - """ - width, height = image.size - short = min(width, height) - long = max(width, height) - if short == short_size: - return image - new_short = short_size - new_long = int(short_size * long / short) - new_size = (new_short, new_long) if width <= height else (new_long, new_short) - return image.resize(new_size) - - -def center_crop(image: Image, size: Tuple[int, int]) -> Image: - if size[0] % 2 != 0 or size[1] % 2 != 0: - raise ValueError("Only even crop sizes supported.") - original_width, original_height = image.size - crop_height, crop_width = size - top = (original_height - crop_height) // 2 - bottom = top + crop_height - left = (original_width - crop_width) // 2 - right = left + crop_width - return image.crop((left, top, right, bottom)) - - -def rescale(image: mx.array) -> mx.array: - return image.astype(mx.float32) * (1 / 255.0) - - -def normalize(image: mx.array, mean: mx.array, std: mx.array) -> mx.array: - return (image - mean) / std diff --git a/llava/llama.py b/llava/language.py similarity index 95% rename from llava/llama.py rename to llava/language.py index 242251a39..bb9078155 100644 --- a/llava/llama.py +++ b/llava/language.py @@ -176,8 +176,13 @@ def __call__( self, inputs: mx.array, cache=None, + inputs_embeds=None, ): - h = self.embed_tokens(inputs) + # for passing merged input embeddings + if inputs_embeds is None: + h = self.embed_tokens(inputs) + else: + h = inputs_embeds mask = None if h.shape[1] > 1: @@ -193,7 +198,7 @@ def __call__( return self.norm(h), cache -class LlamaModel(nn.Module): +class LanguageModel(nn.Module): def __init__(self, args: TextConfig): super().__init__() self.model_type = args.model_type @@ -204,8 +209,9 @@ def __call__( self, inputs: mx.array, cache=None, + inputs_embeds=None, ): - out, cache = self.model(inputs, cache) + out, cache = self.model(inputs, cache, inputs_embeds) return self.lm_head(out), cache @staticmethod diff --git a/llava/llava.py b/llava/llava.py index 4edfbeaca..01e76122a 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -8,9 +8,9 @@ import mlx.core as mx import mlx.nn as nn -from llama import LlamaModel, TextConfig - -from clip import ClipVisionModel, VisionConfig +import numpy as np +from language import LanguageModel, TextConfig +from vision import VisionConfig, VisionModel @dataclass @@ -37,11 +37,15 @@ def from_dict(cls, params): class LlavaMultiModalProjector(nn.Module): def __init__(self, config: LlaVAConfig): super().__init__() - self.linear_1 = nn.Linear(config.in_features, config.out_features) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, config.text_config.hidden_size, bias=True + ) self.gelu = nn.GELU() - self.linear_2 = nn.Linear(config.out_features, config.out_features) + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=True + ) - def forward(self, x: mx.array) -> mx.array: + def __call__(self, x: mx.array) -> mx.array: x = self.linear_1(x) x = self.gelu(x) x = self.linear_2(x) @@ -50,39 +54,81 @@ def forward(self, x: mx.array) -> mx.array: class LlavaModel(nn.Module): def __init__(self, config: LlaVAConfig): - self.vision_tower = ClipVisionModel( - config=VisionConfig.from_dict(config.vision_config) - ) - self.language_model = LlamaModel(args=TextConfig.from_dict(config.text_config)) - self.multi_modal_projector = LlavaMultiModalProjector( - config=config.projection_config - ) + self.config = config + self.vision_tower = VisionModel(config.vision_config) + self.language_model = LanguageModel(config.text_config) + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.vision_feature_layer = config.vision_feature_layer + self.vision_feature_select_strategy = config.vision_feature_select_strategy def __call__( self, input_ids: Optional[mx.array] = None, pixel_values: Optional[mx.array] = None, ): - # TODO: add the forward pass - - if pixel_values is not None and input_ids.shape[1] != 1: - image_outputs = self.vision_tower(pixel_values) + if pixel_values is None: + return self.language_model(input_ids) - # TODO: this is not the correct output layer, but it's a placeholder - selected_image_feature = image_outputs.pooler_output + inputs_embeds = self.language_model.model.embed_tokens(input_ids) + _, _, hidden_states = self.vision_tower( + pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True + ) + selected_image_feature = hidden_states[self.vision_feature_layer] + + if self.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + f"Unexpected select feature strategy: {self.vision_feature_select_strategy}" + ) + + image_features = self.multi_modal_projector(selected_image_feature) + final_inputs_embeds = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids + ) - image_features = self.multi_modal_projector(selected_image_feature) + return final_inputs_embeds def _merge_input_ids_with_image_features( - self, image_features, inputs_embeds, input_ids, attention_mask, labels + self, image_features, inputs_embeds, input_ids ): - # TODO: https://github.com/huggingface/transformers/blob/4f09d0fd888dbf2660313f9715992822acfb99ce/src/transformers/models/llava/modeling_llava.py#L279 + image_features = np.array(image_features) + inputs_embeds = np.array(inputs_embeds) + input_ids = np.array(input_ids) + + _, num_image_patches, embed_dim = image_features.shape + batch_size, sequence_length = input_ids.shape + special_image_token_mask = input_ids == self.config.image_token_index + num_special_image_tokens = np.sum(special_image_token_mask, axis=-1) + max_embed_dim = ( + np.max(num_special_image_tokens) * (num_image_patches - 1) + ) + sequence_length + + non_image_indices = np.where(input_ids != self.config.image_token_index) + + new_token_positions = ( + np.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), axis=-1) + - 1 + ) + text_to_overwrite = new_token_positions[non_image_indices] + + final_embedding = np.zeros( + (batch_size, max_embed_dim, embed_dim), dtype=inputs_embeds.dtype + ) - special_image_token_mask = input_ids == self.config.special_tokens.image + final_embedding[non_image_indices[0], text_to_overwrite, :] = inputs_embeds[ + non_image_indices + ] - num_image_tokens = special_image_token_mask.sum() + image_to_overwrite = np.all(final_embedding == 0, axis=-1) + reshaped_image_features = image_features.reshape(-1, embed_dim) + final_embedding[image_to_overwrite, :] = reshaped_image_features[ + : np.sum(image_to_overwrite) + ] - pass + return mx.array(final_embedding) @staticmethod def from_pretrained(path: str): @@ -92,6 +138,15 @@ def from_pretrained(path: str): model_config = json.load(f) model_config = LlaVAConfig.from_dict(model_config) + + if isinstance(model_config.vision_config, dict): + model_config.vision_config = VisionConfig.from_dict( + model_config.vision_config + ) + + if isinstance(model_config.text_config, dict): + model_config.text_config = TextConfig.from_dict(model_config.text_config) + model = LlavaModel(model_config) weight_files = glob.glob(str(path / "*.safetensors")) if not weight_files: @@ -102,6 +157,11 @@ def from_pretrained(path: str): for wf in weight_files: weights.update(mx.load(wf)) - weights = ClipVisionModel.sanitize(weights) + if hasattr(VisionModel, "sanitize"): + weights = VisionModel.sanitize(weights) + + if hasattr(VisionModel, "sanitize"): + weights = LanguageModel.sanitize(weights) + model.load_weights(list(weights.items())) return model diff --git a/llava/processing_llava.py b/llava/processing_llava.py deleted file mode 100644 index 705d1ccf6..000000000 --- a/llava/processing_llava.py +++ /dev/null @@ -1,23 +0,0 @@ -from image_processor import CLIPImageProcessor - - -class LlavaProcessor: - def __init__(self, image_processor=None, tokenizer=None): - self.image_processor = CLIPImageProcessor() - self.tokenizer = tokenizer - - def __call__( - self, - text=None, - images=None, - padding=False, - truncation=None, - max_length=None, - return_tensors=None, - ): - if images is not None: - pixel_values = self.image_processor(images) - else: - pixel_values = None - - return {"pixel_values": pixel_values} diff --git a/llava/test.py b/llava/test.py index 0d693677e..4cca0c9af 100644 --- a/llava/test.py +++ b/llava/test.py @@ -3,8 +3,8 @@ import mlx.core as mx import numpy as np import requests +import torch from PIL import Image -from processing_llava import LlavaProcessor from transformers import AutoProcessor, LlavaForConditionalGeneration from llava import LlavaModel @@ -14,11 +14,13 @@ def load_mlx_models(path): model = LlavaModel.from_pretrained(path) + model.eval() return model def load_hf_models(path): model = LlavaForConditionalGeneration.from_pretrained(path) + model.eval() return model @@ -29,20 +31,103 @@ def setUpClass(cls): cls.hf_llava = load_hf_models(MODEL_PATH) cls.proc = AutoProcessor.from_pretrained(MODEL_PATH) - def test_processor(self): + def test_image_features(self): prompt = "USER: \nWhat are these?\nASSISTANT:" image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" raw_image = Image.open(requests.get(image_file, stream=True).raw) + vision_feature_layer = -2 + with torch.no_grad(): + pixel_values = self.proc(prompt, raw_image, return_tensors="pt")[ + "pixel_values" + ] - hf_data = mx.array( - self.proc(prompt, raw_image, return_tensors="np")["pixel_values"] - ) + hf_pixel_values = pixel_values + mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1) - mx_data = mx.array( - self.proc(prompt, raw_image, return_tensors="np")["pixel_values"] - ) + _, _, hidden_states = self.mx_llava.vision_tower( + mx_pixel_values, + output_hidden_states=True, + ) - self.assertTrue(mx.allclose(mx_data, hf_data, atol=1e-5)) + mx_elected_image_feature = hidden_states[vision_feature_layer] + mx_image_features = self.mx_llava.multi_modal_projector( + mx_elected_image_feature + ) + + hf_image_outputs = self.hf_llava.vision_tower( + hf_pixel_values, output_hidden_states=True + ) + hf_elected_image_feature = hf_image_outputs.hidden_states[ + vision_feature_layer + ] + hf_image_features = self.hf_llava.multi_modal_projector( + hf_elected_image_feature + ) + + self.assertTrue( + mx.allclose( + mx_image_features, + mx.array(hf_image_features.numpy()), + atol=1e-2, + ) + ) + + def test_merge_input_ids_with_image_features(self): + prompt = "USER: \nWhat are these?\nASSISTANT:" + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + vision_feature_layer = -2 + with torch.no_grad(): + values = self.proc(prompt, raw_image, return_tensors="pt") + pixel_values = values["pixel_values"] + input_ids = values["input_ids"] + + hf_pixel_values = pixel_values + mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1) + + _, _, hidden_states = self.mx_llava.vision_tower( + mx_pixel_values, + output_hidden_states=True, + ) + mx_input_ids = mx.array(input_ids.numpy()) + mx_elected_image_feature = hidden_states[vision_feature_layer] + mx_image_features = self.mx_llava.multi_modal_projector( + mx_elected_image_feature + ) + mx_inputs_embeds = self.mx_llava.language_model.model.embed_tokens( + mx_input_ids + ) + mx_final_embedding = self.mx_llava._merge_input_ids_with_image_features( + mx_image_features, mx_inputs_embeds, mx_input_ids + ) + + hf_image_outputs = self.hf_llava.vision_tower( + hf_pixel_values, output_hidden_states=True + ) + hf_elected_image_feature = hf_image_outputs.hidden_states[ + vision_feature_layer + ] + hf_image_features = self.hf_llava.multi_modal_projector( + hf_elected_image_feature + ) + hf_inputs_embeds = self.hf_llava.get_input_embeddings()(input_ids) + hf_final_embedding, _, _, _ = ( + self.hf_llava._merge_input_ids_with_image_features( + hf_image_features, + hf_inputs_embeds, + input_ids, + attention_mask=input_ids, + labels=torch.ones_like(input_ids), + ) + ) + + self.assertTrue( + mx.allclose( + mx_final_embedding, + mx.array(hf_final_embedding.numpy()), + atol=1e-1, + ) + ) if __name__ == "__main__": diff --git a/llava/utils.py b/llava/utils.py deleted file mode 100644 index 62c5d0ef1..000000000 --- a/llava/utils.py +++ /dev/null @@ -1,70 +0,0 @@ -import mlx.core as mx -import torch -from typing import Tuple - - -def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array: - # bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss - a = a.to(torch.float32) if dtype == "bfloat16" else a.to( - getattr(torch, dtype)) - return mx.array(a.numpy(), getattr(mx, dtype)) - - -def should_keep_weight(key: str): - return not ("position_ids" in key) - - -def map_vision_tower_weights(key: str, value: torch.Tensor) -> Tuple[str, torch.Tensor]: - key = key.replace("embeddings.", "") - key = key.replace("encoder.", "") - key = key.replace("position_embedding.weight", "position_embedding") - - key = key.replace('vision_model.', '') - - # Map attention layers - if "self_attn." in key: - key = key.replace("self_attn.", "attention.") - if "q_proj." in key: - key = key.replace("q_proj.", "query_proj.") - if "k_proj." in key: - key = key.replace("k_proj.", "key_proj.") - if "v_proj." in key: - key = key.replace("v_proj.", "value_proj.") - if "layer_norm1." in key: - key = key.replace("layer_norm1.", "ln1.") - if "layer_norm2." in key: - key = key.replace("layer_norm2.", "ln2.") - # Map ffn layers - if "mlp.fc1" in key: - key = key.replace("mlp.fc1", "linear1") - if "mlp.fc2" in key: - key = key.replace("mlp.fc2", "linear2") - # Fix layernorm typo - if "pre_layrnorm" in key: - # Fix typo in weights :) - key = key.replace("pre_layrnorm", "pre_layernorm") - if "patch_embedding.weight" in key: - # Initially, value: [out_channels, in_channels, kH, KW]. - # We want [out_channels, kH, KW, in_channels] - value = value.permute(0, 2, 3, 1) - return (key, value) - - -def map_language_model_weights(key: str, value: torch.Tensor) -> Tuple[str, torch.Tensor]: - return (key, value) - - -def map_multi_modal_projector_weights(key: str, value: torch.Tensor) -> Tuple[str, torch.Tensor]: - return (key, value) - - -def map_weights(key: str, value: torch.Tensor) -> Tuple[str, mx.array]: - - if 'vision_tower' in key: - key, value = map_vision_tower_weights(key, value) - elif 'language_model' in key: - key, value = map_language_model_weights(key, value) - elif 'multi_modal_projector' in key: - key, value = map_multi_modal_projector_weights(key, value) - - return (key, torch_to_mx(value, dtype=str(value.dtype).replace("torch.", ""))) diff --git a/llava/clip.py b/llava/vision.py similarity index 90% rename from llava/clip.py rename to llava/vision.py index 736307f31..ed7f7a46e 100644 --- a/llava/clip.py +++ b/llava/vision.py @@ -4,7 +4,6 @@ import logging import math from dataclasses import dataclass -from pathlib import Path from typing import Optional import mlx.core as mx @@ -197,29 +196,16 @@ def __call__( pooler_output = self.post_layernorm(x[:, 0, :]) return pooler_output, x, encoder_states - @staticmethod - def from_pretrained(path: str): - path = Path(path) - - with open(path / "config.json", "r") as fid: - config_dict = json.load(fid) - vision_config = VisionConfig(**config_dict["vision_config"]) - - model = ClipVisionModel(vision_config) - - weight_files = glob.glob(str(path / "*.safetensors")) - if not weight_files: - logging.error(f"No safetensors found in {path}") - raise FileNotFoundError(f"No safetensors found in {path}") - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) +class VisionModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.vision_model = ClipVisionModel(config) - weights = model.sanitize(weights) - model.load_weights(list(weights.items())) - model.load_weights(weights) - return model + def __call__( + self, x: mx.array, output_hidden_states: Optional[bool] = None + ) -> mx.array: + return self.vision_model(x, output_hidden_states) @staticmethod def sanitize(weights): From b9aeadea28929f712213fa9db8327d05c84da42a Mon Sep 17 00:00:00 2001 From: anchen Date: Sun, 25 Feb 2024 00:40:40 +1100 Subject: [PATCH 4/5] chore: refactor generate script --- llava/generate.py | 132 ++++++++++++++++++++++++++++++++++------------ llava/llava.py | 2 +- 2 files changed, 100 insertions(+), 34 deletions(-) diff --git a/llava/generate.py b/llava/generate.py index a51fe2cc9..df92c97f7 100644 --- a/llava/generate.py +++ b/llava/generate.py @@ -1,3 +1,6 @@ +import argparse +import os + import mlx.core as mx import mlx.nn as nn import requests @@ -6,53 +9,116 @@ from llava import LlavaModel -MODEL_PATH = "models/llava-hf/llava-1.5-7b-hf" -prompt = "USER: \nWhat are these?\nASSISTANT:" -image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" -raw_image = Image.open(requests.get(image_file, stream=True).raw) +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Generate text from an image using a model." + ) + parser.add_argument( + "--model", + type=str, + default="models/llava-hf/llava-1.5-7b-hf", + help="Path to the model directory.", + ) + parser.add_argument( + "--image", + type=str, + default="http://images.cocodataset.org/val2017/000000039769.jpg", + help="URL or path of the image to process.", + ) + parser.add_argument( + "--prompt", + type=str, + default="USER: \nWhat are these?\nASSISTANT:", + help="Prompt to use for the model.", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=100, + help="Maximum number of tokens to generate.", + ) + parser.add_argument( + "--temperature", type=float, default=0.3, help="Temperature for sampling." + ) + return parser.parse_args() + + +def load_image(image_source): + if image_source.startswith(("http://", "https://")): + try: + response = requests.get(image_source, stream=True) + response.raise_for_status() + return Image.open(response.raw) + except requests.HTTPError as e: + print(f"Failed to load image from URL: {e}") + return None + elif os.path.isfile(image_source): + try: + return Image.open(image_source) + except IOError as e: + print(f"Failed to load image from path: {e}") + return None + else: + print("The image source is neither a valid URL nor a file path.") + return None -processor = AutoProcessor.from_pretrained(MODEL_PATH) -model = LlavaModel.from_pretrained(MODEL_PATH) +def initialize_model(model_path): + processor = AutoProcessor.from_pretrained(model_path) + model = LlavaModel.from_pretrained(model_path) + return processor, model -values = processor(prompt, raw_image, return_tensors="np") -pixel_values = mx.array(values["pixel_values"]) -input_ids = mx.array(values["input_ids"]) -input_embeds = model(input_ids, pixel_values) -max_tokens = 100 -temperature = 0.3 +def prepare_inputs(processor, image, prompt): + inputs = processor(prompt, image, return_tensors="np") + pixel_values = mx.array(inputs["pixel_values"]) + input_ids = mx.array(inputs["input_ids"]) + return input_ids, pixel_values -def sample(logits, temp=0.0): - if temp == 0: +def sample(logits, temperature=0.0): + if temperature == 0: return mx.argmax(logits, axis=-1) else: - return mx.random.categorical(logits * (1 / temp)) + return mx.random.categorical(logits * (1 / temperature)) -def generate(y: mx.array, model: nn.Module, temp: float = 0.0, cache=None): - while True: - logits, cache = model(y[None], cache=cache) - logits = logits[:, -1, :] +def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature): + input_embeds = model.get_input_embeddings(input_ids, pixel_values) + logits, cache = model.language_model( + input_ids, cache=None, inputs_embeds=input_embeds + ) + logits = logits[:, -1, :] + y = sample(logits, temperature=temperature) + tokens = [y.item()] - y = sample(logits, temp=temp) + for _ in range(max_tokens): + logits, cache = model.language_model(y[None], cache=cache) + logits = logits[:, -1, :] + y = sample(logits, temperature) token = y.item() + if token == processor.tokenizer.eos_token_id: + break + tokens.append(token) + + return processor.tokenizer.decode(tokens) + - yield token +def main(): + args = parse_arguments() + raw_image = load_image(args.image) + if raw_image is None: + return + processor, model = initialize_model(args.model) + input_ids, pixel_values = prepare_inputs(processor, raw_image, args.prompt) + print(args.prompt) + generated_text = generate_text( + input_ids, pixel_values, model, processor, args.max_tokens, args.temperature + ) + print(generated_text) -logits, cache = model.language_model(input_ids, cache=None, inputs_embeds=input_embeds) -logits = logits[:, -1, :] -y = sample(logits, temp=temperature) -tokens = [y.item()] -for token, _ in zip( - generate(y, model.language_model, temperature, cache=cache), - range(max_tokens), -): - if token == processor.tokenizer.eos_token_id: - break - tokens.append(token) -print(processor.tokenizer.decode(tokens)) +if __name__ == "__main__": + main() diff --git a/llava/llava.py b/llava/llava.py index 01e76122a..4af5783d7 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -61,7 +61,7 @@ def __init__(self, config: LlaVAConfig): self.vision_feature_layer = config.vision_feature_layer self.vision_feature_select_strategy = config.vision_feature_select_strategy - def __call__( + def get_input_embeddings( self, input_ids: Optional[mx.array] = None, pixel_values: Optional[mx.array] = None, From d8f7b895e1bf197cbfe965e209d5b559da593f37 Mon Sep 17 00:00:00 2001 From: anchen Date: Sun, 25 Feb 2024 01:00:56 +1100 Subject: [PATCH 5/5] chore: clean up --- llava/.gitignore | 162 ---------------------------------------------- llava/generate.py | 14 ++-- llava/test.py | 20 +++--- llava/utils.py | 31 +++++++++ 4 files changed, 49 insertions(+), 178 deletions(-) delete mode 100644 llava/.gitignore create mode 100644 llava/utils.py diff --git a/llava/.gitignore b/llava/.gitignore deleted file mode 100644 index bc0a54fe8..000000000 --- a/llava/.gitignore +++ /dev/null @@ -1,162 +0,0 @@ -**mlx_model# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/#use-with-ide -.pdm.toml - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ - -models \ No newline at end of file diff --git a/llava/generate.py b/llava/generate.py index df92c97f7..c52645818 100644 --- a/llava/generate.py +++ b/llava/generate.py @@ -6,6 +6,7 @@ import requests from PIL import Image from transformers import AutoProcessor +from utils import get_model_path from llava import LlavaModel @@ -17,8 +18,8 @@ def parse_arguments(): parser.add_argument( "--model", type=str, - default="models/llava-hf/llava-1.5-7b-hf", - help="Path to the model directory.", + default="llava-hf/llava-1.5-7b-hf", + help="The path to the local model directory or Hugging Face repo.", ) parser.add_argument( "--image", @@ -30,7 +31,7 @@ def parse_arguments(): "--prompt", type=str, default="USER: \nWhat are these?\nASSISTANT:", - help="Prompt to use for the model.", + help="Message to be processed by the model.", ) parser.add_argument( "--max-tokens", @@ -39,7 +40,7 @@ def parse_arguments(): help="Maximum number of tokens to generate.", ) parser.add_argument( - "--temperature", type=float, default=0.3, help="Temperature for sampling." + "--temp", type=float, default=0.3, help="Temperature for sampling." ) return parser.parse_args() @@ -66,7 +67,8 @@ def load_image(image_source): def initialize_model(model_path): processor = AutoProcessor.from_pretrained(model_path) - model = LlavaModel.from_pretrained(model_path) + + model = LlavaModel.from_pretrained(get_model_path(model_path)) return processor, model @@ -115,7 +117,7 @@ def main(): input_ids, pixel_values = prepare_inputs(processor, raw_image, args.prompt) print(args.prompt) generated_text = generate_text( - input_ids, pixel_values, model, processor, args.max_tokens, args.temperature + input_ids, pixel_values, model, processor, args.max_tokens, args.temp ) print(generated_text) diff --git a/llava/test.py b/llava/test.py index 4cca0c9af..64652324f 100644 --- a/llava/test.py +++ b/llava/test.py @@ -6,14 +6,18 @@ import torch from PIL import Image from transformers import AutoProcessor, LlavaForConditionalGeneration +from utils import get_model_path from llava import LlavaModel -MODEL_PATH = "models/llava-hf/llava-1.5-7b-hf" +MODEL_PATH = "llava-hf/llava-1.5-7b-hf" +PROMPT = "USER: \nWhat are these?\nASSISTANT:" +IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg" def load_mlx_models(path): - model = LlavaModel.from_pretrained(path) + model_path = get_model_path(path) + model = LlavaModel.from_pretrained(model_path) model.eval() return model @@ -32,12 +36,10 @@ def setUpClass(cls): cls.proc = AutoProcessor.from_pretrained(MODEL_PATH) def test_image_features(self): - prompt = "USER: \nWhat are these?\nASSISTANT:" - image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" - raw_image = Image.open(requests.get(image_file, stream=True).raw) + raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw) vision_feature_layer = -2 with torch.no_grad(): - pixel_values = self.proc(prompt, raw_image, return_tensors="pt")[ + pixel_values = self.proc(PROMPT, raw_image, return_tensors="pt")[ "pixel_values" ] @@ -73,12 +75,10 @@ def test_image_features(self): ) def test_merge_input_ids_with_image_features(self): - prompt = "USER: \nWhat are these?\nASSISTANT:" - image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" - raw_image = Image.open(requests.get(image_file, stream=True).raw) + raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw) vision_feature_layer = -2 with torch.no_grad(): - values = self.proc(prompt, raw_image, return_tensors="pt") + values = self.proc(PROMPT, raw_image, return_tensors="pt") pixel_values = values["pixel_values"] input_ids = values["input_ids"] diff --git a/llava/utils.py b/llava/utils.py new file mode 100644 index 000000000..0514b12d5 --- /dev/null +++ b/llava/utils.py @@ -0,0 +1,31 @@ +from pathlib import Path + +from huggingface_hub import snapshot_download + + +def get_model_path(path_or_hf_repo: str) -> Path: + """ + Ensures the model is available locally. If the path does not exist locally, + it is downloaded from the Hugging Face Hub. + + Args: + path_or_hf_repo (str): The local path or Hugging Face repository ID of the model. + + Returns: + Path: The path to the model. + """ + model_path = Path(path_or_hf_repo) + if not model_path.exists(): + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + allow_patterns=[ + "*.json", + "*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + ], + ) + ) + return model_path