Skip to content

Commit 46ca5a5

Browse files
authored
Merge branch 'flashinfer-ai:main' into main
2 parents ac2ee6f + 4a2ebc6 commit 46ca5a5

File tree

14 files changed

+1098
-59
lines changed

14 files changed

+1098
-59
lines changed

flashinfer_bench/compile/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def create_pkg_name(sol: Solution, prefix: str = "") -> str:
4242
h.update(src.path.encode())
4343
h.update(src.content.encode())
4444

45-
return prefix + s + "_" + h.hexdigest()[:4]
45+
return prefix + s + "_" + h.hexdigest()[:6]
4646

4747

4848
class BuildError(RuntimeError):
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .cuda_builder import CUDABuilder
22
from .python_builder import PythonBuilder
33
from .triton_builder import TritonBuilder
4+
from .tvm_ffi_builder import TVMFFIBuilder
45

5-
__all__ = ["CUDABuilder", "PythonBuilder", "TritonBuilder"]
6+
__all__ = ["CUDABuilder", "PythonBuilder", "TritonBuilder", "TVMFFIBuilder"]

flashinfer_bench/compile/builders/cuda_builder.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

3+
import logging
34
import os
45
import re
56
import shutil
67
import sys
78
from importlib import resources
89
from pathlib import Path
9-
from typing import Dict, List
10+
from typing import Dict, List, Optional
1011

1112
from flashinfer_bench.compile.builder import (
1213
Builder,
@@ -16,21 +17,14 @@
1617
)
1718
from flashinfer_bench.compile.runnable import Runnable
1819
from flashinfer_bench.data import Definition, Solution, SourceFile, SupportedLanguages
20+
from flashinfer_bench.utils import is_cuda_available
1921

2022
CUDA_ALLOWED_EXTS = [".cu", ".cpp", ".cc", ".cxx", ".c"]
2123

24+
logger = logging.getLogger(__name__)
2225

23-
def _verify_cuda() -> bool:
24-
try:
25-
import torch
26-
import torch.utils.cpp_extension
27-
28-
return torch.cuda.is_available()
29-
except ImportError:
30-
return False
3126

32-
33-
def _get_package_paths(pkg_name: str, lib_names: List[str] = None):
27+
def _get_package_paths(pkg_name: str, lib_names: Optional[List[str]] = None):
3428
include_path = None
3529
ldflags = []
3630

@@ -64,8 +58,11 @@ def _get_package_paths(pkg_name: str, lib_names: List[str] = None):
6458
ldflags = [f"/LIBPATH:{lib_path}"] + lib_names
6559

6660
except Exception:
67-
# TODO(shanli): add logger to print warning
68-
pass
61+
logger.warning(
62+
"Failed to discover resources for CUDA package '%s'; continuing without it.",
63+
pkg_name,
64+
exc_info=True,
65+
)
6966

7067
return include_path, ldflags
7168

@@ -125,7 +122,7 @@ class CUDABuilder(Builder):
125122
@classmethod
126123
def _get_cuda_available(cls) -> bool:
127124
if cls._cuda_available is None:
128-
cls._cuda_available = _verify_cuda()
125+
cls._cuda_available = is_cuda_available()
129126
return cls._cuda_available
130127

131128
def __init__(self) -> None:
@@ -142,16 +139,19 @@ def _make_key(self, solution: Solution) -> str:
142139
return f"cuda::{create_pkg_name(solution)}"
143140

144141
def _make_closer(self):
145-
# We keep build dirs for torch extension caching. The temp dirs can be cleaned by calling `clear_cache` on program exit.
142+
# We keep build dirs for torch extension caching. The temp dirs can be cleaned by
143+
# calling `clear_cache` on program exit.
146144
return lambda: None
147145

148146
def _build(self, defn: Definition, sol: Solution) -> Runnable:
149147
# CUDA solutions must provide a C/CUDA symbol as entry point.
150-
# If user prefer a Python wrapper, set language to `python` and ensure compilation and binding are properly handled.
148+
# If user prefer a Python wrapper, set language to `python` and ensure compilation and
149+
# binding are properly handled.
151150
entry_file_extension = "." + sol.spec.entry_point.split("::")[0].split(".")[-1]
152151
if entry_file_extension not in CUDA_ALLOWED_EXTS:
153152
raise BuildError(
154-
f"Entry file type not recognized. Must be one of {CUDA_ALLOWED_EXTS}, got {entry_file_extension}."
153+
f"Entry file type not recognized. Must be one of {CUDA_ALLOWED_EXTS}, "
154+
f"got {entry_file_extension}."
155155
)
156156

157157
if not self._get_cuda_available():
@@ -184,7 +184,8 @@ def _build(self, defn: Definition, sol: Solution) -> Runnable:
184184
inc_path = self._extra_include_paths.get(dep)
185185
if not inc_path:
186186
raise BuildError(
187-
f"{dep} is not available in the current environment but referenced by {sol.name}"
187+
f"{dep} is not available in the current environment but referenced "
188+
f"by {sol.name}"
188189
)
189190
extra_include_paths.append(inc_path)
190191
ldflags = self._extra_ldflags.get(dep)

0 commit comments

Comments
 (0)