Skip to content

Commit bece3dc

Browse files
EduardDurechHanHan009527
authored andcommitted
CUDA Arch Independent (sgl-project#8813)
1 parent 879622e commit bece3dc

File tree

1 file changed

+36
-4
lines changed

1 file changed

+36
-4
lines changed

sgl-kernel/python/sgl_kernel/__init__.py

100755100644
Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,46 @@
11
import ctypes
22
import os
33
import platform
4+
import shutil
5+
from pathlib import Path
46

57
import torch
68

7-
SYSTEM_ARCH = platform.machine()
89

9-
cuda_path = f"/usr/local/cuda/targets/{SYSTEM_ARCH}-linux/lib/libcudart.so.12"
10-
if os.path.exists(cuda_path):
11-
ctypes.CDLL(cuda_path, mode=ctypes.RTLD_GLOBAL)
10+
# copy & modify from torch/utils/cpp_extension.py
11+
def _find_cuda_home():
12+
"""Find the CUDA install path."""
13+
# Guess #1
14+
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
15+
if cuda_home is None:
16+
# Guess #2
17+
nvcc_path = shutil.which("nvcc")
18+
if nvcc_path is not None:
19+
cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
20+
else:
21+
# Guess #3
22+
cuda_home = "/usr/local/cuda"
23+
return cuda_home
24+
25+
26+
if torch.version.hip is None:
27+
cuda_home = Path(_find_cuda_home())
28+
29+
if (cuda_home / "lib").is_dir():
30+
cuda_path = cuda_home / "lib"
31+
elif (cuda_home / "lib64").is_dir():
32+
cuda_path = cuda_home / "lib64"
33+
else:
34+
# Search for 'libcudart.so.12' in subdirectories
35+
for path in cuda_home.rglob("libcudart.so.12"):
36+
cuda_path = path.parent
37+
break
38+
else:
39+
raise RuntimeError("Could not find CUDA lib directory.")
40+
41+
cuda_include = (cuda_path / "libcudart.so.12").resolve()
42+
if cuda_include.exists():
43+
ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL)
1244

1345
from sgl_kernel import common_ops
1446
from sgl_kernel.allreduce import *

0 commit comments

Comments
 (0)