diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py
index 2f87c1b10..8264aae83 100644
--- a/exo/inference/pytorch/inference.py
+++ b/exo/inference/pytorch/inference.py
@@ -12,6 +12,9 @@
from exo.helpers import DEBUG
from exo.download.hf.hf_shard_download import HFShardDownloader
+# llama
+from transformers.models.llama.modeling_llama import LlamaModel
+
# model value options
TOP_K = 20
TEMP = 0.6
@@ -52,7 +55,7 @@ def __init__(self, shard_downloader: HFShardDownloader):
if torch.cuda.is_available():
self.device = torch.device("cuda")
- self.torch_dtype = torch.float32
+ self.torch_dtype = torch.float16
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
self.torch_dtype = torch.float32
@@ -105,10 +108,10 @@ async def infer_prompt(
print(f"prompt: {prompt}")
print(f"shard: {shard}")
print(f"inference_state: {inference_state}")
-
+
await self.ensure_shard(shard)
-
- # setup prompt input
+
+ # setup prompt input
messages = [{"role": "user", "content": prompt}]
txt = self.tokenizer.apply_chat_template(
messages,
@@ -174,9 +177,9 @@ async def infer_prompt(
async def infer_tensor(
self,
- request_id: str,
- shard: Shard,
- input_data: np.ndarray,
+ request_id: str,
+ shard: Shard,
+ input_data: np.ndarray,
inference_state: Optional[str] = None
) -> Tuple[np.ndarray, str, bool]:
if DEBUG >= 4:
@@ -192,13 +195,13 @@ async def infer_tensor(
# get cache from inference_state
past_iids, cached_iids = self.infer_caching(inference_state)
- # detect if hidden_states or not
+ # detect if hidden_states or not
hidden_states = None
self.past_input_ids = None
if input_ids.size()[-1] > 1:
hidden_states = input_ids
else:
- if past_iids is not None:
+ if past_iids is not None:
self.past_input_ids = past_iids
else:
self.past_input_ids = input_ids
diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py
index 38cd85c20..57a1590b0 100644
--- a/exo/inference/pytorch/model/hf.py
+++ b/exo/inference/pytorch/model/hf.py
@@ -16,7 +16,7 @@
TemperatureLogitsWarper
)
-# llama
+# llama
from transformers.models.llama.modeling_llama import LlamaModel
class ShardedHuggingFaceModel:
diff --git a/exo/models.py b/exo/models.py
index 6f69960ea..2f1e7d10a 100644
--- a/exo/models.py
+++ b/exo/models.py
@@ -4,6 +4,7 @@
### llama
"llama-3.2-1b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16),
+ "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16),
},
"llama-3.2-3b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
@@ -36,9 +37,6 @@
"llama-3-1B-Base": {
"PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-1B-Base", start_layer=0, end_layer=0, n_layers=3),
},
- "meta-llama/Llama-3.2-1B-Instruct": {
- "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=24),
- },
### mistral
"mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),},
"mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),},
diff --git a/exo/tinychat/index.html b/exo/tinychat/index.html
index c00d2b0a6..9cad69d58 100644
--- a/exo/tinychat/index.html
+++ b/exo/tinychat/index.html
@@ -38,7 +38,6 @@
-
diff --git a/setup.py b/setup.py
index 8401167be..b23485a7f 100644
--- a/setup.py
+++ b/setup.py
@@ -27,7 +27,7 @@
"uuid==1.30",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
"torch==2.4.0+cu124",
- "accelerate=0.33.0"
+ "accelerate"
]
# Add macOS-specific packages if on Darwin (macOS)