Skip to content
17 changes: 17 additions & 0 deletions docker/Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#
# Build targets:
# vllm-openai (default): used for serving deployment
# vllm-openai-zen: vLLM from source + zentorch from PyPI via vllm[zen]
# vllm-test: used for CI tests
# vllm-dev: used for development
#
Expand Down Expand Up @@ -222,3 +223,19 @@ LABEL ai.vllm.build.cpu-arm-bf16="${VLLM_CPU_ARM_BF16:-false}"
LABEL ai.vllm.build.python-version="${PYTHON_VERSION:-3.12}"

ENTRYPOINT ["vllm", "serve"]


######################### ZEN CPU PYPI IMAGE #########################
FROM vllm-openai AS vllm-openai-zen

ARG TARGETARCH

RUN if [ "$TARGETARCH" != "amd64" ]; then \
echo "ERROR: vllm-openai-amd only supports --platform=linux/amd64"; \
exit 1; \
fi

RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install "vllm[zen]"

ENTRYPOINT ["vllm", "serve"]
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,8 @@ def _read_requirements(filename: str) -> list[str]:
ext_modules=ext_modules,
install_requires=get_requirements(),
extras_require={
# AMD Zen CPU optimizations via zentorch
"zen": ["zentorch"],
"bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy", "plotly"],
"tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.2.2"],
Expand Down
68 changes: 68 additions & 0 deletions tests/model_executor/test_cpu_unquantized_gemm_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for CPU unquantized GEMM dispatch behavior."""

import pytest
import torch

from vllm.model_executor.layers import utils
from vllm.platforms import current_platform


@pytest.fixture(scope="module")
def _mock_zentorch_linear_unary():
"""Register a mock zentorch_linear_unary op when zentorch is not installed.

Allows the dispatch tests to run in CI without a real zentorch build.
Skips registration when zentorch is already available.
"""
if hasattr(torch.ops.zentorch, "zentorch_linear_unary"):
yield
return

lib_def = torch.library.Library("zentorch", "DEF")
lib_def.define(
"zentorch_linear_unary("
"Tensor input, "
"Tensor weight, "
"Tensor? bias, "
"bool is_weight_prepacked=False"
") -> Tensor"
)

lib_impl = torch.library.Library("zentorch", "IMPL", "CPU")
lib_impl.impl(
"zentorch_linear_unary",
lambda input, weight, bias, is_weight_prepacked=False: (
torch.nn.functional.linear(input, weight, bias)
),
)

yield

lib_impl._destroy()
lib_def._destroy()


@pytest.mark.usefixtures("_mock_zentorch_linear_unary")
def test_dispatch_cpu_unquantized_gemm_uses_zentorch_on_zen(monkeypatch):
monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True)

layer = torch.nn.Linear(16, 8, bias=True)
x = torch.randn(4, 16)
expected = torch.nn.functional.linear(x, layer.weight, layer.bias)

utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=False)
output = layer.cpu_linear(x, layer.weight, layer.bias)

torch.testing.assert_close(output, expected)


@pytest.mark.usefixtures("_mock_zentorch_linear_unary")
def test_dispatch_cpu_unquantized_gemm_zen_remove_weight(monkeypatch):
monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True)

layer = torch.nn.Linear(16, 8, bias=True)
utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=True)

assert layer.weight.numel() == 0
37 changes: 37 additions & 0 deletions tests/test_zen_cpu_platform_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import mock_open, patch

from vllm.platforms import _is_amd_zen_cpu


def test_is_amd_zen_cpu_detects_amd_with_avx512():
cpuinfo = "vendor_id: AuthenticAMD\nflags: avx avx2 avx512f avx512bw"
with (
patch("os.path.exists", return_value=True),
patch("builtins.open", mock_open(read_data=cpuinfo)),
):
assert _is_amd_zen_cpu()


def test_is_amd_zen_cpu_returns_false_for_amd_without_avx512():
cpuinfo = "vendor_id: AuthenticAMD\nflags: avx avx2"
with (
patch("os.path.exists", return_value=True),
patch("builtins.open", mock_open(read_data=cpuinfo)),
):
assert not _is_amd_zen_cpu()


def test_is_amd_zen_cpu_returns_false_for_intel_with_avx512():
cpuinfo = "vendor_id: GenuineIntel\nflags: avx avx2 avx512f"
with (
patch("os.path.exists", return_value=True),
patch("builtins.open", mock_open(read_data=cpuinfo)),
):
assert not _is_amd_zen_cpu()


def test_is_amd_zen_cpu_returns_false_when_cpuinfo_missing():
with patch("os.path.exists", return_value=False):
assert not _is_amd_zen_cpu()
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
VLLM_CPU_OMP_THREADS_BIND: str = "auto"
VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None
VLLM_CPU_SGL_KERNEL: bool = False
VLLM_ZENTORCH_WEIGHT_PREPACK: bool = True
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
VLLM_XLA_CHECK_RECOMPILATION: bool = False
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
Expand Down Expand Up @@ -709,6 +710,11 @@ def _get_or_set_default() -> str:
else None,
# (CPU backend only) whether to use SGL kernels, optimized for small batch.
"VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))),
# (Zen CPU backend) eagerly prepack weights into ZenDNN blocked layout
# at model load time. Eliminates per-inference layout conversion overhead.
"VLLM_ZENTORCH_WEIGHT_PREPACK": lambda: bool(
int(os.getenv("VLLM_ZENTORCH_WEIGHT_PREPACK", "1"))
),
Comment on lines +713 to +717
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're trying to keep the number of environment variables minimal in vLLM.

Why wouldn't someone want to pre-pack the weights? If there's no compelling reason, could we remove the env?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @tlrmchlsmth since this environment variable is enabled by default it should be transparent for most users. Currently, this is mostly a debug feature for enabling other kernel variants from the zendnn library backend

# If the env var is set, Ray Compiled Graph uses the specified
# channel type to communicate between workers belonging to
# different pipeline-parallel stages.
Expand Down Expand Up @@ -1768,6 +1774,7 @@ def compile_factors() -> dict[str, object]:
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",
"VLLM_CPU_KVCACHE_SPACE",
"VLLM_CPU_MOE_PREPACK",
"VLLM_ZENTORCH_WEIGHT_PREPACK",
"VLLM_TEST_FORCE_LOAD_FORMAT",
"VLLM_ENABLE_CUDA_COMPATIBILITY",
"VLLM_CUDA_COMPATIBILITY_PATH",
Expand Down
24 changes: 24 additions & 0 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,30 @@ def dispatch_cpu_unquantized_gemm(
N, K = layer.weight.size()
dtype = layer.weight.dtype

# Zen CPU path: zentorch_linear_unary with optional eager weight prepacking.
if current_platform.is_zen_cpu() and hasattr(
torch.ops.zentorch, "zentorch_linear_unary"
):
zen_weight = layer.weight.detach()
is_prepacked = False

if envs.VLLM_ZENTORCH_WEIGHT_PREPACK and hasattr(
torch.ops.zentorch, "zentorch_weight_prepack_for_linear"
):
zen_weight = torch.ops.zentorch.zentorch_weight_prepack_for_linear(
zen_weight
)
is_prepacked = True

layer.cpu_linear = lambda x, weight, bias, _p=is_prepacked: (
torch.ops.zentorch.zentorch_linear_unary(
x, zen_weight, bias, is_weight_prepacked=_p
)
)
if remove_weight:
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
return

if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype):
packed_weight = torch.ops._C.convert_weight_packed(layer.weight)
if getattr(layer, "bias", None) is not None:
Expand Down
38 changes: 36 additions & 2 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import os
import traceback
from itertools import chain
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -150,6 +151,15 @@ def xpu_platform_plugin() -> str | None:
return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None


def _is_amd_zen_cpu() -> bool:
"""Detect AMD CPU with AVX-512 via /proc/cpuinfo."""
if not os.path.exists("/proc/cpuinfo"):
return False
with open("/proc/cpuinfo") as f:
cpuinfo = f.read()
return "AuthenticAMD" in cpuinfo and "avx512" in cpuinfo


def cpu_platform_plugin() -> str | None:
is_cpu = False
logger.debug("Checking if CPU platform is available.")
Expand All @@ -171,7 +181,24 @@ def cpu_platform_plugin() -> str | None:
except Exception as e:
logger.debug("CPU platform is not available because: %s", str(e))

return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None
if not is_cpu:
return None

if _is_amd_zen_cpu():
try:
import zentorch # noqa: F401

logger.debug(
"AMD Zen CPU detected with zentorch installed, using ZenCpuPlatform."
)
return "vllm.platforms.zen_cpu.ZenCpuPlatform"
except ImportError:
logger.debug(
"AMD Zen CPU detected but zentorch not installed, "
"falling back to CpuPlatform."
)

return "vllm.platforms.cpu.CpuPlatform"


builtin_platform_plugins = {
Expand Down Expand Up @@ -269,4 +296,11 @@ def __setattr__(name: str, value):
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")


__all__ = ["Platform", "PlatformEnum", "current_platform", "CpuArchEnum", "_init_trace"]
__all__ = [
"Platform",
"PlatformEnum",
"current_platform",
"CpuArchEnum",
"_init_trace",
"_is_amd_zen_cpu",
]
3 changes: 3 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def is_xpu(self) -> bool:
def is_cpu(self) -> bool:
return self._enum == PlatformEnum.CPU

def is_zen_cpu(self) -> bool:
return False

def is_out_of_tree(self) -> bool:
return self._enum == PlatformEnum.OOT

Expand Down
67 changes: 67 additions & 0 deletions vllm/platforms/zen_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import TYPE_CHECKING

from vllm.logger import init_logger
from vllm.platforms.cpu import CpuPlatform
from vllm.utils.torch_utils import is_torch_equal_or_newer

logger = init_logger(__name__)

if TYPE_CHECKING:
from vllm.config import VllmConfig


class ZenCpuPlatform(CpuPlatform):
"""CPU platform with AMD Zen (ZenDNN/zentorch) optimizations.

Model-load time (dispatch_cpu_unquantized_gemm in layers/utils.py):
- Routes linear ops to zentorch_linear_unary.
- When VLLM_ZENTORCH_WEIGHT_PREPACK=1 (default), eagerly prepacks
weights via zentorch_weight_prepack_for_linear.
"""

device_name: str = "cpu"
device_type: str = "cpu"

def is_zen_cpu(self) -> bool:
# is_cpu() also returns True for this platform (inherited from CpuPlatform).
return True

@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
super().check_and_update_config(vllm_config)
cls._apply_pytorch_backports()

@classmethod
def _apply_pytorch_backports(cls):
"""Backport PyTorch mainline fixes missing in 2.10.

PyTorch 2.10 has a bug in FxGraphCachePickler.dumps that doesn't
catch ValueError, causing torch.compile cache misses. Remove this
once we drop PyTorch 2.10 support. PT mainline already has this fix.
"""
if not is_torch_equal_or_newer("2.10.0") or is_torch_equal_or_newer("2.11.0"):
return

cls._patch_fxgraphcache_pickle()

@classmethod
def _patch_fxgraphcache_pickle(cls):
"""Backport mainline ValueError fix to FxGraphCachePickler.dumps()."""
from torch._inductor.codecache import BypassFxGraphCache, FxGraphCachePickler

original_dumps = FxGraphCachePickler.dumps
if hasattr(original_dumps, "_zen_patched"):
return

def patched_dumps(self, obj):
try:
return original_dumps(self, obj)
except ValueError as e:
raise BypassFxGraphCache("Failed to pickle cache key") from e

patched_dumps._zen_patched = True # type: ignore[attr-defined]
FxGraphCachePickler.dumps = patched_dumps
logger.info("[zen_cpu] Patched FxGraphCachePickler.dumps (ValueError fix)")
Loading