Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 97 additions & 55 deletions vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
init_attn_backend,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import (
BatchExecutionDescriptor,
)
from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
Expand Down Expand Up @@ -239,6 +242,66 @@ def generate_draft(
idx_mapping, query_start_loc, pos, num_tokens_padded
)

def _dispatch_and_sync_dp(
self,
cudagraph_manager: EagleCudaGraphManager,
num_reqs: int,
num_tokens: int,
uniform_token_count: int | None,
) -> tuple[BatchExecutionDescriptor, torch.Tensor | None]:
batch_desc = cudagraph_manager.dispatch(
num_reqs, num_tokens, uniform_token_count
)
num_tokens_across_dp = None
if self.dp_size > 1:
batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
cudagraph_manager,
batch_desc,
num_tokens,
num_reqs,
uniform_token_count,
self.dp_size,
self.dp_rank,
)
return batch_desc, num_tokens_across_dp

def _build_draft_attn_metadata(
self,
num_reqs: int,
num_reqs_padded: int,
num_tokens_padded: int,
max_query_len: int,
) -> dict[str, Any] | None:
if not self.draft_attn_layer_names:
return None

query_start_loc_cpu = (
torch.arange(num_reqs_padded + 1, dtype=torch.int32, device="cpu").clamp_(
max=num_reqs
)
* max_query_len
Comment thread
TheEpicDolphin marked this conversation as resolved.
)
block_tables = [
x[:num_reqs_padded] for x in self.block_tables.input_block_tables
]
slot_mappings = self.block_tables.slot_mappings[:, :num_tokens_padded]
attn_metadata = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs_padded,
num_tokens=num_tokens_padded,
query_start_loc_gpu=self.input_buffers.query_start_loc[
: num_reqs_padded + 1
],
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=self.input_buffers.seq_lens[:num_reqs_padded],
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
return attn_metadata

def capture_model(self) -> None:
if self.num_speculative_steps == 1:
return
Expand Down Expand Up @@ -319,7 +382,6 @@ def propose(
logits = self.model.compute_logits(sample_hidden_states)

num_reqs = input_batch.num_reqs
num_reqs_padded = input_batch.num_reqs_after_padding
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
Expand Down Expand Up @@ -366,69 +428,49 @@ def propose(
self.max_num_reqs,
)

# Get batch descriptor and sync across DP ranks.
# Eagle uses FULL-only mode, dispatch with uniform_token_count=1 for decode

batch_desc = self.cudagraph_manager.dispatch(num_reqs, num_reqs, 1)
num_tokens_across_dp = None

if self.dp_size > 1:
batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
self.cudagraph_manager,
batch_desc,
num_reqs,
num_reqs,
1, # uniform_token_count
self.dp_size,
self.dp_rank,
)

if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos, batch_desc.num_tokens
)

if batch_desc.cg_mode == CUDAGraphMode.FULL:
return self.cudagraph_manager.run_fullgraph(batch_desc)[:num_reqs]
# Each request produces exactly 1 token per draft decode step,
# enabling FULL cudagraph.
decode_batch_desc, num_tokens_across_dp = self._dispatch_and_sync_dp(
self.cudagraph_manager,
num_reqs,
num_reqs,
uniform_token_count=1,
)

# Run eager or piecewise CUDA graph.
attn_metadata_updated = None
slot_mappings_updated = None
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc_cpu = torch.arange(
num_reqs_padded + 1, dtype=torch.int32, device="cpu"
)
block_tables = [
x[:num_reqs_padded] for x in self.block_tables.input_block_tables
]

# FIXME(woosuk): This is UNSAFE!!
attn_metadata_updated = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs_padded,
num_tokens=num_reqs_padded,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=1,
seq_lens=self.input_buffers.seq_lens[:num_reqs_padded],
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
# Build attention metadata and slot mappings for the draft
# decode steps. It is necessary to rebuild the attention
# metadata even when replaying the FULL cudagraph so that
# any attention metadata builder state is updated.
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping,
self.input_buffers.query_start_loc[: num_reqs + 1],
pos,
decode_batch_desc.num_tokens,
)
slot_mappings_updated = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
attn_metadata_updated = self._build_draft_attn_metadata(
num_reqs=num_reqs,
num_reqs_padded=decode_batch_desc.num_reqs or num_reqs,
num_tokens_padded=decode_batch_desc.num_tokens,
max_query_len=1,
)

self.generate_draft(
num_reqs,
batch_desc.num_tokens,
attn_metadata_updated,
slot_mappings_updated,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=batch_desc.cg_mode,
)
if decode_batch_desc.cg_mode == CUDAGraphMode.FULL:
self.cudagraph_manager.run_fullgraph(decode_batch_desc)
else:
self.generate_draft(
num_reqs,
decode_batch_desc.num_tokens,
attn_metadata_updated,
slot_mappings_updated,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=decode_batch_desc.cg_mode,
)
return self.draft_tokens[:num_reqs]


Expand Down
Loading