diff --git a/tests/entrypoints/test_serve.py b/tests/entrypoints/test_serve.py index e60afc9cd7b..745d0edee6e 100644 --- a/tests/entrypoints/test_serve.py +++ b/tests/entrypoints/test_serve.py @@ -1,4 +1,4 @@ -"""Unit tests for the Omni serve CLI helpers.""" +"""Unit tests for the Omni serve CLI helpers and detect_explicit_cli_keys.""" from __future__ import annotations @@ -13,6 +13,112 @@ pytestmark = [pytest.mark.core_model, pytest.mark.cpu] +# ============================================================================ +# detect_explicit_cli_keys — parser-aware mode +# ============================================================================ + + +def test_detect_explicit_cli_keys_with_parser_basic() -> None: + """Parser-aware mode returns correct dest names.""" + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-seqs", type=int, default=256) + parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) + parser.add_argument("--dtype", type=str, default="auto") + + argv = ["--max-num-seqs", "64", "--dtype", "float16"] + explicit = detect_explicit_cli_keys(argv, parser) + + assert explicit == {"max_num_seqs", "dtype"} + assert "gpu_memory_utilization" not in explicit + + +def test_detect_explicit_cli_keys_with_parser_equals_syntax() -> None: + """Parser-aware mode handles --flag=value syntax.""" + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-seqs", type=int, default=256) + + explicit = detect_explicit_cli_keys(["--max-num-seqs=64"], parser) + assert "max_num_seqs" in explicit + + +def test_detect_explicit_cli_keys_with_parser_alias() -> None: + """Parser-aware mode resolves alias flags to the canonical dest.""" + parser = argparse.ArgumentParser() + parser.add_argument("--usp", "--ulysses-degree", type=int, default=1, dest="ulysses_degree") + + explicit = detect_explicit_cli_keys(["--usp", "4"], parser) + assert "ulysses_degree" in explicit + assert "usp" not in explicit + + +def test_detect_explicit_cli_keys_with_parser_store_false() -> None: + """Parser-aware mode maps --disable-X to its actual dest.""" + parser = argparse.ArgumentParser() + parser.add_argument("--disable-log-requests", action="store_true", dest="disable_log_requests") + + explicit = detect_explicit_cli_keys(["--disable-log-requests"], parser) + assert "disable_log_requests" in explicit + + +def test_detect_explicit_cli_keys_empty_argv() -> None: + """No flags typed → empty set.""" + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-seqs", type=int, default=256) + + assert detect_explicit_cli_keys([], parser) == set() + + +def test_detect_explicit_cli_keys_ignores_positional_args() -> None: + """Positional arguments (no -- prefix) are ignored.""" + parser = argparse.ArgumentParser() + parser.add_argument("--dtype", type=str, default="auto") + + explicit = detect_explicit_cli_keys(["serve", "fake-model", "--dtype", "float16"], parser) + assert explicit == {"dtype"} + + +# ============================================================================ +# detect_explicit_cli_keys — heuristic fallback (parser=None) +# ============================================================================ + + +def test_detect_explicit_cli_keys_heuristic_basic() -> None: + """Heuristic mode converts hyphens to underscores.""" + explicit = detect_explicit_cli_keys(["--max-num-seqs", "64", "--dtype", "float16"], None) + assert "max_num_seqs" in explicit + assert "dtype" in explicit + + +def test_detect_explicit_cli_keys_heuristic_no_prefix() -> None: + """Heuristic mode strips --no- prefix for BooleanOptionalAction compat.""" + explicit = detect_explicit_cli_keys(["--no-async-chunk"], None) + assert "no_async_chunk" in explicit + assert "async_chunk" in explicit # also adds the stripped form + + +def test_detect_explicit_cli_keys_heuristic_equals() -> None: + """Heuristic mode handles --flag=value syntax.""" + explicit = detect_explicit_cli_keys(["--gpu-memory-utilization=0.8"], None) + assert "gpu_memory_utilization" in explicit + + +def test_detect_explicit_cli_keys_user_types_default_value() -> None: + """When user explicitly types a value that equals the default, it must + still appear in explicit keys — this is the key advantage over the old + defaults-comparison approach.""" + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-seqs", type=int, default=256) + + # User explicitly types the default value + explicit = detect_explicit_cli_keys(["--max-num-seqs", "256"], parser) + assert "max_num_seqs" in explicit + + +# ============================================================================ +# serve parser integration +# ============================================================================ + + def test_serve_parser_accepts_no_async_chunk_and_marks_it_explicit() -> None: """``--no-async-chunk`` should parse to ``async_chunk=False`` and mark the shared deploy-level dest as explicitly provided by the user.""" diff --git a/tests/entrypoints/test_utils.py b/tests/entrypoints/test_utils.py index 248629d51df..14a12442d7e 100644 --- a/tests/entrypoints/test_utils.py +++ b/tests/entrypoints/test_utils.py @@ -359,6 +359,138 @@ def test_load_and_resolve_with_kwargs(self): assert "dtype" in stage_configs[0]["engine_args"] +class TestExplicitCliKeysFiltering: + """Tests for _explicit_cli_keys filtering in load_stage_configs_from_model.""" + + def test_explicit_cli_keys_filters_overrides(self, mocker: MockerFixture): + """Only user-typed CLI flags should survive as cli_overrides; + argparse defaults must be discarded.""" + mock_factory = mocker.patch( + "vllm_omni.entrypoints.utils.StageConfigFactory.create_from_model", + return_value=None, + ) + mocker.patch( + "vllm_omni.entrypoints.utils.resolve_model_config_path", + return_value=None, + ) + + from vllm_omni.entrypoints.utils import load_stage_configs_from_model + + base_engine_args = { + "max_num_seqs": 64, # user typed + "dtype": "auto", # argparse default (not typed) + "enforce_eager": False, # argparse default (not typed) + "_explicit_cli_keys": {"max_num_seqs"}, + } + + load_stage_configs_from_model( + model="fake-model", + base_engine_args=base_engine_args, + ) + + # StageConfigFactory.create_from_model receives filtered overrides + call_kwargs = mock_factory.call_args + cli_overrides = call_kwargs.kwargs.get("cli_overrides") or call_kwargs[1].get("cli_overrides") + assert "max_num_seqs" in cli_overrides + assert "dtype" not in cli_overrides + assert "enforce_eager" not in cli_overrides + assert "_explicit_cli_keys" not in cli_overrides + + def test_no_explicit_cli_keys_passes_all_overrides(self, mocker: MockerFixture): + """When _explicit_cli_keys is absent (programmatic use), all args + pass through as overrides for backward compatibility.""" + mock_factory = mocker.patch( + "vllm_omni.entrypoints.utils.StageConfigFactory.create_from_model", + return_value=None, + ) + mocker.patch( + "vllm_omni.entrypoints.utils.resolve_model_config_path", + return_value=None, + ) + + from vllm_omni.entrypoints.utils import load_stage_configs_from_model + + base_engine_args = { + "max_num_seqs": 64, + "dtype": "auto", + } + + load_stage_configs_from_model( + model="fake-model", + base_engine_args=base_engine_args, + ) + + call_kwargs = mock_factory.call_args + cli_overrides = call_kwargs.kwargs.get("cli_overrides") or call_kwargs[1].get("cli_overrides") + assert "max_num_seqs" in cli_overrides + assert "dtype" in cli_overrides + + def test_explicit_cli_keys_with_stage_overrides(self, mocker: MockerFixture): + """Per-stage overrides from --stage-overrides are always included, + even when _explicit_cli_keys filters out global defaults.""" + mock_factory = mocker.patch( + "vllm_omni.entrypoints.utils.StageConfigFactory.create_from_model", + return_value=None, + ) + mocker.patch( + "vllm_omni.entrypoints.utils.resolve_model_config_path", + return_value=None, + ) + + from vllm_omni.entrypoints.utils import load_stage_configs_from_model + + base_engine_args = { + "max_num_seqs": 64, + "dtype": "auto", # will be filtered out + "_explicit_cli_keys": {"max_num_seqs"}, + } + stage_overrides = {"0": {"gpu_memory_utilization": 0.8}} + + load_stage_configs_from_model( + model="fake-model", + base_engine_args=base_engine_args, + stage_overrides=stage_overrides, + ) + + call_kwargs = mock_factory.call_args + cli_overrides = call_kwargs.kwargs.get("cli_overrides") or call_kwargs[1].get("cli_overrides") + assert "max_num_seqs" in cli_overrides + assert "dtype" not in cli_overrides + assert "stage_0_gpu_memory_utilization" in cli_overrides + + def test_explicit_cli_keys_preserves_dataclass_values(self, mocker: MockerFixture): + """Dataclass-typed values (e.g. profiler_config dict) survive + filtering when the key is in _explicit_cli_keys.""" + mock_factory = mocker.patch( + "vllm_omni.entrypoints.utils.StageConfigFactory.create_from_model", + return_value=None, + ) + mocker.patch( + "vllm_omni.entrypoints.utils.resolve_model_config_path", + return_value=None, + ) + + from vllm_omni.entrypoints.utils import load_stage_configs_from_model + + profiler_dict = {"profiler": "torch", "torch_profiler_dir": "./test"} + base_engine_args = { + "profiler_config": profiler_dict, + "dtype": "auto", + "_explicit_cli_keys": {"profiler_config"}, + } + + load_stage_configs_from_model( + model="fake-model", + base_engine_args=base_engine_args, + ) + + call_kwargs = mock_factory.call_args + cli_overrides = call_kwargs.kwargs.get("cli_overrides") or call_kwargs[1].get("cli_overrides") + assert "profiler_config" in cli_overrides + assert cli_overrides["profiler_config"]["profiler"] == "torch" + assert "dtype" not in cli_overrides + + class TestLoadStageConfigsFromYaml: """Regression tests for stage-config loading and merging.""" diff --git a/tests/test_arg_utils.py b/tests/test_arg_utils.py index ae640b2d861..163c0876569 100644 --- a/tests/test_arg_utils.py +++ b/tests/test_arg_utils.py @@ -353,72 +353,73 @@ class _AmbiguousEngine: assert any("both OrchestratorArgs" in r.message for r in caplog.records) -# Sentinel-default precedence invariants (#3035) - +# ============================================================================ +# deploy_override_field_names — dynamic derivation (replaces whitelist). +# ============================================================================ -def _build_full_serve_parser(): - from vllm.utils.argparse_utils import FlexibleArgumentParser +def test_deploy_override_field_names_covers_all_engine_fields(): + """deploy_override_field_names must dynamically include every OmniEngineArgs + field (minus orchestrator-only keys) so that YAML values for any vLLM + parameter are never silently overwritten by argparse defaults.""" try: - from vllm.entrypoints.openai.cli_args import make_arg_parser - except ImportError: - pytest.skip("vllm parser not importable") - return make_arg_parser(FlexibleArgumentParser()) - + from vllm_omni.engine.arg_utils import OmniEngineArgs, deploy_override_field_names + except Exception as exc: + pytest.skip(f"OmniEngineArgs not importable: {exc}") -def test_nullify_stage_engine_defaults_resets_inherited_defaults(): - import argparse + override_names = deploy_override_field_names() + engine_fields = {f.name for f in fields(OmniEngineArgs)} + orch_fields = orchestrator_field_names() - from vllm_omni.engine.arg_utils import ( - deploy_override_field_names, - nullify_stage_engine_defaults, + # Every engine field that is not orchestrator-only should be in the set. + missing = (engine_fields - orch_fields) - override_names + assert not missing, ( + f"Engine fields missing from deploy_override_field_names: {sorted(missing)}. " + f"These fields' YAML values could be silently overridden by argparse defaults." ) - parser = _build_full_serve_parser() - nullify_stage_engine_defaults(parser) - override_dests = deploy_override_field_names() - offenders = [ - (a.dest, a.default) - for a in parser._actions - if a.dest not in ("help", "version") - and a.option_strings - and a.dest in override_dests - and a.default is not None - and a.default is not argparse.SUPPRESS - ] - assert not offenders, f"Stage flags with non-None defaults after nullify: {offenders}" +def test_deploy_override_field_names_includes_profiler_config(): + """Regression: profiler_config must be covered by deploy overrides.""" + from vllm_omni.engine.arg_utils import deploy_override_field_names + assert "profiler_config" in deploy_override_field_names() -def test_non_override_flags_keep_real_defaults_after_nullify(): - import argparse - from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults +# ============================================================================ +# OmniEngineArgs.__post_init__ — dict → ProfilerConfig conversion. +# ============================================================================ - parser = argparse.ArgumentParser() - parser.add_argument("--hsdp-shard-size", type=int, default=-1, help="HSDP shard size.") - parser.add_argument("--max-num-seqs", type=int, default=64, help="Max num seqs.") - nullify_stage_engine_defaults(parser) - hsdp = next(a for a in parser._actions if a.dest == "hsdp_shard_size") - max_num_seqs = next(a for a in parser._actions if a.dest == "max_num_seqs") - assert hsdp.default == -1 - assert max_num_seqs.default is None +def test_omniengineargs_converts_dict_profiler_config(): + """profiler_config from YAML arrives as a dict; __post_init__ must convert + it to a ProfilerConfig dataclass so downstream code works correctly.""" + try: + from vllm.config import ProfilerConfig + from vllm_omni.engine.arg_utils import OmniEngineArgs + except Exception as exc: + pytest.skip(f"Not importable: {exc}") -def test_help_text_preserves_default_after_nullify(): - # Real defaults must stay visible in --help even though parser stores None. - import argparse + ea = OmniEngineArgs( + model="fake-model", + profiler_config={"profiler": "torch", "torch_profiler_dir": "./test-dir"}, + ) + assert isinstance(ea.profiler_config, ProfilerConfig) + assert ea.profiler_config.profiler == "torch" + assert ea.profiler_config.torch_profiler_dir == "./test-dir" - from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults - parser = argparse.ArgumentParser() - parser.add_argument("--max-num-seqs", type=int, default=42, help="Example knob.") - nullify_stage_engine_defaults(parser) +def test_omniengineargs_preserves_none_profiler_config(): + """When profiler_config is not set (None), it should remain None.""" + try: + from vllm_omni.engine.arg_utils import OmniEngineArgs + except Exception as exc: + pytest.skip(f"Not importable: {exc}") - action = next(a for a in parser._actions if a.dest == "max_num_seqs") - assert action.default is None - assert "(default: 42)" in action.help + ea = OmniEngineArgs(model="fake-model") + # Default from vLLM — should not be converted or raise + assert ea.profiler_config is None or hasattr(ea.profiler_config, "profiler") _OMNIENGINEARGS_USER_INPUT_FIELDS = frozenset( diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index 3f16c329e27..bd17ad095fe 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -174,6 +174,12 @@ def __post_init__(self) -> None: self.worker_cls = current_omni_platform.get_omni_ar_worker_cls() elif self.worker_type == "generation": self.worker_cls = current_omni_platform.get_omni_generation_worker_cls() + # Convert dict profiler_config (from stage YAML) to ProfilerConfig, + # mirroring what EngineArgs.__post_init__ does for compilation_config etc. + if isinstance(self.profiler_config, dict): + from vllm.config import ProfilerConfig + + self.profiler_config = ProfilerConfig(**self.profiler_config) load_omni_general_plugins() super().__post_init__() @@ -456,33 +462,6 @@ class OrchestratorArgs: } ) -_DEPLOY_ENGINE_ARG_OVERRIDE_FIELDS: frozenset[str] = frozenset( - { - # Capacity / scheduling. - "async_scheduling", - "max_model_len", - "max_num_batched_tokens", - "max_num_seqs", - # Memory / parallelism. - "data_parallel_size", - "gpu_memory_utilization", - "pipeline_parallel_size", - "tensor_parallel_size", - # Execution / loading. - "enforce_eager", - "distributed_executor_backend", - "dtype", - "quantization", - "trust_remote_code", - # Caching / chunking. - "async_chunk", - "enable_prefix_caching", - "enable_chunked_prefill", - # Model-specific engine extras. - "subtalker_sampling_params", - } -) - _DEPLOY_RUNTIME_OVERRIDE_FIELDS: frozenset[str] = frozenset( { "devices", @@ -496,8 +475,16 @@ def orchestrator_field_names() -> frozenset[str]: def deploy_override_field_names() -> frozenset[str]: - """Return kwargs whose parser defaults must not override deploy YAML.""" - return _DEPLOY_ENGINE_ARG_OVERRIDE_FIELDS | _DEPLOY_RUNTIME_OVERRIDE_FIELDS + """Return kwargs whose parser defaults must not override deploy YAML. + + Dynamically computed from all ``OmniEngineArgs`` fields so that newly + added engine args (e.g. ``profiler_config``) are automatically covered + without maintaining a manual whitelist. + """ + engine_fields = frozenset(f.name for f in fields(OmniEngineArgs)) + # Orchestrator-only keys are handled separately; shared keys need to + # flow to both orchestrator and engine, so exclude them here. + return (engine_fields - orchestrator_field_names()) | _DEPLOY_RUNTIME_OVERRIDE_FIELDS def internal_blacklist_keys() -> frozenset[str]: @@ -647,23 +634,3 @@ def orchestrator_args_from_argparse(args: Any) -> OrchestratorArgs: if value is not None or f.default is None: kwargs[f.name] = value return OrchestratorArgs(**kwargs) - - -def nullify_stage_engine_defaults(parser: argparse.ArgumentParser) -> None: - """Reset stage-level engine flag defaults to ``None``; preserve real - default in help text. Only deploy-YAML override fields are touched. - Idempotent.""" - override_dests = deploy_override_field_names() - - for action in parser._actions: - if action.dest in ("help", "version") or not action.option_strings: - continue - if action.dest not in override_dests: - continue - if action.default is None or action.default is argparse.SUPPRESS: - continue - if action.help and "(default:" not in action.help and "%(default)" not in action.help: - action.help = f"{action.help} (default: {action.default})" - action.default = None - - parser._omni_nullified = True # type: ignore[attr-defined] diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index a37afd24b4f..e5a0942dc88 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -1399,7 +1399,8 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st stage_configs_path = kwargs.get("stage_configs_path", None) deploy_config_path = kwargs.pop("deploy_config", None) stage_overrides_json = kwargs.pop("stage_overrides", None) - kwargs.pop("_cli_explicit_keys", None) + # Keep _explicit_cli_keys in kwargs so it flows to + # load_stage_configs_from_model where it filters cli_overrides. explicit_stage_configs = kwargs.pop("stage_configs", None) if explicit_stage_configs is not None: logger.warning( diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index 9f2bef26776..b77f66246ed 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -19,7 +19,6 @@ from vllm.logger import init_logger from vllm.utils.argparse_utils import FlexibleArgumentParser -from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults from vllm_omni.entrypoints.cli.logo import log_logo from vllm_omni.entrypoints.openai.api_server import omni_run_server @@ -485,8 +484,6 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu # Stash via type(self) so the docs hook (which execs this function in a # sandboxed globals dict via ``DummySelf``) doesn't fail on a NameError. type(self)._parser = serve_parser - - nullify_stage_engine_defaults(serve_parser) return serve_parser @@ -542,7 +539,13 @@ def run_headless(args: argparse.Namespace) -> None: raise ValueError("headless mode requires worker_backend=multi_process") args_dict = vars(args).copy() - args_dict.pop("_cli_explicit_keys", None) + # Propagate explicit CLI keys so load_stage_configs_from_model filters + # out argparse defaults, matching the online serving path. + if "_explicit_cli_keys" not in args_dict: + from vllm_omni.entrypoints.utils import detect_explicit_cli_keys + args_dict["_explicit_cli_keys"] = detect_explicit_cli_keys( + sys.argv[1:], None + ) config_path, stage_configs = load_and_resolve_stage_configs( model, args_dict.get("stage_configs_path"), diff --git a/vllm_omni/entrypoints/omni_base.py b/vllm_omni/entrypoints/omni_base.py index 75d3ee2901e..4918fffe9bc 100644 --- a/vllm_omni/entrypoints/omni_base.py +++ b/vllm_omni/entrypoints/omni_base.py @@ -95,22 +95,15 @@ def from_cli_args( parser: argparse.ArgumentParser | None = None, **overrides: Any, ) -> OmniBase: - """Build from argparse. If ``parser`` is passed and not yet nullified, - un-typed engine fields are reset to ``None``.""" - kwargs: dict[str, Any] = {k: v for k, v in vars(args).items() if not k.startswith("_")} - - if parser is not None and not getattr(parser, "_omni_nullified", False): - from vllm_omni.engine.arg_utils import ( - deploy_override_field_names, - ) - from vllm_omni.entrypoints.utils import detect_explicit_cli_keys - - explicit = detect_explicit_cli_keys(sys.argv[1:], parser) or set() - override_dests = deploy_override_field_names() - for key in list(kwargs): - if key in override_dests and key not in explicit: - kwargs[key] = None + """Build from argparse, injecting ``_explicit_cli_keys`` so that + downstream stage-config merging only treats user-typed flags as + overrides.""" + from vllm_omni.entrypoints.utils import detect_explicit_cli_keys + kwargs: dict[str, Any] = {k: v for k, v in vars(args).items() if not k.startswith("_")} + # Inject explicit-keys set; load_stage_configs_from_model will use + # it to filter cli_overrides, keeping only user-typed CLI flags. + kwargs["_explicit_cli_keys"] = detect_explicit_cli_keys(sys.argv[1:], parser) kwargs.update(overrides) return cls(**kwargs) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 646bbd6f913..245bcdbaff9 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -11,6 +11,7 @@ # Image generation API imports import random +import sys import time from argparse import Namespace from collections.abc import AsyncIterator @@ -335,6 +336,14 @@ async def omni_run_server(args, **uvicorn_kwargs) -> None: Unified entry point that automatically handles both LLM and Diffusion models through AsyncOmni, which manages multi-stage pipelines. """ + # Compute explicit CLI keys once at the unified entry point so that + # downstream stage-config merging only treats user-typed flags as + # overrides (argparse defaults never shadow deploy YAML values). + if not hasattr(args, "_explicit_cli_keys"): + from vllm_omni.entrypoints.utils import detect_explicit_cli_keys + + args._explicit_cli_keys = detect_explicit_cli_keys(sys.argv[1:], None) + # Suppress Pydantic serialization warnings globally for multimodal content # (e.g., when ChatMessage.content is a list instead of str) import warnings as warnings_module @@ -520,6 +529,10 @@ async def build_async_omni_from_stage_config( try: kwargs = vars(args).copy() kwargs.pop("model", None) + # Propagate the set of user-typed CLI flags (computed in + # OmniServeCommand.cmd) so that downstream stage-config merging + # can distinguish explicit overrides from argparse defaults. + kwargs.setdefault("_explicit_cli_keys", None) async_omni = AsyncOmni(model=args.model, **kwargs) # # Don't keep the dummy data in memory @@ -2881,8 +2894,6 @@ async def omni_wakeup(request: OmniWakeupRequest, raw_request: Request): from vllm.entrypoints.openai.cli_args import make_arg_parser - from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults - parser = argparse.ArgumentParser(description="vLLM-Omni OpenAI-Compatible REST API server") parser = make_arg_parser(parser) registered_flags = set() @@ -2894,8 +2905,8 @@ async def omni_wakeup(request: OmniWakeupRequest, raw_request: Request): parser.add_argument( "--enable-sleep-mode", action="store_true", default=False, help="Enable GPU memory pool for sleep mode." ) - nullify_stage_engine_defaults(parser) args = parser.parse_args() + # _explicit_cli_keys will be computed in omni_run_server if not set if not hasattr(args, "model_tag"): setattr(args, "model_tag", args.model) if hasattr(args, "model_tag") and args.model_tag is None: diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index 5541fde30e7..fcfeff4a56c 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -383,7 +383,13 @@ def load_stage_configs_from_model( if base_engine_args is None: base_engine_args = {} + # Pop the explicit-keys set injected by the CLI entry point. + # When present, only user-typed CLI flags become overrides; + # argparse defaults are discarded so they never shadow YAML values. + explicit_cli_keys: set[str] | None = base_engine_args.pop("_explicit_cli_keys", None) cli_overrides = _convert_dataclasses_to_dict(dict(base_engine_args)) + if explicit_cli_keys is not None: + cli_overrides = {k: v for k, v in cli_overrides.items() if k in explicit_cli_keys} if stage_overrides: for stage_id_str, overrides in stage_overrides.items(): for key, val in overrides.items():