Skip to content

Commit

Permalink
implement dynamic inference engine selection
Browse files Browse the repository at this point in the history
implement the system detection and inference engine selection

implement dynamic inference engine selection

implement dynamic inference engine selection

implement dynamic inference engine selection

remove inconsistency

implement dynamic inference engine selection
  • Loading branch information
itsknk committed Jul 22, 2024
1 parent e49924e commit e934664
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 27 deletions.
26 changes: 25 additions & 1 deletion exo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, Tuple
import socket
import random
import platform
import psutil

DEBUG = int(os.getenv("DEBUG", default="0"))
DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
Expand All @@ -15,6 +17,28 @@
\___/_/\_\___/
"""

def get_system_info():
if psutil.MACOS:
if platform.machine() == 'arm64':
return "Apple Silicon Mac"
elif platform.machine() in ['x86_64', 'i386']:
return "Intel Mac"
else:
return "Unknown Mac architecture"
elif psutil.LINUX:
return "Linux"
else:
return "Non-Mac, non-Linux system"

def get_inference_engine():
system_info = get_system_info()
if system_info == "Apple Silicon Mac":
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
return MLXDynamicShardInferenceEngine()
else:
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
return TinygradDynamicShardInferenceEngine()

def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int:
used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports')

Expand Down Expand Up @@ -113,4 +137,4 @@ def trigger(self, name: K, *args: T) -> None:

def trigger_all(self, *args: T) -> None:
for callback in self.callbacks.values():
callback.set(*args)
callback.set(*args)
33 changes: 7 additions & 26 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@
import asyncio
import signal
import uuid
import platform
import psutil
import os
from typing import List
from exo.orchestration.standard_node import StandardNode
from exo.networking.grpc.grpc_server import GRPCServer
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
from exo.api import ChatGPTAPI
from exo.helpers import print_yellow_exo, find_available_port, DEBUG
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info

# parse args
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
Expand All @@ -27,33 +24,17 @@
args = parser.parse_args()

print_yellow_exo()
print(f"Starting exo {platform.system()=} {psutil.virtual_memory()=}")
if args.inference_engine is None:
if psutil.MACOS:
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
inference_engine = MLXDynamicShardInferenceEngine()
else:
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
import tinygrad.helpers
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
inference_engine = TinygradDynamicShardInferenceEngine()
else:
if args.inference_engine == "mlx":
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
inference_engine = MLXDynamicShardInferenceEngine()
elif args.inference_engine == "tinygrad":
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
import tinygrad.helpers
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
inference_engine = TinygradDynamicShardInferenceEngine()
else:
raise ValueError(f"Inference engine {args.inference_engine} not supported")
print(f"Using inference engine {inference_engine.__class__.__name__}")

system_info = get_system_info()
print(f"Detected system: {system_info}")

inference_engine = get_inference_engine()
print(f"Using inference engine: {inference_engine.__class__.__name__}")

if args.node_port is None:
args.node_port = find_available_port(args.node_host)
if DEBUG >= 1: print(f"Using available port: {args.node_port}")

discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}")
server = GRPCServer(node, args.node_host, args.node_port)
Expand Down

0 comments on commit e934664

Please sign in to comment.