From 76766253cde775f58f5637344dae2656f1c1447a Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Wed, 31 Jul 2024 22:53:46 +0100 Subject: [PATCH] fix regression introduced by image_str for tinygrad --- exo/inference/inference_engine.py | 4 ++-- exo/inference/tinygrad/inference.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/exo/inference/inference_engine.py b/exo/inference/inference_engine.py index 827f429bd..264dd5e92 100644 --- a/exo/inference/inference_engine.py +++ b/exo/inference/inference_engine.py @@ -7,11 +7,11 @@ class InferenceEngine(ABC): @abstractmethod - async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): pass @abstractmethod - async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: pass @abstractmethod diff --git a/exo/inference/tinygrad/inference.py b/exo/inference/tinygrad/inference.py index 545b335f7..6a3e8ce4e 100644 --- a/exo/inference/tinygrad/inference.py +++ b/exo/inference/tinygrad/inference.py @@ -198,7 +198,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine): def __init__(self): self.shard = None - async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): + async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests. await self.ensure_shard(shard) start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0