Skip to content
28 changes: 28 additions & 0 deletions tests/v1/core/test_kv_sharing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from types import SimpleNamespace

import pytest
import torch

from vllm.v1.core.kv_cache_utils import get_kv_cache_config_from_groups
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheGroupSpec
from vllm.v1.worker.utils import add_kv_sharing_layers_to_kv_cache_groups

Expand Down Expand Up @@ -103,3 +106,28 @@ def test_initialize_kv_cache_for_kv_sharing_no_attn_groups():
assert len(kv_cache_groups) == 2
assert kv_cache_groups[0].layer_names == ["model.layers.0", "model.layers.2"]
assert kv_cache_groups[1].layer_names == ["model.layers.1", "model.layers.3"]


def test_dflash_draft_kv_groups_keep_hybrid_tensor_sharing():
spec = new_kv_cache_spec()
num_blocks = 8
vllm_config = SimpleNamespace(
speculative_config=SimpleNamespace(method="dflash"),
cache_config=SimpleNamespace(num_gpu_blocks_override=None),
)
kv_cache_groups = [
KVCacheGroupSpec(["model.layers.0", "model.layers.1"], spec),
KVCacheGroupSpec(["model.layers.2", "model.layers.3"], spec),
]

kv_cache_config = get_kv_cache_config_from_groups(
vllm_config=vllm_config,
kv_cache_groups=kv_cache_groups,
available_memory=spec.page_size_bytes * 2 * num_blocks,
)

assert kv_cache_config.num_blocks == num_blocks
assert [tensor.shared_by for tensor in kv_cache_config.kv_cache_tensors] == [
["model.layers.0", "model.layers.2"],
["model.layers.1", "model.layers.3"],
]
203 changes: 203 additions & 0 deletions tests/v1/spec_decode/test_dflash_swa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from types import SimpleNamespace

import torch

from vllm.config import SpeculativeConfig
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.models.qwen3_dflash import DFlashAttention
from vllm.transformers_utils.configs.speculators import SpeculatorsConfig
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
SlidingWindowSpec,
)
from vllm.v1.spec_decode.dflash import DFlashProposer


class _FakeBuilder:
def __init__(
self, kv_cache_spec=None, layer_names=None, vllm_config=None, device=None
):
self.kv_cache_spec = kv_cache_spec
self.layer_names = layer_names

def build_for_drafting(self, common_attn_metadata, draft_index):
return SimpleNamespace(
causal=common_attn_metadata.causal,
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
)


class _FakeAttentionGroup:
def __init__(self, layer_names, kv_cache_group_id=0):
self.layer_names = layer_names
self.kv_cache_group_id = kv_cache_group_id
self._builder = _FakeBuilder()

def get_metadata_builder(self):
return self._builder


def _make_cad(block_table, slot_mapping) -> CommonAttentionMetadata:
return CommonAttentionMetadata(
query_start_loc=torch.tensor([0, 2], dtype=torch.int32),
query_start_loc_cpu=torch.tensor([0, 2], dtype=torch.int32),
seq_lens=torch.tensor([2], dtype=torch.int32),
num_reqs=1,
num_actual_tokens=2,
max_query_len=2,
max_seq_len=2,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
causal=False,
)


def test_dflash_speculators_preserves_swa_config():
layer_types = [
"sliding_attention",
"sliding_attention",
"full_attention",
]
config = {
"speculators_model_type": "dflash",
"transformer_layer_config": {
"num_hidden_layers": len(layer_types),
"sliding_window": None,
},
"draft_vocab_size": 100,
"target_hidden_size": 64,
"aux_hidden_state_layer_ids": [0, 1, 2],
"mask_token_id": 99,
"layer_types": layer_types,
"use_sliding_window": True,
"sliding_window": 2048,
"max_window_layers": len(layer_types),
}

hf_config = SpeculatorsConfig.extract_transformers_pre_trained_config(config)

assert hf_config["layer_types"] == layer_types
assert hf_config["use_sliding_window"] is True
assert hf_config["sliding_window"] == 2048
assert hf_config["max_window_layers"] == len(layer_types)
assert hf_config["eagle_aux_hidden_state_layer_ids"] == [1, 2, 3]
assert hf_config["dflash_config"]["target_layer_ids"] == [0, 1, 2]


def _compute_dflash_hash(hf_config: SimpleNamespace) -> str:
config = object.__new__(SpeculativeConfig)
config.method = "dflash"
config.draft_model_config = SimpleNamespace(hf_config=hf_config)
return config.compute_hash()


def test_dflash_compile_hash_uses_checkpoint_layer_id_semantics():
dflash_hash = _compute_dflash_hash(
SimpleNamespace(dflash_config={"target_layer_ids": [0, 2]})
)
shifted_aux_hash = _compute_dflash_hash(
SimpleNamespace(eagle_aux_hidden_state_layer_ids=[1, 3])
)
different_hash = _compute_dflash_hash(
SimpleNamespace(dflash_config={"target_layer_ids": [0, 3]})
)

assert dflash_hash == shifted_aux_hash
assert dflash_hash != different_hash


def test_dflash_swa_layers_use_full_kv_cache_spec(monkeypatch):
sliding_spec = SlidingWindowSpec(
block_size=16,
num_kv_heads=1,
head_size=8,
dtype=torch.float16,
sliding_window=4,
)
monkeypatch.setattr(
Attention,
"get_kv_cache_spec",
lambda self, vllm_config: sliding_spec,
)

spec = DFlashAttention.get_kv_cache_spec(
object.__new__(DFlashAttention), SimpleNamespace()
)

assert isinstance(spec, FullAttentionSpec)
assert spec.block_size == sliding_spec.block_size
assert spec.num_kv_heads == sliding_spec.num_kv_heads
assert spec.head_size == sliding_spec.head_size
assert spec.sliding_window is None


def test_dflash_swa_layers_use_causal_metadata():
proposer = object.__new__(DFlashProposer)
proposer.model = SimpleNamespace(sliding_attention_layer_names={"layer.sw"})
proposer.draft_attn_groups = [_FakeAttentionGroup(["layer.sw", "layer.full"])]
proposer.kv_cache_gid = 0
proposer._draft_kv_cache_group_ids = [0]
proposer._draft_layer_to_kv_cache_gid = {
"layer.sw": 0,
"layer.full": 0,
}
proposer._draft_block_tables = {}
cad = _make_cad(
torch.empty(1, 1, dtype=torch.int32),
torch.empty(2, dtype=torch.int64),
)
proposer._slot_mapping_buffers_by_gid = {0: (cad.slot_mapping, cad.slot_mapping)}

per_group, per_layer = DFlashProposer.build_per_group_and_layer_attn_metadata(
proposer, cad
)

assert per_group[0].causal is False
assert per_layer["layer.sw"].causal is True
assert per_layer["layer.full"].causal is False


def test_dflash_metadata_uses_per_kv_group_slot_mapping():
proposer = object.__new__(DFlashProposer)
proposer.model = SimpleNamespace(sliding_attention_layer_names={"layer.sw"})
proposer.draft_attn_groups = [
_FakeAttentionGroup(["layer.full"], kv_cache_group_id=1),
_FakeAttentionGroup(["layer.sw"], kv_cache_group_id=2),
]
proposer.kv_cache_gid = 1
proposer._draft_kv_cache_group_ids = [1, 2]
proposer._draft_layer_to_kv_cache_gid = {
"layer.full": 1,
"layer.sw": 2,
}

full_block_table = torch.tensor([[11, 12]], dtype=torch.int32)
sw_block_table = torch.tensor([[21, 22]], dtype=torch.int32)
full_slots = torch.tensor([111, 112], dtype=torch.int64)
sw_slots = torch.tensor([211, 212], dtype=torch.int64)

base_cad = _make_cad(full_block_table, full_slots)
proposer._draft_block_tables = {
1: full_block_table,
2: sw_block_table,
}
proposer._slot_mapping_buffers_by_gid = {
1: (full_slots, full_slots),
2: (sw_slots, sw_slots),
}

_, per_layer = DFlashProposer.build_per_group_and_layer_attn_metadata(
proposer, base_cad
)

assert per_layer["layer.full"].block_table_tensor is full_block_table
torch.testing.assert_close(per_layer["layer.full"].slot_mapping, full_slots)
assert per_layer["layer.full"].causal is False
assert per_layer["layer.sw"].block_table_tensor is sw_block_table
torch.testing.assert_close(per_layer["layer.sw"].slot_mapping, sw_slots)
assert per_layer["layer.sw"].causal is True
35 changes: 35 additions & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,41 @@ def rnd_stride_order(
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)


@pytest.mark.parametrize(
"physical_order",
[
("block", "kv", "token", "head", "dim"),
("block", "kv", "head", "token", "dim"),
],
)
def test_kv_major_cache_can_share_block_major_raw_tensor(physical_order):
kv_cache_shape = (2, 3, 4, 2, 8)
_, num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
block_elems = block_size * num_kv_heads * head_size
raw_tensor = torch.arange(2 * num_blocks * block_elems)
public_order = ("kv", "block", "token", "head", "dim")
dim_sizes = dict(zip(public_order, kv_cache_shape))
expected_strides = {}
stride = 1
for dim in reversed(physical_order):
expected_strides[dim] = stride
stride *= dim_sizes[dim]

kv_cache = GPUModelRunner._view_kv_cache_with_physical_order(
raw_tensor,
kv_cache_shape,
public_order,
physical_order,
)

assert kv_cache.shape == kv_cache_shape
assert kv_cache.stride() == tuple(expected_strides[dim] for dim in public_order)
assert kv_cache[0, 0, 0, 0, 0] == raw_tensor[0]
assert kv_cache[1, 0, 0, 0, 0] == raw_tensor[block_elems]
assert kv_cache[0, 1, 0, 0, 0] == raw_tensor[2 * block_elems]
assert kv_cache[1, 1, 0, 0, 0] == raw_tensor[3 * block_elems]


def test_update_config(model_runner):
# Simple update
model_runner.update_config({"load_config": {"load_format": "dummy"}})
Expand Down
12 changes: 12 additions & 0 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,14 @@ def compute_hash(self) -> str:
"eagle_aux_hidden_state_layer_ids",
None,
)
if not layer_ids:
dflash_config = getattr(
self.draft_model_config.hf_config, "dflash_config", None
)
if dflash_config and isinstance(dflash_config, dict):
layer_ids = [
i + 1 for i in dflash_config.get("target_layer_ids", [])
]
if layer_ids is not None:
# Convert to tuple to make it hashable
factors.append(tuple(layer_ids))
Expand Down Expand Up @@ -1047,6 +1055,10 @@ def use_gemma4_mtp(self) -> bool:
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "mtp", "dflash")

def requires_eagle_cache_drop(self) -> bool:
"""Whether prefix cache hits must drop one block for hidden states."""
return self.use_eagle() and not self.use_dflash()

def use_dflash(self) -> bool:
return self.method == "dflash"

Expand Down
Loading
Loading