Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
4d964bd
add: llava mlx first draft
nkasmanoff Feb 19, 2024
0e2a054
add: weights comparision
nkasmanoff Feb 19, 2024
6e4a7ee
add forward pass skeleton
nkasmanoff Feb 19, 2024
ed9d376
update: now imports weights correctly
nkasmanoff Feb 22, 2024
b83b1e5
delete base
nkasmanoff Feb 22, 2024
6e23847
latest
nkasmanoff Feb 22, 2024
bb5b898
adding config
nkasmanoff Feb 22, 2024
95f9df1
fix: use config
nkasmanoff Feb 22, 2024
a1c6fe6
add mlx config
nkasmanoff Feb 22, 2024
cec0639
feat: add image processor for llava processor
mzbac Feb 23, 2024
4dd8bca
wip
mzbac Feb 24, 2024
c4ea94f
feat: llava working example
mzbac Feb 24, 2024
b9aeade
chore: refactor generate script
mzbac Feb 24, 2024
d8f7b89
chore: clean up
mzbac Feb 24, 2024
7fb1a39
Merge pull request #1 from mzbac/llava
nkasmanoff Feb 24, 2024
371a807
add: warning to user if no <image> token despite using one
nkasmanoff Feb 24, 2024
449f7d0
add: __call__ to LlavaModel
nkasmanoff Feb 24, 2024
a1cab2b
add: call to LlavaModel
nkasmanoff Feb 24, 2024
8e6b2f5
update fp
nkasmanoff Feb 26, 2024
823411c
clean up var names
nkasmanoff Feb 26, 2024
6bc06c8
update: native GeLU
nkasmanoff Feb 26, 2024
feec5ec
Cleanup
nkasmanoff Feb 28, 2024
d76fd40
update generate and readme
nkasmanoff Feb 28, 2024
49f928a
remove todo comment
nkasmanoff Feb 28, 2024
c2b8463
rearrange tests
nkasmanoff Feb 28, 2024
25a65cf
fix example code
nkasmanoff Feb 28, 2024
c2c9411
nits in README
awni Feb 28, 2024
8301c43
update readme
nkasmanoff Feb 28, 2024
5c8f67d
nit in readme
awni Feb 28, 2024
cd77bcf
nits in README
awni Feb 28, 2024
b39c251
chore(llava): refactor image embedding merging logic
mzbac Feb 28, 2024
935ebb5
min mlx version
awni Mar 1, 2024
683b7c4
nits in readmes
awni Mar 1, 2024
b37891d
fix cli prompt, some nits
awni Mar 1, 2024
7ace6ea
updates, slight simplify
awni Mar 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Some more useful examples are listed below.
### Multimodal models

- Joint text and image embeddings with [CLIP](clip).
- Text generation from image and text inputs with [LLaVA](llava).

### Other Models

Expand Down
1 change: 1 addition & 0 deletions llava/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
**.ipynb
61 changes: 61 additions & 0 deletions llava/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# LLaVA

An example of LLaVA: Large Language and Vision Assistant in MLX.[^1] LLlava is
a multimodal model that can generate text given combined image and text inputs.

## Setup

Install the dependencies:

```bash
pip install -r requirements.txt
```

## Run

You can use LLaVA to ask questions about images.

For example, using the command line:

```bash
python generate.py \
--model llava-hf/llava-1.5-7b-hf \
--image "http://images.cocodataset.org/val2017/000000039769.jpg" \
--prompt "USER: <image>\nWhat are these?\nASSISTANT:" \
--max-tokens 128 \
--temp 0
```

This uses the following image:

![alt text](http://images.cocodataset.org/val2017/000000039769.jpg)

And generates the output:

```
These are two cats lying on a pink couch.
```

You can also use LLaVA in Python:

```python
from generate import load_model, prepare_inputs, generate_text

processor, model = load_model("llava-hf/llava-1.5-7b-hf")

max_tokens, temperature = 128, 0.0

prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
image = "http://images.cocodataset.org/val2017/000000039769.jpg"
input_ids, pixel_values = prepare_inputs(processor, image, prompt)

reply = generate_text(
input_ids, pixel_values, model, processor, max_tokens, temperature
)

print(reply)
```

[^1]:
Refer to [LLaVA project webpage](https://llava-vl.github.io/) for more
information.
130 changes: 130 additions & 0 deletions llava/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright © 2024 Apple Inc.

import argparse
import codecs
from pathlib import Path

import mlx.core as mx
import requests
from PIL import Image
from transformers import AutoProcessor

from llava import LlavaModel


def parse_arguments():
parser = argparse.ArgumentParser(
description="Generate text from an image using a model."
)
parser.add_argument(
"--model",
type=str,
default="llava-hf/llava-1.5-7b-hf",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--image",
type=str,
default="http://images.cocodataset.org/val2017/000000039769.jpg",
help="URL or path of the image to process.",
)
parser.add_argument(
"--prompt",
type=str,
default="USER: <image>\nWhat are these?\nASSISTANT:",
help="Message to be processed by the model.",
)
parser.add_argument(
"--max-tokens",
type=int,
default=100,
help="Maximum number of tokens to generate.",
)
parser.add_argument(
"--temp", type=float, default=0.3, help="Temperature for sampling."
)
return parser.parse_args()


def load_image(image_source):
"""
Helper function to load an image from either a URL or file.
"""
if image_source.startswith(("http://", "https://")):
try:
response = requests.get(image_source, stream=True)
response.raise_for_status()
return Image.open(response.raw)
except Exception as e:
raise ValueError(
f"Failed to load image from URL: {image_source} with error {e}"
)
elif Path(image_source).is_file():
try:
return Image.open(image_source)
except IOError as e:
raise ValueError(f"Failed to load image {image_source} with error: {e}")
else:
raise ValueError(
f"The image {image_source} must be a valid URL or existing file."
)


def prepare_inputs(processor, image, prompt):
if isinstance(image, str):
image = load_image(image)
inputs = processor(prompt, image, return_tensors="np")
pixel_values = mx.array(inputs["pixel_values"])
input_ids = mx.array(inputs["input_ids"])
return input_ids, pixel_values


def load_model(model_path):
processor = AutoProcessor.from_pretrained(model_path)
model = LlavaModel.from_pretrained(model_path)
return processor, model


def sample(logits, temperature=0.0):
if temperature == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temperature))


def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature):

logits, cache = model(input_ids, pixel_values)
logits = logits[:, -1, :]
y = sample(logits, temperature=temperature)
tokens = [y.item()]

for n in range(max_tokens - 1):
logits, cache = model.language_model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits, temperature)
token = y.item()
if token == processor.tokenizer.eos_token_id:
break
tokens.append(token)

return processor.tokenizer.decode(tokens)


def main():
args = parse_arguments()
processor, model = load_model(args.model)

prompt = codecs.decode(args.prompt, "unicode_escape")

input_ids, pixel_values = prepare_inputs(processor, args.image, prompt)

print(prompt)
generated_text = generate_text(
input_ids, pixel_values, model, processor, args.max_tokens, args.temp
)
print(generated_text)


if __name__ == "__main__":
main()
Loading