Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dn/exo logging #401

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -169,5 +169,8 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# vscode
.vscode/

**/*.xcodeproj/*
.aider*
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,29 @@ exo supports the following inference engines:
- ✅ [GRPC](exo/networking/grpc)
- 🚧 [Radio](TODO)
- 🚧 [Bluetooth](TODO)

## CLI Options

- `<command>` (run): Command to run
- `<model_name>`: Model name to run
- `--node-id` (str): Node ID
- `--node-host` (str) (default: '0.0.0.0'): Node host
- `--node-port` (int): Node port
- `--listen-port` (int) (default: 5678): Listening port for discovery
- `--download-quick-check` (flag) (default: False): Quick check local path for model shards download
- `--max-parallel-downloads` (int) (default: 4): Max parallel downloads for model shards download
- `--prometheus-client-port` (int): Prometheus client port
- `--broadcast-port` (int) (default: 5678): Broadcast port for discovery
- `--discovery-module` (str) (default: 'udp'): Discovery module to use
- `--discovery-timeout` (int) (default: 30): Discovery timeout in seconds
- `--discovery-config-path` (str): Path to discovery config json file
- `--wait-for-peers` (int) (default: 0): Number of peers to wait to connect to before starting
- `--chatgpt-api-port` (int) (default: 8000): ChatGPT API port
- `--chatgpt-api-response-timeout` (int) (default: 90): ChatGPT API response timeout in seconds
- `--max-generate-tokens` (int) (default: 10000): Max tokens to generate in each request
- `--inference-engine` (str): Inference engine to use (mlx, tinygrad, or dummy)
- `--disable-tui`, `--no-disable-tui`: Disable TUI
- `--run-model` (str): Specify a model to run directly
- `--prompt` (str) (default: 'Who are you?'): Prompt for the model when using --run-model
- `--tailscale-api-key` (str): Tailscale API key
- `--tailnet-name` (str): Tailnet name
5 changes: 5 additions & 0 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from exo.models import model_base_shards
from typing import Callable

from exo.telemetry.constants import TelemetryAction
from exo.telemetry.logger import Logger


class Message:
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
Expand Down Expand Up @@ -224,6 +227,8 @@ async def handle_get_download_progress(self, request):
return web.json_response(progress_data)

async def handle_post_chat_completions(self, request):
logger = Logger()
logger.write_log(TelemetryAction.REQUEST_RECEIVED)
data = await request.json()
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
stream = data.get("stream", False)
Expand Down
108 changes: 108 additions & 0 deletions exo/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import json
from pathlib import Path
from platformdirs import user_data_dir
import uuid

from exo.helpers import get_git_hash

class PersistentConfig:
"""
Persistent configuration that should be saved between sessions.

Note that this syncs with the SessionConfig instance and is a
subset of the data stored there, so modules should prefer to use
SessionConfig for reading any data, but for writing data that they
need to persist beyond the current session, they should use this
class.
"""
CONFIG_FILE_NAME = "config.json"
_instance = None
_initialized = False

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __init__(self):
if not self._initialized:
self._config_file = self.initialize()
self._initialized = True

# add default values (but don't overwrite if already set)
self.set("device_id", str(uuid.uuid4()), replace_if_exists=False)
self.set("node_id", str(uuid.uuid4()), replace_if_exists=False)

def initialize(self):
app_data = Path(user_data_dir("exo", appauthor="exo_labs"))
print(f"Using app data directory: {app_data}")
app_data.mkdir(parents=True, exist_ok=True)
config_file = app_data / self.CONFIG_FILE_NAME

if not config_file.exists():
with config_file.open('w') as f:
json.dump({}, f)

return config_file

def set(self, key: str, value: any, replace_if_exists: bool = True):
print(f"Setting {key}={value} in config file")

with self._config_file.open('r') as f:
config = json.load(f)

# Update config and write back to file
if replace_if_exists or key not in config:
config[key] = value
with self._config_file.open('w') as f:
json.dump(config, f, indent=4)

# Sync to session config
SessionConfig().set(key, value)


def get(self, key: str):
with self._config_file.open('r') as f:
return json.load(f).get(key)

class SessionConfig:
"""
Handles temporary configuration specific to the current session

Note that this syncs with the PersistentConfig instance and is a
superset of the data there.
"""
_instance = None
_initialized = False

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __init__(self):
if not self._initialized:
# Load persistent config during initialization
self._session_data = {}
self.sync_with_persistent()
self.set("session_id", str(uuid.uuid4()))
self.set("commit_id", get_git_hash())
self._initialized = True

def sync_with_persistent(self):
"""Sync session data with persistent config"""
persistent = PersistentConfig()
with persistent._config_file.open('r') as f:
persistent_data = json.load(f)
self._session_data.update(persistent_data)

def set(self, key: str, value: any, replace_if_exists: bool = True):
if replace_if_exists or key not in self._session_data:
self._session_data[key] = value

def get(self, key: str):
return self._session_data.get(key)

# Expose singleton instances
session_config = SessionConfig()
persistent_config = PersistentConfig()
35 changes: 9 additions & 26 deletions exo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import uuid
import netifaces
from pathlib import Path
import subprocess
import tempfile

DEBUG = int(os.getenv("DEBUG", default="0"))
Expand Down Expand Up @@ -169,32 +170,6 @@ def is_valid_uuid(val):
return False


def get_or_create_node_id():
NODE_ID_FILE = Path(tempfile.gettempdir())/".exo_node_id"
try:
if NODE_ID_FILE.is_file():
with open(NODE_ID_FILE, "r") as f:
stored_id = f.read().strip()
if is_valid_uuid(stored_id):
if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}")
return stored_id
else:
if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.")

new_id = str(uuid.uuid4())
with open(NODE_ID_FILE, "w") as f:
f.write(new_id)

if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}")
return new_id
except IOError as e:
if DEBUG >= 2: print(f"IO error creating node_id: {e}")
return str(uuid.uuid4())
except Exception as e:
if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
return str(uuid.uuid4())


def pretty_print_bytes(size_in_bytes: int) -> str:
if size_in_bytes < 1024:
return f"{size_in_bytes} B"
Expand Down Expand Up @@ -234,3 +209,11 @@ def get_all_ip_addresses():
except:
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
return ["localhost"]


def get_git_hash():
try:
git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8")
return git_hash
except subprocess.CalledProcessError:
return None # In case the command fails (e.g., if not in a git repository)
52 changes: 37 additions & 15 deletions exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,20 @@
from exo.networking.udp.udp_discovery import UDPDiscovery
from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
from exo.telemetry.constants import TelemetryAction
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
from exo.api import ChatGPTAPI
from exo.config import session_config
from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
from exo.download.hf.hf_shard_download import HFShardDownloader
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_all_ip_addresses, terminal_link
from exo.inference.shard import Shard
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
from exo.inference.dummy_inference_engine import DummyInferenceEngine
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration.node import Node
from exo.models import model_base_shards
from exo.telemetry.logger import Logger
from exo.viz.topology_viz import TopologyViz

# parse args
Expand Down Expand Up @@ -51,6 +54,7 @@
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
parser.add_argument("--logging-url", type=str, default="https://exo-logs.foobar.dev/api/v1/logs/bulk", help="Logging URL")
args = parser.parse_args()
print(f"Selected inference engine: {args.inference_engine}")

Expand All @@ -71,7 +75,9 @@
args.node_port = find_available_port(args.node_host)
if DEBUG >= 1: print(f"Using available port: {args.node_port}")

args.node_id = args.node_id or get_or_create_node_id()
if args.node_id:
session_config.set("node_id", args.node_id)

chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
if DEBUG >= 0:
Expand All @@ -84,7 +90,7 @@

if args.discovery_module == "udp":
discovery = UDPDiscovery(
args.node_id,
session_config.get("node_id"),
args.node_port,
args.listen_port,
args.broadcast_port,
Expand All @@ -93,7 +99,7 @@
)
elif args.discovery_module == "tailscale":
discovery = TailscaleDiscovery(
args.node_id,
session_config.get("node_id"),
args.node_port,
lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
discovery_timeout=args.discovery_timeout,
Expand All @@ -103,10 +109,10 @@
elif args.discovery_module == "manual":
if not args.discovery_config_path:
raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
discovery = ManualDiscovery(args.discovery_config_path, session_config.get("node_id"), create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
node = StandardNode(
args.node_id,
session_config.get("node_id"),
None,
inference_engine,
discovery,
Expand All @@ -127,6 +133,9 @@
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
)

# Initialize the logger singleton
logger = Logger(args.logging_url, session_config.get("node_id"))
logger.write_log(TelemetryAction.START)

def preemptively_start_download(request_id: str, opaque_status: str):
try:
Expand Down Expand Up @@ -198,6 +207,7 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
except Exception as e:
print(f"Error processing prompt: {str(e)}")
traceback.print_exc()
raise e
finally:
node.on_token.deregister(callback_id)

Expand Down Expand Up @@ -226,15 +236,27 @@ def handle_exit():


def run():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
except KeyboardInterrupt:
print("Received keyboard interrupt. Shutting down...")
finally:
loop.run_until_complete(shutdown(signal.SIGTERM, loop))
loop.close()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
except KeyboardInterrupt:
print("Received keyboard interrupt. Shutting down...")
except Exception as e:
# First tear down the TUI if it exists
if topology_viz and not args.disable_tui:
topology_viz.stop()

print(f"\n\nError in run: {str(e)}\n")
print("--------------------------------")
report = input("\nThere was an error that got triggered in the code, would you like to report it? (y/N): ").strip()
if report.lower() == 'y' or report.lower() == 'yes':
logger.report_error(e)
finally:
print("Shutting down")
loop.run_until_complete(shutdown(signal.SIGTERM, loop))
loop.close()
print("Exiting")


if __name__ == "__main__":
Expand Down
5 changes: 5 additions & 0 deletions exo/orchestration/standard_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from exo.topology.device_capabilities import device_capabilities
from exo.topology.partitioning_strategy import Partition, PartitioningStrategy, map_partitions_to_shards
from exo import DEBUG
from exo.config import session_config
from exo.helpers import AsyncCallbackSystem
from exo.viz.topology_viz import TopologyViz
from exo.download.hf.hf_helpers import RepoProgressEvent
Expand Down Expand Up @@ -415,6 +416,10 @@ async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4)
self.topology = next_topology
if self.topology_viz:
self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id)

serialized_topology = next_topology.serialize()
session_config.set("topology", serialized_topology)

return next_topology

@property
Expand Down
Loading