Skip to content

Commit

Permalink
Merge pull request #5 from exo-explore/main
Browse files Browse the repository at this point in the history
merge with fork main
  • Loading branch information
risingsunomi authored Aug 24, 2024
2 parents c365749 + f46d077 commit 0b8221f
Show file tree
Hide file tree
Showing 28 changed files with 1,636 additions and 161 deletions.
38 changes: 19 additions & 19 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,24 @@ __pycache__/
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
/.Python
/build/
/develop-eggs/
/dist/
/downloads/
/eggs/
/.eggs/
/lib/
/lib64/
/parts/
/sdist/
/var/
/wheels/
/share/python-wheels/
/*.egg-info/
/.installed.cfg
/*.egg
/MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
Expand Down Expand Up @@ -169,4 +169,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

**/*.xcodeproj/*
**/*.xcodeproj/*
10 changes: 10 additions & 0 deletions exo/download/hf/hf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@

T = TypeVar("T")

async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
refs_dir = get_repo_root(repo_id)/"refs"
refs_file = refs_dir/revision
if await aios.path.exists(refs_file):
async with aiofiles.open(refs_file, 'r') as f:
commit_hash = (await f.read()).strip()
snapshot_dir = get_repo_root(repo_id)/"snapshots"/commit_hash
return snapshot_dir
return None


def filter_repo_objects(
items: Iterable[T],
Expand Down
6 changes: 4 additions & 2 deletions exo/download/hf/hf_shard_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ async def ensure_shard(self, shard: Shard) -> Path:
repo_root = get_repo_root(shard.model_id)
snapshots_dir = repo_root/"snapshots"
if snapshots_dir.exists():
most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
return most_recent_dir
visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
if visible_dirs:
most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime)
return most_recent_dir

# If a download on this shard is already in progress, keep that one
for active_shard in self.active_downloads:
Expand Down
3 changes: 2 additions & 1 deletion exo/inference/tinygrad/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
from exo.inference.shard import Shard
from exo.inference.tokenizers import resolve_tokenizer
from tinygrad.nn.state import safe_load, torch_load, load_state_dict
from tinygrad import Tensor, dtypes, nn, Context
from transformers import AutoTokenizer
Expand Down Expand Up @@ -90,5 +91,5 @@ async def ensure_shard(self, shard: Shard):

model_path = await self.shard_downloader.ensure_shard(shard)
self.model = build_transformer(model_path, shard, model_size="8B" if "8b" in shard.model_id.lower() else "70B")
self.tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
self.tokenizer = await resolve_tokenizer(str((model_path if model_path.is_dir() else model_path.parent)))
self.shard = shard
29 changes: 21 additions & 8 deletions exo/inference/tokenizers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
import traceback
from aiofiles import os as aios
from transformers import AutoTokenizer, AutoProcessor
from exo.download.hf.hf_helpers import get_local_snapshot_dir
from exo.helpers import DEBUG


async def resolve_tokenizer(model_id: str):
local_path = await get_local_snapshot_dir(model_id)
if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
try:
if await aios.path.exists(local_path):
if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}")
return await _resolve_tokenizer(local_path)
except:
if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...")
if DEBUG >= 5: traceback.print_exc()
return await _resolve_tokenizer(model_id)

async def _resolve_tokenizer(model_id_or_local_path: str):
try:
if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
processor = AutoProcessor.from_pretrained(model_id, use_fast=True if "Mistral-Large" in model_id else False)
if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in model_id_or_local_path else False)
if not hasattr(processor, 'eos_token_id'):
processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
if not hasattr(processor, 'encode'):
Expand All @@ -15,14 +28,14 @@ async def resolve_tokenizer(model_id: str):
processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
return processor
except Exception as e:
if DEBUG >= 4: print(f"Failed to load processor for {model_id}. Error: {e}")
if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
if DEBUG >= 4: print(traceback.format_exc())

try:
if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id}")
return AutoTokenizer.from_pretrained(model_id)
if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
return AutoTokenizer.from_pretrained(model_id_or_local_path)
except Exception as e:
if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
if DEBUG >= 4: print(traceback.format_exc())

raise ValueError(f"[TODO] Unsupported model: {model_id}")
raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")
Loading

0 comments on commit 0b8221f

Please sign in to comment.