Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,15 @@ class ArgsTest:
expected_acceptance_len=2.8 + 1,
expected_acceptance_rate=0.9,
),
# A model with self-attn and sliding-window-attn
ArgsTest(
target_model="google/gemma-3-270m-it",
draft_model="google/gemma-3-270m-it",
sampling_config=greedy_sampling(),
num_speculative_tokens=3,
expected_acceptance_len=3 + 1,
expected_acceptance_rate=1.0,
),
]


Expand Down
56 changes: 45 additions & 11 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,25 @@ def __init__(
dtype=torch.int32,
).repeat(max_batch_size, 1)

# filled lazily
self.layer2builder: dict[str, AttentionMetadataBuilder] = {}

def get_attn_metadata_builder(self, layer_name: str) -> AttentionMetadataBuilder:
if len(self.layer2builder) == 0:
self._fill_layer2builder()
return self.layer2builder[layer_name]

def _fill_layer2builder(self) -> None:
for kv_cache_group in self.runner.attn_groups:
for attn_group in kv_cache_group:
builder: AttentionMetadataBuilder = attn_group.get_metadata_builder()
for layer_name in attn_group.layer_names:
if layer_name in self.layer2builder:
raise ValueError(
f"Multiple builders found for layer {layer_name}"
)
self.layer2builder[layer_name] = builder

def _get_positions(self, num_tokens: int):
if self.uses_mrope:
return self.mrope_positions[:, :num_tokens]
Expand Down Expand Up @@ -235,14 +254,6 @@ def propose(

assert self.runner is not None

if self.attn_metadata_builder is None:
attn_metadata_builder = self._get_attention_metadata_builder()
else:
attn_metadata_builder = self.attn_metadata_builder

attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
# FIXME: support hybrid kv for draft model (remove separate indexer)
if self.draft_indexer_metadata_builder:
draft_indexer_metadata = (
Expand All @@ -257,6 +268,10 @@ def propose(
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
builder = self.get_attn_metadata_builder(layer_name)
attn_metadata = builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
per_layer_attn_metadata[layer_name] = attn_metadata

for layer_name in self.indexer_layer_names:
Expand Down Expand Up @@ -311,6 +326,14 @@ def propose(
hidden_states = last_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
self.runner.log_toks("draft input toks", model_kwargs["input_ids"])
if self.runner.do_log:
print("draft positions", model_kwargs["positions"])
_atn_md = list(per_layer_attn_metadata.values())[0]
for idx, block_table in enumerate(_atn_md.block_table):
print(f"block_table {idx}", block_table)
print("slot_mapping", _atn_md.slot_mapping)
print("query_start_loc", _atn_md.query_start_loc)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)

Expand Down Expand Up @@ -435,10 +458,13 @@ def propose(
)

# Rebuild attention metadata
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
)
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
builder = self.get_attn_metadata_builder(layer_name)
attn_metadata = builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1,
)
per_layer_attn_metadata[layer_name] = attn_metadata

# copy inputs to buffer for cudagraph
Expand Down Expand Up @@ -477,6 +503,14 @@ def propose(
hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
self.runner.log_toks("draft input toks", model_kwargs["input_ids"])
if self.runner.do_log:
print("draft positions", model_kwargs["positions"])
_atn_md = list(per_layer_attn_metadata.values())[0]
for idx, block_table in enumerate(_atn_md.block_table):
print(f"block_table {idx}", block_table)
print("slot_mapping", _atn_md.slot_mapping)
print("query_start_loc", _atn_md.query_start_loc)

hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size])
Expand Down
21 changes: 21 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,22 @@ def get_output(self) -> ModelRunnerOutput:


class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def log_toks(self, msg: str, toks: torch.Tensor):
if self.do_log:
print(msg, [self.tokenizer.decode(tok) for tok in toks])

def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.do_log = True
if self.do_log:
from transformers import AutoTokenizer

self.tokenizer = AutoTokenizer.from_pretrained(
vllm_config.model_config.model
)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
Expand Down Expand Up @@ -2419,6 +2430,8 @@ def execute_model(
scheduler_output: "SchedulerOutput",
intermediate_tensors: IntermediateTensors | None = None,
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
if self.do_log:
print("======== STEP =========")
with record_function_or_nullcontext("Preprocess"):
with self.synchronize_input_prep():
# Update persistent batch states.
Expand Down Expand Up @@ -2518,6 +2531,14 @@ def execute_model(
inputs_embeds=inputs_embeds,
**model_kwargs,
)
self.log_toks("tgt input toks", input_ids)
if self.do_log:
print("tgt positions", positions)
_atn_md = list(attn_metadata.values())[0] # type: ignore
for idx, block_table in enumerate(_atn_md.block_table):
print(f"tgt block_table {idx}", block_table)
print("tgt slot_mapping", _atn_md.slot_mapping)
print("tgt query_start_loc", _atn_md.query_start_loc)

with record_function_or_nullcontext("Postprocess"):
if self.use_aux_hidden_state_outputs:
Expand Down