diff --git a/.gitignore b/.gitignore index 814af48c7..cceeb76d0 100644 --- a/.gitignore +++ b/.gitignore @@ -169,5 +169,8 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +# vscode +.vscode/ + **/*.xcodeproj/* .aider* diff --git a/README.md b/README.md index 0b12c7c52..374dab028 100644 --- a/README.md +++ b/README.md @@ -260,3 +260,29 @@ exo supports the following inference engines: - ✅ [GRPC](exo/networking/grpc) - 🚧 [Radio](TODO) - 🚧 [Bluetooth](TODO) + +## CLI Options + +- `` (run): Command to run +- ``: Model name to run +- `--node-id` (str): Node ID +- `--node-host` (str) (default: '0.0.0.0'): Node host +- `--node-port` (int): Node port +- `--listen-port` (int) (default: 5678): Listening port for discovery +- `--download-quick-check` (flag) (default: False): Quick check local path for model shards download +- `--max-parallel-downloads` (int) (default: 4): Max parallel downloads for model shards download +- `--prometheus-client-port` (int): Prometheus client port +- `--broadcast-port` (int) (default: 5678): Broadcast port for discovery +- `--discovery-module` (str) (default: 'udp'): Discovery module to use +- `--discovery-timeout` (int) (default: 30): Discovery timeout in seconds +- `--discovery-config-path` (str): Path to discovery config json file +- `--wait-for-peers` (int) (default: 0): Number of peers to wait to connect to before starting +- `--chatgpt-api-port` (int) (default: 8000): ChatGPT API port +- `--chatgpt-api-response-timeout` (int) (default: 90): ChatGPT API response timeout in seconds +- `--max-generate-tokens` (int) (default: 10000): Max tokens to generate in each request +- `--inference-engine` (str): Inference engine to use (mlx, tinygrad, or dummy) +- `--disable-tui`, `--no-disable-tui`: Disable TUI +- `--run-model` (str): Specify a model to run directly +- `--prompt` (str) (default: 'Who are you?'): Prompt for the model when using --run-model +- `--tailscale-api-key` (str): Tailscale API key +- `--tailnet-name` (str): Tailnet name diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index ab787eab4..18bea98b5 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -17,6 +17,9 @@ from exo.models import model_base_shards from typing import Callable +from exo.telemetry.constants import TelemetryAction +from exo.telemetry.logger import Logger + class Message: def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]): @@ -224,6 +227,8 @@ async def handle_get_download_progress(self, request): return web.json_response(progress_data) async def handle_post_chat_completions(self, request): + logger = Logger() + logger.write_log(TelemetryAction.REQUEST_RECEIVED) data = await request.json() if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}") stream = data.get("stream", False) diff --git a/exo/config.py b/exo/config.py new file mode 100644 index 000000000..55ba77afa --- /dev/null +++ b/exo/config.py @@ -0,0 +1,108 @@ +import json +from pathlib import Path +from platformdirs import user_data_dir +import uuid + +from exo.helpers import get_git_hash + +class PersistentConfig: + """ + Persistent configuration that should be saved between sessions. + + Note that this syncs with the SessionConfig instance and is a + subset of the data stored there, so modules should prefer to use + SessionConfig for reading any data, but for writing data that they + need to persist beyond the current session, they should use this + class. + """ + CONFIG_FILE_NAME = "config.json" + _instance = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if not self._initialized: + self._config_file = self.initialize() + self._initialized = True + + # add default values (but don't overwrite if already set) + self.set("device_id", str(uuid.uuid4()), replace_if_exists=False) + self.set("node_id", str(uuid.uuid4()), replace_if_exists=False) + + def initialize(self): + app_data = Path(user_data_dir("exo", appauthor="exo_labs")) + print(f"Using app data directory: {app_data}") + app_data.mkdir(parents=True, exist_ok=True) + config_file = app_data / self.CONFIG_FILE_NAME + + if not config_file.exists(): + with config_file.open('w') as f: + json.dump({}, f) + + return config_file + + def set(self, key: str, value: any, replace_if_exists: bool = True): + print(f"Setting {key}={value} in config file") + + with self._config_file.open('r') as f: + config = json.load(f) + + # Update config and write back to file + if replace_if_exists or key not in config: + config[key] = value + with self._config_file.open('w') as f: + json.dump(config, f, indent=4) + + # Sync to session config + SessionConfig().set(key, value) + + + def get(self, key: str): + with self._config_file.open('r') as f: + return json.load(f).get(key) + +class SessionConfig: + """ + Handles temporary configuration specific to the current session + + Note that this syncs with the PersistentConfig instance and is a + superset of the data there. + """ + _instance = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if not self._initialized: + # Load persistent config during initialization + self._session_data = {} + self.sync_with_persistent() + self.set("session_id", str(uuid.uuid4())) + self.set("commit_id", get_git_hash()) + self._initialized = True + + def sync_with_persistent(self): + """Sync session data with persistent config""" + persistent = PersistentConfig() + with persistent._config_file.open('r') as f: + persistent_data = json.load(f) + self._session_data.update(persistent_data) + + def set(self, key: str, value: any, replace_if_exists: bool = True): + if replace_if_exists or key not in self._session_data: + self._session_data[key] = value + + def get(self, key: str): + return self._session_data.get(key) + +# Expose singleton instances +session_config = SessionConfig() +persistent_config = PersistentConfig() \ No newline at end of file diff --git a/exo/helpers.py b/exo/helpers.py index b657b2775..c8261abcb 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -8,6 +8,7 @@ import uuid import netifaces from pathlib import Path +import subprocess import tempfile DEBUG = int(os.getenv("DEBUG", default="0")) @@ -169,32 +170,6 @@ def is_valid_uuid(val): return False -def get_or_create_node_id(): - NODE_ID_FILE = Path(tempfile.gettempdir())/".exo_node_id" - try: - if NODE_ID_FILE.is_file(): - with open(NODE_ID_FILE, "r") as f: - stored_id = f.read().strip() - if is_valid_uuid(stored_id): - if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}") - return stored_id - else: - if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.") - - new_id = str(uuid.uuid4()) - with open(NODE_ID_FILE, "w") as f: - f.write(new_id) - - if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}") - return new_id - except IOError as e: - if DEBUG >= 2: print(f"IO error creating node_id: {e}") - return str(uuid.uuid4()) - except Exception as e: - if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}") - return str(uuid.uuid4()) - - def pretty_print_bytes(size_in_bytes: int) -> str: if size_in_bytes < 1024: return f"{size_in_bytes} B" @@ -234,3 +209,11 @@ def get_all_ip_addresses(): except: if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.") return ["localhost"] + + +def get_git_hash(): + try: + git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8") + return git_hash + except subprocess.CalledProcessError: + return None # In case the command fails (e.g., if not in a git repository) \ No newline at end of file diff --git a/exo/main.py b/exo/main.py index f20d15b07..3334c79e6 100644 --- a/exo/main.py +++ b/exo/main.py @@ -13,17 +13,20 @@ from exo.networking.udp.udp_discovery import UDPDiscovery from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle +from exo.telemetry.constants import TelemetryAction from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy from exo.api import ChatGPTAPI +from exo.config import session_config from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader from exo.download.hf.hf_shard_download import HFShardDownloader -from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link +from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_all_ip_addresses, terminal_link from exo.inference.shard import Shard from exo.inference.inference_engine import get_inference_engine, InferenceEngine from exo.inference.dummy_inference_engine import DummyInferenceEngine from exo.inference.tokenizers import resolve_tokenizer from exo.orchestration.node import Node from exo.models import model_base_shards +from exo.telemetry.logger import Logger from exo.viz.topology_viz import TopologyViz # parse args @@ -51,6 +54,7 @@ parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?") parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key") parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name") +parser.add_argument("--logging-url", type=str, default="https://exo-logs.foobar.dev/api/v1/logs/bulk", help="Logging URL") args = parser.parse_args() print(f"Selected inference engine: {args.inference_engine}") @@ -71,7 +75,9 @@ args.node_port = find_available_port(args.node_host) if DEBUG >= 1: print(f"Using available port: {args.node_port}") -args.node_id = args.node_id or get_or_create_node_id() +if args.node_id: + session_config.set("node_id", args.node_id) + chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()] web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()] if DEBUG >= 0: @@ -84,7 +90,7 @@ if args.discovery_module == "udp": discovery = UDPDiscovery( - args.node_id, + session_config.get("node_id"), args.node_port, args.listen_port, args.broadcast_port, @@ -93,7 +99,7 @@ ) elif args.discovery_module == "tailscale": discovery = TailscaleDiscovery( - args.node_id, + session_config.get("node_id"), args.node_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout, @@ -103,10 +109,10 @@ elif args.discovery_module == "manual": if not args.discovery_config_path: raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.") - discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities)) + discovery = ManualDiscovery(args.discovery_config_path, session_config.get("node_id"), create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities)) topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None node = StandardNode( - args.node_id, + session_config.get("node_id"), None, inference_engine, discovery, @@ -127,6 +133,9 @@ lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None ) +# Initialize the logger singleton +logger = Logger(args.logging_url, session_config.get("node_id")) +logger.write_log(TelemetryAction.START) def preemptively_start_download(request_id: str, opaque_status: str): try: @@ -198,6 +207,7 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam except Exception as e: print(f"Error processing prompt: {str(e)}") traceback.print_exc() + raise e finally: node.on_token.deregister(callback_id) @@ -226,15 +236,27 @@ def handle_exit(): def run(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(main()) - except KeyboardInterrupt: - print("Received keyboard interrupt. Shutting down...") - finally: - loop.run_until_complete(shutdown(signal.SIGTERM, loop)) - loop.close() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(main()) + except KeyboardInterrupt: + print("Received keyboard interrupt. Shutting down...") + except Exception as e: + # First tear down the TUI if it exists + if topology_viz and not args.disable_tui: + topology_viz.stop() + + print(f"\n\nError in run: {str(e)}\n") + print("--------------------------------") + report = input("\nThere was an error that got triggered in the code, would you like to report it? (y/N): ").strip() + if report.lower() == 'y' or report.lower() == 'yes': + logger.report_error(e) + finally: + print("Shutting down") + loop.run_until_complete(shutdown(signal.SIGTERM, loop)) + loop.close() + print("Exiting") if __name__ == "__main__": diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py index 3a11900ae..0bce6f382 100644 --- a/exo/orchestration/standard_node.py +++ b/exo/orchestration/standard_node.py @@ -12,6 +12,7 @@ from exo.topology.device_capabilities import device_capabilities from exo.topology.partitioning_strategy import Partition, PartitioningStrategy, map_partitions_to_shards from exo import DEBUG +from exo.config import session_config from exo.helpers import AsyncCallbackSystem from exo.viz.topology_viz import TopologyViz from exo.download.hf.hf_helpers import RepoProgressEvent @@ -415,6 +416,10 @@ async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) self.topology = next_topology if self.topology_viz: self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id) + + serialized_topology = next_topology.serialize() + session_config.set("topology", serialized_topology) + return next_topology @property diff --git a/exo/scripts/docs_generator.py b/exo/scripts/docs_generator.py new file mode 100644 index 000000000..9ece0d4de --- /dev/null +++ b/exo/scripts/docs_generator.py @@ -0,0 +1,94 @@ +import argparse +import re +import sys +from pathlib import Path +from typing import List + +def generate_cli_options_doc(parser: argparse.ArgumentParser) -> str: + """Generate markdown documentation for CLI options from argparse parser.""" + docs = ["## CLI Options\n\n"] + + actions: List[argparse.Action] = parser._actions + + for action in actions: + # Skip help actions + if isinstance(action, argparse._HelpAction): + continue + + option_names = [] + if action.option_strings: + option_names = action.option_strings + elif action.dest in ['command', 'model_name']: # Positional arguments + option_names = [f"<{action.dest}>"] + + if not option_names: + continue + + option_str = ', '.join(f"`{name}`" for name in option_names) + + # Get type hint + type_str = '' + if action.type: + type_str = f" ({action.type.__name__})" + elif isinstance(action, argparse._StoreTrueAction): + type_str = " (flag)" + elif isinstance(action, argparse._StoreAction) and action.choices: + type_str = f" ({' | '.join(action.choices)})" + + # Get default value if exists + default_str = '' + if action.default is not None and action.default != argparse.SUPPRESS: + if isinstance(action.default, str): + default_str = f" (default: '{action.default}')" + else: + default_str = f" (default: {action.default})" + + # Combine help text + help_text = action.help or '' + + docs.append(f"- {option_str}{type_str}{default_str}: {help_text}\n") + + return ''.join(docs) + +def update_readme_cli_options(parser: argparse.ArgumentParser, readme_path: Path) -> None: + """Update README.md with current CLI options.""" + try: + content = readme_path.read_text() + + # Generate new CLI options documentation + cli_docs = generate_cli_options_doc(parser) + + # Replace existing CLI options section or append to end + cli_section_pattern = r"## CLI Options\n\n(?:.*?\n)*?(?=##|$)" + if re.search(cli_section_pattern, content, re.DOTALL): + new_content = re.sub(cli_section_pattern, cli_docs, content, flags=re.DOTALL) + else: + # Remove trailing newline from content if it exists before adding cli_docs + new_content = content.rstrip() + "\n" + cli_docs + + readme_path.write_text(new_content) + print(f"Successfully updated CLI options in {readme_path}") + + except Exception as e: + print(f"Error: Could not update README.md with CLI options: {e}", file=sys.stderr) + sys.exit(1) + +if __name__ == "__main__": + # Import main parser only when running as script + try: + from exo.main import parser + except ImportError: + print("Error: Could not import parser from exo.main. Make sure exo is installed.", file=sys.stderr) + sys.exit(1) + + # Find repository root (where README.md is located) + current_dir = Path(__file__).parent + repo_root = current_dir + while not (repo_root / "README.md").exists(): + repo_root = repo_root.parent + if repo_root == repo_root.parent: # Reached root directory + print("Error: Could not find README.md in parent directories", file=sys.stderr) + sys.exit(1) + + readme_path = repo_root / "README.md" + update_readme_cli_options(parser, readme_path) \ No newline at end of file diff --git a/exo/telemetry/__init__.py b/exo/telemetry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/exo/telemetry/constants.py b/exo/telemetry/constants.py new file mode 100644 index 000000000..65bc65cdc --- /dev/null +++ b/exo/telemetry/constants.py @@ -0,0 +1,7 @@ +from enum import Enum + +class TelemetryAction(str, Enum): + START = "start" + REQUEST_RECEIVED = "request_received" + STOP = "stop" + ERROR = "error" \ No newline at end of file diff --git a/exo/telemetry/logger.py b/exo/telemetry/logger.py new file mode 100644 index 000000000..c7669e364 --- /dev/null +++ b/exo/telemetry/logger.py @@ -0,0 +1,93 @@ +import json +from pathlib import Path +from platformdirs import user_data_dir +import requests +import traceback + +from exo.config import session_config +from exo.telemetry.constants import TelemetryAction + +class Logger: + _instance = None + _initialized = False + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, logging_url: str = None, node_id: str = None): + if not self._initialized: + assert logging_url is not None and node_id is not None, "logging_url and node_id are required for first initialization" + self.logging_url = logging_url + self.node_id = node_id + self.log_file = self._init_log_file(node_id) + self._initialized = True + + def _init_log_file(self, node_id: str) -> Path: + app_data = Path(user_data_dir("exo", appauthor="exo_labs")) + logs_dir = app_data / "logs" / node_id + logs_dir.mkdir(parents=True, exist_ok=True) + log_file = logs_dir / "logs.txt" + # Open in write mode to clear/create the file + with open(log_file, 'w') as f: + pass + + # return the log file path + return log_file + + def write_log(self, action: TelemetryAction, data: dict = {}) -> None: + + # Read existing lines if file exists + lines = [] + if self.log_file.exists(): + with open(self.log_file, "r") as f: + lines = f.readlines() + + # Add new line and keep only last 50 entries + log_entry = { + "commit_id": session_config.get("commit_id"), + "action": action, + "device_id": session_config.get("device_id"), + "session_id": session_config.get("session_id"), + "node_id": self.node_id, + "topology": session_config.get("topology") or {}, + "data": data + } + + # Filter out None values + log_entry = {k: v for k, v in log_entry.items() if v is not None} + + lines.append(json.dumps(log_entry) + "\n") + lines = lines[-50:] + + # Write back all lines + with open(self.log_file, "w") as f: + f.writelines(lines) + + + def send_logs(self) -> None: + headers = { + "accept": "application/json", + "Content-Type": "application/json" + } + # read the log file into a json array + with open(self.log_file, "r") as f: + log_entries = [json.loads(line) for line in f.readlines()] + r = requests.post(self.logging_url, headers=headers, json=log_entries) + if r.status_code != 200: + print(f"Error reporting logs: {r.status_code} {r.text}") + else: + print("Logs reported successfully\n") + + + def report_error(self, error: Exception) -> None: + error_data = { + "error": str(error), + "stacktrace": traceback.format_exc() + } + self.write_log(TelemetryAction.ERROR, error_data) + self.send_logs() + + + \ No newline at end of file diff --git a/exo/topology/topology.py b/exo/topology/topology.py index 46b512e50..367722fe1 100644 --- a/exo/topology/topology.py +++ b/exo/topology/topology.py @@ -1,5 +1,5 @@ from .device_capabilities import DeviceCapabilities -from typing import Dict, Set, Optional +from typing import Any, Dict, Set, Optional class Topology: @@ -8,6 +8,13 @@ def __init__(self): self.peer_graph: Dict[str, Set[str]] = {} # Adjacency list representing the graph self.active_node_id: Optional[str] = None + def serialize(self) -> Dict[str, Any]: + return { + "nodes": [node.to_dict() for node in self.nodes.values()], + "peer_graph": self.peer_graph, + "active_node_id": self.active_node_id + } + def update_node(self, node_id: str, device_capabilities: DeviceCapabilities): self.nodes[node_id] = device_capabilities diff --git a/exo/viz/topology_viz.py b/exo/viz/topology_viz.py index b9910e9d6..20ab54bc8 100644 --- a/exo/viz/topology_viz.py +++ b/exo/viz/topology_viz.py @@ -305,3 +305,8 @@ def _generate_download_layout(self) -> Table: summary.add_row("", progress_bar, eta_str) return summary + + def stop(self): + """Cleanly stop the live display.""" + if hasattr(self, 'live_panel'): + self.live_panel.stop() diff --git a/setup.py b/setup.py index 2a1ddb075..c0faf96a4 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,7 @@ "numpy==2.0.0", "nvidia-ml-py==12.560.30", "pillow==10.4.0", + "platformdirs==4.3.6", "prometheus-client==0.20.0", "protobuf==5.27.1", "psutil==6.0.0",