Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

matching main #2

Merged
merged 3 commits into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 90 additions & 39 deletions exo/download/hf/hf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from exo.helpers import DEBUG
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
from exo.inference.shard import Shard
import aiofiles
from aiofiles import os as aios

T = TypeVar("T")
def filter_repo_objects(
Expand Down Expand Up @@ -56,16 +58,17 @@ def get_hf_home() -> Path:
"""Get the Hugging Face home directory."""
return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))

def get_hf_token():
async def get_hf_token():
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
token_path = get_hf_home() / "token"
if token_path.exists():
return token_path.read_text().strip()
if await aios.path.exists(token_path):
async with aiofiles.open(token_path, 'r') as f:
return (await f.read()).strip()
return None

def get_auth_headers():
async def get_auth_headers():
"""Get authentication headers if a token is available."""
token = get_hf_token()
token = await get_hf_token()
if token:
return {"Authorization": f"Bearer {token}"}
return {}
Expand All @@ -79,7 +82,7 @@ async def fetch_file_list(session, repo_id, revision, path=""):
api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url

headers = get_auth_headers()
headers = await get_auth_headers()
async with session.get(url, headers=headers) as response:
if response.status == 200:
data = await response.json()
Expand All @@ -106,12 +109,12 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
url = urljoin(base_url, file_path)
local_path = os.path.join(save_directory, file_path)

os.makedirs(os.path.dirname(local_path), exist_ok=True)
await aios.makedirs(os.path.dirname(local_path), exist_ok=True)

# Check if file already exists and get its size
local_file_size = os.path.getsize(local_path) if os.path.exists(local_path) else 0
local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0

headers = get_auth_headers()
headers = await get_auth_headers()
if use_range_request:
headers["Range"] = f"bytes={local_file_size}-"

Expand Down Expand Up @@ -162,9 +165,9 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:

DOWNLOAD_CHUNK_SIZE = 32768
start_time = datetime.now()
with open(local_path, mode) as f:
async with aiofiles.open(local_path, mode) as f:
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
f.write(chunk)
await f.write(chunk)
downloaded_size += len(chunk)
downloaded_this_session += len(chunk)
if progress_callback and total_size:
Expand All @@ -177,41 +180,82 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
if DEBUG >= 2: print(f"Downloaded: {file_path}")

async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None) -> Path:
async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, max_parallel_downloads: int = 4) -> Path:
repo_root = get_repo_root(repo_id)
refs_dir = repo_root / "refs"
snapshots_dir = repo_root / "snapshots"
cachedreqs_dir = repo_root / "cachedreqs"

# Ensure directories exist
refs_dir.mkdir(parents=True, exist_ok=True)
snapshots_dir.mkdir(parents=True, exist_ok=True)
await aios.makedirs(refs_dir, exist_ok=True)
await aios.makedirs(snapshots_dir, exist_ok=True)
await aios.makedirs(cachedreqs_dir, exist_ok=True)

# Check if we have a cached commit hash
refs_file = refs_dir / revision
if await aios.path.exists(refs_file):
async with aiofiles.open(refs_file, 'r') as f:
commit_hash = (await f.read()).strip()
if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}")
else:
async with aiohttp.ClientSession() as session:
# Fetch the commit hash for the given revision
api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
headers = await get_auth_headers()
async with session.get(api_url, headers=headers) as response:
if response.status != 200:
raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
revision_info = await response.json()
commit_hash = revision_info['sha']

# Cache the commit hash
async with aiofiles.open(refs_file, 'w') as f:
await f.write(commit_hash)

# Set up the snapshot directory
snapshot_dir = snapshots_dir / commit_hash
await aios.makedirs(snapshot_dir, exist_ok=True)

# Set up the cached file list directory
cached_file_list_dir = cachedreqs_dir / commit_hash
await aios.makedirs(cached_file_list_dir, exist_ok=True)
cached_file_list_path = cached_file_list_dir / "fetch_file_list.json"

async with aiohttp.ClientSession() as session:
# Fetch the commit hash for the given revision
api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
headers = get_auth_headers()
async with session.get(api_url, headers=headers) as response:
if response.status != 200:
raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
revision_info = await response.json()
commit_hash = revision_info['sha']

# Write the commit hash to the refs file
refs_file = refs_dir / revision
refs_file.write_text(commit_hash)

# Set up the snapshot directory
snapshot_dir = snapshots_dir / commit_hash
snapshot_dir.mkdir(exist_ok=True)

file_list = await fetch_file_list(session, repo_id, revision)
# Check if we have a cached file list
if await aios.path.exists(cached_file_list_path):
async with aiofiles.open(cached_file_list_path, 'r') as f:
file_list = json.loads(await f.read())
if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
else:
file_list = await fetch_file_list(session, repo_id, revision)
# Cache the file list
async with aiofiles.open(cached_file_list_path, 'w') as f:
await f.write(json.dumps(file_list))
if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")

filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
total_files = len(filtered_file_list)
total_bytes = sum(file["size"] for file in filtered_file_list)
file_progress: Dict[str, RepoFileProgressEvent] = {file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
start_time = datetime.now()

async def download_with_progress(file_info, progress_state):
local_path = snapshot_dir / file_info["path"]
if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
progress_state['completed_files'] += 1
progress_state['downloaded_bytes'] += file_info["size"]
file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
if progress_callback:
elapsed_time = (datetime.now() - start_time).total_seconds()
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
return

async def file_progress_callback(event: RepoFileProgressEvent):
progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
Expand All @@ -236,7 +280,12 @@ async def file_progress_callback(event: RepoFileProgressEvent):
await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))

progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
tasks = [download_with_progress(file_info, progress_state) for file_info in filtered_file_list]

semaphore = asyncio.Semaphore(max_parallel_downloads)
async def download_with_semaphore(file_info):
async with semaphore:
await download_with_progress(file_info, progress_state)
tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
await asyncio.gather(*tasks)

return snapshot_dir
Expand All @@ -263,12 +312,14 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[
# Check if the file exists
repo_root = get_repo_root(repo_id)
snapshot_dir = repo_root / "snapshots"
index_file = next(snapshot_dir.glob("*/model.safetensors.index.json"), None)

if index_file and index_file.exists():
with open(index_file, 'r') as f:
index_data = json.load(f)
return index_data.get("weight_map")
index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)

if index_file:
index_file_path = snapshot_dir / index_file
if await aios.path.exists(index_file_path):
async with aiofiles.open(index_file_path, 'r') as f:
index_data = json.loads(await f.read())
return index_data.get("weight_map")

return None

Expand Down
6 changes: 4 additions & 2 deletions exo/download/hf/hf_shard_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from exo.helpers import AsyncCallbackSystem, DEBUG

class HFShardDownloader(ShardDownloader):
def __init__(self, quick_check: bool = False):
def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
self.quick_check = quick_check
self.max_parallel_downloads = max_parallel_downloads
self.active_downloads: Dict[Shard, asyncio.Task] = {}
self.completed_downloads: Dict[Shard, Path] = {}
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
Expand Down Expand Up @@ -69,7 +70,8 @@ async def wrapped_progress_callback(event: RepoProgressEvent):
return await download_repo_files(
repo_id=shard.model_id,
progress_callback=wrapped_progress_callback,
allow_patterns=allow_patterns
allow_patterns=allow_patterns,
max_parallel_downloads=self.max_parallel_downloads
)

@property
Expand Down
5 changes: 3 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
parser.add_argument("--node-port", type=int, default=None, help="Node port")
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for shard download")
parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
Expand All @@ -37,7 +38,7 @@
system_info = get_system_info()
print(f"Detected system: {system_info}")

shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check)
shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads)
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
Expand Down