Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@ ENV UV_LINK_MODE=copy
# Verify GCC version
RUN gcc --version

# Ensure CUDA compatibility library is loaded
RUN echo "/usr/local/cuda-$(echo "$CUDA_VERSION" | cut -d. -f1,2)/compat/" > /etc/ld.so.conf.d/cuda-compat.conf && ldconfig
# Enable CUDA forward compatibility by setting '-e VLLM_ENABLE_CUDA_COMPATIBILITY=1'
# Only needed for datacenter/professional GPUs with older drivers.
# See: https://docs.nvidia.com/deploy/cuda-compatibility/
ENV VLLM_ENABLE_CUDA_COMPATIBILITY=0

# ============================================================
# SLOW-CHANGING DEPENDENCIES BELOW
Expand Down Expand Up @@ -560,8 +562,10 @@ ENV UV_HTTP_TIMEOUT=500
ENV UV_INDEX_STRATEGY="unsafe-best-match"
ENV UV_LINK_MODE=copy

# Ensure CUDA compatibility library is loaded
RUN echo "/usr/local/cuda-$(echo "$CUDA_VERSION" | cut -d. -f1,2)/compat/" > /etc/ld.so.conf.d/cuda-compat.conf && ldconfig
# Enable CUDA forward compatibility by setting '-e VLLM_ENABLE_CUDA_COMPATIBILITY=1'
# Only needed for datacenter/professional GPUs with older drivers.
# See: https://docs.nvidia.com/deploy/cuda-compatibility/
ENV VLLM_ENABLE_CUDA_COMPATIBILITY=0

# ============================================================
# SLOW-CHANGING DEPENDENCIES BELOW
Expand Down
190 changes: 190 additions & 0 deletions tests/cuda/test_cuda_compatibility_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for CUDA forward compatibility path logic in env_override.py.

Verifies the opt-in LD_LIBRARY_PATH manipulation for CUDA compat libs,
including env var parsing, path detection, and deduplication.
"""

import os
from unittest.mock import patch

import pytest

# Import the functions directly (they're module-level in env_override)
# We must import them without triggering the module-level side effects,
# so we import the functions by name after the module is already loaded.
from vllm.env_override import (
_get_torch_cuda_version,
_maybe_set_cuda_compatibility_path,
)


class TestCudaCompatibilityEnvParsing:
"""Test VLLM_ENABLE_CUDA_COMPATIBILITY env var parsing."""

def test_disabled_by_default(self, monkeypatch):
"""Compat path is NOT set when env var is absent."""
monkeypatch.delenv("VLLM_ENABLE_CUDA_COMPATIBILITY", raising=False)
monkeypatch.delenv("LD_LIBRARY_PATH", raising=False)
_maybe_set_cuda_compatibility_path()
assert (
"LD_LIBRARY_PATH" not in os.environ
or os.environ.get("LD_LIBRARY_PATH", "") == ""
)

@pytest.mark.parametrize("value", ["0", "false", "False", "no", ""])
def test_disabled_values(self, monkeypatch, value):
"""Various falsy values should not activate compat path."""
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", value)
monkeypatch.delenv("LD_LIBRARY_PATH", raising=False)
_maybe_set_cuda_compatibility_path()
# LD_LIBRARY_PATH should not be set (or remain empty)
ld_path = os.environ.get("LD_LIBRARY_PATH", "")
assert "compat" not in ld_path

@pytest.mark.parametrize("value", ["1", "true", "True", " 1 ", " TRUE "])
def test_enabled_values_with_valid_path(self, monkeypatch, tmp_path, value):
"""Truthy values activate compat path when a valid path exists."""
compat_dir = tmp_path / "compat"
compat_dir.mkdir()
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", value)
monkeypatch.setenv("VLLM_CUDA_COMPATIBILITY_PATH", str(compat_dir))
monkeypatch.delenv("LD_LIBRARY_PATH", raising=False)
_maybe_set_cuda_compatibility_path()
ld_path = os.environ.get("LD_LIBRARY_PATH", "")
assert str(compat_dir) in ld_path


class TestCudaCompatibilityPathDetection:
"""Test path detection: custom override, conda, default."""

def test_custom_path_override(self, monkeypatch, tmp_path):
"""VLLM_CUDA_COMPATIBILITY_PATH takes highest priority."""
custom_dir = tmp_path / "my-compat"
custom_dir.mkdir()
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", "1")
monkeypatch.setenv("VLLM_CUDA_COMPATIBILITY_PATH", str(custom_dir))
monkeypatch.delenv("LD_LIBRARY_PATH", raising=False)
_maybe_set_cuda_compatibility_path()
ld_path = os.environ.get("LD_LIBRARY_PATH", "")
assert ld_path.startswith(str(custom_dir))

def test_conda_prefix_fallback(self, monkeypatch, tmp_path):
"""Falls back to $CONDA_PREFIX/cuda-compat if custom not set."""
conda_dir = tmp_path / "conda-env"
compat_dir = conda_dir / "cuda-compat"
compat_dir.mkdir(parents=True)
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", "1")
monkeypatch.delenv("VLLM_CUDA_COMPATIBILITY_PATH", raising=False)
monkeypatch.setenv("CONDA_PREFIX", str(conda_dir))
monkeypatch.delenv("LD_LIBRARY_PATH", raising=False)
_maybe_set_cuda_compatibility_path()
ld_path = os.environ.get("LD_LIBRARY_PATH", "")
assert str(compat_dir) in ld_path

def test_no_valid_path_does_nothing(self, monkeypatch):
"""When enabled but no valid path exists, LD_LIBRARY_PATH unchanged."""
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", "1")
monkeypatch.setenv("VLLM_CUDA_COMPATIBILITY_PATH", "/nonexistent/path")
monkeypatch.delenv("CONDA_PREFIX", raising=False)
monkeypatch.delenv("LD_LIBRARY_PATH", raising=False)
with patch("vllm.env_override._get_torch_cuda_version", return_value=None):
_maybe_set_cuda_compatibility_path()
assert os.environ.get("LD_LIBRARY_PATH", "") == ""

def test_default_cuda_path_fallback(self, monkeypatch, tmp_path):
"""Falls back to /usr/local/cuda-{ver}/compat via torch version."""
# Create a fake cuda compat dir
fake_cuda = tmp_path / "cuda-12.8" / "compat"
fake_cuda.mkdir(parents=True)
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", "1")
monkeypatch.delenv("VLLM_CUDA_COMPATIBILITY_PATH", raising=False)
monkeypatch.delenv("CONDA_PREFIX", raising=False)
monkeypatch.delenv("LD_LIBRARY_PATH", raising=False)
# Mock _get_torch_cuda_version and os.path.isdir for the default path
with (
patch("vllm.env_override._get_torch_cuda_version", return_value="12.8"),
patch(
"vllm.env_override.os.path.isdir",
side_effect=lambda p: p == "/usr/local/cuda-12.8/compat"
or os.path.isdir(p),
),
):
_maybe_set_cuda_compatibility_path()
ld_path = os.environ.get("LD_LIBRARY_PATH", "")
assert "/usr/local/cuda-12.8/compat" in ld_path


class TestCudaCompatibilityLdPathManipulation:
"""Test LD_LIBRARY_PATH prepend and deduplication logic."""

def test_prepends_to_empty_ld_path(self, monkeypatch, tmp_path):
"""Compat path is set when LD_LIBRARY_PATH is empty."""
compat_dir = tmp_path / "compat"
compat_dir.mkdir()
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", "1")
monkeypatch.setenv("VLLM_CUDA_COMPATIBILITY_PATH", str(compat_dir))
monkeypatch.delenv("LD_LIBRARY_PATH", raising=False)
_maybe_set_cuda_compatibility_path()
assert os.environ["LD_LIBRARY_PATH"] == str(compat_dir)

def test_prepends_to_existing_ld_path(self, monkeypatch, tmp_path):
"""Compat path is prepended before existing entries."""
compat_dir = tmp_path / "compat"
compat_dir.mkdir()
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", "1")
monkeypatch.setenv("VLLM_CUDA_COMPATIBILITY_PATH", str(compat_dir))
monkeypatch.setenv("LD_LIBRARY_PATH", "/usr/lib:/other/lib")
_maybe_set_cuda_compatibility_path()
ld_path = os.environ["LD_LIBRARY_PATH"]
parts = ld_path.split(os.pathsep)
assert parts[0] == str(compat_dir)
assert "/usr/lib" in parts
assert "/other/lib" in parts

def test_deduplicates_existing_compat_path(self, monkeypatch, tmp_path):
"""If compat path already in LD_LIBRARY_PATH, move to front."""
compat_dir = tmp_path / "compat"
compat_dir.mkdir()
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", "1")
monkeypatch.setenv("VLLM_CUDA_COMPATIBILITY_PATH", str(compat_dir))
monkeypatch.setenv(
"LD_LIBRARY_PATH",
f"/usr/lib:{compat_dir}:/other/lib",
)
_maybe_set_cuda_compatibility_path()
ld_path = os.environ["LD_LIBRARY_PATH"]
parts = ld_path.split(os.pathsep)
assert parts[0] == str(compat_dir)
# Should appear exactly once
assert parts.count(str(compat_dir)) == 1

def test_already_at_front_is_noop(self, monkeypatch, tmp_path):
"""If compat path is already first, don't modify LD_LIBRARY_PATH."""
compat_dir = tmp_path / "compat"
compat_dir.mkdir()
original = f"{compat_dir}:/usr/lib"
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", "1")
monkeypatch.setenv("VLLM_CUDA_COMPATIBILITY_PATH", str(compat_dir))
monkeypatch.setenv("LD_LIBRARY_PATH", original)
_maybe_set_cuda_compatibility_path()
assert os.environ["LD_LIBRARY_PATH"] == original


class TestGetTorchCudaVersion:
"""Test _get_torch_cuda_version() helper."""

def test_returns_string_when_torch_available(self):
"""Should return a CUDA version string like '12.8'."""
version = _get_torch_cuda_version()
# torch is installed in vllm's environment
assert version is None or isinstance(version, str)

def test_returns_none_when_torch_missing(self):
"""Should return None when torch is not importable."""
with patch(
"vllm.env_override.importlib.util.find_spec",
return_value=None,
):
assert _get_torch_cuda_version() is None
76 changes: 76 additions & 0 deletions vllm/env_override.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,83 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E402
import importlib.util
import os


def _get_torch_cuda_version():
"""Get torch's CUDA version without importing torch (avoids CUDA init)."""
try:
spec = importlib.util.find_spec("torch")
if not spec:
return None
if spec.origin:
torch_root = os.path.dirname(spec.origin)
elif spec.submodule_search_locations:
torch_root = spec.submodule_search_locations[0]
else:
return None
version_path = os.path.join(torch_root, "version.py")
if not os.path.exists(version_path):
return None
ver_spec = importlib.util.spec_from_file_location("torch.version", version_path)
if not ver_spec or not ver_spec.loader:
return None
module = importlib.util.module_from_spec(ver_spec)
ver_spec.loader.exec_module(module)
return getattr(module, "cuda", None)
except Exception:
return None


def _maybe_set_cuda_compatibility_path():
"""Set LD_LIBRARY_PATH for CUDA forward compatibility if enabled.

Must run before 'import torch' since torch loads CUDA shared libraries
at import time and the dynamic linker only consults LD_LIBRARY_PATH when
a library is first loaded.

CUDA forward compatibility is only supported on select professional and
datacenter NVIDIA GPUs. Consumer GPUs (GeForce, RTX) do not support it
and will get Error 803 if compat libs are loaded.
"""
enable = os.environ.get("VLLM_ENABLE_CUDA_COMPATIBILITY", "0").strip().lower() in (
"1",
"true",
)
if not enable:
return

cuda_compat_path = os.environ.get("VLLM_CUDA_COMPATIBILITY_PATH", "")
if not cuda_compat_path or not os.path.isdir(cuda_compat_path):
conda_prefix = os.environ.get("CONDA_PREFIX", "")
conda_compat = os.path.join(conda_prefix, "cuda-compat")
if conda_prefix and os.path.isdir(conda_compat):
cuda_compat_path = conda_compat
if not cuda_compat_path or not os.path.isdir(cuda_compat_path):
torch_cuda_version = _get_torch_cuda_version()
if torch_cuda_version:
default_path = f"/usr/local/cuda-{torch_cuda_version}/compat"
if os.path.isdir(default_path):
cuda_compat_path = default_path
if not cuda_compat_path or not os.path.isdir(cuda_compat_path):
return

norm_path = os.path.normpath(cuda_compat_path)
existing = os.environ.get("LD_LIBRARY_PATH", "")
ld_paths = existing.split(os.pathsep) if existing else []

if ld_paths and ld_paths[0] and os.path.normpath(ld_paths[0]) == norm_path:
return # Already at the front

new_paths = [norm_path] + [
p for p in ld_paths if not p or os.path.normpath(p) != norm_path
]
os.environ["LD_LIBRARY_PATH"] = os.pathsep.join(new_paths)


_maybe_set_cuda_compatibility_path()

import torch

from vllm.logger import init_logger
Expand Down