Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tests/entrypoints/test_async_omni_diffusion_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from vllm.utils.argparse_utils import FlexibleArgumentParser

from vllm_omni.config.stage_config import deploy_override_field_names
from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
from vllm_omni.entrypoints.cli.serve import OmniServeCommand, _create_default_diffusion_stage_cfg

Expand All @@ -30,6 +31,15 @@ def test_default_stage_config_includes_cache_backend():
assert engine_args["model_stage"] == "diffusion"


def test_default_stage_config_ignores_none_deploy_overrides():
"""Ensure nullified deploy override defaults do not alter diffusion defaults."""
baseline = AsyncOmniEngine._create_default_diffusion_stage_cfg({})[0]
nullified_overrides = {name: None for name in deploy_override_field_names()}
stage_cfg = AsyncOmniEngine._create_default_diffusion_stage_cfg(nullified_overrides)[0]

assert stage_cfg == baseline


def test_default_cache_config_used_when_missing():
"""Ensure default cache_config is synthesized when only backend is given."""
stage_cfg = AsyncOmniEngine._create_default_diffusion_stage_cfg(
Expand Down
1 change: 1 addition & 0 deletions tests/helpers/stage_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ def delete_by_path(config_dict: dict, path: str) -> None:
"max_num_seqs": 1,
"gpu_memory_utilization": 0.9,
"enforce_eager": True,
"enable_prefix_caching": False,
"max_num_batched_tokens": 16384,
"max_model_len": 16384,
"skip_mm_profiling": True,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,8 @@ def _build_full_serve_parser():
def test_nullify_stage_engine_defaults_resets_inherited_defaults():
import argparse

from vllm_omni.config.stage_config import deploy_override_field_names
from vllm_omni.engine.arg_utils import (
deploy_override_field_names,
nullify_stage_engine_defaults,
)

Expand Down
69 changes: 67 additions & 2 deletions tests/test_config_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def test_to_omegaconf_basic(self):
assert omega_config.engine_args.worker_type == "ar"
assert omega_config.final_output is True
assert omega_config.final_output_type == "text"
assert "max_num_seqs" not in omega_config.engine_args
# Legacy field name for backward compatibility
assert omega_config.engine_input_source == []

Expand Down Expand Up @@ -146,6 +147,24 @@ def test_to_omegaconf_max_num_seqs_in_engine_args(self):
omega_config = config.to_omegaconf()
assert omega_config.engine_args.max_num_seqs == 32

def test_to_omegaconf_omits_none_deploy_overrides_for_engine_args(self):
"""None deploy overrides must fall through to EngineArgs defaults."""
from vllm_omni.config.stage_config import deploy_override_field_names

config = StageConfig(
stage_id=0,
model_stage="thinker",
runtime_overrides={name: None for name in deploy_override_field_names()},
)

omega_config = config.to_omegaconf()
engine_args = dict(omega_config.engine_args)

assert "devices" not in engine_args
assert "max_batch_size" not in engine_args
for name in deploy_override_field_names() - {"devices"}:
assert name not in engine_args


class TestModelPipeline:
"""Tests for ModelPipeline class."""
Expand Down Expand Up @@ -802,6 +821,40 @@ def test_register_and_lookup(self):


class TestDeployConfigLoading:
def test_deploy_override_fields_include_deploy_schema_fields(self):
from vllm_omni.config.stage_config import deploy_override_field_names

expected_fields = {
"async_chunk",
"async_scheduling",
"config_format",
"data_parallel_size",
"devices",
"disable_hybrid_kv_cache_manager",
"distributed_executor_backend",
"dtype",
"enable_chunked_prefill",
"enable_flashinfer_autotune",
"enable_prefix_caching",
"enforce_eager",
"gpu_memory_utilization",
"load_format",
"max_model_len",
"max_num_batched_tokens",
"max_num_seqs",
"mm_processor_cache_gb",
"pipeline_parallel_size",
"profiler_config",
"quantization",
"skip_mm_profiling",
"subtalker_sampling_params",
"tensor_parallel_size",
"tokenizer_mode",
"trust_remote_code",
}

assert expected_fields == deploy_override_field_names()

def test_load_deploy_config(self):
from pathlib import Path

Expand All @@ -817,6 +870,17 @@ def test_load_deploy_config(self):
assert deploy.connectors is not None
assert deploy.platforms is not None

voxtral_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "voxtral_tts.yaml"
if voxtral_path.exists():
voxtral_deploy = load_deploy_config(voxtral_path)
assert voxtral_deploy.stages[0].config_format == "mistral"
assert voxtral_deploy.stages[0].load_format == "mistral"
assert voxtral_deploy.stages[0].tokenizer_mode == "mistral"
assert not any(
name in voxtral_deploy.stages[0].engine_extras
for name in ("config_format", "load_format", "tokenizer_mode")
)

def test_merge_pipeline_deploy(self):
from pathlib import Path

Expand Down Expand Up @@ -1011,7 +1075,8 @@ def test_ci_inherits_from_main(self):
deploy = load_deploy_config(ci_path)
assert len(deploy.stages) == 3
# CI overrides
assert deploy.stages[0].engine_extras.get("load_format") == "dummy"
assert deploy.stages[0].load_format == "dummy"
assert "load_format" not in deploy.stages[0].engine_extras
assert deploy.stages[0].max_num_seqs == 5
# Inherited from base
assert deploy.stages[0].gpu_memory_utilization == 0.9
Expand Down Expand Up @@ -1216,7 +1281,7 @@ def test_typed_kwarg_overrides_yaml(self):
def test_none_value_skipped_yaml_wins(self):
stages = self._stages({"max_num_seqs": None})
assert stages[2].runtime_overrides.get("max_num_seqs") is None
assert stages[2].yaml_engine_args.get("max_num_seqs") == 1
assert "max_num_seqs" not in stages[2].yaml_engine_args

def test_empty_kwargs_yaml_only(self):
stages = self._stages({})
Expand Down
44 changes: 32 additions & 12 deletions vllm_omni/config/stage_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,18 +402,26 @@ class StageDeployConfig:
"""

stage_id: int
max_num_seqs: int = 64
gpu_memory_utilization: float = 0.9
tensor_parallel_size: int = 1
enforce_eager: bool = False
max_num_batched_tokens: int = 32768
max_num_seqs: int | None = None
gpu_memory_utilization: float | None = None
tensor_parallel_size: int | None = None
enforce_eager: bool | None = None
max_num_batched_tokens: int | None = None
max_model_len: int | None = None
async_scheduling: bool | None = None
devices: str = "0"
output_connectors: dict[str, str] | None = None
input_connectors: dict[str, str] | None = None
default_sampling_params: dict[str, Any] | None = None
subtalker_sampling_params: dict[str, Any] | None = None
profiler_config: dict[str, Any] | None = None
disable_hybrid_kv_cache_manager: bool | None = None
mm_processor_cache_gb: float | None = None
skip_mm_profiling: bool | None = None
enable_flashinfer_autotune: bool | None = None
config_format: str | None = None
load_format: str | None = None
tokenizer_mode: str | None = None
engine_extras: dict[str, Any] = field(default_factory=dict)


Expand All @@ -438,14 +446,14 @@ class DeployConfig:
pipeline: str | None = None

# === Pipeline-wide engine settings (applied uniformly to every stage) ===
trust_remote_code: bool = True
trust_remote_code: bool | None = None
distributed_executor_backend: str | None = None
dtype: str | None = None
quantization: str | None = None
enable_prefix_caching: bool = False
enable_prefix_caching: bool | None = None
enable_chunked_prefill: bool | None = None
data_parallel_size: int = 1
pipeline_parallel_size: int = 1
data_parallel_size: int | None = None
pipeline_parallel_size: int | None = None


_STAGE_NON_ENGINE_KEYS = frozenset(
Expand Down Expand Up @@ -689,6 +697,18 @@ def _select_processor_funcs(
)


def deploy_override_field_names() -> frozenset[str]:
"""Return deploy-schema fields whose CLI defaults must not override YAML."""
return (
frozenset(_STAGE_DEPLOY_FIELDS)
| frozenset(_PIPELINE_WIDE_ENGINE_FIELDS)
| {
"async_chunk",
"devices",
}
)


def _build_engine_args(
ps: StagePipelineConfig,
ds: StageDeployConfig | None,
Expand Down Expand Up @@ -861,13 +881,15 @@ def to_omegaconf(self) -> Any:

# CLI overrides take precedence over YAML defaults
for key, value in self.runtime_overrides.items():
if value is None:
continue
if key not in ("devices", "max_batch_size"):
engine_args[key] = value

# Build runtime config from YAML defaults + CLI overrides
runtime: dict[str, Any] = dict(self.yaml_runtime)
runtime.setdefault("process", True)
if "devices" in self.runtime_overrides:
if self.runtime_overrides.get("devices") is not None:
runtime["devices"] = self.runtime_overrides["devices"]

# Legacy compat: migrate runtime.max_batch_size → engine_args.max_num_seqs
Expand All @@ -883,8 +905,6 @@ def to_omegaconf(self) -> Any:
effective_mbs = int(cli_mbs or legacy_mbs or 1)
engine_args.setdefault("max_num_seqs", effective_mbs)

engine_args.setdefault("max_num_seqs", 1)

# Build full config dict
config_dict: dict[str, Any] = {
"stage_id": self.stage_id,
Expand Down
4 changes: 2 additions & 2 deletions vllm_omni/deploy/cosyvoice3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ stages:
# near-identity repetition penalty forces vLLM to track
# output_token_ids for RAS (stop-token logit logsumexp).
repetition_penalty: 1.0001
disable_hybrid_kv_cache_manager: true
enable_prefix_caching: false
mm_processor_cache_gb: 0
skip_mm_profiling: true

Expand All @@ -54,5 +54,5 @@ stages:
from_stage_0: connector_of_shared_memory
default_sampling_params:
max_tokens: 2048
disable_hybrid_kv_cache_manager: true
enable_prefix_caching: false
skip_mm_profiling: true
1 change: 1 addition & 0 deletions vllm_omni/deploy/mimo_audio.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ stages:
max_num_seqs: 1
gpu_memory_utilization: 0.2
enforce_eager: true
enable_prefix_caching: false
async_scheduling: false
max_num_batched_tokens: 8192
max_model_len: 8192
Expand Down
2 changes: 2 additions & 0 deletions vllm_omni/deploy/qwen2_5_omni.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ stages:
max_num_seqs: 1
gpu_memory_utilization: 0.8
enforce_eager: true
enable_prefix_caching: false
mm_processor_cache_gb: 0
devices: "0"
default_sampling_params:
Expand Down Expand Up @@ -49,6 +50,7 @@ stages:
max_num_seqs: 1
gpu_memory_utilization: 0.15
enforce_eager: true
enable_prefix_caching: false
enable_flashinfer_autotune: false
async_scheduling: false
devices: "0"
Expand Down
2 changes: 2 additions & 0 deletions vllm_omni/deploy/qwen3_omni_moe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ connectors:
stages:
- stage_id: 0
gpu_memory_utilization: 0.9
enable_prefix_caching: false
devices: "0"
default_sampling_params:
temperature: 0.4
Expand All @@ -47,6 +48,7 @@ stages:
- stage_id: 2
gpu_memory_utilization: 0.1
enforce_eager: true
enable_prefix_caching: false
async_scheduling: false
max_num_batched_tokens: 51200
devices: "1"
Expand Down
1 change: 1 addition & 0 deletions vllm_omni/deploy/qwen3_tts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ stages:
max_num_seqs: 1
gpu_memory_utilization: 0.3
enforce_eager: true
enable_prefix_caching: false
async_scheduling: true
# Must be divisible by num_code_groups and cover (left_context + chunk).
# Prefill length is Q * num_frames (e.g. 16 * 2148 = 34368); keep
Expand Down
1 change: 1 addition & 0 deletions vllm_omni/deploy/voxcpm2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ stages:
max_num_batched_tokens: 4096
max_model_len: 4096
devices: "0"
trust_remote_code: true
default_sampling_params:
temperature: 0.0
top_p: 1.0
Expand Down
2 changes: 2 additions & 0 deletions vllm_omni/deploy/voxtral_tts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ stages:
max_num_seqs: 32
gpu_memory_utilization: 0.8
enforce_eager: false
enable_prefix_caching: false
async_scheduling: true
max_model_len: 4096
devices: "0"
Expand All @@ -48,6 +49,7 @@ stages:
max_num_seqs: 32
gpu_memory_utilization: 0.1
enforce_eager: true
enable_prefix_caching: false
async_scheduling: false
max_num_batched_tokens: 65536
max_model_len: 65536
Expand Down
40 changes: 2 additions & 38 deletions vllm_omni/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,50 +456,12 @@ class OrchestratorArgs:
}
)

_DEPLOY_ENGINE_ARG_OVERRIDE_FIELDS: frozenset[str] = frozenset(
{
# Capacity / scheduling.
"async_scheduling",
"max_model_len",
"max_num_batched_tokens",
"max_num_seqs",
# Memory / parallelism.
"data_parallel_size",
"gpu_memory_utilization",
"pipeline_parallel_size",
"tensor_parallel_size",
# Execution / loading.
"enforce_eager",
"distributed_executor_backend",
"dtype",
"quantization",
"trust_remote_code",
# Caching / chunking.
"async_chunk",
"enable_prefix_caching",
"enable_chunked_prefill",
# Model-specific engine extras.
"subtalker_sampling_params",
}
)

_DEPLOY_RUNTIME_OVERRIDE_FIELDS: frozenset[str] = frozenset(
{
"devices",
}
)


def orchestrator_field_names() -> frozenset[str]:
"""Return the names of every field on OrchestratorArgs."""
return frozenset(f.name for f in fields(OrchestratorArgs))


def deploy_override_field_names() -> frozenset[str]:
"""Return kwargs whose parser defaults must not override deploy YAML."""
return _DEPLOY_ENGINE_ARG_OVERRIDE_FIELDS | _DEPLOY_RUNTIME_OVERRIDE_FIELDS


def internal_blacklist_keys() -> frozenset[str]:
"""Return the set of CLI keys that must never be forwarded as per-stage
engine overrides.
Expand Down Expand Up @@ -653,6 +615,8 @@ def nullify_stage_engine_defaults(parser: argparse.ArgumentParser) -> None:
"""Reset stage-level engine flag defaults to ``None``; preserve real
default in help text. Only deploy-YAML override fields are touched.
Idempotent."""
from vllm_omni.config.stage_config import deploy_override_field_names

override_dests = deploy_override_field_names()

for action in parser._actions:
Expand Down
Loading
Loading