Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show downloaded models, improve error handling #456

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
c7dd312
adding logic to check which models are downloaded
cadenmackenzie Nov 13, 2024
de09e2a
reusing helper function to get cached directory
cadenmackenzie Nov 13, 2024
7d7bdd8
removing uneccesary console logs and fixing order of variables in ind…
cadenmackenzie Nov 13, 2024
fb32a85
removing error separtation so I can put in different PR
cadenmackenzie Nov 13, 2024
59f5b6d
adding back in set error message
cadenmackenzie Nov 13, 2024
25d67f5
cleaning up logging in index.js
cadenmackenzie Nov 13, 2024
95ce665
removing unneccesary css
cadenmackenzie Nov 13, 2024
3eb726c
removing sorting of models by name
cadenmackenzie Nov 13, 2024
cbeb1b3
fix safari issue
dtnewman Nov 14, 2024
372d873
Merge pull request #1 from dtnewman/dn/downloadModelsV2
cadenmackenzie Nov 14, 2024
d9aabd7
working versions
cadenmackenzie Nov 14, 2024
dfcf513
removing is_model_downloaded method and changing how downloaded varia…
cadenmackenzie Nov 14, 2024
972074e
reducing redundent checks
cadenmackenzie Nov 14, 2024
dd38924
removing checking of percentage for models that are not found locally
cadenmackenzie Nov 14, 2024
bd2985a
Merge pull request #2 from cadenmackenzie/downloadedModelsV2Revisions
cadenmackenzie Nov 14, 2024
649157d
creating HFShardDownloader with quick_check true so it doesnt start d…
cadenmackenzie Nov 17, 2024
c923ef6
modifying how its being displayed becuase now calculating overall per…
cadenmackenzie Nov 18, 2024
c61f40c
adding helper funciton to check file download. also modifying downloa…
cadenmackenzie Nov 18, 2024
dec79ac
modify get_shard_download_status to use helper function
cadenmackenzie Nov 18, 2024
4c6fda7
modifying helper fucntion checking size to follow redirect for .safet…
cadenmackenzie Nov 18, 2024
3ac8687
adding redirect for all requests
cadenmackenzie Nov 18, 2024
3256051
comment
cadenmackenzie Nov 18, 2024
db610f5
removing traceback
cadenmackenzie Nov 18, 2024
6a7de04
removing path update
cadenmackenzie Nov 18, 2024
fad0591
Merge pull request #4 from cadenmackenzie/hf_helperRefactor
cadenmackenzie Nov 18, 2024
b77362b
moving os import
cadenmackenzie Nov 18, 2024
695ab34
removing import get_hf_home
cadenmackenzie Nov 18, 2024
8135437
fixing formatting
cadenmackenzie Nov 19, 2024
91276cc
fixing formatting
cadenmackenzie Nov 19, 2024
8ee6cc3
yapf formatting
cadenmackenzie Nov 19, 2024
0d50167
yapf in download_file
cadenmackenzie Nov 19, 2024
2cdd55d
Merge branch 'main' into downloadedModelsV2
cadenmackenzie Nov 21, 2024
1ca11ea
defining optional
cadenmackenzie Nov 21, 2024
7a8c722
Merge pull request #5 from cadenmackenzie/main
cadenmackenzie Nov 21, 2024
7e6c69f
remvoing console log
cadenmackenzie Nov 21, 2024
39139c1
fixiing required engines definition
cadenmackenzie Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 49 additions & 8 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import asyncio
import json
import os
from pathlib import Path
from transformers import AutoTokenizer
from typing import List, Literal, Union, Dict
Expand All @@ -15,8 +16,9 @@
from exo.helpers import PrefixDict, shutdown
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration import Node
from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
from typing import Callable, Optional
from exo.download.hf.hf_shard_download import HFShardDownloader

class Message:
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
Expand Down Expand Up @@ -213,13 +215,52 @@ async def handle_healthcheck(self, request):
return web.json_response({"status": "ok"})

async def handle_model_support(self, request):
return web.json_response({
"model pool": {
model_name: pretty_name.get(model_name, model_name)
for model_name in get_supported_models(self.node.topology_inference_engines_pool)
}
})

try:
model_pool = {}

for model_name, pretty in pretty_name.items():
if model_name in model_cards:
model_info = model_cards[model_name]

# Get required engines from the node's topology directly
required_engines = list(dict.fromkeys(
[engine_name for engine_list in self.node.topology_inference_engines_pool
for engine_name in engine_list
if engine_name is not None] +
[self.inference_engine_classname]
))
# Check if model supports required engines
if all(map(lambda engine: engine in model_info["repo"], required_engines)):
shard = build_base_shard(model_name, self.inference_engine_classname)
if shard:
# Use HFShardDownloader to check status without initiating download
downloader = HFShardDownloader(quick_check=True) # quick_check=True prevents downloads
downloader.current_shard = shard
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
status = await downloader.get_shard_download_status()
if DEBUG >= 2:
print(f"Download status for {model_name}: {status}")

# Get overall percentage from status
download_percentage = status.get("overall") if status else None
if DEBUG >= 2 and download_percentage is not None:
print(f"Overall download percentage for {model_name}: {download_percentage}")

model_pool[model_name] = {
"name": pretty,
"downloaded": download_percentage == 100 if download_percentage is not None else False,
"download_percentage": download_percentage
}

return web.json_response({"model pool": model_pool})
except Exception as e:
print(f"Error in handle_model_support: {str(e)}")
traceback.print_exc()
return web.json_response(
{"detail": f"Server error: {str(e)}"},
status=500
)

async def handle_get_models(self, request):
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])

Expand Down
64 changes: 62 additions & 2 deletions exo/download/hf/hf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,18 @@ async def download_file(
downloaded_size = local_file_size
downloaded_this_session = 0
mode = 'ab' if use_range_request else 'wb'
if downloaded_size == total_size:
percentage = await get_file_download_percentage(
session,
repo_id,
revision,
file_path,
Path(save_directory)
)

if percentage == 100:
if DEBUG >= 2: print(f"File already downloaded: {file_path}")
if progress_callback:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete"))
return

if response.status == 200:
Expand Down Expand Up @@ -429,6 +437,57 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)

async def get_file_download_percentage(
session: aiohttp.ClientSession,
repo_id: str,
revision: str,
file_path: str,
snapshot_dir: Path,
) -> float:
"""
Calculate the download percentage for a file by comparing local and remote sizes.
"""
try:
local_path = snapshot_dir / file_path
if not await aios.path.exists(local_path):
return 0

# Get local file size first
local_size = await aios.path.getsize(local_path)
if local_size == 0:
return 0

# Check remote size
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
url = urljoin(base_url, file_path)
headers = await get_auth_headers()

# Use HEAD request with redirect following for all files
async with session.head(url, headers=headers, allow_redirects=True) as response:
if response.status != 200:
if DEBUG >= 2:
print(f"Failed to get remote file info for {file_path}: {response.status}")
return 0

remote_size = int(response.headers.get('Content-Length', 0))

if remote_size == 0:
if DEBUG >= 2:
print(f"Remote size is 0 for {file_path}")
return 0

# Only return 100% if sizes match exactly
if local_size == remote_size:
return 100.0

# Calculate percentage based on sizes
return (local_size / remote_size) * 100 if remote_size > 0 else 0

except Exception as e:
if DEBUG >= 2:
print(f"Error checking file download status for {file_path}: {e}")
return 0

async def has_hf_home_read_access() -> bool:
hf_home = get_hf_home()
try: return await aios.access(hf_home, os.R_OK)
Expand All @@ -438,3 +497,4 @@ async def has_hf_home_write_access() -> bool:
hf_home = get_hf_home()
try: return await aios.access(hf_home, os.W_OK)
except OSError: return False

86 changes: 84 additions & 2 deletions exo/download/hf/hf_shard_download.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import asyncio
import traceback
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Optional
from exo.inference.shard import Shard
from exo.download.shard_download import ShardDownloader
from exo.download.download_progress import RepoProgressEvent
from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
from exo.download.hf.hf_helpers import (
download_repo_files, RepoProgressEvent, get_weight_map,
get_allow_patterns, get_repo_root, fetch_file_list,
get_local_snapshot_dir, get_file_download_percentage,
filter_repo_objects
)
from exo.helpers import AsyncCallbackSystem, DEBUG
from exo.models import model_cards, get_repo
import aiohttp
from aiofiles import os as aios


class HFShardDownloader(ShardDownloader):
Expand All @@ -17,8 +24,13 @@ def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
self.active_downloads: Dict[Shard, asyncio.Task] = {}
self.completed_downloads: Dict[Shard, Path] = {}
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
self.current_shard: Optional[Shard] = None
self.current_repo_id: Optional[str] = None
self.revision: str = "main"

async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
self.current_shard = shard
self.current_repo_id = get_repo(shard.model_id, inference_engine_name)
repo_name = get_repo(shard.model_id, inference_engine_name)
if shard in self.completed_downloads:
return self.completed_downloads[shard]
Expand Down Expand Up @@ -77,3 +89,73 @@ async def wrapped_progress_callback(event: RepoProgressEvent):
@property
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return self._on_progress

async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally there seems to be a lot of duplication between this and other functions. I think a refactor should be done here.

if not self.current_shard or not self.current_repo_id:
if DEBUG >= 2:
print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
return None

try:
# If no snapshot directory exists, return None - no need to check remote files
snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
if not snapshot_dir:
if DEBUG >= 2:
print(f"No snapshot directory found for {self.current_repo_id}")
return None

# Get the weight map to know what files we need
weight_map = await get_weight_map(self.current_repo_id, self.revision)
if not weight_map:
if DEBUG >= 2:
print(f"No weight map found for {self.current_repo_id}")
return None

# Get all files needed for this shard
patterns = get_allow_patterns(weight_map, self.current_shard)

# Check download status for all relevant files
status = {}
total_bytes = 0
downloaded_bytes = 0

async with aiohttp.ClientSession() as session:
file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
relevant_files = list(
filter_repo_objects(
file_list, allow_patterns=patterns, key=lambda x: x["path"]))

for file in relevant_files:
file_size = file["size"]
total_bytes += file_size

percentage = await get_file_download_percentage(
session,
self.current_repo_id,
self.revision,
file["path"],
snapshot_dir,
)
status[file["path"]] = percentage
downloaded_bytes += (file_size * (percentage / 100))

# Add overall progress weighted by file size
if total_bytes > 0:
status["overall"] = (downloaded_bytes / total_bytes) * 100
else:
status["overall"] = 0

if DEBUG >= 2:
print(f"Download calculation for {self.current_repo_id}:")
print(f"Total bytes: {total_bytes}")
print(f"Downloaded bytes: {downloaded_bytes}")
for file in relevant_files:
print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")

return status

except Exception as e:
if DEBUG >= 2:
print(f"Error getting shard download status: {e}")
traceback.print_exc()
return None
12 changes: 11 additions & 1 deletion exo/download/shard_download.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional, Tuple
from typing import Optional, Tuple, Dict
from pathlib import Path
from exo.inference.shard import Shard
from exo.download.download_progress import RepoProgressEvent
Expand All @@ -26,6 +26,16 @@ async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
pass

@abstractmethod
async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
"""Get the download status of shards.

Returns:
Optional[Dict[str, float]]: A dictionary mapping shard IDs to their download percentage (0-100),
or None if status cannot be determined
"""
pass


class NoopShardDownloader(ShardDownloader):
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
Expand Down
10 changes: 5 additions & 5 deletions exo/tinychat/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
<body>
<main x-data="state" x-init="console.log(endpoint)">
<!-- Error Toast -->
<div x-show="errorMessage" x-transition.opacity class="toast">
<div x-show="errorMessage !== null" x-transition.opacity class="toast">
<div class="toast-header">
<span class="toast-error-message" x-text="errorMessage.basic"></span>
<span class="toast-error-message" x-text="errorMessage?.basic || ''"></span>
<div class="toast-header-buttons">
<button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }"
class="toast-expand-button"
x-show="errorMessage.stack">
x-show="errorMessage?.stack">
<span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
</button>
<button @click="errorMessage = null; errorExpanded = false;" class="toast-close-button">
Expand All @@ -41,11 +41,11 @@
</div>
</div>
<div class="toast-content" x-show="errorExpanded" x-transition>
<span x-text="errorMessage.stack"></span>
<span x-text="errorMessage?.stack || ''"></span>
</div>
</div>
<div class="model-selector">
<select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" x-init="await populateSelector()" class='model-select'>
<select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" class='model-select'>
</select>
</div>
<div @popstate.window="
Expand Down
Loading