Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion vllm/model_executor/models/qwen3_5_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,10 +407,16 @@ def forward(
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
**kwargs: object,
):
hidden_states = self.model(
input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds
input_ids,
positions,
hidden_states,
intermediate_tensors,
inputs_embeds,
spec_step_idx,
)
return hidden_states

Expand Down
8 changes: 7 additions & 1 deletion vllm/model_executor/models/qwen3_next_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,16 @@ def forward(
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
**kwargs: object,
):
hidden_states = self.model(
input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds
input_ids,
positions,
hidden_states,
intermediate_tensors,
inputs_embeds,
spec_step_idx,
)
return hidden_states

Expand Down
12 changes: 9 additions & 3 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,13 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:

self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)

def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor:
def _greedy_sample(
self, hidden_states: torch.Tensor, spec_step_idx: int = 0
) -> torch.Tensor:
"""Greedy-sample draft tokens from hidden states."""
if self.use_local_argmax_reduction:
return self.model.get_top_tokens(hidden_states)
return self.model.compute_logits(hidden_states).argmax(dim=-1)
return self.model.compute_logits(hidden_states, spec_step_idx).argmax(dim=-1)
Comment on lines 383 to +385
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The get_top_tokens method is called without the spec_step_idx argument. When use_local_argmax_reduction is enabled, this will cause the wrong decoder layer to be used for speculative steps beyond the first, leading to incorrect draft tokens. This seems to defeat the purpose of this bug fix for this code path.

To fix this, spec_step_idx should be passed to get_top_tokens. You will also need to update the get_top_tokens method on the MTP models to accept and use this parameter, similar to how compute_logits is being updated.

Suggested change
if self.use_local_argmax_reduction:
return self.model.get_top_tokens(hidden_states)
return self.model.compute_logits(hidden_states).argmax(dim=-1)
return self.model.compute_logits(hidden_states, spec_step_idx).argmax(dim=-1)
if self.use_local_argmax_reduction:
return self.model.get_top_tokens(hidden_states, spec_step_idx=spec_step_idx)
return self.model.compute_logits(hidden_states, spec_step_idx).argmax(dim=-1)


def propose(
self,
Expand Down Expand Up @@ -621,6 +623,7 @@ def propose(
"input_ids": input_ids,
"positions": self._get_positions(input_batch_size),
"inputs_embeds": inputs_embeds,
"spec_step_idx": token_index + 1,
}
if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = self.hidden_states[:input_batch_size]
Expand All @@ -641,7 +644,10 @@ def propose(
last_hidden_states, hidden_states = ret_hidden_states

hidden_states = hidden_states[:batch_size]
draft_token_ids = self._greedy_sample(last_hidden_states[:batch_size])
draft_token_ids = self._greedy_sample(
last_hidden_states[:batch_size],
spec_step_idx=token_index + 1,
)
draft_token_ids_list.append(draft_token_ids)

# [batch_size, num_speculative_tokens]
Expand Down
7 changes: 5 additions & 2 deletions vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def run_model(
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
spec_step_idx: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
batch_descriptor = BatchDescriptor(num_tokens=num_tokens)
with set_forward_context(
Expand All @@ -124,6 +125,7 @@ def run_model(
input_ids=self.input_buffers.input_ids[:num_tokens],
positions=self.input_buffers.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
spec_step_idx=spec_step_idx,
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
Expand Down Expand Up @@ -153,10 +155,11 @@ def generate_draft(
slot_mappings,
num_tokens_across_dp,
cudagraph_runtime_mode,
spec_step_idx=step,
)
last_hidden_states = last_hidden_states[:num_reqs]
hidden_states = hidden_states[:num_reqs]
logits = self.model.compute_logits(last_hidden_states)
logits = self.model.compute_logits(last_hidden_states, spec_step_idx=step)
Comment on lines 155 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The run_model and compute_logits calls now pass spec_step_idx. However, self.model here is the target model, which is not guaranteed to be an MTP model that accepts this argument. If a non-MTP model is used as the target with eagle speculative decoding, this will raise a TypeError as its forward and compute_logits methods may not accept spec_step_idx. A similar issue exists in the propose method.

To make this more robust, you could check if the model's methods support the spec_step_idx parameter before passing it, for example by using inspect.signature.


# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
Expand Down Expand Up @@ -264,7 +267,7 @@ def propose(
num_tokens_across_dp=num_tokens_across_dp,
)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
logits = self.model.compute_logits(sample_hidden_states, spec_step_idx=0)

num_reqs = input_batch.num_reqs
num_reqs_padded = input_batch.num_reqs_after_padding
Expand Down
Loading