diff --git a/tests/entrypoints/test_cli_main.py b/tests/entrypoints/test_cli_main.py new file mode 100644 index 000000000000..2223a8c32733 --- /dev/null +++ b/tests/entrypoints/test_cli_main.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the CLI main entrypoint to ensure help doesn't trigger +platform detection. +""" + +import sys +from unittest.mock import patch + +import pytest + + +@pytest.mark.parametrize( + "argv", + [ + ["vllm", "--help"], + ["vllm", "serve", "--help"], + ["vllm", "-h"], + ["vllm", "bench", "--help"], + ["vllm", "serve", "--help=ModelConfig"], + ], +) +def test_needs_help_detects_help_flags(argv): + """Test that needs_help() correctly detects help flags in sys.argv.""" + from vllm.engine.arg_utils import needs_help + + # patch.object on sys.argv is safe — it's a simple list attribute + # with no lazy-init or side-effect machinery. + with patch.object(sys, "argv", argv): + assert needs_help(), f"needs_help() should return True for {argv}" + + +@pytest.mark.parametrize( + "argv", + [ + ["vllm", "serve", "--model", "test"], + ["vllm", "bench", "latency", "--model", "test"], + ["vllm", "collect-env"], + ], +) +def test_needs_help_returns_false_without_help_flags(argv): + """Test that needs_help() returns False when no help flag is present.""" + from vllm.engine.arg_utils import needs_help + + with patch.object(sys, "argv", argv): + assert not needs_help(), f"needs_help() should return False for {argv}" + + +def test_bench_help_skips_platform_detection(): + """Test that the bench guard in main() is skipped when --help is present. + + The guard in main.py is: + if sys.argv[1] == "bench" and not showing_help + When showing_help is True, current_platform is never accessed for + the bench CPU-override, avoiding unnecessary platform detection. + """ + from vllm.engine.arg_utils import needs_help + + # Verify the guard: needs_help() == True means "not showing_help" is False, + # so the bench platform-override block is skipped. + with patch.object(sys, "argv", ["vllm", "bench", "--help"]): + assert needs_help(), "needs_help() should be True for bench --help" + + # Without --help the guard would be entered + with patch.object(sys, "argv", ["vllm", "bench", "latency"]): + assert not needs_help(), "needs_help() should be False without --help" diff --git a/tests/entrypoints/test_utils_lazy_import.py b/tests/entrypoints/test_utils_lazy_import.py new file mode 100644 index 000000000000..65d21b9365ae --- /dev/null +++ b/tests/entrypoints/test_utils_lazy_import.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Test that get_max_tokens lazily imports current_platform +rather than relying on a module-level import.""" + +import subprocess +import sys + + +def test_get_max_tokens_lazy_platform_import(): + """current_platform should not be imported at module level in utils.py.""" + result = subprocess.run( + [ + sys.executable, + "-c", + ( + "import ast, inspect, textwrap; " + "from vllm.entrypoints import utils; " + "src = inspect.getsource(utils.get_max_tokens); " + "tree = ast.parse(textwrap.dedent(src)); " + "imports = [n for n in ast.walk(tree) " + " if isinstance(n, ast.ImportFrom) " + " and n.module == 'vllm.platforms']; " + "assert imports, 'get_max_tokens should have a local platform import'" + ), + ], + capture_output=True, + text=True, + ) + assert result.returncode == 0, ( + f"Subprocess failed (rc={result.returncode}):\n" + f"stdout: {result.stdout}\nstderr: {result.stderr}" + ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7028b12dab32..e85d427be5f8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -241,11 +241,18 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]: return type_hints -NEEDS_HELP = ( - any("--help" in arg for arg in sys.argv) # vllm SUBCOMMAND --help - or (argv0 := sys.argv[0]).endswith("mkdocs") # mkdocs SUBCOMMAND - or argv0.endswith("mkdocs/__main__.py") # python -m mkdocs SUBCOMMAND -) +def needs_help() -> bool: + """Check if help is being requested via CLI flags or mkdocs.""" + return ( + any( + arg == "-h" or arg.startswith("--help") for arg in sys.argv + ) # vllm SUBCOMMAND --help/-h/--help=X + or (argv0 := sys.argv[0]).endswith("mkdocs") # mkdocs SUBCOMMAND + or argv0.endswith("mkdocs/__main__.py") # python -m mkdocs SUBCOMMAND + ) + + +NEEDS_HELP = needs_help() def _maybe_add_docs_url(cls: Any) -> str: @@ -2412,7 +2419,8 @@ def add_cli_args( "- DEBUG: Prompt inputs (e.g: text, token IDs).\n" "You can set the minimum log level via `VLLM_LOGGING_LEVEL`.", ) - current_platform.pre_register_and_update(parser) + if not NEEDS_HELP: + current_platform.pre_register_and_update(parser) return parser diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index 2261ef233134..83bbe2999468 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -8,6 +8,7 @@ import importlib.metadata import sys +from vllm.engine.arg_utils import needs_help from vllm.logger import init_logger logger = init_logger(__name__) @@ -34,8 +35,9 @@ def main(): cli_env_setup() - # For 'vllm bench *': use CPU instead of UnspecifiedPlatform by default - if len(sys.argv) > 1 and sys.argv[1] == "bench": + # For 'vllm bench *': use CPU instead of UnspecifiedPlatform by default. + # When showing help, skip this to avoid triggering platform detection. + if len(sys.argv) > 1 and sys.argv[1] == "bench" and not needs_help(): logger.debug( "Bench command detected, must ensure current platform is not " "UnspecifiedPlatform to avoid device type inference error" diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index e3682280ec50..2d3fccb02678 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -25,7 +25,6 @@ ) from vllm.entrypoints.openai.models.protocol import LoRAModulePath from vllm.logger import current_formatter_type, init_logger -from vllm.platforms import current_platform from vllm.utils.argparse_utils import FlexibleArgumentParser logger = init_logger(__name__) @@ -178,6 +177,8 @@ def get_max_tokens( default_sampling_params: dict, override_max_tokens: int | None = None, ) -> int: + from vllm.platforms import current_platform + if max_model_len < input_length: raise ValueError( f"Input length ({input_length}) exceeds model's maximum "