7
7
from exo .download .download_progress import RepoProgressEvent
8
8
from exo .download .hf .hf_helpers import (
9
9
download_repo_files , RepoProgressEvent , get_weight_map ,
10
- get_allow_patterns , get_repo_root , fetch_file_list , get_local_snapshot_dir
10
+ get_allow_patterns , get_repo_root , fetch_file_list ,
11
+ get_local_snapshot_dir , get_file_download_percentage ,
12
+ filter_repo_objects
11
13
)
12
14
from exo .helpers import AsyncCallbackSystem , DEBUG
13
15
from exo .models import model_cards , get_repo
@@ -94,6 +96,7 @@ async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
94
96
return None
95
97
96
98
try :
99
+ # If no snapshot directory exists, return None - no need to check remote files
97
100
snapshot_dir = await get_local_snapshot_dir (self .current_repo_id , self .revision )
98
101
if not snapshot_dir :
99
102
if DEBUG >= 2 : print (f"No snapshot directory found for { self .current_repo_id } " )
@@ -105,32 +108,44 @@ async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
105
108
if DEBUG >= 2 : print (f"No weight map found for { self .current_repo_id } " )
106
109
return None
107
110
108
- # Get the patterns for this shard
111
+ # Get all files needed for this shard
109
112
patterns = get_allow_patterns (weight_map , self .current_shard )
110
113
111
- # First check which files exist locally
114
+ # Check download status for all relevant files
112
115
status = {}
113
- local_files = []
114
- local_sizes = {}
116
+ total_bytes = 0
117
+ downloaded_bytes = 0
115
118
116
- for pattern in patterns :
117
- if pattern .endswith ('safetensors' ) or pattern .endswith ('mlx' ):
118
- file_path = snapshot_dir / pattern
119
- if await aios .path .exists (file_path ):
120
- local_size = await aios .path .getsize (file_path )
121
- local_files .append (pattern )
122
- local_sizes [pattern ] = local_size
123
-
124
- # Only fetch remote info if we found local files
125
- if local_files :
126
- async with aiohttp .ClientSession () as session :
127
- file_list = await fetch_file_list (session , self .current_repo_id , self .revision )
119
+ async with aiohttp .ClientSession () as session :
120
+ file_list = await fetch_file_list (session , self .current_repo_id , self .revision )
121
+ relevant_files = list (filter_repo_objects (file_list , allow_patterns = patterns , key = lambda x : x ["path" ]))
122
+
123
+ for file in relevant_files :
124
+ file_size = file ["size" ]
125
+ total_bytes += file_size
128
126
129
- for pattern in local_files :
130
- for file in file_list :
131
- if file ["path" ].endswith (pattern ):
132
- status [pattern ] = (local_sizes [pattern ] / file ["size" ]) * 100
133
- break
127
+ percentage = await get_file_download_percentage (
128
+ session ,
129
+ self .current_repo_id ,
130
+ self .revision ,
131
+ file ["path" ],
132
+ snapshot_dir
133
+ )
134
+ status [file ["path" ]] = percentage
135
+ downloaded_bytes += (file_size * (percentage / 100 ))
136
+
137
+ # Add overall progress weighted by file size
138
+ if total_bytes > 0 :
139
+ status ["overall" ] = (downloaded_bytes / total_bytes ) * 100
140
+ else :
141
+ status ["overall" ] = 0
142
+
143
+ if DEBUG >= 2 :
144
+ print (f"Download calculation for { self .current_repo_id } :" )
145
+ print (f"Total bytes: { total_bytes } " )
146
+ print (f"Downloaded bytes: { downloaded_bytes } " )
147
+ for file in relevant_files :
148
+ print (f"File { file ['path' ]} : size={ file ['size' ]} , percentage={ status [file ['path' ]]} " )
134
149
135
150
return status
136
151
0 commit comments