Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 15 additions & 20 deletions tests/v1/logits_processors/test_custom_offline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
from typing import Any

import pytest
Expand All @@ -15,9 +16,9 @@
CustomLogitprocSource,
DummyLogitsProcessor,
WrappedPerReqLogitsProcessor,
install_dummy_logitproc_entrypoint,
prompts,
)
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
from vllm import LLM, SamplingParams
from vllm.v1.sample.logits_processor import (
STR_POOLING_REJECTS_LOGITSPROCS,
Expand Down Expand Up @@ -102,7 +103,11 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None:

@create_new_process_for_each_test()
@pytest.mark.parametrize("logitproc_source", list(CustomLogitprocSource))
def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource):
def test_custom_logitsprocs(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
logitproc_source: CustomLogitprocSource,
) -> None:
"""Test offline Python interface for passing custom logitsprocs

Construct an `LLM` instance which loads a custom logitproc that has a
Expand Down Expand Up @@ -145,13 +150,7 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource

if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT:
# Scenario: vLLM loads a logitproc from a preconfigured entrypoint
# To that end, mock a dummy logitproc entrypoint
import importlib.metadata

importlib.metadata.entry_points = fake_entry_points # type: ignore

# fork is required for workers to see entrypoint patch
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork")
install_dummy_logitproc_entrypoint(monkeypatch, tmp_path)
_run_test({}, logitproc_loaded=True)
return

Expand All @@ -167,7 +166,7 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource


@create_new_process_for_each_test()
def test_custom_logitsprocs_req(monkeypatch):
def test_custom_logitsprocs_req(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test passing request-level logits processor to offline Python interface

Wrap a request-level logits processor to create a batch level logits
Expand Down Expand Up @@ -210,8 +209,11 @@ def test_custom_logitsprocs_req(monkeypatch):
],
)
def test_rejects_custom_logitsprocs(
monkeypatch, model_scenario: str, logitproc_source: CustomLogitprocSource
):
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
model_scenario: str,
logitproc_source: CustomLogitprocSource,
) -> None:
"""Validate that vLLM engine initialization properly rejects custom
logitsprocs when the model is a pooling model or speculative decoding
enabled.
Expand Down Expand Up @@ -266,14 +268,7 @@ def test_rejects_custom_logitsprocs(
# Scenario: vLLM loads a model and ignores a logitproc that is
# available at a preconfigured entrypoint

# Patch in dummy logitproc entrypoint
import importlib.metadata

importlib.metadata.entry_points = fake_entry_points # type: ignore

# fork is required for entrypoint patch to be visible to workers,
# although they should ignore the entrypoint patch anyway
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork")
install_dummy_logitproc_entrypoint(monkeypatch, tmp_path)

llm = LLM(**llm_kwargs)
# Require that no custom logitsprocs have been loaded
Expand Down
29 changes: 16 additions & 13 deletions tests/v1/logits_processors/test_custom_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import os
import random
import sys
from collections.abc import AsyncIterator, Iterator
from pathlib import Path
from typing import Any

import openai
Expand All @@ -17,9 +19,9 @@
MAX_TOKENS,
MODEL_NAME,
TEMP_GREEDY,
install_dummy_logitproc_entrypoint,
prompts,
)
from tests.v1.logits_processors.utils import entry_points as fake_entry_points


def _server_with_logitproc_entrypoint(
Expand All @@ -28,15 +30,8 @@ def _server_with_logitproc_entrypoint(
vllm_serve_args: list[str],
) -> None:
"""Start vLLM server, inject dummy logitproc entrypoint"""

# Patch `entry_points` to inject logitproc entrypoint
import importlib.metadata

importlib.metadata.entry_points = fake_entry_points # type: ignore
from vllm.entrypoints.cli import main

# fork is required for workers to see entrypoint patch
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork"
if env_dict is not None:
os.environ.update(env_dict)

Expand All @@ -62,7 +57,7 @@ def _server_with_logitproc_fqcn(


@pytest.fixture(scope="module")
def default_server_args():
def default_server_args() -> list[str]:
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
Expand All @@ -77,7 +72,12 @@ def default_server_args():
@pytest.fixture(
scope="function", params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]]
)
def server(default_server_args, request, monkeypatch):
def server(
default_server_args: list[str],
request: pytest.FixtureRequest,
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> Iterator[RemoteOpenAIServerCustom]:
"""Consider two server configurations:
(1) --logits-processors cli arg specifies dummy logits processor via fully-
qualified class name (FQCN); patch in a dummy logits processor module
Expand All @@ -93,7 +93,8 @@ def server(default_server_args, request, monkeypatch):
args = default_server_args + request.param
_server_fxn = _server_with_logitproc_fqcn
else:
# Launch server, inject dummy logitproc entrypoint
# Launch server with a discoverable dummy logitproc entrypoint
install_dummy_logitproc_entrypoint(monkeypatch, tmp_path)
args = default_server_args
_server_fxn = _server_with_logitproc_entrypoint

Expand All @@ -102,7 +103,9 @@ def server(default_server_args, request, monkeypatch):


@pytest_asyncio.fixture
async def client(server):
async def client(
server: RemoteOpenAIServerCustom,
) -> AsyncIterator[openai.AsyncOpenAI]:
async with server.get_async_client() as async_client:
yield async_client

Expand All @@ -124,7 +127,7 @@ async def client(server):
"model_name",
[MODEL_NAME],
)
def test_custom_logitsprocs(server, model_name: str):
def test_custom_logitsprocs(server: RemoteOpenAIServerCustom, model_name: str) -> None:
"""Test custom logitsprocs when starting OpenAI server from CLI

Launch vLLM OpenAI-compatible server, configured to load a custom logitproc
Expand Down
27 changes: 27 additions & 0 deletions tests/v1/logits_processors/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
import types
from enum import Enum, auto
from pathlib import Path
from typing import Any

import pytest
import torch

from vllm.config import VllmConfig
Expand Down Expand Up @@ -189,3 +192,27 @@ def new_req_logits_processor(

"""Fake version of importlib.metadata.entry_points"""
entry_points = lambda group: EntryPoints(group)


def install_dummy_logitproc_entrypoint(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
"""Install a temporary entrypoint visible to spawned worker processes."""
dist_info = tmp_path / "dummy_vllm_logitproc-0.0.dist-info"
dist_info.mkdir()
(dist_info / "METADATA").write_text(
"Metadata-Version: 2.1\nName: dummy-vllm-logitproc\nVersion: 0.0\n",
encoding="utf-8",
)
(dist_info / "entry_points.txt").write_text(
f"[{LOGITSPROCS_GROUP}]\n"
f"{DUMMY_LOGITPROC_ENTRYPOINT} = {DUMMY_LOGITPROC_FQCN}\n",
encoding="utf-8",
)

monkeypatch.syspath_prepend(str(tmp_path))
pythonpath = os.environ.get("PYTHONPATH")
monkeypatch.setenv(
"PYTHONPATH",
os.pathsep.join(filter(None, [str(tmp_path), pythonpath])),
)
Comment thread
SoluMilken marked this conversation as resolved.
Loading