diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py old mode 100755 new mode 100644 index 8d7053bbd9d..4e92e6ebfae --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -1,14 +1,46 @@ import ctypes import os import platform +import shutil +from pathlib import Path import torch -SYSTEM_ARCH = platform.machine() -cuda_path = f"/usr/local/cuda/targets/{SYSTEM_ARCH}-linux/lib/libcudart.so.12" -if os.path.exists(cuda_path): - ctypes.CDLL(cuda_path, mode=ctypes.RTLD_GLOBAL) +# copy & modify from torch/utils/cpp_extension.py +def _find_cuda_home(): + """Find the CUDA install path.""" + # Guess #1 + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") + if cuda_home is None: + # Guess #2 + nvcc_path = shutil.which("nvcc") + if nvcc_path is not None: + cuda_home = os.path.dirname(os.path.dirname(nvcc_path)) + else: + # Guess #3 + cuda_home = "/usr/local/cuda" + return cuda_home + + +if torch.version.hip is None: + cuda_home = Path(_find_cuda_home()) + + if (cuda_home / "lib").is_dir(): + cuda_path = cuda_home / "lib" + elif (cuda_home / "lib64").is_dir(): + cuda_path = cuda_home / "lib64" + else: + # Search for 'libcudart.so.12' in subdirectories + for path in cuda_home.rglob("libcudart.so.12"): + cuda_path = path.parent + break + else: + raise RuntimeError("Could not find CUDA lib directory.") + + cuda_include = (cuda_path / "libcudart.so.12").resolve() + if cuda_include.exists(): + ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL) from sgl_kernel import common_ops from sgl_kernel.allreduce import *