Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions tests/entrypoints/test_cli_main.py
Comment thread
AbhiOnGithub marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the CLI main entrypoint to ensure help doesn't trigger
platform detection.
"""

import sys
from unittest.mock import patch

import pytest


@pytest.mark.parametrize(
"argv",
[
["vllm", "--help"],
["vllm", "serve", "--help"],
["vllm", "-h"],
["vllm", "bench", "--help"],
["vllm", "serve", "--help=ModelConfig"],
],
)
def test_needs_help_detects_help_flags(argv):
"""Test that needs_help() correctly detects help flags in sys.argv."""
from vllm.engine.arg_utils import needs_help

# patch.object on sys.argv is safe — it's a simple list attribute
# with no lazy-init or side-effect machinery.
with patch.object(sys, "argv", argv):
assert needs_help(), f"needs_help() should return True for {argv}"


@pytest.mark.parametrize(
"argv",
[
["vllm", "serve", "--model", "test"],
["vllm", "bench", "latency", "--model", "test"],
["vllm", "collect-env"],
],
)
def test_needs_help_returns_false_without_help_flags(argv):
"""Test that needs_help() returns False when no help flag is present."""
from vllm.engine.arg_utils import needs_help

with patch.object(sys, "argv", argv):
assert not needs_help(), f"needs_help() should return False for {argv}"
Comment on lines +13 to +46
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These can be merged into

def test_needs_help(argv, expected):
    ...



def test_bench_help_skips_platform_detection():
"""Test that the bench guard in main() is skipped when --help is present.

The guard in main.py is:
if sys.argv[1] == "bench" and not showing_help
When showing_help is True, current_platform is never accessed for
the bench CPU-override, avoiding unnecessary platform detection.
"""
from vllm.engine.arg_utils import needs_help

# Verify the guard: needs_help() == True means "not showing_help" is False,
# so the bench platform-override block is skipped.
with patch.object(sys, "argv", ["vllm", "bench", "--help"]):
assert needs_help(), "needs_help() should be True for bench --help"

# Without --help the guard would be entered
with patch.object(sys, "argv", ["vllm", "bench", "latency"]):
assert not needs_help(), "needs_help() should be False without --help"
Comment on lines +49 to +66
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This duplicates what is already done in the previous 2 tests?

33 changes: 33 additions & 0 deletions tests/entrypoints/test_utils_lazy_import.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is strangely specific. Surely we just want to check that there's no vllm.platforms in the global scope, not that there is vllm.platforms in get_max_tokens?

Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test that get_max_tokens lazily imports current_platform
rather than relying on a module-level import."""

import subprocess
import sys


def test_get_max_tokens_lazy_platform_import():
"""current_platform should not be imported at module level in utils.py."""
result = subprocess.run(
[
sys.executable,
"-c",
(
"import ast, inspect, textwrap; "
"from vllm.entrypoints import utils; "
"src = inspect.getsource(utils.get_max_tokens); "
"tree = ast.parse(textwrap.dedent(src)); "
"imports = [n for n in ast.walk(tree) "
" if isinstance(n, ast.ImportFrom) "
" and n.module == 'vllm.platforms']; "
"assert imports, 'get_max_tokens should have a local platform import'"
),
],
capture_output=True,
text=True,
)
assert result.returncode == 0, (
f"Subprocess failed (rc={result.returncode}):\n"
f"stdout: {result.stdout}\nstderr: {result.stderr}"
)
20 changes: 14 additions & 6 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,18 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
return type_hints


NEEDS_HELP = (
any("--help" in arg for arg in sys.argv) # vllm SUBCOMMAND --help
or (argv0 := sys.argv[0]).endswith("mkdocs") # mkdocs SUBCOMMAND
or argv0.endswith("mkdocs/__main__.py") # python -m mkdocs SUBCOMMAND
)
def needs_help() -> bool:
"""Check if help is being requested via CLI flags or mkdocs."""
return (
any(
arg == "-h" or arg.startswith("--help") for arg in sys.argv
) # vllm SUBCOMMAND --help/-h/--help=X
or (argv0 := sys.argv[0]).endswith("mkdocs") # mkdocs SUBCOMMAND
or argv0.endswith("mkdocs/__main__.py") # python -m mkdocs SUBCOMMAND
)


NEEDS_HELP = needs_help()


def _maybe_add_docs_url(cls: Any) -> str:
Expand Down Expand Up @@ -2412,7 +2419,8 @@ def add_cli_args(
"- DEBUG: Prompt inputs (e.g: text, token IDs).\n"
"You can set the minimum log level via `VLLM_LOGGING_LEVEL`.",
)
current_platform.pre_register_and_update(parser)
if not NEEDS_HELP:
current_platform.pre_register_and_update(parser)
return parser


Expand Down
6 changes: 4 additions & 2 deletions vllm/entrypoints/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import importlib.metadata
import sys

from vllm.engine.arg_utils import needs_help
from vllm.logger import init_logger

logger = init_logger(__name__)
Expand All @@ -34,8 +35,9 @@ def main():

cli_env_setup()

# For 'vllm bench *': use CPU instead of UnspecifiedPlatform by default
if len(sys.argv) > 1 and sys.argv[1] == "bench":
# For 'vllm bench *': use CPU instead of UnspecifiedPlatform by default.
# When showing help, skip this to avoid triggering platform detection.
if len(sys.argv) > 1 and sys.argv[1] == "bench" and not needs_help():
logger.debug(
"Bench command detected, must ensure current platform is not "
"UnspecifiedPlatform to avoid device type inference error"
Expand Down
3 changes: 2 additions & 1 deletion vllm/entrypoints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
)
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser

logger = init_logger(__name__)
Expand Down Expand Up @@ -178,6 +177,8 @@ def get_max_tokens(
default_sampling_params: dict,
override_max_tokens: int | None = None,
) -> int:
from vllm.platforms import current_platform

if max_model_len < input_length:
raise ValueError(
f"Input length ({input_length}) exceeds model's maximum "
Expand Down
Loading