From 1ba52fc47babc2f4f3fdcd14ccbf8535c51e6844 Mon Sep 17 00:00:00 2001 From: Christian Heimes Date: Fri, 19 Sep 2025 09:37:28 +0200 Subject: [PATCH 1/2] fix: improve and fix download_artifacts The previous version was attempting to use `tqdm` module as callable. The new version replaces the custom logging handler with `tqdm.contrib.logging` helper. Fixes `'module' object is not callable. Did you mean: 'tqdm.tqdm(...)'?` The `ThreadPoolExecutor` is now correctly wrapped in a context manager. The progress bar is updated by a future done callback. Signed-off-by: Christian Heimes --- flashinfer/artifacts.py | 74 +++++++++++------------------------------ 1 file changed, 20 insertions(+), 54 deletions(-) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index e7f5a3bb60..7e0aeb50a9 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -30,7 +30,6 @@ ) -import logging from contextlib import contextmanager @@ -47,45 +46,6 @@ def temp_env_var(key, value): os.environ[key] = old_value -@contextmanager -def patch_logger_for_tqdm(logger): - """ - Context manager to patch the logger so that log messages are displayed using tqdm.write, - preventing interference with tqdm progress bars. - """ - import tqdm - - class TqdmLoggingHandler(logging.Handler): - def emit(self, record): - try: - msg = self.format(record) - tqdm.write(msg, end="\n") - except Exception: - self.handleError(record) - - # Save original handlers and level - original_handlers = logger.handlers[:] - original_level = logger.level - - # Remove all existing handlers to prevent duplicate output - for h in original_handlers: - logger.removeHandler(h) - - # Add our tqdm-aware handler - handler = TqdmLoggingHandler() - handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) - logger.addHandler(handler) - logger.setLevel(logging.INFO) - try: - yield - finally: - # Remove tqdm handler and restore original handlers and level - logger.removeHandler(handler) - for h in original_handlers: - logger.addHandler(h) - logger.setLevel(original_level) - - def get_available_cubin_files(source, retries=3, delay=5, timeout=10): for attempt in range(1, retries + 1): try: @@ -155,25 +115,31 @@ def get_cubin_file_list(): def download_artifacts(): - import tqdm + from tqdm.contrib.logging import tqdm_logging_redirect + + # use a shared session to make use of HTTP keep-alive and reuse of + # HTTPS connections. + session = requests.Session() with temp_env_var("FLASHINFER_CUBIN_CHECKSUM_DISABLED", "1"): cubin_files = get_cubin_file_list() num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4")) - pool = ThreadPoolExecutor(num_threads) - futures = [] - for name, extension in cubin_files: - ret = pool.submit(get_cubin, name, "", extension) - futures.append(ret) - results = [] - with ( - patch_logger_for_tqdm(logger), - tqdm(total=len(futures), desc="Downloading cubins") as pbar, - ): - for ret in as_completed(futures): - result = ret.result() - results.append(result) + with tqdm_logging_redirect( + total=len(cubin_files), desc="Downloading cubins" + ) as pbar: + + def update_pbar_cb(_) -> None: pbar.update(1) + + with ThreadPoolExecutor(num_threads) as pool: + futures = [] + for name, extension in cubin_files: + fut = pool.submit(get_cubin, name, "", extension, session) + fut.add_done_callback(update_pbar_cb) + futures.append(fut) + + results = [fut.result() for fut in as_completed(futures)] + all_success = all(results) if not all_success: raise RuntimeError("Failed to download cubins") From 4de6fc3bce02f96fd4f3a0b4661a4262a6ded53c Mon Sep 17 00:00:00 2001 From: Christian Heimes Date: Fri, 19 Sep 2025 09:49:50 +0200 Subject: [PATCH 2/2] feat: reuse requests' session `requests.get` creates a new session object for each GET request. This is pretty inefficient, because each request has to perform DNS lookup, TCP handshake, and HTTPS handshake including certificate validation. A `requests.Session` can be shared between requests and across threads to make use of HTTP keep-alive. This change more than doubles the download speed and reduces the load on the server. Signed-off-by: Christian Heimes --- flashinfer/jit/cubin_loader.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/flashinfer/jit/cubin_loader.py b/flashinfer/jit/cubin_loader.py index bc2bd84e4e..a6d0df57c8 100644 --- a/flashinfer/jit/cubin_loader.py +++ b/flashinfer/jit/cubin_loader.py @@ -33,7 +33,9 @@ ) -def download_file(source, local_path, retries=3, delay=5, timeout=10, lock_timeout=30): +def download_file( + source, local_path, retries=3, delay=5, timeout=10, lock_timeout=30, session=None +): """ Downloads a file from a URL or copies from a local path to a destination. @@ -51,6 +53,9 @@ def download_file(source, local_path, retries=3, delay=5, timeout=10, lock_timeo import requests # type: ignore[import-untyped] + if session is None: + session = requests.Session() + lock_path = f"{local_path}.lock" # Lock file path lock = filelock.FileLock(lock_path, timeout=lock_timeout) @@ -71,7 +76,7 @@ def download_file(source, local_path, retries=3, delay=5, timeout=10, lock_timeo # Handle URL downloads for attempt in range(1, retries + 1): try: - response = requests.get(source, timeout=timeout) + response = session.get(source, timeout=timeout) response.raise_for_status() with open(local_path, "wb") as file: @@ -133,7 +138,7 @@ def load_cubin(cubin_path, sha256) -> bytes: return b"" -def get_cubin(name, sha256, file_extension=".cubin"): +def get_cubin(name, sha256, file_extension=".cubin", session=None): """ Load a cubin from the local cache directory with {name} and ensure that the sha256 signature matches. @@ -151,7 +156,7 @@ def get_cubin(name, sha256, file_extension=".cubin"): # either the file does not exist or it is corrupted, we'll download a new one. uri = FLASHINFER_CUBINS_REPOSITORY + "/" + cubin_fname logger.info(f"Fetching cubin {name} from {uri}") - download_file(uri, cubin_path) + download_file(uri, cubin_path, session=session) return load_cubin(cubin_path, sha256)