Skip to content

Commit 170c3d7

Browse files
authored
Merge pull request exo-explore#88 from varshith15/main
Support for LLaVA
2 parents c1fdb98 + f548ef4 commit 170c3d7

20 files changed

+1083
-97
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ __pycache__/
22
.venv
33
test_weights.npz
44
.exo_used_ports
5+
.idea
56

67
# Byte-compiled / optimized / DLL files
78
__pycache__/
@@ -82,6 +83,7 @@ target/
8283

8384
# Jupyter Notebook
8485
.ipynb_checkpoints
86+
Untitled.ipynb
8587

8688
# IPython
8789
profile_default/

README.md

+29-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Forget expensive NVIDIA GPUs, unify your existing devices into one powerful GPU:
2727
<div align="center">
2828
<h2>Update: Exo Supports Llama 3.1</h2>
2929
<p>Now the default models, run 8B, 70B and 405B parameter models on your own devices</p>
30-
<p><a href="https://github.com/exo-explore/exo/blob/main/exo/inference/mlx/models/sharded_llama.py">See the code</a></p>
30+
<p><a href="https://github.com/exo-explore/exo/blob/main/exo/inference/mlx/models/llama.py">See the code</a></p>
3131
</div>
3232

3333
## Get Involved
@@ -40,7 +40,7 @@ We also welcome contributions from the community. We have a list of bounties in
4040

4141
### Wide Model Support
4242

43-
exo supports LLaMA ([MLX](exo/inference/mlx/models/sharded_llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)) and other popular models.
43+
exo supports LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)) and other popular models.
4444

4545
### Dynamic Model Partitioning
4646

@@ -111,7 +111,7 @@ The native way to access models running on exo is using the exo library with pee
111111

112112
exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:8000
113113

114-
For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000/v1/chat/completions. Example with curl:
114+
For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000/v1/chat/completions. Example with curls:
115115

116116
```sh
117117
curl http://localhost:8000/v1/chat/completions \
@@ -123,6 +123,32 @@ curl http://localhost:8000/v1/chat/completions \
123123
}'
124124
```
125125

126+
```sh
127+
curl http://localhost:8000/v1/chat/completions \
128+
-H "Content-Type: application/json" \
129+
-d '{
130+
"model": "llava-1.5-7b-hf",
131+
"messages": [
132+
{
133+
"role": "user",
134+
"content": [
135+
{
136+
"type": "text",
137+
"text": "What are these?"
138+
},
139+
{
140+
"type": "image_url",
141+
"image_url": {
142+
"url": "http://images.cocodataset.org/val2017/000000039769.jpg"
143+
}
144+
}
145+
]
146+
}
147+
],
148+
"temperature": 0.0
149+
}'
150+
```
151+
126152
## Debugging
127153

128154
Enable debug logs with the DEBUG environment variable (0-9).

exo/api/chatgpt_api.py

+69-12
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import json
55
from pathlib import Path
6-
from transformers import AutoTokenizer
6+
from transformers import AutoTokenizer, AutoProcessor
77
from typing import List, Literal, Union, Dict
88
from aiohttp import web
99
import aiohttp_cors
@@ -42,11 +42,15 @@
4242
"deepseek-coder-v2-lite": {
4343
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
4444
},
45+
### llava
46+
"llava-1.5-7b-hf": {
47+
"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
48+
},
4549
}
4650

4751

4852
class Message:
49-
def __init__(self, role: str, content: str):
53+
def __init__(self, role: str, content: Union[str, list]):
5054
self.role = role
5155
self.content = content
5256

@@ -68,6 +72,18 @@ def resolve_tinygrad_tokenizer(model_id: str):
6872

6973

7074
async def resolve_tokenizer(model_id: str):
75+
try:
76+
if DEBUG >= 2: print(f"Trying to AutoProcessor for {model_id}")
77+
processor = AutoProcessor.from_pretrained(model_id)
78+
processor.eos_token_id = processor.tokenizer.eos_token_id
79+
processor.encode = processor.tokenizer.encode
80+
return processor
81+
except Exception as e:
82+
if DEBUG >= 2: print(f"Failed to load processor for {model_id}. Error: {e}")
83+
import traceback
84+
85+
if DEBUG >= 2: print(traceback.format_exc())
86+
7187
try:
7288
if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
7389
return AutoTokenizer.from_pretrained(model_id)
@@ -137,8 +153,50 @@ def generate_completion(
137153
return completion
138154

139155

140-
def build_prompt(tokenizer, messages: List[Message]):
141-
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
156+
def remap_messages(messages: List[Message]) -> List[Message]:
157+
remapped_messages = []
158+
last_image = None
159+
for message in messages:
160+
remapped_content = []
161+
for content in message.content:
162+
if isinstance(content, dict):
163+
if content.get("type") in ["image_url", "image"]:
164+
image_url = content.get("image_url", {}).get("url") or content.get("image")
165+
if image_url:
166+
last_image = {"type": "image", "image": image_url}
167+
remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
168+
else:
169+
remapped_content.append(content)
170+
else:
171+
remapped_content.append({"type": "text", "text": content})
172+
remapped_messages.append(Message(role=message.role, content=remapped_content))
173+
174+
if last_image:
175+
# Replace the last image placeholder with the actual image content
176+
for message in reversed(remapped_messages):
177+
for i, content in enumerate(message.content):
178+
if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
179+
message.content[i] = last_image
180+
return remapped_messages
181+
182+
return remapped_messages
183+
184+
def build_prompt(tokenizer, _messages: List[Message]):
185+
messages = remap_messages(_messages)
186+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
187+
image_str = None
188+
for message in messages:
189+
if not isinstance(message.content, list):
190+
continue
191+
192+
for content in message.content:
193+
# note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41
194+
# follows the convention in https://platform.openai.com/docs/guides/vision
195+
if content.get("type", None) == "image":
196+
image_str = content.get("image", None)
197+
break
198+
199+
return prompt, image_str
142200

143201

144202
def parse_message(data: dict):
@@ -160,7 +218,7 @@ def __init__(self, node: Node, inference_engine_classname: str, response_timeout
160218
self.node = node
161219
self.inference_engine_classname = inference_engine_classname
162220
self.response_timeout_secs = response_timeout_secs
163-
self.app = web.Application()
221+
self.app = web.Application(client_max_size=100 * 1024 * 1024) # 100MB to support image upload
164222
self.prev_token_lens: Dict[str, int] = {}
165223
self.stream_tasks: Dict[str, asyncio.Task] = {}
166224
cors = aiohttp_cors.setup(self.app)
@@ -187,15 +245,14 @@ async def middleware(request):
187245
return middleware
188246

189247
async def handle_root(self, request):
190-
print(f"Handling root request from {request.remote}")
191248
return web.FileResponse(self.static_dir / "index.html")
192249

193250
async def handle_post_chat_token_encode(self, request):
194251
data = await request.json()
195252
shard = shard_mappings.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
196253
messages = [parse_message(msg) for msg in data.get("messages", [])]
197254
tokenizer = await resolve_tokenizer(shard.model_id)
198-
return web.json_response({"length": len(build_prompt(tokenizer, messages))})
255+
return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
199256

200257
async def handle_post_chat_completions(self, request):
201258
data = await request.json()
@@ -219,13 +276,13 @@ async def handle_post_chat_completions(self, request):
219276
tokenizer = await resolve_tokenizer(shard.model_id)
220277
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
221278

222-
prompt = build_prompt(tokenizer, chat_request.messages)
279+
prompt, image_str = build_prompt(tokenizer, chat_request.messages)
223280
callback_id = f"chatgpt-api-wait-response-{request_id}"
224281
callback = self.node.on_token.register(callback_id)
225282

226-
if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
283+
if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
227284
try:
228-
await self.node.process_prompt(shard, prompt, request_id=request_id)
285+
await self.node.process_prompt(shard, prompt, image_str, request_id=request_id)
229286
except Exception as e:
230287
if DEBUG >= 2:
231288
import traceback
@@ -252,7 +309,7 @@ async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
252309
self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
253310
new_tokens = tokens[prev_last_tokens_len:]
254311
finish_reason = None
255-
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
312+
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer, AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
256313
if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
257314
new_tokens = new_tokens[:-1]
258315
if is_finished:
@@ -294,7 +351,7 @@ def on_result(_request_id: str, tokens: List[int], is_finished: bool):
294351
)
295352

296353
finish_reason = "length"
297-
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
354+
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
298355
if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
299356
if tokens[-1] == eos_token_id:
300357
tokens = tokens[:-1]

0 commit comments

Comments
 (0)