Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions examples/offline_inference/bagel/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,30 +153,20 @@ def main():

from vllm_omni.entrypoints.omni import Omni

omni_kwargs = {}
omni_kwargs = vars(args).copy()
deploy_config = args.deploy_config
if args.think and deploy_config is None:
deploy_config = "vllm_omni/deploy/bagel_think.yaml"
print(f"[Info] Think mode enabled, using deploy config: {deploy_config}")
if deploy_config:
omni_kwargs["deploy_config"] = deploy_config

omni_kwargs.update(
{
"log_stats": args.log_stats,
"init_sleep_seconds": args.init_sleep_seconds,
"batch_timeout": args.batch_timeout,
"init_timeout": args.init_timeout,
"shm_threshold_bytes": args.shm_threshold_bytes,
"worker_backend": args.worker_backend,
"ray_address": args.ray_address,
"enable_diffusion_pipeline_profiler": args.enable_diffusion_pipeline_profiler,
}
)
if args.quantization:
omni_kwargs["quantization_config"] = args.quantization

omni = Omni.from_cli_args(args, model=model_name, **omni_kwargs)
# Override CLI --model with the derived model_name.
omni_kwargs["model"] = model_name
omni = Omni(**omni_kwargs)

formatted_prompts = []
for p in prompts:
Expand Down
3 changes: 2 additions & 1 deletion examples/offline_inference/dynin_omni/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch
from PIL import Image

from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults

TASK_CHOICES = ("t2t", "t2i", "t2s", "i2i", "i2t", "s2t", "v2t")

TASK_DEFAULT_RUNTIME = {
Expand Down Expand Up @@ -970,7 +972,6 @@ def parse_args(repo_root: Path) -> argparse.Namespace:
parser.add_argument("--vq-model-audio-local-files-only", action=argparse.BooleanOptionalAction, default=None)

parser.add_argument("--disable-hf-xet", action=argparse.BooleanOptionalAction, default=True)
from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults

nullify_stage_engine_defaults(parser)
return parser.parse_args()
Expand Down
2 changes: 2 additions & 0 deletions examples/offline_inference/mimo_audio/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser

from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniTokensPrompt

Expand Down Expand Up @@ -438,6 +439,7 @@ def parse_args():
"vllm_omni/deploy/mimo_audio.yaml based on the HF model_type.",
)

nullify_stage_engine_defaults(parser)
return parser.parse_args()


Expand Down
5 changes: 4 additions & 1 deletion examples/offline_inference/ming_flash_omni/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,10 @@ def main(args):
else:
query_result = query_func(processor)

omni = Omni.from_cli_args(args, model=MODEL_NAME)
omni_kwargs = vars(args).copy()
# override CLI --model with derived model_name
omni_kwargs["model"] = MODEL_NAME
omni = Omni(**omni_kwargs)

# Thinker sampling params
thinker_sampling_params = SamplingParams(
Expand Down
7 changes: 6 additions & 1 deletion examples/offline_inference/qwen2_5_omni/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.sampling_params import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser

from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
from vllm_omni.entrypoints.omni import Omni

SEED = 42
Expand Down Expand Up @@ -325,7 +326,10 @@ def main(args):
else:
query_result = query_func()
args.quantization_config = quantization_config
omni = Omni.from_cli_args(args, model=model_name)
omni_kwargs = vars(args).copy()
# Override CLI --model with the derived model_name.
omni_kwargs["model"] = model_name
Comment thread
xiaohajiayou marked this conversation as resolved.
omni = Omni(**omni_kwargs)
thinker_sampling_params = SamplingParams(
temperature=0.0, # Deterministic - no randomness
top_p=1.0, # Disable nucleus sampling
Expand Down Expand Up @@ -550,6 +554,7 @@ def parse_args():
default=False,
help="Use py_generator mode. The returned type of Omni.generate() is a Python Generator object.",
)
nullify_stage_engine_defaults(parser)
return parser.parse_args()


Expand Down
5 changes: 4 additions & 1 deletion examples/offline_inference/qwen3_omni/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,10 @@ def main(args):
else:
query_result = query_func()

omni = Omni.from_cli_args(args, model=model_name)
omni_kwargs = vars(args).copy()
# Override CLI --model with the derived model_name.
omni_kwargs["model"] = model_name
omni = Omni(**omni_kwargs)

thinker_sampling_params = SamplingParams(
temperature=0.9,
Expand Down
10 changes: 3 additions & 7 deletions examples/offline_inference/qwen3_omni/end2end_async_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from vllm.multimodal.media.audio import load_audio
from vllm.utils.argparse_utils import FlexibleArgumentParser

from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
from vllm_omni.entrypoints.async_omni import AsyncOmni

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -382,13 +383,7 @@ async def run_all(args):
print(f"[Info] Creating AsyncOmni with deploy_config={args.deploy_config}")
async_omni = None
try:
# ``from_cli_args`` expands vars(args) into kwargs and auto-captures
# ``_cli_explicit_keys`` from ``sys.argv[1:]`` so argparse defaults
# do not silently override deploy YAML values. Mirrors the
# ``EngineArgs.from_cli_args`` pattern used throughout vllm /
# vllm-omni. ``deploy_config=None`` (the default) falls through to
# the bundled ``vllm_omni/deploy/qwen3_omni_moe.yaml``.
async_omni = AsyncOmni.from_cli_args(args)
async_omni = AsyncOmni(**vars(args))

# Use default sampling params from stage config (they are pre-configured
# in the YAML for each stage).
Expand Down Expand Up @@ -594,6 +589,7 @@ def parse_args():
default=16000,
help="Sampling rate for audio loading.",
)
nullify_stage_engine_defaults(parser)
return parser.parse_args()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.assets.audio import AudioAsset
from vllm.multimodal.media.audio import load_audio

from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.model_executor.models.cosyvoice3.config import CosyVoice3Config
from vllm_omni.model_executor.models.cosyvoice3.tokenizer import get_qwen_tokenizer
Expand Down Expand Up @@ -44,6 +45,7 @@ def run_e2e():
required=True,
help="Path to tokenizer directory (e.g., <model_path>/CosyVoice-BlankEN).",
)
nullify_stage_engine_defaults(parser)
args = parser.parse_args()
# Ensure tokenizer directory exists
if not os.path.exists(args.tokenizer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from vllm.utils.argparse_utils import FlexibleArgumentParser

from vllm_omni import AsyncOmni, Omni
from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
from vllm_omni.model_executor.models.fish_speech.dac_utils import DAC_HOP_LENGTH, DAC_SAMPLE_RATE
from vllm_omni.model_executor.models.fish_speech.prompt_utils import (
build_fish_text_only_prompt_ids,
Expand Down Expand Up @@ -266,6 +267,7 @@ def parse_args():
default=False,
help="Stream audio chunks as they arrive via AsyncOmni.",
)
nullify_stage_engine_defaults(parser)
return parser.parse_args()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def parse_args():
def main():
args = parse_args()

omni = Omni.from_cli_args(args, model=args.model)
omni = Omni(**vars(args))

messages = get_messages(args.case, args.text)
decode_args = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import soundfile as sf

from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams

Expand Down Expand Up @@ -79,6 +80,7 @@ def run_e2e():
default=600,
help="Stage initialization timeout in seconds",
)
nullify_stage_engine_defaults(parser)
args = parser.parse_args()

if not os.path.exists(args.stage_config):
Expand Down
12 changes: 10 additions & 2 deletions examples/offline_inference/text_to_speech/qwen3_tts/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.utils.argparse_utils import FlexibleArgumentParser

from vllm_omni import AsyncOmni, Omni
from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -383,7 +384,10 @@ def main(args):
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)

omni = Omni.from_cli_args(args, model=model_name)
omni_kwargs = vars(args).copy()
# Override CLI --model with the derived model_name.
omni_kwargs["model"] = model_name
omni = Omni(**omni_kwargs)

batch_size = args.batch_size
for batch_start in range(0, len(inputs), batch_size):
Expand All @@ -399,7 +403,10 @@ async def main_streaming(args):
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)

omni = AsyncOmni.from_cli_args(args, model=model_name)
omni_kwargs = vars(args).copy()
# Override CLI --model with the derived model_name.
omni_kwargs["model"] = model_name
omni = AsyncOmni(**omni_kwargs)

for i, prompt in enumerate(inputs):
request_id = str(i)
Expand Down Expand Up @@ -541,6 +548,7 @@ def parse_args():
help="Number of prompts per batch (default: 1, sequential).",
)

nullify_stage_engine_defaults(parser)
return parser.parse_args()


Expand Down
2 changes: 2 additions & 0 deletions examples/offline_inference/text_to_speech/voxcpm/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.utils.argparse_utils import FlexibleArgumentParser

from vllm_omni import AsyncOmni, Omni
from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults

REPO_ROOT = Path(__file__).resolve().parents[4]
DEFAULT_SYNC_STAGE_CONFIG = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm.yaml"
Expand Down Expand Up @@ -185,6 +186,7 @@ def parse_args():
)
parser.add_argument("--stage-init-timeout", type=int, default=600, help="Stage initialization timeout in seconds.")
parser.add_argument("--log-stats", action="store_true", help="Enable vLLM Omni stats logging.")
nullify_stage_engine_defaults(parser)
args = parser.parse_args()
if (args.ref_audio is None) != (args.ref_text is None):
raise ValueError("Voice cloning requires --ref-audio and --ref-text together.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.utils.argparse_utils import FlexibleArgumentParser

from vllm_omni import Omni
from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults

REPO_ROOT = Path(__file__).resolve().parents[4]
DEFAULT_STAGE_CONFIGS_PATH = str(REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm2.yaml")
Expand Down Expand Up @@ -59,6 +60,7 @@ def parse_args():
default=None,
help="Optional transcript of --ref-audio (enables continuation mode).",
)
nullify_stage_engine_defaults(parser)
return parser.parse_args()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from vllm.utils.argparse_utils import FlexibleArgumentParser

from vllm_omni import AsyncOmni
from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
from vllm_omni.entrypoints.omni import Omni

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -304,6 +305,7 @@ def parse_args() -> Namespace:
default=None,
help="CFG alpha for flow-matching guidance (default: use value from stage config, typically 1.2).",
)
nullify_stage_engine_defaults(parser)
return parser.parse_args()


Expand Down
56 changes: 54 additions & 2 deletions tests/entrypoints/test_omni_entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,58 @@ def _patch_engine(monkeypatch: pytest.MonkeyPatch, engine: FakeAsyncOmniEngine)
monkeypatch.setattr("vllm_omni.entrypoints.omni_base.omni_snapshot_download", lambda model: model)


def test_from_cli_args_only_nulls_untyped_override_fields(monkeypatch: pytest.MonkeyPatch):
def test_direct_omni_with_nullified_parser_only_nulls_untyped_override_fields(
monkeypatch: pytest.MonkeyPatch,
):
from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
from vllm_omni.entrypoints.omni import Omni

captured: dict[str, Any] = {}

def fake_engine(*args: Any, **kwargs: Any) -> FakeAsyncOmniEngine:
captured.update(kwargs)
return FakeAsyncOmniEngine()

monkeypatch.setattr("vllm_omni.entrypoints.omni_base.AsyncOmniEngine", fake_engine)
monkeypatch.setattr("vllm_omni.entrypoints.omni_base.omni_snapshot_download", lambda model: model)

parser = argparse.ArgumentParser()
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
parser.add_argument("--hsdp-shard-size", type=int, default=-1)
nullify_stage_engine_defaults(parser)
args = parser.parse_args([])
args.model = "fake-model"

Omni(**vars(args))

assert captured["gpu_memory_utilization"] is None
assert captured["hsdp_shard_size"] == -1
Comment thread
xiaohajiayou marked this conversation as resolved.
assert "_cli_explicit_keys" not in captured


def test_from_cli_args_warns_and_forwards_without_internal_keys(
monkeypatch: pytest.MonkeyPatch,
):
captured: dict[str, Any] = {}

def fake_engine(*args: Any, **kwargs: Any) -> FakeAsyncOmniEngine:
captured.update(kwargs)
return FakeAsyncOmniEngine()

monkeypatch.setattr("vllm_omni.entrypoints.omni_base.AsyncOmniEngine", fake_engine)
monkeypatch.setattr("vllm_omni.entrypoints.omni_base.omni_snapshot_download", lambda model: model)

args = argparse.Namespace(model="fake-model", gpu_memory_utilization=0.9, _cli_explicit_keys={"model"})
with pytest.deprecated_call(match="from_cli_args"):
Omni.from_cli_args(args)

assert captured["gpu_memory_utilization"] == 0.9
assert "_cli_explicit_keys" not in captured


def test_deprecated_from_cli_args_preserves_legacy_parser_nulling(
monkeypatch: pytest.MonkeyPatch,
):
from vllm_omni.entrypoints.omni import Omni

captured: dict[str, Any] = {}
Expand All @@ -186,7 +237,8 @@ def fake_engine(*args: Any, **kwargs: Any) -> FakeAsyncOmniEngine:
args = parser.parse_args([])
args.model = "fake-model"

Omni.from_cli_args(args, parser=parser)
with pytest.deprecated_call(match="from_cli_args"):
Omni.from_cli_args(args, parser=parser)

assert captured["gpu_memory_utilization"] is None
assert captured["hsdp_shard_size"] == -1
Expand Down
11 changes: 4 additions & 7 deletions tests/entrypoints/test_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
from pytest_mock import MockerFixture

from vllm_omni.entrypoints.cli.serve import OmniServeCommand, run_headless
from vllm_omni.entrypoints.utils import detect_explicit_cli_keys

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]


def test_serve_parser_accepts_no_async_chunk_and_marks_it_explicit() -> None:
"""``--no-async-chunk`` should parse to ``async_chunk=False`` and mark the
shared deploy-level dest as explicitly provided by the user."""
def test_serve_parser_accepts_no_async_chunk() -> None:
"""``--no-async-chunk`` should parse after deploy-overriding parser
defaults are nullified."""
try:
from vllm.utils.argparse_utils import FlexibleArgumentParser
except Exception as exc:
Expand All @@ -24,14 +23,12 @@ def test_serve_parser_accepts_no_async_chunk_and_marks_it_explicit() -> None:
root = FlexibleArgumentParser()
subparsers = root.add_subparsers(dest="subcommand")
cmd = OmniServeCommand()
serve_parser = cmd.subparser_init(subparsers)
cmd.subparser_init(subparsers)

argv = ["serve", "fake-model", "--omni", "--no-async-chunk"]
args = root.parse_args(argv)

assert args.async_chunk is False
explicit = detect_explicit_cli_keys(argv, serve_parser)
assert "async_chunk" in explicit


def _make_headless_args() -> argparse.Namespace:
Expand Down
Loading
Loading