diff --git a/vllm/model_executor/models/qwen3_5_mtp.py b/vllm/model_executor/models/qwen3_5_mtp.py index 0eca47492c91..d8aa9e53d62e 100644 --- a/vllm/model_executor/models/qwen3_5_mtp.py +++ b/vllm/model_executor/models/qwen3_5_mtp.py @@ -407,10 +407,16 @@ def forward( hidden_states: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, + spec_step_idx: int = 0, **kwargs: object, ): hidden_states = self.model( - input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds + input_ids, + positions, + hidden_states, + intermediate_tensors, + inputs_embeds, + spec_step_idx, ) return hidden_states diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index e76664bedff9..802965ac9d28 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -266,10 +266,16 @@ def forward( hidden_states: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, + spec_step_idx: int = 0, **kwargs: object, ): hidden_states = self.model( - input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds + input_ids, + positions, + hidden_states, + intermediate_tensors, + inputs_embeds, + spec_step_idx, ) return hidden_states diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 445bb403b4b3..c77a1731b5a9 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -376,11 +376,13 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode) - def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor: + def _greedy_sample( + self, hidden_states: torch.Tensor, spec_step_idx: int = 0 + ) -> torch.Tensor: """Greedy-sample draft tokens from hidden states.""" if self.use_local_argmax_reduction: return self.model.get_top_tokens(hidden_states) - return self.model.compute_logits(hidden_states).argmax(dim=-1) + return self.model.compute_logits(hidden_states, spec_step_idx).argmax(dim=-1) def propose( self, @@ -621,6 +623,7 @@ def propose( "input_ids": input_ids, "positions": self._get_positions(input_batch_size), "inputs_embeds": inputs_embeds, + "spec_step_idx": token_index + 1, } if self.pass_hidden_states_to_model: model_kwargs["hidden_states"] = self.hidden_states[:input_batch_size] @@ -641,7 +644,10 @@ def propose( last_hidden_states, hidden_states = ret_hidden_states hidden_states = hidden_states[:batch_size] - draft_token_ids = self._greedy_sample(last_hidden_states[:batch_size]) + draft_token_ids = self._greedy_sample( + last_hidden_states[:batch_size], + spec_step_idx=token_index + 1, + ) draft_token_ids_list.append(draft_token_ids) # [batch_size, num_speculative_tokens] diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index 922031a52180..7432645a7dd5 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -109,6 +109,7 @@ def run_model( slot_mappings: dict[str, torch.Tensor] | None, num_tokens_across_dp: torch.Tensor | None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + spec_step_idx: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: batch_descriptor = BatchDescriptor(num_tokens=num_tokens) with set_forward_context( @@ -124,6 +125,7 @@ def run_model( input_ids=self.input_buffers.input_ids[:num_tokens], positions=self.input_buffers.positions[:num_tokens], hidden_states=self.hidden_states[:num_tokens], + spec_step_idx=spec_step_idx, ) if self.method == "mtp": last_hidden_states = ret_hidden_states @@ -153,10 +155,11 @@ def generate_draft( slot_mappings, num_tokens_across_dp, cudagraph_runtime_mode, + spec_step_idx=step, ) last_hidden_states = last_hidden_states[:num_reqs] hidden_states = hidden_states[:num_reqs] - logits = self.model.compute_logits(last_hidden_states) + logits = self.model.compute_logits(last_hidden_states, spec_step_idx=step) # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise # used for draft and target sampling. @@ -264,7 +267,7 @@ def propose( num_tokens_across_dp=num_tokens_across_dp, ) sample_hidden_states = last_hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states) + logits = self.model.compute_logits(sample_hidden_states, spec_step_idx=0) num_reqs = input_batch.num_reqs num_reqs_padded = input_batch.num_reqs_after_padding