diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 678654cb78a4..89cf47e8fdaa 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -62,9 +62,11 @@ def __init__( sampler: Sampler, spec_config: SpeculativeConfig | None = None, device: torch.device | None = None, + eos_token_ids: list[int] | None = None, ): super().__init__() self.sampler = sampler + self.eos_token_ids = eos_token_ids logprobs_mode = self.sampler.logprobs_mode self.is_processed_logprobs_mode = logprobs_mode.startswith("processed") self.is_logits_logprobs_mode = logprobs_mode.endswith("logits") @@ -165,6 +167,17 @@ def forward( sampling_metadata, ) + # Suppress EOS at draft positions so MTP speculative decoding + # doesn't prematurely terminate tool-call generation. + # Use scalar column indexing (select + fill_) instead of list + # indexing to avoid the indexSelectSmallIndex CUDA kernel which + # can assert with large vocab sizes. + if self.eos_token_ids: + vocab_size = target_logits.shape[-1] + for eid in self.eos_token_ids: + if 0 <= eid < vocab_size: + target_logits[:, eid].fill_(float('-inf')) + output_token_ids = rejection_sample( metadata.draft_token_ids, metadata.num_draft_tokens, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bcab2ca2d4c2..740057277b99 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -218,6 +218,25 @@ logger = init_logger(__name__) + +def _collect_eos_token_ids(model_config) -> list[int] | None: + ids: set[int] = set() + for cfg in (model_config.hf_config, model_config.hf_text_config): + val = getattr(cfg, 'eos_token_id', None) + if val is not None: + if isinstance(val, int): + ids.add(val) + else: + ids.update(val) + gen_config = model_config.try_get_generation_config() + if gen_config and (val := gen_config.get('eos_token_id')) is not None: + if isinstance(val, int): + ids.add(val) + else: + ids.update(val) + return sorted(ids) if ids else None + + AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] # list when ubatching is enabled PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict @@ -578,7 +597,8 @@ def __init__( f"{self.speculative_config.method}" ) self.rejection_sampler = RejectionSampler( - self.sampler, self.speculative_config, self.device + self.sampler, self.speculative_config, self.device, + eos_token_ids=_collect_eos_token_ids(self.model_config), ) self.num_spec_tokens = 0