diff --git a/benchmarks/diffusion/quantization_quality.py b/benchmarks/diffusion/quantization_quality.py
index 4a916e7ea62..85c6d784459 100644
--- a/benchmarks/diffusion/quantization_quality.py
+++ b/benchmarks/diffusion/quantization_quality.py
@@ -58,6 +58,8 @@
import numpy as np
import torch
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
+
def compute_lpips_images(
baseline_images: list,
@@ -137,6 +139,7 @@ def _build_omni_kwargs(args, quantization=None):
)
kwargs = {
"model": args.model,
+ "explicit_overrides": getattr(args, "_explicit_overrides", None),
"parallel_config": parallel_config,
"enforce_eager": args.enforce_eager,
}
@@ -452,7 +455,7 @@ def parse_args():
parser.add_argument("--ring-degree", type=int, default=1)
parser.add_argument("--tensor-parallel-size", type=int, default=1)
parser.add_argument("--enforce-eager", action="store_true")
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
if __name__ == "__main__":
diff --git a/examples/offline_inference/bagel/end2end.py b/examples/offline_inference/bagel/end2end.py
index 472d748d1e6..9bd2d9a91ff 100644
--- a/examples/offline_inference/bagel/end2end.py
+++ b/examples/offline_inference/bagel/end2end.py
@@ -1,6 +1,7 @@
import argparse
import os
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.inputs.data import OmniPromptType
from vllm_omni.model_executor.stage_input_processors.bagel import (
GEN_THINK_SYSTEM_PROMPT,
@@ -98,7 +99,7 @@ def parse_args():
help="Enable thinking mode: AR stage decodes ... planning tokens before image generation.",
)
- args = parser.parse_args()
+ args = parse_args_with_explicit_overrides(parser)
return args
@@ -152,6 +153,7 @@ def main():
)
if args.quantization:
omni_kwargs["quantization_config"] = args.quantization
+ omni_kwargs["explicit_overrides"] = getattr(args, "_explicit_overrides", None)
omni = Omni(model=model_name, **omni_kwargs)
diff --git a/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py b/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py
index 68ab72b3870..fc792b314d8 100644
--- a/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py
+++ b/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py
@@ -10,6 +10,7 @@
from vllm import SamplingParams
from vllm.assets.audio import AudioAsset
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.model_executor.models.cosyvoice3.config import CosyVoice3Config
from vllm_omni.model_executor.models.cosyvoice3.tokenizer import get_qwen_tokenizer
@@ -55,7 +56,7 @@ def run_e2e():
required=True,
help="Path to tokenizer directory (e.g., /CosyVoice-BlankEN).",
)
- args = parser.parse_args()
+ args = parse_args_with_explicit_overrides(parser)
_ensure_mel_filters_asset()
# Ensure tokenizer directory exists
if not os.path.exists(args.tokenizer):
@@ -72,6 +73,7 @@ def run_e2e():
# We pass trust_remote_code=True same as Qwen examples
omni = Omni(
model=args.model,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
stage_configs_path=args.stage_config,
trust_remote_code=True,
tokenizer=args.tokenizer,
diff --git a/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py b/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py
index 8ab5e0d9a6c..8dd4f02d3c7 100644
--- a/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py
+++ b/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py
@@ -52,6 +52,7 @@
from PIL import Image
from vllm_omni.diffusion.data import DiffusionParallelConfig
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
@@ -100,7 +101,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--vae-use-tiling", action="store_true")
parser.add_argument("--enable-cpu-offload", action="store_true")
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
# ===========================
@@ -149,6 +150,7 @@ async def main():
# ---- Initialize Omni ----
omni = Omni(
model=args.model,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
vae_use_slicing=args.vae_use_slicing,
vae_use_tiling=args.vae_use_tiling,
cache_backend=args.cache_backend,
diff --git a/examples/offline_inference/dynin_omni/end2end.py b/examples/offline_inference/dynin_omni/end2end.py
index 66047934d51..dc7f0793fb0 100644
--- a/examples/offline_inference/dynin_omni/end2end.py
+++ b/examples/offline_inference/dynin_omni/end2end.py
@@ -18,6 +18,8 @@
import torch
from PIL import Image
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
+
TASK_CHOICES = ("t2t", "t2i", "t2s", "i2i", "i2t", "s2t", "v2t")
TASK_DEFAULT_RUNTIME = {
@@ -970,7 +972,7 @@ def parse_args(repo_root: Path) -> argparse.Namespace:
parser.add_argument("--vq-model-audio-local-files-only", action=argparse.BooleanOptionalAction, default=None)
parser.add_argument("--disable-hf-xet", action=argparse.BooleanOptionalAction, default=True)
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
def main() -> None:
@@ -1395,7 +1397,12 @@ def main() -> None:
from vllm_omni.entrypoints.omni import Omni
stage_config_path = str(Path(args.stage_config_path).expanduser())
- omni = Omni(model=model_source, stage_configs_path=stage_config_path, dtype=args.dtype)
+ omni = Omni(
+ model=model_source,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
+ stage_configs_path=stage_config_path,
+ dtype=args.dtype,
+ )
sampling_params_list = [
SamplingParams(max_tokens=int(args.max_tokens_per_stage), temperature=0.0, top_p=1.0, detokenize=False)
for _ in range(omni.num_stages)
diff --git a/examples/offline_inference/fish_speech/end2end.py b/examples/offline_inference/fish_speech/end2end.py
index 31c24d3d5d6..5db70ef38c3 100644
--- a/examples/offline_inference/fish_speech/end2end.py
+++ b/examples/offline_inference/fish_speech/end2end.py
@@ -29,6 +29,7 @@
from vllm.utils.argparse_utils import FlexibleArgumentParser
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni import AsyncOmni, Omni
from vllm_omni.model_executor.models.fish_speech.dac_utils import DAC_HOP_LENGTH, DAC_SAMPLE_RATE
from vllm_omni.model_executor.models.fish_speech.prompt_utils import (
@@ -149,6 +150,7 @@ def main(args):
omni = Omni(
model=model_name,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
stage_configs_path=stage_configs_path,
log_stats=args.log_stats,
stage_init_timeout=args.stage_init_timeout,
@@ -185,6 +187,7 @@ async def main_streaming(args):
omni = AsyncOmni(
model=model_name,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
stage_configs_path=stage_configs_path,
log_stats=args.log_stats,
stage_init_timeout=args.stage_init_timeout,
@@ -273,7 +276,7 @@ def parse_args():
default=False,
help="Stream audio chunks as they arrive via AsyncOmni.",
)
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
if __name__ == "__main__":
diff --git a/examples/offline_inference/glm_image/end2end.py b/examples/offline_inference/glm_image/end2end.py
index 13bcd23f55a..be4a5205734 100644
--- a/examples/offline_inference/glm_image/end2end.py
+++ b/examples/offline_inference/glm_image/end2end.py
@@ -44,6 +44,7 @@
from PIL import Image
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -260,6 +261,7 @@ def main(args: argparse.Namespace) -> None:
omni = Omni(
model=args.model_path,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
stage_configs_path=config_path,
log_stats=args.enable_stats,
stage_init_timeout=args.stage_init_timeout,
@@ -503,7 +505,7 @@ def parse_args() -> argparse.Namespace:
help="Enable diffusion pipeline profiler to display stage durations.",
)
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
if __name__ == "__main__":
diff --git a/examples/offline_inference/helios/end2end.py b/examples/offline_inference/helios/end2end.py
index 88c3b865d42..afcc77b4ab3 100644
--- a/examples/offline_inference/helios/end2end.py
+++ b/examples/offline_inference/helios/end2end.py
@@ -52,6 +52,7 @@
import torch
from vllm_omni.diffusion.data import DiffusionParallelConfig
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
@@ -196,7 +197,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--cfg-parallel-size", type=int, default=1, choices=[1, 2], help="CFG parallel size.")
parser.add_argument("--tensor-parallel-size", type=int, default=1, help="Tensor parallelism size.")
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
def main():
@@ -212,6 +213,7 @@ def main():
omni = Omni(
model=args.model,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
enable_layerwise_offload=args.enable_layerwise_offload,
vae_use_slicing=args.vae_use_slicing,
vae_use_tiling=args.vae_use_tiling,
diff --git a/examples/offline_inference/hunyuan_image3/image_to_text.py b/examples/offline_inference/hunyuan_image3/image_to_text.py
index d40134ac0a0..0d64f064c06 100644
--- a/examples/offline_inference/hunyuan_image3/image_to_text.py
+++ b/examples/offline_inference/hunyuan_image3/image_to_text.py
@@ -6,6 +6,7 @@
from PIL import Image
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.omni import Omni
"""
@@ -46,7 +47,7 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable diffusion pipeline profiler to display stage durations.",
)
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
def load_image(image_path: str) -> Image.Image:
@@ -59,6 +60,7 @@ def load_image(image_path: str) -> Image.Image:
def main(args: argparse.Namespace) -> None:
omni = Omni(
model=args.model,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler,
mode="image-to-text",
)
diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py
index a8035a3fdcb..f5e706609a7 100644
--- a/examples/offline_inference/image_to_image/image_edit.py
+++ b/examples/offline_inference/image_to_image/image_edit.py
@@ -95,6 +95,7 @@
from PIL import Image
from vllm_omni.diffusion.data import DiffusionParallelConfig
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
@@ -330,7 +331,7 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable diffusion pipeline profiler to display stage durations.",
)
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
def main():
@@ -386,6 +387,7 @@ def main():
# Initialize Omni with appropriate pipeline
omni = Omni(
model=args.model,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
enable_layerwise_offload=args.enable_layerwise_offload,
vae_use_slicing=args.vae_use_slicing,
vae_use_tiling=args.vae_use_tiling,
diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py
index 7e7cfbf84e8..2633c73fe20 100644
--- a/examples/offline_inference/image_to_video/image_to_video.py
+++ b/examples/offline_inference/image_to_video/image_to_video.py
@@ -42,6 +42,7 @@
import torch
from vllm_omni.diffusion.data import DiffusionParallelConfig
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
@@ -187,7 +188,7 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable diffusion pipeline profiler to display stage durations.",
)
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
def calculate_dimensions(
@@ -281,6 +282,7 @@ def main():
)
omni = Omni(
model=args.model,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
enable_layerwise_offload=args.enable_layerwise_offload,
vae_use_slicing=args.vae_use_slicing,
vae_use_tiling=args.vae_use_tiling,
diff --git a/examples/offline_inference/magi_human/end2end.py b/examples/offline_inference/magi_human/end2end.py
index 64f11c4658b..4cfa5d90bd9 100644
--- a/examples/offline_inference/magi_human/end2end.py
+++ b/examples/offline_inference/magi_human/end2end.py
@@ -1,6 +1,7 @@
import argparse
from vllm_omni.diffusion.utils.media_utils import mux_video_audio_bytes
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -24,7 +25,7 @@ def parse_args():
parser.add_argument("--width", type=int, default=448, help="Video width.")
parser.add_argument("--num-inference-steps", type=int, default=8, help="Number of denoising steps.")
parser.add_argument("--seed", type=int, default=52, help="Random seed for generation.")
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
def main():
@@ -33,6 +34,7 @@ def main():
print(f"Initializing MagiHuman pipeline with TP={args.tensor_parallel_size}...")
omni = Omni(
model=args.model,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
init_timeout=1200,
tensor_parallel_size=args.tensor_parallel_size,
devices=list(range(args.tensor_parallel_size)),
diff --git a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py
index ca87b9e9a94..5c8623d3098 100644
--- a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py
+++ b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py
@@ -19,6 +19,7 @@
from vllm.multimodal.image import convert_image_mode
from vllm_omni import Omni
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
DEFAULT_SYSTEM = "You are a helpful assistant."
DEFAULT_QUESTION = "Please summarize the content of this image."
@@ -48,7 +49,7 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable diffusion pipeline profiler to display stage durations.",
)
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
def build_prompt(system: str, question: str) -> str:
@@ -73,6 +74,7 @@ def main() -> None:
omni = Omni(
model=args.model,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
stage_configs_path=args.stage_config,
trust_remote_code=args.trust_remote_code,
enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler,
diff --git a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py
index a4c41fee1f8..4a18feaa844 100644
--- a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py
+++ b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py
@@ -29,6 +29,7 @@
from vllm.sampling_params import SamplingParams
from vllm_omni import Omni
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
@@ -117,7 +118,7 @@ def parse_args() -> argparse.Namespace:
)
p.add_argument("--out", type=str, default="output.png", help="Path to save the generated image.")
p.add_argument("--trust-remote-code", action="store_true", help="Trust remote code when loading the model.")
- args = p.parse_args()
+ args = parse_args_with_explicit_overrides(p)
if not args.prompt:
args.prompt = ["A stylish woman with sunglasses riding a motorcycle in NYC."]
return args
@@ -194,7 +195,12 @@ def main() -> None:
expected_grid_tokens = ar_height * (ar_width + 1)
logger.info("Initializing Omni pipeline...")
- omni = Omni(model=args.model, stage_configs_path=args.stage_config, trust_remote_code=args.trust_remote_code)
+ omni = Omni(
+ model=args.model,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
+ stage_configs_path=args.stage_config,
+ trust_remote_code=args.trust_remote_code,
+ )
try:
ar_sampling = SamplingParams(
temperature=1.0,
diff --git a/examples/offline_inference/mimo_audio/end2end.py b/examples/offline_inference/mimo_audio/end2end.py
index ae044d2e8a1..163568f1e41 100644
--- a/examples/offline_inference/mimo_audio/end2end.py
+++ b/examples/offline_inference/mimo_audio/end2end.py
@@ -23,6 +23,7 @@
from vllm import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniTokensPrompt
@@ -182,6 +183,7 @@ def main(args):
omni = Omni(
model=model_name,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
stage_configs_path=args.stage_configs_path,
log_stats=args.enable_stats,
log_file=("omni_pipeline.log" if args.enable_stats else None),
@@ -434,7 +436,7 @@ def parse_args():
help="Path to a stage configs file.",
)
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
if __name__ == "__main__":
diff --git a/examples/offline_inference/omnivoice/end2end.py b/examples/offline_inference/omnivoice/end2end.py
index b41379b011a..4678f0d4809 100644
--- a/examples/offline_inference/omnivoice/end2end.py
+++ b/examples/offline_inference/omnivoice/end2end.py
@@ -21,6 +21,7 @@
import numpy as np
import soundfile as sf
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -79,7 +80,7 @@ def run_e2e():
default=600,
help="Stage initialization timeout in seconds",
)
- args = parser.parse_args()
+ args = parse_args_with_explicit_overrides(parser)
if not os.path.exists(args.stage_config):
raise FileNotFoundError(f"Stage config not found: {args.stage_config}")
@@ -88,6 +89,7 @@ def run_e2e():
omni = Omni(
model=args.model,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
stage_configs_path=args.stage_config,
trust_remote_code=True,
log_stats=True,
diff --git a/examples/offline_inference/qwen2_5_omni/end2end.py b/examples/offline_inference/qwen2_5_omni/end2end.py
index 7bba5998308..33f73ab17f3 100644
--- a/examples/offline_inference/qwen2_5_omni/end2end.py
+++ b/examples/offline_inference/qwen2_5_omni/end2end.py
@@ -20,6 +20,7 @@
from vllm.sampling_params import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.omni import Omni
SEED = 42
@@ -322,6 +323,7 @@ def main(args):
query_result = query_func()
omni = Omni(
model=model_name,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
log_stats=args.log_stats,
stage_init_timeout=args.stage_init_timeout,
batch_timeout=args.batch_timeout,
@@ -540,7 +542,7 @@ def parse_args():
default=False,
help="Use py_generator mode. The returned type of Omni.generate() is a Python Generator object.",
)
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
if __name__ == "__main__":
diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py
index 155eca4ed9f..1f6352c2070 100644
--- a/examples/offline_inference/qwen3_omni/end2end.py
+++ b/examples/offline_inference/qwen3_omni/end2end.py
@@ -21,6 +21,7 @@
from vllm.multimodal.image import convert_image_mode
from vllm.utils.argparse_utils import FlexibleArgumentParser
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.omni import Omni
SEED = 42
@@ -296,6 +297,7 @@ def main(args):
omni = Omni(
model=model_name,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
dtype=args.dtype,
stage_configs_path=args.stage_configs_path,
log_stats=args.log_stats,
@@ -557,7 +559,7 @@ def parse_args():
help="Model dtype (auto, half, float16, bfloat16, float, float32).",
)
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
if __name__ == "__main__":
diff --git a/examples/offline_inference/qwen3_omni/end2end_async_chunk.py b/examples/offline_inference/qwen3_omni/end2end_async_chunk.py
index 8adbae9eb66..60723215fe7 100644
--- a/examples/offline_inference/qwen3_omni/end2end_async_chunk.py
+++ b/examples/offline_inference/qwen3_omni/end2end_async_chunk.py
@@ -41,6 +41,7 @@
from vllm.multimodal.image import convert_image_mode
from vllm.utils.argparse_utils import FlexibleArgumentParser
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.async_omni import AsyncOmni
logger = logging.getLogger(__name__)
@@ -379,6 +380,7 @@ async def run_all(args):
try:
async_omni = AsyncOmni(
model=args.model,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
stage_configs_path=args.stage_configs_path,
log_stats=args.log_stats,
stage_init_timeout=args.stage_init_timeout,
@@ -584,7 +586,7 @@ def parse_args():
default=16000,
help="Sampling rate for audio loading.",
)
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
if __name__ == "__main__":
diff --git a/examples/offline_inference/qwen3_tts/end2end.py b/examples/offline_inference/qwen3_tts/end2end.py
index 901418c39b8..e56b516fa6e 100644
--- a/examples/offline_inference/qwen3_tts/end2end.py
+++ b/examples/offline_inference/qwen3_tts/end2end.py
@@ -17,6 +17,7 @@
from vllm.utils.argparse_utils import FlexibleArgumentParser
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni import AsyncOmni, Omni
logger = logging.getLogger(__name__)
@@ -368,6 +369,7 @@ def main(args):
omni = Omni(
model=model_name,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
stage_configs_path=args.stage_configs_path,
log_stats=args.log_stats,
stage_init_timeout=args.stage_init_timeout,
@@ -389,6 +391,7 @@ async def main_streaming(args):
omni = AsyncOmni(
model=model_name,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
stage_configs_path=args.stage_configs_path,
log_stats=args.log_stats,
stage_init_timeout=args.stage_init_timeout,
@@ -534,7 +537,7 @@ def parse_args():
help="Number of prompts per batch (default: 1, sequential).",
)
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
if __name__ == "__main__":
diff --git a/examples/offline_inference/text_to_audio/text_to_audio.py b/examples/offline_inference/text_to_audio/text_to_audio.py
index a6968c419f6..29688984a3b 100644
--- a/examples/offline_inference/text_to_audio/text_to_audio.py
+++ b/examples/offline_inference/text_to_audio/text_to_audio.py
@@ -20,6 +20,7 @@
import numpy as np
import torch
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
@@ -95,7 +96,7 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable diffusion pipeline profiler to display stage durations.",
)
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
def save_audio(audio_data: np.ndarray, output_path: str, sample_rate: int = 44100):
@@ -140,6 +141,7 @@ def main():
# Initialize Omni with Stable Audio model
omni = Omni(
model=args.model,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler,
)
diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py
index 615e4067ed8..41bc5953c52 100644
--- a/examples/offline_inference/text_to_image/text_to_image.py
+++ b/examples/offline_inference/text_to_image/text_to_image.py
@@ -10,6 +10,7 @@
import torch
from vllm_omni.diffusion.data import DiffusionParallelConfig, logger
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.lora.request import LoRARequest
@@ -255,7 +256,7 @@ def parse_args() -> argparse.Namespace:
default=None,
help=("Custom system prompt. Used when --use-system-prompt is custom. "),
)
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
def main():
@@ -334,6 +335,7 @@ def main():
omni_kwargs = {
"model": args.model,
+ "explicit_overrides": getattr(args, "_explicit_overrides", None),
"enable_layerwise_offload": args.enable_layerwise_offload,
"vae_use_slicing": args.vae_use_slicing,
"vae_use_tiling": args.vae_use_tiling,
diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py
index 83925cc458a..430217bacfe 100644
--- a/examples/offline_inference/text_to_video/text_to_video.py
+++ b/examples/offline_inference/text_to_video/text_to_video.py
@@ -10,6 +10,7 @@
import torch
from vllm_omni.diffusion.data import DiffusionParallelConfig
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
@@ -185,7 +186,7 @@ def parse_args() -> argparse.Namespace:
choices=["fp8", "gguf"],
help="Quantization method for the transformer (fp8 for online FP8 quantization).",
)
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
def main():
@@ -229,6 +230,7 @@ def main():
omni_kwargs = dict(
model=args.model,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
enable_layerwise_offload=args.enable_layerwise_offload,
vae_use_slicing=args.vae_use_slicing,
vae_use_tiling=args.vae_use_tiling,
diff --git a/examples/offline_inference/vace/vace_video_generation.py b/examples/offline_inference/vace/vace_video_generation.py
index 6ca0d74c52e..963814cb4f3 100644
--- a/examples/offline_inference/vace/vace_video_generation.py
+++ b/examples/offline_inference/vace/vace_video_generation.py
@@ -34,6 +34,7 @@
import torch
from vllm_omni.diffusion.data import DiffusionParallelConfig
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
@@ -71,7 +72,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--ulysses-degree", type=int, default=1, help="Ulysses SP degree.")
parser.add_argument("--ring-degree", type=int, default=1, help="Ring attention degree.")
parser.add_argument("--cfg-parallel-size", type=int, default=1, choices=[1, 2], help="CFG parallel size.")
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
def build_prompts(args):
@@ -155,6 +156,7 @@ def main():
omni = Omni(
model=args.model,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
vae_use_tiling=args.vae_use_tiling,
flow_shift=args.flow_shift,
enforce_eager=args.enforce_eager,
diff --git a/examples/offline_inference/voxtral_tts/end2end.py b/examples/offline_inference/voxtral_tts/end2end.py
index 0750246450a..08948be09ec 100644
--- a/examples/offline_inference/voxtral_tts/end2end.py
+++ b/examples/offline_inference/voxtral_tts/end2end.py
@@ -29,6 +29,7 @@
from vllm import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni import AsyncOmni
from vllm_omni.entrypoints.omni import Omni
@@ -39,6 +40,7 @@
async def run_streaming(inputs, sampling_params_list, model_name, args, output_dir):
async_omni = AsyncOmni(
model=model_name,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
stage_configs_path=args.stage_configs_path,
log_stats=args.log_stats,
)
@@ -191,6 +193,7 @@ async def _generate_one(batch_idx, single_input):
def run_non_streaming(inputs, sampling_params_list, model_name, args, output_dir):
llm = Omni(
model=model_name,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
log_stats=args.log_stats,
stage_configs_path=args.stage_configs_path,
)
@@ -297,7 +300,7 @@ def parse_args() -> Namespace:
default=None,
help="Voice to use instead of audio file.",
)
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
def compose_request(
diff --git a/examples/offline_inference/x_to_video_audio/x_to_video_audio.py b/examples/offline_inference/x_to_video_audio/x_to_video_audio.py
index e0424add69b..c28a65404b9 100644
--- a/examples/offline_inference/x_to_video_audio/x_to_video_audio.py
+++ b/examples/offline_inference/x_to_video_audio/x_to_video_audio.py
@@ -9,6 +9,7 @@
from PIL import Image
from vllm_omni.diffusion.data import DiffusionParallelConfig
+from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -56,7 +57,7 @@ def parse_args() -> argparse.Namespace:
default=False,
help="Enable CPU offloading for diffusion models.",
)
- return parser.parse_args()
+ return parse_args_with_explicit_overrides(parser)
def load_image_and_audio(image_paths, audio_paths):
@@ -121,6 +122,7 @@ def main() -> None:
omni = Omni(
model=args.model,
+ explicit_overrides=getattr(args, "_explicit_overrides", None),
parallel_config=parallel_config,
model_type=args.model_type,
enable_cpu_offload=args.enable_cpu_offload,
diff --git a/tests/entrypoints/test_utils.py b/tests/entrypoints/test_utils.py
index 6e44fe533c2..103278aef60 100644
--- a/tests/entrypoints/test_utils.py
+++ b/tests/entrypoints/test_utils.py
@@ -1,8 +1,11 @@
"""Unit tests for vllm_omni.entrypoints.utils module."""
+import argparse
import os
from collections import Counter
from dataclasses import dataclass
+from pathlib import Path
+from textwrap import dedent
import pytest
import torch
@@ -16,6 +19,7 @@
_filter_dict_like_object,
filter_dataclass_kwargs,
load_and_resolve_stage_configs,
+ parse_args_with_explicit_overrides,
resolve_model_config_path,
)
@@ -309,6 +313,28 @@ def mock_exists(path):
assert "glm_image.yaml" in result
+class TestParseArgsWithExplicitStageOverrides:
+ def test_only_explicit_cli_values_are_captured(self):
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--foo", default="default-foo")
+ parser.add_argument("--bar", type=int, default=3)
+
+ args = parse_args_with_explicit_overrides(parser, ["--foo", "cli-foo"])
+
+ assert args.foo == "cli-foo"
+ assert args.bar == 3
+ assert args._explicit_overrides == {"foo": "cli-foo"}
+
+ def test_equals_style_option_is_supported(self):
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--tensor-parallel-size", dest="tensor_parallel_size", type=int, default=1)
+
+ args = parse_args_with_explicit_overrides(parser, ["--tensor-parallel-size=4"])
+
+ assert args.tensor_parallel_size == 4
+ assert args._explicit_overrides == {"tensor_parallel_size": 4}
+
+
class TestLoadAndResolveStageConfigs:
def test_load_and_resolve_with_kwargs(self):
"""Ensure that dtype survives default stage creation."""
@@ -322,3 +348,134 @@ def test_load_and_resolve_with_kwargs(self):
assert config_path is None
assert len(stage_configs) == 1
assert "dtype" in stage_configs[0]["engine_args"]
+
+ def test_explicit_yaml_overrides_explicit_overrides(self, tmp_path: Path):
+ config_path = tmp_path / "stage.yaml"
+ config_path.write_text(
+ dedent(
+ """
+ stage_args:
+ - stage_id: 0
+ stage_type: diffusion
+ runtime:
+ process: true
+ devices: "0"
+ engine_args:
+ tensor_parallel_size: 1
+ """
+ ).strip(),
+ encoding="utf-8",
+ )
+
+ _, stage_configs = load_and_resolve_stage_configs(
+ model="dummy-model",
+ stage_configs_path=str(config_path),
+ kwargs={
+ "tensor_parallel_size": 4,
+ "explicit_overrides": {"tensor_parallel_size": 4},
+ },
+ )
+
+ assert stage_configs[0].engine_args.tensor_parallel_size == 1
+
+ def test_default_yaml_uses_explicit_overrides_instead_of_full_kwargs(self, mocker: MockerFixture):
+ config_path = "/tmp/default-stage.yaml"
+ stage_cfg = [
+ {
+ "stage_id": 0,
+ "stage_type": "diffusion",
+ "runtime": {"process": True, "devices": "0"},
+ "engine_args": {
+ "tensor_parallel_size": 1,
+ "distributed_executor_backend": "mp",
+ },
+ }
+ ]
+
+ mocker.patch(
+ "vllm_omni.entrypoints.utils.resolve_model_config_path",
+ return_value=config_path,
+ )
+ mocker.patch(
+ "vllm_omni.entrypoints.utils.load_yaml_config",
+ side_effect=lambda path: __import__(
+ "vllm_omni.config.yaml_util", fromlist=["create_config"]
+ ).create_config({"stage_args": stage_cfg}),
+ )
+
+ _, stage_configs = load_and_resolve_stage_configs(
+ model="dummy-model",
+ stage_configs_path=None,
+ kwargs={
+ "distributed_executor_backend": None,
+ "tensor_parallel_size": 8,
+ "explicit_overrides": {"tensor_parallel_size": 8},
+ },
+ )
+
+ assert stage_configs[0].engine_args.tensor_parallel_size == 8
+ assert stage_configs[0].engine_args.distributed_executor_backend == "mp"
+
+ def test_default_yaml_plain_kwargs_do_not_override_defaults(self, mocker: MockerFixture):
+ config_path = "/tmp/default-stage.yaml"
+ stage_cfg = [
+ {
+ "stage_id": 0,
+ "stage_type": "diffusion",
+ "runtime": {"process": True, "devices": "0"},
+ "engine_args": {"tensor_parallel_size": 1},
+ }
+ ]
+
+ mocker.patch(
+ "vllm_omni.entrypoints.utils.resolve_model_config_path",
+ return_value=config_path,
+ )
+ mocker.patch(
+ "vllm_omni.entrypoints.utils.load_yaml_config",
+ side_effect=lambda path: __import__(
+ "vllm_omni.config.yaml_util", fromlist=["create_config"]
+ ).create_config({"stage_args": stage_cfg}),
+ )
+
+ _, stage_configs = load_and_resolve_stage_configs(
+ model="dummy-model",
+ stage_configs_path=None,
+ kwargs={"tensor_parallel_size": 8},
+ )
+
+ assert stage_configs[0].engine_args.tensor_parallel_size == 1
+
+ def test_parser_defaults_still_fill_missing_keys_under_default_yaml(self, mocker: MockerFixture):
+ config_path = "/tmp/default-stage.yaml"
+ stage_cfg = [
+ {
+ "stage_id": 0,
+ "stage_type": "diffusion",
+ "runtime": {"process": True, "devices": "0"},
+ "engine_args": {"tensor_parallel_size": 1},
+ }
+ ]
+
+ mocker.patch(
+ "vllm_omni.entrypoints.utils.resolve_model_config_path",
+ return_value=config_path,
+ )
+ mocker.patch(
+ "vllm_omni.entrypoints.utils.load_yaml_config",
+ side_effect=lambda path: __import__(
+ "vllm_omni.config.yaml_util", fromlist=["create_config"]
+ ).create_config({"stage_args": stage_cfg}),
+ )
+
+ _, stage_configs = load_and_resolve_stage_configs(
+ model="dummy-model",
+ stage_configs_path=None,
+ kwargs={
+ "pipeline_parallel_size": 2,
+ "explicit_overrides": {},
+ },
+ )
+
+ assert stage_configs[0].engine_args.tensor_parallel_size == 1
+ assert stage_configs[0].engine_args.pipeline_parallel_size == 2
diff --git a/vllm_omni/entrypoints/cli/main.py b/vllm_omni/entrypoints/cli/main.py
index affa6c83349..cbbbd496622 100644
--- a/vllm_omni/entrypoints/cli/main.py
+++ b/vllm_omni/entrypoints/cli/main.py
@@ -20,6 +20,7 @@ def main():
import vllm_omni.entrypoints.cli.benchmark.main
import vllm_omni.entrypoints.cli.serve
+ from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides
CMD_MODULES = [
vllm_omni.entrypoints.cli.serve,
@@ -49,7 +50,7 @@ def main():
for cmd in new_cmds:
cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd)
cmds[cmd.name] = cmd
- args = parser.parse_args()
+ args = parse_args_with_explicit_overrides(parser)
if args.subparser in cmds:
cmds[args.subparser].validate(args)
diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py
index 0706b98987c..87b5ddb083c 100644
--- a/vllm_omni/entrypoints/openai/api_server.py
+++ b/vllm_omni/entrypoints/openai/api_server.py
@@ -430,8 +430,9 @@ async def build_async_omni_from_stage_config(
async_omni: EngineClient | None = None
try:
- kwargs = vars(args).copy()
+ kwargs = {k: v for k, v in vars(args).items() if not k.startswith("_")}
kwargs.pop("model", None)
+ kwargs["explicit_overrides"] = getattr(args, "_explicit_overrides", None)
async_omni = AsyncOmni(model=args.model, **kwargs)
# # Don't keep the dummy data in memory
diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py
index c5e49a93364..0a4725c3ed0 100644
--- a/vllm_omni/entrypoints/utils.py
+++ b/vllm_omni/entrypoints/utils.py
@@ -1,4 +1,6 @@
+import argparse
import os
+import sys
import types
from collections import Counter
from dataclasses import fields, is_dataclass
@@ -23,6 +25,53 @@
}
+def parse_args_with_explicit_overrides(
+ parser: argparse.ArgumentParser,
+ argv: list[str] | None = None,
+) -> argparse.Namespace:
+ """Parse CLI args and attach explicit override kwargs onto the namespace."""
+ if argv is None:
+ argv = sys.argv[1:]
+
+ option_to_dest: dict[str, str] = {}
+ for action in parser._actions:
+ for option_string in action.option_strings:
+ option_to_dest[option_string] = action.dest
+
+ choices = getattr(action, "choices", None)
+ if isinstance(choices, dict):
+ for choice in choices.values():
+ if not isinstance(choice, argparse.ArgumentParser):
+ continue
+ for sub_action in choice._actions:
+ for option_string in sub_action.option_strings:
+ option_to_dest[option_string] = sub_action.dest
+
+ explicit_dests: set[str] = set()
+ for arg in argv:
+ if arg == "--":
+ break
+ if not arg.startswith("-") or arg == "-":
+ continue
+ option = arg.split("=", 1)[0]
+ dest = option_to_dest.get(option)
+ if dest is not None:
+ explicit_dests.add(dest)
+
+ args = parser.parse_args(argv)
+ explicit_overrides: dict[str, Any] = {}
+ for dest in explicit_dests:
+ if dest.startswith("_") or not hasattr(args, dest):
+ continue
+ value = getattr(args, dest)
+ if callable(value):
+ continue
+ explicit_overrides[dest] = value
+
+ setattr(args, "_explicit_overrides", explicit_overrides)
+ return args
+
+
def inject_omni_kv_config(stage: Any, omni_conn_cfg: dict[str, Any], omni_from: str, omni_to: str) -> None:
"""Inject connector configuration into stage engine arguments."""
# Prepare omni_kv_config dict
@@ -320,28 +369,38 @@ def load_stage_configs_from_yaml(
Args:
config_path: Path to the YAML configuration file
base_engine_args: Engine args supplied by the caller.
- prefer_stage_engine_args: When True, YAML stage args override caller
- engine args. When False, caller engine args override YAML defaults.
+ prefer_stage_engine_args: Controls precedence between YAML stage args
+ and explicit_overrides. Plain caller kwargs only fill keys missing
+ from the YAML and never override YAML defaults.
Returns:
List of stage configuration dictionaries from the file's stage_args
"""
if base_engine_args is None:
base_engine_args = {}
+ base_engine_args = dict(base_engine_args)
+ explicit_overrides = base_engine_args.pop("explicit_overrides", None)
config_data = load_yaml_config(config_path)
stage_args = config_data.stage_args
global_async_chunk = config_data.get("async_chunk", False)
- # Convert any nested dataclass objects to dicts before creating DictConfig
- base_engine_args = _convert_dataclasses_to_dict(base_engine_args)
- base_engine_args = create_config(base_engine_args)
+ # Convert nested dataclass objects before creating DictConfigs.
+ base_engine_args = create_config(_convert_dataclasses_to_dict(base_engine_args))
+ explicit_engine_args = (
+ create_config(_convert_dataclasses_to_dict(explicit_overrides)) if explicit_overrides else None
+ )
for stage_arg in stage_args:
base_engine_args_tmp = base_engine_args.copy()
- # Update base_engine_args with stage-specific engine_args if they exist
+ # Plain caller kwargs only provide fallback values; stage YAML always
+ # wins over them. Only explicit_overrides participate in conflict
+ # resolution against the YAML defaults below.
if hasattr(stage_arg, "engine_args") and stage_arg.engine_args is not None:
+ base_engine_args_tmp = create_config(merge_configs(base_engine_args_tmp, stage_arg.engine_args))
+ if explicit_engine_args is not None:
+ explicit_engine_args_tmp = explicit_engine_args.copy()
if prefer_stage_engine_args:
- merged_engine_args = merge_configs(base_engine_args_tmp, stage_arg.engine_args)
+ merged_engine_args = merge_configs(explicit_engine_args_tmp, base_engine_args_tmp)
else:
- merged_engine_args = merge_configs(stage_arg.engine_args, base_engine_args_tmp)
+ merged_engine_args = merge_configs(base_engine_args_tmp, explicit_engine_args_tmp)
base_engine_args_tmp = create_config(merged_engine_args)
stage_type = getattr(stage_arg, "stage_type", "llm")
if hasattr(stage_arg, "runtime") and stage_arg.runtime is not None and stage_type != "diffusion":
@@ -462,6 +521,8 @@ def load_and_resolve_stage_configs(
Returns:
Tuple of (config_path, stage_configs)
"""
+ kwargs = kwargs or {}
+
if stage_configs_path is None:
config_path = resolve_model_config_path(model)
stage_configs = load_stage_configs_from_model(model, base_engine_args=kwargs)