diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index 129ec210f546..5f819acc6aea 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -9,6 +9,7 @@ # # Build targets: # vllm-openai (default): used for serving deployment +# vllm-openai-zen: vLLM from source + zentorch from PyPI via vllm[zen] # vllm-test: used for CI tests # vllm-dev: used for development # @@ -222,3 +223,19 @@ LABEL ai.vllm.build.cpu-arm-bf16="${VLLM_CPU_ARM_BF16:-false}" LABEL ai.vllm.build.python-version="${PYTHON_VERSION:-3.12}" ENTRYPOINT ["vllm", "serve"] + + +######################### ZEN CPU PYPI IMAGE ######################### +FROM vllm-openai AS vllm-openai-zen + +ARG TARGETARCH + +RUN if [ "$TARGETARCH" != "amd64" ]; then \ + echo "ERROR: vllm-openai-amd only supports --platform=linux/amd64"; \ + exit 1; \ + fi + +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install "vllm[zen]" + +ENTRYPOINT ["vllm", "serve"] diff --git a/setup.py b/setup.py index 829552fba320..d5782a81d853 100644 --- a/setup.py +++ b/setup.py @@ -966,6 +966,8 @@ def _read_requirements(filename: str) -> list[str]: ext_modules=ext_modules, install_requires=get_requirements(), extras_require={ + # AMD Zen CPU optimizations via zentorch + "zen": ["zentorch"], "bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy", "plotly"], "tensorizer": ["tensorizer==2.10.1"], "fastsafetensors": ["fastsafetensors >= 0.2.2"], diff --git a/tests/model_executor/test_cpu_unquantized_gemm_dispatch.py b/tests/model_executor/test_cpu_unquantized_gemm_dispatch.py new file mode 100644 index 000000000000..322897c02468 --- /dev/null +++ b/tests/model_executor/test_cpu_unquantized_gemm_dispatch.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for CPU unquantized GEMM dispatch behavior.""" + +import pytest +import torch + +from vllm.model_executor.layers import utils +from vllm.platforms import current_platform + + +@pytest.fixture(scope="module") +def _mock_zentorch_linear_unary(): + """Register a mock zentorch_linear_unary op when zentorch is not installed. + + Allows the dispatch tests to run in CI without a real zentorch build. + Skips registration when zentorch is already available. + """ + if hasattr(torch.ops.zentorch, "zentorch_linear_unary"): + yield + return + + lib_def = torch.library.Library("zentorch", "DEF") + lib_def.define( + "zentorch_linear_unary(" + "Tensor input, " + "Tensor weight, " + "Tensor? bias, " + "bool is_weight_prepacked=False" + ") -> Tensor" + ) + + lib_impl = torch.library.Library("zentorch", "IMPL", "CPU") + lib_impl.impl( + "zentorch_linear_unary", + lambda input, weight, bias, is_weight_prepacked=False: ( + torch.nn.functional.linear(input, weight, bias) + ), + ) + + yield + + lib_impl._destroy() + lib_def._destroy() + + +@pytest.mark.usefixtures("_mock_zentorch_linear_unary") +def test_dispatch_cpu_unquantized_gemm_uses_zentorch_on_zen(monkeypatch): + monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True) + + layer = torch.nn.Linear(16, 8, bias=True) + x = torch.randn(4, 16) + expected = torch.nn.functional.linear(x, layer.weight, layer.bias) + + utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=False) + output = layer.cpu_linear(x, layer.weight, layer.bias) + + torch.testing.assert_close(output, expected) + + +@pytest.mark.usefixtures("_mock_zentorch_linear_unary") +def test_dispatch_cpu_unquantized_gemm_zen_remove_weight(monkeypatch): + monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True) + + layer = torch.nn.Linear(16, 8, bias=True) + utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=True) + + assert layer.weight.numel() == 0 diff --git a/tests/test_zen_cpu_platform_detection.py b/tests/test_zen_cpu_platform_detection.py new file mode 100644 index 000000000000..a1798d2b52a3 --- /dev/null +++ b/tests/test_zen_cpu_platform_detection.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from unittest.mock import mock_open, patch + +from vllm.platforms import _is_amd_zen_cpu + + +def test_is_amd_zen_cpu_detects_amd_with_avx512(): + cpuinfo = "vendor_id: AuthenticAMD\nflags: avx avx2 avx512f avx512bw" + with ( + patch("os.path.exists", return_value=True), + patch("builtins.open", mock_open(read_data=cpuinfo)), + ): + assert _is_amd_zen_cpu() + + +def test_is_amd_zen_cpu_returns_false_for_amd_without_avx512(): + cpuinfo = "vendor_id: AuthenticAMD\nflags: avx avx2" + with ( + patch("os.path.exists", return_value=True), + patch("builtins.open", mock_open(read_data=cpuinfo)), + ): + assert not _is_amd_zen_cpu() + + +def test_is_amd_zen_cpu_returns_false_for_intel_with_avx512(): + cpuinfo = "vendor_id: GenuineIntel\nflags: avx avx2 avx512f" + with ( + patch("os.path.exists", return_value=True), + patch("builtins.open", mock_open(read_data=cpuinfo)), + ): + assert not _is_amd_zen_cpu() + + +def test_is_amd_zen_cpu_returns_false_when_cpuinfo_missing(): + with patch("os.path.exists", return_value=False): + assert not _is_amd_zen_cpu() diff --git a/vllm/envs.py b/vllm/envs.py index d310e9e1307d..caa2fb38afb6 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -51,6 +51,7 @@ VLLM_CPU_OMP_THREADS_BIND: str = "auto" VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None VLLM_CPU_SGL_KERNEL: bool = False + VLLM_ZENTORCH_WEIGHT_PREPACK: bool = True VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CHECK_RECOMPILATION: bool = False VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto" @@ -709,6 +710,11 @@ def _get_or_set_default() -> str: else None, # (CPU backend only) whether to use SGL kernels, optimized for small batch. "VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))), + # (Zen CPU backend) eagerly prepack weights into ZenDNN blocked layout + # at model load time. Eliminates per-inference layout conversion overhead. + "VLLM_ZENTORCH_WEIGHT_PREPACK": lambda: bool( + int(os.getenv("VLLM_ZENTORCH_WEIGHT_PREPACK", "1")) + ), # If the env var is set, Ray Compiled Graph uses the specified # channel type to communicate between workers belonging to # different pipeline-parallel stages. @@ -1768,6 +1774,7 @@ def compile_factors() -> dict[str, object]: "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "VLLM_CPU_KVCACHE_SPACE", "VLLM_CPU_MOE_PREPACK", + "VLLM_ZENTORCH_WEIGHT_PREPACK", "VLLM_TEST_FORCE_LOAD_FORMAT", "VLLM_ENABLE_CUDA_COMPATIBILITY", "VLLM_CUDA_COMPATIBILITY_PATH", diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index e46e4fd39a69..5a526f12776f 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -231,6 +231,30 @@ def dispatch_cpu_unquantized_gemm( N, K = layer.weight.size() dtype = layer.weight.dtype + # Zen CPU path: zentorch_linear_unary with optional eager weight prepacking. + if current_platform.is_zen_cpu() and hasattr( + torch.ops.zentorch, "zentorch_linear_unary" + ): + zen_weight = layer.weight.detach() + is_prepacked = False + + if envs.VLLM_ZENTORCH_WEIGHT_PREPACK and hasattr( + torch.ops.zentorch, "zentorch_weight_prepack_for_linear" + ): + zen_weight = torch.ops.zentorch.zentorch_weight_prepack_for_linear( + zen_weight + ) + is_prepacked = True + + layer.cpu_linear = lambda x, weight, bias, _p=is_prepacked: ( + torch.ops.zentorch.zentorch_linear_unary( + x, zen_weight, bias, is_weight_prepacked=_p + ) + ) + if remove_weight: + layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) + return + if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype): packed_weight = torch.ops._C.convert_weight_packed(layer.weight) if getattr(layer, "bias", None) is not None: diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 2630df62d334..af344acfcbc7 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging +import os import traceback from itertools import chain from typing import TYPE_CHECKING @@ -150,6 +151,15 @@ def xpu_platform_plugin() -> str | None: return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None +def _is_amd_zen_cpu() -> bool: + """Detect AMD CPU with AVX-512 via /proc/cpuinfo.""" + if not os.path.exists("/proc/cpuinfo"): + return False + with open("/proc/cpuinfo") as f: + cpuinfo = f.read() + return "AuthenticAMD" in cpuinfo and "avx512" in cpuinfo + + def cpu_platform_plugin() -> str | None: is_cpu = False logger.debug("Checking if CPU platform is available.") @@ -171,7 +181,24 @@ def cpu_platform_plugin() -> str | None: except Exception as e: logger.debug("CPU platform is not available because: %s", str(e)) - return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None + if not is_cpu: + return None + + if _is_amd_zen_cpu(): + try: + import zentorch # noqa: F401 + + logger.debug( + "AMD Zen CPU detected with zentorch installed, using ZenCpuPlatform." + ) + return "vllm.platforms.zen_cpu.ZenCpuPlatform" + except ImportError: + logger.debug( + "AMD Zen CPU detected but zentorch not installed, " + "falling back to CpuPlatform." + ) + + return "vllm.platforms.cpu.CpuPlatform" builtin_platform_plugins = { @@ -269,4 +296,11 @@ def __setattr__(name: str, value): raise AttributeError(f"No attribute named '{name}' exists in {__name__}.") -__all__ = ["Platform", "PlatformEnum", "current_platform", "CpuArchEnum", "_init_trace"] +__all__ = [ + "Platform", + "PlatformEnum", + "current_platform", + "CpuArchEnum", + "_init_trace", + "_is_amd_zen_cpu", +] diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index b538524995a5..619b403ba4c1 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -167,6 +167,9 @@ def is_xpu(self) -> bool: def is_cpu(self) -> bool: return self._enum == PlatformEnum.CPU + def is_zen_cpu(self) -> bool: + return False + def is_out_of_tree(self) -> bool: return self._enum == PlatformEnum.OOT diff --git a/vllm/platforms/zen_cpu.py b/vllm/platforms/zen_cpu.py new file mode 100644 index 000000000000..62ba37a74c8d --- /dev/null +++ b/vllm/platforms/zen_cpu.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TYPE_CHECKING + +from vllm.logger import init_logger +from vllm.platforms.cpu import CpuPlatform +from vllm.utils.torch_utils import is_torch_equal_or_newer + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +class ZenCpuPlatform(CpuPlatform): + """CPU platform with AMD Zen (ZenDNN/zentorch) optimizations. + + Model-load time (dispatch_cpu_unquantized_gemm in layers/utils.py): + - Routes linear ops to zentorch_linear_unary. + - When VLLM_ZENTORCH_WEIGHT_PREPACK=1 (default), eagerly prepacks + weights via zentorch_weight_prepack_for_linear. + """ + + device_name: str = "cpu" + device_type: str = "cpu" + + def is_zen_cpu(self) -> bool: + # is_cpu() also returns True for this platform (inherited from CpuPlatform). + return True + + @classmethod + def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: + super().check_and_update_config(vllm_config) + cls._apply_pytorch_backports() + + @classmethod + def _apply_pytorch_backports(cls): + """Backport PyTorch mainline fixes missing in 2.10. + + PyTorch 2.10 has a bug in FxGraphCachePickler.dumps that doesn't + catch ValueError, causing torch.compile cache misses. Remove this + once we drop PyTorch 2.10 support. PT mainline already has this fix. + """ + if not is_torch_equal_or_newer("2.10.0") or is_torch_equal_or_newer("2.11.0"): + return + + cls._patch_fxgraphcache_pickle() + + @classmethod + def _patch_fxgraphcache_pickle(cls): + """Backport mainline ValueError fix to FxGraphCachePickler.dumps().""" + from torch._inductor.codecache import BypassFxGraphCache, FxGraphCachePickler + + original_dumps = FxGraphCachePickler.dumps + if hasattr(original_dumps, "_zen_patched"): + return + + def patched_dumps(self, obj): + try: + return original_dumps(self, obj) + except ValueError as e: + raise BypassFxGraphCache("Failed to pickle cache key") from e + + patched_dumps._zen_patched = True # type: ignore[attr-defined] + FxGraphCachePickler.dumps = patched_dumps + logger.info("[zen_cpu] Patched FxGraphCachePickler.dumps (ValueError fix)")