Skip to content

Commit

Permalink
Merge branch 'main' into HEAD
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Oct 6, 2024
2 parents f95942f + 7b2a523 commit c3ea732
Show file tree
Hide file tree
Showing 10 changed files with 235 additions and 42 deletions.
42 changes: 31 additions & 11 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import aiohttp_cors
import traceback
from exo import DEBUG, VERSION
from exo.download.download_progress import RepoProgressEvent
from exo.helpers import PrefixDict
from exo.inference.shard import Shard
from exo.inference.tokenizers import resolve_tokenizer
Expand Down Expand Up @@ -184,14 +185,23 @@ def __init__(self, node: Node, inference_engine_classname: str, response_timeout
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})

self.static_dir = Path(__file__).parent.parent/"tinychat"
self.static_dir = Path(__file__).parent.parent / "tinychat"
self.app.router.add_get("/", self.handle_root)
self.app.router.add_static("/", self.static_dir, name="static")

# Add middleware to log every request
self.app.middlewares.append(self.timeout_middleware)
self.app.middlewares.append(self.log_request)

async def timeout_middleware(self, app, handler):
async def middleware(request):
try:
return await asyncio.wait_for(handler(request), timeout=self.response_timeout)
except asyncio.TimeoutError:
return web.json_response({"detail": "Request timed out"}, status=408)
return middleware

async def log_request(self, app, handler):
async def middleware(request):
if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
Expand All @@ -212,6 +222,16 @@ async def handle_post_chat_token_encode(self, request):
tokenizer = await resolve_tokenizer(shard.model_id)
return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})

async def handle_get_download_progress(self, request):
progress_data = {}
for node_id, progress_event in self.node.node_download_progress.items():
if isinstance(progress_event, RepoProgressEvent):
progress_data[node_id] = progress_event.to_dict()
else:
print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
return web.json_response(progress_data)


async def handle_post_chat_completions(self, request):
data = await request.json()
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
Expand Down Expand Up @@ -257,13 +277,10 @@ async def handle_post_chat_completions(self, request):
callback = self.node.on_token.register(callback_id)

if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
try:
await self.node.process_prompt(shard, prompt, image_str, request_id=request_id)
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)

try:
await asyncio.wait_for(self.node.process_prompt(shard, prompt, image_str, request_id=request_id), timeout=self.response_timeout)

if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")

if stream:
Expand All @@ -277,9 +294,9 @@ async def handle_post_chat_completions(self, request):
)
await response.prepare(request)

async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
prev_last_tokens_len = self.prev_token_lens.get(request_id, 0)
self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
prev_last_tokens_len = self.prev_token_lens.get(_request_id, 0)
self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
new_tokens = tokens[prev_last_tokens_len:]
finish_reason = None
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
Expand Down Expand Up @@ -309,7 +326,7 @@ async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
if DEBUG >= 2: traceback.print_exc()

def on_result(_request_id: str, tokens: List[int], is_finished: bool):
self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))
if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, tokens, is_finished))

return _request_id == request_id and is_finished

Expand Down Expand Up @@ -338,6 +355,9 @@ def on_result(_request_id: str, tokens: List[int], is_finished: bool):
return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
except asyncio.TimeoutError:
return web.json_response({"detail": "Response generation timed out"}, status=408)
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
finally:
deregistered_callback = self.node.on_token.deregister(callback_id)
if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
Expand Down
15 changes: 9 additions & 6 deletions exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,15 @@ def preemptively_start_download(request_id: str, opaque_status: str):
last_broadcast_time = 0

def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
global last_broadcast_time
current_time = time.time()
if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
last_broadcast_time = current_time
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))

global last_broadcast_time
current_time = time.time()
if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
last_broadcast_time = current_time
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({
"type": "download_progress",
"node_id": node.id,
"progress": event.to_dict()
})))

shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)

Expand Down
2 changes: 1 addition & 1 deletion exo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
},
"llama-3.1-70b-bf16": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-bf16", start_layer=0, end_layer=0, n_layers=80),
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED", start_layer=0, end_layer=0, n_layers=80),
"TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
},
"llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),},
Expand Down
12 changes: 6 additions & 6 deletions exo/networking/tailscale/tailscale_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import time
import traceback
from typing import List, Dict, Callable, Tuple
from tailscale import Tailscale, Device
from exo.networking.discovery import Discovery
from exo.networking.peer_handle import PeerHandle
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
from exo.helpers import DEBUG, DEBUG_DISCOVERY
from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, update_device_attributes
from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, update_device_attributes, get_tailscale_devices, Device

class TailscaleDiscovery(Discovery):
def __init__(
Expand All @@ -32,7 +31,8 @@ def __init__(
self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
self.discovery_task = None
self.cleanup_task = None
self.tailscale = Tailscale(api_key=tailscale_api_key, tailnet=tailnet)
self.tailscale_api_key = tailscale_api_key
self.tailnet = tailnet
self._device_id = None
self.update_task = None

Expand Down Expand Up @@ -61,12 +61,12 @@ async def get_device_id(self):
return self._device_id

async def update_device_posture_attributes(self):
await update_device_attributes(await self.get_device_id(), self.tailscale.api_key, self.node_id, self.node_port, self.device_capabilities)
await update_device_attributes(await self.get_device_id(), self.tailscale_api_key, self.node_id, self.node_port, self.device_capabilities)

async def task_discover_peers(self):
while True:
try:
devices: dict[str, Device] = await self.tailscale.devices()
devices: dict[str, Device] = await get_tailscale_devices(self.tailscale_api_key, self.tailnet)
current_time = time.time()

active_devices = {
Expand All @@ -81,7 +81,7 @@ async def task_discover_peers(self):
for device in active_devices.values():
if device.name == self.node_id: continue
peer_host = device.addresses[0]
peer_id, peer_port, device_capabilities = await get_device_attributes(device.device_id, self.tailscale.api_key)
peer_id, peer_port, device_capabilities = await get_device_attributes(device.device_id, self.tailscale_api_key)
if not peer_id:
if DEBUG_DISCOVERY >= 4: print(f"{device.device_id} does not have exo node attributes. skipping.")
continue
Expand Down
42 changes: 41 additions & 1 deletion exo/networking/tailscale/tailscale_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,32 @@
import asyncio
import aiohttp
import re
from typing import Dict, Any, Tuple
from typing import Dict, Any, Tuple, List, Optional
from exo.helpers import DEBUG_DISCOVERY
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
from datetime import datetime, timezone

class Device:
def __init__(self, device_id: str, name: str, addresses: List[str], last_seen: Optional[datetime] = None):
self.device_id = device_id
self.name = name
self.addresses = addresses
self.last_seen = last_seen

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Device':
return cls(
device_id=data.get('id', ''),
name=data.get('name', ''),
addresses=data.get('addresses', []),
last_seen=cls.parse_datetime(data.get('lastSeen'))
)

@staticmethod
def parse_datetime(date_string: Optional[str]) -> Optional[datetime]:
if not date_string:
return None
return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc)

async def get_device_id() -> str:
try:
Expand Down Expand Up @@ -94,3 +117,20 @@ def sanitize_attribute(value: str) -> str:
sanitized_value = re.sub(r'[^a-zA-Z0-9_.]', '_', value)
# Truncate to 50 characters
return sanitized_value[:50]

async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]:
async with aiohttp.ClientSession() as session:
url = f"https://api.tailscale.com/api/v2/tailnet/{tailnet}/devices"
headers = {"Authorization": f"Bearer {api_key}"}

async with session.get(url, headers=headers) as response:
response.raise_for_status()
data = await response.json()

devices = {}
for device_data in data.get("devices", []):
print("Device data: ", device_data)
device = Device.from_dict(device_data)
devices[device.name] = device

return devices
35 changes: 21 additions & 14 deletions exo/networking/udp/udp_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
self.broadcast_interval = broadcast_interval
self.discovery_timeout = discovery_timeout
self.device_capabilities = device_capabilities
self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
self.known_peers: Dict[str, Tuple[PeerHandle, float, float, int]] = {}
self.broadcast_task = None
self.listen_task = None
self.cleanup_task = None
Expand All @@ -76,24 +76,25 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
while len(self.known_peers) < wait_for_peers:
if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
await asyncio.sleep(0.1)
return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()]

async def task_broadcast_presence(self):
message = json.dumps({
"type": "discovery",
"node_id": self.node_id,
"grpc_port": self.node_port,
"device_capabilities": self.device_capabilities.to_dict(),
})

if DEBUG_DISCOVERY >= 2:
print("Starting task_broadcast_presence...")
print(f"\nBroadcast message: {message}")

while True:
# Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
# the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
for addr in get_all_ip_addresses():
message = json.dumps({
"type": "discovery",
"node_id": self.node_id,
"grpc_port": self.node_port,
"device_capabilities": self.device_capabilities.to_dict(),
"priority": 1, # For now, every interface has the same priority. We can make this better by prioriting interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
})
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr}): {message}")

transport = None
try:
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
Expand Down Expand Up @@ -138,21 +139,27 @@ async def on_listen_message(self, data, addr):
peer_id = message["node_id"]
peer_host = addr[0]
peer_port = message["grpc_port"]
peer_prio = message["priority"]
device_capabilities = DeviceCapabilities(**message["device_capabilities"])

if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
if peer_id in self.known_peers:
existing_peer_prio = self.known_peers[peer_id][3]
if existing_peer_prio >= peer_prio:
if DEBUG >= 1: print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
return
new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
if not await new_peer_handle.health_check():
if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
return
if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
self.known_peers[peer_id] = (new_peer_handle, time.time(), time.time())
self.known_peers[peer_id] = (new_peer_handle, time.time(), time.time(), peer_prio)
else:
if not await self.known_peers[peer_id][0].health_check():
if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Removing.")
if peer_id in self.known_peers: del self.known_peers[peer_id]
return
self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)

async def task_listen_for_peers(self):
await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
Expand All @@ -164,13 +171,13 @@ async def task_cleanup_peers(self):
try:
current_time = time.time()
peers_to_remove = []
for peer_id, (peer_handle, connected_at, last_seen) in self.known_peers.items():
for peer_id, (peer_handle, connected_at, last_seen, prio) in self.known_peers.items():
if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or \
(current_time - last_seen > self.discovery_timeout) or \
(not await peer_handle.health_check()):
peers_to_remove.append(peer_id)

if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, {connected_at=}, {last_seen=}, {prio=}" for peer_handle, connected_at, last_seen, prio in self.known_peers.values()})

for peer_id in peers_to_remove:
if peer_id in self.known_peers: del self.known_peers[peer_id]
Expand Down
4 changes: 3 additions & 1 deletion exo/tinychat/index.css
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ main {
border-right: 2px solid var(--secondary-color);
box-shadow: 10px 10px 20px 2px var(--secondary-color-transparent);
}

.download-progress{
margin-bottom: 20em;
}
.message > pre {
white-space: pre-wrap;
}
Expand Down
15 changes: 15 additions & 0 deletions exo/tinychat/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,21 @@ <h3 x-text="new Date(_state.time).toLocaleString()"></h3>
$el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
" x-ref="messages" x-show="home === 2" x-transition="">
</div>

<!-- Download Progress Section -->
<template x-if="downloadProgress">
<div class="download-progress message message-role-assistant">
<h2>Download Progress</h2>
<div class="download-progress-node">
<p><strong>Model:</strong> <span x-text="downloadProgress.repo_id + '@' + downloadProgress.repo_revision"></span></p>
<p><strong>Progress:</strong> <span x-text="`${downloadProgress.downloaded_bytes_display} / ${downloadProgress.total_bytes_display} (${downloadProgress.percentage}%)`"></span></p>
<p><strong>Speed:</strong> <span x-text="downloadProgress.overall_speed_display || 'N/A'"></span></p>
<p><strong>ETA:</strong> <span x-text="downloadProgress.overall_eta_display || 'N/A'"></span></p>
</div>
</div>
</template>


<div class="input-container">
<div class="input-performance">
<span class="input-performance-point">
Expand Down
Loading

0 comments on commit c3ea732

Please sign in to comment.