diff --git a/sgl-kernel/python/sgl_kernel/load_utils.py b/sgl-kernel/python/sgl_kernel/load_utils.py index 45f06707dfc..d0b18d3fc83 100644 --- a/sgl-kernel/python/sgl_kernel/load_utils.py +++ b/sgl-kernel/python/sgl_kernel/load_utils.py @@ -205,20 +205,28 @@ def _find_cuda_home(): def _preload_cuda_library(): + """Preload the CUDA runtime library to help avoid 'libcudart.so.12 not found' issues.""" cuda_home = Path(_find_cuda_home()) - if (cuda_home / "lib").is_dir(): - cuda_path = cuda_home / "lib" - elif (cuda_home / "lib64").is_dir(): - cuda_path = cuda_home / "lib64" - else: - # Search for 'libcudart.so.12' in subdirectories - for path in cuda_home.rglob("libcudart.so.12"): - cuda_path = path.parent - break - else: - raise RuntimeError("Could not find CUDA lib directory.") - - cuda_include = (cuda_path / "libcudart.so.12").resolve() - if cuda_include.exists(): - ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL) + candidate_dirs = [ + cuda_home / "lib", + cuda_home / "lib64", + Path("/usr/lib/x86_64-linux-gnu"), + Path("/usr/lib/aarch64-linux-gnu"), + Path("/usr/lib64"), + Path("/usr/lib"), + ] + + for base in candidate_dirs: + candidate = base / "libcudart.so.12" + if candidate.exists(): + try: + cuda_runtime_lib = candidate.resolve() + ctypes.CDLL(str(cuda_runtime_lib), mode=ctypes.RTLD_GLOBAL) + logger.debug(f"Preloaded CUDA runtime under {cuda_runtime_lib}") + return + except Exception as e: + logger.debug(f"Failed to load {candidate}: {e}") + continue + + logger.debug("[sgl_kernel] Could not preload CUDA runtime library")