Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions tests/engine/test_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import argparse
import inspect
import logging
from types import SimpleNamespace
from unittest.mock import Mock

Expand Down Expand Up @@ -174,6 +175,24 @@ def fake_create_model_config(self):
}


def test_from_cli_args_raise_for_invalid_types():
"""Ensures that from_cli_args does type validation."""
ns = argparse.Namespace(model=34983589)
with pytest.raises(ValidationError):
OmniEngineArgs.from_cli_args(ns)


def test_invalid_cli_args_logs_unrecognized_kwargs(caplog):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing positive round-trip test — only the log-fires path is asserted. Add a case that constructs a valid Namespace, calls from_cli_args, and verifies values pass through TypeAdapter unchanged.

"""Ensures that from_cli_args logs unrecognized namespace keys."""
logger_module = "vllm_omni.engine.arg_utils"
logging.getLogger(logger_module).addHandler(caplog.handler)
# NOTE: for now we keep this at debug level, since depending on where we
# pull the kwargs dict from, this could be noisy, but it's still useful.
with caplog.at_level(logging.DEBUG, logger=logger_module):
OmniEngineArgs.from_cli_args(argparse.Namespace(garbage_arg="Foo"))
assert any("garbage_arg" in r.message for r in caplog.records)


def test_stage_specific_text_config_override():
"""Ensure dependent attributes are updated when using stage-specific config."""
vllm_config = EngineArgs().create_model_config()
Expand Down
32 changes: 26 additions & 6 deletions vllm_omni/engine/arg_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import argparse
import dataclasses
import json
import os
import tempfile
from dataclasses import dataclass, field, fields
from typing import Any
from typing import Any, get_type_hints

from pydantic import TypeAdapter
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.logger import init_logger

Expand Down Expand Up @@ -146,7 +146,7 @@ class OmniEngineArgs(EngineArgs):
quantization_config: Any | None = None
worker_type: str | None = None
task_type: str | None = None
worker_cls: str = None
worker_cls: str | None = None
enable_sleep_mode: bool = False
omni: bool = False

Expand Down Expand Up @@ -183,9 +183,29 @@ def __post_init__(self) -> None:

@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "OmniEngineArgs":
attrs = [attr.name for attr in dataclasses.fields(cls)]
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)})
return engine_args
"""Create an instance of this class from an argparse namespace."""
validated_dict = cls.get_validated_args_dict(vars(args))
return cls(**validated_dict)

@classmethod
def get_validated_args_dict(cls, engine_kwargs: dict) -> dict:
"""Given a proposed dict of engine kwargs, warn if we receive any keys
that are unknown, and validate the values of known keys, raising if we
get incorrect types."""
validated_kwargs = {}
field_type_map = get_type_hints(cls)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_type_hints(cls) evaluates string annotations at call time and walks the parent EngineArgs from upstream vllm. If any inherited field's annotation references a symbol not importable from this module's namespace, this raises NameError rather than the cleaner dataclasses.fields(cls) + per-field resolution. Has bitten upstream before across vllm version bumps.

skip_keys = []

for key, value in engine_kwargs.items():
if key not in field_type_map:
skip_keys.append(key)
else:
field_type = field_type_map[key]
if field_type is not Any:
value = TypeAdapter(field_type).validate_python(value, strict=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strict=True rejects argparse None for fields typed as plain int (no | None). Argparse defaults frequently produce None for unprovided optional args. Add a positive-path test that runs against a realistic full argparse namespace, not just Namespace(model=...).

validated_kwargs[key] = value
logger.debug("OmniEngineArgs filtered invalid keys: %s", skip_keys)
return validated_kwargs

def _ensure_omni_models_registered(self):
if hasattr(self, "_omni_models_registered"):
Expand Down
Loading