Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 210 additions & 0 deletions tests/test_triton_utils.py
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note-- this script acts as a handoff when Triton updates are behind CUDA updates. in the future this handoff may be deemed irrelevant on Triton update

Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,213 @@ def test_no_triton_fallback():
assert triton.__class__.__name__ == "TritonPlaceholder"
assert triton.language.__class__.__name__ == "TritonLanguagePlaceholder"
assert tl.__class__.__name__ == "TritonLanguagePlaceholder"


def test_configure_triton_ptxas_respects_existing_env():
"""Test that _configure_triton_ptxas_for_new_gpus doesn't override
user-set TRITON_PTXAS_PATH."""
import os

from vllm.triton_utils.importing import _configure_triton_ptxas_for_new_gpus

# Save original value
original = os.environ.get("TRITON_PTXAS_PATH")

try:
# Set a custom path
os.environ["TRITON_PTXAS_PATH"] = "/custom/path/to/ptxas"

# Call the function - it should not override
_configure_triton_ptxas_for_new_gpus()

# Verify it wasn't changed
assert os.environ.get("TRITON_PTXAS_PATH") == "/custom/path/to/ptxas"
finally:
# Restore original value
if original is None:
os.environ.pop("TRITON_PTXAS_PATH", None)
else:
os.environ["TRITON_PTXAS_PATH"] = original


def test_configure_triton_ptxas_detects_new_gpu():
"""Test that _configure_triton_ptxas_for_new_gpus sets TRITON_PTXAS_PATH
for GPUs with compute capability >= 11.0 using Triton's native detection."""
import os
import tempfile

from vllm.triton_utils.importing import _configure_triton_ptxas_for_new_gpus

# Save original values
original_ptxas = os.environ.get("TRITON_PTXAS_PATH")
original_cuda_home = os.environ.get("CUDA_HOME")

try:
# Clear TRITON_PTXAS_PATH
os.environ.pop("TRITON_PTXAS_PATH", None)

# Create a mock ptxas executable
with tempfile.TemporaryDirectory() as tmpdir:
mock_ptxas = os.path.join(tmpdir, "bin", "ptxas")
os.makedirs(os.path.dirname(mock_ptxas))
with open(mock_ptxas, "w") as f:
f.write("#!/bin/sh\necho 'ptxas mock'\n")
os.chmod(mock_ptxas, 0o755)

# Set CUDA_HOME to our temp dir
os.environ["CUDA_HOME"] = tmpdir

# Mock Triton's native GPU detection to return arch=110 (Thor, CC 11.0)
mock_target = mock.MagicMock()
mock_target.arch = 110 # CC 11.0

mock_driver_instance = mock.MagicMock()
mock_driver_instance.get_current_target.return_value = mock_target

mock_driver_class = mock.MagicMock(return_value=mock_driver_instance)
mock_driver_class.is_active.return_value = True

mock_nvidia_backend = mock.MagicMock()
mock_nvidia_backend.driver = mock_driver_class

mock_backends = mock.MagicMock()
mock_backends.get.return_value = mock_nvidia_backend

with mock.patch("vllm.triton_utils.importing.backends", mock_backends):
_configure_triton_ptxas_for_new_gpus()

# Verify TRITON_PTXAS_PATH was set
assert os.environ.get("TRITON_PTXAS_PATH") == mock_ptxas

finally:
# Restore original values
if original_ptxas is None:
os.environ.pop("TRITON_PTXAS_PATH", None)
else:
os.environ["TRITON_PTXAS_PATH"] = original_ptxas
if original_cuda_home is None:
os.environ.pop("CUDA_HOME", None)
else:
os.environ["CUDA_HOME"] = original_cuda_home


def test_configure_triton_ptxas_skips_older_gpus():
"""Test that _configure_triton_ptxas_for_new_gpus does not set
TRITON_PTXAS_PATH for GPUs with compute capability < 11.0."""
import os
import tempfile

from vllm.triton_utils.importing import _configure_triton_ptxas_for_new_gpus

# Save original values
original_ptxas = os.environ.get("TRITON_PTXAS_PATH")
original_cuda_home = os.environ.get("CUDA_HOME")

try:
# Clear TRITON_PTXAS_PATH
os.environ.pop("TRITON_PTXAS_PATH", None)

# Create a mock ptxas executable
with tempfile.TemporaryDirectory() as tmpdir:
mock_ptxas = os.path.join(tmpdir, "bin", "ptxas")
os.makedirs(os.path.dirname(mock_ptxas))
with open(mock_ptxas, "w") as f:
f.write("#!/bin/sh\necho 'ptxas mock'\n")
os.chmod(mock_ptxas, 0o755)

# Set CUDA_HOME to our temp dir
os.environ["CUDA_HOME"] = tmpdir

# Mock Triton's native GPU detection to return arch=90 (Hopper, CC 9.0)
mock_target = mock.MagicMock()
mock_target.arch = 90 # CC 9.0

mock_driver_instance = mock.MagicMock()
mock_driver_instance.get_current_target.return_value = mock_target

mock_driver_class = mock.MagicMock(return_value=mock_driver_instance)
mock_driver_class.is_active.return_value = True

mock_nvidia_backend = mock.MagicMock()
mock_nvidia_backend.driver = mock_driver_class

mock_backends = mock.MagicMock()
mock_backends.get.return_value = mock_nvidia_backend

with mock.patch("vllm.triton_utils.importing.backends", mock_backends):
_configure_triton_ptxas_for_new_gpus()

# Verify TRITON_PTXAS_PATH was NOT set
assert os.environ.get("TRITON_PTXAS_PATH") is None

finally:
# Restore original values
if original_ptxas is None:
os.environ.pop("TRITON_PTXAS_PATH", None)
else:
os.environ["TRITON_PTXAS_PATH"] = original_ptxas
if original_cuda_home is None:
os.environ.pop("CUDA_HOME", None)
else:
os.environ["CUDA_HOME"] = original_cuda_home


def test_configure_triton_ptxas_detects_gb10():
"""Test that _configure_triton_ptxas_for_new_gpus sets TRITON_PTXAS_PATH
for NVIDIA GB10 (DGX Spark) with compute capability 12.1 (arch=121)."""
import os
import tempfile

from vllm.triton_utils.importing import _configure_triton_ptxas_for_new_gpus

# Save original values
original_ptxas = os.environ.get("TRITON_PTXAS_PATH")
original_cuda_home = os.environ.get("CUDA_HOME")

try:
# Clear TRITON_PTXAS_PATH
os.environ.pop("TRITON_PTXAS_PATH", None)

# Create a mock ptxas executable
with tempfile.TemporaryDirectory() as tmpdir:
mock_ptxas = os.path.join(tmpdir, "bin", "ptxas")
os.makedirs(os.path.dirname(mock_ptxas))
with open(mock_ptxas, "w") as f:
f.write("#!/bin/sh\necho 'ptxas mock'\n")
os.chmod(mock_ptxas, 0o755)

# Set CUDA_HOME to our temp dir
os.environ["CUDA_HOME"] = tmpdir

# Mock Triton's native GPU detection to return arch=121 (GB10, CC 12.1)
mock_target = mock.MagicMock()
mock_target.arch = 121 # CC 12.1 (GB10 / DGX Spark)

mock_driver_instance = mock.MagicMock()
mock_driver_instance.get_current_target.return_value = mock_target

mock_driver_class = mock.MagicMock(return_value=mock_driver_instance)
mock_driver_class.is_active.return_value = True

mock_nvidia_backend = mock.MagicMock()
mock_nvidia_backend.driver = mock_driver_class

mock_backends = mock.MagicMock()
mock_backends.get.return_value = mock_nvidia_backend

with mock.patch("vllm.triton_utils.importing.backends", mock_backends):
_configure_triton_ptxas_for_new_gpus()

# Verify TRITON_PTXAS_PATH was set
assert os.environ.get("TRITON_PTXAS_PATH") == mock_ptxas

finally:
# Restore original values
if original_ptxas is None:
os.environ.pop("TRITON_PTXAS_PATH", None)
else:
os.environ["TRITON_PTXAS_PATH"] = original_ptxas
if original_cuda_home is None:
os.environ.pop("CUDA_HOME", None)
else:
os.environ["CUDA_HOME"] = original_cuda_home
95 changes: 95 additions & 0 deletions vllm/triton_utils/importing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,108 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
import shutil
import subprocess
import types
from importlib.util import find_spec

from vllm.logger import init_logger

logger = init_logger(__name__)


def _configure_triton_ptxas_for_new_gpus():
"""
Configure TRITON_PTXAS_PATH for GPUs that may not be supported by
Triton's bundled ptxas (e.g., Jetson Thor sm_110a, DGX Spark sm_121a).

Triton bundles a ptxas binary (currently CUDA 12.8) that may not support
the newest GPU architectures. When running on such GPUs, Triton kernel
compilation fails with errors like:
ptxas fatal: Value 'sm_121a' is not defined for option 'gpu-name'

This function uses Triton's native GPU detection to check the architecture
and configures Triton to use the system's CUDA toolkit ptxas instead,
which typically has broader architecture support (e.g., CUDA 13.0+).
"""
# Don't override if already set by user
if os.environ.get("TRITON_PTXAS_PATH"):
return

# Try to find system ptxas
cuda_home = os.environ.get("CUDA_HOME", "/usr/local/cuda")
system_ptxas_paths = [
os.path.join(cuda_home, "bin", "ptxas"),
"/usr/local/cuda/bin/ptxas",
shutil.which("ptxas"), # Check PATH
]

system_ptxas = None
for path in system_ptxas_paths:
if path and os.path.isfile(path) and os.access(path, os.X_OK):
system_ptxas = path
break

if not system_ptxas:
# No system ptxas found, can't help
return

# Use Triton's native GPU detection to get the architecture.
# This is how Triton itself determines the target GPU.
try:
from triton.backends import backends

nvidia_backend = backends.get("nvidia")
if nvidia_backend is None or nvidia_backend.driver is None:
return

if not nvidia_backend.driver.is_active():
return

# Get the current GPU target using Triton's driver
driver_instance = nvidia_backend.driver()
target = driver_instance.get_current_target()
arch = target.arch # e.g., 121 for sm_121a (CC 12.1)

# GPUs with arch >= 110 (compute capability >= 11.0) may need system ptxas
# - arch 110: Jetson Thor (sm_110a, CC 11.0)
# - arch 120: Blackwell B100/B200 (sm_120, CC 12.0)
# - arch 121: DGX Spark GB10 (sm_121a, CC 12.1)
if arch >= 110:
# Check if system ptxas is functional
try:
result = subprocess.run(
[system_ptxas, "--version"],
capture_output=True,
text=True,
timeout=5,
)
if result.returncode == 0:
# System ptxas is available, use it
os.environ["TRITON_PTXAS_PATH"] = system_ptxas
major, minor = divmod(arch, 10)
logger.info(
"Detected GPU with compute capability %d.%d (arch=%d). "
"Configuring TRITON_PTXAS_PATH=%s to ensure "
"Triton kernel compilation compatibility.",
major,
minor,
arch,
system_ptxas,
)
except (subprocess.TimeoutExpired, FileNotFoundError, OSError) as e:
logger.debug("Cannot use system ptxas: %s", e)

except Exception as e:
# Don't fail if detection doesn't work - user can still set
# TRITON_PTXAS_PATH manually
logger.debug("Failed to auto-configure TRITON_PTXAS_PATH: %s", e)


# Configure ptxas before importing Triton to ensure kernels can compile
# on new GPU architectures (Thor, GB10, etc.)
_configure_triton_ptxas_for_new_gpus()

HAS_TRITON = (
find_spec("triton") is not None
or find_spec("pytorch-triton-xpu") is not None # Not compatible
Expand Down