Skip to content

Commit

Permalink
by default find an ephemeral node port fixes #35, more robust topolog…
Browse files Browse the repository at this point in the history
…y updates. both fix #15 and #14
  • Loading branch information
AlexCheema committed Jul 19, 2024
1 parent 54c9860 commit 3517769
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 25 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__pycache__/
.venv
test_weights.npz
.exo_used_ports

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
3 changes: 0 additions & 3 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,6 @@ async def handle_post_chat_completions(self, request):
headers={
"Content-Type": "application/json",
"Cache-Control": "no-cache",
# "Access-Control-Allow-Origin": "*",
# "Access-Control-Allow-Methods": "*",
# "Access-Control-Allow-Headers": "*",
}
)
await response.prepare(request)
Expand Down
34 changes: 33 additions & 1 deletion exo/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import asyncio
from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, Tuple
import socket
import random

DEBUG = int(os.getenv("DEBUG", default="0"))
DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
Expand All @@ -13,6 +15,36 @@
\___/_/\_\___/
"""

def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int:
used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports')

def read_used_ports():
if os.path.exists(used_ports_file):
with open(used_ports_file, 'r') as f:
return [int(line.strip()) for line in f if line.strip().isdigit()]
return []

def write_used_port(port, used_ports):
with open(used_ports_file, 'w') as f:
print(used_ports[-19:])
for p in used_ports[-19:] + [port]:
f.write(f"{p}\n")

used_ports = read_used_ports()
available_ports = set(range(min_port, max_port + 1)) - set(used_ports)

while available_ports:
port = random.choice(list(available_ports))
if DEBUG >= 2: print(f"Trying to find available port {port=}")
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((host, port))
write_used_port(port, used_ports)
return port
except socket.error:
available_ports.remove(port)

raise RuntimeError("No available ports in the specified range")

def print_exo():
print(exo_text)
Expand Down Expand Up @@ -81,4 +113,4 @@ def trigger(self, name: K, *args: T) -> None:

def trigger_all(self, *args: T) -> None:
for callback in self.callbacks.values():
callback.set(*args)
callback.set(*args)
31 changes: 18 additions & 13 deletions exo/networking/grpc/grpc_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_por
self.listen_port = listen_port
self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
self.broadcast_interval = broadcast_interval
self.known_peers: Dict[str, GRPCPeerHandle] = {}
self.peer_last_seen: Dict[str, float] = {}
self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float]] = {}
self.broadcast_task = None
self.listen_task = None
self.cleanup_task = None
Expand Down Expand Up @@ -74,7 +73,7 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
if DEBUG_DISCOVERY >= 2: print("No new peers discovered in the last grace period. Ending discovery process.")
break # No new peers found in the grace period, we are done

return list(self.known_peers.values())
return [peer_handle for peer_handle, _ in self.known_peers.values()]

async def task_broadcast_presence(self):
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
Expand Down Expand Up @@ -110,21 +109,27 @@ async def on_listen_message(self, data, addr):
peer_port = message['grpc_port']
device_capabilities = DeviceCapabilities(**message['device_capabilities'])
if peer_id not in self.known_peers:
self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
self.known_peers[peer_id] = (GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), time.time())
if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
self.peer_last_seen[peer_id] = time.time()
self.known_peers[peer_id] = (self.known_peers[peer_id][0], time.time())

async def task_listen_for_peers(self):
await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port))
if DEBUG_DISCOVERY >= 2: print("Started listen task")

async def task_cleanup_peers(self):
while True:
current_time = time.time()
timeout = 15 * self.broadcast_interval
peers_to_remove = [peer_id for peer_id, last_seen in self.peer_last_seen.items() if current_time - last_seen > timeout]
for peer_id in peers_to_remove:
del self.known_peers[peer_id]
del self.peer_last_seen[peer_id]
if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
await asyncio.sleep(self.broadcast_interval)
try:
current_time = time.time()
timeout = 15 * self.broadcast_interval
peers_to_remove = [peer_handle.id() for peer_handle, last_seen in self.known_peers.values() if not await peer_handle.is_connected() or current_time - last_seen > timeout]
if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, last_seen={last_seen}" for peer_handle, last_seen in self.known_peers.values()})
if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0: print(f"Cleaning up peers: {peers_to_remove}")
for peer_id in peers_to_remove:
if peer_id in self.known_peers: del self.known_peers[peer_id]
if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
await asyncio.sleep(self.broadcast_interval)
except Exception as e:
print(f"Error in cleanup peers: {e}")
import traceback
print(traceback.format_exc())
4 changes: 3 additions & 1 deletion exo/networking/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def __init__(self, node: Node, host: str, port: int):

async def start(self) -> None:
self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10), options=[
('grpc.max_metadata_size', 32*1024*1024)
('grpc.max_metadata_size', 32*1024*1024),
('grpc.max_send_message_length', 128 * 1024 * 1024),
('grpc.max_receive_message_length', 128 * 1024 * 1024),
])
node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
listen_addr = f'{self.host}:{self.port}'
Expand Down
12 changes: 7 additions & 5 deletions exo/orchestration/standard_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Opt
if result.size == 1: # we got a new token out
self.buffered_token_output[request_id][0].append(result.item())
self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")

if not is_finished:
asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
Expand Down Expand Up @@ -179,16 +179,17 @@ async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarr
return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]

async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
self.topology.update_node(self.id, self.device_capabilities)
next_topology = Topology()
next_topology.update_node(self.id, self.device_capabilities)

if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")

prev_visited = visited.copy()
visited.update(p.id() for p in self.peers)

for peer in self.peers:
self.topology.update_node(peer.id(), peer.device_capabilities())
self.topology.add_edge(self.id, peer.id())
next_topology.update_node(peer.id(), peer.device_capabilities())
next_topology.add_edge(self.id, peer.id())

if peer.id() in prev_visited:
if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
Expand All @@ -205,7 +206,8 @@ async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4)
except Exception as e:
print(f"Error collecting topology from {peer.id()}: {e}")

return self.topology
self.topology = next_topology
return next_topology

# TODO: unify this and collect_topology as global actions
async def global_reset(self, base_shard: Shard, visited: set[str] = set(), max_depth: int = 2) -> None:
Expand Down
8 changes: 6 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
from exo.api import ChatGPTAPI
from exo.helpers import print_yellow_exo
from exo.helpers import print_yellow_exo, find_available_port, DEBUG

# parse args
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
parser.add_argument("--node-id", type=str, default=str(uuid.uuid4()), help="Node ID")
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
parser.add_argument("--node-port", type=int, default=8080, help="Node port")
parser.add_argument("--node-port", type=int, default=None, help="Node port")
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
Expand Down Expand Up @@ -49,6 +49,10 @@
raise ValueError(f"Inference engine {args.inference_engine} not supported")
print(f"Using inference engine {inference_engine.__class__.__name__}")


if args.node_port is None:
args.node_port = find_available_port(args.node_host)
if DEBUG >= 1: print(f"Using available port: {args.node_port}")
discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy())
server = GRPCServer(node, args.node_host, args.node_port)
Expand Down

0 comments on commit 3517769

Please sign in to comment.