Skip to content

Commit 6c05536

Browse files
NickLuccheFeiDaLI
authored andcommitted
[XPU] Set consistent default KV cache layout (vllm-project#24745)
Signed-off-by: NickLucche <[email protected]>
1 parent d87f284 commit 6c05536

File tree

3 files changed

+23
-16
lines changed

3 files changed

+23
-16
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@
5656
logger.warning("NIXL is not available")
5757
NixlWrapper = None
5858

59-
# Supported xPUs and types of kv transfer buffer.
60-
# {xPU: tuple of supported kv buffer types}
61-
_NIXL_SUPPORTED_XPUS = {
59+
# Supported platforms and types of kv transfer buffer.
60+
# {device: tuple of supported kv buffer types}
61+
_NIXL_SUPPORTED_DEVICE = {
6262
"cuda": ("cuda", ),
6363
"tpu": ("cpu", ),
6464
"xpu": ("cpu", ),
@@ -458,17 +458,17 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
458458
self.device_type = current_platform.device_type
459459
self.kv_buffer_device: str = \
460460
vllm_config.kv_transfer_config.kv_buffer_device
461-
if self.device_type not in _NIXL_SUPPORTED_XPUS:
461+
if self.device_type not in _NIXL_SUPPORTED_DEVICE:
462462
raise RuntimeError(f"{self.device_type} is not supported.")
463-
elif self.kv_buffer_device not in _NIXL_SUPPORTED_XPUS[
463+
elif self.kv_buffer_device not in _NIXL_SUPPORTED_DEVICE[
464464
self.device_type]:
465465
raise RuntimeError(
466466
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
467467
"is not supported.")
468468
self.device_kv_caches: dict[str, torch.Tensor] = {}
469469

470470
# cpu kv buffer for xfer
471-
# used when xPU memory can not be registered under nixl
471+
# used when device memory can not be registered under nixl
472472
self.host_xfer_buffers: dict[str, torch.Tensor] = {}
473473
self.use_host_buffer = self.kv_buffer_device == "cpu"
474474
if self.kv_buffer_device == "cuda":
@@ -927,6 +927,9 @@ def add_remote_agent(self,
927927
if tp_ratio > 1:
928928
# Heterogeneous TP expects same kv_cache_layout.
929929
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
930+
if self.device_type == "xpu":
931+
raise ValueError(
932+
"Heterogeneous TP is not supported on XPU")
930933

931934
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
932935
"Remote P worker KV layer cache must be of shape [2, N, "

vllm/platforms/xpu.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import vllm.envs as envs
1010
from vllm.logger import init_logger
1111
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
12+
from vllm.v1.attention.backends.utils import set_kv_cache_layout
1213

1314
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
1415

@@ -164,12 +165,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
164165
vllm_config.scheduler_config.max_model_len,
165166
DEFAULT_MAX_NUM_BATCHED_TOKENS)
166167

167-
if (envs.VLLM_KV_CACHE_LAYOUT is None
168-
or envs.VLLM_KV_CACHE_LAYOUT != "NHD"):
169-
os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD"
170-
logger.info(
171-
"Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
172-
"only NHD layout is supported by XPU attention kernels.")
168+
set_kv_cache_layout("NHD")
169+
logger.info("Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
170+
"only NHD layout is supported by XPU attention kernels.")
173171

174172
@classmethod
175173
def is_pin_memory_available(cls):

vllm/v1/attention/backends/utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import functools
66
from abc import abstractmethod
77
from dataclasses import dataclass, fields, make_dataclass
8-
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol,
9-
TypeVar)
8+
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Literal, Optional,
9+
Protocol, TypeVar, Union, get_args)
1010

1111
import numpy as np
1212
import torch
@@ -30,7 +30,12 @@
3030
from vllm.v1.kv_cache_interface import AttentionSpec
3131

3232
logger = init_logger(__name__)
33-
_KV_CACHE_LAYOUT_OVERRIDE = None
33+
KVCacheLayoutType = Literal["NHD", "HND"]
34+
_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None
35+
36+
37+
def is_valid_kv_cache_layout(value: str) -> bool:
38+
return value in get_args(KVCacheLayoutType)
3439

3540

3641
@dataclass
@@ -296,12 +301,13 @@ def get_kv_cache_layout():
296301
if cache_layout is None:
297302
cache_layout = get_kv_connector_cache_layout()
298303
else:
304+
assert is_valid_kv_cache_layout(cache_layout)
299305
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
300306
"detected. Setting KV cache layout to %s.", cache_layout)
301307
return cache_layout
302308

303309

304-
def set_kv_cache_layout(cache_layout: str):
310+
def set_kv_cache_layout(cache_layout: KVCacheLayoutType):
305311
global _KV_CACHE_LAYOUT_OVERRIDE
306312
_KV_CACHE_LAYOUT_OVERRIDE = cache_layout
307313

0 commit comments

Comments
 (0)