Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import importlib
import math
from collections.abc import Callable
from typing import Any

Expand Down Expand Up @@ -44,6 +45,7 @@
KVCacheGroupSpec,
KVCacheSpec,
KVCacheTensor,
KVQuantMode,
MambaSpec,
MLAAttentionSpec,
SlidingWindowSpec,
Expand Down Expand Up @@ -2189,6 +2191,44 @@ def test_unify_hybrid_kv_cache_specs():
kv_cache_utils.unify_hybrid_kv_cache_specs(kv_cache_spec)


def test_unify_kv_cache_spec_page_size_uses_common_multiple_for_int8_hybrid():
kv_cache_spec = {
"full": FullAttentionSpec(
block_size=16,
num_kv_heads=2,
head_size=512,
head_size_v=512,
dtype=torch.float16,
kv_quant_mode=KVQuantMode.INT8_PER_TOKEN_HEAD,
),
"sliding": SlidingWindowSpec(
block_size=16,
num_kv_heads=8,
head_size=256,
dtype=torch.float16,
kv_quant_mode=KVQuantMode.INT8_PER_TOKEN_HEAD,
sliding_window=1024,
),
}

original_page_sizes = {
name: spec.page_size_bytes for name, spec in kv_cache_spec.items()
}
unified = kv_cache_utils.unify_kv_cache_spec_page_size(kv_cache_spec)
expected_page_size = math.lcm(*original_page_sizes.values())

assert unified["full"].page_size_bytes == unified["sliding"].page_size_bytes
assert unified["full"].page_size_bytes == expected_page_size
assert unified["full"].block_size == (
16 * expected_page_size // original_page_sizes["full"]
)
assert unified["sliding"].block_size == (
16 * expected_page_size // original_page_sizes["sliding"]
)
assert isinstance(unified["sliding"], SlidingWindowSpec)
assert unified["sliding"].sliding_window == 1024


def test_hma_not_disabled_when_kv_events_enabled():
"""
Test enabling KV events must not force disable_hybrid_kv_cache_manager to True.
Expand Down
23 changes: 15 additions & 8 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass, replace
from dataclasses import dataclass
from functools import partial
from typing import Any, NewType, TypeAlias, cast, overload

Expand Down Expand Up @@ -1024,22 +1024,28 @@ def unify_kv_cache_spec_page_size(
# All layers have the same page size, no need to unify.
return kv_cache_spec

max_page_size = max(page_sizes)
unified_page_size = max(page_sizes)
if any(unified_page_size % page_size != 0 for page_size in page_sizes):
unified_page_size = math.lcm(*page_sizes)

new_kv_cache_spec = {}
for layer_name, layer_spec in kv_cache_spec.items():
if layer_spec.page_size_bytes == max_page_size:
if layer_spec.page_size_bytes == unified_page_size:
new_kv_cache_spec[layer_name] = layer_spec
else:
layer_page_size = layer_spec.page_size_bytes
if max_page_size % layer_page_size != 0:
if unified_page_size % layer_page_size != 0:
raise NotImplementedError(
"The page size of the layer is not divisible by the "
"maximum page size. Cannot unify by adjusting block_size."
"unified page size. Cannot unify by adjusting block_size."
)
ratio = max_page_size // layer_page_size
ratio = unified_page_size // layer_page_size
new_block_size = layer_spec.block_size * ratio
new_spec = replace(layer_spec, block_size=new_block_size)
assert new_spec.page_size_bytes == max_page_size
new_spec = layer_spec.copy_with_new_block_size(new_block_size)
if new_spec.page_size_bytes != unified_page_size:
raise NotImplementedError(
"Failed to unify KV cache page size after adjusting block_size."
)
new_kv_cache_spec[layer_name] = new_spec
return new_kv_cache_spec

Expand Down Expand Up @@ -1398,6 +1404,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
dtype=spec.dtype,
kv_quant_mode=spec.kv_quant_mode,
attention_chunk_size=spec.attention_chunk_size,
page_size_padded=spec.page_size_padded,
)
Expand Down
8 changes: 5 additions & 3 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from concurrent.futures import Future
from contextlib import ExitStack, contextmanager
from enum import IntEnum
from functools import partial
from functools import partial, reduce
from inspect import isclass, signature
from logging import DEBUG
from math import gcd
from multiprocessing.queues import Queue
from typing import Any, TypeVar, cast

Expand Down Expand Up @@ -273,8 +274,9 @@ def _initialize_kv_caches(self, vllm_config: VllmConfig) -> KVCacheConfig:
vllm_config.cache_config.num_gpu_blocks = scheduler_kv_cache_config.num_blocks
kv_cache_groups = scheduler_kv_cache_config.kv_cache_groups
if kv_cache_groups:
vllm_config.cache_config.block_size = min(
g.kv_cache_spec.block_size for g in kv_cache_groups
vllm_config.cache_config.block_size = reduce(
gcd,
(group.kv_cache_spec.block_size for group in kv_cache_groups),
)

vllm_config.validate_block_size()
Expand Down
Loading