From 1d54f105149f909c4720217354b18f46cc4d860b Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Wed, 31 Jul 2024 22:50:04 +0100 Subject: [PATCH] pass on tinygrad set_on_download_progress --- exo/inference/inference_engine.py | 6 +++++- exo/inference/tinygrad/inference.py | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/exo/inference/inference_engine.py b/exo/inference/inference_engine.py index 013560a55..827f429bd 100644 --- a/exo/inference/inference_engine.py +++ b/exo/inference/inference_engine.py @@ -1,6 +1,6 @@ import numpy as np -from typing import Tuple, Optional +from typing import Tuple, Optional, Callable from abc import ABC, abstractmethod from .shard import Shard @@ -13,3 +13,7 @@ async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarr @abstractmethod async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: pass + + @abstractmethod + def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): + pass diff --git a/exo/inference/tinygrad/inference.py b/exo/inference/tinygrad/inference.py index c0e352773..545b335f7 100644 --- a/exo/inference/tinygrad/inference.py +++ b/exo/inference/tinygrad/inference.py @@ -1,7 +1,7 @@ import asyncio from functools import partial from pathlib import Path -from typing import List, Optional, Union +from typing import List, Optional, Union, Callable import json import tiktoken from tiktoken.load import load_tiktoken_bpe @@ -294,3 +294,6 @@ async def ensure_shard(self, shard: Shard): self.shard = shard self.model = model self.tokenizer = tokenizer + + def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): + pass \ No newline at end of file