1
1
import asyncio
2
2
import aiohttp
3
+ import json
3
4
import os
4
5
from urllib .parse import urljoin
5
6
from typing import Callable , Optional , Coroutine , Any , Dict , List , Union , Literal
6
7
from datetime import datetime , timedelta
7
8
from fnmatch import fnmatch
8
9
from pathlib import Path
9
10
from typing import Generator , Iterable , TypeVar , TypedDict
10
- from dataclasses import dataclass
11
11
from tenacity import retry , stop_after_attempt , wait_exponential , retry_if_exception_type
12
12
from exo .helpers import DEBUG
13
+ from exo .download .download_progress import RepoProgressEvent , RepoFileProgressEvent , RepoProgressCallback , RepoFileProgressCallback
13
14
14
15
T = TypeVar ("T" )
15
16
def filter_repo_objects (
@@ -21,10 +22,8 @@ def filter_repo_objects(
21
22
) -> Generator [T , None , None ]:
22
23
if isinstance (allow_patterns , str ):
23
24
allow_patterns = [allow_patterns ]
24
-
25
25
if isinstance (ignore_patterns , str ):
26
26
ignore_patterns = [ignore_patterns ]
27
-
28
27
if allow_patterns is not None :
29
28
allow_patterns = [_add_wildcard_to_directories (p ) for p in allow_patterns ]
30
29
if ignore_patterns is not None :
@@ -37,18 +36,14 @@ def _identity(item: T) -> str:
37
36
if isinstance (item , Path ):
38
37
return str (item )
39
38
raise ValueError (f"Please provide `key` argument in `filter_repo_objects`: `{ item } ` is not a string." )
40
-
41
39
key = _identity
42
40
43
41
for item in items :
44
42
path = key (item )
45
-
46
43
if allow_patterns is not None and not any (fnmatch (path , r ) for r in allow_patterns ):
47
44
continue
48
-
49
45
if ignore_patterns is not None and any (fnmatch (path , r ) for r in ignore_patterns ):
50
46
continue
51
-
52
47
yield item
53
48
54
49
def _add_wildcard_to_directories (pattern : str ) -> str :
@@ -99,84 +94,13 @@ async def fetch_file_list(session, repo_id, revision, path=""):
99
94
raise Exception (f"Failed to fetch file list: { response .status } " )
100
95
101
96
102
- @dataclass
103
- class HFRepoFileProgressEvent :
104
- file_path : str
105
- downloaded : int
106
- downloaded_this_session : int
107
- total : int
108
- speed : int
109
- eta : timedelta
110
- status : Literal ["not_started" , "in_progress" , "complete" ]
111
-
112
- def to_dict (self ):
113
- return {
114
- "file_path" : self .file_path ,
115
- "downloaded" : self .downloaded ,
116
- "downloaded_this_session" : self .downloaded_this_session ,
117
- "total" : self .total ,
118
- "speed" : self .speed ,
119
- "eta" : self .eta .total_seconds (),
120
- "status" : self .status
121
- }
122
-
123
- @classmethod
124
- def from_dict (cls , data ):
125
- # Convert eta from seconds back to timedelta
126
- if 'eta' in data :
127
- data ['eta' ] = timedelta (seconds = data ['eta' ])
128
- return cls (** data )
129
-
130
- @dataclass
131
- class HFRepoProgressEvent :
132
- completed_files : int
133
- total_files : int
134
- downloaded_bytes : int
135
- downloaded_bytes_this_session : int
136
- total_bytes : int
137
- overall_speed : int
138
- overall_eta : timedelta
139
- file_progress : Dict [str , HFRepoFileProgressEvent ]
140
- status : Literal ["not_started" , "in_progress" , "complete" ]
141
-
142
- def to_dict (self ):
143
- return {
144
- "completed_files" : self .completed_files ,
145
- "total_files" : self .total_files ,
146
- "downloaded_bytes" : self .downloaded_bytes ,
147
- "downloaded_bytes_this_session" : self .downloaded_bytes_this_session ,
148
- "total_bytes" : self .total_bytes ,
149
- "overall_speed" : self .overall_speed ,
150
- "overall_eta" : self .overall_eta .total_seconds (),
151
- "file_progress" : {k : v .to_dict () for k , v in self .file_progress .items ()},
152
- "status" : self .status
153
- }
154
-
155
- @classmethod
156
- def from_dict (cls , data ):
157
- # Convert overall_eta from seconds back to timedelta
158
- if 'overall_eta' in data :
159
- data ['overall_eta' ] = timedelta (seconds = data ['overall_eta' ])
160
-
161
- # Parse file_progress
162
- if 'file_progress' in data :
163
- data ['file_progress' ] = {
164
- k : HFRepoFileProgressEvent .from_dict (v )
165
- for k , v in data ['file_progress' ].items ()
166
- }
167
-
168
- return cls (** data )
169
-
170
- HFRepoFileProgressCallback = Callable [[HFRepoFileProgressEvent ], Coroutine [Any , Any , None ]]
171
- HFRepoProgressCallback = Callable [[HFRepoProgressEvent ], Coroutine [Any , Any , None ]]
172
-
173
97
@retry (
174
98
stop = stop_after_attempt (5 ),
175
99
wait = wait_exponential (multiplier = 1 , min = 4 , max = 60 ),
176
100
retry = retry_if_exception_type ((aiohttp .ClientError , asyncio .TimeoutError , aiohttp .ClientResponseError )),
177
101
reraise = True
178
102
)
179
- async def download_file (session : aiohttp .ClientSession , repo_id : str , revision : str , file_path : str , save_directory : str , progress_callback : Optional [HFRepoFileProgressCallback ] = None , use_range_request : bool = True ):
103
+ async def download_file (session : aiohttp .ClientSession , repo_id : str , revision : str , file_path : str , save_directory : str , progress_callback : Optional [RepoFileProgressCallback ] = None , use_range_request : bool = True ):
180
104
base_url = f"https://huggingface.co/{ repo_id } /resolve/{ revision } /"
181
105
url = urljoin (base_url , file_path )
182
106
local_path = os .path .join (save_directory , file_path )
@@ -198,7 +122,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
198
122
if downloaded_size == total_size :
199
123
if DEBUG >= 2 : print (f"File already downloaded: { file_path } " )
200
124
if progress_callback :
201
- await progress_callback (HFRepoFileProgressEvent (file_path , downloaded_size , downloaded_this_session , total_size , 0 , timedelta (0 ), "complete" ))
125
+ await progress_callback (RepoFileProgressEvent (file_path , downloaded_size , downloaded_this_session , total_size , 0 , timedelta (0 ), "complete" ))
202
126
return
203
127
204
128
if response .status == 200 :
@@ -221,7 +145,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
221
145
if downloaded_size == total_size :
222
146
if DEBUG >= 2 : print (f"File fully downloaded on first pass: { file_path } " )
223
147
if progress_callback :
224
- await progress_callback (HFRepoFileProgressEvent (file_path , downloaded_size , downloaded_this_session , total_size , 0 , timedelta (0 ), "complete" ))
148
+ await progress_callback (RepoFileProgressEvent (file_path , downloaded_size , downloaded_this_session , total_size , 0 , timedelta (0 ), "complete" ))
225
149
return
226
150
except ValueError :
227
151
if DEBUG >= 1 : print (f"Failed to parse Content-Range header: { content_range } . Starting download from scratch..." )
@@ -232,7 +156,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
232
156
if downloaded_size == total_size :
233
157
print (f"File already downloaded: { file_path } " )
234
158
if progress_callback :
235
- await progress_callback (HFRepoFileProgressEvent (file_path , downloaded_size , downloaded_this_session , total_size , 0 , timedelta (0 ), "complete" ))
159
+ await progress_callback (RepoFileProgressEvent (file_path , downloaded_size , downloaded_this_session , total_size , 0 , timedelta (0 ), "complete" ))
236
160
return
237
161
238
162
DOWNLOAD_CHUNK_SIZE = 32768
@@ -249,10 +173,10 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
249
173
eta = timedelta (seconds = remaining_size / speed ) if speed > 0 else timedelta (0 )
250
174
status = "in_progress" if downloaded_size < total_size else "complete"
251
175
if DEBUG >= 8 : print (f"HF repo file download progress: { file_path = } { elapsed_time = } { speed = } Downloaded={ downloaded_size } /{ total_size } { remaining_size = } { eta = } { status = } " )
252
- await progress_callback (HFRepoFileProgressEvent (file_path , downloaded_size , downloaded_this_session , total_size , speed , eta , status ))
176
+ await progress_callback (RepoFileProgressEvent (file_path , downloaded_size , downloaded_this_session , total_size , speed , eta , status ))
253
177
if DEBUG >= 2 : print (f"Downloaded: { file_path } " )
254
178
255
- async def download_all_files (repo_id : str , revision : str = "main" , progress_callback : Optional [HFRepoProgressCallback ] = None , allow_patterns : Optional [Union [List [str ], str ]] = None , ignore_patterns : Optional [Union [List [str ], str ]] = None ):
179
+ async def download_repo_files (repo_id : str , revision : str = "main" , progress_callback : Optional [RepoProgressCallback ] = None , allow_patterns : Optional [Union [List [str ], str ]] = None , ignore_patterns : Optional [Union [List [str ], str ]] = None ) -> Path :
256
180
repo_root = get_repo_root (repo_id )
257
181
refs_dir = repo_root / "refs"
258
182
snapshots_dir = repo_root / "snapshots"
@@ -283,11 +207,11 @@ async def download_all_files(repo_id: str, revision: str = "main", progress_call
283
207
filtered_file_list = list (filter_repo_objects (file_list , allow_patterns = allow_patterns , ignore_patterns = ignore_patterns , key = lambda x : x ["path" ]))
284
208
total_files = len (filtered_file_list )
285
209
total_bytes = sum (file ["size" ] for file in filtered_file_list )
286
- file_progress : Dict [str , HFRepoFileProgressEvent ] = {file ["path" ]: HFRepoFileProgressEvent (file ["path" ], 0 , 0 , file ["size" ], 0 , timedelta (0 ), "not_started" ) for file in filtered_file_list }
210
+ file_progress : Dict [str , RepoFileProgressEvent ] = {file ["path" ]: RepoFileProgressEvent (file ["path" ], 0 , 0 , file ["size" ], 0 , timedelta (0 ), "not_started" ) for file in filtered_file_list }
287
211
start_time = datetime .now ()
288
212
289
213
async def download_with_progress (file_info , progress_state ):
290
- async def file_progress_callback (event : HFRepoFileProgressEvent ):
214
+ async def file_progress_callback (event : RepoFileProgressEvent ):
291
215
progress_state ['downloaded_bytes' ] += event .downloaded - file_progress [event .file_path ].downloaded
292
216
progress_state ['downloaded_bytes_this_session' ] += event .downloaded_this_session - file_progress [event .file_path ].downloaded_this_session
293
217
file_progress [event .file_path ] = event
@@ -297,21 +221,60 @@ async def file_progress_callback(event: HFRepoFileProgressEvent):
297
221
remaining_bytes = total_bytes - progress_state ['downloaded_bytes' ]
298
222
overall_eta = timedelta (seconds = remaining_bytes / overall_speed ) if overall_speed > 0 else timedelta (seconds = 0 )
299
223
status = "in_progress" if progress_state ['downloaded_bytes' ] < total_bytes else "complete"
300
- await progress_callback (HFRepoProgressEvent (progress_state ['completed_files' ], total_files , progress_state ['downloaded_bytes' ], progress_state ['downloaded_bytes_this_session' ], total_bytes , overall_speed , overall_eta , file_progress , status ))
224
+ await progress_callback (RepoProgressEvent (progress_state ['completed_files' ], total_files , progress_state ['downloaded_bytes' ], progress_state ['downloaded_bytes_this_session' ], total_bytes , overall_speed , overall_eta , file_progress , status ))
301
225
302
226
await download_file (session , repo_id , revision , file_info ["path" ], snapshot_dir , file_progress_callback )
303
227
progress_state ['completed_files' ] += 1
304
- file_progress [file_info ["path" ]] = HFRepoFileProgressEvent (file_info ["path" ], file_info ["size" ], file_progress [file_info ["path" ]].downloaded_this_session , file_info ["size" ], 0 , timedelta (0 ), "complete" )
228
+ file_progress [file_info ["path" ]] = RepoFileProgressEvent (file_info ["path" ], file_info ["size" ], file_progress [file_info ["path" ]].downloaded_this_session , file_info ["size" ], 0 , timedelta (0 ), "complete" )
305
229
if progress_callback :
306
230
elapsed_time = (datetime .now () - start_time ).total_seconds ()
307
231
overall_speed = int (progress_state ['downloaded_bytes_this_session' ] / elapsed_time ) if elapsed_time > 0 else 0
308
232
remaining_bytes = total_bytes - progress_state ['downloaded_bytes' ]
309
233
overall_eta = timedelta (seconds = remaining_bytes / overall_speed ) if overall_speed > 0 else timedelta (seconds = 0 )
310
234
status = "in_progress" if progress_state ['completed_files' ] < total_files else "complete"
311
- await progress_callback (HFRepoProgressEvent (progress_state ['completed_files' ], total_files , progress_state ['downloaded_bytes' ], progress_state ['downloaded_bytes_this_session' ], total_bytes , overall_speed , overall_eta , file_progress , status ))
235
+ await progress_callback (RepoProgressEvent (progress_state ['completed_files' ], total_files , progress_state ['downloaded_bytes' ], progress_state ['downloaded_bytes_this_session' ], total_bytes , overall_speed , overall_eta , file_progress , status ))
312
236
313
237
progress_state = {'completed_files' : 0 , 'downloaded_bytes' : 0 , 'downloaded_bytes_this_session' : 0 }
314
238
tasks = [download_with_progress (file_info , progress_state ) for file_info in filtered_file_list ]
315
239
await asyncio .gather (* tasks )
316
240
317
241
return snapshot_dir
242
+
243
+ async def get_weight_map (repo_id : str , revision : str = "main" ) -> Optional [Dict [str , str ]]:
244
+ """
245
+ Retrieve the weight map from the model.safetensors.index.json file.
246
+
247
+ Args:
248
+ repo_id (str): The Hugging Face repository ID.
249
+ revision (str): The revision of the repository to use.
250
+
251
+ Returns:
252
+ Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
253
+ """
254
+
255
+ # Download the index file
256
+ await download_repo_files (
257
+ repo_id = repo_id ,
258
+ revision = revision ,
259
+ allow_patterns = "model.safetensors.index.json"
260
+ )
261
+
262
+ # Check if the file exists
263
+ repo_root = get_repo_root (repo_id )
264
+ snapshot_dir = repo_root / "snapshots"
265
+ index_file = next (snapshot_dir .glob ("*/model.safetensors.index.json" ), None )
266
+
267
+ if index_file and index_file .exists ():
268
+ with open (index_file , 'r' ) as f :
269
+ index_data = json .load (f )
270
+ return index_data .get ("weight_map" )
271
+
272
+ return None
273
+
274
+ def extract_layer_num (tensor_name : str ) -> Optional [int ]:
275
+ # This is a simple example and might need to be adjusted based on the actual naming convention
276
+ parts = tensor_name .split ('.' )
277
+ for part in parts :
278
+ if part .isdigit ():
279
+ return int (part )
280
+ return None
0 commit comments