Skip to content

Commit

Permalink
fix regression introduced by image_str for tinygrad
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Jul 31, 2024
1 parent 1d54f10 commit 7676625
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions exo/inference/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

class InferenceEngine(ABC):
@abstractmethod
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
pass

@abstractmethod
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
pass

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion exo/inference/tinygrad/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
def __init__(self):
self.shard = None

async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
# TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
await self.ensure_shard(shard)
start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
Expand Down

0 comments on commit 7676625

Please sign in to comment.