diff --git a/.gitignore b/.gitignore index a2b1fec44..c5d644a87 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ __pycache__/ .venv test_weights.npz +.exo_used_ports # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index e46e6cf5f..796101859 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -187,9 +187,6 @@ async def handle_post_chat_completions(self, request): headers={ "Content-Type": "application/json", "Cache-Control": "no-cache", - # "Access-Control-Allow-Origin": "*", - # "Access-Control-Allow-Methods": "*", - # "Access-Control-Allow-Headers": "*", } ) await response.prepare(request) diff --git a/exo/helpers.py b/exo/helpers.py index 857da5649..34a84035e 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -1,6 +1,8 @@ import os import asyncio from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, Tuple +import socket +import random DEBUG = int(os.getenv("DEBUG", default="0")) DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0")) @@ -13,6 +15,36 @@ \___/_/\_\___/ """ +def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int: + used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports') + + def read_used_ports(): + if os.path.exists(used_ports_file): + with open(used_ports_file, 'r') as f: + return [int(line.strip()) for line in f if line.strip().isdigit()] + return [] + + def write_used_port(port, used_ports): + with open(used_ports_file, 'w') as f: + print(used_ports[-19:]) + for p in used_ports[-19:] + [port]: + f.write(f"{p}\n") + + used_ports = read_used_ports() + available_ports = set(range(min_port, max_port + 1)) - set(used_ports) + + while available_ports: + port = random.choice(list(available_ports)) + if DEBUG >= 2: print(f"Trying to find available port {port=}") + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind((host, port)) + write_used_port(port, used_ports) + return port + except socket.error: + available_ports.remove(port) + + raise RuntimeError("No available ports in the specified range") def print_exo(): print(exo_text) @@ -81,4 +113,4 @@ def trigger(self, name: K, *args: T) -> None: def trigger_all(self, *args: T) -> None: for callback in self.callbacks.values(): - callback.set(*args) + callback.set(*args) \ No newline at end of file diff --git a/exo/networking/grpc/grpc_discovery.py b/exo/networking/grpc/grpc_discovery.py index f7a627ce5..f01f38867 100644 --- a/exo/networking/grpc/grpc_discovery.py +++ b/exo/networking/grpc/grpc_discovery.py @@ -30,8 +30,7 @@ def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_por self.listen_port = listen_port self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port self.broadcast_interval = broadcast_interval - self.known_peers: Dict[str, GRPCPeerHandle] = {} - self.peer_last_seen: Dict[str, float] = {} + self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float]] = {} self.broadcast_task = None self.listen_task = None self.cleanup_task = None @@ -74,7 +73,7 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: if DEBUG_DISCOVERY >= 2: print("No new peers discovered in the last grace period. Ending discovery process.") break # No new peers found in the grace period, we are done - return list(self.known_peers.values()) + return [peer_handle for peer_handle, _ in self.known_peers.values()] async def task_broadcast_presence(self): transport, _ = await asyncio.get_event_loop().create_datagram_endpoint( @@ -110,9 +109,9 @@ async def on_listen_message(self, data, addr): peer_port = message['grpc_port'] device_capabilities = DeviceCapabilities(**message['device_capabilities']) if peer_id not in self.known_peers: - self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities) + self.known_peers[peer_id] = (GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), time.time()) if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}") - self.peer_last_seen[peer_id] = time.time() + self.known_peers[peer_id] = (self.known_peers[peer_id][0], time.time()) async def task_listen_for_peers(self): await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port)) @@ -120,11 +119,17 @@ async def task_listen_for_peers(self): async def task_cleanup_peers(self): while True: - current_time = time.time() - timeout = 15 * self.broadcast_interval - peers_to_remove = [peer_id for peer_id, last_seen in self.peer_last_seen.items() if current_time - last_seen > timeout] - for peer_id in peers_to_remove: - del self.known_peers[peer_id] - del self.peer_last_seen[peer_id] - if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.") - await asyncio.sleep(self.broadcast_interval) + try: + current_time = time.time() + timeout = 15 * self.broadcast_interval + peers_to_remove = [peer_handle.id() for peer_handle, last_seen in self.known_peers.values() if not await peer_handle.is_connected() or current_time - last_seen > timeout] + if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, last_seen={last_seen}" for peer_handle, last_seen in self.known_peers.values()}) + if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0: print(f"Cleaning up peers: {peers_to_remove}") + for peer_id in peers_to_remove: + if peer_id in self.known_peers: del self.known_peers[peer_id] + if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.") + await asyncio.sleep(self.broadcast_interval) + except Exception as e: + print(f"Error in cleanup peers: {e}") + import traceback + print(traceback.format_exc()) diff --git a/exo/networking/grpc/grpc_server.py b/exo/networking/grpc/grpc_server.py index ad54086c8..e956ddd51 100644 --- a/exo/networking/grpc/grpc_server.py +++ b/exo/networking/grpc/grpc_server.py @@ -17,7 +17,9 @@ def __init__(self, node: Node, host: str, port: int): async def start(self) -> None: self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10), options=[ - ('grpc.max_metadata_size', 32*1024*1024) + ('grpc.max_metadata_size', 32*1024*1024), + ('grpc.max_send_message_length', 128 * 1024 * 1024), + ('grpc.max_receive_message_length', 128 * 1024 * 1024), ]) node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server) listen_addr = f'{self.host}:{self.port}' diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py index d8ba21972..268663ea1 100644 --- a/exo/orchestration/standard_node.py +++ b/exo/orchestration/standard_node.py @@ -84,7 +84,7 @@ async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Opt if result.size == 1: # we got a new token out self.buffered_token_output[request_id][0].append(result.item()) self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished) - if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}") + if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}") if not is_finished: asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state)) @@ -179,7 +179,8 @@ async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarr return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1] async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology: - self.topology.update_node(self.id, self.device_capabilities) + next_topology = Topology() + next_topology.update_node(self.id, self.device_capabilities) if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}") @@ -187,8 +188,8 @@ async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) visited.update(p.id() for p in self.peers) for peer in self.peers: - self.topology.update_node(peer.id(), peer.device_capabilities()) - self.topology.add_edge(self.id, peer.id()) + next_topology.update_node(peer.id(), peer.device_capabilities()) + next_topology.add_edge(self.id, peer.id()) if peer.id() in prev_visited: if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...") @@ -205,7 +206,8 @@ async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) except Exception as e: print(f"Error collecting topology from {peer.id()}: {e}") - return self.topology + self.topology = next_topology + return next_topology # TODO: unify this and collect_topology as global actions async def global_reset(self, base_shard: Shard, visited: set[str] = set(), max_depth: int = 2) -> None: diff --git a/main.py b/main.py index 52459f319..b79fee744 100644 --- a/main.py +++ b/main.py @@ -11,13 +11,13 @@ from exo.networking.grpc.grpc_discovery import GRPCDiscovery from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy from exo.api import ChatGPTAPI -from exo.helpers import print_yellow_exo +from exo.helpers import print_yellow_exo, find_available_port, DEBUG # parse args parser = argparse.ArgumentParser(description="Initialize GRPC Discovery") parser.add_argument("--node-id", type=str, default=str(uuid.uuid4()), help="Node ID") parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host") -parser.add_argument("--node-port", type=int, default=8080, help="Node port") +parser.add_argument("--node-port", type=int, default=None, help="Node port") parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery") parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery") parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting") @@ -49,6 +49,10 @@ raise ValueError(f"Inference engine {args.inference_engine} not supported") print(f"Using inference engine {inference_engine.__class__.__name__}") + +if args.node_port is None: + args.node_port = find_available_port(args.node_host) + if DEBUG >= 1: print(f"Using available port: {args.node_port}") discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port) node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy()) server = GRPCServer(node, args.node_host, args.node_port)