Skip to content

Commit 476a714

Browse files
committed
make a separate ShardDownloader abstract class w HFShardDownloader. this opens up plugging in different methods of downloading model shards e.g. #79 / #16
1 parent d22ed12 commit 476a714

16 files changed

+428
-385
lines changed

exo/api/chatgpt_api.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,7 @@ async def resolve_tokenizer(model_id: str):
103103
if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
104104
if DEBUG >= 4: print(traceback.format_exc())
105105

106-
if DEBUG >= 4: print(f"Trying mlx tokenizer for {model_id}")
107-
from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
108-
109-
return load_tokenizer(await get_model_path(model_id))
106+
raise ValueError(f"[TODO] Unsupported model: {model_id}")
110107

111108

112109
def generate_completion(

exo/download/download_progress.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from typing import Dict, Callable, Coroutine, Any, Literal
2+
from dataclasses import dataclass
3+
from datetime import timedelta
4+
5+
@dataclass
6+
class RepoFileProgressEvent:
7+
file_path: str
8+
downloaded: int
9+
downloaded_this_session: int
10+
total: int
11+
speed: int
12+
eta: timedelta
13+
status: Literal["not_started", "in_progress", "complete"]
14+
15+
def to_dict(self):
16+
return {
17+
"file_path": self.file_path,
18+
"downloaded": self.downloaded,
19+
"downloaded_this_session": self.downloaded_this_session,
20+
"total": self.total,
21+
"speed": self.speed,
22+
"eta": self.eta.total_seconds(),
23+
"status": self.status
24+
}
25+
26+
@classmethod
27+
def from_dict(cls, data):
28+
# Convert eta from seconds back to timedelta
29+
if 'eta' in data:
30+
data['eta'] = timedelta(seconds=data['eta'])
31+
return cls(**data)
32+
33+
@dataclass
34+
class RepoProgressEvent:
35+
completed_files: int
36+
total_files: int
37+
downloaded_bytes: int
38+
downloaded_bytes_this_session: int
39+
total_bytes: int
40+
overall_speed: int
41+
overall_eta: timedelta
42+
file_progress: Dict[str, RepoFileProgressEvent]
43+
status: Literal["not_started", "in_progress", "complete"]
44+
45+
def to_dict(self):
46+
return {
47+
"completed_files": self.completed_files,
48+
"total_files": self.total_files,
49+
"downloaded_bytes": self.downloaded_bytes,
50+
"downloaded_bytes_this_session": self.downloaded_bytes_this_session,
51+
"total_bytes": self.total_bytes,
52+
"overall_speed": self.overall_speed,
53+
"overall_eta": self.overall_eta.total_seconds(),
54+
"file_progress": {k: v.to_dict() for k, v in self.file_progress.items()},
55+
"status": self.status
56+
}
57+
58+
@classmethod
59+
def from_dict(cls, data):
60+
# Convert overall_eta from seconds back to timedelta
61+
if 'overall_eta' in data:
62+
data['overall_eta'] = timedelta(seconds=data['overall_eta'])
63+
64+
# Parse file_progress
65+
if 'file_progress' in data:
66+
data['file_progress'] = {
67+
k: RepoFileProgressEvent.from_dict(v)
68+
for k, v in data['file_progress'].items()
69+
}
70+
71+
return cls(**data)
72+
73+
RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]]
74+
RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]]

exo/inference/hf_helpers.py renamed to exo/download/hf/hf_helpers.py

+52-89
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import asyncio
22
import aiohttp
3+
import json
34
import os
45
from urllib.parse import urljoin
56
from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
67
from datetime import datetime, timedelta
78
from fnmatch import fnmatch
89
from pathlib import Path
910
from typing import Generator, Iterable, TypeVar, TypedDict
10-
from dataclasses import dataclass
1111
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
1212
from exo.helpers import DEBUG
13+
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
1314

1415
T = TypeVar("T")
1516
def filter_repo_objects(
@@ -21,10 +22,8 @@ def filter_repo_objects(
2122
) -> Generator[T, None, None]:
2223
if isinstance(allow_patterns, str):
2324
allow_patterns = [allow_patterns]
24-
2525
if isinstance(ignore_patterns, str):
2626
ignore_patterns = [ignore_patterns]
27-
2827
if allow_patterns is not None:
2928
allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
3029
if ignore_patterns is not None:
@@ -37,18 +36,14 @@ def _identity(item: T) -> str:
3736
if isinstance(item, Path):
3837
return str(item)
3938
raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
40-
4139
key = _identity
4240

4341
for item in items:
4442
path = key(item)
45-
4643
if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
4744
continue
48-
4945
if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
5046
continue
51-
5247
yield item
5348

5449
def _add_wildcard_to_directories(pattern: str) -> str:
@@ -99,84 +94,13 @@ async def fetch_file_list(session, repo_id, revision, path=""):
9994
raise Exception(f"Failed to fetch file list: {response.status}")
10095

10196

102-
@dataclass
103-
class HFRepoFileProgressEvent:
104-
file_path: str
105-
downloaded: int
106-
downloaded_this_session: int
107-
total: int
108-
speed: int
109-
eta: timedelta
110-
status: Literal["not_started", "in_progress", "complete"]
111-
112-
def to_dict(self):
113-
return {
114-
"file_path": self.file_path,
115-
"downloaded": self.downloaded,
116-
"downloaded_this_session": self.downloaded_this_session,
117-
"total": self.total,
118-
"speed": self.speed,
119-
"eta": self.eta.total_seconds(),
120-
"status": self.status
121-
}
122-
123-
@classmethod
124-
def from_dict(cls, data):
125-
# Convert eta from seconds back to timedelta
126-
if 'eta' in data:
127-
data['eta'] = timedelta(seconds=data['eta'])
128-
return cls(**data)
129-
130-
@dataclass
131-
class HFRepoProgressEvent:
132-
completed_files: int
133-
total_files: int
134-
downloaded_bytes: int
135-
downloaded_bytes_this_session: int
136-
total_bytes: int
137-
overall_speed: int
138-
overall_eta: timedelta
139-
file_progress: Dict[str, HFRepoFileProgressEvent]
140-
status: Literal["not_started", "in_progress", "complete"]
141-
142-
def to_dict(self):
143-
return {
144-
"completed_files": self.completed_files,
145-
"total_files": self.total_files,
146-
"downloaded_bytes": self.downloaded_bytes,
147-
"downloaded_bytes_this_session": self.downloaded_bytes_this_session,
148-
"total_bytes": self.total_bytes,
149-
"overall_speed": self.overall_speed,
150-
"overall_eta": self.overall_eta.total_seconds(),
151-
"file_progress": {k: v.to_dict() for k, v in self.file_progress.items()},
152-
"status": self.status
153-
}
154-
155-
@classmethod
156-
def from_dict(cls, data):
157-
# Convert overall_eta from seconds back to timedelta
158-
if 'overall_eta' in data:
159-
data['overall_eta'] = timedelta(seconds=data['overall_eta'])
160-
161-
# Parse file_progress
162-
if 'file_progress' in data:
163-
data['file_progress'] = {
164-
k: HFRepoFileProgressEvent.from_dict(v)
165-
for k, v in data['file_progress'].items()
166-
}
167-
168-
return cls(**data)
169-
170-
HFRepoFileProgressCallback = Callable[[HFRepoFileProgressEvent], Coroutine[Any, Any, None]]
171-
HFRepoProgressCallback = Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]
172-
17397
@retry(
17498
stop=stop_after_attempt(5),
17599
wait=wait_exponential(multiplier=1, min=4, max=60),
176100
retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)),
177101
reraise=True
178102
)
179-
async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[HFRepoFileProgressCallback] = None, use_range_request: bool = True):
103+
async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True):
180104
base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
181105
url = urljoin(base_url, file_path)
182106
local_path = os.path.join(save_directory, file_path)
@@ -198,7 +122,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
198122
if downloaded_size == total_size:
199123
if DEBUG >= 2: print(f"File already downloaded: {file_path}")
200124
if progress_callback:
201-
await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
125+
await progress_callback(RepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
202126
return
203127

204128
if response.status == 200:
@@ -221,7 +145,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
221145
if downloaded_size == total_size:
222146
if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
223147
if progress_callback:
224-
await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
148+
await progress_callback(RepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
225149
return
226150
except ValueError:
227151
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
@@ -232,7 +156,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
232156
if downloaded_size == total_size:
233157
print(f"File already downloaded: {file_path}")
234158
if progress_callback:
235-
await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
159+
await progress_callback(RepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
236160
return
237161

238162
DOWNLOAD_CHUNK_SIZE = 32768
@@ -249,10 +173,10 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
249173
eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
250174
status = "in_progress" if downloaded_size < total_size else "complete"
251175
if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
252-
await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
176+
await progress_callback(RepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
253177
if DEBUG >= 2: print(f"Downloaded: {file_path}")
254178

255-
async def download_all_files(repo_id: str, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
179+
async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None) -> Path:
256180
repo_root = get_repo_root(repo_id)
257181
refs_dir = repo_root / "refs"
258182
snapshots_dir = repo_root / "snapshots"
@@ -283,11 +207,11 @@ async def download_all_files(repo_id: str, revision: str = "main", progress_call
283207
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
284208
total_files = len(filtered_file_list)
285209
total_bytes = sum(file["size"] for file in filtered_file_list)
286-
file_progress: Dict[str, HFRepoFileProgressEvent] = {file["path"]: HFRepoFileProgressEvent(file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
210+
file_progress: Dict[str, RepoFileProgressEvent] = {file["path"]: RepoFileProgressEvent(file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
287211
start_time = datetime.now()
288212

289213
async def download_with_progress(file_info, progress_state):
290-
async def file_progress_callback(event: HFRepoFileProgressEvent):
214+
async def file_progress_callback(event: RepoFileProgressEvent):
291215
progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
292216
progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
293217
file_progress[event.file_path] = event
@@ -297,21 +221,60 @@ async def file_progress_callback(event: HFRepoFileProgressEvent):
297221
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
298222
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
299223
status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
300-
await progress_callback(HFRepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
224+
await progress_callback(RepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
301225

302226
await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
303227
progress_state['completed_files'] += 1
304-
file_progress[file_info["path"]] = HFRepoFileProgressEvent(file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
228+
file_progress[file_info["path"]] = RepoFileProgressEvent(file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
305229
if progress_callback:
306230
elapsed_time = (datetime.now() - start_time).total_seconds()
307231
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
308232
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
309233
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
310234
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
311-
await progress_callback(HFRepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
235+
await progress_callback(RepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
312236

313237
progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
314238
tasks = [download_with_progress(file_info, progress_state) for file_info in filtered_file_list]
315239
await asyncio.gather(*tasks)
316240

317241
return snapshot_dir
242+
243+
async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
244+
"""
245+
Retrieve the weight map from the model.safetensors.index.json file.
246+
247+
Args:
248+
repo_id (str): The Hugging Face repository ID.
249+
revision (str): The revision of the repository to use.
250+
251+
Returns:
252+
Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
253+
"""
254+
255+
# Download the index file
256+
await download_repo_files(
257+
repo_id=repo_id,
258+
revision=revision,
259+
allow_patterns="model.safetensors.index.json"
260+
)
261+
262+
# Check if the file exists
263+
repo_root = get_repo_root(repo_id)
264+
snapshot_dir = repo_root / "snapshots"
265+
index_file = next(snapshot_dir.glob("*/model.safetensors.index.json"), None)
266+
267+
if index_file and index_file.exists():
268+
with open(index_file, 'r') as f:
269+
index_data = json.load(f)
270+
return index_data.get("weight_map")
271+
272+
return None
273+
274+
def extract_layer_num(tensor_name: str) -> Optional[int]:
275+
# This is a simple example and might need to be adjusted based on the actual naming convention
276+
parts = tensor_name.split('.')
277+
for part in parts:
278+
if part.isdigit():
279+
return int(part)
280+
return None

0 commit comments

Comments
 (0)