Skip to content

Commit f95942f

Browse files
committed
Merge branch 'main' of github.com:risingsunomi/exo-nvidia
2 parents cee3e31 + c861f30 commit f95942f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+95
-57
lines changed

.circleci/config.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ commands:
1717
source env/bin/activate
1818
1919
# Start first instance
20-
HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 2>&1 | tee output1.log &
20+
HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 2>&1 | tee output1.log &
2121
PID1=$!
2222
2323
# Start second instance
24-
HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 2>&1 | tee output2.log &
24+
HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 2>&1 | tee output2.log &
2525
PID2=$!
2626
2727
# Wait for discovery
@@ -138,9 +138,9 @@ jobs:
138138
name: Run discovery integration test
139139
command: |
140140
source env/bin/activate
141-
DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 > output1.log 2>&1 &
141+
DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 > output1.log 2>&1 &
142142
PID1=$!
143-
DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 > output2.log 2>&1 &
143+
DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 > output2.log 2>&1 &
144144
PID2=$!
145145
sleep 10
146146
kill $PID1 $PID2

README.md

+47-12

exo/api/chatgpt_api.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,7 @@ def remap_messages(messages: List[Message]) -> List[Message]:
124124

125125
def build_prompt(tokenizer, _messages: List[Message]):
126126
messages = remap_messages(_messages)
127-
if DEBUG >= 3:
128-
print(f"messages: {messages}")
129-
prompt = tokenizer.apply_chat_template(
130-
messages,
131-
tokenize=False,
132-
add_generation_prompt=True
133-
)
134-
135-
if DEBUG >= 3:
136-
print(f"prompt: {str(prompt)}")
137-
for msg in messages:
138-
print(f"chat role: {msg.role}\ncontent: {msg.content}")
139-
127+
prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
140128
image_str = None
141129
for message in messages:
142130
if not isinstance(message.content, list):
@@ -197,7 +185,7 @@ def __init__(self, node: Node, inference_engine_classname: str, response_timeout
197185
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
198186
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
199187

200-
self.static_dir = Path(__file__).parent.parent.parent/"tinychat/examples/tinychat"
188+
self.static_dir = Path(__file__).parent.parent/"tinychat"
201189
self.app.router.add_get("/", self.handle_root)
202190
self.app.router.add_static("/", self.static_dir, name="static")
203191

exo/download/hf/hf_helpers.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
T = TypeVar("T")
1919

20+
2021
async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
2122
refs_dir = get_repo_root(repo_id)/"refs"
2223
refs_file = refs_dir/revision
@@ -69,6 +70,8 @@ def _add_wildcard_to_directories(pattern: str) -> str:
6970
return pattern + "*"
7071
return pattern
7172

73+
def get_hf_endpoint() -> str:
74+
return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
7275

7376
def get_hf_home() -> Path:
7477
"""Get the Hugging Face home directory."""
@@ -99,7 +102,7 @@ def get_repo_root(repo_id: str) -> Path:
99102

100103

101104
async def fetch_file_list(session, repo_id, revision, path=""):
102-
api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
105+
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
103106
url = f"{api_url}/{path}" if path else api_url
104107

105108
headers = await get_auth_headers()
@@ -124,7 +127,7 @@ async def fetch_file_list(session, repo_id, revision, path=""):
124127
async def download_file(
125128
session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
126129
):
127-
base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
130+
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
128131
url = urljoin(base_url, file_path)
129132
local_path = os.path.join(save_directory, file_path)
130133

@@ -214,7 +217,7 @@ async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str:
214217

215218
# Fetch the commit hash for the given revision
216219
async with aiohttp.ClientSession() as session:
217-
api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
220+
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/revision/{revision}"
218221
headers = await get_auth_headers()
219222
async with session.get(api_url, headers=headers) as response:
220223
if response.status != 200:

exo/helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def is_valid_uuid(val):
169169

170170

171171
def get_or_create_node_id():
172-
NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__)))/".exo_node_id"
172+
NODE_ID_FILE = Path(tempfile.gettempdir()) / ".exo_node_id"
173173
try:
174174
if NODE_ID_FILE.is_file():
175175
with open(NODE_ID_FILE, "r") as f:

exo/inference/tinygrad/__init__.py

Whitespace-only changes.

exo/inference/tinygrad/inference.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,15 @@
44
from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
55
from exo.inference.shard import Shard
66
from exo.inference.tokenizers import resolve_tokenizer
7-
from tinygrad.nn.state import safe_load, torch_load, load_state_dict
8-
from tinygrad import Tensor, dtypes, nn, Context
9-
from transformers import AutoTokenizer
7+
from tinygrad.nn.state import load_state_dict
8+
from tinygrad import Tensor, nn, Context
109
from exo.inference.inference_engine import InferenceEngine
1110
from typing import Optional, Tuple
1211
import numpy as np
1312
from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
1413
from exo.download.shard_download import ShardDownloader
1514
from concurrent.futures import ThreadPoolExecutor
1615
import asyncio
17-
import threading
18-
from functools import partial
1916

2017
Tensor.no_grad = True
2118
# default settings

exo/inference/tinygrad/models/__init__.py

Whitespace-only changes.

main.py renamed to exo/main.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import time
66
import traceback
77
import uuid
8+
import sys
89
from exo.orchestration.standard_node import StandardNode
910
from exo.networking.grpc.grpc_server import GRPCServer
10-
from exo.networking.udp_discovery import UDPDiscovery
11-
from exo.networking.tailscale_discovery import TailscaleDiscovery
11+
from exo.networking.udp.udp_discovery import UDPDiscovery
12+
from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
1213
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
1314
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
1415
from exo.api import ChatGPTAPI
@@ -24,6 +25,8 @@
2425

2526
# parse args
2627
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
28+
parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
29+
parser.add_argument("model_name", nargs="?", help="Model name to run")
2730
parser.add_argument("--node-id", type=str, default=None, help="Node ID")
2831
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
2932
parser.add_argument("--node-port", type=int, default=None, help="Node port")
@@ -180,14 +183,18 @@ def handle_exit():
180183

181184
await node.start(wait_for_peers=args.wait_for_peers)
182185

183-
if args.run_model:
184-
await run_model_cli(node, inference_engine, args.run_model, args.prompt)
186+
if args.command == "run" or args.run_model:
187+
model_name = args.model_name or args.run_model
188+
if not model_name:
189+
print("Error: Model name is required when using 'run' command or --run-model")
190+
return
191+
await run_model_cli(node, inference_engine, model_name, args.prompt)
185192
else:
186193
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
187194
await asyncio.Event().wait()
188195

189196

190-
if __name__ == "__main__":
197+
def run():
191198
loop = asyncio.new_event_loop()
192199
asyncio.set_event_loop(loop)
193200
try:
@@ -197,3 +204,6 @@ def handle_exit():
197204
finally:
198205
loop.run_until_complete(shutdown(signal.SIGTERM, loop))
199206
loop.close()
207+
208+
if __name__ == "__main__":
209+
run()

exo/models.py

+6
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@
5353
},
5454

5555
### qwen
56+
"qwen-2.5-coder-1.5b": {
57+
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
58+
},
59+
"qwen-2.5-coder-7b": {
60+
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
61+
},
5662
"qwen-2.5-7b": {
5763
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
5864
},

exo/networking/tailscale/__init__.py

Whitespace-only changes.

exo/networking/tailscale_discovery.py renamed to exo/networking/tailscale/tailscale_discovery.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import traceback
44
from typing import List, Dict, Callable, Tuple
55
from tailscale import Tailscale, Device
6-
from .discovery import Discovery
7-
from .peer_handle import PeerHandle
6+
from exo.networking.discovery import Discovery
7+
from exo.networking.peer_handle import PeerHandle
88
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
99
from exo.helpers import DEBUG, DEBUG_DISCOVERY
1010
from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, update_device_attributes

exo/networking/test_tailscale_discovery.py renamed to exo/networking/tailscale/test_tailscale_discovery.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import asyncio
33
import unittest
44
from unittest import mock
5-
from exo.networking.tailscale_discovery import TailscaleDiscovery
5+
from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
66
from exo.networking.peer_handle import PeerHandle
77

88
class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):

exo/networking/udp/__init__.py

Whitespace-only changes.

exo/networking/test_udp_discovery.py renamed to exo/networking/udp/test_udp_discovery.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import unittest
33
from unittest import mock
4-
from exo.networking.udp_discovery import UDPDiscovery
4+
from exo.networking.udp.udp_discovery import UDPDiscovery
55
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
66
from exo.networking.grpc.grpc_server import GRPCServer
77
from exo.orchestration.node import Node

exo/networking/udp_discovery.py renamed to exo/networking/udp/udp_discovery.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import time
55
import traceback
66
from typing import List, Dict, Callable, Tuple, Coroutine
7-
from .discovery import Discovery
8-
from .peer_handle import PeerHandle
7+
from exo.networking.discovery import Discovery
8+
from exo.networking.peer_handle import PeerHandle
99
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
1010
from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses
1111

File renamed without changes.
File renamed without changes.
File renamed without changes.

tinychat/examples/tinychat/index.html renamed to exo/tinychat/index.html

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
<option value="deepseek-coder-v2-lite">Deepseek Coder V2 Lite</option>
4444
<option value="deepseek-coder-v2.5">Deepseek Coder V2.5</option>
4545
<option value="llava-1.5-7b-hf">LLaVa 1.5 7B (Vision Model)</option>
46+
<option value="qwen-2.5-coder-1.5b">Qwen 2.5 Coder 1.5B</option>
47+
<option value="qwen-2.5-coder-7b">Qwen 2.5 Coder 7B</option>
4648
<option value="qwen-2.5-7b">Qwen 2.5 7B</option>
4749
<option value="qwen-2.5-math-7b">Qwen 2.5 7B (Math)</option>
4850
<option value="qwen-2.5-14b">Qwen 2.5 14B</option>
File renamed without changes.

install.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
python3 -m venv .venv
44
source .venv/bin/activate
5-
pip install .
5+
pip install -e .

0 commit comments

Comments
 (0)