Skip to content

Commit

Permalink
Merge pull request #472 from exo-explore/pyver
Browse files Browse the repository at this point in the history
Update some versions to support Python >= 3.9 and fix tinygrad thread issues
  • Loading branch information
AlexCheema authored Nov 19, 2024
2 parents a018de7 + 312602f commit 2dafa9c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 14 deletions.
2 changes: 1 addition & 1 deletion exo/download/hf/hf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
elif shard.is_last_layer():
shard_specific_patterns.add(sorted_file_names[-1])
else:
shard_specific_patterns = set("*.safetensors")
shard_specific_patterns = set(["*.safetensors"])
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)

Expand Down
14 changes: 5 additions & 9 deletions exo/inference/tinygrad/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from tinygrad.nn.state import load_state_dict
from tinygrad import Tensor, nn, Context
from exo.inference.inference_engine import InferenceEngine
from typing import Optional, Tuple
import numpy as np
from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
from exo.download.shard_download import ShardDownloader
Expand Down Expand Up @@ -68,24 +67,21 @@ def __init__(self, shard_downloader: ShardDownloader):
async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
logits = x[:, -1, :]
def sample_wrapper():
return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize()
out = await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
return out.numpy().astype(int)
return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)

async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
await self.ensure_shard(shard)
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
return np.array(tokens)
return await asyncio.get_running_loop().run_in_executor(self.executor, np.array, tokens)

async def decode(self, shard: Shard, tokens) -> str:
await self.ensure_shard(shard)
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
return tokens
return await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)

async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
await self.ensure_shard(shard)
output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize())
return output_data.numpy()
return await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize().numpy())

async def ensure_shard(self, shard: Shard):
if self.shard == shard:
Expand Down
7 changes: 3 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
"aiohttp==3.10.11",
"aiohttp_cors==0.7.0",
"aiofiles==24.1.0",
"grpcio==1.64.1",
"grpcio-tools==1.64.1",
"grpcio==1.68.0",
"grpcio-tools==1.68.0",
"Jinja2==3.1.4",
"netifaces==0.11.0",
"numpy==2.0.0",
Expand All @@ -21,10 +21,9 @@
"pydantic==2.9.2",
"requests==2.32.3",
"rich==13.7.1",
"safetensors==0.4.3",
"tenacity==9.0.0",
"tqdm==4.66.4",
"transformers==4.43.3",
"transformers==4.46.3",
"uuid==1.30",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
]
Expand Down

0 comments on commit 2dafa9c

Please sign in to comment.