|
1 | 1 | import ctypes |
2 | 2 | import os |
3 | 3 | import platform |
| 4 | +import shutil |
| 5 | +from pathlib import Path |
4 | 6 |
|
5 | 7 | import torch |
6 | 8 |
|
7 | | -SYSTEM_ARCH = platform.machine() |
8 | 9 |
|
9 | | -cuda_path = f"/usr/local/cuda/targets/{SYSTEM_ARCH}-linux/lib/libcudart.so.12" |
10 | | -if os.path.exists(cuda_path): |
11 | | - ctypes.CDLL(cuda_path, mode=ctypes.RTLD_GLOBAL) |
| 10 | +# copy & modify from torch/utils/cpp_extension.py |
| 11 | +def _find_cuda_home(): |
| 12 | + """Find the CUDA install path.""" |
| 13 | + # Guess #1 |
| 14 | + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") |
| 15 | + if cuda_home is None: |
| 16 | + # Guess #2 |
| 17 | + nvcc_path = shutil.which("nvcc") |
| 18 | + if nvcc_path is not None: |
| 19 | + cuda_home = os.path.dirname(os.path.dirname(nvcc_path)) |
| 20 | + else: |
| 21 | + # Guess #3 |
| 22 | + cuda_home = "/usr/local/cuda" |
| 23 | + return cuda_home |
| 24 | + |
| 25 | + |
| 26 | +if torch.version.hip is None: |
| 27 | + cuda_home = Path(_find_cuda_home()) |
| 28 | + |
| 29 | + if (cuda_home / "lib").is_dir(): |
| 30 | + cuda_path = cuda_home / "lib" |
| 31 | + elif (cuda_home / "lib64").is_dir(): |
| 32 | + cuda_path = cuda_home / "lib64" |
| 33 | + else: |
| 34 | + # Search for 'libcudart.so.12' in subdirectories |
| 35 | + for path in cuda_home.rglob("libcudart.so.12"): |
| 36 | + cuda_path = path.parent |
| 37 | + break |
| 38 | + else: |
| 39 | + raise RuntimeError("Could not find CUDA lib directory.") |
| 40 | + |
| 41 | + cuda_include = (cuda_path / "libcudart.so.12").resolve() |
| 42 | + if cuda_include.exists(): |
| 43 | + ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL) |
12 | 44 |
|
13 | 45 | from sgl_kernel import common_ops |
14 | 46 | from sgl_kernel.allreduce import * |
|
0 commit comments