Skip to content

Commit fad0591

Browse files
Merge pull request #4 from cadenmackenzie/hf_helperRefactor
Hf helper refactor
2 parents 649157d + 6a7de04 commit fad0591

File tree

3 files changed

+103
-35
lines changed

3 files changed

+103
-35
lines changed

exo/api/chatgpt_api.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -231,14 +231,10 @@ async def handle_model_support(self, request):
231231
if DEBUG >= 2:
232232
print(f"Download status for {model_name}: {status}")
233233

234-
# Calculate overall percentage if we have status
235-
download_percentage = None
236-
if status:
237-
percentages = list(status.values())
238-
if percentages:
239-
download_percentage = sum(percentages) / len(percentages)
240-
if DEBUG >= 2:
241-
print(f"Calculated download percentage for {model_name}: {download_percentage}")
234+
# Get overall percentage from status
235+
download_percentage = status.get("overall") if status else None
236+
if DEBUG >= 2 and download_percentage is not None:
237+
print(f"Overall download percentage for {model_name}: {download_percentage}")
242238

243239
model_pool[model_name] = {
244240
"name": pretty,

exo/download/hf/hf_helpers.py

+62-5
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,19 @@ async def download_file(
147147
downloaded_size = local_file_size
148148
downloaded_this_session = 0
149149
mode = 'ab' if use_range_request else 'wb'
150-
if downloaded_size == total_size:
151-
if DEBUG >= 2: print(f"File already downloaded: {file_path}")
152-
if progress_callback:
153-
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
154-
return
150+
percentage = await get_file_download_percentage(
151+
session,
152+
repo_id,
153+
revision,
154+
file_path,
155+
Path(save_directory)
156+
)
157+
158+
if percentage == 100:
159+
if DEBUG >= 2: print(f"File already downloaded: {file_path}")
160+
if progress_callback:
161+
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete"))
162+
return
155163

156164
if response.status == 200:
157165
# File doesn't support range requests or we're not using them, start from beginning
@@ -412,3 +420,52 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
412420
shard_specific_patterns = set("*.safetensors")
413421
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
414422
return list(default_patterns | shard_specific_patterns)
423+
424+
425+
async def get_file_download_percentage(
426+
session: aiohttp.ClientSession,
427+
repo_id: str,
428+
revision: str,
429+
file_path: str,
430+
snapshot_dir: Path
431+
) -> float:
432+
"""
433+
Calculate the download percentage for a file by comparing local and remote sizes.
434+
"""
435+
try:
436+
local_path = snapshot_dir / file_path
437+
if not await aios.path.exists(local_path):
438+
return 0
439+
440+
# Get local file size first
441+
local_size = await aios.path.getsize(local_path)
442+
if local_size == 0:
443+
return 0
444+
445+
# Check remote size
446+
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
447+
url = urljoin(base_url, file_path)
448+
headers = await get_auth_headers()
449+
450+
# Use HEAD request with redirect following for all files
451+
async with session.head(url, headers=headers, allow_redirects=True) as response:
452+
if response.status != 200:
453+
if DEBUG >= 2: print(f"Failed to get remote file info for {file_path}: {response.status}")
454+
return 0
455+
456+
remote_size = int(response.headers.get('Content-Length', 0))
457+
458+
if remote_size == 0:
459+
if DEBUG >= 2: print(f"Remote size is 0 for {file_path}")
460+
return 0
461+
462+
# Only return 100% if sizes match exactly
463+
if local_size == remote_size:
464+
return 100.0
465+
466+
# Calculate percentage based on sizes
467+
return (local_size / remote_size) * 100 if remote_size > 0 else 0
468+
469+
except Exception as e:
470+
if DEBUG >= 2: print(f"Error checking file download status for {file_path}: {e}")
471+
return 0

exo/download/hf/hf_shard_download.py

+37-22
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from exo.download.download_progress import RepoProgressEvent
88
from exo.download.hf.hf_helpers import (
99
download_repo_files, RepoProgressEvent, get_weight_map,
10-
get_allow_patterns, get_repo_root, fetch_file_list, get_local_snapshot_dir
10+
get_allow_patterns, get_repo_root, fetch_file_list,
11+
get_local_snapshot_dir, get_file_download_percentage,
12+
filter_repo_objects
1113
)
1214
from exo.helpers import AsyncCallbackSystem, DEBUG
1315
from exo.models import model_cards, get_repo
@@ -94,6 +96,7 @@ async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
9496
return None
9597

9698
try:
99+
# If no snapshot directory exists, return None - no need to check remote files
97100
snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
98101
if not snapshot_dir:
99102
if DEBUG >= 2: print(f"No snapshot directory found for {self.current_repo_id}")
@@ -105,32 +108,44 @@ async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
105108
if DEBUG >= 2: print(f"No weight map found for {self.current_repo_id}")
106109
return None
107110

108-
# Get the patterns for this shard
111+
# Get all files needed for this shard
109112
patterns = get_allow_patterns(weight_map, self.current_shard)
110113

111-
# First check which files exist locally
114+
# Check download status for all relevant files
112115
status = {}
113-
local_files = []
114-
local_sizes = {}
116+
total_bytes = 0
117+
downloaded_bytes = 0
115118

116-
for pattern in patterns:
117-
if pattern.endswith('safetensors') or pattern.endswith('mlx'):
118-
file_path = snapshot_dir / pattern
119-
if await aios.path.exists(file_path):
120-
local_size = await aios.path.getsize(file_path)
121-
local_files.append(pattern)
122-
local_sizes[pattern] = local_size
123-
124-
# Only fetch remote info if we found local files
125-
if local_files:
126-
async with aiohttp.ClientSession() as session:
127-
file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
119+
async with aiohttp.ClientSession() as session:
120+
file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
121+
relevant_files = list(filter_repo_objects(file_list, allow_patterns=patterns, key=lambda x: x["path"]))
122+
123+
for file in relevant_files:
124+
file_size = file["size"]
125+
total_bytes += file_size
128126

129-
for pattern in local_files:
130-
for file in file_list:
131-
if file["path"].endswith(pattern):
132-
status[pattern] = (local_sizes[pattern] / file["size"]) * 100
133-
break
127+
percentage = await get_file_download_percentage(
128+
session,
129+
self.current_repo_id,
130+
self.revision,
131+
file["path"],
132+
snapshot_dir
133+
)
134+
status[file["path"]] = percentage
135+
downloaded_bytes += (file_size * (percentage / 100))
136+
137+
# Add overall progress weighted by file size
138+
if total_bytes > 0:
139+
status["overall"] = (downloaded_bytes / total_bytes) * 100
140+
else:
141+
status["overall"] = 0
142+
143+
if DEBUG >= 2:
144+
print(f"Download calculation for {self.current_repo_id}:")
145+
print(f"Total bytes: {total_bytes}")
146+
print(f"Downloaded bytes: {downloaded_bytes}")
147+
for file in relevant_files:
148+
print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
134149

135150
return status
136151

0 commit comments

Comments
 (0)