Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions vllm_omni/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
"""
Configuration module for vLLM-omni.
"""
from typing import Optional
from pydantic.dataclasses import dataclass
from pydantic import ConfigDict

from vllm.config import ModelConfig
from vllm.config import config

import vllm_omni.model_executor.models as me_models

from .stage_config import (
OmniStageConfig,
Expand All @@ -11,7 +19,28 @@
create_dit_stage_config,
)


@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
Comment on lines +23 to +24
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a docstring to explain what the @config decorator does and why it's being used here. This will improve readability and understanding of the code.

class OmniModelConfig(ModelConfig):
"""Configuration for Omni models, extending the base ModelConfig."""

stage_id: int = 0
model_stage: str = "thinker"
model_arch: str = "Qwen2_5OmniForConditionalGeneration"
engine_output_type: Optional[str] = None

@property
def registry(self):
return me_models.OmniModelRegistry

@property
def architectures(self) -> list[str]:
return [self.model_arch]


__all__ = [
"OmniModelConfig",
"OmniStageConfig",
"DiTConfig",
"DiTCacheConfig",
Expand Down
51 changes: 51 additions & 0 deletions vllm_omni/engine/arg_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Optional
from dataclasses import dataclass
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser

from vllm_omni.config import OmniModelConfig


@dataclass
class OmniEngineArgs(EngineArgs):
stage_id: int = 0
model_stage: str = "thinker"
model_arch: str = "Qwen2_5OmniForConditionalGeneration"
engine_output_type: Optional[str] = None

@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine."""
parser.add_argument(
"--engine-output-type",
type=str,
default=EngineArgs.engine_output_type,
help=(
"Declare EngineCoreOutput.output_type (e.g., 'text', 'image', "
"'text+image', 'latent'). This will be written into "
"model_config.engine_output_type for schedulers to use."
),
)
parser.add_argument("--model-stage", type=str, default=OmniEngineArgs.model_stage,
help="Declare model stage (e.g., 'thinker', 'talker', 'token2wav'). This will be written into model_config.model_stage for schedulers to use.")
return parser

def create_model_config(self) -> OmniModelConfig:
# First, get the base ModelConfig from the parent class
base_config = super().create_model_config()

# Create OmniModelConfig by copying all base config attributes
# and adding the new omni-specific fields
config_dict = base_config.__dict__.copy()

# Add the new omni-specific fields
config_dict['stage_id'] = self.stage_id
config_dict['model_stage'] = self.model_stage
config_dict['model_arch'] = self.model_arch
config_dict['engine_output_type'] = self.engine_output_type

# Create and return the OmniModelConfig instance
omni_config = OmniModelConfig(**config_dict)
omni_config.hf_config.architectures = omni_config.architectures

return omni_config
Loading