Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions vllm/v1/cudagraph_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ class CudagraphDispatcher:
runnable without cudagraph (if the mode does not match or mode is NONE).
"""

def __init__(self, vllm_config: VllmConfig):
def __init__(self, vllm_config: VllmConfig, for_draft_model: bool = False):
self.vllm_config = vllm_config
self.for_draft_model = for_draft_model
self.compilation_config = vllm_config.compilation_config
self.uniform_decode_query_len = (
1
Expand Down Expand Up @@ -134,9 +135,11 @@ def _create_padded_batch_descriptor(
uniform_decode: bool,
has_lora: bool,
num_active_loras: int = 0,
uniform_decode_query_len: int | None = None,
) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len
if uniform_decode_query_len is None:
uniform_decode_query_len = self.uniform_decode_query_len
num_tokens_padded = self._bs_to_padded_graph_size[num_tokens]

if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
Expand Down Expand Up @@ -229,6 +232,26 @@ def initialize_cudagraph_keys(
),
)

if self.for_draft_model and cudagraph_mode.has_full_cudagraphs():
max_num_tokens = self.vllm_config.scheduler_config.max_num_seqs
assert self.compilation_config.cudagraph_capture_sizes is not None, (
"Cudagraph capture sizes must be set when full mode is enabled."
)
capture_sizes_for_draft_model = []
for size in self.compilation_config.cudagraph_capture_sizes:
capture_sizes_for_draft_model.append(size)
if size >= max_num_tokens:
break
for bs, num_active_loras in product(
capture_sizes_for_draft_model, lora_cases
):
self.add_cudagraph_key(
CUDAGraphMode.FULL,
self._create_padded_batch_descriptor(
bs, True, num_active_loras > 0, num_active_loras, 1
),
)

self.keys_initialized = True

def dispatch(
Expand All @@ -239,6 +262,7 @@ def dispatch(
num_active_loras: int = 0,
valid_modes: AbstractSet[CUDAGraphMode] | None = None,
invalid_modes: AbstractSet[CUDAGraphMode] | None = None,
uniform_decode_query_len: int | None = None,
) -> tuple[CUDAGraphMode, BatchDescriptor]:
"""
Given conditions(e.g.,batch descriptor and if using piecewise only),
Expand Down Expand Up @@ -298,9 +322,15 @@ def dispatch(
)
effective_num_active_loras = self.vllm_config.lora_config.max_loras + 1

normalized_uniform = uniform_decode and self.cudagraph_mode.separate_routine()
normalized_uniform = uniform_decode and (
self.cudagraph_mode.separate_routine() or self.for_draft_model
)
batch_desc = self._create_padded_batch_descriptor(
num_tokens, normalized_uniform, has_lora, effective_num_active_loras
num_tokens,
normalized_uniform,
has_lora,
effective_num_active_loras,
uniform_decode_query_len,
)

if CUDAGraphMode.FULL in allowed_modes:
Expand Down
140 changes: 105 additions & 35 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
import torch
import torch.nn as nn

from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.config import (
CUDAGraphMode,
VllmConfig,
get_layers_from_vllm_config,
)
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
Expand Down Expand Up @@ -123,7 +124,7 @@ def __init__(
# Keys are initialized later via initialize_cudagraph_keys() called from
# gpu_model_runner._check_and_update_cudagraph_mode after
# adjust_cudagraph_sizes_for_spec_decode is called.
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config, True)

# persistent buffers for cuda graph
self.input_ids = torch.zeros(
Expand Down Expand Up @@ -359,21 +360,11 @@ def _get_slot_mapping(
return {name: view for name in self._draft_attn_layer_names}

def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys for eagle.
"""Initialize cudagraph dispatcher keys for eagle."""
if self.speculative_config.enforce_eager:
cudagraph_mode = CUDAGraphMode.NONE

Eagle only supports PIECEWISE cudagraphs (via mixed_mode).
This should be called after adjust_cudagraph_sizes_for_spec_decode.
"""
if (
not self.speculative_config.enforce_eager
and cudagraph_mode.mixed_mode()
in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
):
eagle_cudagraph_mode = CUDAGraphMode.PIECEWISE
else:
eagle_cudagraph_mode = CUDAGraphMode.NONE

self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)
self.cudagraph_dispatcher.initialize_cudagraph_keys(cudagraph_mode)

def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Greedy-sample draft tokens from hidden states."""
Expand All @@ -393,6 +384,7 @@ def propose(
next_token_ids: torch.Tensor,
token_indices_to_sample: torch.Tensor | None,
common_attn_metadata: CommonAttentionMetadata,
target_model_batch_desc: BatchDescriptor,
sampling_metadata: SamplingMetadata,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None,
Expand All @@ -403,8 +395,12 @@ def propose(
batch_size = common_attn_metadata.batch_size()

if self.method == "eagle3":
if isinstance(self.model, CUDAGraphWrapper):
model = self.model.unwrap()
else:
model = self.model
assert isinstance(
self.model, (Eagle3LlamaForCausalLM, Eagle3DeepseekV2ForCausalLM)
model, (Eagle3LlamaForCausalLM, Eagle3DeepseekV2ForCausalLM)
)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states
Expand All @@ -431,8 +427,9 @@ def propose(
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata

cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens)
uniform_decode = target_model_batch_desc.uniform
cudagraph_runtime_mode, batch_desc, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens, uniform_decode)
)

if self.supports_mm_inputs:
Expand Down Expand Up @@ -464,6 +461,7 @@ def propose(
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_desc,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
Expand Down Expand Up @@ -517,14 +515,18 @@ def propose(
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]

cudagraph_runtime_mode, input_batch_size, batch_size_across_dp = (
self._determine_batch_execution_and_padding(batch_size)
cudagraph_runtime_mode, batch_desc, input_batch_size, batch_size_across_dp = (
self._determine_batch_execution_and_padding(
batch_size, True, uniform_decode_query_len=1
)
)

common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
common_attn_metadata.query_start_loc[: batch_size + 1] = self.arange[
: batch_size + 1
]
common_attn_metadata.query_start_loc_cpu[: batch_size + 1] = torch.from_numpy(
self.token_arange_np[: batch_size + 1]
).clone()

Expand Down Expand Up @@ -628,6 +630,7 @@ def propose(
num_tokens=input_batch_size,
num_tokens_across_dp=batch_size_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_desc,
slot_mapping=self._get_slot_mapping(input_batch_size),
):
ret_hidden_states = self.model(**model_kwargs)
Expand Down Expand Up @@ -823,6 +826,7 @@ def prepare_next_token_ids_padded(
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_mask: torch.Tensor,
batch_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
Expand All @@ -844,16 +848,17 @@ def prepare_next_token_ids_padded(
self.backup_next_token_ids.copy_to_gpu(num_reqs)
backup_tokens_gpu = self.backup_next_token_ids.gpu

batch_size, num_tokens = sampled_token_ids.shape
_, num_tokens = sampled_token_ids.shape
device = sampled_token_ids.device

assert discard_request_mask.dtype == torch.bool
assert backup_tokens_gpu.dtype == torch.int32

next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
next_token_ids = torch.zeros(batch_size, dtype=torch.int32, device=device)
valid_sampled_tokens_count = next_token_ids.new_zeros(batch_size)

# Kernel grid: one program per request (row)
# NOTE: For CUDA Graph, we need the `batch_size` to be `num_reqs_padded` here
grid = (batch_size,)

# Find the next power of 2 for block sizes
Expand All @@ -878,6 +883,7 @@ def prepare_inputs_padded(
common_attn_metadata: CommonAttentionMetadata,
spec_decode_metadata: SpecDecodeMetadata,
valid_sampled_tokens_count: torch.Tensor,
gpu_input_batch: InputBatch,
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding
Expand All @@ -897,6 +903,14 @@ def prepare_inputs_padded(
(num_reqs,), dtype=torch.int32, device=device
)

actual_num_reqs = gpu_input_batch.num_reqs
spec_decode_metadata.cu_num_draft_tokens = nn.functional.pad(
spec_decode_metadata.cu_num_draft_tokens,
(0, num_reqs - actual_num_reqs),
mode="constant",
value=spec_decode_metadata.cu_num_draft_tokens[-1],
)

grid = (num_reqs,)
eagle_prepare_inputs_padded_kernel[grid](
spec_decode_metadata.cu_num_draft_tokens,
Expand Down Expand Up @@ -1237,6 +1251,19 @@ def load_model(self, target_model: nn.Module) -> None:
)

self.model = self._get_model()
# wrap the model with full cudagraph wrapper if needed.
cudagraph_mode = self.compilation_config.cudagraph_mode
if (
cudagraph_mode.has_full_cudagraphs()
and not self.vllm_config.parallel_config.use_ubatching
and not self.speculative_config.disable_padded_drafter_batch
):
# Currently Ubatch does not support FULL in speculative decoding, unpadded
# drafter batch either due to the dynamic number of tokens.
# We can consider supporting FULL for these cases in the future if needed.
self.model = CUDAGraphWrapper(
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
)

# Find draft layers (attention layers added by draft model)
all_attn_layers = get_layers_from_vllm_config(
Expand Down Expand Up @@ -1469,21 +1496,46 @@ def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None:
def dummy_run(
self,
num_tokens: int,
common_attn_metadata: CommonAttentionMetadata | None = None,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None:
# FIXME: when using tree-based specdec, adjust number of forward-passes
# according to the depth of the tree.
for fwd_idx in range(
self.num_speculative_tokens if not is_graph_capturing else 1
self.num_speculative_tokens
if not is_graph_capturing
else min(self.num_speculative_tokens, 2)
):
if fwd_idx <= 1:
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_tokens, use_cudagraphs=use_cudagraphs
)
if fwd_idx > 0 and common_attn_metadata is not None:
# All speculative steps except the first one typically use
# a uniform decode with 1 token per request.
uniform_decode = True
num_tokens = common_attn_metadata.num_reqs
uniform_decode_query_len = 1
else:
# For the first step, note that for FULL_DECODE_ONLY and
# FULL_AND_PIECEWISE we need to set uniform_decode to True
# while for FULL we don't
mode = self.cudagraph_dispatcher.cudagraph_mode
is_full_sep = (
mode.decode_mode() == CUDAGraphMode.FULL and mode.separate_routine()
)
uniform_decode = is_full_sep and common_attn_metadata is not None
uniform_decode_query_len = None

(
cudagraph_runtime_mode,
batch_desc,
num_input_tokens,
num_tokens_across_dp,
) = self._determine_batch_execution_and_padding(
num_tokens,
uniform_decode,
use_cudagraphs=use_cudagraphs,
uniform_decode_query_len=uniform_decode_query_len,
)

# Make sure to use EAGLE's own buffer during cudagraph capture.
if (
Expand All @@ -1495,12 +1547,26 @@ def dummy_run(
else:
slot_mapping_dict = slot_mappings or {}

if common_attn_metadata is not None:
dummy_attn_metadata = {}
for attn_group in self.draft_attn_groups:
attn_metadata = (
attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
)
for layer_name in attn_group.layer_names:
dummy_attn_metadata[layer_name] = attn_metadata
else:
dummy_attn_metadata = None

with set_forward_context(
None,
dummy_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_desc,
slot_mapping=slot_mapping_dict,
):
if self.supports_mm_inputs:
Expand Down Expand Up @@ -1623,11 +1689,15 @@ def initialize_attn_backend(
def _determine_batch_execution_and_padding(
self,
num_tokens: int,
uniform_decode: bool = False,
use_cudagraphs: bool = True,
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
uniform_decode_query_len: int | None = None,
) -> tuple[CUDAGraphMode, BatchDescriptor, int, torch.Tensor | None]:
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens,
uniform_decode,
valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
uniform_decode_query_len=uniform_decode_query_len,
)
num_tokens_padded = batch_desc.num_tokens

Expand Down Expand Up @@ -1662,7 +1732,7 @@ def _determine_batch_execution_and_padding(
assert batch_desc.num_tokens == num_tokens_padded
num_tokens_across_dp[dp_rank] = num_tokens_padded

return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
return cudagraph_mode, batch_desc, num_tokens_padded, num_tokens_across_dp


class EagleProposer(SpecDecodeBaseProposer):
Expand Down
Loading
Loading