From cec0639edabc889d0c42ccea35178fb5e82c627e Mon Sep 17 00:00:00 2001
From: anchen
Date: Fri, 23 Feb 2024 20:43:40 +1100
Subject: [PATCH 1/5] feat: add image processor for llava processor
---
llava/.gitignore | 163 +++++++++++++++++++++++++++++++-
llava/download.py | 54 +++++++++++
llava/image_processor.py | 93 ++++++++++++++++++
llava/mlx_model/mlx_config.json | 36 -------
llava/processing_llava.py | 23 +++++
llava/test.py | 50 ++++++++++
6 files changed, 382 insertions(+), 37 deletions(-)
create mode 100644 llava/download.py
create mode 100644 llava/image_processor.py
delete mode 100644 llava/mlx_model/mlx_config.json
create mode 100644 llava/processing_llava.py
create mode 100644 llava/test.py
diff --git a/llava/.gitignore b/llava/.gitignore
index 857540df8..bc0a54fe8 100644
--- a/llava/.gitignore
+++ b/llava/.gitignore
@@ -1 +1,162 @@
-**mlx_model
\ No newline at end of file
+**mlx_model# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+
+models
\ No newline at end of file
diff --git a/llava/download.py b/llava/download.py
new file mode 100644
index 000000000..e755896bb
--- /dev/null
+++ b/llava/download.py
@@ -0,0 +1,54 @@
+import argparse
+import os
+
+import requests
+from tqdm import tqdm
+
+
+def download_file(url, path):
+ response = requests.get(url, stream=True)
+ total_size_in_bytes = int(response.headers.get("content-length", 0))
+ block_size = 1024 # 1 Kbyte
+ progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
+
+ with open(path, "wb") as file:
+ for data in response.iter_content(block_size):
+ progress_bar.update(len(data))
+ file.write(data)
+
+ progress_bar.close()
+
+
+def download_model(model_name, destination_folder="models"):
+ # Define the base URL and headers for the Hugging Face API
+ base_url = f"https://huggingface.co/{model_name}/resolve/main"
+ headers = {"User-Agent": "Hugging Face Python"}
+
+ # Send a GET request to the Hugging Face API to get a list of all files
+ response = requests.get(
+ f"https://huggingface.co/api/models/{model_name}", headers=headers
+ )
+ response.raise_for_status()
+
+ # Extract the list of files from the response JSON
+ files_to_download = [
+ file["rfilename"]
+ for file in response.json()["siblings"]
+ if not file["rfilename"].endswith(".bin")
+ ]
+
+ # Ensure the directory exists
+ os.makedirs(f"{destination_folder}/{model_name}", exist_ok=True)
+
+ # Download each file
+ for file in files_to_download:
+ print(f"Downloading {file}...")
+ download_file(f"{base_url}/{file}", f"{destination_folder}/{model_name}/{file}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("model_name", type=str, help="Name of the model to download.")
+ args = parser.parse_args()
+
+ download_model(args.model_name)
diff --git a/llava/image_processor.py b/llava/image_processor.py
new file mode 100644
index 000000000..5f5be8484
--- /dev/null
+++ b/llava/image_processor.py
@@ -0,0 +1,93 @@
+# Copyright © 2023-2024 Apple Inc.
+
+import json
+from pathlib import Path
+from typing import List, Tuple
+
+import mlx.core as mx
+import numpy as np
+from PIL.Image import Image
+
+
+class CLIPImageProcessor:
+ """
+ A simple port of
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py.
+ """
+
+ def __init__(
+ self,
+ crop_size: int = 336,
+ do_center_crop: bool = True,
+ do_normalize: bool = True,
+ do_resize: bool = True,
+ image_mean: List[float] = [0.48145466, 0.4578275, 0.40821073],
+ image_std: List[float] = [0.26862954, 0.26130258, 0.27577711],
+ size: int = 336,
+ **kwargs
+ ) -> None:
+ self.crop_size = crop_size
+ self.do_center_crop = do_center_crop
+ self.do_normalize = do_normalize
+ self.do_resize = do_resize
+ self.image_mean = mx.array(image_mean)
+ self.image_std = mx.array(image_std)
+ self.size = size
+
+ def __call__(self, images: List[Image]) -> mx.array:
+ return mx.concatenate(
+ [self._preprocess(image)[None] for image in images], axis=0
+ )
+
+ def _preprocess(self, image: Image) -> mx.array:
+ if self.do_resize:
+ image = resize(image, self.size)
+ if self.do_center_crop:
+ image = center_crop(image, (self.crop_size, self.crop_size))
+ image = mx.array(np.array(image))
+ image = rescale(image)
+ if self.do_normalize:
+ image = normalize(image, self.image_mean, self.image_std)
+ return image
+
+ @staticmethod
+ def from_pretrained(path: str):
+ path = Path(path)
+ with open(path / "preprocessor_config.json", encoding="utf-8") as f:
+ config = json.load(f)
+ return CLIPImageProcessor(**config)
+
+
+def resize(image: Image, short_size: int) -> Image:
+ """
+ Resize so small size to short_size
+ """
+ width, height = image.size
+ short = min(width, height)
+ long = max(width, height)
+ if short == short_size:
+ return image
+ new_short = short_size
+ new_long = int(short_size * long / short)
+ new_size = (new_short, new_long) if width <= height else (new_long, new_short)
+ return image.resize(new_size)
+
+
+def center_crop(image: Image, size: Tuple[int, int]) -> Image:
+ if size[0] % 2 != 0 or size[1] % 2 != 0:
+ raise ValueError("Only even crop sizes supported.")
+ original_width, original_height = image.size
+ crop_height, crop_width = size
+ top = (original_height - crop_height) // 2
+ bottom = top + crop_height
+ left = (original_width - crop_width) // 2
+ right = left + crop_width
+ return image.crop((left, top, right, bottom))
+
+
+def rescale(image: mx.array) -> mx.array:
+ return image.astype(mx.float32) * (1 / 255.0)
+
+
+def normalize(image: mx.array, mean: mx.array, std: mx.array) -> mx.array:
+ return (image - mean) / std
diff --git a/llava/mlx_model/mlx_config.json b/llava/mlx_model/mlx_config.json
deleted file mode 100644
index 482ec26ca..000000000
--- a/llava/mlx_model/mlx_config.json
+++ /dev/null
@@ -1,36 +0,0 @@
-{
- "language_model": {
- "hidden_size": 4096,
- "num_hidden_layers": 32,
- "intermediate_size": 11008,
- "num_attention_heads": 32,
- "rms_norm_eps": 1e-5,
- "vocab_size": 32064,
- "num_key_value_heads": 32,
- "rope_theta": 0,
- "rope_traditional": false,
- "rope_scaling": null
- },
-
- "vision_tower": {
- "num_hidden_layers": 24,
- "hidden_size": 1024,
- "intermediate_size": 4096,
- "num_attention_heads": 16,
- "num_channels": 3,
- "image_size": 336,
- "patch_size": 14
- },
-
- "multi_modal_projector": {
- "in_features": 1024,
- "out_features": 4096
- },
-
- "vision_feature_layer": -2,
- "vision_feature_selection_strategy": "default",
- "image_token_index": 32000,
- "pad_token_id": 32001,
- "tie_word_embeddings": false,
- "vocab_size": 32064
-}
\ No newline at end of file
diff --git a/llava/processing_llava.py b/llava/processing_llava.py
new file mode 100644
index 000000000..705d1ccf6
--- /dev/null
+++ b/llava/processing_llava.py
@@ -0,0 +1,23 @@
+from image_processor import CLIPImageProcessor
+
+
+class LlavaProcessor:
+ def __init__(self, image_processor=None, tokenizer=None):
+ self.image_processor = CLIPImageProcessor()
+ self.tokenizer = tokenizer
+
+ def __call__(
+ self,
+ text=None,
+ images=None,
+ padding=False,
+ truncation=None,
+ max_length=None,
+ return_tensors=None,
+ ):
+ if images is not None:
+ pixel_values = self.image_processor(images)
+ else:
+ pixel_values = None
+
+ return {"pixel_values": pixel_values}
diff --git a/llava/test.py b/llava/test.py
new file mode 100644
index 000000000..2ad128f9f
--- /dev/null
+++ b/llava/test.py
@@ -0,0 +1,50 @@
+import unittest
+
+import mlx.core as mx
+import numpy as np
+import requests
+import torch
+from PIL import Image
+from processing_llava import LlavaProcessor
+from transformers import AutoProcessor, LlavaForConditionalGeneration
+
+MLX_PATH = "models/llava-hf/llava-1.5-7b-hf"
+HF_PATH = "models/llava-hf/llava-1.5-7b-hf"
+
+
+def load_mlx_models(path):
+ processor = LlavaProcessor()
+ return processor, None
+
+
+def load_hf_models(path):
+ processor = AutoProcessor.from_pretrained(path)
+ model = LlavaForConditionalGeneration.from_pretrained(path)
+
+ return processor, model
+
+
+class TestCLIP(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.mx_proc, cls.mx_llava = load_mlx_models(MLX_PATH)
+ cls.hf_proc, cls.hf_llava = load_hf_models(HF_PATH)
+
+ def test_processor(self):
+ prompt = "USER: \nWhat are these?\nASSISTANT:"
+ image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ raw_image = Image.open(requests.get(image_file, stream=True).raw)
+
+ hf_data = mx.array(
+ np.array(
+ self.hf_proc(prompt, raw_image, return_tensors="pt")["pixel_values"]
+ )
+ ).transpose(0, 2, 3, 1)
+
+ mx_data = self.mx_proc(prompt, [raw_image])["pixel_values"]
+
+ self.assertTrue(mx.allclose(mx_data, hf_data, atol=1e-5))
+
+
+if __name__ == "__main__":
+ unittest.main()
From 4dd8bca0279a6548fe495df0eba34fa36f429c30 Mon Sep 17 00:00:00 2001
From: anchen
Date: Sat, 24 Feb 2024 15:19:15 +1100
Subject: [PATCH 2/5] wip
---
llava/clip.py | 409 +++++++++++++++++++++----------------------------
llava/llama.py | 58 +++----
llava/llava.py | 134 +++++++---------
llava/test.py | 29 ++--
4 files changed, 267 insertions(+), 363 deletions(-)
diff --git a/llava/clip.py b/llava/clip.py
index 1568858a7..736307f31 100644
--- a/llava/clip.py
+++ b/llava/clip.py
@@ -1,65 +1,39 @@
-# Copyright © 2023-2024 Apple Inc.
-
+import glob
+import inspect
import json
+import logging
+import math
from dataclasses import dataclass
from pathlib import Path
-from typing import Any, Optional
+from typing import Optional
import mlx.core as mx
import mlx.nn as nn
-from mlx.core import linalg as LA
-from mlx.nn.losses import cross_entropy
-from mlx.utils import tree_flatten
-
-
-@dataclass
-class CLIPVisionOutput:
- pooler_output: mx.array
- last_hidden_state: mx.array
- llava_hidden_state: mx.array
-
-
-@dataclass
-class CLIPTextOutput:
- pooler_output: mx.array
- last_hidden_state: mx.array
-
-
-@dataclass
-class CLIPModelOutput:
- loss: Optional[mx.array]
- text_embeds: Optional[mx.array]
- image_embeds: Optional[mx.array]
- text_model_output: CLIPTextOutput
- vision_model_output: CLIPVisionOutput
-
-
-@dataclass
-class CLIPTextConfig:
- num_hidden_layers: int
- hidden_size: int
- intermediate_size: int
- num_attention_heads: int
- max_position_embeddings: int
- vocab_size: int
-
-
-@dataclass
-class CLIPVisionConfig:
- num_hidden_layers: int
- hidden_size: int
- intermediate_size: int
- num_attention_heads: int
- num_channels: int
- image_size: int
- patch_size: int
@dataclass
-class CLIPConfig:
- text_config: CLIPTextConfig
- vision_config: CLIPVisionConfig
- projection_dim: int
+class VisionConfig:
+ model_type: str
+ num_hidden_layers: int = 24
+ hidden_size: int = 1024
+ intermediate_size: int = 4096
+ num_attention_heads: int = 16
+ image_size: int = 336
+ patch_size: int = 14
+ projection_dim: int = 768
+ vocab_size: int = 32000
+ num_channels: int = 3
+ layer_norm_eps: float = 1e-5
+
+ @classmethod
+ def from_dict(cls, params):
+ return cls(
+ **{
+ k: v
+ for k, v in params.items()
+ if k in inspect.signature(cls).parameters
+ }
+ )
def quick_gelu(x: mx.array) -> mx.array:
@@ -69,227 +43,196 @@ def quick_gelu(x: mx.array) -> mx.array:
return x * mx.sigmoid(1.702 * x)
-def clip_loss(logits: mx.array) -> mx.array:
- N, M = logits.shape
- caption_loss = cross_entropy(logits, mx.arange(N), reduction="mean")
- image_loss = cross_entropy(logits.T, mx.arange(M), reduction="mean")
- return (caption_loss + image_loss) / 2.0
-
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dims: int,
+ num_heads: int,
+ query_input_dims: Optional[int] = None,
+ key_input_dims: Optional[int] = None,
+ value_input_dims: Optional[int] = None,
+ value_dims: Optional[int] = None,
+ value_output_dims: Optional[int] = None,
+ bias: bool = False,
+ ):
+ super().__init__()
-class CLIPEncoderLayer(nn.TransformerEncoderLayer):
- """The transformer encoder layer from CLIP."""
+ if (dims % num_heads) != 0:
+ raise ValueError(
+ "The input feature dimensions should be divisible by the "
+ f"number of heads ({dims} % {num_heads}) != 0"
+ )
- def __init__(self, hidden_dim: int, intermediate_dim: int, num_heads: int):
- super().__init__(
- dims=hidden_dim,
- mlp_dims=intermediate_dim,
- num_heads=num_heads,
- activation=quick_gelu,
- norm_first=True,
- )
- # Add biases to the attention projections
- self.attention = nn.MultiHeadAttention(
- hidden_dim, num_heads, bias=True)
+ query_input_dims = query_input_dims or dims
+ key_input_dims = key_input_dims or dims
+ value_input_dims = value_input_dims or key_input_dims
+ value_dims = value_dims or dims
+ value_output_dims = value_output_dims or dims
+
+ self.num_heads = num_heads
+ self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
+ self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
+ self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
+ self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
+
+ def __call__(self, queries, keys, values, mask=None):
+ queries = self.q_proj(queries)
+ keys = self.k_proj(keys)
+ values = self.v_proj(values)
+
+ num_heads = self.num_heads
+ B, L, D = queries.shape
+ _, S, _ = keys.shape
+ queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
+ keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
+ values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
+
+ scale = math.sqrt(1 / queries.shape[-1])
+ scores = (queries * scale) @ keys
+ if mask is not None:
+ scores = scores + mask.astype(scores.dtype)
+ scores = mx.softmax(scores, axis=-1)
+ values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
+
+ return self.out_proj(values_hat)
+
+
+class MLP(nn.Module):
+ def __init__(self, config: VisionConfig):
+ super().__init__()
+ self.activation_fn = quick_gelu
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+ def __call__(self, x: mx.array) -> mx.array:
+ x = self.activation_fn(self.fc1(x))
+ x = self.fc2(x)
+ return x
-class CLIPTextModel(nn.Module):
- """Implements the text encoder transformer from CLIP."""
- def __init__(self, config: CLIPTextConfig):
+class EncoderLayer(nn.Module):
+ def __init__(self, config: VisionConfig):
super().__init__()
-
- self.token_embedding = nn.Embedding(
- config.vocab_size, config.hidden_size)
- self.position_embedding = mx.zeros(
- (config.max_position_embeddings, config.hidden_size)
+ self.embed_dim = config.hidden_size
+ self.self_attn = Attention(
+ config.hidden_size, config.num_attention_heads, bias=True
)
- self.layers = [
- CLIPEncoderLayer(
- config.hidden_size, config.intermediate_size, config.num_attention_heads
- )
- for _ in range(config.num_hidden_layers)
- ]
- self.final_layer_norm = nn.LayerNorm(config.hidden_size)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = MLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- def _embed(self, x: mx.array) -> mx.array:
- embeddings = self.token_embedding(x)
- embeddings += self.position_embedding[: x.shape[1]]
- return embeddings
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
+ y = self.layer_norm1(x)
+ y = self.self_attn(y, y, y, mask)
+ x = x + y
+ y = self.layer_norm2(x)
+ y = self.mlp(y)
+ return x + y
- def __call__(self, x: mx.array) -> CLIPTextOutput:
- B, N = x.shape
- eot_tokens = mx.argmax(x, axis=-1)
- x = self._embed(x)
- mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
- for l in self.layers:
- x = l(x, mask)
- last_hidden_state = self.final_layer_norm(x)
- pooler_output = last_hidden_state[mx.arange(B), eot_tokens]
-
- return CLIPTextOutput(
- pooler_output=pooler_output, last_hidden_state=last_hidden_state
- )
+class Encoder(nn.Module):
+ def __init__(self, config: VisionConfig):
+ super().__init__()
+ self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
-class CLIPVisionModel(nn.Module):
- """Implements the vision encoder transformer from CLIP."""
- def __init__(self, config: CLIPVisionConfig):
+class VisionEmbeddings(nn.Module):
+ def __init__(self, config: VisionConfig):
super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
self.class_embedding = mx.zeros((config.hidden_size,))
+
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
- out_channels=config.hidden_size,
- kernel_size=config.patch_size,
- stride=config.patch_size,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
bias=False,
)
- num_patches = (config.image_size // config.patch_size) ** 2
- num_positions = num_patches + 1
- self.position_embedding = mx.zeros((num_positions, config.hidden_size))
- self.pre_layernorm = nn.LayerNorm(config.hidden_size)
- self.layers = [
- CLIPEncoderLayer(
- config.hidden_size, config.intermediate_size, config.num_attention_heads
- )
- for _ in range(config.num_hidden_layers)
- ]
- self.post_layernorm = nn.LayerNorm(config.hidden_size)
- def _embed(self, x: mx.array) -> mx.array:
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches + 1
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+ def __call__(self, x: mx.array) -> mx.array:
batch_size = x.shape[0]
- # Patchify using conv:
- # [batch_size, sqrt(num_patches), sqrt(num_patches), embed_dim]
patch_embeddings = self.patch_embedding(x)
- # [batch_size, num_patches, embed_dim]
- patch_embeddings = mx.flatten(
- patch_embeddings, start_axis=1, end_axis=2)
+ patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
embed_dim = patch_embeddings.shape[-1]
- # Prepend embeddings
- # [batch_size, 1, embed_dim]
cls_embeddings = mx.broadcast_to(
self.class_embedding, (batch_size, 1, embed_dim)
)
- # [batch_size, num_patches + 1, embed_dim]
embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
- # Add positional encoding
- embeddings += self.position_embedding
+ embeddings += self.position_embedding.weight
return embeddings
- def __call__(self, x: mx.array) -> CLIPVisionOutput:
- x = self._embed(x)
- x = self.pre_layernorm(x)
-
- for l in self.layers:
- x = l(x, mask=None)
-
- # Extract token embedding
- pooler_output = self.post_layernorm(x[:, 0, :])
-
- llava_hidden_state = x
- return CLIPVisionOutput(pooler_output=pooler_output, last_hidden_state=x, llava_hidden_state=llava_hidden_state)
-
-class CLIPModel(nn.Module):
- def __init__(self, config: CLIPConfig):
- self.text_model = CLIPTextModel(config.text_config)
- self.vision_model = CLIPVisionModel(config.vision_config)
-
- text_embed_dim = config.text_config.hidden_size
- vision_embed_dim = config.vision_config.hidden_size
- projection_dim = config.projection_dim
+class ClipVisionModel(nn.Module):
+ def __init__(self, config: VisionConfig):
+ super().__init__()
+ self.embeddings = VisionEmbeddings(config)
+ self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
+ self.encoder = Encoder(config)
+ self.post_layernorm = nn.LayerNorm(config.hidden_size)
- self.visual_projection = nn.Linear(
- vision_embed_dim, projection_dim, bias=False)
- self.text_projection = nn.Linear(
- text_embed_dim, projection_dim, bias=False)
- self.logit_scale = mx.array(0.0)
+ def __call__(
+ self,
+ x: mx.array,
+ output_hidden_states: Optional[bool] = None,
+ ) -> mx.array:
+ x = self.embeddings(x)
+ x = self.pre_layrnorm(x)
- def get_text_features(self, x: mx.array) -> mx.array:
- return self.text_projection(self.text_model(x).pooler_output)
+ encoder_states = (x,) if output_hidden_states else None
- def get_image_features(self, x: mx.array) -> mx.array:
- return self.visual_projection(self.vision_model(x).pooler_output)
+ for l in self.encoder.layers:
+ x = l(x, mask=None)
+ if output_hidden_states:
+ encoder_states = encoder_states + (x,)
- def __call__(
- self,
- input_ids: Optional[mx.array] = None,
- pixel_values: Optional[mx.array] = None,
- return_loss=False,
- ) -> CLIPModelOutput:
- if input_ids is not None:
- text_model_output = self.text_model(input_ids)
- text_embeds = self.text_projection(text_model_output.pooler_output)
- text_embeds = text_embeds / \
- LA.norm(text_embeds, axis=-1, keepdims=True)
- else:
- text_embeds = None
- text_model_output = None
-
- if pixel_values is not None:
- vision_model_output = self.vision_model(pixel_values)
- image_embeds = self.visual_projection(
- vision_model_output.pooler_output)
- image_embeds = image_embeds / \
- LA.norm(image_embeds, axis=-1, keepdims=True)
- else:
- image_embeds = None
- vision_model_output = None
-
- if return_loss and (input_ids is None or pixel_values is None):
- raise ValueError(
- "Must provide text and image inputs to compute loss.")
-
- if return_loss:
- logit_scale = mx.exp(self.logit_scale)
- logits = (text_embeds @ image_embeds.T) * logit_scale
- loss = clip_loss(logits)
- else:
- loss = None
-
- return CLIPModelOutput(
- loss=loss,
- text_embeds=text_embeds,
- image_embeds=image_embeds,
- vision_model_output=vision_model_output,
- text_model_output=text_model_output,
- )
+ pooler_output = self.post_layernorm(x[:, 0, :])
+ return pooler_output, x, encoder_states
@staticmethod
def from_pretrained(path: str):
path = Path(path)
with open(path / "config.json", "r") as fid:
- config = json.load(fid)
-
- text_config = config["text_config"]
- text_config = CLIPTextConfig(
- num_hidden_layers=text_config["num_hidden_layers"],
- hidden_size=text_config["hidden_size"],
- intermediate_size=text_config["intermediate_size"],
- num_attention_heads=text_config["num_attention_heads"],
- max_position_embeddings=text_config["max_position_embeddings"],
- vocab_size=text_config["vocab_size"],
- )
+ config_dict = json.load(fid)
+ vision_config = VisionConfig(**config_dict["vision_config"])
- vision_config = config["vision_config"]
+ model = ClipVisionModel(vision_config)
- vision_config = CLIPVisionConfig(
- num_hidden_layers=vision_config["num_hidden_layers"],
- hidden_size=vision_config["hidden_size"],
- intermediate_size=vision_config["intermediate_size"],
- num_attention_heads=vision_config["num_attention_heads"],
- num_channels=3,
- image_size=vision_config["image_size"],
- patch_size=vision_config["patch_size"],
- )
+ weight_files = glob.glob(str(path / "*.safetensors"))
+ if not weight_files:
+ logging.error(f"No safetensors found in {path}")
+ raise FileNotFoundError(f"No safetensors found in {path}")
- config = CLIPConfig(
- text_config=text_config,
- vision_config=vision_config,
- projection_dim=config["projection_dim"],
- )
- model = CLIPModel(config)
- model.load_weights(str(path / "weights.npz"))
+ weights = {}
+ for wf in weight_files:
+ weights.update(mx.load(wf))
+
+ weights = model.sanitize(weights)
+ model.load_weights(list(weights.items()))
+ model.load_weights(weights)
return model
+
+ @staticmethod
+ def sanitize(weights):
+ sanitized_weights = {}
+ for k, v in weights.items():
+ if "position_ids" in k:
+ # Remove unused position_ids
+ continue
+ elif "patch_embedding.weight" in k:
+ # pytorch conv2d expects the weight tensor to be of shape [out_channels, in_channels, kH, KW]
+ # mlx conv2d expects the weight tensor to be of shape [out_channels, kH, KW, in_channels]
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
+ else:
+ sanitized_weights[k] = v
+
+ return sanitized_weights
diff --git a/llava/llama.py b/llava/llama.py
index ba5baeda5..242251a39 100644
--- a/llava/llama.py
+++ b/llava/llama.py
@@ -1,14 +1,25 @@
+import inspect
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
-import inspect
-
@dataclass
-class BaseModelArgs:
+class TextConfig:
+ model_type: str
+ hidden_size: int = 4096
+ num_hidden_layers: int = 32
+ intermediate_size: int = 11008
+ num_attention_heads: int = 32
+ rms_norm_eps: float = 1e-6
+ vocab_size: int = 32000
+ num_key_value_heads: int = None
+ rope_theta: float = 10000
+ rope_traditional: bool = False
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
+
@classmethod
def from_dict(cls, params):
return cls(
@@ -19,21 +30,6 @@ def from_dict(cls, params):
}
)
-
-@dataclass
-class ModelArgs(BaseModelArgs):
- model_type: str
- hidden_size: int
- num_hidden_layers: int
- intermediate_size: int
- num_attention_heads: int
- rms_norm_eps: float
- vocab_size: int
- num_key_value_heads: int = None
- rope_theta: float = 10000
- rope_traditional: bool = False
- rope_scaling: Optional[Dict[str, Union[float, str]]] = None
-
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
@@ -41,12 +37,10 @@ def __post_init__(self):
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
- raise ValueError(
- f"rope_scaling must contain keys {required_keys}")
+ raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
- raise ValueError(
- "rope_scaling 'type' currently only supports 'linear'")
+ raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class RMSNorm(nn.Module):
@@ -64,7 +58,7 @@ def __call__(self, x):
class Attention(nn.Module):
- def __init__(self, args: ModelArgs):
+ def __init__(self, args: TextConfig):
super().__init__()
dim = args.hidden_size
@@ -106,8 +100,7 @@ def __call__(
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
- values = values.reshape(
- B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if self.repeats > 1:
keys = mx.repeat(keys, self.repeats, axis=1)
@@ -126,8 +119,7 @@ def __call__(
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
- scores = mx.softmax(scores.astype(mx.float32),
- axis=-1).astype(scores.dtype)
+ scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
@@ -144,15 +136,14 @@ def __call__(self, x) -> mx.array:
class TransformerBlock(nn.Module):
- def __init__(self, args: ModelArgs):
+ def __init__(self, args: TextConfig):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
- self.post_attention_layernorm = RMSNorm(
- args.hidden_size, eps=args.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
def __call__(
@@ -169,7 +160,7 @@ def __call__(
class Llama(nn.Module):
- def __init__(self, args: ModelArgs):
+ def __init__(self, args: TextConfig):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
@@ -190,8 +181,7 @@ def __call__(
mask = None
if h.shape[1] > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(
- h.shape[1])
+ mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
@@ -204,7 +194,7 @@ def __call__(
class LlamaModel(nn.Module):
- def __init__(self, args: ModelArgs):
+ def __init__(self, args: TextConfig):
super().__init__()
self.model_type = args.model_type
self.model = Llama(args)
diff --git a/llava/llava.py b/llava/llava.py
index 34eabb9c6..4edfbeaca 100644
--- a/llava/llava.py
+++ b/llava/llava.py
@@ -1,56 +1,41 @@
-from clip import CLIPVisionModel
-from llama import LlamaModel
-from pathlib import Path
+import glob
+import inspect
import json
-import mlx.nn as nn
-import mlx.core as mx
-from typing import Any, Optional, Dict, Union
-
-
+import logging
from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Optional
+import mlx.core as mx
+import mlx.nn as nn
+from llama import LlamaModel, TextConfig
-@dataclass
-class VisionConfig:
- num_hidden_layers: int
- hidden_size: int
- intermediate_size: int
- num_attention_heads: int
- num_channels: int
- image_size: int
- patch_size: int
-
-
-@dataclass
-class LLMConfig:
- model_type: str
- hidden_size: int
- num_hidden_layers: int
- intermediate_size: int
- num_attention_heads: int
- rms_norm_eps: float
- vocab_size: int
- num_key_value_heads: int
- rope_theta: float = 10000
- rope_traditional: bool = False
- rope_scaling: Optional[Dict[str, Union[float, str]]] = None
-
-
-@dataclass
-class ProjectionConfig:
- in_features: int
- out_features: int
+from clip import ClipVisionModel, VisionConfig
@dataclass
class LlaVAConfig:
- llm_config: Any
+ text_config: TextConfig
vision_config: VisionConfig
- projection_config: ProjectionConfig
+ ignore_index: int = -100
+ image_token_index: int = 32000
+ vision_feature_select_strategy: str = "default"
+ vision_feature_layer: int = -2
+ vocab_size: int = 32000
+
+ @classmethod
+ def from_dict(cls, params):
+ return cls(
+ **{
+ k: v
+ for k, v in params.items()
+ if k in inspect.signature(cls).parameters
+ }
+ )
class LlavaMultiModalProjector(nn.Module):
- def __init__(self, config: Any):
+ def __init__(self, config: LlaVAConfig):
super().__init__()
self.linear_1 = nn.Linear(config.in_features, config.out_features)
self.gelu = nn.GELU()
@@ -65,14 +50,19 @@ def forward(self, x: mx.array) -> mx.array:
class LlavaModel(nn.Module):
def __init__(self, config: LlaVAConfig):
- self.vision_tower = CLIPVisionModel(config=config.vision_config)
- self.language_model = LlamaModel(args=config.llm_config)
+ self.vision_tower = ClipVisionModel(
+ config=VisionConfig.from_dict(config.vision_config)
+ )
+ self.language_model = LlamaModel(args=TextConfig.from_dict(config.text_config))
self.multi_modal_projector = LlavaMultiModalProjector(
- config=config.projection_config)
+ config=config.projection_config
+ )
- def __call__(self,
- input_ids: Optional[mx.array] = None,
- pixel_values: Optional[mx.array] = None):
+ def __call__(
+ self,
+ input_ids: Optional[mx.array] = None,
+ pixel_values: Optional[mx.array] = None,
+ ):
# TODO: add the forward pass
if pixel_values is not None and input_ids.shape[1] != 1:
@@ -81,10 +71,11 @@ def __call__(self,
# TODO: this is not the correct output layer, but it's a placeholder
selected_image_feature = image_outputs.pooler_output
- image_features = self.multi_modal_projector(
- selected_image_feature)
+ image_features = self.multi_modal_projector(selected_image_feature)
- def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
+ def _merge_input_ids_with_image_features(
+ self, image_features, inputs_embeds, input_ids, attention_mask, labels
+ ):
# TODO: https://github.com/huggingface/transformers/blob/4f09d0fd888dbf2660313f9715992822acfb99ce/src/transformers/models/llava/modeling_llava.py#L279
special_image_token_mask = input_ids == self.config.special_tokens.image
@@ -97,39 +88,20 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in
def from_pretrained(path: str):
path = Path(path)
- with open(path / "mlx_config.json", "r") as f:
+ with open(path / "config.json", "r") as f:
model_config = json.load(f)
- llava_mlx_config = LlaVAConfig(
- llm_config=LLMConfig(
- model_type='vicuna',
- hidden_size=model_config['language_model']['hidden_size'],
- num_hidden_layers=model_config['language_model']['num_hidden_layers'],
- intermediate_size=model_config['language_model']['intermediate_size'],
- num_attention_heads=model_config['language_model']['num_attention_heads'],
- rms_norm_eps=model_config['language_model']['rms_norm_eps'],
- vocab_size=model_config['language_model']['vocab_size'],
- num_key_value_heads=model_config['language_model']['num_key_value_heads'],
- rope_theta=model_config['language_model']['rope_theta'],
- rope_traditional=model_config['language_model']['rope_traditional'],
- rope_scaling=model_config['language_model']['rope_scaling'],
- ),
- vision_config=VisionConfig(
- num_hidden_layers=model_config['vision_tower']['num_hidden_layers'],
- hidden_size=model_config['vision_tower']['hidden_size'],
- intermediate_size=model_config['vision_tower']['intermediate_size'],
- num_attention_heads=model_config['vision_tower']['num_attention_heads'],
- num_channels=model_config['vision_tower']['num_channels'],
- image_size=model_config['vision_tower']['image_size'],
- patch_size=model_config['vision_tower']['patch_size'],
- ),
- projection_config=ProjectionConfig(
- in_features=model_config['multi_modal_projector']['in_features'],
- out_features=model_config['multi_modal_projector']['out_features'],
- )
- )
+ model_config = LlaVAConfig.from_dict(model_config)
+ model = LlavaModel(model_config)
+ weight_files = glob.glob(str(path / "*.safetensors"))
+ if not weight_files:
+ logging.error(f"No safetensors found in {path}")
+ raise FileNotFoundError(f"No safetensors found in {path}")
- model = LlavaModel(llava_mlx_config)
- model.load_weights(str(path / "weights.npz"))
+ weights = {}
+ for wf in weight_files:
+ weights.update(mx.load(wf))
+ weights = ClipVisionModel.sanitize(weights)
+ model.load_weights(list(weights.items()))
return model
diff --git a/llava/test.py b/llava/test.py
index 2ad128f9f..0d693677e 100644
--- a/llava/test.py
+++ b/llava/test.py
@@ -3,32 +3,31 @@
import mlx.core as mx
import numpy as np
import requests
-import torch
from PIL import Image
from processing_llava import LlavaProcessor
from transformers import AutoProcessor, LlavaForConditionalGeneration
-MLX_PATH = "models/llava-hf/llava-1.5-7b-hf"
-HF_PATH = "models/llava-hf/llava-1.5-7b-hf"
+from llava import LlavaModel
+
+MODEL_PATH = "models/llava-hf/llava-1.5-7b-hf"
def load_mlx_models(path):
- processor = LlavaProcessor()
- return processor, None
+ model = LlavaModel.from_pretrained(path)
+ return model
def load_hf_models(path):
- processor = AutoProcessor.from_pretrained(path)
model = LlavaForConditionalGeneration.from_pretrained(path)
-
- return processor, model
+ return model
class TestCLIP(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls.mx_proc, cls.mx_llava = load_mlx_models(MLX_PATH)
- cls.hf_proc, cls.hf_llava = load_hf_models(HF_PATH)
+ cls.mx_llava = load_mlx_models(MODEL_PATH)
+ cls.hf_llava = load_hf_models(MODEL_PATH)
+ cls.proc = AutoProcessor.from_pretrained(MODEL_PATH)
def test_processor(self):
prompt = "USER: \nWhat are these?\nASSISTANT:"
@@ -36,12 +35,12 @@ def test_processor(self):
raw_image = Image.open(requests.get(image_file, stream=True).raw)
hf_data = mx.array(
- np.array(
- self.hf_proc(prompt, raw_image, return_tensors="pt")["pixel_values"]
- )
- ).transpose(0, 2, 3, 1)
+ self.proc(prompt, raw_image, return_tensors="np")["pixel_values"]
+ )
- mx_data = self.mx_proc(prompt, [raw_image])["pixel_values"]
+ mx_data = mx.array(
+ self.proc(prompt, raw_image, return_tensors="np")["pixel_values"]
+ )
self.assertTrue(mx.allclose(mx_data, hf_data, atol=1e-5))
From c4ea94fdec3ee62a38b51f7408def4e129172561 Mon Sep 17 00:00:00 2001
From: anchen
Date: Sun, 25 Feb 2024 00:17:13 +1100
Subject: [PATCH 3/5] feat: llava working example
---
llava/Local LLava.ipynb | 1121 -------------------------------
llava/config.py | 37 -
llava/convert.py | 87 ---
llava/download.py | 54 --
llava/generate.py | 58 ++
llava/image_processor.py | 93 ---
llava/{llama.py => language.py} | 12 +-
llava/llava.py | 112 ++-
llava/processing_llava.py | 23 -
llava/test.py | 103 ++-
llava/utils.py | 70 --
llava/{clip.py => vision.py} | 30 +-
12 files changed, 255 insertions(+), 1545 deletions(-)
delete mode 100644 llava/Local LLava.ipynb
delete mode 100644 llava/config.py
delete mode 100644 llava/convert.py
delete mode 100644 llava/download.py
create mode 100644 llava/generate.py
delete mode 100644 llava/image_processor.py
rename llava/{llama.py => language.py} (95%)
delete mode 100644 llava/processing_llava.py
delete mode 100644 llava/utils.py
rename llava/{clip.py => vision.py} (90%)
diff --git a/llava/Local LLava.ipynb b/llava/Local LLava.ipynb
deleted file mode 100644
index 69c760103..000000000
--- a/llava/Local LLava.ipynb
+++ /dev/null
@@ -1,1121 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Goal: Download and convert the weights of LlaVA into MLX, and test the forward pass of this model on example data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "import shutil\n",
- "from pathlib import Path\n",
- "import os\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "mlx_path = Path('mlx_model')\n",
- "\n",
- "if not os.path.exists(mlx_path):\n",
- " os.makedirs(mlx_path)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/noahkasmanoff/anaconda3/envs/mlx/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n",
- "Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 214177.23it/s]\n",
- "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
- ]
- }
- ],
- "source": [
- "import mlx.core as mx\n",
- "from convert import get_model_path, fetch_from_hub, hf_repo\n",
- "\n",
- "\n",
- "model_path = get_model_path(hf_repo)\n",
- "model_config, model_weights, model_weight_files, config, tokenizer = fetch_from_hub(model_path)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[INFO] Converting\n",
- "[INFO] Saving\n"
- ]
- }
- ],
- "source": [
- "from utils import map_weights, should_keep_weight\n",
- "do_convert = True\n",
- "if do_convert:\n",
- "\n",
- " print(\"[INFO] Converting\")\n",
- " mlx_weights = dict(map_weights(k, v) for (k, v) in model_weights.items())\n",
- " mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)}\n",
- " print(\"[INFO] Saving\")\n",
- " mx.savez(str(mlx_path / \"weights.npz\"), **mlx_weights)\n",
- " for fn in [\"config.json\", \"merges.txt\", \"vocab.json\", \"preprocessor_config.json\"]:\n",
- " if fn in os.listdir(model_path):\n",
- " shutil.copyfile(\n",
- " str(model_path / f\"{fn}\"),\n",
- " str(mlx_path / f\"{fn}\"),\n",
- " )\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "from llava import LlavaModel\n",
- "mlx_model = LlavaModel.from_pretrained(path='mlx_model')\n",
- "\n",
- "\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "mlx_model = LlavaModel.from_pretrained(path='mlx_model')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "LlavaModel(\n",
- " (vision_tower): CLIPVisionModel(\n",
- " (patch_embedding): Conv2d(3, 1024, kernel_size=(14,), stride=(14, 14), padding=(0, 0), bias=False)\n",
- " (pre_layernorm): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (layers.0): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.1): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.2): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.3): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.4): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.5): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.6): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.7): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.8): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.9): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.10): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.11): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.12): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.13): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.14): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.15): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.16): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.17): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.18): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.19): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.20): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.21): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.22): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (layers.23): CLIPEncoderLayer(\n",
- " (attention): MultiHeadAttention(\n",
- " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n",
- " )\n",
- " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n",
- " (dropout1): Dropout(p=0.0)\n",
- " (dropout2): Dropout(p=0.0)\n",
- " )\n",
- " (post_layernorm): LayerNorm(1024, eps=1e-05, affine=True)\n",
- " )\n",
- " (language_model): LlamaModel(\n",
- " (model): Llama(\n",
- " (embed_tokens): Embedding(32064, 4096)\n",
- " (layers.0): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.1): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.2): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.3): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.4): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.5): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.6): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.7): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.8): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.9): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.10): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.11): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.12): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.13): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.14): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.15): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.16): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.17): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.18): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.19): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.20): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.21): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.22): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.23): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.24): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.25): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.26): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.27): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.28): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.29): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.30): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (layers.31): TransformerBlock(\n",
- " (self_attn): Attention(\n",
- " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n",
- " (rope): RoPE(128, traditional=False)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n",
- " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n",
- " )\n",
- " (input_layernorm): RMSNorm()\n",
- " (post_attention_layernorm): RMSNorm()\n",
- " )\n",
- " (norm): RMSNorm()\n",
- " )\n",
- " (lm_head): Linear(input_dims=4096, output_dims=32064, bias=False)\n",
- " )\n",
- " (multi_modal_projector): LlavaMultiModalProjector(\n",
- " (linear_1): Linear(input_dims=1024, output_dims=4096, bias=True)\n",
- " (gelu): GELU()\n",
- " (linear_2): Linear(input_dims=4096, output_dims=4096, bias=True)\n",
- " )\n",
- ")"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "mlx_model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
- ]
- }
- ],
- "source": [
- "# Now that model weights are loaded in, now we can try and run inference code / set that up.\n",
- "\n",
- "# load the processor\n",
- "from transformers import AutoProcessor\n",
- "import requests\n",
- "from PIL import Image\n",
- "processor = AutoProcessor.from_pretrained(\"llava-hf/llava-1.5-7b-hf\")\n",
- "\n",
- "prompt = \"\\nUSER: What's the content of the image?\\nASSISTANT:\"\n",
- "url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n",
- "image = Image.open(requests.get(url, stream=True).raw)\n",
- "\n",
- "inputs = processor(text=prompt, images=image, return_tensors=\"pt\")\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 31,
- "metadata": {},
- "outputs": [],
- "source": [
- "\n",
- "input_ids = mx.array(inputs[\"input_ids\"].numpy())\n",
- "pixel_values = mx.array(inputs[\"pixel_values\"].numpy())\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 37,
- "metadata": {},
- "outputs": [],
- "source": [
- "vision_model_output = mlx_model.vision_tower(pixel_values.transpose(0,2,3,1))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 55,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(1, 577, 1024)"
- ]
- },
- "execution_count": 55,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": 57,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "CLIPVisionOutput(pooler_output=array([[-0.721487, -0.476275, 0.0173661, ..., 0.190072, -1.71528, 1.36224]], dtype=float32), last_hidden_state=array([[[-0.333623, -0.269844, 0.025435, ..., -0.0516554, -0.729696, 0.542679],\n",
- " [0.208684, 0.92752, 0.0233985, ..., 1.59934, -0.024813, 0.879629],\n",
- " [0.550235, 0.45201, 0.80935, ..., 1.63056, -0.37727, 0.699322],\n",
- " ...,\n",
- " [0.740987, 0.445616, 0.893172, ..., 0.523529, 0.0230118, -0.457155],\n",
- " [0.49297, 0.0680847, 0.79401, ..., 0.476083, 0.274526, -0.284749],\n",
- " [-0.0411091, 0.290756, 0.518906, ..., 0.242572, 0.40785, 0.420446]]], dtype=float32))"
- ]
- },
- "execution_count": 57,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "vision_model_output"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "mlx",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/llava/config.py b/llava/config.py
deleted file mode 100644
index f8617ac14..000000000
--- a/llava/config.py
+++ /dev/null
@@ -1,37 +0,0 @@
-model_config = {
- 'language_model': {
- 'hidden_size': 4096,
- 'num_hidden_layers': 32,
- 'intermediate_size': 11008,
- 'num_attention_heads': 32,
- 'rms_norm_eps': 1e-5,
- 'vocab_size': 32000,
- 'num_key_value_heads': 32,
- 'rope_theta': 0,
- 'rope_traditional': False,
- 'rope_scaling': None},
-
- 'vision_tower': {
- 'num_hidden_layers': 24,
- 'hidden_size': 1024,
- 'intermediate_size': 4096,
- 'num_attention_heads': 16,
- 'num_channels': 3,
- 'image_size': 336,
- 'patch_size': 14
- },
-
- 'multi_modal_projector': {
- 'in_features': 1024,
- 'out_features': 4096
- },
-
- 'vision_feature_layer': -2,
- 'vision_feature_selection_strategy': 'default',
- 'image_token_index': 32000,
- 'pad_token_id': 32001,
- 'tie_word_embeddings': False,
- 'vocab_size': 32064, # TODO: confirm this value
-
-
-}
diff --git a/llava/convert.py b/llava/convert.py
deleted file mode 100644
index af9973850..000000000
--- a/llava/convert.py
+++ /dev/null
@@ -1,87 +0,0 @@
-
-from safetensors.torch import load_file
-from pathlib import Path
-import glob
-import json
-import logging
-import mlx.nn as nn
-from huggingface_hub import snapshot_download
-from typing import Dict, Tuple
-from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
-
-
-hf_repo = "llava-hf/llava-1.5-7b-hf"
-
-
-def get_model_path(path_or_hf_repo: str) -> Path:
- """
- Ensures the model is available locally. If the path does not exist locally,
- it is downloaded from the Hugging Face Hub.
-
- Args:
- path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
-
- Returns:
- Path: The path to the model.
- """
- model_path = Path(path_or_hf_repo)
- if not model_path.exists():
- model_path = Path(
- snapshot_download(
- repo_id=path_or_hf_repo,
- allow_patterns=[
- "*.json",
- "*.safetensors",
- "*.py",
- "tokenizer.model",
- "*.tiktoken",
- ],
- )
- )
- return model_path
-
-
-def load_model(model_path: Path) -> nn.Module:
- """
- Load and initialize the model from a given path.
-
- Args:
- model_path (Path): The path to load the model from.
-
- Returns:
- nn.Module: The loaded and initialized model.
-
- Raises:
- FileNotFoundError: If the weight files (.safetensors) are not found.
- ValueError: If the model class or args class are not found or cannot be instantiated.
- """
- try:
- with open(model_path / "config.json", "r") as f:
- config = json.load(f)
- except FileNotFoundError:
- logging.error(f"Config file not found in {model_path}")
- raise
-
- weight_files = glob.glob(str(model_path / "*.safetensors"))
- if not weight_files:
- logging.error(f"No safetensors found in {model_path}")
- raise FileNotFoundError(f"No safetensors found in {model_path}")
-
- weights = {}
- for wf in weight_files:
- weights.update(load_file(wf))
-
- return config, weights, weight_files
-
-
-def fetch_from_hub(
- model_path: Path,
-) -> Tuple[Dict, dict, PreTrainedTokenizer]:
- model_config, model_weights, model_weight_files = load_model(model_path)
-
- config = AutoConfig.from_pretrained(model_path)
- tokenizer = AutoTokenizer.from_pretrained(
- model_path) # TODO: should this be the processor?
-
- # TODO: replace outputs with the model alone once conversion is complete
- return model_config, model_weights, model_weight_files, config, tokenizer
diff --git a/llava/download.py b/llava/download.py
deleted file mode 100644
index e755896bb..000000000
--- a/llava/download.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import argparse
-import os
-
-import requests
-from tqdm import tqdm
-
-
-def download_file(url, path):
- response = requests.get(url, stream=True)
- total_size_in_bytes = int(response.headers.get("content-length", 0))
- block_size = 1024 # 1 Kbyte
- progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
-
- with open(path, "wb") as file:
- for data in response.iter_content(block_size):
- progress_bar.update(len(data))
- file.write(data)
-
- progress_bar.close()
-
-
-def download_model(model_name, destination_folder="models"):
- # Define the base URL and headers for the Hugging Face API
- base_url = f"https://huggingface.co/{model_name}/resolve/main"
- headers = {"User-Agent": "Hugging Face Python"}
-
- # Send a GET request to the Hugging Face API to get a list of all files
- response = requests.get(
- f"https://huggingface.co/api/models/{model_name}", headers=headers
- )
- response.raise_for_status()
-
- # Extract the list of files from the response JSON
- files_to_download = [
- file["rfilename"]
- for file in response.json()["siblings"]
- if not file["rfilename"].endswith(".bin")
- ]
-
- # Ensure the directory exists
- os.makedirs(f"{destination_folder}/{model_name}", exist_ok=True)
-
- # Download each file
- for file in files_to_download:
- print(f"Downloading {file}...")
- download_file(f"{base_url}/{file}", f"{destination_folder}/{model_name}/{file}")
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("model_name", type=str, help="Name of the model to download.")
- args = parser.parse_args()
-
- download_model(args.model_name)
diff --git a/llava/generate.py b/llava/generate.py
new file mode 100644
index 000000000..a51fe2cc9
--- /dev/null
+++ b/llava/generate.py
@@ -0,0 +1,58 @@
+import mlx.core as mx
+import mlx.nn as nn
+import requests
+from PIL import Image
+from transformers import AutoProcessor
+
+from llava import LlavaModel
+
+MODEL_PATH = "models/llava-hf/llava-1.5-7b-hf"
+
+prompt = "USER: \nWhat are these?\nASSISTANT:"
+image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
+raw_image = Image.open(requests.get(image_file, stream=True).raw)
+
+
+processor = AutoProcessor.from_pretrained(MODEL_PATH)
+model = LlavaModel.from_pretrained(MODEL_PATH)
+
+values = processor(prompt, raw_image, return_tensors="np")
+pixel_values = mx.array(values["pixel_values"])
+input_ids = mx.array(values["input_ids"])
+
+input_embeds = model(input_ids, pixel_values)
+max_tokens = 100
+temperature = 0.3
+
+
+def sample(logits, temp=0.0):
+ if temp == 0:
+ return mx.argmax(logits, axis=-1)
+ else:
+ return mx.random.categorical(logits * (1 / temp))
+
+
+def generate(y: mx.array, model: nn.Module, temp: float = 0.0, cache=None):
+ while True:
+ logits, cache = model(y[None], cache=cache)
+ logits = logits[:, -1, :]
+
+ y = sample(logits, temp=temp)
+ token = y.item()
+
+ yield token
+
+
+logits, cache = model.language_model(input_ids, cache=None, inputs_embeds=input_embeds)
+logits = logits[:, -1, :]
+y = sample(logits, temp=temperature)
+tokens = [y.item()]
+for token, _ in zip(
+ generate(y, model.language_model, temperature, cache=cache),
+ range(max_tokens),
+):
+ if token == processor.tokenizer.eos_token_id:
+ break
+ tokens.append(token)
+
+print(processor.tokenizer.decode(tokens))
diff --git a/llava/image_processor.py b/llava/image_processor.py
deleted file mode 100644
index 5f5be8484..000000000
--- a/llava/image_processor.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# Copyright © 2023-2024 Apple Inc.
-
-import json
-from pathlib import Path
-from typing import List, Tuple
-
-import mlx.core as mx
-import numpy as np
-from PIL.Image import Image
-
-
-class CLIPImageProcessor:
- """
- A simple port of
- https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py.
- """
-
- def __init__(
- self,
- crop_size: int = 336,
- do_center_crop: bool = True,
- do_normalize: bool = True,
- do_resize: bool = True,
- image_mean: List[float] = [0.48145466, 0.4578275, 0.40821073],
- image_std: List[float] = [0.26862954, 0.26130258, 0.27577711],
- size: int = 336,
- **kwargs
- ) -> None:
- self.crop_size = crop_size
- self.do_center_crop = do_center_crop
- self.do_normalize = do_normalize
- self.do_resize = do_resize
- self.image_mean = mx.array(image_mean)
- self.image_std = mx.array(image_std)
- self.size = size
-
- def __call__(self, images: List[Image]) -> mx.array:
- return mx.concatenate(
- [self._preprocess(image)[None] for image in images], axis=0
- )
-
- def _preprocess(self, image: Image) -> mx.array:
- if self.do_resize:
- image = resize(image, self.size)
- if self.do_center_crop:
- image = center_crop(image, (self.crop_size, self.crop_size))
- image = mx.array(np.array(image))
- image = rescale(image)
- if self.do_normalize:
- image = normalize(image, self.image_mean, self.image_std)
- return image
-
- @staticmethod
- def from_pretrained(path: str):
- path = Path(path)
- with open(path / "preprocessor_config.json", encoding="utf-8") as f:
- config = json.load(f)
- return CLIPImageProcessor(**config)
-
-
-def resize(image: Image, short_size: int) -> Image:
- """
- Resize so small size to short_size
- """
- width, height = image.size
- short = min(width, height)
- long = max(width, height)
- if short == short_size:
- return image
- new_short = short_size
- new_long = int(short_size * long / short)
- new_size = (new_short, new_long) if width <= height else (new_long, new_short)
- return image.resize(new_size)
-
-
-def center_crop(image: Image, size: Tuple[int, int]) -> Image:
- if size[0] % 2 != 0 or size[1] % 2 != 0:
- raise ValueError("Only even crop sizes supported.")
- original_width, original_height = image.size
- crop_height, crop_width = size
- top = (original_height - crop_height) // 2
- bottom = top + crop_height
- left = (original_width - crop_width) // 2
- right = left + crop_width
- return image.crop((left, top, right, bottom))
-
-
-def rescale(image: mx.array) -> mx.array:
- return image.astype(mx.float32) * (1 / 255.0)
-
-
-def normalize(image: mx.array, mean: mx.array, std: mx.array) -> mx.array:
- return (image - mean) / std
diff --git a/llava/llama.py b/llava/language.py
similarity index 95%
rename from llava/llama.py
rename to llava/language.py
index 242251a39..bb9078155 100644
--- a/llava/llama.py
+++ b/llava/language.py
@@ -176,8 +176,13 @@ def __call__(
self,
inputs: mx.array,
cache=None,
+ inputs_embeds=None,
):
- h = self.embed_tokens(inputs)
+ # for passing merged input embeddings
+ if inputs_embeds is None:
+ h = self.embed_tokens(inputs)
+ else:
+ h = inputs_embeds
mask = None
if h.shape[1] > 1:
@@ -193,7 +198,7 @@ def __call__(
return self.norm(h), cache
-class LlamaModel(nn.Module):
+class LanguageModel(nn.Module):
def __init__(self, args: TextConfig):
super().__init__()
self.model_type = args.model_type
@@ -204,8 +209,9 @@ def __call__(
self,
inputs: mx.array,
cache=None,
+ inputs_embeds=None,
):
- out, cache = self.model(inputs, cache)
+ out, cache = self.model(inputs, cache, inputs_embeds)
return self.lm_head(out), cache
@staticmethod
diff --git a/llava/llava.py b/llava/llava.py
index 4edfbeaca..01e76122a 100644
--- a/llava/llava.py
+++ b/llava/llava.py
@@ -8,9 +8,9 @@
import mlx.core as mx
import mlx.nn as nn
-from llama import LlamaModel, TextConfig
-
-from clip import ClipVisionModel, VisionConfig
+import numpy as np
+from language import LanguageModel, TextConfig
+from vision import VisionConfig, VisionModel
@dataclass
@@ -37,11 +37,15 @@ def from_dict(cls, params):
class LlavaMultiModalProjector(nn.Module):
def __init__(self, config: LlaVAConfig):
super().__init__()
- self.linear_1 = nn.Linear(config.in_features, config.out_features)
+ self.linear_1 = nn.Linear(
+ config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
+ )
self.gelu = nn.GELU()
- self.linear_2 = nn.Linear(config.out_features, config.out_features)
+ self.linear_2 = nn.Linear(
+ config.text_config.hidden_size, config.text_config.hidden_size, bias=True
+ )
- def forward(self, x: mx.array) -> mx.array:
+ def __call__(self, x: mx.array) -> mx.array:
x = self.linear_1(x)
x = self.gelu(x)
x = self.linear_2(x)
@@ -50,39 +54,81 @@ def forward(self, x: mx.array) -> mx.array:
class LlavaModel(nn.Module):
def __init__(self, config: LlaVAConfig):
- self.vision_tower = ClipVisionModel(
- config=VisionConfig.from_dict(config.vision_config)
- )
- self.language_model = LlamaModel(args=TextConfig.from_dict(config.text_config))
- self.multi_modal_projector = LlavaMultiModalProjector(
- config=config.projection_config
- )
+ self.config = config
+ self.vision_tower = VisionModel(config.vision_config)
+ self.language_model = LanguageModel(config.text_config)
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
+ self.vision_feature_layer = config.vision_feature_layer
+ self.vision_feature_select_strategy = config.vision_feature_select_strategy
def __call__(
self,
input_ids: Optional[mx.array] = None,
pixel_values: Optional[mx.array] = None,
):
- # TODO: add the forward pass
-
- if pixel_values is not None and input_ids.shape[1] != 1:
- image_outputs = self.vision_tower(pixel_values)
+ if pixel_values is None:
+ return self.language_model(input_ids)
- # TODO: this is not the correct output layer, but it's a placeholder
- selected_image_feature = image_outputs.pooler_output
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
+ _, _, hidden_states = self.vision_tower(
+ pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
+ )
+ selected_image_feature = hidden_states[self.vision_feature_layer]
+
+ if self.vision_feature_select_strategy == "default":
+ selected_image_feature = selected_image_feature[:, 1:]
+ elif self.vision_feature_select_strategy == "full":
+ selected_image_feature = selected_image_feature
+ else:
+ raise ValueError(
+ f"Unexpected select feature strategy: {self.vision_feature_select_strategy}"
+ )
+
+ image_features = self.multi_modal_projector(selected_image_feature)
+ final_inputs_embeds = self._merge_input_ids_with_image_features(
+ image_features, inputs_embeds, input_ids
+ )
- image_features = self.multi_modal_projector(selected_image_feature)
+ return final_inputs_embeds
def _merge_input_ids_with_image_features(
- self, image_features, inputs_embeds, input_ids, attention_mask, labels
+ self, image_features, inputs_embeds, input_ids
):
- # TODO: https://github.com/huggingface/transformers/blob/4f09d0fd888dbf2660313f9715992822acfb99ce/src/transformers/models/llava/modeling_llava.py#L279
+ image_features = np.array(image_features)
+ inputs_embeds = np.array(inputs_embeds)
+ input_ids = np.array(input_ids)
+
+ _, num_image_patches, embed_dim = image_features.shape
+ batch_size, sequence_length = input_ids.shape
+ special_image_token_mask = input_ids == self.config.image_token_index
+ num_special_image_tokens = np.sum(special_image_token_mask, axis=-1)
+ max_embed_dim = (
+ np.max(num_special_image_tokens) * (num_image_patches - 1)
+ ) + sequence_length
+
+ non_image_indices = np.where(input_ids != self.config.image_token_index)
+
+ new_token_positions = (
+ np.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), axis=-1)
+ - 1
+ )
+ text_to_overwrite = new_token_positions[non_image_indices]
+
+ final_embedding = np.zeros(
+ (batch_size, max_embed_dim, embed_dim), dtype=inputs_embeds.dtype
+ )
- special_image_token_mask = input_ids == self.config.special_tokens.image
+ final_embedding[non_image_indices[0], text_to_overwrite, :] = inputs_embeds[
+ non_image_indices
+ ]
- num_image_tokens = special_image_token_mask.sum()
+ image_to_overwrite = np.all(final_embedding == 0, axis=-1)
+ reshaped_image_features = image_features.reshape(-1, embed_dim)
+ final_embedding[image_to_overwrite, :] = reshaped_image_features[
+ : np.sum(image_to_overwrite)
+ ]
- pass
+ return mx.array(final_embedding)
@staticmethod
def from_pretrained(path: str):
@@ -92,6 +138,15 @@ def from_pretrained(path: str):
model_config = json.load(f)
model_config = LlaVAConfig.from_dict(model_config)
+
+ if isinstance(model_config.vision_config, dict):
+ model_config.vision_config = VisionConfig.from_dict(
+ model_config.vision_config
+ )
+
+ if isinstance(model_config.text_config, dict):
+ model_config.text_config = TextConfig.from_dict(model_config.text_config)
+
model = LlavaModel(model_config)
weight_files = glob.glob(str(path / "*.safetensors"))
if not weight_files:
@@ -102,6 +157,11 @@ def from_pretrained(path: str):
for wf in weight_files:
weights.update(mx.load(wf))
- weights = ClipVisionModel.sanitize(weights)
+ if hasattr(VisionModel, "sanitize"):
+ weights = VisionModel.sanitize(weights)
+
+ if hasattr(VisionModel, "sanitize"):
+ weights = LanguageModel.sanitize(weights)
+
model.load_weights(list(weights.items()))
return model
diff --git a/llava/processing_llava.py b/llava/processing_llava.py
deleted file mode 100644
index 705d1ccf6..000000000
--- a/llava/processing_llava.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from image_processor import CLIPImageProcessor
-
-
-class LlavaProcessor:
- def __init__(self, image_processor=None, tokenizer=None):
- self.image_processor = CLIPImageProcessor()
- self.tokenizer = tokenizer
-
- def __call__(
- self,
- text=None,
- images=None,
- padding=False,
- truncation=None,
- max_length=None,
- return_tensors=None,
- ):
- if images is not None:
- pixel_values = self.image_processor(images)
- else:
- pixel_values = None
-
- return {"pixel_values": pixel_values}
diff --git a/llava/test.py b/llava/test.py
index 0d693677e..4cca0c9af 100644
--- a/llava/test.py
+++ b/llava/test.py
@@ -3,8 +3,8 @@
import mlx.core as mx
import numpy as np
import requests
+import torch
from PIL import Image
-from processing_llava import LlavaProcessor
from transformers import AutoProcessor, LlavaForConditionalGeneration
from llava import LlavaModel
@@ -14,11 +14,13 @@
def load_mlx_models(path):
model = LlavaModel.from_pretrained(path)
+ model.eval()
return model
def load_hf_models(path):
model = LlavaForConditionalGeneration.from_pretrained(path)
+ model.eval()
return model
@@ -29,20 +31,103 @@ def setUpClass(cls):
cls.hf_llava = load_hf_models(MODEL_PATH)
cls.proc = AutoProcessor.from_pretrained(MODEL_PATH)
- def test_processor(self):
+ def test_image_features(self):
prompt = "USER: \nWhat are these?\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
+ vision_feature_layer = -2
+ with torch.no_grad():
+ pixel_values = self.proc(prompt, raw_image, return_tensors="pt")[
+ "pixel_values"
+ ]
- hf_data = mx.array(
- self.proc(prompt, raw_image, return_tensors="np")["pixel_values"]
- )
+ hf_pixel_values = pixel_values
+ mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1)
- mx_data = mx.array(
- self.proc(prompt, raw_image, return_tensors="np")["pixel_values"]
- )
+ _, _, hidden_states = self.mx_llava.vision_tower(
+ mx_pixel_values,
+ output_hidden_states=True,
+ )
- self.assertTrue(mx.allclose(mx_data, hf_data, atol=1e-5))
+ mx_elected_image_feature = hidden_states[vision_feature_layer]
+ mx_image_features = self.mx_llava.multi_modal_projector(
+ mx_elected_image_feature
+ )
+
+ hf_image_outputs = self.hf_llava.vision_tower(
+ hf_pixel_values, output_hidden_states=True
+ )
+ hf_elected_image_feature = hf_image_outputs.hidden_states[
+ vision_feature_layer
+ ]
+ hf_image_features = self.hf_llava.multi_modal_projector(
+ hf_elected_image_feature
+ )
+
+ self.assertTrue(
+ mx.allclose(
+ mx_image_features,
+ mx.array(hf_image_features.numpy()),
+ atol=1e-2,
+ )
+ )
+
+ def test_merge_input_ids_with_image_features(self):
+ prompt = "USER: \nWhat are these?\nASSISTANT:"
+ image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ raw_image = Image.open(requests.get(image_file, stream=True).raw)
+ vision_feature_layer = -2
+ with torch.no_grad():
+ values = self.proc(prompt, raw_image, return_tensors="pt")
+ pixel_values = values["pixel_values"]
+ input_ids = values["input_ids"]
+
+ hf_pixel_values = pixel_values
+ mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1)
+
+ _, _, hidden_states = self.mx_llava.vision_tower(
+ mx_pixel_values,
+ output_hidden_states=True,
+ )
+ mx_input_ids = mx.array(input_ids.numpy())
+ mx_elected_image_feature = hidden_states[vision_feature_layer]
+ mx_image_features = self.mx_llava.multi_modal_projector(
+ mx_elected_image_feature
+ )
+ mx_inputs_embeds = self.mx_llava.language_model.model.embed_tokens(
+ mx_input_ids
+ )
+ mx_final_embedding = self.mx_llava._merge_input_ids_with_image_features(
+ mx_image_features, mx_inputs_embeds, mx_input_ids
+ )
+
+ hf_image_outputs = self.hf_llava.vision_tower(
+ hf_pixel_values, output_hidden_states=True
+ )
+ hf_elected_image_feature = hf_image_outputs.hidden_states[
+ vision_feature_layer
+ ]
+ hf_image_features = self.hf_llava.multi_modal_projector(
+ hf_elected_image_feature
+ )
+ hf_inputs_embeds = self.hf_llava.get_input_embeddings()(input_ids)
+ hf_final_embedding, _, _, _ = (
+ self.hf_llava._merge_input_ids_with_image_features(
+ hf_image_features,
+ hf_inputs_embeds,
+ input_ids,
+ attention_mask=input_ids,
+ labels=torch.ones_like(input_ids),
+ )
+ )
+
+ self.assertTrue(
+ mx.allclose(
+ mx_final_embedding,
+ mx.array(hf_final_embedding.numpy()),
+ atol=1e-1,
+ )
+ )
if __name__ == "__main__":
diff --git a/llava/utils.py b/llava/utils.py
deleted file mode 100644
index 62c5d0ef1..000000000
--- a/llava/utils.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import mlx.core as mx
-import torch
-from typing import Tuple
-
-
-def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
- # bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss
- a = a.to(torch.float32) if dtype == "bfloat16" else a.to(
- getattr(torch, dtype))
- return mx.array(a.numpy(), getattr(mx, dtype))
-
-
-def should_keep_weight(key: str):
- return not ("position_ids" in key)
-
-
-def map_vision_tower_weights(key: str, value: torch.Tensor) -> Tuple[str, torch.Tensor]:
- key = key.replace("embeddings.", "")
- key = key.replace("encoder.", "")
- key = key.replace("position_embedding.weight", "position_embedding")
-
- key = key.replace('vision_model.', '')
-
- # Map attention layers
- if "self_attn." in key:
- key = key.replace("self_attn.", "attention.")
- if "q_proj." in key:
- key = key.replace("q_proj.", "query_proj.")
- if "k_proj." in key:
- key = key.replace("k_proj.", "key_proj.")
- if "v_proj." in key:
- key = key.replace("v_proj.", "value_proj.")
- if "layer_norm1." in key:
- key = key.replace("layer_norm1.", "ln1.")
- if "layer_norm2." in key:
- key = key.replace("layer_norm2.", "ln2.")
- # Map ffn layers
- if "mlp.fc1" in key:
- key = key.replace("mlp.fc1", "linear1")
- if "mlp.fc2" in key:
- key = key.replace("mlp.fc2", "linear2")
- # Fix layernorm typo
- if "pre_layrnorm" in key:
- # Fix typo in weights :)
- key = key.replace("pre_layrnorm", "pre_layernorm")
- if "patch_embedding.weight" in key:
- # Initially, value: [out_channels, in_channels, kH, KW].
- # We want [out_channels, kH, KW, in_channels]
- value = value.permute(0, 2, 3, 1)
- return (key, value)
-
-
-def map_language_model_weights(key: str, value: torch.Tensor) -> Tuple[str, torch.Tensor]:
- return (key, value)
-
-
-def map_multi_modal_projector_weights(key: str, value: torch.Tensor) -> Tuple[str, torch.Tensor]:
- return (key, value)
-
-
-def map_weights(key: str, value: torch.Tensor) -> Tuple[str, mx.array]:
-
- if 'vision_tower' in key:
- key, value = map_vision_tower_weights(key, value)
- elif 'language_model' in key:
- key, value = map_language_model_weights(key, value)
- elif 'multi_modal_projector' in key:
- key, value = map_multi_modal_projector_weights(key, value)
-
- return (key, torch_to_mx(value, dtype=str(value.dtype).replace("torch.", "")))
diff --git a/llava/clip.py b/llava/vision.py
similarity index 90%
rename from llava/clip.py
rename to llava/vision.py
index 736307f31..ed7f7a46e 100644
--- a/llava/clip.py
+++ b/llava/vision.py
@@ -4,7 +4,6 @@
import logging
import math
from dataclasses import dataclass
-from pathlib import Path
from typing import Optional
import mlx.core as mx
@@ -197,29 +196,16 @@ def __call__(
pooler_output = self.post_layernorm(x[:, 0, :])
return pooler_output, x, encoder_states
- @staticmethod
- def from_pretrained(path: str):
- path = Path(path)
-
- with open(path / "config.json", "r") as fid:
- config_dict = json.load(fid)
- vision_config = VisionConfig(**config_dict["vision_config"])
-
- model = ClipVisionModel(vision_config)
-
- weight_files = glob.glob(str(path / "*.safetensors"))
- if not weight_files:
- logging.error(f"No safetensors found in {path}")
- raise FileNotFoundError(f"No safetensors found in {path}")
- weights = {}
- for wf in weight_files:
- weights.update(mx.load(wf))
+class VisionModel(nn.Module):
+ def __init__(self, config: VisionConfig):
+ super().__init__()
+ self.vision_model = ClipVisionModel(config)
- weights = model.sanitize(weights)
- model.load_weights(list(weights.items()))
- model.load_weights(weights)
- return model
+ def __call__(
+ self, x: mx.array, output_hidden_states: Optional[bool] = None
+ ) -> mx.array:
+ return self.vision_model(x, output_hidden_states)
@staticmethod
def sanitize(weights):
From b9aeadea28929f712213fa9db8327d05c84da42a Mon Sep 17 00:00:00 2001
From: anchen
Date: Sun, 25 Feb 2024 00:40:40 +1100
Subject: [PATCH 4/5] chore: refactor generate script
---
llava/generate.py | 132 ++++++++++++++++++++++++++++++++++------------
llava/llava.py | 2 +-
2 files changed, 100 insertions(+), 34 deletions(-)
diff --git a/llava/generate.py b/llava/generate.py
index a51fe2cc9..df92c97f7 100644
--- a/llava/generate.py
+++ b/llava/generate.py
@@ -1,3 +1,6 @@
+import argparse
+import os
+
import mlx.core as mx
import mlx.nn as nn
import requests
@@ -6,53 +9,116 @@
from llava import LlavaModel
-MODEL_PATH = "models/llava-hf/llava-1.5-7b-hf"
-prompt = "USER: \nWhat are these?\nASSISTANT:"
-image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
-raw_image = Image.open(requests.get(image_file, stream=True).raw)
+def parse_arguments():
+ parser = argparse.ArgumentParser(
+ description="Generate text from an image using a model."
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="models/llava-hf/llava-1.5-7b-hf",
+ help="Path to the model directory.",
+ )
+ parser.add_argument(
+ "--image",
+ type=str,
+ default="http://images.cocodataset.org/val2017/000000039769.jpg",
+ help="URL or path of the image to process.",
+ )
+ parser.add_argument(
+ "--prompt",
+ type=str,
+ default="USER: \nWhat are these?\nASSISTANT:",
+ help="Prompt to use for the model.",
+ )
+ parser.add_argument(
+ "--max-tokens",
+ type=int,
+ default=100,
+ help="Maximum number of tokens to generate.",
+ )
+ parser.add_argument(
+ "--temperature", type=float, default=0.3, help="Temperature for sampling."
+ )
+ return parser.parse_args()
+
+
+def load_image(image_source):
+ if image_source.startswith(("http://", "https://")):
+ try:
+ response = requests.get(image_source, stream=True)
+ response.raise_for_status()
+ return Image.open(response.raw)
+ except requests.HTTPError as e:
+ print(f"Failed to load image from URL: {e}")
+ return None
+ elif os.path.isfile(image_source):
+ try:
+ return Image.open(image_source)
+ except IOError as e:
+ print(f"Failed to load image from path: {e}")
+ return None
+ else:
+ print("The image source is neither a valid URL nor a file path.")
+ return None
-processor = AutoProcessor.from_pretrained(MODEL_PATH)
-model = LlavaModel.from_pretrained(MODEL_PATH)
+def initialize_model(model_path):
+ processor = AutoProcessor.from_pretrained(model_path)
+ model = LlavaModel.from_pretrained(model_path)
+ return processor, model
-values = processor(prompt, raw_image, return_tensors="np")
-pixel_values = mx.array(values["pixel_values"])
-input_ids = mx.array(values["input_ids"])
-input_embeds = model(input_ids, pixel_values)
-max_tokens = 100
-temperature = 0.3
+def prepare_inputs(processor, image, prompt):
+ inputs = processor(prompt, image, return_tensors="np")
+ pixel_values = mx.array(inputs["pixel_values"])
+ input_ids = mx.array(inputs["input_ids"])
+ return input_ids, pixel_values
-def sample(logits, temp=0.0):
- if temp == 0:
+def sample(logits, temperature=0.0):
+ if temperature == 0:
return mx.argmax(logits, axis=-1)
else:
- return mx.random.categorical(logits * (1 / temp))
+ return mx.random.categorical(logits * (1 / temperature))
-def generate(y: mx.array, model: nn.Module, temp: float = 0.0, cache=None):
- while True:
- logits, cache = model(y[None], cache=cache)
- logits = logits[:, -1, :]
+def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature):
+ input_embeds = model.get_input_embeddings(input_ids, pixel_values)
+ logits, cache = model.language_model(
+ input_ids, cache=None, inputs_embeds=input_embeds
+ )
+ logits = logits[:, -1, :]
+ y = sample(logits, temperature=temperature)
+ tokens = [y.item()]
- y = sample(logits, temp=temp)
+ for _ in range(max_tokens):
+ logits, cache = model.language_model(y[None], cache=cache)
+ logits = logits[:, -1, :]
+ y = sample(logits, temperature)
token = y.item()
+ if token == processor.tokenizer.eos_token_id:
+ break
+ tokens.append(token)
+
+ return processor.tokenizer.decode(tokens)
+
- yield token
+def main():
+ args = parse_arguments()
+ raw_image = load_image(args.image)
+ if raw_image is None:
+ return
+ processor, model = initialize_model(args.model)
+ input_ids, pixel_values = prepare_inputs(processor, raw_image, args.prompt)
+ print(args.prompt)
+ generated_text = generate_text(
+ input_ids, pixel_values, model, processor, args.max_tokens, args.temperature
+ )
+ print(generated_text)
-logits, cache = model.language_model(input_ids, cache=None, inputs_embeds=input_embeds)
-logits = logits[:, -1, :]
-y = sample(logits, temp=temperature)
-tokens = [y.item()]
-for token, _ in zip(
- generate(y, model.language_model, temperature, cache=cache),
- range(max_tokens),
-):
- if token == processor.tokenizer.eos_token_id:
- break
- tokens.append(token)
-print(processor.tokenizer.decode(tokens))
+if __name__ == "__main__":
+ main()
diff --git a/llava/llava.py b/llava/llava.py
index 01e76122a..4af5783d7 100644
--- a/llava/llava.py
+++ b/llava/llava.py
@@ -61,7 +61,7 @@ def __init__(self, config: LlaVAConfig):
self.vision_feature_layer = config.vision_feature_layer
self.vision_feature_select_strategy = config.vision_feature_select_strategy
- def __call__(
+ def get_input_embeddings(
self,
input_ids: Optional[mx.array] = None,
pixel_values: Optional[mx.array] = None,
From d8f7b895e1bf197cbfe965e209d5b559da593f37 Mon Sep 17 00:00:00 2001
From: anchen
Date: Sun, 25 Feb 2024 01:00:56 +1100
Subject: [PATCH 5/5] chore: clean up
---
llava/.gitignore | 162 ----------------------------------------------
llava/generate.py | 14 ++--
llava/test.py | 20 +++---
llava/utils.py | 31 +++++++++
4 files changed, 49 insertions(+), 178 deletions(-)
delete mode 100644 llava/.gitignore
create mode 100644 llava/utils.py
diff --git a/llava/.gitignore b/llava/.gitignore
deleted file mode 100644
index bc0a54fe8..000000000
--- a/llava/.gitignore
+++ /dev/null
@@ -1,162 +0,0 @@
-**mlx_model# Byte-compiled / optimized / DLL files
-__pycache__/
-*.py[cod]
-*$py.class
-
-# C extensions
-*.so
-
-# Distribution / packaging
-.Python
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-share/python-wheels/
-*.egg-info/
-.installed.cfg
-*.egg
-MANIFEST
-
-# PyInstaller
-# Usually these files are written by a python script from a template
-# before PyInstaller builds the exe, so as to inject date/other infos into it.
-*.manifest
-*.spec
-
-# Installer logs
-pip-log.txt
-pip-delete-this-directory.txt
-
-# Unit test / coverage reports
-htmlcov/
-.tox/
-.nox/
-.coverage
-.coverage.*
-.cache
-nosetests.xml
-coverage.xml
-*.cover
-*.py,cover
-.hypothesis/
-.pytest_cache/
-cover/
-
-# Translations
-*.mo
-*.pot
-
-# Django stuff:
-*.log
-local_settings.py
-db.sqlite3
-db.sqlite3-journal
-
-# Flask stuff:
-instance/
-.webassets-cache
-
-# Scrapy stuff:
-.scrapy
-
-# Sphinx documentation
-docs/_build/
-
-# PyBuilder
-.pybuilder/
-target/
-
-# Jupyter Notebook
-.ipynb_checkpoints
-
-# IPython
-profile_default/
-ipython_config.py
-
-# pyenv
-# For a library or package, you might want to ignore these files since the code is
-# intended to run in multiple environments; otherwise, check them in:
-# .python-version
-
-# pipenv
-# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
-# However, in case of collaboration, if having platform-specific dependencies or dependencies
-# having no cross-platform support, pipenv may install dependencies that don't work, or not
-# install all needed dependencies.
-#Pipfile.lock
-
-# poetry
-# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
-# This is especially recommended for binary packages to ensure reproducibility, and is more
-# commonly ignored for libraries.
-# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
-#poetry.lock
-
-# pdm
-# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
-#pdm.lock
-# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
-# in version control.
-# https://pdm.fming.dev/#use-with-ide
-.pdm.toml
-
-# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
-__pypackages__/
-
-# Celery stuff
-celerybeat-schedule
-celerybeat.pid
-
-# SageMath parsed files
-*.sage.py
-
-# Environments
-.env
-.venv
-env/
-venv/
-ENV/
-env.bak/
-venv.bak/
-
-# Spyder project settings
-.spyderproject
-.spyproject
-
-# Rope project settings
-.ropeproject
-
-# mkdocs documentation
-/site
-
-# mypy
-.mypy_cache/
-.dmypy.json
-dmypy.json
-
-# Pyre type checker
-.pyre/
-
-# pytype static type analyzer
-.pytype/
-
-# Cython debug symbols
-cython_debug/
-
-# PyCharm
-# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
-# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
-# and can be added to the global gitignore or merged into this file. For a more nuclear
-# option (not recommended) you can uncomment the following to ignore the entire idea folder.
-#.idea/
-
-models
\ No newline at end of file
diff --git a/llava/generate.py b/llava/generate.py
index df92c97f7..c52645818 100644
--- a/llava/generate.py
+++ b/llava/generate.py
@@ -6,6 +6,7 @@
import requests
from PIL import Image
from transformers import AutoProcessor
+from utils import get_model_path
from llava import LlavaModel
@@ -17,8 +18,8 @@ def parse_arguments():
parser.add_argument(
"--model",
type=str,
- default="models/llava-hf/llava-1.5-7b-hf",
- help="Path to the model directory.",
+ default="llava-hf/llava-1.5-7b-hf",
+ help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--image",
@@ -30,7 +31,7 @@ def parse_arguments():
"--prompt",
type=str,
default="USER: \nWhat are these?\nASSISTANT:",
- help="Prompt to use for the model.",
+ help="Message to be processed by the model.",
)
parser.add_argument(
"--max-tokens",
@@ -39,7 +40,7 @@ def parse_arguments():
help="Maximum number of tokens to generate.",
)
parser.add_argument(
- "--temperature", type=float, default=0.3, help="Temperature for sampling."
+ "--temp", type=float, default=0.3, help="Temperature for sampling."
)
return parser.parse_args()
@@ -66,7 +67,8 @@ def load_image(image_source):
def initialize_model(model_path):
processor = AutoProcessor.from_pretrained(model_path)
- model = LlavaModel.from_pretrained(model_path)
+
+ model = LlavaModel.from_pretrained(get_model_path(model_path))
return processor, model
@@ -115,7 +117,7 @@ def main():
input_ids, pixel_values = prepare_inputs(processor, raw_image, args.prompt)
print(args.prompt)
generated_text = generate_text(
- input_ids, pixel_values, model, processor, args.max_tokens, args.temperature
+ input_ids, pixel_values, model, processor, args.max_tokens, args.temp
)
print(generated_text)
diff --git a/llava/test.py b/llava/test.py
index 4cca0c9af..64652324f 100644
--- a/llava/test.py
+++ b/llava/test.py
@@ -6,14 +6,18 @@
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration
+from utils import get_model_path
from llava import LlavaModel
-MODEL_PATH = "models/llava-hf/llava-1.5-7b-hf"
+MODEL_PATH = "llava-hf/llava-1.5-7b-hf"
+PROMPT = "USER: \nWhat are these?\nASSISTANT:"
+IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg"
def load_mlx_models(path):
- model = LlavaModel.from_pretrained(path)
+ model_path = get_model_path(path)
+ model = LlavaModel.from_pretrained(model_path)
model.eval()
return model
@@ -32,12 +36,10 @@ def setUpClass(cls):
cls.proc = AutoProcessor.from_pretrained(MODEL_PATH)
def test_image_features(self):
- prompt = "USER: \nWhat are these?\nASSISTANT:"
- image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
- raw_image = Image.open(requests.get(image_file, stream=True).raw)
+ raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw)
vision_feature_layer = -2
with torch.no_grad():
- pixel_values = self.proc(prompt, raw_image, return_tensors="pt")[
+ pixel_values = self.proc(PROMPT, raw_image, return_tensors="pt")[
"pixel_values"
]
@@ -73,12 +75,10 @@ def test_image_features(self):
)
def test_merge_input_ids_with_image_features(self):
- prompt = "USER: \nWhat are these?\nASSISTANT:"
- image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
- raw_image = Image.open(requests.get(image_file, stream=True).raw)
+ raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw)
vision_feature_layer = -2
with torch.no_grad():
- values = self.proc(prompt, raw_image, return_tensors="pt")
+ values = self.proc(PROMPT, raw_image, return_tensors="pt")
pixel_values = values["pixel_values"]
input_ids = values["input_ids"]
diff --git a/llava/utils.py b/llava/utils.py
new file mode 100644
index 000000000..0514b12d5
--- /dev/null
+++ b/llava/utils.py
@@ -0,0 +1,31 @@
+from pathlib import Path
+
+from huggingface_hub import snapshot_download
+
+
+def get_model_path(path_or_hf_repo: str) -> Path:
+ """
+ Ensures the model is available locally. If the path does not exist locally,
+ it is downloaded from the Hugging Face Hub.
+
+ Args:
+ path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
+
+ Returns:
+ Path: The path to the model.
+ """
+ model_path = Path(path_or_hf_repo)
+ if not model_path.exists():
+ model_path = Path(
+ snapshot_download(
+ repo_id=path_or_hf_repo,
+ allow_patterns=[
+ "*.json",
+ "*.safetensors",
+ "*.py",
+ "tokenizer.model",
+ "*.tiktoken",
+ ],
+ )
+ )
+ return model_path