diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 5354ef088d01..59786ed7a153 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -30,7 +30,10 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: def init_attn_backend( - kv_cache_config: KVCacheConfig, vllm_config: VllmConfig, device: torch.device + kv_cache_config: KVCacheConfig, + vllm_config: VllmConfig, + device: torch.device, + active_layer_names: set[str] | None = None, ): attn_backends: dict[str, type[AttentionBackend]] = {} attn_groups: list[list[AttentionGroup]] = [] @@ -39,6 +42,8 @@ def init_attn_backend( kv_cache_config.kv_cache_groups ): layer_names = kv_cache_group_spec.layer_names + if active_layer_names is not None: + layer_names = list(active_layer_names.intersection(layer_names)) layer_type = cast(type[Any], AttentionLayerBase) attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 57f170b59000..3f0bf4342a12 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -351,7 +351,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.speculator.set_attn( self.model_state, self.kv_cache_config, - self.attn_groups, self.block_tables, ) diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index 922031a52180..49b6b5331b5c 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -5,15 +5,17 @@ import torch import torch.nn as nn -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config.compilation import CUDAGraphMode from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.triton_utils import tl, triton from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import ( build_attn_metadata, build_slot_mappings_by_layer, + init_attn_backend, ) from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding @@ -22,7 +24,6 @@ from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model -from vllm.v1.worker.utils import AttentionGroup logger = init_logger(__name__) @@ -87,18 +88,35 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): ) def load_model(self, target_model: nn.Module) -> None: + target_attn_layer_names = get_layers_from_vllm_config( + self.vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + ).keys() + self.model = load_eagle_model(target_model, self.vllm_config) + all_attn_layers = get_layers_from_vllm_config( + self.vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + ).keys() + self.draft_attn_layer_names = set(all_attn_layers) - set( + target_attn_layer_names + ) + def set_attn( self, model_state: ModelState, kv_cache_config: KVCacheConfig, - attn_groups: list[list[AttentionGroup]], block_tables: BlockTables, ) -> None: self.model_state = model_state self.kv_cache_config = kv_cache_config - self.attn_groups = attn_groups + _, self.attn_groups = init_attn_backend( + kv_cache_config, + self.vllm_config, + self.device, + active_layer_names=self.draft_attn_layer_names, + ) self.block_tables = block_tables @torch.inference_mode()