Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import traceback
from contextlib import suppress
from itertools import chain
from typing import TYPE_CHECKING, Optional

Expand All @@ -14,6 +15,21 @@
logger = logging.getLogger(__name__)


def vllm_version_matches_substr(substr: str) -> bool:
"""
Check to see if the vLLM version matches a substring.
"""
from importlib.metadata import PackageNotFoundError, version
try:
vllm_version = version("vllm")
except PackageNotFoundError as e:
logger.warning(
"The vLLM package was not found, so its version could not be "
"inspected. This may cause platform detection to fail.")
raise e
return substr in vllm_version


def tpu_platform_plugin() -> Optional[str]:
is_tpu = False
try:
Expand All @@ -33,8 +49,6 @@ def cuda_platform_plugin() -> Optional[str]:
is_cuda = False

try:
from importlib.metadata import version

from vllm.utils import import_pynvml
pynvml = import_pynvml()
pynvml.nvmlInit()
Expand All @@ -45,7 +59,7 @@ def cuda_platform_plugin() -> Optional[str]:
# Otherwise, vllm will always activate cuda plugin
# on a GPU machine, even if in a cpu build.
is_cuda = (pynvml.nvmlDeviceGetCount() > 0
and "cpu" not in version("vllm"))
and not vllm_version_matches_substr("cpu"))
finally:
pynvml.nvmlShutdown()
except Exception as e:
Expand Down Expand Up @@ -113,8 +127,7 @@ def xpu_platform_plugin() -> Optional[str]:
def cpu_platform_plugin() -> Optional[str]:
is_cpu = False
try:
from importlib.metadata import version
is_cpu = "cpu" in version("vllm")
is_cpu = vllm_version_matches_substr("cpu")
if not is_cpu:
import platform
is_cpu = platform.machine().lower().startswith("arm")
Expand All @@ -138,11 +151,8 @@ def neuron_platform_plugin() -> Optional[str]:

def openvino_platform_plugin() -> Optional[str]:
is_openvino = False
try:
from importlib.metadata import version
is_openvino = "openvino" in version("vllm")
except Exception:
pass
with suppress(Exception):
is_openvino = vllm_version_matches_substr("openvino")

return "vllm.platforms.openvino.OpenVinoPlatform" if is_openvino else None

Expand Down