-
Notifications
You must be signed in to change notification settings - Fork 833
bugfix: Fix flashinfer download-cubin #1729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -33,7 +33,9 @@ | |||||||||||
| ) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def download_file(source, local_path, retries=3, delay=5, timeout=10, lock_timeout=30): | ||||||||||||
| def download_file( | ||||||||||||
| source, local_path, retries=3, delay=5, timeout=10, lock_timeout=30, session=None | ||||||||||||
| ): | ||||||||||||
| """ | ||||||||||||
| Downloads a file from a URL or copies from a local path to a destination. | ||||||||||||
|
|
||||||||||||
|
|
@@ -51,6 +53,9 @@ def download_file(source, local_path, retries=3, delay=5, timeout=10, lock_timeo | |||||||||||
|
|
||||||||||||
| import requests # type: ignore[import-untyped] | ||||||||||||
|
|
||||||||||||
| if session is None: | ||||||||||||
| session = requests.Session() | ||||||||||||
|
|
||||||||||||
| lock_path = f"{local_path}.lock" # Lock file path | ||||||||||||
| lock = filelock.FileLock(lock_path, timeout=lock_timeout) | ||||||||||||
|
|
||||||||||||
|
|
@@ -71,7 +76,7 @@ def download_file(source, local_path, retries=3, delay=5, timeout=10, lock_timeo | |||||||||||
| # Handle URL downloads | ||||||||||||
| for attempt in range(1, retries + 1): | ||||||||||||
| try: | ||||||||||||
| response = requests.get(source, timeout=timeout) | ||||||||||||
| response = session.get(source, timeout=timeout) | ||||||||||||
| response.raise_for_status() | ||||||||||||
|
|
||||||||||||
| with open(local_path, "wb") as file: | ||||||||||||
|
|
@@ -133,7 +138,7 @@ def load_cubin(cubin_path, sha256) -> bytes: | |||||||||||
| return b"" | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def get_cubin(name, sha256, file_extension=".cubin"): | ||||||||||||
| def get_cubin(name, sha256, file_extension=".cubin", session=None): | ||||||||||||
| """ | ||||||||||||
| Load a cubin from the local cache directory with {name} and | ||||||||||||
| ensure that the sha256 signature matches. | ||||||||||||
|
|
@@ -151,7 +156,7 @@ def get_cubin(name, sha256, file_extension=".cubin"): | |||||||||||
| # either the file does not exist or it is corrupted, we'll download a new one. | ||||||||||||
| uri = FLASHINFER_CUBINS_REPOSITORY + "/" + cubin_fname | ||||||||||||
| logger.info(f"Fetching cubin {name} from {uri}") | ||||||||||||
| download_file(uri, cubin_path) | ||||||||||||
| download_file(uri, cubin_path, session=session) | ||||||||||||
| return load_cubin(cubin_path, sha256) | ||||||||||||
|
Comment on lines
+159
to
160
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||
|
|
||||||||||||
|
|
||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to also document session