Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions download-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@

class ModelDownloader:
def __init__(self, max_retries=5):
self.session = requests.Session()
if max_retries:
self.session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries))
self.session.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries))
self.max_retries = max_retries

def get_session(self):
session = requests.Session()
if self.max_retries:
session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=self.max_retries))
session.mount('https://huggingface.co', HTTPAdapter(max_retries=self.max_retries))

if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None:
self.session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))

try:
from huggingface_hub import get_token
Expand All @@ -41,7 +44,9 @@ def __init__(self, max_retries=5):
token = os.getenv("HF_TOKEN")

if token is not None:
self.session.headers = {'authorization': f'Bearer {token}'}
session.headers = {'authorization': f'Bearer {token}'}

return session

def sanitize_model_and_branch_names(self, model, branch):
if model[-1] == '/':
Expand All @@ -65,6 +70,7 @@ def sanitize_model_and_branch_names(self, model, branch):
return model, branch

def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None):
session = self.get_session()
page = f"/api/models/{model}/tree/{branch}"
cursor = b""

Expand All @@ -78,7 +84,7 @@ def get_download_links_from_huggingface(self, model, branch, text_only=False, sp
is_lora = False
while True:
url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "")
r = self.session.get(url, timeout=10)
r = session.get(url, timeout=10)
r.raise_for_status()
content = r.content

Expand Down Expand Up @@ -171,14 +177,15 @@ def get_output_folder(self, model, branch, is_lora, is_llamacpp=False):
return output_folder

def get_single_file(self, url, output_folder, start_from_scratch=False):
session = self.get_session()
filename = Path(url.rsplit('/', 1)[1])
output_path = output_folder / filename
headers = {}
mode = 'wb'
if output_path.exists() and not start_from_scratch:

# Check if the file has already been downloaded completely
r = self.session.get(url, stream=True, timeout=10)
r = session.get(url, stream=True, timeout=10)
total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size:
return
Expand All @@ -187,7 +194,7 @@ def get_single_file(self, url, output_folder, start_from_scratch=False):
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
mode = 'ab'

with self.session.get(url, stream=True, headers=headers, timeout=10) as r:
with session.get(url, stream=True, headers=headers, timeout=10) as r:
r.raise_for_status() # Do not continue the download if the request was unsuccessful
total_size = int(r.headers.get('content-length', 0))
block_size = 1024 * 1024 # 1MB
Expand Down