From 1e74965a222954801137340673f15b739a3b5ff5 Mon Sep 17 00:00:00 2001 From: "Claude 2.0" Date: Sun, 3 May 2026 21:17:37 +0800 Subject: [PATCH] fix(spec decode): suppress EOS at draft positions in rejection sampler MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When using MTP speculative decoding, the rejection sampler's target model can produce EOS as the argmax at a draft position. The scheduler iterates the MTP burst tokens one-by-one via check_stop(), which immediately sets FINISHED_STOPPED when it encounters EOS — discarding all remaining tokens in the burst, including the bonus token that would have continued generation. This manifests as premature stopping at reasoning-to-tool-call boundaries: the client receives finish_reason "stop" with only reasoning_content and no tool_calls or content. The fix masks all EOS tokens in target_logits before the rejection sampling step, setting their logits to -inf at all draft positions. The bonus token (sampled from a separate bonus_logits tensor) still produces EOS for legitimate stops. Draft positions can no longer prematurely terminate the burst. Key implementation details: - _collect_eos_token_ids gathers EOS IDs from hf_config, hf_text_config, and generation_config (multimodal models like Qwen3.6-27B nest eos_token_id inside text_config) - Uses scalar column indexing (select + fill_) to avoid the indexSelectSmallIndex CUDA kernel that asserts with large vocab sizes (observed with Qwen3.6-27B: vocab=248320, eos=248044) - Only the large model runner is patched — the small runner uses a different RejectionSampler with a different API Co-Authored-By: Claude Opus 4.7 --- vllm/v1/sample/rejection_sampler.py | 13 +++++++++++++ vllm/v1/worker/gpu_model_runner.py | 22 +++++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 678654cb78a4..89cf47e8fdaa 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -62,9 +62,11 @@ def __init__( sampler: Sampler, spec_config: SpeculativeConfig | None = None, device: torch.device | None = None, + eos_token_ids: list[int] | None = None, ): super().__init__() self.sampler = sampler + self.eos_token_ids = eos_token_ids logprobs_mode = self.sampler.logprobs_mode self.is_processed_logprobs_mode = logprobs_mode.startswith("processed") self.is_logits_logprobs_mode = logprobs_mode.endswith("logits") @@ -165,6 +167,17 @@ def forward( sampling_metadata, ) + # Suppress EOS at draft positions so MTP speculative decoding + # doesn't prematurely terminate tool-call generation. + # Use scalar column indexing (select + fill_) instead of list + # indexing to avoid the indexSelectSmallIndex CUDA kernel which + # can assert with large vocab sizes. + if self.eos_token_ids: + vocab_size = target_logits.shape[-1] + for eid in self.eos_token_ids: + if 0 <= eid < vocab_size: + target_logits[:, eid].fill_(float('-inf')) + output_token_ids = rejection_sample( metadata.draft_token_ids, metadata.num_draft_tokens, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bcab2ca2d4c2..740057277b99 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -218,6 +218,25 @@ logger = init_logger(__name__) + +def _collect_eos_token_ids(model_config) -> list[int] | None: + ids: set[int] = set() + for cfg in (model_config.hf_config, model_config.hf_text_config): + val = getattr(cfg, 'eos_token_id', None) + if val is not None: + if isinstance(val, int): + ids.add(val) + else: + ids.update(val) + gen_config = model_config.try_get_generation_config() + if gen_config and (val := gen_config.get('eos_token_id')) is not None: + if isinstance(val, int): + ids.add(val) + else: + ids.update(val) + return sorted(ids) if ids else None + + AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] # list when ubatching is enabled PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict @@ -578,7 +597,8 @@ def __init__( f"{self.speculative_config.method}" ) self.rejection_sampler = RejectionSampler( - self.sampler, self.speculative_config, self.device + self.sampler, self.speculative_config, self.device, + eos_token_ids=_collect_eos_token_ids(self.model_config), ) self.num_spec_tokens = 0