From cfb8f711681b115d4e070e18946b7c9d75183914 Mon Sep 17 00:00:00 2001 From: noonghunna <10742901+noonghunna@users.noreply.github.com> Date: Fri, 8 May 2026 17:42:55 +0000 Subject: [PATCH 1/2] [Spec Decode] Allow DFlash drafter to coexist with quantized target KV via independent KV groups + dtype override MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refs: vllm-project/vllm#41559 The original framing in #41559 (filed against v0.20.0) — that DFlash + KV-quant fails because backend allowlists exclude KV-quant when causal=False — has partially drifted on current main. FLASH_ATTN already gates dynamically via flash_attn_supports_fp8(); FLEX_ATTENTION rejects all quantized KV at implementation construction time; TRITON_ATTN remains causal-only. So an allowlist-only patch is no longer the right shape. This patch addresses the actual remaining blocker for the *common case* on Ampere consumer hardware: a BF16 DFlash drafter coexisting with a quantized target KV cache (e.g. int8_per_token_head). The fix is to stop forcing target+drafter through a single page-size unify pass — drafter has its own independent KV pool by design, so it can keep its existing BF16 page geometry without forced reconciliation against target's quantized geometry. Three layers: 1. vllm/v1/core/kv_cache_utils.py: partition DFlash draft KV specs before page-size unify. Target specs go through the existing unify path. Drafter specs form independent KV groups with their own page_size_bytes. Allocator code extended to size isolated DFlash tensors by their own page size rather than via get_uniform_page_size(). 2. vllm/model_executor/models/qwen3_dflash.py: override DFlash drafter cache_dtype to "auto" when the engine's global cache_dtype is quantized. The drafter's KV pool is independent post-(1), so it doesn't need to inherit target's dtype. 3. vllm/v1/attention/backends/flash_attn.py: in the metadata scheduler, when the per-spec kv_quant_mode is NONE, use the spec's local kv_cache_dtype rather than the global cache_config.cache_dtype. This patch is independent of (and builds on) the already-merged PR #39930 (independent drafter attention backend selection). DFlash drafter naturally selects FLASH_ATTN as its non-causal-capable backend; target stays on TRITON_ATTN; each backend sees the KV dtype it can handle. ## Why this is not duplicating an existing PR - #41559 itself has 0 cross-referenced PRs. - #42069 (mikeumus, OPEN/BLOCKED) addresses TRITON_ATTN propagating to drafter on Gemma 4 — different layer of the bug, complementary. - #40425 addresses quantized DRAFTER WEIGHTS (NVFP4 draft model), not quantized KV cache. - #41703 / #40898 / #41971 / #39995 are different DFlash bugs. ## Testing tests/v1/core/test_kv_cache_utils.py adds three test cases: - DFlash draft specs are partitioned before page-size unify. - Heterogeneous target/draft page-size allocation produces correct KVCacheTensor layouts. - Non-DFlash regression: when no DFlash isolated groups exist, the existing unify behavior is unchanged byte-for-byte. Local end-to-end validation on 2x RTX 3090 (sm_86), Gemma 4 31B + z-lab DFlash drafter, target INT8 PTH KV via PR #40391 (vendored locally), 65K max_model_len: | Metric | This patch | gemma-dflash.yml (32K bf16 baseline) | |-------------------------|-----------:|-------------------------------------:| | Boot HEALTHY | yes | yes | | Paris smoke output | clean | clean | | Narrative TPS | 95.89 | 95-104 | | Code TPS | 168.09 | 168-177 | | AL on long-ctx code | 5.0-5.3 | 4.94 | | NIAH at 32K prompt | PASS | n/a | | KV pool tokens | 149,345 | ~38,000 | | Max ctx validated | 65K | 32K | | VRAM | 23.85 GB | 22.7 GB | The DFlash long-context code-optimal cell of the matrix — previously unreachable on Ampere — is now reachable. ## AI assistance This patch was developed with AI assistance (Claude + Codex) using a Claude Code -> MCP -> Codex collaboration. The human submitter reviewed every changed line, validated the patch end-to-end on the rig described above, and is accountable for the change. Co-authored-by: Codex Co-authored-by: Claude Signed-off-by: noonghunna <10742901+noonghunna@users.noreply.github.com> --- tests/v1/core/test_kv_cache_utils.py | 109 ++++++++++ vllm/model_executor/models/qwen3_dflash.py | 12 +- vllm/v1/attention/backends/flash_attn.py | 6 +- vllm/v1/core/kv_cache_utils.py | 241 ++++++++++++++++++--- 4 files changed, 333 insertions(+), 35 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 985b97c69ca4..084555468d51 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -3,6 +3,7 @@ import hashlib import importlib from collections.abc import Callable +from types import SimpleNamespace from typing import Any import pytest @@ -177,6 +178,19 @@ def new_mamba_spec( ) +def make_dflash_test_config(target_num_layers=2): + return SimpleNamespace( + speculative_config=SimpleNamespace(method="dflash"), + model_config=SimpleNamespace( + max_model_len=16, + get_num_layers=lambda parallel_config: target_num_layers, + ), + parallel_config=SimpleNamespace(), + scheduler_config=SimpleNamespace(disable_hybrid_kv_cache_manager=False), + cache_config=SimpleNamespace(num_gpu_blocks_override=None), + ) + + @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_none_hash(monkeypatch, hash_fn): import vllm.v1.core.kv_cache_utils @@ -1768,6 +1782,101 @@ def test_get_kv_cache_config_one_worker(): ) +def test_dflash_isolated_specs_are_partitioned_before_page_size_unify(): + vllm_config = make_dflash_test_config(target_num_layers=2) + kv_cache_specs = { + "model.layers.0.self_attn.attn": new_kv_cache_spec(head_size=64), + "model.layers.1.self_attn.attn": new_sliding_window_spec(head_size=32), + "model.layers.2.self_attn.attn": new_kv_cache_spec( + dtype=torch.bfloat16, + head_size=192, + ), + } + + groups = kv_cache_utils.get_kv_cache_groups(vllm_config, kv_cache_specs) + + assert groups == [ + KVCacheGroupSpec( + ["model.layers.0.self_attn.attn"], + new_kv_cache_spec(head_size=64), + ), + KVCacheGroupSpec( + ["model.layers.1.self_attn.attn"], + new_sliding_window_spec(block_size=32, head_size=32), + ), + KVCacheGroupSpec( + ["model.layers.2.self_attn.attn"], + new_kv_cache_spec(dtype=torch.bfloat16, head_size=192), + ), + ] + + +def test_dflash_heterogeneous_page_size_allocator_keeps_isolated_pool(): + vllm_config = make_dflash_test_config(target_num_layers=2) + target_page_size = new_kv_cache_spec(head_size=64).page_size_bytes + draft_spec = new_kv_cache_spec(dtype=torch.bfloat16, head_size=192) + draft_page_size = draft_spec.page_size_bytes + groups = [ + KVCacheGroupSpec( + ["model.layers.0.self_attn.attn"], + new_kv_cache_spec(head_size=64), + ), + KVCacheGroupSpec( + ["model.layers.1.self_attn.attn"], + new_sliding_window_spec(block_size=32, head_size=32), + ), + KVCacheGroupSpec( + ["model.layers.2.self_attn.attn"], + draft_spec, + ), + ] + num_blocks = 10 + available_memory = (target_page_size + draft_page_size) * num_blocks + + kv_cache_config = kv_cache_utils.get_kv_cache_config_from_groups( + vllm_config, groups, available_memory + ) + + assert kv_cache_config.num_blocks == num_blocks + assert kv_cache_config.kv_cache_tensors == [ + KVCacheTensor( + size=target_page_size * num_blocks, + shared_by=[ + "model.layers.0.self_attn.attn", + "model.layers.1.self_attn.attn", + ], + ), + KVCacheTensor( + size=draft_page_size * num_blocks, + shared_by=["model.layers.2.self_attn.attn"], + ), + ] + assert sum(t.size for t in kv_cache_config.kv_cache_tensors) == available_memory + + +def test_non_dflash_grouping_still_uses_existing_unify_path(): + model_config = ModelConfig(max_model_len=16) + vllm_config = VllmConfig(model_config=model_config) + kv_cache_specs = { + "model.layers.0.self_attn.attn": new_kv_cache_spec(head_size=64), + "model.layers.1.self_attn.attn": new_sliding_window_spec(head_size=32), + } + + assert kv_cache_utils._partition_dflash_isolated_specs( + vllm_config, kv_cache_specs + ) == (kv_cache_specs, {}) + assert kv_cache_utils.get_kv_cache_groups(vllm_config, kv_cache_specs) == [ + KVCacheGroupSpec( + ["model.layers.0.self_attn.attn"], + new_kv_cache_spec(head_size=64), + ), + KVCacheGroupSpec( + ["model.layers.1.self_attn.attn"], + new_sliding_window_spec(block_size=32, head_size=32), + ), + ] + + def test_get_kv_cache_configs_attention_free(): kv_cache_specs: dict[str, KVCacheSpec] = {} vllm_config = VllmConfig(model_config=ModelConfig(max_model_len=16)) diff --git a/vllm/model_executor/models/qwen3_dflash.py b/vllm/model_executor/models/qwen3_dflash.py index cffe6267a4b3..0207d157f365 100644 --- a/vllm/model_executor/models/qwen3_dflash.py +++ b/vllm/model_executor/models/qwen3_dflash.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable +from dataclasses import replace import torch import torch.nn.functional as F @@ -33,6 +34,7 @@ ) from vllm.multimodal.inputs import NestedTensors from vllm.transformers_utils.config import set_default_rope_theta +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import AttentionType from .qwen2 import Qwen2MLP as Qwen3MLP @@ -109,12 +111,20 @@ def __init__( max_position=max_position, rope_parameters=rope_parameters, ) + # DFlash draft layers use an independent KV cache pool. Keep the + # target's block/sliding-window settings, but do not inherit a + # quantized target KV dtype into the BF16 draft attention path. + draft_cache_config = cache_config + if draft_cache_config is not None and is_quantized_kv_cache( + draft_cache_config.cache_dtype + ): + draft_cache_config = replace(draft_cache_config, cache_dtype="auto") self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config, + cache_config=draft_cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", attn_type=attn_type, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 4b8b86d864be..93c488d37746 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -61,7 +61,7 @@ from vllm.v1.attention.backends.utils import ( get_kv_cache_layout, ) -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import AttentionSpec, KVQuantMode logger = init_logger(__name__) @@ -449,7 +449,9 @@ def schedule( batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal ): cache_dtype = self.cache_config.cache_dtype - if is_quantized_kv_cache(cache_dtype): + if self.kv_cache_spec.kv_quant_mode == KVQuantMode.NONE: + qkv_dtype = self.kv_cache_dtype + elif is_quantized_kv_cache(cache_dtype): qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( cache_dtype ) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index b57e10b67faa..cace1cb5d5c8 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -8,10 +8,12 @@ import os from collections import defaultdict from collections.abc import Callable, Iterable, Iterator, Sequence -from dataclasses import dataclass, replace +from dataclasses import dataclass from functools import partial from typing import Any, NewType, TypeAlias, cast, overload +import regex as re + from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger @@ -79,6 +81,7 @@ def maybe_convert_block_hash(hash_bytes: BlockHash) -> ExternalBlockHash: logger = init_logger(__name__) +_LAYER_INDEX_RE = re.compile(r"(?:^|[.])layers[.](\d+)(?:[.]|$)") # The hash seed for the first block of any prefix block sequence. # @@ -900,7 +903,10 @@ def may_override_num_blocks(vllm_config: VllmConfig, num_blocks: int) -> int: return num_blocks -def _pool_bytes_per_block(kv_cache_groups: list[KVCacheGroupSpec]) -> int: +def _pool_bytes_per_block( + vllm_config: VllmConfig, + kv_cache_groups: list[KVCacheGroupSpec], +) -> int: """ Bytes consumed by one block in the worker's shared KV cache pool, mirroring the divisor used by `get_kv_cache_config_from_groups` to convert @@ -913,7 +919,7 @@ def _pool_bytes_per_block(kv_cache_groups: list[KVCacheGroupSpec]) -> int: return kv_cache_groups[0].kv_cache_spec.page_size_bytes if all( isinstance(g.kv_cache_spec, UniformTypeKVCacheSpecs) for g in kv_cache_groups - ): + ) and not _get_dflash_isolated_group_ids(vllm_config, kv_cache_groups): # DeepseekV4: shared layout sized by the largest per-page-size bucket. full_mla_spec = cast(UniformTypeKVCacheSpecs, kv_cache_groups[0].kv_cache_spec) layer_tuple_page_bytes = sum(full_mla_spec.get_page_sizes()) @@ -922,6 +928,18 @@ def _pool_bytes_per_block(kv_cache_groups: list[KVCacheGroupSpec]) -> int: for g in kv_cache_groups ) return layer_tuple_page_bytes * num_layer_tuples + isolated_layers = _get_dflash_isolated_layers(vllm_config, kv_cache_groups) + if isolated_layers: + shared_groups = _get_shared_kv_cache_groups(vllm_config, kv_cache_groups) + shared_bytes = 0 + if shared_groups: + shared_group_size = max(len(g.layer_names) for g in shared_groups) + shared_page_size = get_uniform_page_size( + [g.kv_cache_spec for g in shared_groups] + ) + shared_bytes = shared_page_size * shared_group_size + return shared_bytes + sum(spec.page_size_bytes for _, spec in isolated_layers) + group_size = max(len(g.layer_names) for g in kv_cache_groups) page_size = get_uniform_page_size([g.kv_cache_spec for g in kv_cache_groups]) return page_size * group_size @@ -956,6 +974,108 @@ def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int: return page_sizes.pop() +def _get_dflash_isolated_layer_names( + vllm_config: VllmConfig, + layer_names: Iterable[str], +) -> set[str]: + spec_config = vllm_config.speculative_config + if spec_config is None or spec_config.method != "dflash": + return set() + + try: + target_num_layers = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config + ) + except Exception: + return set() + + isolated_layer_names: set[str] = set() + for layer_name in layer_names: + match = _LAYER_INDEX_RE.search(layer_name) + if match is not None and int(match.group(1)) >= target_num_layers: + isolated_layer_names.add(layer_name) + return isolated_layer_names + + +def _partition_dflash_isolated_specs( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], +) -> tuple[dict[str, KVCacheSpec], dict[str, KVCacheSpec]]: + isolated_layer_names = _get_dflash_isolated_layer_names( + vllm_config, kv_cache_spec.keys() + ) + if not isolated_layer_names: + return kv_cache_spec, {} + + shared_specs = { + layer_name: layer_spec + for layer_name, layer_spec in kv_cache_spec.items() + if layer_name not in isolated_layer_names + } + isolated_specs = { + layer_name: layer_spec + for layer_name, layer_spec in kv_cache_spec.items() + if layer_name in isolated_layer_names + } + if not shared_specs or not isolated_specs: + return kv_cache_spec, {} + return shared_specs, isolated_specs + + +def _get_dflash_isolated_group_ids( + vllm_config: VllmConfig, + kv_cache_groups: list[KVCacheGroupSpec], +) -> set[int]: + isolated_layer_names = _get_dflash_isolated_layer_names( + vllm_config, + (layer_name for group in kv_cache_groups for layer_name in group.layer_names), + ) + if not isolated_layer_names: + return set() + + group_ids: set[int] = set() + for group_id, group in enumerate(kv_cache_groups): + if group.layer_names and all( + layer_name in isolated_layer_names for layer_name in group.layer_names + ): + group_ids.add(group_id) + return group_ids + + +def _get_layer_spec_from_group( + kv_cache_group: KVCacheGroupSpec, + layer_name: str, +) -> KVCacheSpec: + group_spec = kv_cache_group.kv_cache_spec + if isinstance(group_spec, UniformTypeKVCacheSpecs): + return group_spec.kv_cache_specs[layer_name] + return group_spec + + +def _get_dflash_isolated_layers( + vllm_config: VllmConfig, + kv_cache_groups: list[KVCacheGroupSpec], +) -> list[tuple[str, KVCacheSpec]]: + isolated_group_ids = _get_dflash_isolated_group_ids(vllm_config, kv_cache_groups) + return [ + (layer_name, _get_layer_spec_from_group(kv_cache_groups[group_id], layer_name)) + for group_id in sorted(isolated_group_ids) + for layer_name in kv_cache_groups[group_id].layer_names + ] + + +def _get_shared_kv_cache_groups( + vllm_config: VllmConfig, + kv_cache_groups: list[KVCacheGroupSpec], +) -> list[KVCacheGroupSpec]: + isolated_group_ids = _get_dflash_isolated_group_ids(vllm_config, kv_cache_groups) + return [ + group + for group_id, group in enumerate(kv_cache_groups) + if group_id not in isolated_group_ids + ] + + def _get_kv_cache_groups_uniform_spec( kv_cache_specs: dict[str, KVCacheSpec], ) -> list[KVCacheGroupSpec]: @@ -1038,7 +1158,7 @@ def unify_kv_cache_spec_page_size( ) ratio = max_page_size // layer_page_size new_block_size = layer_spec.block_size * ratio - new_spec = replace(layer_spec, block_size=new_block_size) + new_spec = layer_spec.copy_with_new_block_size(new_block_size) assert new_spec.page_size_bytes == max_page_size new_kv_cache_spec[layer_name] = new_spec return new_kv_cache_spec @@ -1275,7 +1395,7 @@ def get_kv_cache_config_from_groups( elif all( isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs) for group in kv_cache_groups - ): + ) and not _get_dflash_isolated_group_ids(vllm_config, kv_cache_groups): # DeepseekV4: UniformTypeKVCacheSpecs but multiple groups. # Delegate to the DeepseekV4-specific allocator. num_blocks, kv_cache_tensors = _get_kv_cache_config_deepseek_v4( @@ -1290,24 +1410,45 @@ def get_kv_cache_config_from_groups( # (sw.1, padding) will be: (group_size = 2) # full.0, sw.0, sw.1: share a Tensor with size=available_memory//2 # full.1, sw.2: share another Tensor with size=available_memory//2 - group_size = max(len(group.layer_names) for group in kv_cache_groups) - - page_size = get_uniform_page_size( - [group.kv_cache_spec for group in kv_cache_groups] + isolated_layers = _get_dflash_isolated_layers(vllm_config, kv_cache_groups) + shared_groups = ( + _get_shared_kv_cache_groups(vllm_config, kv_cache_groups) + if isolated_layers + else kv_cache_groups + ) + shared_group_size = ( + max(len(group.layer_names) for group in shared_groups) + if shared_groups + else 0 + ) + page_size = ( + get_uniform_page_size([group.kv_cache_spec for group in shared_groups]) + if shared_groups + else 0 ) - assert group_size > 0, "group_size must be greater than 0" - num_blocks = get_num_blocks( - vllm_config, group_size, available_memory, page_size + bytes_per_block = page_size * shared_group_size + sum( + spec.page_size_bytes for _, spec in isolated_layers ) + assert bytes_per_block > 0, "bytes_per_block must be greater than 0" + num_blocks = int(available_memory // bytes_per_block) + num_blocks = max(num_blocks, 0) + num_blocks = may_override_num_blocks(vllm_config, num_blocks) kv_cache_tensors = [] - for i in range(group_size): + for i in range(shared_group_size): shared_by = [] - for j in range(len(kv_cache_groups)): - if i < len(kv_cache_groups[j].layer_names): - shared_by.append(kv_cache_groups[j].layer_names[i]) + for group in shared_groups: + if i < len(group.layer_names): + shared_by.append(group.layer_names[i]) kv_cache_tensors.append( KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by) ) + for layer_name, layer_spec in isolated_layers: + kv_cache_tensors.append( + KVCacheTensor( + size=layer_spec.page_size_bytes * num_blocks, + shared_by=[layer_name], + ) + ) return KVCacheConfig( num_blocks=num_blocks, @@ -1610,7 +1751,7 @@ def _annotate_eagle_groups_deepseek_v4( break -def get_kv_cache_groups( +def _get_kv_cache_groups( vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec] ) -> list[KVCacheGroupSpec]: """ @@ -1623,9 +1764,6 @@ def get_kv_cache_groups( Returns: The generated KVCacheGroups """ - if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: - unify_hybrid_kv_cache_specs(kv_cache_spec) - if is_kv_cache_type_attention_free(kv_cache_spec): # This returns an empty list to allow for the KVCacheManager to handle # attention free models. @@ -1661,6 +1799,34 @@ def get_kv_cache_groups( return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) +def get_kv_cache_groups( + vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec] +) -> list[KVCacheGroupSpec]: + """ + Split the layers in the model into groups with the same KV cache spec. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of each attention layer in the model + + Returns: + The generated KVCacheGroups + """ + if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: + unify_hybrid_kv_cache_specs(kv_cache_spec) + + shared_specs, isolated_specs = _partition_dflash_isolated_specs( + vllm_config, kv_cache_spec + ) + if isolated_specs: + return [ + *_get_kv_cache_groups(vllm_config, shared_specs), + *_get_kv_cache_groups(vllm_config, isolated_specs), + ] + + return _get_kv_cache_groups(vllm_config, kv_cache_spec) + + def generate_scheduler_kv_cache_config( kv_cache_configs: list[KVCacheConfig], ) -> KVCacheConfig: @@ -1741,7 +1907,7 @@ def _max_memory_usage_bytes_from_groups( elif all( isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs) for group in kv_cache_groups - ): + ) and not _get_dflash_isolated_group_ids(vllm_config, kv_cache_groups): # Special case (only DeepseekV4 for now): all groups are # UniformTypeKVCacheSpecs. # They must already be page_size aligned and share a common padded @@ -1764,18 +1930,29 @@ def _max_memory_usage_bytes_from_groups( total_max_mem_usage_bytes += g_max_mem_usage_page_bytes return total_max_mem_usage_bytes - # General case: group_size pools, each shared by one layer per group - # Memory = group_size * page_size * blocks_for_max_len - group_size = max(len(group.layer_names) for group in kv_cache_groups) - page_size = get_uniform_page_size( - [group.kv_cache_spec for group in kv_cache_groups] - ) - blocks_needed = sum( - cdiv(group.kv_cache_spec.max_memory_usage_bytes(vllm_config), page_size) - for group in kv_cache_groups + isolated_layers = _get_dflash_isolated_layers(vllm_config, kv_cache_groups) + shared_groups = ( + _get_shared_kv_cache_groups(vllm_config, kv_cache_groups) + if isolated_layers + else kv_cache_groups ) - return group_size * page_size * blocks_needed + total_max_mem_usage_bytes = 0 + if shared_groups: + group_size = max(len(group.layer_names) for group in shared_groups) + page_size = get_uniform_page_size( + [group.kv_cache_spec for group in shared_groups] + ) + blocks_needed = sum( + cdiv(group.kv_cache_spec.max_memory_usage_bytes(vllm_config), page_size) + for group in shared_groups + ) + total_max_mem_usage_bytes += group_size * page_size * blocks_needed + + total_max_mem_usage_bytes += sum( + spec.max_memory_usage_bytes(vllm_config) for _, spec in isolated_layers + ) + return total_max_mem_usage_bytes def _estimate_max_model_len_from_groups( @@ -1994,7 +2171,7 @@ def get_kv_cache_configs( if not groups: adjusted_memory.append(avail_mem) continue - bytes_per_block = _pool_bytes_per_block(groups) + bytes_per_block = _pool_bytes_per_block(vllm_config, groups) logger.info( "Overriding num_gpu_blocks=%d with num_gpu_blocks_override=%d", avail_mem // bytes_per_block, From 5cb61c60d7d483dd0b2731a542ff72b04665355e Mon Sep 17 00:00:00 2001 From: noonghunna <10742901+noonghunna@users.noreply.github.com> Date: Fri, 8 May 2026 18:19:27 +0000 Subject: [PATCH 2/2] [Spec Decode] Address review: don't swallow get_num_layers errors in DFlash isolation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review on PR #42102 by @gemini-code-assist[bot]: > This try-except Exception block is too broad and can hide important > configuration errors. If get_num_layers fails when DFlash is enabled, > it should be a hard failure rather than being silently ignored. Silently > returning an empty set of isolated layers will cause the DFlash-specific > logic to be skipped, leading to the old unification path which this PR > aims to avoid for DFlash, likely causing more obscure errors later. The try/except was a defensive safety net but, on reflection, it would mask real config errors and produce harder-to-diagnose downstream failures (the exact unify failure this PR aims to fix). The "drafter convention doesn't match" graceful-fallback case is already handled separately by the layer name regex check below — if a layer name doesn't match the index pattern, it's simply skipped, no exception involved. Removing the try/except so a genuine get_num_layers() failure propagates with a clear message rather than silently degrading. Co-authored-by: gemini-code-assist[bot] Signed-off-by: noonghunna <10742901+noonghunna@users.noreply.github.com> --- vllm/v1/core/kv_cache_utils.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index cace1cb5d5c8..be0b2894e3aa 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -982,12 +982,9 @@ def _get_dflash_isolated_layer_names( if spec_config is None or spec_config.method != "dflash": return set() - try: - target_num_layers = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config - ) - except Exception: - return set() + target_num_layers = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config + ) isolated_layer_names: set[str] = set() for layer_name in layer_names: