Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vllm/v1/worker/gpu/attn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:


def init_attn_backend(
kv_cache_config: KVCacheConfig, vllm_config: VllmConfig, device: torch.device
kv_cache_config: KVCacheConfig,
vllm_config: VllmConfig,
device: torch.device,
active_layer_names: set[str] | None = None,
):
attn_backends: dict[str, type[AttentionBackend]] = {}
attn_groups: list[list[AttentionGroup]] = []
Expand All @@ -39,6 +42,8 @@ def init_attn_backend(
kv_cache_config.kv_cache_groups
):
layer_names = kv_cache_group_spec.layer_names
if active_layer_names is not None:
layer_names = list(active_layer_names.intersection(layer_names))

layer_type = cast(type[Any], AttentionLayerBase)
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
Expand Down
1 change: 0 additions & 1 deletion vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
self.speculator.set_attn(
self.model_state,
self.kv_cache_config,
self.attn_groups,
self.block_tables,
)

Expand Down
26 changes: 22 additions & 4 deletions vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.triton_utils import tl, triton
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
init_attn_backend,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
Expand All @@ -22,7 +24,6 @@
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
from vllm.v1.worker.utils import AttentionGroup

logger = init_logger(__name__)

Expand Down Expand Up @@ -87,18 +88,35 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
)

def load_model(self, target_model: nn.Module) -> None:
target_attn_layer_names = get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
).keys()

self.model = load_eagle_model(target_model, self.vllm_config)

all_attn_layers = get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
).keys()
self.draft_attn_layer_names = set(all_attn_layers) - set(
target_attn_layer_names
)

def set_attn(
self,
model_state: ModelState,
kv_cache_config: KVCacheConfig,
attn_groups: list[list[AttentionGroup]],
block_tables: BlockTables,
) -> None:
self.model_state = model_state
self.kv_cache_config = kv_cache_config
self.attn_groups = attn_groups
_, self.attn_groups = init_attn_backend(
kv_cache_config,
self.vllm_config,
self.device,
active_layer_names=self.draft_attn_layer_names,
)
self.block_tables = block_tables

@torch.inference_mode()
Expand Down
Loading