Skip to content

Commit

Permalink
cleaning up code, added pytorch engine to llama 3.2 1b model shard in…
Browse files Browse the repository at this point in the history
… models.py, removed old 3.2 1b model shard, moving to test server for more vram
  • Loading branch information
risingsunomi committed Oct 6, 2024
1 parent 57e14e8 commit 9fe3ec6
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 15 deletions.
21 changes: 12 additions & 9 deletions exo/inference/pytorch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from exo.helpers import DEBUG
from exo.download.hf.hf_shard_download import HFShardDownloader

# llama
from transformers.models.llama.modeling_llama import LlamaModel

# model value options
TOP_K = 20
TEMP = 0.6
Expand Down Expand Up @@ -52,7 +55,7 @@ def __init__(self, shard_downloader: HFShardDownloader):

if torch.cuda.is_available():
self.device = torch.device("cuda")
self.torch_dtype = torch.float32
self.torch_dtype = torch.float16
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
self.torch_dtype = torch.float32
Expand Down Expand Up @@ -105,10 +108,10 @@ async def infer_prompt(
print(f"prompt: {prompt}")
print(f"shard: {shard}")
print(f"inference_state: {inference_state}")

await self.ensure_shard(shard)
# setup prompt input

# setup prompt input
messages = [{"role": "user", "content": prompt}]
txt = self.tokenizer.apply_chat_template(
messages,
Expand Down Expand Up @@ -174,9 +177,9 @@ async def infer_prompt(

async def infer_tensor(
self,
request_id: str,
shard: Shard,
input_data: np.ndarray,
request_id: str,
shard: Shard,
input_data: np.ndarray,
inference_state: Optional[str] = None
) -> Tuple[np.ndarray, str, bool]:
if DEBUG >= 4:
Expand All @@ -192,13 +195,13 @@ async def infer_tensor(
# get cache from inference_state
past_iids, cached_iids = self.infer_caching(inference_state)

# detect if hidden_states or not
# detect if hidden_states or not
hidden_states = None
self.past_input_ids = None
if input_ids.size()[-1] > 1:
hidden_states = input_ids
else:
if past_iids is not None:
if past_iids is not None:
self.past_input_ids = past_iids
else:
self.past_input_ids = input_ids
Expand Down
2 changes: 1 addition & 1 deletion exo/inference/pytorch/model/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
TemperatureLogitsWarper
)

# llama
# llama
from transformers.models.llama.modeling_llama import LlamaModel

class ShardedHuggingFaceModel:
Expand Down
4 changes: 1 addition & 3 deletions exo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
### llama
"llama-3.2-1b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16),
"PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16),
},
"llama-3.2-3b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
Expand Down Expand Up @@ -36,9 +37,6 @@
"llama-3-1B-Base": {
"PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-1B-Base", start_layer=0, end_layer=0, n_layers=3),
},
"meta-llama/Llama-3.2-1B-Instruct": {
"PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=24),
},
### mistral
"mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),},
"mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),},
Expand Down
1 change: 0 additions & 1 deletion exo/tinychat/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
<option value="llama-3.1-405b">Llama 3.1 405B</option>
<option value="llama-3-8b">Llama 3 8B</option>
<option value="llama-3-70b">Llama 3 70B</option>
<option value="Llama-3.2-1B-Instruct">Llama-3.2-1B-Instruct</option>
<option value="mistral-nemo">Mistral Nemo</option>
<option value="mistral-large">Mistral Large</option>
<option value="deepseek-coder-v2-lite">Deepseek Coder V2 Lite</option>
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"uuid==1.30",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
"torch==2.4.0+cu124",
"accelerate=0.33.0"
"accelerate"
]

# Add macOS-specific packages if on Darwin (macOS)
Expand Down

0 comments on commit 9fe3ec6

Please sign in to comment.