Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 107 additions & 1 deletion tests/entrypoints/test_serve.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Unit tests for the Omni serve CLI helpers."""
"""Unit tests for the Omni serve CLI helpers and detect_explicit_cli_keys."""

from __future__ import annotations

Expand All @@ -13,6 +13,112 @@
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]


# ============================================================================
# detect_explicit_cli_keys — parser-aware mode
# ============================================================================


def test_detect_explicit_cli_keys_with_parser_basic() -> None:
"""Parser-aware mode returns correct dest names."""
parser = argparse.ArgumentParser()
parser.add_argument("--max-num-seqs", type=int, default=256)
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
parser.add_argument("--dtype", type=str, default="auto")

argv = ["--max-num-seqs", "64", "--dtype", "float16"]
explicit = detect_explicit_cli_keys(argv, parser)

assert explicit == {"max_num_seqs", "dtype"}
assert "gpu_memory_utilization" not in explicit


def test_detect_explicit_cli_keys_with_parser_equals_syntax() -> None:
"""Parser-aware mode handles --flag=value syntax."""
parser = argparse.ArgumentParser()
parser.add_argument("--max-num-seqs", type=int, default=256)

explicit = detect_explicit_cli_keys(["--max-num-seqs=64"], parser)
assert "max_num_seqs" in explicit


def test_detect_explicit_cli_keys_with_parser_alias() -> None:
"""Parser-aware mode resolves alias flags to the canonical dest."""
parser = argparse.ArgumentParser()
parser.add_argument("--usp", "--ulysses-degree", type=int, default=1, dest="ulysses_degree")

explicit = detect_explicit_cli_keys(["--usp", "4"], parser)
assert "ulysses_degree" in explicit
assert "usp" not in explicit


def test_detect_explicit_cli_keys_with_parser_store_false() -> None:
"""Parser-aware mode maps --disable-X to its actual dest."""
parser = argparse.ArgumentParser()
parser.add_argument("--disable-log-requests", action="store_true", dest="disable_log_requests")

explicit = detect_explicit_cli_keys(["--disable-log-requests"], parser)
assert "disable_log_requests" in explicit


def test_detect_explicit_cli_keys_empty_argv() -> None:
"""No flags typed → empty set."""
parser = argparse.ArgumentParser()
parser.add_argument("--max-num-seqs", type=int, default=256)

assert detect_explicit_cli_keys([], parser) == set()


def test_detect_explicit_cli_keys_ignores_positional_args() -> None:
"""Positional arguments (no -- prefix) are ignored."""
parser = argparse.ArgumentParser()
parser.add_argument("--dtype", type=str, default="auto")

explicit = detect_explicit_cli_keys(["serve", "fake-model", "--dtype", "float16"], parser)
assert explicit == {"dtype"}


# ============================================================================
# detect_explicit_cli_keys — heuristic fallback (parser=None)
# ============================================================================


def test_detect_explicit_cli_keys_heuristic_basic() -> None:
"""Heuristic mode converts hyphens to underscores."""
explicit = detect_explicit_cli_keys(["--max-num-seqs", "64", "--dtype", "float16"], None)
assert "max_num_seqs" in explicit
assert "dtype" in explicit


def test_detect_explicit_cli_keys_heuristic_no_prefix() -> None:
"""Heuristic mode strips --no- prefix for BooleanOptionalAction compat."""
explicit = detect_explicit_cli_keys(["--no-async-chunk"], None)
assert "no_async_chunk" in explicit
assert "async_chunk" in explicit # also adds the stripped form


def test_detect_explicit_cli_keys_heuristic_equals() -> None:
"""Heuristic mode handles --flag=value syntax."""
explicit = detect_explicit_cli_keys(["--gpu-memory-utilization=0.8"], None)
assert "gpu_memory_utilization" in explicit


def test_detect_explicit_cli_keys_user_types_default_value() -> None:
"""When user explicitly types a value that equals the default, it must
still appear in explicit keys — this is the key advantage over the old
defaults-comparison approach."""
parser = argparse.ArgumentParser()
parser.add_argument("--max-num-seqs", type=int, default=256)

# User explicitly types the default value
explicit = detect_explicit_cli_keys(["--max-num-seqs", "256"], parser)
assert "max_num_seqs" in explicit


# ============================================================================
# serve parser integration
# ============================================================================


def test_serve_parser_accepts_no_async_chunk_and_marks_it_explicit() -> None:
"""``--no-async-chunk`` should parse to ``async_chunk=False`` and mark the
shared deploy-level dest as explicitly provided by the user."""
Expand Down
132 changes: 132 additions & 0 deletions tests/entrypoints/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,138 @@ def test_load_and_resolve_with_kwargs(self):
assert "dtype" in stage_configs[0]["engine_args"]


class TestExplicitCliKeysFiltering:
"""Tests for _explicit_cli_keys filtering in load_stage_configs_from_model."""

def test_explicit_cli_keys_filters_overrides(self, mocker: MockerFixture):
"""Only user-typed CLI flags should survive as cli_overrides;
argparse defaults must be discarded."""
mock_factory = mocker.patch(
"vllm_omni.entrypoints.utils.StageConfigFactory.create_from_model",
return_value=None,
)
mocker.patch(
"vllm_omni.entrypoints.utils.resolve_model_config_path",
return_value=None,
)

from vllm_omni.entrypoints.utils import load_stage_configs_from_model

base_engine_args = {
"max_num_seqs": 64, # user typed
"dtype": "auto", # argparse default (not typed)
"enforce_eager": False, # argparse default (not typed)
"_explicit_cli_keys": {"max_num_seqs"},
}

load_stage_configs_from_model(
model="fake-model",
base_engine_args=base_engine_args,
)

# StageConfigFactory.create_from_model receives filtered overrides
call_kwargs = mock_factory.call_args
cli_overrides = call_kwargs.kwargs.get("cli_overrides") or call_kwargs[1].get("cli_overrides")
assert "max_num_seqs" in cli_overrides
assert "dtype" not in cli_overrides
assert "enforce_eager" not in cli_overrides
assert "_explicit_cli_keys" not in cli_overrides

def test_no_explicit_cli_keys_passes_all_overrides(self, mocker: MockerFixture):
"""When _explicit_cli_keys is absent (programmatic use), all args
pass through as overrides for backward compatibility."""
mock_factory = mocker.patch(
"vllm_omni.entrypoints.utils.StageConfigFactory.create_from_model",
return_value=None,
)
mocker.patch(
"vllm_omni.entrypoints.utils.resolve_model_config_path",
return_value=None,
)

from vllm_omni.entrypoints.utils import load_stage_configs_from_model

base_engine_args = {
"max_num_seqs": 64,
"dtype": "auto",
}

load_stage_configs_from_model(
model="fake-model",
base_engine_args=base_engine_args,
)

call_kwargs = mock_factory.call_args
cli_overrides = call_kwargs.kwargs.get("cli_overrides") or call_kwargs[1].get("cli_overrides")
assert "max_num_seqs" in cli_overrides
assert "dtype" in cli_overrides

def test_explicit_cli_keys_with_stage_overrides(self, mocker: MockerFixture):
"""Per-stage overrides from --stage-overrides are always included,
even when _explicit_cli_keys filters out global defaults."""
mock_factory = mocker.patch(
"vllm_omni.entrypoints.utils.StageConfigFactory.create_from_model",
return_value=None,
)
mocker.patch(
"vllm_omni.entrypoints.utils.resolve_model_config_path",
return_value=None,
)

from vllm_omni.entrypoints.utils import load_stage_configs_from_model

base_engine_args = {
"max_num_seqs": 64,
"dtype": "auto", # will be filtered out
"_explicit_cli_keys": {"max_num_seqs"},
}
stage_overrides = {"0": {"gpu_memory_utilization": 0.8}}

load_stage_configs_from_model(
model="fake-model",
base_engine_args=base_engine_args,
stage_overrides=stage_overrides,
)

call_kwargs = mock_factory.call_args
cli_overrides = call_kwargs.kwargs.get("cli_overrides") or call_kwargs[1].get("cli_overrides")
assert "max_num_seqs" in cli_overrides
assert "dtype" not in cli_overrides
assert "stage_0_gpu_memory_utilization" in cli_overrides

def test_explicit_cli_keys_preserves_dataclass_values(self, mocker: MockerFixture):
"""Dataclass-typed values (e.g. profiler_config dict) survive
filtering when the key is in _explicit_cli_keys."""
mock_factory = mocker.patch(
"vllm_omni.entrypoints.utils.StageConfigFactory.create_from_model",
return_value=None,
)
mocker.patch(
"vllm_omni.entrypoints.utils.resolve_model_config_path",
return_value=None,
)

from vllm_omni.entrypoints.utils import load_stage_configs_from_model

profiler_dict = {"profiler": "torch", "torch_profiler_dir": "./test"}
base_engine_args = {
"profiler_config": profiler_dict,
"dtype": "auto",
"_explicit_cli_keys": {"profiler_config"},
}

load_stage_configs_from_model(
model="fake-model",
base_engine_args=base_engine_args,
)

call_kwargs = mock_factory.call_args
cli_overrides = call_kwargs.kwargs.get("cli_overrides") or call_kwargs[1].get("cli_overrides")
assert "profiler_config" in cli_overrides
assert cli_overrides["profiler_config"]["profiler"] == "torch"
assert "dtype" not in cli_overrides


class TestLoadStageConfigsFromYaml:
"""Regression tests for stage-config loading and merging."""

Expand Down
97 changes: 49 additions & 48 deletions tests/test_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,72 +353,73 @@ class _AmbiguousEngine:
assert any("both OrchestratorArgs" in r.message for r in caplog.records)


# Sentinel-default precedence invariants (#3035)

# ============================================================================
# deploy_override_field_names — dynamic derivation (replaces whitelist).
# ============================================================================

def _build_full_serve_parser():
from vllm.utils.argparse_utils import FlexibleArgumentParser

def test_deploy_override_field_names_covers_all_engine_fields():
"""deploy_override_field_names must dynamically include every OmniEngineArgs
field (minus orchestrator-only keys) so that YAML values for any vLLM
parameter are never silently overwritten by argparse defaults."""
try:
from vllm.entrypoints.openai.cli_args import make_arg_parser
except ImportError:
pytest.skip("vllm parser not importable")
return make_arg_parser(FlexibleArgumentParser())

from vllm_omni.engine.arg_utils import OmniEngineArgs, deploy_override_field_names
except Exception as exc:
pytest.skip(f"OmniEngineArgs not importable: {exc}")

def test_nullify_stage_engine_defaults_resets_inherited_defaults():
import argparse
override_names = deploy_override_field_names()
engine_fields = {f.name for f in fields(OmniEngineArgs)}
orch_fields = orchestrator_field_names()

from vllm_omni.engine.arg_utils import (
deploy_override_field_names,
nullify_stage_engine_defaults,
# Every engine field that is not orchestrator-only should be in the set.
missing = (engine_fields - orch_fields) - override_names
assert not missing, (
f"Engine fields missing from deploy_override_field_names: {sorted(missing)}. "
f"These fields' YAML values could be silently overridden by argparse defaults."
)

parser = _build_full_serve_parser()
nullify_stage_engine_defaults(parser)

override_dests = deploy_override_field_names()
offenders = [
(a.dest, a.default)
for a in parser._actions
if a.dest not in ("help", "version")
and a.option_strings
and a.dest in override_dests
and a.default is not None
and a.default is not argparse.SUPPRESS
]
assert not offenders, f"Stage flags with non-None defaults after nullify: {offenders}"
def test_deploy_override_field_names_includes_profiler_config():
"""Regression: profiler_config must be covered by deploy overrides."""
from vllm_omni.engine.arg_utils import deploy_override_field_names

assert "profiler_config" in deploy_override_field_names()

def test_non_override_flags_keep_real_defaults_after_nullify():
import argparse

from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
# ============================================================================
# OmniEngineArgs.__post_init__ — dict → ProfilerConfig conversion.
# ============================================================================

parser = argparse.ArgumentParser()
parser.add_argument("--hsdp-shard-size", type=int, default=-1, help="HSDP shard size.")
parser.add_argument("--max-num-seqs", type=int, default=64, help="Max num seqs.")
nullify_stage_engine_defaults(parser)

hsdp = next(a for a in parser._actions if a.dest == "hsdp_shard_size")
max_num_seqs = next(a for a in parser._actions if a.dest == "max_num_seqs")
assert hsdp.default == -1
assert max_num_seqs.default is None
def test_omniengineargs_converts_dict_profiler_config():
"""profiler_config from YAML arrives as a dict; __post_init__ must convert
it to a ProfilerConfig dataclass so downstream code works correctly."""
try:
from vllm.config import ProfilerConfig

from vllm_omni.engine.arg_utils import OmniEngineArgs
except Exception as exc:
pytest.skip(f"Not importable: {exc}")

def test_help_text_preserves_default_after_nullify():
# Real defaults must stay visible in --help even though parser stores None.
import argparse
ea = OmniEngineArgs(
model="fake-model",
profiler_config={"profiler": "torch", "torch_profiler_dir": "./test-dir"},
)
assert isinstance(ea.profiler_config, ProfilerConfig)
assert ea.profiler_config.profiler == "torch"
assert ea.profiler_config.torch_profiler_dir == "./test-dir"

from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults

parser = argparse.ArgumentParser()
parser.add_argument("--max-num-seqs", type=int, default=42, help="Example knob.")
nullify_stage_engine_defaults(parser)
def test_omniengineargs_preserves_none_profiler_config():
"""When profiler_config is not set (None), it should remain None."""
try:
from vllm_omni.engine.arg_utils import OmniEngineArgs
except Exception as exc:
pytest.skip(f"Not importable: {exc}")

action = next(a for a in parser._actions if a.dest == "max_num_seqs")
assert action.default is None
assert "(default: 42)" in action.help
ea = OmniEngineArgs(model="fake-model")
# Default from vLLM — should not be converted or raise
assert ea.profiler_config is None or hasattr(ea.profiler_config, "profiler")


_OMNIENGINEARGS_USER_INPUT_FIELDS = frozenset(
Expand Down
Loading
Loading