Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 20 additions & 54 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
)


import logging
from contextlib import contextmanager


Expand All @@ -47,45 +46,6 @@ def temp_env_var(key, value):
os.environ[key] = old_value


@contextmanager
def patch_logger_for_tqdm(logger):
"""
Context manager to patch the logger so that log messages are displayed using tqdm.write,
preventing interference with tqdm progress bars.
"""
import tqdm

class TqdmLoggingHandler(logging.Handler):
def emit(self, record):
try:
msg = self.format(record)
tqdm.write(msg, end="\n")
except Exception:
self.handleError(record)

# Save original handlers and level
original_handlers = logger.handlers[:]
original_level = logger.level

# Remove all existing handlers to prevent duplicate output
for h in original_handlers:
logger.removeHandler(h)

# Add our tqdm-aware handler
handler = TqdmLoggingHandler()
handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
try:
yield
finally:
# Remove tqdm handler and restore original handlers and level
logger.removeHandler(handler)
for h in original_handlers:
logger.addHandler(h)
logger.setLevel(original_level)


def get_available_cubin_files(source, retries=3, delay=5, timeout=10):
for attempt in range(1, retries + 1):
try:
Expand Down Expand Up @@ -155,25 +115,31 @@ def get_cubin_file_list():


def download_artifacts():
import tqdm
from tqdm.contrib.logging import tqdm_logging_redirect

# use a shared session to make use of HTTP keep-alive and reuse of
# HTTPS connections.
session = requests.Session()

with temp_env_var("FLASHINFER_CUBIN_CHECKSUM_DISABLED", "1"):
cubin_files = get_cubin_file_list()
num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4"))
pool = ThreadPoolExecutor(num_threads)
futures = []
for name, extension in cubin_files:
ret = pool.submit(get_cubin, name, "", extension)
futures.append(ret)
results = []
with (
patch_logger_for_tqdm(logger),
tqdm(total=len(futures), desc="Downloading cubins") as pbar,
):
for ret in as_completed(futures):
result = ret.result()
results.append(result)
with tqdm_logging_redirect(
total=len(cubin_files), desc="Downloading cubins"
) as pbar:

def update_pbar_cb(_) -> None:
pbar.update(1)

with ThreadPoolExecutor(num_threads) as pool:
futures = []
for name, extension in cubin_files:
fut = pool.submit(get_cubin, name, "", extension, session)
fut.add_done_callback(update_pbar_cb)
futures.append(fut)

results = [fut.result() for fut in as_completed(futures)]

all_success = all(results)
if not all_success:
raise RuntimeError("Failed to download cubins")
Expand Down
13 changes: 9 additions & 4 deletions flashinfer/jit/cubin_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
)


def download_file(source, local_path, retries=3, delay=5, timeout=10, lock_timeout=30):
def download_file(
source, local_path, retries=3, delay=5, timeout=10, lock_timeout=30, session=None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to also document session

):
"""
Downloads a file from a URL or copies from a local path to a destination.

Expand All @@ -51,6 +53,9 @@ def download_file(source, local_path, retries=3, delay=5, timeout=10, lock_timeo

import requests # type: ignore[import-untyped]

if session is None:
session = requests.Session()

lock_path = f"{local_path}.lock" # Lock file path
lock = filelock.FileLock(lock_path, timeout=lock_timeout)

Expand All @@ -71,7 +76,7 @@ def download_file(source, local_path, retries=3, delay=5, timeout=10, lock_timeo
# Handle URL downloads
for attempt in range(1, retries + 1):
try:
response = requests.get(source, timeout=timeout)
response = session.get(source, timeout=timeout)
response.raise_for_status()

with open(local_path, "wb") as file:
Expand Down Expand Up @@ -133,7 +138,7 @@ def load_cubin(cubin_path, sha256) -> bytes:
return b""


def get_cubin(name, sha256, file_extension=".cubin"):
def get_cubin(name, sha256, file_extension=".cubin", session=None):
"""
Load a cubin from the local cache directory with {name} and
ensure that the sha256 signature matches.
Expand All @@ -151,7 +156,7 @@ def get_cubin(name, sha256, file_extension=".cubin"):
# either the file does not exist or it is corrupted, we'll download a new one.
uri = FLASHINFER_CUBINS_REPOSITORY + "/" + cubin_fname
logger.info(f"Fetching cubin {name} from {uri}")
download_file(uri, cubin_path)
download_file(uri, cubin_path, session=session)
return load_cubin(cubin_path, sha256)
Comment on lines +159 to 160
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The download_file function returns a boolean indicating success or failure. It's better to check this return value explicitly rather than relying on load_cubin to fail implicitly. This makes the control flow clearer and more robust, as download_file already handles logging on failure.

Suggested change
download_file(uri, cubin_path, session=session)
return load_cubin(cubin_path, sha256)
if download_file(uri, cubin_path, session=session):
return load_cubin(cubin_path, sha256)
return b""



Expand Down