From 9fe3ec63dd26b78d9c27e3bcb17f72a79c7ee977 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 11:51:33 -0800 Subject: [PATCH] cleaning up code, added pytorch engine to llama 3.2 1b model shard in models.py, removed old 3.2 1b model shard, moving to test server for more vram --- exo/inference/pytorch/inference.py | 21 ++++++++++++--------- exo/inference/pytorch/model/hf.py | 2 +- exo/models.py | 4 +--- exo/tinychat/index.html | 1 - setup.py | 2 +- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 2f87c1b10..8264aae83 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -12,6 +12,9 @@ from exo.helpers import DEBUG from exo.download.hf.hf_shard_download import HFShardDownloader +# llama +from transformers.models.llama.modeling_llama import LlamaModel + # model value options TOP_K = 20 TEMP = 0.6 @@ -52,7 +55,7 @@ def __init__(self, shard_downloader: HFShardDownloader): if torch.cuda.is_available(): self.device = torch.device("cuda") - self.torch_dtype = torch.float32 + self.torch_dtype = torch.float16 elif torch.backends.mps.is_available(): self.device = torch.device("mps") self.torch_dtype = torch.float32 @@ -105,10 +108,10 @@ async def infer_prompt( print(f"prompt: {prompt}") print(f"shard: {shard}") print(f"inference_state: {inference_state}") - + await self.ensure_shard(shard) - - # setup prompt input + + # setup prompt input messages = [{"role": "user", "content": prompt}] txt = self.tokenizer.apply_chat_template( messages, @@ -174,9 +177,9 @@ async def infer_prompt( async def infer_tensor( self, - request_id: str, - shard: Shard, - input_data: np.ndarray, + request_id: str, + shard: Shard, + input_data: np.ndarray, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: if DEBUG >= 4: @@ -192,13 +195,13 @@ async def infer_tensor( # get cache from inference_state past_iids, cached_iids = self.infer_caching(inference_state) - # detect if hidden_states or not + # detect if hidden_states or not hidden_states = None self.past_input_ids = None if input_ids.size()[-1] > 1: hidden_states = input_ids else: - if past_iids is not None: + if past_iids is not None: self.past_input_ids = past_iids else: self.past_input_ids = input_ids diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 38cd85c20..57a1590b0 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -16,7 +16,7 @@ TemperatureLogitsWarper ) -# llama +# llama from transformers.models.llama.modeling_llama import LlamaModel class ShardedHuggingFaceModel: diff --git a/exo/models.py b/exo/models.py index 6f69960ea..2f1e7d10a 100644 --- a/exo/models.py +++ b/exo/models.py @@ -4,6 +4,7 @@ ### llama "llama-3.2-1b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16), + "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16), }, "llama-3.2-3b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), @@ -36,9 +37,6 @@ "llama-3-1B-Base": { "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-1B-Base", start_layer=0, end_layer=0, n_layers=3), }, - "meta-llama/Llama-3.2-1B-Instruct": { - "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=24), - }, ### mistral "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),}, "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),}, diff --git a/exo/tinychat/index.html b/exo/tinychat/index.html index c00d2b0a6..9cad69d58 100644 --- a/exo/tinychat/index.html +++ b/exo/tinychat/index.html @@ -38,7 +38,6 @@ - diff --git a/setup.py b/setup.py index 8401167be..b23485a7f 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ "uuid==1.30", "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad", "torch==2.4.0+cu124", - "accelerate=0.33.0" + "accelerate" ] # Add macOS-specific packages if on Darwin (macOS)