-
-
Notifications
You must be signed in to change notification settings - Fork 14.8k
In-Tree AMD Zen CPU Backend via zentorch [1/N] #35970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
tlrmchlsmth
merged 11 commits into
vllm-project:main
from
amd-lalithnc:zentorch-upstreaming
Mar 15, 2026
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
a23c37c
zentorch platform upstreaming
amd-lalithnc ad93740
Address Review Comments
amd-lalithnc f207ab7
fix pre-commit failures
amd-lalithnc f6c2a80
Address review comments
amd-lalithnc d73c2b1
Address review comments
amd-lalithnc 65e8a56
Fix Linter Issues
amd-lalithnc 9b77dfc
support pypi based dockerfile build with zentorch
amd-lalithnc 43d1664
Merge branch 'main' into zentorch-upstreaming
tlrmchlsmth b3a1785
Address review comments in zen_cpu.py
tlrmchlsmth 60f0c74
Merge branch 'main' into zentorch-upstreaming
tlrmchlsmth 04b4684
Rename vllm-openai-amd Docker target to vllm-openai-zen
tlrmchlsmth File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
tests/model_executor/test_cpu_unquantized_gemm_dispatch.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Tests for CPU unquantized GEMM dispatch behavior.""" | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.model_executor.layers import utils | ||
| from vllm.platforms import current_platform | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def _mock_zentorch_linear_unary(): | ||
| """Register a mock zentorch_linear_unary op when zentorch is not installed. | ||
|
|
||
| Allows the dispatch tests to run in CI without a real zentorch build. | ||
| Skips registration when zentorch is already available. | ||
| """ | ||
| if hasattr(torch.ops.zentorch, "zentorch_linear_unary"): | ||
| yield | ||
| return | ||
|
|
||
| lib_def = torch.library.Library("zentorch", "DEF") | ||
| lib_def.define( | ||
| "zentorch_linear_unary(" | ||
| "Tensor input, " | ||
| "Tensor weight, " | ||
| "Tensor? bias, " | ||
| "bool is_weight_prepacked=False" | ||
| ") -> Tensor" | ||
| ) | ||
|
|
||
| lib_impl = torch.library.Library("zentorch", "IMPL", "CPU") | ||
| lib_impl.impl( | ||
| "zentorch_linear_unary", | ||
| lambda input, weight, bias, is_weight_prepacked=False: ( | ||
| torch.nn.functional.linear(input, weight, bias) | ||
| ), | ||
| ) | ||
|
|
||
| yield | ||
|
|
||
| lib_impl._destroy() | ||
| lib_def._destroy() | ||
|
|
||
|
|
||
| @pytest.mark.usefixtures("_mock_zentorch_linear_unary") | ||
| def test_dispatch_cpu_unquantized_gemm_uses_zentorch_on_zen(monkeypatch): | ||
| monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True) | ||
|
|
||
| layer = torch.nn.Linear(16, 8, bias=True) | ||
| x = torch.randn(4, 16) | ||
| expected = torch.nn.functional.linear(x, layer.weight, layer.bias) | ||
|
|
||
| utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=False) | ||
| output = layer.cpu_linear(x, layer.weight, layer.bias) | ||
|
|
||
| torch.testing.assert_close(output, expected) | ||
|
|
||
|
|
||
| @pytest.mark.usefixtures("_mock_zentorch_linear_unary") | ||
| def test_dispatch_cpu_unquantized_gemm_zen_remove_weight(monkeypatch): | ||
| monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True) | ||
|
|
||
| layer = torch.nn.Linear(16, 8, bias=True) | ||
| utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=True) | ||
|
|
||
| assert layer.weight.numel() == 0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| from unittest.mock import mock_open, patch | ||
|
|
||
| from vllm.platforms import _is_amd_zen_cpu | ||
|
|
||
|
|
||
| def test_is_amd_zen_cpu_detects_amd_with_avx512(): | ||
| cpuinfo = "vendor_id: AuthenticAMD\nflags: avx avx2 avx512f avx512bw" | ||
| with ( | ||
| patch("os.path.exists", return_value=True), | ||
| patch("builtins.open", mock_open(read_data=cpuinfo)), | ||
| ): | ||
| assert _is_amd_zen_cpu() | ||
|
|
||
|
|
||
| def test_is_amd_zen_cpu_returns_false_for_amd_without_avx512(): | ||
| cpuinfo = "vendor_id: AuthenticAMD\nflags: avx avx2" | ||
| with ( | ||
| patch("os.path.exists", return_value=True), | ||
| patch("builtins.open", mock_open(read_data=cpuinfo)), | ||
| ): | ||
| assert not _is_amd_zen_cpu() | ||
|
|
||
|
|
||
| def test_is_amd_zen_cpu_returns_false_for_intel_with_avx512(): | ||
| cpuinfo = "vendor_id: GenuineIntel\nflags: avx avx2 avx512f" | ||
| with ( | ||
| patch("os.path.exists", return_value=True), | ||
| patch("builtins.open", mock_open(read_data=cpuinfo)), | ||
| ): | ||
| assert not _is_amd_zen_cpu() | ||
|
|
||
|
|
||
| def test_is_amd_zen_cpu_returns_false_when_cpuinfo_missing(): | ||
| with patch("os.path.exists", return_value=False): | ||
| assert not _is_amd_zen_cpu() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| from vllm.logger import init_logger | ||
| from vllm.platforms.cpu import CpuPlatform | ||
| from vllm.utils.torch_utils import is_torch_equal_or_newer | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from vllm.config import VllmConfig | ||
|
|
||
|
|
||
| class ZenCpuPlatform(CpuPlatform): | ||
| """CPU platform with AMD Zen (ZenDNN/zentorch) optimizations. | ||
|
|
||
| Model-load time (dispatch_cpu_unquantized_gemm in layers/utils.py): | ||
| - Routes linear ops to zentorch_linear_unary. | ||
| - When VLLM_ZENTORCH_WEIGHT_PREPACK=1 (default), eagerly prepacks | ||
| weights via zentorch_weight_prepack_for_linear. | ||
| """ | ||
|
|
||
| device_name: str = "cpu" | ||
| device_type: str = "cpu" | ||
|
|
||
| def is_zen_cpu(self) -> bool: | ||
tlrmchlsmth marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # is_cpu() also returns True for this platform (inherited from CpuPlatform). | ||
| return True | ||
|
|
||
| @classmethod | ||
| def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: | ||
| super().check_and_update_config(vllm_config) | ||
| cls._apply_pytorch_backports() | ||
|
|
||
| @classmethod | ||
| def _apply_pytorch_backports(cls): | ||
| """Backport PyTorch mainline fixes missing in 2.10. | ||
|
|
||
| PyTorch 2.10 has a bug in FxGraphCachePickler.dumps that doesn't | ||
| catch ValueError, causing torch.compile cache misses. Remove this | ||
| once we drop PyTorch 2.10 support. PT mainline already has this fix. | ||
| """ | ||
| if not is_torch_equal_or_newer("2.10.0") or is_torch_equal_or_newer("2.11.0"): | ||
| return | ||
tlrmchlsmth marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| cls._patch_fxgraphcache_pickle() | ||
|
|
||
| @classmethod | ||
| def _patch_fxgraphcache_pickle(cls): | ||
| """Backport mainline ValueError fix to FxGraphCachePickler.dumps().""" | ||
| from torch._inductor.codecache import BypassFxGraphCache, FxGraphCachePickler | ||
|
|
||
| original_dumps = FxGraphCachePickler.dumps | ||
| if hasattr(original_dumps, "_zen_patched"): | ||
| return | ||
|
|
||
| def patched_dumps(self, obj): | ||
| try: | ||
| return original_dumps(self, obj) | ||
| except ValueError as e: | ||
| raise BypassFxGraphCache("Failed to pickle cache key") from e | ||
|
|
||
| patched_dumps._zen_patched = True # type: ignore[attr-defined] | ||
| FxGraphCachePickler.dumps = patched_dumps | ||
| logger.info("[zen_cpu] Patched FxGraphCachePickler.dumps (ValueError fix)") | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're trying to keep the number of environment variables minimal in vLLM.
Why wouldn't someone want to pre-pack the weights? If there's no compelling reason, could we remove the env?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @tlrmchlsmth since this environment variable is enabled by default it should be transparent for most users. Currently, this is mostly a debug feature for enabling other kernel variants from the zendnn library backend