11from __future__ import annotations
22
3+ import logging
34import os
45import re
56import shutil
67import sys
78from importlib import resources
89from pathlib import Path
9- from typing import Dict , List
10+ from typing import Dict , List , Optional
1011
1112from flashinfer_bench .compile .builder import (
1213 Builder ,
1617)
1718from flashinfer_bench .compile .runnable import Runnable
1819from flashinfer_bench .data import Definition , Solution , SourceFile , SupportedLanguages
20+ from flashinfer_bench .utils import is_cuda_available
1921
2022CUDA_ALLOWED_EXTS = [".cu" , ".cpp" , ".cc" , ".cxx" , ".c" ]
2123
24+ logger = logging .getLogger (__name__ )
2225
23- def _verify_cuda () -> bool :
24- try :
25- import torch
26- import torch .utils .cpp_extension
27-
28- return torch .cuda .is_available ()
29- except ImportError :
30- return False
3126
32-
33- def _get_package_paths (pkg_name : str , lib_names : List [str ] = None ):
27+ def _get_package_paths (pkg_name : str , lib_names : Optional [List [str ]] = None ):
3428 include_path = None
3529 ldflags = []
3630
@@ -64,8 +58,11 @@ def _get_package_paths(pkg_name: str, lib_names: List[str] = None):
6458 ldflags = [f"/LIBPATH:{ lib_path } " ] + lib_names
6559
6660 except Exception :
67- # TODO(shanli): add logger to print warning
68- pass
61+ logger .warning (
62+ "Failed to discover resources for CUDA package '%s'; continuing without it." ,
63+ pkg_name ,
64+ exc_info = True ,
65+ )
6966
7067 return include_path , ldflags
7168
@@ -125,7 +122,7 @@ class CUDABuilder(Builder):
125122 @classmethod
126123 def _get_cuda_available (cls ) -> bool :
127124 if cls ._cuda_available is None :
128- cls ._cuda_available = _verify_cuda ()
125+ cls ._cuda_available = is_cuda_available ()
129126 return cls ._cuda_available
130127
131128 def __init__ (self ) -> None :
@@ -142,16 +139,19 @@ def _make_key(self, solution: Solution) -> str:
142139 return f"cuda::{ create_pkg_name (solution )} "
143140
144141 def _make_closer (self ):
145- # We keep build dirs for torch extension caching. The temp dirs can be cleaned by calling `clear_cache` on program exit.
142+ # We keep build dirs for torch extension caching. The temp dirs can be cleaned by
143+ # calling `clear_cache` on program exit.
146144 return lambda : None
147145
148146 def _build (self , defn : Definition , sol : Solution ) -> Runnable :
149147 # CUDA solutions must provide a C/CUDA symbol as entry point.
150- # If user prefer a Python wrapper, set language to `python` and ensure compilation and binding are properly handled.
148+ # If user prefer a Python wrapper, set language to `python` and ensure compilation and
149+ # binding are properly handled.
151150 entry_file_extension = "." + sol .spec .entry_point .split ("::" )[0 ].split ("." )[- 1 ]
152151 if entry_file_extension not in CUDA_ALLOWED_EXTS :
153152 raise BuildError (
154- f"Entry file type not recognized. Must be one of { CUDA_ALLOWED_EXTS } , got { entry_file_extension } ."
153+ f"Entry file type not recognized. Must be one of { CUDA_ALLOWED_EXTS } , "
154+ f"got { entry_file_extension } ."
155155 )
156156
157157 if not self ._get_cuda_available ():
@@ -184,7 +184,8 @@ def _build(self, defn: Definition, sol: Solution) -> Runnable:
184184 inc_path = self ._extra_include_paths .get (dep )
185185 if not inc_path :
186186 raise BuildError (
187- f"{ dep } is not available in the current environment but referenced by { sol .name } "
187+ f"{ dep } is not available in the current environment but referenced "
188+ f"by { sol .name } "
188189 )
189190 extra_include_paths .append (inc_path )
190191 ldflags = self ._extra_ldflags .get (dep )
0 commit comments