Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
num_speculative_steps=self.num_speculative_steps,
vocab_size=self.vocab_size,
device=self.device,
cache_draft_logits=not use_strict_rejection_sampling,
)
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
Expand Down Expand Up @@ -447,7 +446,6 @@ def _dummy_run(
next_prefill_tokens=self.req_states.next_prefill_tokens,
temperature=self.sampler.sampling_states.temperature.gpu,
seeds=self.sampler.sampling_states.seeds.gpu,
draft_logits_out=self.req_states.draft_logits,
num_tokens_across_dp=num_tokens_across_dp,
dummy_run=True,
skip_attn_for_dummy_run=skip_attn,
Expand Down Expand Up @@ -816,11 +814,12 @@ def sample(
else:
# Rejection sampling for spec decoding.
assert self.rejection_sampler is not None
assert self.speculator is not None
sampler_output = self.rejection_sampler(
logits,
input_batch,
# Draft logits are needed for probabilistic rejection sampling.
self.req_states.draft_logits,
self.speculator.draft_logits,
)

# Get the number of sampled and rejected tokens.
Expand Down Expand Up @@ -1146,7 +1145,6 @@ def sample_tokens(
self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
self.req_states.draft_logits,
num_tokens_across_dp=num_tokens_across_dp,
)
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
Expand Down
23 changes: 15 additions & 8 deletions vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
device=device,
)

cache_draft_logits = self.speculative_config.rejection_sample_method != "strict"
self.draft_logits: torch.Tensor | None = None
if cache_draft_logits:
self.draft_logits = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
self.vocab_size,
dtype=torch.float32,
device=device,
)

# currently we don't support PIECEWISE for Eagle.
cudagraph_mode = vllm_config.compilation_config.cudagraph_mode
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL:
Expand Down Expand Up @@ -140,7 +151,6 @@ def generate_draft(
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
draft_logits_out: torch.Tensor | None = None,
) -> None:
pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
Expand All @@ -167,8 +177,8 @@ def generate_draft(
self.seeds,
pos + 1,
apply_temperature=True,
processed_logits_out=draft_logits_out[:, step]
if draft_logits_out is not None
processed_logits_out=self.draft_logits[:, step]
if self.draft_logits is not None
else None,
)
self.draft_tokens[:num_reqs, step] = draft_tokens
Expand Down Expand Up @@ -223,8 +233,6 @@ def propose(
temperature: torch.Tensor,
# [max_num_reqs]
seeds: torch.Tensor,
# [max_num_reqs, num_speculative_steps, vocab_size]
draft_logits_out: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor | None = None,
dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False,
Expand Down Expand Up @@ -290,8 +298,8 @@ def propose(
self.seeds,
pos + 1,
apply_temperature=True,
processed_logits_out=draft_logits_out[:, 0]
if draft_logits_out is not None
processed_logits_out=self.draft_logits[:, 0]
if self.draft_logits is not None
else None,
)

Expand Down Expand Up @@ -376,7 +384,6 @@ def propose(
slot_mappings_updated,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=batch_desc.cg_mode,
draft_logits_out=draft_logits_out,
)
return self.draft_tokens[:num_reqs]

Expand Down
13 changes: 0 additions & 13 deletions vllm/v1/worker/gpu/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def __init__(
num_speculative_steps: int,
vocab_size: int,
device: torch.device,
cache_draft_logits: bool,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
Expand Down Expand Up @@ -71,18 +70,6 @@ def __init__(
dtype=torch.int64,
device=device,
)
# Draft token logits.
# NOTE: This tensor maintains the "processed" logits after applying temperature,
# top-p, etc.
self.draft_logits: torch.Tensor | None = None
if cache_draft_logits:
self.draft_logits = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
self.vocab_size,
dtype=torch.float32,
device=device,
)

self.next_prefill_tokens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device
Expand Down
Loading