Skip to content

Commit

Permalink
merging files
Browse files Browse the repository at this point in the history
  • Loading branch information
bayedieng committed Sep 11, 2024
2 parents 12ae92c + e0ed917 commit 2d17066
Show file tree
Hide file tree
Showing 23 changed files with 355 additions and 201 deletions.
6 changes: 3 additions & 3 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ commands:
source env/bin/activate
# Start first instance
HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout-secs 900 > output1.log 2>&1 &
HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 2>&1 | tee output1.log &
PID1=$!
# Start second instance
HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout-secs 900 > output2.log 2>&1 &
HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 2>&1 | tee output2.log &
PID2=$!
# Wait for discovery
Expand Down Expand Up @@ -144,7 +144,7 @@ jobs:
PID2=$!
sleep 10
kill $PID1 $PID2
if grep -q "Connected to peer" output1.log && grep -q "Connected to peer" output2.log; then
if grep -q "Successfully connected peers: \['node2@.*:.*'\]" output1.log && ! grep -q "Failed to connect peers:" output1.log && grep -q "Successfully connected peers: \['node1@.*:.*'\]" output2.log && ! grep -q "Failed to connect peers:" output2.log; then
echo "Test passed: Both instances discovered each other"
exit 0
else
Expand Down
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ exo will [automatically discover](https://github.com/exo-explore/exo/blob/945f90

### ChatGPT-compatible API

exo provides a [ChatGPT-compatible API](exo/api/chatgpt_api.py) for running models. It's a [one-line change](examples/chatgpt_api.py) in your application to run models on your own hardware using exo.
exo provides a [ChatGPT-compatible API](exo/api/chatgpt_api.py) for running models. It's a [one-line change](examples/chatgpt_api.sh) in your application to run models on your own hardware using exo.

### Device Equality

Expand Down Expand Up @@ -108,8 +108,6 @@ python3 main.py

That's it! No configuration required - exo will automatically discover the other device(s).

The native way to access models running on exo is using the exo library with peer handles. See how in [this example for Llama 3](examples/llama3_distributed.py).

exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:8000

For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000/v1/chat/completions. Example with curls:
Expand Down
10 changes: 5 additions & 5 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ def __init__(self, request_id: str, timestamp: int, prompt: str):


class ChatGPTAPI:
def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
self.node = node
self.inference_engine_classname = inference_engine_classname
self.response_timeout_secs = response_timeout_secs
self.response_timeout = response_timeout
self.on_chat_completion_request = on_chat_completion_request
self.app = web.Application(client_max_size=100*1024*1024) # 100MB to support image upload
self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
Expand Down Expand Up @@ -257,7 +257,7 @@ async def handle_post_chat_completions(self, request):
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)

try:
if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")

if stream:
response = web.StreamResponse(
Expand Down Expand Up @@ -306,7 +306,7 @@ def on_result(_request_id: str, tokens: List[int], is_finished: bool):

return _request_id == request_id and is_finished

_, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs)
_, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout)
if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete
if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
try:
Expand All @@ -318,7 +318,7 @@ def on_result(_request_id: str, tokens: List[int], is_finished: bool):
else:
_, tokens, _ = await callback.wait(
lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
timeout=self.response_timeout_secs,
timeout=self.response_timeout,
)

finish_reason = "length"
Expand Down
Empty file added exo/download/__init__.py
Empty file.
Empty file added exo/download/hf/__init__.py
Empty file.
78 changes: 42 additions & 36 deletions exo/download/hf/hf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,36 @@ async def download_file(
if DEBUG >= 2: print(f"Downloaded: {file_path}")


async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str:
repo_root = get_repo_root(repo_id)
refs_dir = repo_root/"refs"
refs_file = refs_dir/revision

# Check if we have a cached commit hash
if await aios.path.exists(refs_file):
async with aiofiles.open(refs_file, 'r') as f:
commit_hash = (await f.read()).strip()
if DEBUG >= 2: print(f"Commit hash is already cached at {refs_file}: {commit_hash}")
return commit_hash

# Fetch the commit hash for the given revision
async with aiohttp.ClientSession() as session:
api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
headers = await get_auth_headers()
async with session.get(api_url, headers=headers) as response:
if response.status != 200:
raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
revision_info = await response.json()
commit_hash = revision_info['sha']

# Cache the commit hash
await aios.makedirs(refs_dir, exist_ok=True)
async with aiofiles.open(refs_file, 'w') as f:
await f.write(commit_hash)

return commit_hash


async def download_repo_files(
repo_id: str,
revision: str = "main",
Expand All @@ -209,35 +239,15 @@ async def download_repo_files(
max_parallel_downloads: int = 4
) -> Path:
repo_root = get_repo_root(repo_id)
refs_dir = repo_root/"refs"
snapshots_dir = repo_root/"snapshots"
cachedreqs_dir = repo_root/"cachedreqs"

# Ensure directories exist
await aios.makedirs(refs_dir, exist_ok=True)
await aios.makedirs(snapshots_dir, exist_ok=True)
await aios.makedirs(cachedreqs_dir, exist_ok=True)

# Check if we have a cached commit hash
refs_file = refs_dir/revision
if await aios.path.exists(refs_file):
async with aiofiles.open(refs_file, 'r') as f:
commit_hash = (await f.read()).strip()
if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}")
else:
async with aiohttp.ClientSession() as session:
# Fetch the commit hash for the given revision
api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
headers = await get_auth_headers()
async with session.get(api_url, headers=headers) as response:
if response.status != 200:
raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
revision_info = await response.json()
commit_hash = revision_info['sha']

# Cache the commit hash
async with aiofiles.open(refs_file, 'w') as f:
await f.write(commit_hash)
# Resolve revision to commit hash
commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)

# Set up the snapshot directory
snapshot_dir = snapshots_dir/commit_hash
Expand Down Expand Up @@ -357,7 +367,8 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[

# Check if the file exists
repo_root = get_repo_root(repo_id)
snapshot_dir = repo_root/"snapshots"
commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
snapshot_dir = repo_root/"snapshots"/commit_hash
index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)

if index_file:
Expand All @@ -380,24 +391,19 @@ def extract_layer_num(tensor_name: str) -> Optional[int]:


def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
default_patterns = [
"*.json",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
]
shard_specific_patterns = []
default_patterns = set(["*.json","*.py","tokenizer.model","*.tiktoken","*.txt"])
shard_specific_patterns = set()
if weight_map:
for tensor_name, filename in weight_map.items():
layer_num = extract_layer_num(tensor_name)
if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
shard_specific_patterns.append(filename)
shard_specific_patterns.add(filename)
sorted_file_names = sorted(weight_map.values())
if shard.is_first_layer():
shard_specific_patterns.append(sorted_file_names[0])
shard_specific_patterns.add(sorted_file_names[0])
elif shard.is_last_layer():
shard_specific_patterns.append(sorted_file_names[-1])
shard_specific_patterns.add(sorted_file_names[-1])
else:
shard_specific_patterns = ["*.safetensors", "*Q4_K_M*.gguf"]
return list(set(default_patterns + shard_specific_patterns)) # Remove duplicates
shard_specific_patterns = ["*.safetensors"]
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)
3 changes: 2 additions & 1 deletion exo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import uuid
import netifaces
from pathlib import Path
import tempfile

DEBUG = int(os.getenv("DEBUG", default="0"))
DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
Expand All @@ -34,7 +35,7 @@ def get_system_info():


def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int:
used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".exo_used_ports")
used_ports_file = os.path.join(tempfile.gettempdir(), "exo_used_ports")

def read_used_ports():
if os.path.exists(used_ports_file):
Expand Down
24 changes: 16 additions & 8 deletions exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,43 @@
from ..shard import Shard
from typing import Optional
from exo.download.shard_download import ShardDownloader

import asyncio
from concurrent.futures import ThreadPoolExecutor

class MLXDynamicShardInferenceEngine(InferenceEngine):
def __init__(self, shard_downloader: ShardDownloader):
self.shard = None
self.shard_downloader = shard_downloader
self.executor = ThreadPoolExecutor(max_workers=1)

async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
await self.ensure_shard(shard)
loop = asyncio.get_running_loop()
if image_str:
image = await get_image_from_str(image_str)
inputs = self.tokenizer(prompt, image, return_tensors="np")
inputs = await loop.run_in_executor(self.executor, self.tokenizer, prompt, image, return_tensors="np")
pixel_values = mx.array(inputs["pixel_values"])
input_ids = mx.array(inputs["input_ids"])
output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, input_ids, pixel_values))
output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))
else:
output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt))))
input_ids = mx.array(await loop.run_in_executor(self.executor, self.tokenizer.encode, prompt))
output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids))
return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id

async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
await self.ensure_shard(shard)
output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(input_data)))
output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, mx.array(input_data)))
return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id

async def ensure_shard(self, shard: Shard):
if self.shard == shard:
return

model_path = await self.shard_downloader.ensure_shard(shard)
model_shard, self.tokenizer = await load_shard(model_path, shard)
self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
self.shard = shard

if self.shard != shard:
loop = asyncio.get_running_loop()
def load_shard_wrapper(): return asyncio.run(load_shard(model_path, shard))
model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard)
self.shard = shard
2 changes: 1 addition & 1 deletion exo/inference/mlx/sharded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ..shard import Shard


# TODO: support a speculative model so we can parallelise compute across devices
class StatefulShardedModel:
def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
self.shard = shard
Expand Down
21 changes: 15 additions & 6 deletions exo/inference/tinygrad/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
import numpy as np
from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
from exo.download.shard_download import ShardDownloader
from concurrent.futures import ThreadPoolExecutor
import asyncio
import threading
from functools import partial

Tensor.no_grad = True
# default settings
Expand Down Expand Up @@ -52,14 +56,15 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
def __init__(self, shard_downloader: ShardDownloader):
self.shard = None
self.shard_downloader = shard_downloader
self.executor = ThreadPoolExecutor(max_workers=1)

async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
await self.ensure_shard(shard)
start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)

toks = self.tokenizer.encode(prompt)
h = self.model(Tensor([toks]), start_pos, TEMPERATURE).realize()
toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor([toks]), start_pos, TEMPERATURE).realize())

if h.shape == (1,):
start_pos += len(toks)
Expand All @@ -75,7 +80,7 @@ async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarr
start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)

h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos, TEMPERATURE).realize())

if h.shape == (1,):
start_pos += n_captured_toks
Expand All @@ -90,6 +95,10 @@ async def ensure_shard(self, shard: Shard):
return

model_path = await self.shard_downloader.ensure_shard(shard)
self.model = build_transformer(model_path, shard, model_size="8B" if "8b" in shard.model_id.lower() else "70B")
self.tokenizer = await resolve_tokenizer(str((model_path if model_path.is_dir() else model_path.parent)))
self.shard = shard

if self.shard != shard:
self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, "8B" if "8b" in shard.model_id.lower() else "70B")

tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
self.tokenizer = await resolve_tokenizer(tokenizer_path)
self.shard = shard
9 changes: 6 additions & 3 deletions exo/inference/tokenizers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import traceback
from aiofiles import os as aios
from os import PathLike
from pathlib import Path
from typing import Union
from transformers import AutoTokenizer, AutoProcessor
from exo.download.hf.hf_helpers import get_local_snapshot_dir
from exo.helpers import DEBUG
Expand All @@ -8,18 +11,18 @@ async def resolve_tokenizer(model_id: str):
local_path = await get_local_snapshot_dir(model_id)
if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
try:
if await aios.path.exists(local_path):
if local_path and await aios.path.exists(local_path):
if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}")
return await _resolve_tokenizer(local_path)
except:
if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...")
if DEBUG >= 5: traceback.print_exc()
return await _resolve_tokenizer(model_id)

async def _resolve_tokenizer(model_id_or_local_path: str):
async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
try:
if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in model_id_or_local_path else False)
processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False)
if not hasattr(processor, 'eos_token_id'):
processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
if not hasattr(processor, 'encode'):
Expand Down
4 changes: 4 additions & 0 deletions exo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
"TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
},
"llama-3.1-70b-bf16": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-bf16", start_layer=0, end_layer=0, n_layers=80),
"TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
},
"llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),},
"llama-3-8b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
Expand Down
4 changes: 4 additions & 0 deletions exo/networking/grpc/grpc_peer_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabiliti
def id(self) -> str:
return self._id

def addr(self) -> str:
return self.address

def device_capabilities(self) -> DeviceCapabilities:
return self._device_capabilities

async def connect(self):
self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
await self.channel.channel_ready()

async def is_connected(self) -> bool:
return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
Expand Down
Loading

0 comments on commit 2d17066

Please sign in to comment.