diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 985b97c69ca4..f49d30c5e7e0 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -47,6 +47,8 @@ MambaSpec, MLAAttentionSpec, SlidingWindowSpec, + TQFullAttentionSpec, + TQSlidingWindowSpec, UniformTypeKVCacheSpecs, ) from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats @@ -141,6 +143,26 @@ def new_sliding_window_spec( ) +def new_tq_sliding_window_spec( + block_size=16, + num_kv_heads=2, + head_size=64, + dtype=torch.float32, + page_size_padded=None, + sliding_window=1, + tq_slot_size=80, +): + return TQSlidingWindowSpec( + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + page_size_padded=page_size_padded, + sliding_window=sliding_window, + tq_slot_size=tq_slot_size, + ) + + def new_chunked_local_attention_spec( block_size=16, num_kv_heads=2, @@ -2214,3 +2236,37 @@ def test_hma_not_disabled_when_kv_events_enabled(): assert vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is False, ( "kv_events_config must not force-disable the hybrid KV cache manager." ) + + +def test_unify_hybrid_kv_cache_specs_preserves_tq_page_size(): + before_spec_1 = new_kv_cache_spec() + before_spec_2 = new_tq_sliding_window_spec( + page_size_padded=32 * 1024, + sliding_window=1024, + tq_slot_size=80, + ) + kv_cache_spec = { + "layer_1": before_spec_1, + "layer_2": before_spec_2, + } + + kv_cache_utils.unify_hybrid_kv_cache_specs(kv_cache_spec) + + expected_spec_2 = TQFullAttentionSpec( + block_size=before_spec_2.block_size, + num_kv_heads=before_spec_2.num_kv_heads, + head_size=before_spec_2.head_size, + head_size_v=before_spec_2.head_size_v, + dtype=before_spec_2.dtype, + kv_quant_mode=before_spec_2.kv_quant_mode, + sliding_window=before_spec_2.sliding_window, + page_size_padded=before_spec_2.page_size_padded, + tq_slot_size=before_spec_2.tq_slot_size, + ) + assert kv_cache_spec["layer_1"] == before_spec_1 + assert kv_cache_spec["layer_2"] == expected_spec_2 + assert kv_cache_spec["layer_2"].real_page_size_bytes == ( + before_spec_2.block_size + * before_spec_2.num_kv_heads + * before_spec_2.tq_slot_size + ) diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index f59830dcd741..16b3db0171fc 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -481,3 +481,33 @@ def test_predictor_matches_allocator_blocks_calculation_with_admission_cap(): f"but allocator pulled {len(new_blocks)}" ) total_computed = num_tokens + + +def test_tq_sliding_window_uses_sliding_window_manager(): + from vllm.v1 import kv_cache_interface + from vllm.v1.core import single_type_kv_cache_manager as manager_utils + + spec = kv_cache_interface.TQSlidingWindowSpec( + block_size=2, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=4, + tq_slot_size=1, + ) + block_pool = BlockPool( + num_gpu_blocks=10, + enable_caching=False, + hash_block_size=spec.block_size, + ) + + manager = manager_utils.get_manager_for_kv_cache_spec( + spec, + max_num_batched_tokens=4, + max_model_len=16, + block_pool=block_pool, + enable_caching=False, + kv_cache_group_id=0, + ) + + assert isinstance(manager, SlidingWindowManager) diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index db9ae2bbda34..b3c92a1701f9 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -541,28 +541,32 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: # Should not be called for enc-dec or encoder-only attention. assert self.attn_type == AttentionType.DECODER quant_mode = get_kv_quant_mode(self.kv_cache_dtype) - if self.sliding_window is not None: - assert not vllm_config.model_config.use_mla, ( - "MLA is not supported for slidingwindow" - ) - return SlidingWindowSpec( - block_size=block_size, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - head_size_v=self.head_size_v, - dtype=self.kv_cache_torch_dtype, - kv_quant_mode=quant_mode, - sliding_window=self.sliding_window, - ) - elif self.kv_cache_dtype.startswith("turboquant_"): + if self.kv_cache_dtype.startswith("turboquant_"): from vllm.model_executor.layers.quantization.turboquant.config import ( TurboQuantConfig, ) - from vllm.v1.kv_cache_interface import TQFullAttentionSpec + from vllm.v1.kv_cache_interface import ( + TQFullAttentionSpec, + TQSlidingWindowSpec, + ) tq_config = TurboQuantConfig.from_cache_dtype( self.kv_cache_dtype, self.head_size ) + if self.sliding_window is not None: + assert not vllm_config.model_config.use_mla, ( + "MLA is not supported for slidingwindow" + ) + return TQSlidingWindowSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + head_size_v=self.head_size_v, + dtype=self.kv_cache_torch_dtype, + kv_quant_mode=quant_mode, + sliding_window=self.sliding_window, + tq_slot_size=tq_config.slot_size_aligned, + ) return TQFullAttentionSpec( block_size=block_size, num_kv_heads=self.num_kv_heads, @@ -571,6 +575,19 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: dtype=self.kv_cache_torch_dtype, tq_slot_size=tq_config.slot_size_aligned, ) + elif self.sliding_window is not None: + assert not vllm_config.model_config.use_mla, ( + "MLA is not supported for slidingwindow" + ) + return SlidingWindowSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + head_size_v=self.head_size_v, + dtype=self.kv_cache_torch_dtype, + kv_quant_mode=quant_mode, + sliding_window=self.sliding_window, + ) else: return FullAttentionSpec( block_size=block_size, diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index b57e10b67faa..91e72deac771 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -29,6 +29,8 @@ MLAAttentionSpec, SlidingWindowMLASpec, SlidingWindowSpec, + TQFullAttentionSpec, + TQSlidingWindowSpec, UniformTypeKVCacheSpecs, ) from vllm.v1.request import Request @@ -1381,6 +1383,18 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): compress_ratio=spec.compress_ratio, model_version=spec.model_version, ) + elif isinstance(spec, TQSlidingWindowSpec): + kv_cache_spec[layer_name] = TQFullAttentionSpec( + block_size=spec.block_size, + num_kv_heads=spec.num_kv_heads, + head_size=spec.head_size, + head_size_v=spec.head_size_v, + dtype=spec.dtype, + kv_quant_mode=spec.kv_quant_mode, + sliding_window=spec.sliding_window, + page_size_padded=spec.page_size_padded, + tq_slot_size=spec.tq_slot_size, + ) elif isinstance(spec, SlidingWindowSpec): kv_cache_spec[layer_name] = FullAttentionSpec( block_size=spec.block_size, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index e8d3a6f75688..61c6ca9204e2 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -23,6 +23,7 @@ SlidingWindowMLASpec, SlidingWindowSpec, TQFullAttentionSpec, + TQSlidingWindowSpec, ) from vllm.v1.request import Request @@ -1144,6 +1145,7 @@ def __init__( TQFullAttentionSpec: FullAttentionManager, MLAAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, + TQSlidingWindowSpec: SlidingWindowManager, SlidingWindowMLASpec: SlidingWindowManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, MambaSpec: MambaManager, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 19438fb1e42d..c2dd0c871811 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -462,6 +462,19 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: return max_blocks * self.page_size_bytes +@dataclass(frozen=True, kw_only=True) +class TQSlidingWindowSpec(SlidingWindowSpec): + """SlidingWindowSpec with TQ-aware page size.""" + + tq_slot_size: int = 0 + + @property + def real_page_size_bytes(self) -> int: + if self.tq_slot_size > 0: + return self.block_size * self.num_kv_heads * self.tq_slot_size + return super().real_page_size_bytes + + @dataclass(frozen=True, kw_only=True) class SlidingWindowMLASpec(SlidingWindowSpec): """Sliding window attention with MLA cache format."""