Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions tests/v1/sample/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,8 +521,8 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
pytest.param(
(
"eagle",
"meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
"meta-llama/Llama-3.2-1B-Instruct",
"nm-testing/Llama3_2_1B_speculator.eagle3",
),
marks=large_gpu_mark(min_gb=32),
),
Expand All @@ -541,7 +541,7 @@ def test_spec_decode_logprobs(
"""
from vllm import LLM

prompt = "Hello world"
prompt = "Hello world " * 50
sampling_params = SamplingParams(
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
)
Expand Down Expand Up @@ -582,6 +582,9 @@ def test_spec_decode_logprobs(
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
# Force prefill chunking
enable_chunked_prefill=True,
max_num_batched_tokens=32,
)
spec_results = spec_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from spec decode LLM.
Expand All @@ -597,6 +600,8 @@ def test_spec_decode_logprobs(
# Per-token logprobs are expected to be the same.
assert len(ref_logprobs) == len(spec_logprobs)
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
assert math.isclose(
ref_logprob.logprob, spec_logprob.logprob, rel_tol=5e-2, abs_tol=1e-1
)
assert ref_logprob.rank == spec_logprob.rank
assert ref_logprob.decoded_token == spec_logprob.decoded_token
5 changes: 4 additions & 1 deletion vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def forward(
if logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif logprobs_mode == "raw_logits":
raw_logprobs = logits.clone()
if logits.dtype == torch.float32:
raw_logprobs = logits.clone()
else:
raw_logprobs = logits.to(torch.float32)

# Use float32 for the logits.
logits = logits.to(torch.float32)
Expand Down
19 changes: 9 additions & 10 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2466,7 +2466,9 @@ def _bookkeeping_sync(

num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids
logprobs_tensors = sampler_output.logprobs_tensors
invalid_req_indices = []
cu_num_new_tokens: list[int] | None = None
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
Expand All @@ -2479,6 +2481,12 @@ def _bookkeeping_sync(
sampled_token_ids,
self.input_batch.vocab_size,
)
if logprobs_tensors:
# Needed for extracting logprobs when spec decoding.
# This must be done prior to discarding sampled tokens.
cu_num_new_tokens = [0]
for toks in valid_sampled_token_ids:
cu_num_new_tokens.append(cu_num_new_tokens[-1] + len(toks))
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)].clear()
Expand Down Expand Up @@ -2506,10 +2514,6 @@ def _bookkeeping_sync(
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
req_ids = self.input_batch.req_ids
logprobs_tensors = sampler_output.logprobs_tensors
cu_num_accepted_tokens = (
[0] if spec_decode_metadata and logprobs_tensors else None
)
for req_idx in range(num_sampled_tokens):
if self.use_async_scheduling:
sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
Expand All @@ -2518,11 +2522,6 @@ def _bookkeeping_sync(

num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0

if cu_num_accepted_tokens is not None:
cu_num_accepted_tokens.append(
cu_num_accepted_tokens[-1] + num_sampled_ids
)

if not sampled_ids:
continue

Expand All @@ -2544,7 +2543,7 @@ def _bookkeeping_sync(
req_state.output_token_ids.extend(sampled_ids)

logprobs_lists = (
logprobs_tensors.tolists(cu_num_accepted_tokens)
logprobs_tensors.tolists(cu_num_new_tokens)
if not self.use_async_scheduling and logprobs_tensors is not None
else None
)
Expand Down