Skip to content

Commit

Permalink
Merge pull request #88 from varshith15/main
Browse files Browse the repository at this point in the history
Support for LLaVA
  • Loading branch information
AlexCheema authored Jul 30, 2024
2 parents 1426826 + af1c7ce commit 0ec77e1
Show file tree
Hide file tree
Showing 20 changed files with 1,083 additions and 97 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ __pycache__/
.venv
test_weights.npz
.exo_used_ports
.idea

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down Expand Up @@ -82,6 +83,7 @@ target/

# Jupyter Notebook
.ipynb_checkpoints
Untitled.ipynb

# IPython
profile_default/
Expand Down
32 changes: 29 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Forget expensive NVIDIA GPUs, unify your existing devices into one powerful GPU:
<div align="center">
<h2>Update: Exo Supports Llama 3.1</h2>
<p>Now the default models, run 8B, 70B and 405B parameter models on your own devices</p>
<p><a href="https://github.com/exo-explore/exo/blob/main/exo/inference/mlx/models/sharded_llama.py">See the code</a></p>
<p><a href="https://github.com/exo-explore/exo/blob/main/exo/inference/mlx/models/llama.py">See the code</a></p>
</div>

## Get Involved
Expand All @@ -40,7 +40,7 @@ We also welcome contributions from the community. We have a list of bounties in

### Wide Model Support

exo supports LLaMA ([MLX](exo/inference/mlx/models/sharded_llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)) and other popular models.
exo supports LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)) and other popular models.

### Dynamic Model Partitioning

Expand Down Expand Up @@ -111,7 +111,7 @@ The native way to access models running on exo is using the exo library with pee

exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:8000

For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000/v1/chat/completions. Example with curl:
For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000/v1/chat/completions. Example with curls:

```sh
curl http://localhost:8000/v1/chat/completions \
Expand All @@ -123,6 +123,32 @@ curl http://localhost:8000/v1/chat/completions \
}'
```

```sh
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llava-1.5-7b-hf",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are these?"
},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/val2017/000000039769.jpg"
}
}
]
}
],
"temperature": 0.0
}'
```

## Debugging

Enable debug logs with the DEBUG environment variable (0-9).
Expand Down
81 changes: 69 additions & 12 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import json
from pathlib import Path
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoProcessor
from typing import List, Literal, Union, Dict
from aiohttp import web
import aiohttp_cors
Expand Down Expand Up @@ -42,11 +42,15 @@
"deepseek-coder-v2-lite": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
},
### llava
"llava-1.5-7b-hf": {
"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
},
}


class Message:
def __init__(self, role: str, content: str):
def __init__(self, role: str, content: Union[str, list]):
self.role = role
self.content = content

Expand All @@ -68,6 +72,18 @@ def resolve_tinygrad_tokenizer(model_id: str):


async def resolve_tokenizer(model_id: str):
try:
if DEBUG >= 2: print(f"Trying to AutoProcessor for {model_id}")
processor = AutoProcessor.from_pretrained(model_id)
processor.eos_token_id = processor.tokenizer.eos_token_id
processor.encode = processor.tokenizer.encode
return processor
except Exception as e:
if DEBUG >= 2: print(f"Failed to load processor for {model_id}. Error: {e}")
import traceback

if DEBUG >= 2: print(traceback.format_exc())

try:
if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
return AutoTokenizer.from_pretrained(model_id)
Expand Down Expand Up @@ -137,8 +153,50 @@ def generate_completion(
return completion


def build_prompt(tokenizer, messages: List[Message]):
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def remap_messages(messages: List[Message]) -> List[Message]:
remapped_messages = []
last_image = None
for message in messages:
remapped_content = []
for content in message.content:
if isinstance(content, dict):
if content.get("type") in ["image_url", "image"]:
image_url = content.get("image_url", {}).get("url") or content.get("image")
if image_url:
last_image = {"type": "image", "image": image_url}
remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
else:
remapped_content.append(content)
else:
remapped_content.append({"type": "text", "text": content})
remapped_messages.append(Message(role=message.role, content=remapped_content))

if last_image:
# Replace the last image placeholder with the actual image content
for message in reversed(remapped_messages):
for i, content in enumerate(message.content):
if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
message.content[i] = last_image
return remapped_messages

return remapped_messages

def build_prompt(tokenizer, _messages: List[Message]):
messages = remap_messages(_messages)
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_str = None
for message in messages:
if not isinstance(message.content, list):
continue

for content in message.content:
# note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41
# follows the convention in https://platform.openai.com/docs/guides/vision
if content.get("type", None) == "image":
image_str = content.get("image", None)
break

return prompt, image_str


def parse_message(data: dict):
Expand All @@ -160,7 +218,7 @@ def __init__(self, node: Node, inference_engine_classname: str, response_timeout
self.node = node
self.inference_engine_classname = inference_engine_classname
self.response_timeout_secs = response_timeout_secs
self.app = web.Application()
self.app = web.Application(client_max_size=100 * 1024 * 1024) # 100MB to support image upload
self.prev_token_lens: Dict[str, int] = {}
self.stream_tasks: Dict[str, asyncio.Task] = {}
cors = aiohttp_cors.setup(self.app)
Expand All @@ -187,15 +245,14 @@ async def middleware(request):
return middleware

async def handle_root(self, request):
print(f"Handling root request from {request.remote}")
return web.FileResponse(self.static_dir / "index.html")

async def handle_post_chat_token_encode(self, request):
data = await request.json()
shard = shard_mappings.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
messages = [parse_message(msg) for msg in data.get("messages", [])]
tokenizer = await resolve_tokenizer(shard.model_id)
return web.json_response({"length": len(build_prompt(tokenizer, messages))})
return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})

async def handle_post_chat_completions(self, request):
data = await request.json()
Expand All @@ -219,13 +276,13 @@ async def handle_post_chat_completions(self, request):
tokenizer = await resolve_tokenizer(shard.model_id)
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")

prompt = build_prompt(tokenizer, chat_request.messages)
prompt, image_str = build_prompt(tokenizer, chat_request.messages)
callback_id = f"chatgpt-api-wait-response-{request_id}"
callback = self.node.on_token.register(callback_id)

if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
try:
await self.node.process_prompt(shard, prompt, request_id=request_id)
await self.node.process_prompt(shard, prompt, image_str, request_id=request_id)
except Exception as e:
if DEBUG >= 2:
import traceback
Expand All @@ -252,7 +309,7 @@ async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
new_tokens = tokens[prev_last_tokens_len:]
finish_reason = None
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer, AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
new_tokens = new_tokens[:-1]
if is_finished:
Expand Down Expand Up @@ -294,7 +351,7 @@ def on_result(_request_id: str, tokens: List[int], is_finished: bool):
)

finish_reason = "length"
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
if tokens[-1] == eos_token_id:
tokens = tokens[:-1]
Expand Down
Loading

0 comments on commit 0ec77e1

Please sign in to comment.