diff --git a/tests/v1/core/test_kv_sharing.py b/tests/v1/core/test_kv_sharing.py index db0e8dae3e78..8d6680b7bbfd 100644 --- a/tests/v1/core/test_kv_sharing.py +++ b/tests/v1/core/test_kv_sharing.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from types import SimpleNamespace + import pytest import torch +from vllm.v1.core.kv_cache_utils import get_kv_cache_config_from_groups from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheGroupSpec from vllm.v1.worker.utils import add_kv_sharing_layers_to_kv_cache_groups @@ -103,3 +106,28 @@ def test_initialize_kv_cache_for_kv_sharing_no_attn_groups(): assert len(kv_cache_groups) == 2 assert kv_cache_groups[0].layer_names == ["model.layers.0", "model.layers.2"] assert kv_cache_groups[1].layer_names == ["model.layers.1", "model.layers.3"] + + +def test_dflash_draft_kv_groups_keep_hybrid_tensor_sharing(): + spec = new_kv_cache_spec() + num_blocks = 8 + vllm_config = SimpleNamespace( + speculative_config=SimpleNamespace(method="dflash"), + cache_config=SimpleNamespace(num_gpu_blocks_override=None), + ) + kv_cache_groups = [ + KVCacheGroupSpec(["model.layers.0", "model.layers.1"], spec), + KVCacheGroupSpec(["model.layers.2", "model.layers.3"], spec), + ] + + kv_cache_config = get_kv_cache_config_from_groups( + vllm_config=vllm_config, + kv_cache_groups=kv_cache_groups, + available_memory=spec.page_size_bytes * 2 * num_blocks, + ) + + assert kv_cache_config.num_blocks == num_blocks + assert [tensor.shared_by for tensor in kv_cache_config.kv_cache_tensors] == [ + ["model.layers.0", "model.layers.2"], + ["model.layers.1", "model.layers.3"], + ] diff --git a/tests/v1/spec_decode/test_dflash_swa.py b/tests/v1/spec_decode/test_dflash_swa.py new file mode 100644 index 000000000000..af445daf3d9b --- /dev/null +++ b/tests/v1/spec_decode/test_dflash_swa.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +import torch + +from vllm.config import SpeculativeConfig +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.models.qwen3_dflash import DFlashAttention +from vllm.transformers_utils.configs.speculators import SpeculatorsConfig +from vllm.v1.attention.backend import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + SlidingWindowSpec, +) +from vllm.v1.spec_decode.dflash import DFlashProposer + + +class _FakeBuilder: + def __init__( + self, kv_cache_spec=None, layer_names=None, vllm_config=None, device=None + ): + self.kv_cache_spec = kv_cache_spec + self.layer_names = layer_names + + def build_for_drafting(self, common_attn_metadata, draft_index): + return SimpleNamespace( + causal=common_attn_metadata.causal, + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + ) + + +class _FakeAttentionGroup: + def __init__(self, layer_names, kv_cache_group_id=0): + self.layer_names = layer_names + self.kv_cache_group_id = kv_cache_group_id + self._builder = _FakeBuilder() + + def get_metadata_builder(self): + return self._builder + + +def _make_cad(block_table, slot_mapping) -> CommonAttentionMetadata: + return CommonAttentionMetadata( + query_start_loc=torch.tensor([0, 2], dtype=torch.int32), + query_start_loc_cpu=torch.tensor([0, 2], dtype=torch.int32), + seq_lens=torch.tensor([2], dtype=torch.int32), + num_reqs=1, + num_actual_tokens=2, + max_query_len=2, + max_seq_len=2, + block_table_tensor=block_table, + slot_mapping=slot_mapping, + causal=False, + ) + + +def test_dflash_speculators_preserves_swa_config(): + layer_types = [ + "sliding_attention", + "sliding_attention", + "full_attention", + ] + config = { + "speculators_model_type": "dflash", + "transformer_layer_config": { + "num_hidden_layers": len(layer_types), + "sliding_window": None, + }, + "draft_vocab_size": 100, + "target_hidden_size": 64, + "aux_hidden_state_layer_ids": [0, 1, 2], + "mask_token_id": 99, + "layer_types": layer_types, + "use_sliding_window": True, + "sliding_window": 2048, + "max_window_layers": len(layer_types), + } + + hf_config = SpeculatorsConfig.extract_transformers_pre_trained_config(config) + + assert hf_config["layer_types"] == layer_types + assert hf_config["use_sliding_window"] is True + assert hf_config["sliding_window"] == 2048 + assert hf_config["max_window_layers"] == len(layer_types) + assert hf_config["eagle_aux_hidden_state_layer_ids"] == [1, 2, 3] + assert hf_config["dflash_config"]["target_layer_ids"] == [0, 1, 2] + + +def _compute_dflash_hash(hf_config: SimpleNamespace) -> str: + config = object.__new__(SpeculativeConfig) + config.method = "dflash" + config.draft_model_config = SimpleNamespace(hf_config=hf_config) + return config.compute_hash() + + +def test_dflash_compile_hash_uses_checkpoint_layer_id_semantics(): + dflash_hash = _compute_dflash_hash( + SimpleNamespace(dflash_config={"target_layer_ids": [0, 2]}) + ) + shifted_aux_hash = _compute_dflash_hash( + SimpleNamespace(eagle_aux_hidden_state_layer_ids=[1, 3]) + ) + different_hash = _compute_dflash_hash( + SimpleNamespace(dflash_config={"target_layer_ids": [0, 3]}) + ) + + assert dflash_hash == shifted_aux_hash + assert dflash_hash != different_hash + + +def test_dflash_swa_layers_use_full_kv_cache_spec(monkeypatch): + sliding_spec = SlidingWindowSpec( + block_size=16, + num_kv_heads=1, + head_size=8, + dtype=torch.float16, + sliding_window=4, + ) + monkeypatch.setattr( + Attention, + "get_kv_cache_spec", + lambda self, vllm_config: sliding_spec, + ) + + spec = DFlashAttention.get_kv_cache_spec( + object.__new__(DFlashAttention), SimpleNamespace() + ) + + assert isinstance(spec, FullAttentionSpec) + assert spec.block_size == sliding_spec.block_size + assert spec.num_kv_heads == sliding_spec.num_kv_heads + assert spec.head_size == sliding_spec.head_size + assert spec.sliding_window is None + + +def test_dflash_swa_layers_use_causal_metadata(): + proposer = object.__new__(DFlashProposer) + proposer.model = SimpleNamespace(sliding_attention_layer_names={"layer.sw"}) + proposer.draft_attn_groups = [_FakeAttentionGroup(["layer.sw", "layer.full"])] + proposer.kv_cache_gid = 0 + proposer._draft_kv_cache_group_ids = [0] + proposer._draft_layer_to_kv_cache_gid = { + "layer.sw": 0, + "layer.full": 0, + } + proposer._draft_block_tables = {} + cad = _make_cad( + torch.empty(1, 1, dtype=torch.int32), + torch.empty(2, dtype=torch.int64), + ) + proposer._slot_mapping_buffers_by_gid = {0: (cad.slot_mapping, cad.slot_mapping)} + + per_group, per_layer = DFlashProposer.build_per_group_and_layer_attn_metadata( + proposer, cad + ) + + assert per_group[0].causal is False + assert per_layer["layer.sw"].causal is True + assert per_layer["layer.full"].causal is False + + +def test_dflash_metadata_uses_per_kv_group_slot_mapping(): + proposer = object.__new__(DFlashProposer) + proposer.model = SimpleNamespace(sliding_attention_layer_names={"layer.sw"}) + proposer.draft_attn_groups = [ + _FakeAttentionGroup(["layer.full"], kv_cache_group_id=1), + _FakeAttentionGroup(["layer.sw"], kv_cache_group_id=2), + ] + proposer.kv_cache_gid = 1 + proposer._draft_kv_cache_group_ids = [1, 2] + proposer._draft_layer_to_kv_cache_gid = { + "layer.full": 1, + "layer.sw": 2, + } + + full_block_table = torch.tensor([[11, 12]], dtype=torch.int32) + sw_block_table = torch.tensor([[21, 22]], dtype=torch.int32) + full_slots = torch.tensor([111, 112], dtype=torch.int64) + sw_slots = torch.tensor([211, 212], dtype=torch.int64) + + base_cad = _make_cad(full_block_table, full_slots) + proposer._draft_block_tables = { + 1: full_block_table, + 2: sw_block_table, + } + proposer._slot_mapping_buffers_by_gid = { + 1: (full_slots, full_slots), + 2: (sw_slots, sw_slots), + } + + _, per_layer = DFlashProposer.build_per_group_and_layer_attn_metadata( + proposer, base_cad + ) + + assert per_layer["layer.full"].block_table_tensor is full_block_table + torch.testing.assert_close(per_layer["layer.full"].slot_mapping, full_slots) + assert per_layer["layer.full"].causal is False + assert per_layer["layer.sw"].block_table_tensor is sw_block_table + torch.testing.assert_close(per_layer["layer.sw"].slot_mapping, sw_slots) + assert per_layer["layer.sw"].causal is True diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 4942390cdd38..b8dedae9e479 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -716,6 +716,41 @@ def rnd_stride_order( assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) +@pytest.mark.parametrize( + "physical_order", + [ + ("block", "kv", "token", "head", "dim"), + ("block", "kv", "head", "token", "dim"), + ], +) +def test_kv_major_cache_can_share_block_major_raw_tensor(physical_order): + kv_cache_shape = (2, 3, 4, 2, 8) + _, num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape + block_elems = block_size * num_kv_heads * head_size + raw_tensor = torch.arange(2 * num_blocks * block_elems) + public_order = ("kv", "block", "token", "head", "dim") + dim_sizes = dict(zip(public_order, kv_cache_shape)) + expected_strides = {} + stride = 1 + for dim in reversed(physical_order): + expected_strides[dim] = stride + stride *= dim_sizes[dim] + + kv_cache = GPUModelRunner._view_kv_cache_with_physical_order( + raw_tensor, + kv_cache_shape, + public_order, + physical_order, + ) + + assert kv_cache.shape == kv_cache_shape + assert kv_cache.stride() == tuple(expected_strides[dim] for dim in public_order) + assert kv_cache[0, 0, 0, 0, 0] == raw_tensor[0] + assert kv_cache[1, 0, 0, 0, 0] == raw_tensor[block_elems] + assert kv_cache[0, 1, 0, 0, 0] == raw_tensor[2 * block_elems] + assert kv_cache[1, 1, 0, 0, 0] == raw_tensor[3 * block_elems] + + def test_update_config(model_runner): # Simple update model_runner.update_config({"load_config": {"load_format": "dummy"}}) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 4589820c9fdc..3511c6ce012a 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -289,6 +289,14 @@ def compute_hash(self) -> str: "eagle_aux_hidden_state_layer_ids", None, ) + if not layer_ids: + dflash_config = getattr( + self.draft_model_config.hf_config, "dflash_config", None + ) + if dflash_config and isinstance(dflash_config, dict): + layer_ids = [ + i + 1 for i in dflash_config.get("target_layer_ids", []) + ] if layer_ids is not None: # Convert to tuple to make it hashable factors.append(tuple(layer_ids)) @@ -1047,6 +1055,10 @@ def use_gemma4_mtp(self) -> bool: def use_eagle(self) -> bool: return self.method in ("eagle", "eagle3", "mtp", "dflash") + def requires_eagle_cache_drop(self) -> bool: + """Whether prefix cache hits must drop one block for hidden states.""" + return self.use_eagle() and not self.use_dflash() + def use_dflash(self) -> bool: return self.method == "dflash" diff --git a/vllm/model_executor/models/qwen3_dflash.py b/vllm/model_executor/models/qwen3_dflash.py index cffe6267a4b3..0587fa62bfb4 100644 --- a/vllm/model_executor/models/qwen3_dflash.py +++ b/vllm/model_executor/models/qwen3_dflash.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable +from collections.abc import Iterable, Mapping import torch import torch.nn.functional as F @@ -34,6 +34,11 @@ from vllm.multimodal.inputs import NestedTensors from vllm.transformers_utils.config import set_default_rope_theta from vllm.v1.attention.backend import AttentionType +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheSpec, + SlidingWindowSpec, +) from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen3 import Qwen3ForCausalLM @@ -47,6 +52,53 @@ logger = init_logger(__name__) +_DFLASH_VALID_LAYER_TYPES = frozenset({"full_attention", "sliding_attention"}) + + +def _get_dflash_layer_types(config: Qwen3Config) -> tuple[str, ...]: + layer_types = getattr(config, "layer_types", None) + if layer_types is None: + return ("full_attention",) * config.num_hidden_layers + if len(layer_types) != config.num_hidden_layers: + raise ValueError( + f"DFlash layer_types length {len(layer_types)} does not match " + f"num_hidden_layers {config.num_hidden_layers}." + ) + invalid = set(layer_types) - _DFLASH_VALID_LAYER_TYPES + if invalid: + raise ValueError(f"Invalid DFlash layer_type(s): {sorted(invalid)}.") + if "sliding_attention" in layer_types and not getattr( + config, "sliding_window", None + ): + raise ValueError( + "DFlash sliding_attention layers require `sliding_window` in config." + ) + return tuple(layer_types) + + +class DFlashAttention(Attention): + """Attention with DFlash-specific KV allocation semantics. + + The compute path keeps the layer's configured sliding window. The KV cache + spec is widened to full attention because DFlash writes every context KV + before drafting and cannot evict old context blocks from draft layers. + """ + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: + spec = super().get_kv_cache_spec(vllm_config) + if isinstance(spec, SlidingWindowSpec): + return FullAttentionSpec( + block_size=spec.block_size, + num_kv_heads=spec.num_kv_heads, + head_size=spec.head_size, + head_size_v=getattr(spec, "head_size_v", spec.head_size), + dtype=spec.dtype, + kv_quant_mode=spec.kv_quant_mode, + page_size_padded=spec.page_size_padded, + ) + return spec + + class DFlashQwen3Attention(nn.Module): """Attention for DFlash speculative decoding. @@ -66,6 +118,7 @@ def __init__( attention_bias: bool = False, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, + sliding_window: int | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, ) -> None: @@ -109,13 +162,14 @@ def __init__( max_position=max_position, rope_parameters=rope_parameters, ) - self.attn = Attention( + self.attn = DFlashAttention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, + per_layer_sliding_window=sliding_window, prefix=f"{prefix}.attn", attn_type=attn_type, ) @@ -158,12 +212,17 @@ def __init__( config: Qwen3Config, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, + layer_type: str = "full_attention", prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size + self.layer_type = layer_type set_default_rope_theta(config, default_theta=1000000) attn_type = AttentionType.DECODER + sliding_window = ( + config.sliding_window if layer_type == "sliding_attention" else None + ) self.self_attn = DFlashQwen3Attention( hidden_size=self.hidden_size, @@ -175,6 +234,7 @@ def __init__( head_dim=getattr(config, "head_dim", None), cache_config=cache_config, quant_config=quant_config, + sliding_window=sliding_window, rope_parameters=config.rope_parameters, prefix=f"{prefix}.self_attn", attn_type=attn_type, @@ -243,16 +303,23 @@ def __init__( prefix=maybe_prefix(prefix, "embed_tokens"), ) + self.layer_types = _get_dflash_layer_types(self.config) self.layers = nn.ModuleList( [ DFlashQwen3DecoderLayer( current_vllm_config, prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"), config=self.config, + layer_type=self.layer_types[layer_idx], ) for layer_idx in range(self.config.num_hidden_layers) ] ) + self.sliding_attention_layer_names = { + layer.self_attn.attn.layer_name + for layer in self.layers + if layer.layer_type == "sliding_attention" + } if self.use_aux_hidden_state: num_features_to_use = self.config.num_hidden_layers if "target_layer_ids" in drafter_config: @@ -343,7 +410,7 @@ def precompute_and_store_context_kv( self, context_states: torch.Tensor, context_positions: torch.Tensor, - context_slot_mapping: torch.Tensor | None = None, + context_slot_mapping: torch.Tensor | Mapping[str, torch.Tensor] | None = None, ) -> None: """Precompute K/V for context states write them into each layer's KV cache. @@ -424,13 +491,18 @@ def precompute_and_store_context_kv( all_k_final = all_k_flat.view(L, num_ctx, nkv, hd) for i in range(L): attn = self._attn_layers[i] + layer_slot_mapping = ( + context_slot_mapping[attn.layer_name] + if isinstance(context_slot_mapping, Mapping) + else context_slot_mapping + ) kv_cache = attn.kv_cache attn.impl.do_kv_cache_update( attn, all_k_final[i], all_v[i], kv_cache, - context_slot_mapping, + layer_slot_mapping, ) def forward( @@ -521,7 +593,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor( - self.config.draft_vocab_size, scale=logit_scale + self.config.draft_vocab_size, + scale=logit_scale, ) target_vocab_size = vllm_config.model_config.get_vocab_size() if self.config.draft_vocab_size != target_vocab_size: @@ -569,13 +642,17 @@ def precompute_and_store_context_kv( self, context_states: torch.Tensor, context_positions: torch.Tensor, - context_slot_mapping: torch.Tensor | None = None, + context_slot_mapping: torch.Tensor | Mapping[str, torch.Tensor] | None = None, ) -> None: """Precompute projected + RoPE'd K/V and write to cache.""" self.model.precompute_and_store_context_kv( context_states, context_positions, context_slot_mapping ) + @property + def sliding_attention_layer_names(self) -> set[str]: + return self.model.sliding_attention_layer_names + def combine_hidden_states( self, hidden_states: torch.Tensor, diff --git a/vllm/transformers_utils/configs/speculators/algos.py b/vllm/transformers_utils/configs/speculators/algos.py index 405d5f5de1d1..83f5ab07dfa7 100644 --- a/vllm/transformers_utils/configs/speculators/algos.py +++ b/vllm/transformers_utils/configs/speculators/algos.py @@ -54,19 +54,29 @@ def update_dflash(config_dict: dict, pre_trained_config: dict) -> None: - target_hidden_size: Hidden size of the target model - mask_token_id (required): Token ID used for parallel drafting mask placeholders - - aux_hidden_state_layer_ids (required): Layer indices from the target - model whose intermediate hidden states are used as context for the - DFlash drafter. Mapped to both eagle_aux_hidden_state_layer_ids - (for gpu_model_runner) and dflash_config.target_layer_ids (for the - DFlash model). + - aux_hidden_state_layer_ids (required): DFlash target layer indices whose + intermediate hidden states are used as context for the DFlash drafter. + Mapped to dflash_config.target_layer_ids for the DFlash model. The + runner-facing eagle_aux_hidden_state_layer_ids are shifted by one to + match vLLM's hidden-state extraction semantics. """ pre_trained_config["architectures"] = ["DFlashDraftModel"] pre_trained_config["draft_vocab_size"] = config_dict.get("draft_vocab_size") if config_dict.get("target_hidden_size") is not None: pre_trained_config["target_hidden_size"] = config_dict["target_hidden_size"] + for key in ( + "layer_types", + "use_sliding_window", + "sliding_window", + "max_window_layers", + ): + if key in config_dict: + pre_trained_config[key] = config_dict[key] aux_layer_ids = config_dict["aux_hidden_state_layer_ids"] - pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids + pre_trained_config["eagle_aux_hidden_state_layer_ids"] = [ + i + 1 for i in aux_layer_ids + ] pre_trained_config["dflash_config"] = { "mask_token_id": config_dict["mask_token_id"], diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index b57e10b67faa..71f308e25cf5 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1302,9 +1302,9 @@ def get_kv_cache_config_from_groups( kv_cache_tensors = [] for i in range(group_size): shared_by = [] - for j in range(len(kv_cache_groups)): - if i < len(kv_cache_groups[j].layer_names): - shared_by.append(kv_cache_groups[j].layer_names[i]) + for group in kv_cache_groups: + if i < len(group.layer_names): + shared_by.append(group.layer_names[i]) kv_cache_tensors.append( KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by) ) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8aaeb3970079..e46c9a80f137 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -208,11 +208,15 @@ def __init__( speculative_config = vllm_config.speculative_config self.use_eagle = False + self.requires_eagle_cache_drop = False self.num_spec_tokens = self.num_lookahead_tokens = 0 if speculative_config: self.num_spec_tokens = speculative_config.num_speculative_tokens if speculative_config.use_eagle(): self.use_eagle = True + self.requires_eagle_cache_drop = ( + speculative_config.requires_eagle_cache_drop() + ) self.num_lookahead_tokens = self.num_spec_tokens if speculative_config.uses_draft_model(): self.num_lookahead_tokens = self.num_spec_tokens @@ -225,7 +229,7 @@ def __init__( max_model_len=self.max_model_len, max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens, enable_caching=self.cache_config.enable_prefix_caching, - use_eagle=self.use_eagle, + use_eagle=self.requires_eagle_cache_drop, log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, dcp_world_size=self.dcp_world_size, @@ -283,13 +287,13 @@ def _mamba_block_aligned_split( # must be a multiple of `block_size`. # As an exception, if `num_new_tokens` is less than `block_size`, the # state is simply not cached, requiring no special handling. - # Additionally, when Eagle mode is enabled, FullAttn prunes the last - # matching block. To prevent this from causing a Mamba cache miss, the - # last chunk must be not smaller than `block_size`. + # Additionally, when Eagle-style cache drop is enabled, FullAttn + # prunes the last matching block. To prevent this from causing a + # Mamba cache miss, the last chunk must be not smaller than + # `block_size`. block_size = self.cache_config.block_size last_cache_position = request.num_tokens - request.num_tokens % block_size - # eagle prune - if self.use_eagle: + if self.requires_eagle_cache_drop: last_cache_position = max(last_cache_position - block_size, 0) num_computed_tokens_after_sched = num_computed_tokens + num_new_tokens if num_computed_tokens_after_sched < last_cache_position: @@ -393,7 +397,7 @@ def schedule(self) -> SchedulerOutput: request.num_computed_tokens, num_new_tokens, encoder_compute_budget, - shift_computed_tokens=1 if self.use_eagle else 0, + shift_computed_tokens=1 if self.requires_eagle_cache_drop else 0, ) if self.need_mamba_block_aligned_split: @@ -662,7 +666,9 @@ def schedule(self) -> SchedulerOutput: num_computed_tokens, num_new_tokens, encoder_compute_budget, - shift_computed_tokens=1 if self.use_eagle else 0, + shift_computed_tokens=1 + if self.requires_eagle_cache_drop + else 0, ) if num_new_tokens == 0: # The request cannot be scheduled. diff --git a/vllm/v1/spec_decode/dflash.py b/vllm/v1/spec_decode/dflash.py index db74044f4fde..eb3f825cde16 100644 --- a/vllm/v1/spec_decode/dflash.py +++ b/vllm/v1/spec_decode/dflash.py @@ -1,17 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import replace from typing import Any import torch from typing_extensions import override -from vllm.config import VllmConfig +from vllm.config import VllmConfig, replace from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.triton_utils import triton from vllm.v1.attention.backend import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.spec_decode.llm_base_proposer import SpecDecodeBaseProposer from vllm.v1.spec_decode.utils import copy_and_expand_dflash_inputs_kernel @@ -30,6 +30,7 @@ def __init__( super().__init__( vllm_config=vllm_config, device=device, + # Request aux hidden states; DFlash turns them into context K/V. pass_hidden_states_to_model=True, runner=runner, ) @@ -50,6 +51,11 @@ def __init__( dtype=torch.int64, device=device, ) + self._slot_mapping_buffers_by_gid: dict[ + int, tuple[torch.Tensor, torch.Tensor] + ] = {} + self._draft_block_size_by_gid: dict[int, int] = {} + self._draft_block_tables: dict[int, torch.Tensor] = {} self._context_positions_buffer = torch.zeros( self.max_num_tokens, dtype=torch.int64, @@ -65,9 +71,41 @@ def __init__( self.max_positions + 1, device=device, dtype=torch.int32 ) - # For DFlash we use the input embeddings to embed the mask token + # DFlash embeds mask tokens directly. self.parallel_drafting_hidden_state_tensor = None + @override + def allow_multiple_draft_kv_cache_groups(self) -> bool: + return True + + @override + def initialize_attn_backend( + self, + kv_cache_config: KVCacheConfig, + kernel_block_sizes: list[int] | None = None, + ) -> None: + super().initialize_attn_backend(kv_cache_config, kernel_block_sizes) + self._draft_block_size_by_gid.clear() + for attn_group in self.draft_attn_groups: + gid = attn_group.kv_cache_group_id + self._draft_block_size_by_gid[gid] = ( + kernel_block_sizes[gid] + if kernel_block_sizes is not None and gid < len(kernel_block_sizes) + else attn_group.get_metadata_builder().kv_cache_spec.block_size + ) + self._ensure_slot_mapping_buffers() + + def clear_draft_block_tables(self) -> None: + self._draft_block_tables.clear() + + def set_draft_block_table( + self, + kv_cache_gid: int, + block_table: torch.Tensor, + ) -> None: + if kv_cache_gid in self._draft_kv_cache_group_ids: + self._draft_block_tables[kv_cache_gid] = block_table + @override def _create_draft_vllm_config(self) -> VllmConfig: base = super()._create_draft_vllm_config() @@ -82,8 +120,85 @@ def _create_draft_vllm_config(self) -> VllmConfig: @override def _warn_if_multimodal(self): # Override to allow multimodal inputs since DFlash supports Qwen3.5 models + # Support for multimodal inputs has not been tested. pass + def _ensure_slot_mapping_buffers(self) -> None: + gids = self._draft_kv_gids() + + first_gid = gids[0] + for gid in gids: + if gid in self._slot_mapping_buffers_by_gid: + continue + if gid == first_gid: + self._slot_mapping_buffers_by_gid[gid] = ( + self._context_slot_mapping_buffer, + self._slot_mapping_buffer, + ) + else: + self._slot_mapping_buffers_by_gid[gid] = ( + torch.zeros( + self.max_num_tokens, + dtype=torch.int64, + device=self.device, + ), + torch.zeros( + self.max_query_tokens, + dtype=torch.int64, + device=self.device, + ), + ) + + def _draft_kv_gids(self) -> list[int]: + return self._draft_kv_cache_group_ids or [ + self.kv_cache_gid if self.kv_cache_gid >= 0 else 0 + ] + + def _get_dflash_block_table( + self, + kv_cache_gid: int, + cad: CommonAttentionMetadata, + ) -> torch.Tensor: + block_table = self._draft_block_tables.get(kv_cache_gid) + if block_table is not None: + return block_table + if kv_cache_gid == self.kv_cache_gid or self.kv_cache_gid < 0: + return cad.block_table_tensor + raise RuntimeError( + "Missing DFlash KV metadata for draft KV cache group " + f"{kv_cache_gid}. This is required when DFlash draft layers span " + "multiple KV cache groups." + ) + + def _get_dflash_context_slot_mapping( + self, + num_context: int, + ) -> torch.Tensor | dict[str, torch.Tensor]: + if not self._draft_layer_to_kv_cache_gid: + return self._context_slot_mapping_buffer[:num_context] + return { + layer_name: self._slot_mapping_buffers_by_gid[ + self._draft_layer_to_kv_cache_gid[layer_name] + ][0][:num_context] + for layer_name in self._draft_attn_layer_names + } + + @override + def _get_slot_mapping( + self, + num_tokens: int, + slot_mapping: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + self._ensure_slot_mapping_buffers() + if self._draft_layer_to_kv_cache_gid: + return { + layer_name: self._slot_mapping_buffers_by_gid[ + self._draft_layer_to_kv_cache_gid[layer_name] + ][1][:num_tokens] + for layer_name in self._draft_attn_layer_names + } + return super()._get_slot_mapping(num_tokens, slot_mapping) + @override def set_inputs_first_pass( self, @@ -102,11 +217,9 @@ def set_inputs_first_pass( num_query_per_req = 1 + self.num_speculative_tokens num_query_total = batch_size * num_query_per_req - # Store for build_model_inputs_first_pass to use self._dflash_num_context = num_context - # We don't need to copy into a buffer here since the context preprocessing - # does not run in a CUDA graph + # Context preprocessing does not run in a CUDA graph. self._dflash_hidden_states = target_hidden_states token_indices_to_sample = torch.empty( @@ -115,8 +228,7 @@ def set_inputs_first_pass( device=self.device, ) - # Launch fused triton kernel for input_ids, positions, slot_mapping, - # and token_indices_to_sample + # Fill query inputs and per-KV-group slot mappings. max_ctx_per_req = cad.max_query_len max_tokens_per_req = max_ctx_per_req + num_query_per_req BLOCK_SIZE = min(256, triton.next_power_of_2(max_tokens_per_req)) @@ -124,36 +236,48 @@ def set_inputs_first_pass( grid = (batch_size, num_blocks) has_num_rejected = num_rejected_tokens_gpu is not None - copy_and_expand_dflash_inputs_kernel[grid]( - # Inputs - next_token_ids_ptr=next_token_ids, - target_positions_ptr=target_positions, - # Outputs - out_input_ids_ptr=self.input_ids, - out_context_positions_ptr=self._context_positions_buffer, - out_query_positions_ptr=self.positions, - out_context_slot_mapping_ptr=self._context_slot_mapping_buffer, - out_query_slot_mapping_ptr=self._slot_mapping_buffer, - out_token_indices_ptr=token_indices_to_sample, - # Block table - block_table_ptr=cad.block_table_tensor, - block_table_stride=cad.block_table_tensor.stride(0), - # Metadata - query_start_loc_ptr=cad.query_start_loc, - num_rejected_tokens_ptr=( - num_rejected_tokens_gpu if has_num_rejected else 0 - ), - # Scalars - parallel_drafting_token_id=self.parallel_drafting_token_id, - block_size=self.block_size, - num_query_per_req=num_query_per_req, - num_speculative_tokens=self.num_speculative_tokens, - total_input_tokens=num_context, - BLOCK_SIZE=BLOCK_SIZE, - HAS_NUM_REJECTED=has_num_rejected, - ) + self._ensure_slot_mapping_buffers() + draft_kv_group_ids = self._draft_kv_gids() + for kv_cache_gid in draft_kv_group_ids: + context_slot_mapping_buffer, query_slot_mapping_buffer = ( + self._slot_mapping_buffers_by_gid[kv_cache_gid] + ) + block_table = self._get_dflash_block_table(kv_cache_gid, cad) + copy_and_expand_dflash_inputs_kernel[grid]( + # Inputs + next_token_ids_ptr=next_token_ids, + target_positions_ptr=target_positions, + # Outputs + out_input_ids_ptr=self.input_ids, + out_context_positions_ptr=self._context_positions_buffer, + out_query_positions_ptr=self.positions, + out_context_slot_mapping_ptr=context_slot_mapping_buffer, + out_query_slot_mapping_ptr=query_slot_mapping_buffer, + out_token_indices_ptr=token_indices_to_sample, + # Block table + block_table_ptr=block_table, + block_table_stride=block_table.stride(0), + # Metadata + query_start_loc_ptr=cad.query_start_loc, + num_rejected_tokens_ptr=( + num_rejected_tokens_gpu if has_num_rejected else 0 + ), + # Scalars + parallel_drafting_token_id=self.parallel_drafting_token_id, + block_size=self._draft_block_size_by_gid.get( + kv_cache_gid, self.block_size + ), + num_query_per_req=num_query_per_req, + num_speculative_tokens=self.num_speculative_tokens, + total_input_tokens=num_context, + BLOCK_SIZE=BLOCK_SIZE, + HAS_NUM_REJECTED=has_num_rejected, + ) - query_slot_mapping = self._slot_mapping_buffer[:num_query_total] + primary_kv_cache_gid = draft_kv_group_ids[0] + query_slot_mapping = self._slot_mapping_buffers_by_gid[primary_kv_cache_gid][1][ + :num_query_total + ] new_query_start_loc = self.arange[: batch_size + 1] * num_query_per_req # In padded mode, cad.seq_lens includes rejected tokens. Subtract @@ -162,12 +286,6 @@ def set_inputs_first_pass( if has_num_rejected: effective_seq_lens = effective_seq_lens - num_rejected_tokens_gpu - # Skip num_rejected_tokens (GPU-only); overestimating is fine here. - new_seq_lens_cpu_upper_bound = ( - cad.seq_lens_cpu_upper_bound + num_query_per_req - if cad.seq_lens_cpu_upper_bound is not None - else None - ) new_cad = CommonAttentionMetadata( query_start_loc=new_query_start_loc, seq_lens=effective_seq_lens + num_query_per_req, @@ -177,12 +295,11 @@ def set_inputs_first_pass( ), _seq_lens_cpu=None, _num_computed_tokens_cpu=None, - seq_lens_cpu_upper_bound=new_seq_lens_cpu_upper_bound, num_reqs=cad.num_reqs, num_actual_tokens=num_query_total, max_query_len=num_query_per_req, max_seq_len=cad.max_seq_len + num_query_per_req, - block_table_tensor=cad.block_table_tensor, + block_table_tensor=self._get_dflash_block_table(primary_kv_cache_gid, cad), slot_mapping=query_slot_mapping, causal=False, # Non-causal attention is required for DFlash ) @@ -255,15 +372,13 @@ def build_model_inputs_first_pass( num_input_tokens: int, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None, ) -> tuple[dict[str, Any], int]: - # Context and query positions/slots were written to separate - # buffers by the kernel — no copy needed. + # Context and query positions/slots were written by the kernel. num_context = self._dflash_num_context - # Pre-insert context KVs directly into cache self.model.precompute_and_store_context_kv( self._dflash_hidden_states, # Shape is already [num_context, hidden_size] self._context_positions_buffer[:num_context], - self._context_slot_mapping_buffer[:num_context], + self._get_dflash_context_slot_mapping(num_context), ) return ( dict( @@ -278,10 +393,52 @@ def build_model_inputs_first_pass( def build_per_group_and_layer_attn_metadata( self, cad: CommonAttentionMetadata, draft_index: int = 0 ) -> tuple[list[object], dict[str, object]]: - per_group, per_layer = super().build_per_group_and_layer_attn_metadata( - cad, draft_index + self._ensure_slot_mapping_buffers() + sliding_layer_names: set[str] = getattr( + self.model, "sliding_attention_layer_names", set() ) + + per_group: list[object] = [] + per_layer: dict[str, object] = {} + for attn_group in self.draft_attn_groups: + kv_cache_gid = attn_group.kv_cache_group_id + group_cad = cad.replace( + block_table_tensor=self._get_dflash_block_table(kv_cache_gid, cad), + slot_mapping=self._slot_mapping_buffers_by_gid[kv_cache_gid][1][ + : cad.num_actual_tokens + ], + causal=False, + ) + attn_metadata = attn_group.get_metadata_builder().build_for_drafting( + common_attn_metadata=group_cad, + draft_index=draft_index, + ) + per_group.append(attn_metadata) + for layer_name in attn_group.layer_names: + per_layer[layer_name] = attn_metadata + + # DFlash layers consume attention metadata through the per-layer + # forward context. Keep the non-causal group metadata for + # group-level spec decode checks, and specialize only the SWA + # layers that need a causal sliding-window mask. + causal_layers = sliding_layer_names & set(attn_group.layer_names) + if causal_layers: + causal_attn_metadata = ( + attn_group.get_metadata_builder().build_for_drafting( + common_attn_metadata=group_cad.replace(causal=True), + draft_index=draft_index, + ) + ) + for layer_name in causal_layers: + per_layer[layer_name] = causal_attn_metadata + for layer_name, attn_metadata in per_layer.items(): + if layer_name in sliding_layer_names: + assert getattr(attn_metadata, "causal", None) is True, ( + f"Attention metadata for sliding layer {layer_name} does not have" + " causal support, which is required for DFlash SWA." + ) + continue assert getattr(attn_metadata, "causal", None) is False, ( f"Attention metadata for layer {layer_name} does not have" " non-causal support, which is required for DFlash." diff --git a/vllm/v1/spec_decode/llm_base_proposer.py b/vllm/v1/spec_decode/llm_base_proposer.py index cc113025c129..1fbd3a3ee1a8 100644 --- a/vllm/v1/spec_decode/llm_base_proposer.py +++ b/vllm/v1/spec_decode/llm_base_proposer.py @@ -30,7 +30,11 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher -from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs +from vllm.v1.kv_cache_interface import ( + KVCacheConfig, + KVCacheSpec, + UniformTypeKVCacheSpecs, +) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import _SAMPLING_EPS from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -130,6 +134,8 @@ def __init__( self.draft_attn_groups: list[AttentionGroup] = [] self.kv_cache_gid: int = -1 + self._draft_layer_to_kv_cache_gid: dict[str, int] = {} + self._draft_kv_cache_group_ids: list[int] = [] self.eagle3_use_aux_hidden_state: bool = ( self._get_eagle3_use_aux_hidden_state_from_config() ) @@ -1485,6 +1491,9 @@ def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: == 1 ), "All drafting layers should belong to the same kv cache group" + def allow_multiple_draft_kv_cache_groups(self) -> bool: + return False + def initialize_attn_backend( self, kv_cache_config: KVCacheConfig, @@ -1499,47 +1508,61 @@ def initialize_attn_backend( AttentionLayerBase, # type: ignore[type-abstract] ) - # Find which kv_cache_group the draft layers belong to - self.validate_same_kv_cache_group(kv_cache_config) - kv_cache_spec = None - for gid, group in enumerate(kv_cache_config.kv_cache_groups): - if self._draft_attn_layer_names & set(group.layer_names): - self.kv_cache_gid = gid - kv_cache_spec = group.kv_cache_spec - break - - attention_groups: dict[tuple[str, str], AttentionGroup] = {} - if kv_cache_spec is not None: - for layer_name in self._draft_attn_layer_names: - attn_backend = all_attn_layers[layer_name].get_attn_backend() - backend_key = attn_backend.full_cls_name() - if backend_key not in attention_groups: - layer_kv_cache_spec = kv_cache_spec - if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): - layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[ - layer_name - ] - - kernel_block_size = ( - kernel_block_sizes[self.kv_cache_gid] - if kernel_block_sizes is not None - and self.kv_cache_gid < len(kernel_block_sizes) - else None - ) - attn_group = AttentionGroup( - backend=attn_backend, - layer_names=[layer_name], - kv_cache_spec=layer_kv_cache_spec, - kv_cache_group_id=self.kv_cache_gid, - ) - attn_group.create_metadata_builders( - self.vllm_config, - self.device, - kernel_block_size=kernel_block_size, - ) - attention_groups[backend_key] = attn_group - else: - attention_groups[backend_key].layer_names.append(layer_name) + self._draft_layer_to_kv_cache_gid = { + layer_name: gid + for gid, group in enumerate(kv_cache_config.kv_cache_groups) + for layer_name in group.layer_names + if layer_name in self._draft_attn_layer_names + } + missing_layers = self._draft_attn_layer_names - set( + self._draft_layer_to_kv_cache_gid + ) + assert not missing_layers, ( + "Draft attention layers are missing from KV cache groups: " + f"{sorted(missing_layers)}" + ) + self._draft_kv_cache_group_ids = sorted( + set(self._draft_layer_to_kv_cache_gid.values()) + ) + if not self.allow_multiple_draft_kv_cache_groups(): + assert len(self._draft_kv_cache_group_ids) == 1, ( + "All drafting layers should belong to the same kv cache group" + ) + self.kv_cache_gid = self._draft_kv_cache_group_ids[0] + + attention_groups: dict[ + tuple[int, tuple[str, str], KVCacheSpec], AttentionGroup + ] = {} + for layer_name in self._draft_attn_layer_names: + gid = self._draft_layer_to_kv_cache_gid[layer_name] + kv_cache_spec = kv_cache_config.kv_cache_groups[gid].kv_cache_spec + layer_kv_cache_spec = kv_cache_spec + if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): + layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name] + + attn_backend = all_attn_layers[layer_name].get_attn_backend() + backend_key = attn_backend.full_cls_name() + group_key = (gid, backend_key, layer_kv_cache_spec) + if group_key not in attention_groups: + kernel_block_size = ( + kernel_block_sizes[gid] + if kernel_block_sizes is not None and gid < len(kernel_block_sizes) + else None + ) + attn_group = AttentionGroup( + backend=attn_backend, + layer_names=[layer_name], + kv_cache_spec=layer_kv_cache_spec, + kv_cache_group_id=gid, + ) + attn_group.create_metadata_builders( + self.vllm_config, + self.device, + kernel_block_size=kernel_block_size, + ) + attention_groups[group_key] = attn_group + else: + attention_groups[group_key].layer_names.append(layer_name) self.draft_attn_groups = list(attention_groups.values()) self.block_size = ( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 98b2fbf120eb..1f7e56130828 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2318,6 +2318,14 @@ def _build_attn_group_metadata( # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. spec_decode_common_attn_metadata = None + dflash_drafter = ( + self.drafter + if self.speculative_config and isinstance(self.drafter, DFlashProposer) + else None + ) + if dflash_drafter is not None: + dflash_drafter.clear_draft_block_tables() + for kv_cache_gid, kv_cache_group in enumerate(kv_cache_groups): cm = copy(cm_base) # shallow copy @@ -2333,6 +2341,11 @@ def _build_attn_group_metadata( cm.block_table_tensor = _get_block_table(kv_cache_gid) cm.slot_mapping = slot_mappings[kv_cache_gid] + if dflash_drafter is not None: + dflash_drafter.set_draft_block_table( + kv_cache_gid, cm.block_table_tensor + ) + if self.speculative_config and spec_decode_common_attn_metadata is None: if isinstance( self.drafter, (EagleProposer, DFlashProposer, Gemma4Proposer) @@ -5042,7 +5055,8 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: if not layer_ids: dflash_config = getattr(hf_config, "dflash_config", None) if dflash_config and isinstance(dflash_config, dict): - layer_ids = dflash_config.get("target_layer_ids") + # Add 1 to convert DFlash's aux layer id semantics. + layer_ids = [i + 1 for i in dflash_config.get("target_layer_ids", [])] if layer_ids and isinstance(layer_ids, (list, tuple)): return tuple(layer_ids) @@ -6675,6 +6689,118 @@ def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]: for attn_groups in self.attn_groups: yield from attn_groups + @staticmethod + def _get_kv_cache_stride_order( + attn_backend: type[AttentionBackend], + kv_cache_shape: tuple[int, ...], + ) -> tuple[int, ...]: + try: + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len(kv_cache_shape) + return kv_cache_stride_order + except (AttributeError, NotImplementedError): + return tuple(range(len(kv_cache_shape))) + + @classmethod + def _get_standard_kv_cache_orders( + cls, + attn_backend: type[AttentionBackend], + kv_cache_shape: tuple[int, ...], + num_blocks: int, + ) -> tuple[tuple[int, ...], tuple[str, ...] | None, tuple[str, ...] | None]: + stride_order = cls._get_kv_cache_stride_order(attn_backend, kv_cache_shape) + if len(kv_cache_shape) != 5: + return stride_order, None, None + if kv_cache_shape[:2] == (2, num_blocks): + public_order = ("kv", "block", "token", "head", "dim") + elif kv_cache_shape[:2] == (num_blocks, 2): + public_order = ("block", "kv", "token", "head", "dim") + else: + return stride_order, None, None + return stride_order, public_order, tuple(public_order[i] for i in stride_order) + + @staticmethod + def _view_kv_cache_with_physical_order( + raw_tensor: torch.Tensor, + kv_cache_shape: tuple[int, ...], + public_order: tuple[str, ...], + physical_order: tuple[str, ...], + ) -> torch.Tensor: + # Backends can expose the same standard KV cache with different public + # shapes or physical orders. When they share one raw tensor, keep each + # backend's public shape but map it to the shared physical layout. + dim_sizes = dict(zip(public_order, kv_cache_shape)) + physical_strides: dict[str, int] = {} + stride = 1 + for dim in reversed(physical_order): + physical_strides[dim] = stride + stride *= dim_sizes[dim] + public_strides = tuple(physical_strides[dim] for dim in public_order) + return torch.as_strided( + raw_tensor, + size=kv_cache_shape, + stride=public_strides, + ) + + @staticmethod + def _get_attention_kv_cache_shape( + attn_backend: type[AttentionBackend], + kv_cache_spec: AttentionSpec, + num_blocks: int, + kernel_block_size: int, + cache_dtype_str: str, + ) -> tuple[int, ...]: + # For MLA with compression, storage_block_size != block_size. + if kv_cache_spec.storage_block_size != kv_cache_spec.block_size: + shape_block_size = kv_cache_spec.storage_block_size + else: + shape_block_size = kernel_block_size + return attn_backend.get_kv_cache_shape( + num_blocks, + shape_block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str=cache_dtype_str, + ) + + def _get_raw_tensor_physical_orders( + self, + kv_cache_raw_tensors: dict[str, torch.Tensor], + kernel_block_sizes: list[int], + ) -> dict[int, set[tuple[str, ...]]]: + raw_tensor_physical_orders: dict[int, set[tuple[str, ...]]] = defaultdict(set) + for group in self._kv_cache_spec_attn_group_iterator(): + kv_cache_spec = group.kv_cache_spec + if group.kv_cache_group_id == len(kernel_block_sizes) or not isinstance( + kv_cache_spec, AttentionSpec + ): + continue + kernel_block_size = kernel_block_sizes[group.kv_cache_group_id] + for layer_name in group.layer_names: + if ( + layer_name in self.runner_only_attn_layers + or layer_name not in kv_cache_raw_tensors + ): + continue + raw_tensor = kv_cache_raw_tensors[layer_name] + num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes + num_blocks *= kv_cache_spec.block_size // kernel_block_size + kv_cache_shape = self._get_attention_kv_cache_shape( + group.backend, + kv_cache_spec, + num_blocks, + kernel_block_size, + self.cache_config.cache_dtype, + ) + _, _, physical_order = self._get_standard_kv_cache_orders( + group.backend, + kv_cache_shape, + num_blocks, + ) + if physical_order is not None: + raw_tensor_physical_orders[id(raw_tensor)].add(physical_order) + return raw_tensor_physical_orders + def _reshape_kv_cache_tensors( self, kv_cache_raw_tensors: dict[str, torch.Tensor], @@ -6693,6 +6819,10 @@ def _reshape_kv_cache_tensors( """ kv_caches: dict[str, torch.Tensor] = {} has_attn, has_mamba = False, False + raw_tensor_physical_orders = self._get_raw_tensor_physical_orders( + kv_cache_raw_tensors, kernel_block_sizes + ) + for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec attn_backend = group.backend @@ -6713,25 +6843,45 @@ def _reshape_kv_cache_tensors( ) kernel_num_blocks = num_blocks * num_blocks_per_kv_block - # For MLA with compression, storage_block_size != block_size - if kv_cache_spec.storage_block_size != kv_cache_spec.block_size: - shape_block_size = kv_cache_spec.storage_block_size - else: - shape_block_size = kernel_block_size - - kv_cache_shape = attn_backend.get_kv_cache_shape( + kv_cache_shape = self._get_attention_kv_cache_shape( + attn_backend, + kv_cache_spec, + kernel_num_blocks, + kernel_block_size, + self.cache_config.cache_dtype, + ) + ( + kv_cache_stride_order, + public_order, + physical_order, + ) = self._get_standard_kv_cache_orders( + attn_backend, + kv_cache_shape, kernel_num_blocks, - shape_block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - cache_dtype_str=self.cache_config.cache_dtype, ) dtype = kv_cache_spec.dtype - try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() - assert len(kv_cache_stride_order) == len(kv_cache_shape) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_shape))) + shared_orders = raw_tensor_physical_orders[id(raw_tensor)] + block_major_orders = ( + order for order in shared_orders if order[:2] == ("block", "kv") + ) + shared_physical_order = next( + iter(sorted(block_major_orders)), None + ) or next(iter(sorted(shared_orders)), None) + raw_tensor = raw_tensor.view(dtype) + if ( + public_order is not None + and physical_order != shared_physical_order + and shared_physical_order is not None + and kv_cache_spec.page_size_padded is None + ): + kv_caches[layer_name] = self._view_kv_cache_with_physical_order( + raw_tensor, + kv_cache_shape, + public_order, + shared_physical_order, + ) + continue + # The allocation respects the backend-defined stride order # to ensure the semantic remains consistent for each # backend. We first obtain the generic kv cache shape and @@ -6746,16 +6896,15 @@ def _reshape_kv_cache_tensors( for i in range(len(kv_cache_stride_order)) ] - raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype) if kv_cache_spec.page_size_padded is not None: # Use strided view to handle page_size_bytes that - # include padding. This follows - # the same pattern as MambaSpec handling below. + # include padding. This follows the same pattern as + # MambaSpec handling below. # NOTE: This assumes kv_cache_shape[0] == num_blocks # (i.e. the first physical dimension is the block # index), which holds for MLA backends but NOT for - # standard attention backends whose shape starts with - # a K/V dimension of size 2. + # standard attention backends whose shape starts + # with a K/V dimension of size 2. dtype_size = get_dtype_size(dtype) page_stride = kv_cache_spec.page_size_bytes // dtype_size strides = list(torch.empty(kv_cache_shape).stride())