From 32e79ebe4e1071a7222c555c73231326f3df43b3 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 21 Apr 2026 16:05:03 +0000 Subject: [PATCH 1/3] validate engine args from cli Signed-off-by: Alex Brooks minor refactoring Signed-off-by: Alex Brooks --- tests/engine/test_arg_utils.py | 18 ++++++++++++++++++ vllm_omni/engine/arg_utils.py | 32 ++++++++++++++++++++++++++------ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 0d61f6a675b..d5dc5ba0369 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -6,6 +6,7 @@ import argparse import inspect +import logging from types import SimpleNamespace from unittest.mock import Mock @@ -173,6 +174,23 @@ def fake_create_model_config(self): }, } +def test_from_cli_args_raise_for_invalid_types(): + """Ensures that from_cli_args does type validation.""" + ns = argparse.Namespace(model=34983589) + with pytest.raises(ValidationError): + OmniEngineArgs.from_cli_args(ns) + + +def test_invalid_cli_args_logs_unrecognized_kwargs(caplog): + """Ensures that from_cli_args logs unrecognized namespace keys.""" + logger_module = "vllm_omni.engine.arg_utils" + logging.getLogger(logger_module).addHandler(caplog.handler) + # NOTE: for now we keep this at debug level, since depending on where we + # pull the kwargs dict from, this could be noisy, but it's still useful. + with caplog.at_level(logging.DEBUG, logger=logger_module): + OmniEngineArgs.from_cli_args(argparse.Namespace(garbage_arg="Foo")) + assert any("garbage_arg" in r.message for r in caplog.records) + def test_stage_specific_text_config_override(): """Ensure dependent attributes are updated when using stage-specific config.""" diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index 89139bf1b0b..ff308ba6130 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -1,11 +1,11 @@ import argparse -import dataclasses import json import os import tempfile from dataclasses import dataclass, field, fields -from typing import Any +from typing import Any, get_type_hints +from pydantic import TypeAdapter from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.logger import init_logger @@ -146,7 +146,7 @@ class OmniEngineArgs(EngineArgs): quantization_config: Any | None = None worker_type: str | None = None task_type: str | None = None - worker_cls: str = None + worker_cls: str | None = None enable_sleep_mode: bool = False omni: bool = False @@ -183,9 +183,29 @@ def __post_init__(self) -> None: @classmethod def from_cli_args(cls, args: argparse.Namespace) -> "OmniEngineArgs": - attrs = [attr.name for attr in dataclasses.fields(cls)] - engine_args = cls(**{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}) - return engine_args + """Create an instance of this class from an argparse namespace.""" + validated_dict = cls.get_validated_args_dict(vars(args)) + return cls(**validated_dict) + + @classmethod + def get_validated_args_dict(cls, engine_kwargs: dict) -> dict: + """Given a proposed dict of engine kwargs, warn if we receive any keys + that are unknown, and validate the values of known keys, raising if we + get incorrect types.""" + validated_kwargs = {} + field_type_map = get_type_hints(cls) + skip_keys = [] + + for key, value in engine_kwargs.items(): + if key not in field_type_map: + skip_keys.append(key) + else: + field_type = field_type_map[key] + if field_type is not Any: + TypeAdapter(field_type).validate_python(value) + validated_kwargs[key] = value + logger.debug("OmniEngineArgs filtered invalid keys: %s", skip_keys) + return validated_kwargs def _ensure_omni_models_registered(self): if hasattr(self, "_omni_models_registered"): From e2ca7a959c9917736ab9ad3fd6d1fdf0b1ce1268 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 22 Apr 2026 03:59:14 +0000 Subject: [PATCH 2/3] use strict pydantic validator Signed-off-by: Alex Brooks --- vllm_omni/engine/arg_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index ff308ba6130..07116e31982 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -202,7 +202,7 @@ def get_validated_args_dict(cls, engine_kwargs: dict) -> dict: else: field_type = field_type_map[key] if field_type is not Any: - TypeAdapter(field_type).validate_python(value) + value = TypeAdapter(field_type).validate_python(value, strict=True) validated_kwargs[key] = value logger.debug("OmniEngineArgs filtered invalid keys: %s", skip_keys) return validated_kwargs From 9bc6e8fd6310e731febe8e8b3b24b8ed6e167f24 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 22 Apr 2026 04:36:46 +0000 Subject: [PATCH 3/3] fmt Signed-off-by: Alex Brooks --- tests/engine/test_arg_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index d5dc5ba0369..bfc9d278cd2 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -174,6 +174,7 @@ def fake_create_model_config(self): }, } + def test_from_cli_args_raise_for_invalid_types(): """Ensures that from_cli_args does type validation.""" ns = argparse.Namespace(model=34983589)