Skip to content

Commit

Permalink
Merge pull request exo-explore#58 from jakobdylanc/main
Browse files Browse the repository at this point in the history
Inference engine selection improvements
  • Loading branch information
AlexCheema authored Jul 22, 2024
2 parents 8dd20cd + 2949637 commit 4ab4474
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
9 changes: 5 additions & 4 deletions exo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ def get_system_info():
else:
return "Non-Mac, non-Linux system"

def get_inference_engine():
system_info = get_system_info()
if system_info == "Apple Silicon Mac":
def get_inference_engine(inference_engine_name):
if inference_engine_name == "mlx":
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
return MLXDynamicShardInferenceEngine()
else:
elif inference_engine_name == "tinygrad":
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
return TinygradDynamicShardInferenceEngine()
else:
raise ValueError(f"Inference engine {inference_engine_name} not supported")

def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int:
used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports')
Expand Down
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
system_info = get_system_info()
print(f"Detected system: {system_info}")

inference_engine = get_inference_engine()
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
inference_engine = get_inference_engine(inference_engine_name)
print(f"Using inference engine: {inference_engine.__class__.__name__}")

if args.node_port is None:
Expand Down

0 comments on commit 4ab4474

Please sign in to comment.