diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 06e91d380298..c05423672f51 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -195,7 +195,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): num_speculative_steps=self.num_speculative_steps, vocab_size=self.vocab_size, device=self.device, - cache_draft_logits=not use_strict_rejection_sampling, ) self.input_buffers = InputBuffers( max_num_reqs=self.max_num_reqs, @@ -447,7 +446,6 @@ def _dummy_run( next_prefill_tokens=self.req_states.next_prefill_tokens, temperature=self.sampler.sampling_states.temperature.gpu, seeds=self.sampler.sampling_states.seeds.gpu, - draft_logits_out=self.req_states.draft_logits, num_tokens_across_dp=num_tokens_across_dp, dummy_run=True, skip_attn_for_dummy_run=skip_attn, @@ -816,11 +814,12 @@ def sample( else: # Rejection sampling for spec decoding. assert self.rejection_sampler is not None + assert self.speculator is not None sampler_output = self.rejection_sampler( logits, input_batch, # Draft logits are needed for probabilistic rejection sampling. - self.req_states.draft_logits, + self.speculator.draft_logits, ) # Get the number of sampled and rejected tokens. @@ -1146,7 +1145,6 @@ def sample_tokens( self.req_states.next_prefill_tokens, self.sampler.sampling_states.temperature.gpu, self.sampler.sampling_states.seeds.gpu, - self.req_states.draft_logits, num_tokens_across_dp=num_tokens_across_dp, ) self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index 922031a52180..9e4090bd6888 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -75,6 +75,17 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): device=device, ) + cache_draft_logits = self.speculative_config.rejection_sample_method != "strict" + self.draft_logits: torch.Tensor | None = None + if cache_draft_logits: + self.draft_logits = torch.zeros( + self.max_num_reqs, + self.num_speculative_steps, + self.vocab_size, + dtype=torch.float32, + device=device, + ) + # currently we don't support PIECEWISE for Eagle. cudagraph_mode = vllm_config.compilation_config.cudagraph_mode if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL: @@ -140,7 +151,6 @@ def generate_draft( slot_mappings: dict[str, torch.Tensor] | None, num_tokens_across_dp: torch.Tensor | None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - draft_logits_out: torch.Tensor | None = None, ) -> None: pos = self.input_buffers.positions[:num_reqs] query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] @@ -167,8 +177,8 @@ def generate_draft( self.seeds, pos + 1, apply_temperature=True, - processed_logits_out=draft_logits_out[:, step] - if draft_logits_out is not None + processed_logits_out=self.draft_logits[:, step] + if self.draft_logits is not None else None, ) self.draft_tokens[:num_reqs, step] = draft_tokens @@ -223,8 +233,6 @@ def propose( temperature: torch.Tensor, # [max_num_reqs] seeds: torch.Tensor, - # [max_num_reqs, num_speculative_steps, vocab_size] - draft_logits_out: torch.Tensor | None, num_tokens_across_dp: torch.Tensor | None = None, dummy_run: bool = False, skip_attn_for_dummy_run: bool = False, @@ -290,8 +298,8 @@ def propose( self.seeds, pos + 1, apply_temperature=True, - processed_logits_out=draft_logits_out[:, 0] - if draft_logits_out is not None + processed_logits_out=self.draft_logits[:, 0] + if self.draft_logits is not None else None, ) @@ -376,7 +384,6 @@ def propose( slot_mappings_updated, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=batch_desc.cg_mode, - draft_logits_out=draft_logits_out, ) return self.draft_tokens[:num_reqs] diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index 3fb02c12d999..f929b5eddf89 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -15,7 +15,6 @@ def __init__( num_speculative_steps: int, vocab_size: int, device: torch.device, - cache_draft_logits: bool, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -71,18 +70,6 @@ def __init__( dtype=torch.int64, device=device, ) - # Draft token logits. - # NOTE: This tensor maintains the "processed" logits after applying temperature, - # top-p, etc. - self.draft_logits: torch.Tensor | None = None - if cache_draft_logits: - self.draft_logits = torch.zeros( - self.max_num_reqs, - self.num_speculative_steps, - self.vocab_size, - dtype=torch.float32, - device=device, - ) self.next_prefill_tokens = torch.zeros( self.max_num_reqs, dtype=torch.int32, device=device