Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ repos:
rev: 'v1.17.1' # Use the sha / tag you want to point at
hooks:
- id: mypy
args: ["--config-file", "pyproject.toml", "--exclude", "flashinfer-cubin/"]
files: ^flashinfer/
exclude: ^(flashinfer-cubin/|3rdparty/|build/)

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
Expand Down
2 changes: 2 additions & 0 deletions flashinfer-cubin/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
build/
flashinfer_cubin/cubins/
3 changes: 3 additions & 0 deletions flashinfer-cubin/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
include README.md
include LICENSE
recursive-include flashinfer_cubin/cubins *
66 changes: 66 additions & 0 deletions flashinfer-cubin/build_wheel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env python3
"""
Copyright (c) 2025 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import subprocess
import sys
from pathlib import Path


def build_wheel():
"""Build the flashinfer-cubin wheel."""

# Change to the flashinfer-cubin directory
script_dir = Path(__file__).parent
os.chdir(script_dir)

print("Building flashinfer-cubin wheel...")
print(f"Working directory: {script_dir}")

# Clean previous builds
dist_dir = script_dir / "dist"
build_dir = script_dir / "build"
egg_info_dir = script_dir / "flashinfer_cubin.egg-info"

for dir_to_clean in [dist_dir, build_dir, egg_info_dir]:
if dir_to_clean.exists():
print(f"Cleaning {dir_to_clean}")
import shutil
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better code style and readability, it's recommended to move imports to the top of the file. Please move import shutil to the top-level imports section (e.g., after from pathlib import Path).


shutil.rmtree(dir_to_clean)

# Build wheel
try:
subprocess.run([sys.executable, "setup.py", "bdist_wheel"], check=True)

print("Wheel built successfully!")

# List built wheels
if dist_dir.exists():
wheels = list(dist_dir.glob("*.whl"))
if wheels:
print(f"Built wheel: {wheels[0]}")
else:
print("No wheel files found in dist/")

except subprocess.CalledProcessError as e:
print(f"Failed to build wheel: {e}")
sys.exit(1)


if __name__ == "__main__":
build_wheel()
80 changes: 80 additions & 0 deletions flashinfer-cubin/download_cubins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python3
"""
Copyright (c) 2025 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import sys
import argparse
from pathlib import Path

# Add parent directory to path to import flashinfer modules
sys.path.insert(0, str(Path(__file__).parent.parent))

from flashinfer.artifacts import download_artifacts
from flashinfer.jit.cubin_loader import FLASHINFER_CUBINS_REPOSITORY


def main():
parser = argparse.ArgumentParser(
description="Download FlashInfer cubins from artifactory"
)
parser.add_argument(
"--output-dir",
"-o",
type=str,
default="flashinfer_cubin/cubins",
help="Output directory for cubins (default: flashinfer_cubin/cubins)",
)
parser.add_argument(
"--threads",
"-t",
type=int,
default=4,
help="Number of download threads (default: 4)",
)
parser.add_argument(
"--repository",
"-r",
type=str,
default=None,
help="Override the cubins repository URL",
)

args = parser.parse_args()

# Set environment variables to control download behavior
if args.repository:
os.environ["FLASHINFER_CUBINS_REPOSITORY"] = args.repository

os.environ["FLASHINFER_CUBIN_DIR"] = str(Path(args.output_dir).absolute())
os.environ["FLASHINFER_CUBIN_DOWNLOAD_THREADS"] = str(args.threads)

print(f"Downloading cubins to {args.output_dir}")
print(
f"Repository: {os.environ.get('FLASHINFER_CUBINS_REPOSITORY', FLASHINFER_CUBINS_REPOSITORY)}"
)

# Use the existing download_artifacts function
try:
download_artifacts()
print("Download complete!")
except Exception as e:
print(f"Download failed: {e}")
sys.exit(1)


if __name__ == "__main__":
main()
67 changes: 67 additions & 0 deletions flashinfer-cubin/flashinfer_cubin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
Copyright (c) 2025 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
from pathlib import Path

# Get the path to the cubins directory within this package
CUBIN_DIR = Path(__file__).parent / "cubins"


def get_cubin_dir():
"""Get the directory containing the cubins."""
return str(CUBIN_DIR)


def list_cubins():
"""List all available cubin files."""
if not CUBIN_DIR.exists():
return []

cubins = []
for root, _, files in os.walk(CUBIN_DIR):
for file in files:
if file.endswith(".cubin"):
rel_path = os.path.relpath(os.path.join(root, file), CUBIN_DIR)
cubins.append(rel_path)
return sorted(cubins)
Comment on lines +29 to +40
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The list_cubins function can be made more concise and idiomatic by using pathlib.Path.rglob to find all .cubin files recursively. This also improves consistency by using pathlib features instead of mixing with os.walk and os.path.

Suggested change
def list_cubins():
"""List all available cubin files."""
if not CUBIN_DIR.exists():
return []
cubins = []
for root, _, files in os.walk(CUBIN_DIR):
for file in files:
if file.endswith(".cubin"):
rel_path = os.path.relpath(os.path.join(root, file), CUBIN_DIR)
cubins.append(rel_path)
return sorted(cubins)
def list_cubins():
"""List all available cubin files."""
if not CUBIN_DIR.exists():
return []
return sorted([str(p.relative_to(CUBIN_DIR)) for p in CUBIN_DIR.rglob("*.cubin")])



def get_cubin_path(relative_path):
"""Get the absolute path to a specific cubin file."""
return str(CUBIN_DIR / relative_path)


# Read version from build metadata or fallback to main flashinfer version.txt
def _get_version():
# First try to read from build metadata (for wheel distributions)
try:
from . import _build_meta

return _build_meta.__version__
except ImportError:
pass

# Fallback to reading from the main flashinfer version.txt (for development)
version_file = Path(__file__).parent.parent.parent / "version.txt"
if version_file.exists():
with open(version_file, "r") as f:
return f.read().strip()
return "0.0.0"


__version__ = _get_version()
__all__ = ["get_cubin_dir", "list_cubins", "get_cubin_path", "CUBIN_DIR"]
54 changes: 54 additions & 0 deletions flashinfer-cubin/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
[build-system]
requires = ["setuptools>=61.0", "wheel", "requests", "filelock", "torch", "tqdm"] # NOTE(Zihao): we should remove torch once https://github.com/flashinfer-ai/flashinfer/pull/1641 merged
build-backend = "setuptools.build_meta"

[project]
name = "flashinfer-cubin"
dynamic = ["version"]
description = "Pre-compiled cubins for FlashInfer"
readme = {text = "This package contains pre-compiled CUDA kernels (cubins) for FlashInfer. It provides all necessary cubin files downloaded from the FlashInfer artifactory.", content-type = "text/plain"}
requires-python = ">=3.8"
license = {text = "Apache-2.0"}
authors = [
{name = "FlashInfer team"},
]
maintainers = [
{name = "FlashInfer team"},
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Software Development :: Libraries :: Python Modules",
]
dependencies = [
"requests",
"filelock",
]

[project.urls]
Homepage = "https://github.com/flashinfer-ai/flashinfer"
Documentation = "https://github.com/flashinfer-ai/flashinfer"
Repository = "https://github.com/flashinfer-ai/flashinfer"
"Issue Tracker" = "https://github.com/flashinfer-ai/flashinfer/issues"

[tool.setuptools]
packages = ["flashinfer_cubin"]
include-package-data = true

[tool.setuptools.dynamic]
version = {attr = "flashinfer_cubin.__version__"}

[tool.setuptools.package-data]
flashinfer_cubin = ["cubins/**/*"]

[tool.setuptools.cmdclass]
build_py = "setup.DownloadAndBuildPy"
sdist = "setup.CustomSdist"
Loading