Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions tests/v1/logits_processors/test_entrypoint_output_token_tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test that entry-point logits processor plugins correctly enable
output token id tracking.

Previously, ``gpu_model_runner.py`` derived the
``logitsprocs_need_output_token_ids`` flag solely from CLI-passed
``custom_logitsprocs``, ignoring entry-point plugins loaded by
``build_logitsprocs()``. When the flag was ``False`` and all penalties
were neutral (``repetition_penalty=1.0``), vLLM's async scheduling path
filled the output token id buffer with ``-1`` placeholders instead of
real tokens, silently breaking any entry-point logits processor that
inspects generation history.

This test verifies that entry-point plugins cause
``LogitsProcessors.has_custom`` (and therefore the flag) to be ``True``.
"""

import importlib.metadata

import pytest
import torch

from tests.v1.logits_processors.utils import (
DummyLogitsProcessor,
entry_points as fake_entry_points,
)
from vllm.config import VllmConfig
from vllm.platforms import current_platform
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.sample.logits_processor import build_logitsprocs

PIN_MEMORY_AVAILABLE = is_pin_memory_available()
DEVICE = current_platform.device_type


@pytest.fixture()
def _mock_entry_points(monkeypatch):
"""Inject a dummy non-argmax-invariant logits processor entry point."""
monkeypatch.setattr(importlib.metadata, "entry_points", fake_entry_points)


@pytest.mark.usefixtures("_mock_entry_points")
def test_entrypoint_plugin_enables_output_token_tracking():
"""Verify that a non-argmax-invariant entry-point plugin sets the
output-token-tracking flag even when no CLI logits processors are
provided.

This is the scenario that was broken before the fix: a plugin like
``NoRepeatNGramLogitsProcessor`` loaded via ``vllm.logits_processors``
entry-point group with ``is_argmax_invariant() = False`` would not
cause ``logitsprocs_need_output_token_ids`` to be set, resulting in
all ``-1`` placeholder tokens in the output buffer.
"""
device = torch.device(DEVICE)
vllm_config = VllmConfig()

# No CLI-passed custom logits processors — only the entry-point plugin.
custom_logitsprocs: tuple = ()

logitsprocs = build_logitsprocs(
vllm_config,
device,
PIN_MEMORY_AVAILABLE,
is_pooling_model=False,
custom_logitsprocs=custom_logitsprocs,
)

# The entry-point plugin (DummyLogitsProcessor) returns
# is_argmax_invariant() = False, so it should appear in
# non_argmax_invariant.
entry_point_types = [
type(p) for p in logitsprocs.non_argmax_invariant
]
assert DummyLogitsProcessor in entry_point_types, (
"DummyLogitsProcessor should be loaded via the entry-point mock"
)

# This is the key assertion: has_custom must be True when entry-point
# plugins are present, even though custom_logitsprocs is empty.
assert logitsprocs.has_custom is True, (
"has_custom should be True when entry-point plugins are loaded, "
"even with no CLI-passed custom_logitsprocs"
)


@pytest.mark.usefixtures("_mock_entry_points")
def test_cli_logitsprocs_still_enable_tracking():
"""Verify that CLI-passed logits processors still enable tracking
(existing behavior preserved)."""
device = torch.device(DEVICE)
vllm_config = VllmConfig()

custom_logitsprocs = (DummyLogitsProcessor,)

logitsprocs = build_logitsprocs(
vllm_config,
device,
PIN_MEMORY_AVAILABLE,
is_pooling_model=False,
custom_logitsprocs=custom_logitsprocs,
)

assert logitsprocs.has_custom is True


def test_no_logitsprocs_disables_tracking():
"""Verify that has_custom is False when no custom logits processors
are loaded (no CLI, no entry-point plugins)."""
device = torch.device(DEVICE)
vllm_config = VllmConfig()

custom_logitsprocs: tuple = ()

# Without mocking entry_points, no plugins will be found.
logitsprocs = build_logitsprocs(
vllm_config,
device,
PIN_MEMORY_AVAILABLE,
is_pooling_model=False,
custom_logitsprocs=custom_logitsprocs,
)

# Only builtins should be loaded — has_custom should be False.
assert logitsprocs.has_custom is False, (
"has_custom should be False when only builtins are loaded"
)
11 changes: 7 additions & 4 deletions vllm/v1/sample/logits_processor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,13 @@ def build_logitsprocs(

custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
return LogitsProcessors(
ctor(vllm_config, device, is_pin_memory)
for ctor in itertools.chain(
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes
)
(
ctor(vllm_config, device, is_pin_memory)
for ctor in itertools.chain(
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes
)
),
has_custom=bool(custom_logitsprocs_classes),
)


Expand Down
10 changes: 9 additions & 1 deletion vllm/v1/sample/logits_processor/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,17 @@ def get_and_reset(self, batch_size: int) -> BatchUpdate | None:
class LogitsProcessors:
"""Encapsulates initialized logitsproc objects."""

def __init__(self, logitsprocs: Iterable["LogitsProcessor"] | None = None) -> None:
def __init__(
self,
logitsprocs: Iterable["LogitsProcessor"] | None = None,
*,
has_custom: bool = False,
) -> None:
self.argmax_invariant: list[LogitsProcessor] = []
self.non_argmax_invariant: list[LogitsProcessor] = []
# True when non-builtin processors (entry-point plugins or
# CLI-passed custom processors) are loaded.
self.has_custom = has_custom
if logitsprocs:
for logitproc in logitsprocs:
(
Expand Down
28 changes: 15 additions & 13 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,19 @@ def __init__(
)
self._init_block_sizes = [placeholder_block_size]
self._init_kernel_block_sizes = [placeholder_block_size]
logitsprocs = build_logitsprocs(
self.vllm_config,
self.device,
self.pin_memory,
self.is_pooling_model,
custom_logitsprocs,
)
# ThinkingTokenBudgetLogitsProcessor also needs output token ids to
# correctly track think start/end token sequences in async scheduling.
logitsprocs_need_output_token_ids = (
logitsprocs.has_custom
or self.vllm_config.reasoning_config is not None
)
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
# We need to use the encoder length for encoder-decoder
Expand All @@ -625,19 +638,8 @@ def __init__(
block_sizes=[placeholder_block_size],
kernel_block_sizes=[placeholder_block_size],
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=build_logitsprocs(
self.vllm_config,
self.device,
self.pin_memory,
self.is_pooling_model,
custom_logitsprocs,
),
# We currently don't know whether a particular custom logits processor
# uses output token ids so we set this conservatively.
# ThinkingTokenBudgetLogitsProcessor also needs output token ids to
# correctly track think start/end token sequences in async scheduling.
logitsprocs_need_output_token_ids=bool(custom_logitsprocs)
or self.vllm_config.reasoning_config is not None,
logitsprocs=logitsprocs,
logitsprocs_need_output_token_ids=logitsprocs_need_output_token_ids,
is_pooling_model=self.is_pooling_model,
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
)
Expand Down
Loading