diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 985b97c69ca4..6a9f2ec147aa 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -1742,16 +1742,41 @@ def test_get_kv_cache_config_one_worker(): ], ) - # different hidden size that cannot be aligned by using different block size + # different hidden size and different type: the page-size guard converts + # SlidingWindowSpec → FullAttentionSpec, then UniformTypeKVCacheSpecs + # handles the two FullAttentionSpecs with different head sizes. kv_cache_specs_hybrid = { "layer_1": new_kv_cache_spec(head_size=64), "layer_2": new_sliding_window_spec(head_size=96), } - - with pytest.raises(NotImplementedError): - get_kv_cache_configs( - vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] - )[0] + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, + [kv_cache_specs_hybrid], + [mem_per_block_per_layer * 2 * 32], + )[0] + expected_specs = { + "layer_1": new_kv_cache_spec(head_size=64), + "layer_2": new_kv_cache_spec(head_size=96, sliding_window=1), + } + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=25, + kv_cache_tensors=[ + KVCacheTensor( + size=mem_per_block_per_layer * 25, + shared_by=["layer_1"], + ), + KVCacheTensor( + size=new_kv_cache_spec(head_size=96).page_size_bytes * 25, + shared_by=["layer_2"], + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_1", "layer_2"], + UniformTypeKVCacheSpecs(block_size=16, kv_cache_specs=expected_specs), + ), + ], + ) # Test num_gpu_blocks_override vllm_config.cache_config.num_gpu_blocks_override = 16 diff --git a/vllm/config/model.py b/vllm/config/model.py index 22a1f6a42981..cf982bb8a4a0 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -695,8 +695,10 @@ def __post_init__( if self.disable_sliding_window: # Set after get_and_verify_max_len to ensure that max_model_len - # can be correctly capped to sliding window size - self.hf_text_config.sliding_window = None + # can be correctly capped to sliding window size. + # Use object.__setattr__ to bypass huggingface_hub strict + # dataclass validation which rejects None for int-typed fields. + object.__setattr__(self.hf_text_config, "sliding_window", None) # Avoid running try_verify_and_update_config multiple times self.config_updated = False diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1b3803139217..761a38cf3469 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1704,10 +1704,24 @@ def create_engine_config( TurboQuantConfig, ) + num_layers = model_config.hf_text_config.num_hidden_layers boundary = TurboQuantConfig.get_boundary_skip_layers(model_config) existing = set(cache_config.kv_cache_dtype_skip_layers) - cache_config.kv_cache_dtype_skip_layers = sorted( - existing | set(boundary), key=int + merged = sorted(existing | set(boundary), key=lambda x: int(x)) + + hf_cfg = model_config.hf_text_config + merged = TurboQuantConfig.apply_yoco_skip_alignment( + merged=merged, + num_layers=num_layers, + layer_types=getattr(hf_cfg, "layer_types", None) or [], + num_kv_shared=getattr(hf_cfg, "num_kv_shared_layers", 0), + ) + + cache_config.kv_cache_dtype_skip_layers = merged + logger.info( + "TQ: skipping layers %s for boundary protection (num_layers=%d)", + merged, + num_layers, ) ray_runtime_env = None diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index db9ae2bbda34..b28a41a02075 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -541,20 +541,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: # Should not be called for enc-dec or encoder-only attention. assert self.attn_type == AttentionType.DECODER quant_mode = get_kv_quant_mode(self.kv_cache_dtype) - if self.sliding_window is not None: - assert not vllm_config.model_config.use_mla, ( - "MLA is not supported for slidingwindow" - ) - return SlidingWindowSpec( - block_size=block_size, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - head_size_v=self.head_size_v, - dtype=self.kv_cache_torch_dtype, - kv_quant_mode=quant_mode, - sliding_window=self.sliding_window, - ) - elif self.kv_cache_dtype.startswith("turboquant_"): + if self.kv_cache_dtype.startswith("turboquant_"): from vllm.model_executor.layers.quantization.turboquant.config import ( TurboQuantConfig, ) @@ -570,6 +557,20 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: head_size_v=self.head_size, dtype=self.kv_cache_torch_dtype, tq_slot_size=tq_config.slot_size_aligned, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + assert not vllm_config.model_config.use_mla, ( + "MLA is not supported for slidingwindow" + ) + return SlidingWindowSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + head_size_v=self.head_size_v, + dtype=self.kv_cache_torch_dtype, + kv_quant_mode=quant_mode, + sliding_window=self.sliding_window, ) else: return FullAttentionSpec( diff --git a/vllm/model_executor/layers/quantization/turboquant/config.py b/vllm/model_executor/layers/quantization/turboquant/config.py index 50beb8d1d9bf..2fda3464a04f 100644 --- a/vllm/model_executor/layers/quantization/turboquant/config.py +++ b/vllm/model_executor/layers/quantization/turboquant/config.py @@ -205,6 +205,82 @@ def get_boundary_skip_layers( indices = sorted(set(first + last)) return [str(i) for i in indices] + @staticmethod + def apply_yoco_skip_alignment( + merged: list[str], + num_layers: int, + layer_types: list, + num_kv_shared: int, + ) -> list[str]: + """Align the TQ skip list for YOCO (You Only Cache Once) architectures. + + KV-shared layers reuse their target's cache tensor, so the + kv_cache_dtype of a shared layer MUST match its target's. + This method: + 1. Skips all KV-sharing target layers (to prevent quantization + error amplification across every consumer layer). + 2. Propagates the skip/no-skip decision from each target to its + corresponding shared layers so the layouts stay compatible. + + Args: + merged: Current sorted skip-layer list (strings of layer indices). + num_layers: Total number of hidden layers. + layer_types: Per-layer type list from hf_text_config.layer_types. + num_kv_shared: Number of KV-sharing layers + (hf_text_config.num_kv_shared_layers). + + Returns: + Updated sorted skip-layer list as strings. + """ + import logging + + _logger = logging.getLogger(__name__) + + if num_kv_shared <= 0 or not layer_types: + return merged + + first_shared = num_layers - num_kv_shared + skip_set = set(merged) + + # 1) Find all unique KV-sharing target layers and skip them + # to prevent error amplification through YOCO sharing. + target_set: set[str] = set() + for shared_idx in range(first_shared, num_layers): + current_type = layer_types[shared_idx] + for t in range(first_shared - 1, -1, -1): + if layer_types[t] == current_type: + target_set.add(str(t)) + break + new_targets = target_set - skip_set + if new_targets: + skip_set |= new_targets + _logger.info( + "TQ: skipping KV-sharing target layers %s to " + "prevent error amplification in YOCO architecture", + sorted(new_targets, key=lambda x: int(x)), + ) + + # 2) Propagate skip/no-skip from target → shared layer. + for shared_idx in range(first_shared, num_layers): + current_type = layer_types[shared_idx] + target_idx = None + for t in range(first_shared - 1, -1, -1): + if layer_types[t] == current_type: + target_idx = t + break + if target_idx is None: + continue + target_skipped = str(target_idx) in skip_set + shared_skipped = str(shared_idx) in skip_set + if target_skipped and not shared_skipped: + skip_set.add(str(shared_idx)) + elif not target_skipped and shared_skipped: + skip_set.discard(str(shared_idx)) + + result = sorted(skip_set, key=lambda x: int(x)) + _logger.info("TQ: after KV-sharing alignment, skip list: %s", result) + return result + @staticmethod def from_cache_dtype(cache_dtype: str, head_dim: int) -> TurboQuantConfig: """Create config from a named preset. diff --git a/vllm/v1/attention/backends/turboquant_attn.py b/vllm/v1/attention/backends/turboquant_attn.py index af2d0fb0830f..b9087bc87fd3 100644 --- a/vllm/v1/attention/backends/turboquant_attn.py +++ b/vllm/v1/attention/backends/turboquant_attn.py @@ -262,6 +262,12 @@ def __init__( self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads self.num_kv_groups = num_heads // self.num_kv_heads self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = sliding_window + # window_size for flash_attn: [left, right] + if sliding_window is None: + self._fa_window_size: list[int] = [-1, -1] + else: + self._fa_window_size = [sliding_window - 1, 0] from vllm.model_executor.layers.quantization.turboquant.config import ( TurboQuantConfig, @@ -312,6 +318,7 @@ def _flash_attn_varlen( max_seqlen_k=max_seqlen_k, softmax_scale=self.scale, causal=True, + window_size=self._fa_window_size, ) return flash_attn_varlen_func( q=q, @@ -323,6 +330,7 @@ def _flash_attn_varlen( max_seqlen_k=max_seqlen_k, softmax_scale=self.scale, causal=True, + window_size=self._fa_window_size, fa_version=self.fa_version, ) @@ -627,14 +635,31 @@ def _prefill_attention( q_t = q_seq.transpose(0, 1).contiguous() k_t = k_seq.transpose(0, 1).contiguous() v_t = v_seq.transpose(0, 1).contiguous() - out = F.scaled_dot_product_attention( - q_t, - k_t, - v_t, - is_causal=True, - scale=self.scale, - enable_gqa=use_gqa, - ).transpose(0, 1) + # Build sliding-window causal mask if needed + sw = self.sliding_window + if sw is not None: + q_pos = torch.arange(q_len, device=query.device) + k_pos = torch.arange(q_len, device=query.device) + mask = (k_pos.unsqueeze(0) <= q_pos.unsqueeze(1)) & ( + q_pos.unsqueeze(1) - k_pos.unsqueeze(0) < sw + ) + out = F.scaled_dot_product_attention( + q_t, + k_t, + v_t, + attn_mask=mask, + scale=self.scale, + enable_gqa=use_gqa, + ).transpose(0, 1) + else: + out = F.scaled_dot_product_attention( + q_t, + k_t, + v_t, + is_causal=True, + scale=self.scale, + enable_gqa=use_gqa, + ).transpose(0, 1) output[q_start:q_end] = out.to(query.dtype) else: # Continuation chunk: tokens already stored to TQ cache @@ -662,6 +687,7 @@ def _prefill_attention( key_fp8=self.tq_config.key_fp8, norm_correction=self.tq_config.norm_correction, PiT=PiT, + sliding_window=self.sliding_window, ) else: # Large continuation: dequant cached K/V and use @@ -814,6 +840,10 @@ def _continuation_prefill( q_pos = torch.arange(q_len, device=device).unsqueeze(1) + cached_len k_pos = torch.arange(seq_len, device=device).unsqueeze(0) mask = k_pos <= q_pos # (q_len, seq_len) + # Apply sliding window constraint + sw = self.sliding_window + if sw is not None: + mask = mask & (q_pos - k_pos < sw) out = F.scaled_dot_product_attention( q_t, k_t, @@ -874,5 +904,6 @@ def _decode_attention( lse_buf=lse_buf, buf_holder=layer, max_num_kv_splits=self.max_num_kv_splits, + sliding_window=self.sliding_window, ) return result diff --git a/vllm/v1/attention/ops/triton_turboquant_decode.py b/vllm/v1/attention/ops/triton_turboquant_decode.py index 3adaf2610d8d..b94743367b53 100644 --- a/vllm/v1/attention/ops/triton_turboquant_decode.py +++ b/vllm/v1/attention/ops/triton_turboquant_decode.py @@ -83,6 +83,7 @@ def _tq_decode_stage1( KEY_FP8: tl.constexpr, # 1 if K is stored as FP8 NORM_CORRECTION: tl.constexpr = 0, # 1 = re-normalize centroids FP8_E4B15: tl.constexpr = 0, # 1 = use e4b15 (Ampere/Ada), 0 = e4nv (Hopper+) + SLIDING_WINDOW: tl.constexpr = 0, # 0 = full attention, >0 = window size ): bid = tl.program_id(0) # batch index hid = tl.program_id(1) # q_head index @@ -93,9 +94,16 @@ def _tq_decode_stage1( # Sequence length for this batch seq_len = tl.load(Seq_lens_ptr + bid) - # KV split range - split_len = tl.cdiv(seq_len, NUM_KV_SPLITS) - split_start = split_len * sid + # Sliding window: only attend to the last SLIDING_WINDOW tokens + if SLIDING_WINDOW > 0: + effective_start = tl.maximum(0, seq_len - SLIDING_WINDOW) + else: + effective_start = 0 + effective_len = seq_len - effective_start + + # KV split range (over the effective window only) + split_len = tl.cdiv(effective_len, NUM_KV_SPLITS) + split_start = effective_start + split_len * sid split_end = tl.minimum(split_start + split_len, seq_len) if split_start >= split_end: @@ -503,6 +511,7 @@ def triton_turboquant_decode_attention( lse_buf: torch.Tensor | None = None, buf_holder: Any = None, max_num_kv_splits: int = 32, # fixed split count (must be constant for cudagraph) + sliding_window: int | None = None, ) -> torch.Tensor: """Launch fused TQ decode attention (Triton stage1 + stage2). @@ -550,6 +559,7 @@ def triton_turboquant_decode_attention( # Stage 1: split-KV tiled attention scoring + value accumulation fp8_e4b15 = _use_fp8_e4b15(device.index or 0) BLOCK_KV = 4 + _sliding_window = sliding_window if sliding_window is not None else 0 grid = (B, Hq, NUM_KV_SPLITS) _tq_decode_stage1[grid]( q_rot, @@ -583,6 +593,7 @@ def triton_turboquant_decode_attention( KEY_FP8=1 if key_fp8 else 0, NORM_CORRECTION=1 if norm_correction else 0, FP8_E4B15=fp8_e4b15, + SLIDING_WINDOW=_sliding_window, num_warps=1, num_stages=1, ) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index b57e10b67faa..0f62fff9346f 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -10,6 +10,7 @@ from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import dataclass, replace from functools import partial +from math import gcd from typing import Any, NewType, TypeAlias, cast, overload from vllm import envs @@ -1024,22 +1025,30 @@ def unify_kv_cache_spec_page_size( # All layers have the same page size, no need to unify. return kv_cache_spec - max_page_size = max(page_sizes) + # Use LCM of all page sizes as the target so that every layer's page + # size divides the target evenly. The previous approach used the + # maximum page size which fails when sizes aren't exact multiples of + # each other (e.g. TurboQuant + heterogeneous head dims in Gemma 4). + target_page_size = 1 + for ps in page_sizes: + target_page_size = ps * target_page_size // gcd(ps, target_page_size) + new_kv_cache_spec = {} for layer_name, layer_spec in kv_cache_spec.items(): - if layer_spec.page_size_bytes == max_page_size: + layer_page_size = layer_spec.page_size_bytes + if layer_page_size == target_page_size: new_kv_cache_spec[layer_name] = layer_spec else: - layer_page_size = layer_spec.page_size_bytes - if max_page_size % layer_page_size != 0: + if target_page_size % layer_page_size != 0: raise NotImplementedError( "The page size of the layer is not divisible by the " - "maximum page size. Cannot unify by adjusting block_size." + "target (LCM) page size. Cannot unify by adjusting " + "block_size." ) - ratio = max_page_size // layer_page_size + ratio = target_page_size // layer_page_size new_block_size = layer_spec.block_size * ratio new_spec = replace(layer_spec, block_size=new_block_size) - assert new_spec.page_size_bytes == max_page_size + assert new_spec.page_size_bytes == target_page_size new_kv_cache_spec[layer_name] = new_spec return new_kv_cache_spec @@ -1650,6 +1659,23 @@ def get_kv_cache_groups( _annotate_eagle_groups_deepseek_v4(vllm_config, kv_cache_spec, kv_cache_groups) return kv_cache_groups + # When page sizes aren't clean multiples of each other, the LCM-based + # unification below creates excessively large blocks. Try converting + # SlidingWindowSpec / ChunkedLocalAttentionSpec → FullAttentionSpec + # first: if that collapses all specs into one uniform type, the + # single-group path avoids the LCM blow-up entirely. + page_sizes = {s.page_size_bytes for s in kv_cache_spec.values()} + if len(page_sizes) > 1 and max(page_sizes) % min(page_sizes) != 0: + try: + unify_hybrid_kv_cache_specs(kv_cache_spec) + except ValueError: + pass # Could not fully unify; fall through to LCM path + else: + if is_kv_cache_spec_uniform(kv_cache_spec): + return _get_kv_cache_groups_uniform_spec(kv_cache_spec) + elif uniform_spec := UniformTypeKVCacheSpecs.from_specs(kv_cache_spec): + return _get_kv_cache_groups_uniform_type(uniform_spec) + # As KVCacheManager can only allocate memory of one size, we need to unify # the page size of the layers. For cases cannot be unified, this function # will raise an error. diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index b693adbf2771..f1094fba176f 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -450,6 +450,15 @@ def add_kv_sharing_layers_to_kv_cache_groups( tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name] tgt_kv_cache_group.layer_names.append(layer_name) + # When the group uses UniformTypeKVCacheSpecs, also register the + # shared layer so that per-layer spec lookups (e.g. in + # get_attn_backends_for_group) can find it. + group_spec = tgt_kv_cache_group.kv_cache_spec + if isinstance(group_spec, UniformTypeKVCacheSpecs): + group_spec.kv_cache_specs[layer_name] = group_spec.kv_cache_specs[ + target_layer_name + ] + if runner_only_attn_layers is not None: runner_only_attn_layers.add(layer_name)