Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions sgl-kernel/python/sgl_kernel/__init__.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,14 +1,46 @@
import ctypes
import os
import platform
import shutil
from pathlib import Path

import torch

SYSTEM_ARCH = platform.machine()

cuda_path = f"/usr/local/cuda/targets/{SYSTEM_ARCH}-linux/lib/libcudart.so.12"
if os.path.exists(cuda_path):
ctypes.CDLL(cuda_path, mode=ctypes.RTLD_GLOBAL)
# copy & modify from torch/utils/cpp_extension.py
def _find_cuda_home():
"""Find the CUDA install path."""
# Guess #1
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home is None:
# Guess #2
nvcc_path = shutil.which("nvcc")
if nvcc_path is not None:
cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
else:
# Guess #3
cuda_home = "/usr/local/cuda"
return cuda_home


if torch.version.hip is None:
cuda_home = Path(_find_cuda_home())

if (cuda_home / "lib").is_dir():
cuda_path = cuda_home / "lib"
elif (cuda_home / "lib64").is_dir():
cuda_path = cuda_home / "lib64"
else:
# Search for 'libcudart.so.12' in subdirectories
for path in cuda_home.rglob("libcudart.so.12"):
cuda_path = path.parent
break
else:
raise RuntimeError("Could not find CUDA lib directory.")

cuda_include = (cuda_path / "libcudart.so.12").resolve()
if cuda_include.exists():
ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL)

from sgl_kernel import common_ops
from sgl_kernel.allreduce import *
Expand Down
Loading