Skip to content
Merged
31 changes: 31 additions & 0 deletions tests/v1/sample/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import itertools
import math
from collections.abc import Generator
from types import SimpleNamespace
from typing import get_args

import pytest
Expand All @@ -20,6 +21,7 @@
from vllm import SamplingParams
from vllm.config.model import LogprobsMode
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.exceptions import VLLMValidationError
from vllm.platforms import current_platform

from ...conftest import HfRunner, VllmRunner
Expand Down Expand Up @@ -78,6 +80,14 @@ def hf_model(hf_runner) -> Generator[HfRunner, None, None]:
yield hf_model


def _model_config(vocab_size: int = 10):
return SimpleNamespace(
max_logprobs=20,
logits_processors=None,
get_vocab_size=lambda: vocab_size,
)


def _repeat_logprob_config(
test_prompts,
logprob_prompt_logprob_list: BatchLogprobsSpecType,
Expand Down Expand Up @@ -397,6 +407,27 @@ def test_max_logprobs():
runner.generate(["Hello world"], sampling_params=bad_sampling_params)


@pytest.mark.parametrize("token_ids", [[0], [0, 9]])
def test_logprob_token_ids_validate_vocab_bounds_valid(token_ids: list[int]):
SamplingParams(logprob_token_ids=token_ids).verify(
_model_config(),
speculative_config=None,
structured_outputs_config=None,
tokenizer=None,
)


@pytest.mark.parametrize("token_ids", [[-1], [10], [-35, 1873042417]])
def test_logprob_token_ids_validate_vocab_bounds_invalid(token_ids: list[int]):
with pytest.raises(VLLMValidationError, match="logprob_token_ids"):
SamplingParams(logprob_token_ids=token_ids).verify(
_model_config(),
speculative_config=None,
structured_outputs_config=None,
tokenizer=None,
)


def test_none_logprobs(vllm_model, example_prompts):
"""Engine should return `logprobs` and `prompt_logprobs` as `None`

Expand Down
14 changes: 14 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,20 @@ def _validate_logprobs(self, model_config: ModelConfig) -> None:
parameter="logprob_token_ids",
value=n,
)
vocab_size = model_config.get_vocab_size()
invalid_token_ids = [
token_id
for token_id in self.logprob_token_ids
if token_id < 0 or token_id >= vocab_size
]
if invalid_token_ids:
raise VLLMValidationError(
f"token_id(s) {invalid_token_ids} in logprob_token_ids "
f"contain out-of-vocab token ids. Vocabulary size: "
f"{vocab_size}",
parameter="logprob_token_ids",
value=invalid_token_ids,
)
if self.logprobs is not None and self.logprobs != n:
raise VLLMValidationError(
f"When both logprobs and logprob_token_ids are set, "
Expand Down
36 changes: 27 additions & 9 deletions vllm/v1/sample/thinking_budget_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch

from vllm.platforms import current_platform
from vllm.utils.torch_utils import async_tensor_h2d
from vllm.v1.sample.logits_processor.interface import (
BatchUpdate,
Expand Down Expand Up @@ -511,14 +512,31 @@ def _apply_forcing_to_logits(

if active_indices_cpu:
device = logits.device
active_indices = async_tensor_h2d(
active_indices_cpu, dtype=torch.long, device=device
)
force_tokens = async_tensor_h2d(
force_tokens_cpu, dtype=torch.long, device=device
)
# Avoid CPU->GPU sync.
fill = logits.new_full((len(active_indices_cpu),), 1e9)
logits.index_put_((active_indices, force_tokens), fill)
if current_platform.is_rocm() and logits.is_contiguous():
# Flattened index_fill avoids ROCm faults seen with 2-D
# advanced-indexing writes on the thinking-budget path.
vocab_size = logits.shape[1]
flat_indices_cpu = [
row * vocab_size + token
for row, token in zip(active_indices_cpu, force_tokens_cpu)
]
flat_indices = async_tensor_h2d(
flat_indices_cpu, dtype=torch.long, device=device
)
logits.view(-1).index_fill_(0, flat_indices, 1e9)
elif current_platform.is_rocm():
fill = logits.new_tensor(1e9)
for row, token in zip(active_indices_cpu, force_tokens_cpu):
logits[row, token] = fill
else:
active_indices = async_tensor_h2d(
active_indices_cpu, dtype=torch.long, device=device
)
force_tokens = async_tensor_h2d(
force_tokens_cpu, dtype=torch.long, device=device
)
# Avoid CPU->GPU sync.
fill = logits.new_full((len(active_indices_cpu),), 1e9)
logits.index_put_((active_indices, force_tokens), fill)

return logits
10 changes: 10 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6264,11 +6264,21 @@ def shutdown(self) -> None:

# Calls torch.accelerator.synchronize()
self._cleanup_profiling_kv_cache()
if current_platform.is_rocm():
# Drop captured graphs before distributed teardown. On ROCm, delayed
# graph destruction can surface HSA faults in the next engine startup.
CUDAGraphWrapper.clear_all_graphs()
BreakableCUDAGraphWrapper.clear_all_graphs()
self.encoder_cudagraph_manager = None
self.compilation_config.static_forward_context.clear()
self.model = None # type: ignore[assignment]
_ROPE_DICT.clear()

reset_workspace_manager()
if current_platform.is_rocm():
gc.collect()
torch.accelerator.empty_cache()
torch.accelerator.synchronize()

def _cleanup_profiling_kv_cache(self) -> None:
torch.accelerator.synchronize()
Expand Down
Loading