Skip to content

Commit f571ff8

Browse files
22quinnnjhillWoosukKwon
authored
[Sampler] Support returning final logprobs (#22387)
Signed-off-by: 22quinn <[email protected]> Co-authored-by: Nick Hill <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]>
1 parent f64ee61 commit f571ff8

File tree

7 files changed

+126
-70
lines changed

7 files changed

+126
-70
lines changed

docs/usage/v1_guide.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,15 @@ differences compared to V0:
154154

155155
##### Logprobs Calculation
156156

157-
Logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e.
157+
By default, logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e.
158158
before applying any logits post-processing such as temperature scaling or penalty
159159
adjustments). As a result, the returned logprobs do not reflect the final adjusted
160160
probabilities used during sampling.
161161

162-
Support for logprobs with post-sampling adjustments is in progress and will be added in future updates.
162+
You can adjust this behavior by setting the `--logprobs-mode` flag.
163+
Four modes are supported: `raw_logprobs` (default), `processed_logprobs`, `raw_logits`, `processed_logits`.
164+
Raw means the values before applying any logit processors, like bad words.
165+
Processed means the values after applying all processors, including temperature and top_k/top_p.
163166

164167
##### Prompt Logprobs with Prefix Caching
165168

tests/v1/sample/test_logprobs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,9 +456,7 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
456456
assert len(logprob) == vocab_size
457457

458458

459-
@pytest.mark.parametrize(
460-
"logprobs_mode",
461-
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
459+
@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode))
462460
def test_logprobs_mode(logprobs_mode: LogprobsMode,
463461
monkeypatch: pytest.MonkeyPatch):
464462
"""Test with LLM engine with different logprobs_mode.
@@ -487,12 +485,14 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode,
487485
for logprobs in output.logprobs:
488486
for token_id in logprobs:
489487
logprob = logprobs[token_id]
490-
if "logprobs" in logprobs_mode:
488+
if logprobs_mode in (LogprobsMode.RAW_LOGPROBS,
489+
LogprobsMode.PROCESSED_LOGPROBS):
491490
assert logprob.logprob <= 0
492491
if logprob.logprob > 0:
493492
positive_values = positive_values + 1
494493
total_token_with_logprobs = total_token_with_logprobs + 1
495494
assert total_token_with_logprobs >= len(results[0].outputs)
496-
if "logits" in logprobs_mode:
495+
if logprobs_mode in (LogprobsMode.RAW_LOGITS,
496+
LogprobsMode.PROCESSED_LOGITS):
497497
assert positive_values > 0
498498
del llm

vllm/config/__init__.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,16 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
257257

258258
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
259259
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
260-
LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs",
261-
"processed_logits"]
262260
MMEncoderTPMode = Literal["weights", "data"]
263261

264262

263+
class LogprobsMode(enum.Enum):
264+
RAW_LOGITS = "raw_logits"
265+
RAW_LOGPROBS = "raw_logprobs"
266+
PROCESSED_LOGITS = "processed_logits"
267+
PROCESSED_LOGPROBS = "processed_logprobs"
268+
269+
265270
@config
266271
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
267272
class ModelConfig:
@@ -363,12 +368,13 @@ class ModelConfig:
363368
specified in `SamplingParams`. The default value comes the default for the
364369
OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length *
365370
vocab_size) logprobs are allowed to be returned and it may cause OOM."""
366-
logprobs_mode: LogprobsMode = "raw_logprobs"
371+
logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS
367372
"""Indicates the content returned in the logprobs and prompt_logprobs.
368373
Supported mode:
369374
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
370-
Raw means the values before applying logit processors, like bad words.
371-
Processed means the values after applying such processors.
375+
Raw means the values before applying any logit processors, like bad words.
376+
Processed means the values after applying all processors, including
377+
temperature and top_k/top_p.
372378
"""
373379
disable_sliding_window: bool = False
374380
"""Whether to disable sliding window. If True, we will disable the sliding
@@ -2586,7 +2592,7 @@ class MultiModalConfig:
25862592

25872593
skip_mm_profiling: bool = False
25882594
"""
2589-
When enabled, skips multimodal memory profiling and only profiles with
2595+
When enabled, skips multimodal memory profiling and only profiles with
25902596
language backbone model during engine initialization.
25912597
25922598
This reduces engine startup time but shifts the responsibility to users for
@@ -2649,24 +2655,24 @@ class PoolerConfig:
26492655
## for embeddings models
26502656
normalize: Optional[bool] = None
26512657
"""
2652-
Whether to normalize the embeddings outputs.
2658+
Whether to normalize the embeddings outputs.
26532659
"""
26542660
dimensions: Optional[int] = None
26552661
"""
2656-
Reduce the dimensions of embeddings if model
2662+
Reduce the dimensions of embeddings if model
26572663
support matryoshka representation.
26582664
"""
26592665

26602666
## for classification models
26612667
activation: Optional[bool] = None
26622668
"""
2663-
Whether to apply activation function to the classification outputs.
2669+
Whether to apply activation function to the classification outputs.
26642670
"""
26652671

26662672
## for reward models
26672673
softmax: Optional[bool] = None
26682674
"""
2669-
Whether to apply softmax to the reward outputs.
2675+
Whether to apply softmax to the reward outputs.
26702676
"""
26712677
step_tag_id: Optional[int] = None
26722678
"""
@@ -2692,9 +2698,9 @@ class PoolerConfig:
26922698

26932699
max_embed_len: Optional[int] = None
26942700
"""
2695-
Maximum input length allowed for embedding generation. When set, allows
2701+
Maximum input length allowed for embedding generation. When set, allows
26962702
inputs longer than max_embed_len to be accepted for embedding models.
2697-
This parameter enables accepting long inputs without requiring
2703+
This parameter enables accepting long inputs without requiring
26982704
VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds
26992705
max_embed_len, it will be handled according to the original max_model_len
27002706
validation logic. Defaults to None (i.e. set to max_model_len).

vllm/engine/arg_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
516516
model_group.add_argument("--max-logprobs",
517517
**model_kwargs["max_logprobs"])
518518
model_group.add_argument("--logprobs-mode",
519+
choices=[f.value for f in LogprobsMode],
519520
**model_kwargs["logprobs_mode"])
520521
model_group.add_argument("--disable-sliding-window",
521522
**model_kwargs["disable_sliding_window"])

vllm/v1/sample/ops/topk_topp_sampler.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from packaging import version
99

1010
from vllm import envs
11+
from vllm.config import LogprobsMode
1112
from vllm.logger import init_logger
1213
from vllm.platforms import current_platform
1314

@@ -28,9 +29,16 @@ class TopKTopPSampler(nn.Module):
2829
Implementations may update the logits tensor in-place.
2930
"""
3031

31-
def __init__(self):
32+
def __init__(
33+
self,
34+
logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS) -> None:
3235
super().__init__()
33-
if current_platform.is_cuda():
36+
self.logprobs_mode = logprobs_mode
37+
# flashinfer optimization does not apply if intermediate
38+
# logprobs/logits after top_k/top_p need to be returned
39+
if logprobs_mode not in (LogprobsMode.PROCESSED_LOGITS,
40+
LogprobsMode.PROCESSED_LOGPROBS
41+
) and current_platform.is_cuda():
3442
if is_flashinfer_available:
3543
flashinfer_version = flashinfer.__version__
3644
if version.parse(flashinfer_version) < version.parse("0.2.3"):
@@ -63,61 +71,58 @@ def __init__(self):
6371
"native implementation of top-p & top-k sampling. For the "
6472
"best performance, please install FlashInfer.")
6573
self.forward = self.forward_native
66-
elif current_platform.is_tpu():
67-
self.forward = self.forward_tpu
6874
else:
6975
self.forward = self.forward_native
76+
if current_platform.is_tpu():
77+
self.apply_top_k_top_p = apply_top_k_top_p_tpu
78+
else:
79+
self.apply_top_k_top_p = apply_top_k_top_p
7080

7181
def forward_native(
7282
self,
7383
logits: torch.Tensor,
7484
generators: dict[int, torch.Generator],
7585
k: Optional[torch.Tensor],
7686
p: Optional[torch.Tensor],
77-
) -> torch.Tensor:
87+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
7888
"""
7989
PyTorch-native implementation of top-k and top-p sampling.
8090
8191
The logits tensor may be updated in-place.
8292
"""
83-
logits = apply_top_k_top_p(logits, k, p)
93+
logits = self.apply_top_k_top_p(logits, k, p)
94+
logits_to_return = None
95+
if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS:
96+
logits_to_return = logits
97+
elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS:
98+
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
8499
probs = logits.softmax(dim=-1, dtype=torch.float32)
85-
return random_sample(probs, generators)
100+
return random_sample(probs, generators), logits_to_return
86101

87102
def forward_cuda(
88103
self,
89104
logits: torch.Tensor,
90105
generators: dict[int, torch.Generator],
91106
k: Optional[torch.Tensor],
92107
p: Optional[torch.Tensor],
93-
) -> torch.Tensor:
108+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
94109
"""More optimized implementation for top-k and top-p sampling."""
95-
if k is None and p is None:
96-
# We prefer `random_sample` over `flashinfer_sample` when sorting is
97-
# not needed. This is because `random_sample` does not require
98-
# CPU-GPU synchronization while `flashinfer_sample` does.
99-
probs = logits.softmax(dim=-1, dtype=torch.float32)
100-
return random_sample(probs, generators)
101-
if generators:
102-
logger.warning_once("FlashInfer 0.2.3+ does not support "
103-
"per-request generators. Falling back to "
104-
"PyTorch-native implementation.")
110+
# We prefer `random_sample` over `flashinfer_sample` when sorting is
111+
# not needed. This is because `random_sample` does not require
112+
# CPU-GPU synchronization while `flashinfer_sample` does.
113+
if (k is None and p is None) or generators:
114+
if generators:
115+
logger.warning_once("FlashInfer 0.2.3+ does not support "
116+
"per-request generators. Falling back to "
117+
"PyTorch-native implementation.")
105118
return self.forward_native(logits, generators, k, p)
119+
assert self.logprobs_mode not in (
120+
LogprobsMode.PROCESSED_LOGITS, LogprobsMode.PROCESSED_LOGPROBS
121+
), "FlashInfer does not support returning logits/logprobs"
106122
# flashinfer sampling functions expect contiguous logits.
107123
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
108124
# because of slicing operation in logits_processor.
109-
return flashinfer_sample(logits.contiguous(), k, p, generators)
110-
111-
def forward_tpu(
112-
self,
113-
logits: torch.Tensor,
114-
generators: dict[int, torch.Generator],
115-
k: Optional[torch.Tensor],
116-
p: Optional[torch.Tensor],
117-
) -> torch.Tensor:
118-
logits = apply_top_k_top_p_tpu(logits, k, p)
119-
probs = logits.softmax(dim=-1, dtype=torch.float32)
120-
return random_sample(probs, generators)
125+
return flashinfer_sample(logits.contiguous(), k, p, generators), None
121126

122127

123128
def apply_top_k_top_p_tpu(

0 commit comments

Comments
 (0)