diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 985b97c69ca4..d1e45e7dae55 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib import importlib +import math from collections.abc import Callable from typing import Any @@ -44,6 +45,7 @@ KVCacheGroupSpec, KVCacheSpec, KVCacheTensor, + KVQuantMode, MambaSpec, MLAAttentionSpec, SlidingWindowSpec, @@ -2189,6 +2191,44 @@ def test_unify_hybrid_kv_cache_specs(): kv_cache_utils.unify_hybrid_kv_cache_specs(kv_cache_spec) +def test_unify_kv_cache_spec_page_size_uses_common_multiple_for_int8_hybrid(): + kv_cache_spec = { + "full": FullAttentionSpec( + block_size=16, + num_kv_heads=2, + head_size=512, + head_size_v=512, + dtype=torch.float16, + kv_quant_mode=KVQuantMode.INT8_PER_TOKEN_HEAD, + ), + "sliding": SlidingWindowSpec( + block_size=16, + num_kv_heads=8, + head_size=256, + dtype=torch.float16, + kv_quant_mode=KVQuantMode.INT8_PER_TOKEN_HEAD, + sliding_window=1024, + ), + } + + original_page_sizes = { + name: spec.page_size_bytes for name, spec in kv_cache_spec.items() + } + unified = kv_cache_utils.unify_kv_cache_spec_page_size(kv_cache_spec) + expected_page_size = math.lcm(*original_page_sizes.values()) + + assert unified["full"].page_size_bytes == unified["sliding"].page_size_bytes + assert unified["full"].page_size_bytes == expected_page_size + assert unified["full"].block_size == ( + 16 * expected_page_size // original_page_sizes["full"] + ) + assert unified["sliding"].block_size == ( + 16 * expected_page_size // original_page_sizes["sliding"] + ) + assert isinstance(unified["sliding"], SlidingWindowSpec) + assert unified["sliding"].sliding_window == 1024 + + def test_hma_not_disabled_when_kv_events_enabled(): """ Test enabling KV events must not force disable_hybrid_kv_cache_manager to True. diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index b57e10b67faa..68475cec47e8 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -8,7 +8,7 @@ import os from collections import defaultdict from collections.abc import Callable, Iterable, Iterator, Sequence -from dataclasses import dataclass, replace +from dataclasses import dataclass from functools import partial from typing import Any, NewType, TypeAlias, cast, overload @@ -1024,22 +1024,28 @@ def unify_kv_cache_spec_page_size( # All layers have the same page size, no need to unify. return kv_cache_spec - max_page_size = max(page_sizes) + unified_page_size = max(page_sizes) + if any(unified_page_size % page_size != 0 for page_size in page_sizes): + unified_page_size = math.lcm(*page_sizes) + new_kv_cache_spec = {} for layer_name, layer_spec in kv_cache_spec.items(): - if layer_spec.page_size_bytes == max_page_size: + if layer_spec.page_size_bytes == unified_page_size: new_kv_cache_spec[layer_name] = layer_spec else: layer_page_size = layer_spec.page_size_bytes - if max_page_size % layer_page_size != 0: + if unified_page_size % layer_page_size != 0: raise NotImplementedError( "The page size of the layer is not divisible by the " - "maximum page size. Cannot unify by adjusting block_size." + "unified page size. Cannot unify by adjusting block_size." ) - ratio = max_page_size // layer_page_size + ratio = unified_page_size // layer_page_size new_block_size = layer_spec.block_size * ratio - new_spec = replace(layer_spec, block_size=new_block_size) - assert new_spec.page_size_bytes == max_page_size + new_spec = layer_spec.copy_with_new_block_size(new_block_size) + if new_spec.page_size_bytes != unified_page_size: + raise NotImplementedError( + "Failed to unify KV cache page size after adjusting block_size." + ) new_kv_cache_spec[layer_name] = new_spec return new_kv_cache_spec @@ -1398,6 +1404,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): num_kv_heads=spec.num_kv_heads, head_size=spec.head_size, dtype=spec.dtype, + kv_quant_mode=spec.kv_quant_mode, attention_chunk_size=spec.attention_chunk_size, page_size_padded=spec.page_size_padded, ) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 11c5ee19a664..491dfbea610e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -11,9 +11,10 @@ from concurrent.futures import Future from contextlib import ExitStack, contextmanager from enum import IntEnum -from functools import partial +from functools import partial, reduce from inspect import isclass, signature from logging import DEBUG +from math import gcd from multiprocessing.queues import Queue from typing import Any, TypeVar, cast @@ -273,8 +274,9 @@ def _initialize_kv_caches(self, vllm_config: VllmConfig) -> KVCacheConfig: vllm_config.cache_config.num_gpu_blocks = scheduler_kv_cache_config.num_blocks kv_cache_groups = scheduler_kv_cache_config.kv_cache_groups if kv_cache_groups: - vllm_config.cache_config.block_size = min( - g.kv_cache_spec.block_size for g in kv_cache_groups + vllm_config.cache_config.block_size = reduce( + gcd, + (group.kv_cache_spec.block_size for group in kv_cache_groups), ) vllm_config.validate_block_size()