-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[Config Refactor] Validate Engine Args From CLI #3008
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,11 @@ | ||
| import argparse | ||
| import dataclasses | ||
| import json | ||
| import os | ||
| import tempfile | ||
| from dataclasses import dataclass, field, fields | ||
| from typing import Any | ||
| from typing import Any, get_type_hints | ||
|
|
||
| from pydantic import TypeAdapter | ||
| from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs | ||
| from vllm.logger import init_logger | ||
|
|
||
|
|
@@ -146,7 +146,7 @@ class OmniEngineArgs(EngineArgs): | |
| quantization_config: Any | None = None | ||
| worker_type: str | None = None | ||
| task_type: str | None = None | ||
| worker_cls: str = None | ||
| worker_cls: str | None = None | ||
| enable_sleep_mode: bool = False | ||
| omni: bool = False | ||
|
|
||
|
|
@@ -183,9 +183,29 @@ def __post_init__(self) -> None: | |
|
|
||
| @classmethod | ||
| def from_cli_args(cls, args: argparse.Namespace) -> "OmniEngineArgs": | ||
| attrs = [attr.name for attr in dataclasses.fields(cls)] | ||
| engine_args = cls(**{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}) | ||
| return engine_args | ||
| """Create an instance of this class from an argparse namespace.""" | ||
| validated_dict = cls.get_validated_args_dict(vars(args)) | ||
| return cls(**validated_dict) | ||
|
|
||
| @classmethod | ||
| def get_validated_args_dict(cls, engine_kwargs: dict) -> dict: | ||
| """Given a proposed dict of engine kwargs, warn if we receive any keys | ||
| that are unknown, and validate the values of known keys, raising if we | ||
| get incorrect types.""" | ||
| validated_kwargs = {} | ||
| field_type_map = get_type_hints(cls) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| skip_keys = [] | ||
|
|
||
| for key, value in engine_kwargs.items(): | ||
| if key not in field_type_map: | ||
| skip_keys.append(key) | ||
| else: | ||
| field_type = field_type_map[key] | ||
| if field_type is not Any: | ||
| value = TypeAdapter(field_type).validate_python(value, strict=True) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| validated_kwargs[key] = value | ||
| logger.debug("OmniEngineArgs filtered invalid keys: %s", skip_keys) | ||
| return validated_kwargs | ||
|
|
||
| def _ensure_omni_models_registered(self): | ||
| if hasattr(self, "_omni_models_registered"): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing positive round-trip test — only the log-fires path is asserted. Add a case that constructs a valid
Namespace, callsfrom_cli_args, and verifies values pass throughTypeAdapterunchanged.