diff --git a/exo/helpers.py b/exo/helpers.py index 97b1d8e8f..5e3cbbc12 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -30,14 +30,15 @@ def get_system_info(): else: return "Non-Mac, non-Linux system" -def get_inference_engine(): - system_info = get_system_info() - if system_info == "Apple Silicon Mac": +def get_inference_engine(inference_engine_name): + if inference_engine_name == "mlx": from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine return MLXDynamicShardInferenceEngine() - else: + elif inference_engine_name == "tinygrad": from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine return TinygradDynamicShardInferenceEngine() + else: + raise ValueError(f"Inference engine {inference_engine_name} not supported") def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int: used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports') diff --git a/main.py b/main.py index cda988880..d0ae3f522 100644 --- a/main.py +++ b/main.py @@ -31,7 +31,8 @@ system_info = get_system_info() print(f"Detected system: {system_info}") -inference_engine = get_inference_engine() +inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad") +inference_engine = get_inference_engine(inference_engine_name) print(f"Using inference engine: {inference_engine.__class__.__name__}") if args.node_port is None: