Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hashlib
import importlib
from collections.abc import Callable
from types import SimpleNamespace
from typing import Any

import pytest
Expand Down Expand Up @@ -177,6 +178,19 @@ def new_mamba_spec(
)


def make_dflash_test_config(target_num_layers=2):
return SimpleNamespace(
speculative_config=SimpleNamespace(method="dflash"),
model_config=SimpleNamespace(
max_model_len=16,
get_num_layers=lambda parallel_config: target_num_layers,
),
parallel_config=SimpleNamespace(),
scheduler_config=SimpleNamespace(disable_hybrid_kv_cache_manager=False),
cache_config=SimpleNamespace(num_gpu_blocks_override=None),
)


@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_none_hash(monkeypatch, hash_fn):
import vllm.v1.core.kv_cache_utils
Expand Down Expand Up @@ -1768,6 +1782,101 @@ def test_get_kv_cache_config_one_worker():
)


def test_dflash_isolated_specs_are_partitioned_before_page_size_unify():
vllm_config = make_dflash_test_config(target_num_layers=2)
kv_cache_specs = {
"model.layers.0.self_attn.attn": new_kv_cache_spec(head_size=64),
"model.layers.1.self_attn.attn": new_sliding_window_spec(head_size=32),
"model.layers.2.self_attn.attn": new_kv_cache_spec(
dtype=torch.bfloat16,
head_size=192,
),
}

groups = kv_cache_utils.get_kv_cache_groups(vllm_config, kv_cache_specs)

assert groups == [
KVCacheGroupSpec(
["model.layers.0.self_attn.attn"],
new_kv_cache_spec(head_size=64),
),
KVCacheGroupSpec(
["model.layers.1.self_attn.attn"],
new_sliding_window_spec(block_size=32, head_size=32),
),
KVCacheGroupSpec(
["model.layers.2.self_attn.attn"],
new_kv_cache_spec(dtype=torch.bfloat16, head_size=192),
),
]


def test_dflash_heterogeneous_page_size_allocator_keeps_isolated_pool():
vllm_config = make_dflash_test_config(target_num_layers=2)
target_page_size = new_kv_cache_spec(head_size=64).page_size_bytes
draft_spec = new_kv_cache_spec(dtype=torch.bfloat16, head_size=192)
draft_page_size = draft_spec.page_size_bytes
groups = [
KVCacheGroupSpec(
["model.layers.0.self_attn.attn"],
new_kv_cache_spec(head_size=64),
),
KVCacheGroupSpec(
["model.layers.1.self_attn.attn"],
new_sliding_window_spec(block_size=32, head_size=32),
),
KVCacheGroupSpec(
["model.layers.2.self_attn.attn"],
draft_spec,
),
]
num_blocks = 10
available_memory = (target_page_size + draft_page_size) * num_blocks

kv_cache_config = kv_cache_utils.get_kv_cache_config_from_groups(
vllm_config, groups, available_memory
)

assert kv_cache_config.num_blocks == num_blocks
assert kv_cache_config.kv_cache_tensors == [
KVCacheTensor(
size=target_page_size * num_blocks,
shared_by=[
"model.layers.0.self_attn.attn",
"model.layers.1.self_attn.attn",
],
),
KVCacheTensor(
size=draft_page_size * num_blocks,
shared_by=["model.layers.2.self_attn.attn"],
),
]
assert sum(t.size for t in kv_cache_config.kv_cache_tensors) == available_memory


def test_non_dflash_grouping_still_uses_existing_unify_path():
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config)
kv_cache_specs = {
"model.layers.0.self_attn.attn": new_kv_cache_spec(head_size=64),
"model.layers.1.self_attn.attn": new_sliding_window_spec(head_size=32),
}

assert kv_cache_utils._partition_dflash_isolated_specs(
vllm_config, kv_cache_specs
) == (kv_cache_specs, {})
assert kv_cache_utils.get_kv_cache_groups(vllm_config, kv_cache_specs) == [
KVCacheGroupSpec(
["model.layers.0.self_attn.attn"],
new_kv_cache_spec(head_size=64),
),
KVCacheGroupSpec(
["model.layers.1.self_attn.attn"],
new_sliding_window_spec(block_size=32, head_size=32),
),
]


def test_get_kv_cache_configs_attention_free():
kv_cache_specs: dict[str, KVCacheSpec] = {}
vllm_config = VllmConfig(model_config=ModelConfig(max_model_len=16))
Expand Down
12 changes: 11 additions & 1 deletion vllm/model_executor/models/qwen3_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Iterable
from dataclasses import replace

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -33,6 +34,7 @@
)
from vllm.multimodal.inputs import NestedTensors
from vllm.transformers_utils.config import set_default_rope_theta
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import AttentionType

from .qwen2 import Qwen2MLP as Qwen3MLP
Expand Down Expand Up @@ -109,12 +111,20 @@ def __init__(
max_position=max_position,
rope_parameters=rope_parameters,
)
# DFlash draft layers use an independent KV cache pool. Keep the
# target's block/sliding-window settings, but do not inherit a
# quantized target KV dtype into the BF16 draft attention path.
draft_cache_config = cache_config
if draft_cache_config is not None and is_quantized_kv_cache(
draft_cache_config.cache_dtype
):
draft_cache_config = replace(draft_cache_config, cache_dtype="auto")
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
cache_config=draft_cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=attn_type,
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from vllm.v1.attention.backends.utils import (
get_kv_cache_layout,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.kv_cache_interface import AttentionSpec, KVQuantMode

logger = init_logger(__name__)

Expand Down Expand Up @@ -449,7 +449,9 @@ def schedule(
batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
):
cache_dtype = self.cache_config.cache_dtype
if is_quantized_kv_cache(cache_dtype):
if self.kv_cache_spec.kv_quant_mode == KVQuantMode.NONE:
qkv_dtype = self.kv_cache_dtype
elif is_quantized_kv_cache(cache_dtype):
qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
cache_dtype
)
Expand Down
Loading
Loading