Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 5 additions & 15 deletions python/sglang/check_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,22 +194,12 @@ def _get_cuda_driver_version(self):
"""
Get CUDA driver version.
"""
versions = set()
try:
output = subprocess.check_output(
[
"nvidia-smi",
"--query-gpu=driver_version",
"--format=csv,noheader,nounits",
]
)
versions = set(output.decode().strip().split("\n"))
if len(versions) == 1:
return {"CUDA Driver Version": versions.pop()}
else:
return {"CUDA Driver Versions": ", ".join(sorted(versions))}
except subprocess.SubprocessError:
from sglang.srt.utils.common import get_nvidia_driver_version_str

ver = get_nvidia_driver_version_str()
if ver is None:
return {"CUDA Driver Version": "Not Available"}
return {"CUDA Driver Version": ver}

def get_topology(self):
"""
Expand Down
18 changes: 5 additions & 13 deletions python/sglang/cli/killall.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,11 @@ def _run_smi(query, query_type="gpu"):


def _get_smi_version():
"""Return nvidia-smi driver version and CUDA version, or None on failure."""
try:
out = subprocess.check_output(
[
"nvidia-smi",
"--query-gpu=driver_version",
"--format=csv,noheader,nounits",
],
text=True,
timeout=10,
)
driver = out.strip().splitlines()[0].strip() if out.strip() else "unknown"
except (subprocess.SubprocessError, FileNotFoundError, IndexError):
"""Return nvidia-smi driver version and GPU name, or None on failure."""
from sglang.srt.utils.common import get_nvidia_driver_version_str

driver = get_nvidia_driver_version_str()
if driver is None:
return None
try:
out = subprocess.check_output(
Expand Down
20 changes: 16 additions & 4 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
get_device_name,
get_device_sm,
get_int_env_var,
get_nvidia_driver_version,
get_quantization_config,
human_readable_int,
is_blackwell_supported,
Expand Down Expand Up @@ -1751,10 +1752,21 @@ def _handle_model_specific_adjustments(self):
and is_triton_kernels_available()
and self.quantization is None
):
self.moe_runner_backend = "triton_kernel"
logger.warning(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
# The triton_kernels package segfaults on Blackwell (B200)
# with NVIDIA driver >= 595. Fall back to triton backend.
if is_blackwell_supported() and get_nvidia_driver_version() >= (
595,
):
self.moe_runner_backend = "triton"
logger.warning(
"Detected GPT-OSS model on Blackwell with driver >= 595, "
"using triton MOE kernel to avoid triton_kernels SIGSEGV."
)
else:
self.moe_runner_backend = "triton_kernel"
logger.warning(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)

if self.moe_runner_backend == "triton_kernel":
assert (
Expand Down
35 changes: 35 additions & 0 deletions python/sglang/srt/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3399,6 +3399,41 @@ def is_triton_kernels_available() -> bool:
return importlib.util.find_spec("triton_kernels") is not None


@lru_cache(maxsize=1)
def get_nvidia_driver_version() -> tuple:
"""Return the NVIDIA driver version as a tuple of ints, e.g. (595, 58, 3).
Returns (0,) on failure."""
version_str = get_nvidia_driver_version_str()
if version_str is None:
return (0,)
try:
return tuple(int(x) for x in version_str.split("."))
except ValueError:
return (0,)


@lru_cache(maxsize=1)
def get_nvidia_driver_version_str() -> str:
"""Return the NVIDIA driver version string, e.g. '595.58.03'.
Returns None on failure."""
try:
result = subprocess.run(
[
"nvidia-smi",
"--query-gpu=driver_version",
"--format=csv,noheader,nounits",
],
capture_output=True,
text=True,
check=True,
timeout=10,
)
version_str = result.stdout.strip().split("\n")[0].strip()
return version_str if version_str else None
except (subprocess.CalledProcessError, FileNotFoundError, ValueError):
return None


def check_cuda_result(raw_output):
import cuda.bindings.runtime as cuda_rt

Expand Down
Loading