diff --git a/benchmarks/diffusion/quantization_quality.py b/benchmarks/diffusion/quantization_quality.py index 4a916e7ea62..85c6d784459 100644 --- a/benchmarks/diffusion/quantization_quality.py +++ b/benchmarks/diffusion/quantization_quality.py @@ -58,6 +58,8 @@ import numpy as np import torch +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides + def compute_lpips_images( baseline_images: list, @@ -137,6 +139,7 @@ def _build_omni_kwargs(args, quantization=None): ) kwargs = { "model": args.model, + "explicit_overrides": getattr(args, "_explicit_overrides", None), "parallel_config": parallel_config, "enforce_eager": args.enforce_eager, } @@ -452,7 +455,7 @@ def parse_args(): parser.add_argument("--ring-degree", type=int, default=1) parser.add_argument("--tensor-parallel-size", type=int, default=1) parser.add_argument("--enforce-eager", action="store_true") - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) if __name__ == "__main__": diff --git a/examples/offline_inference/bagel/end2end.py b/examples/offline_inference/bagel/end2end.py index 472d748d1e6..9bd2d9a91ff 100644 --- a/examples/offline_inference/bagel/end2end.py +++ b/examples/offline_inference/bagel/end2end.py @@ -1,6 +1,7 @@ import argparse import os +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.inputs.data import OmniPromptType from vllm_omni.model_executor.stage_input_processors.bagel import ( GEN_THINK_SYSTEM_PROMPT, @@ -98,7 +99,7 @@ def parse_args(): help="Enable thinking mode: AR stage decodes ... planning tokens before image generation.", ) - args = parser.parse_args() + args = parse_args_with_explicit_overrides(parser) return args @@ -152,6 +153,7 @@ def main(): ) if args.quantization: omni_kwargs["quantization_config"] = args.quantization + omni_kwargs["explicit_overrides"] = getattr(args, "_explicit_overrides", None) omni = Omni(model=model_name, **omni_kwargs) diff --git a/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py b/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py index 68ab72b3870..fc792b314d8 100644 --- a/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py +++ b/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py @@ -10,6 +10,7 @@ from vllm import SamplingParams from vllm.assets.audio import AudioAsset +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.omni import Omni from vllm_omni.model_executor.models.cosyvoice3.config import CosyVoice3Config from vllm_omni.model_executor.models.cosyvoice3.tokenizer import get_qwen_tokenizer @@ -55,7 +56,7 @@ def run_e2e(): required=True, help="Path to tokenizer directory (e.g., /CosyVoice-BlankEN).", ) - args = parser.parse_args() + args = parse_args_with_explicit_overrides(parser) _ensure_mel_filters_asset() # Ensure tokenizer directory exists if not os.path.exists(args.tokenizer): @@ -72,6 +73,7 @@ def run_e2e(): # We pass trust_remote_code=True same as Qwen examples omni = Omni( model=args.model, + explicit_overrides=getattr(args, "_explicit_overrides", None), stage_configs_path=args.stage_config, trust_remote_code=True, tokenizer=args.tokenizer, diff --git a/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py b/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py index 8ab5e0d9a6c..8dd4f02d3c7 100644 --- a/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py +++ b/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py @@ -52,6 +52,7 @@ from PIL import Image from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput @@ -100,7 +101,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--vae-use-tiling", action="store_true") parser.add_argument("--enable-cpu-offload", action="store_true") - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) # =========================== @@ -149,6 +150,7 @@ async def main(): # ---- Initialize Omni ---- omni = Omni( model=args.model, + explicit_overrides=getattr(args, "_explicit_overrides", None), vae_use_slicing=args.vae_use_slicing, vae_use_tiling=args.vae_use_tiling, cache_backend=args.cache_backend, diff --git a/examples/offline_inference/dynin_omni/end2end.py b/examples/offline_inference/dynin_omni/end2end.py index 66047934d51..dc7f0793fb0 100644 --- a/examples/offline_inference/dynin_omni/end2end.py +++ b/examples/offline_inference/dynin_omni/end2end.py @@ -18,6 +18,8 @@ import torch from PIL import Image +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides + TASK_CHOICES = ("t2t", "t2i", "t2s", "i2i", "i2t", "s2t", "v2t") TASK_DEFAULT_RUNTIME = { @@ -970,7 +972,7 @@ def parse_args(repo_root: Path) -> argparse.Namespace: parser.add_argument("--vq-model-audio-local-files-only", action=argparse.BooleanOptionalAction, default=None) parser.add_argument("--disable-hf-xet", action=argparse.BooleanOptionalAction, default=True) - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) def main() -> None: @@ -1395,7 +1397,12 @@ def main() -> None: from vllm_omni.entrypoints.omni import Omni stage_config_path = str(Path(args.stage_config_path).expanduser()) - omni = Omni(model=model_source, stage_configs_path=stage_config_path, dtype=args.dtype) + omni = Omni( + model=model_source, + explicit_overrides=getattr(args, "_explicit_overrides", None), + stage_configs_path=stage_config_path, + dtype=args.dtype, + ) sampling_params_list = [ SamplingParams(max_tokens=int(args.max_tokens_per_stage), temperature=0.0, top_p=1.0, detokenize=False) for _ in range(omni.num_stages) diff --git a/examples/offline_inference/fish_speech/end2end.py b/examples/offline_inference/fish_speech/end2end.py index 31c24d3d5d6..5db70ef38c3 100644 --- a/examples/offline_inference/fish_speech/end2end.py +++ b/examples/offline_inference/fish_speech/end2end.py @@ -29,6 +29,7 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni import AsyncOmni, Omni from vllm_omni.model_executor.models.fish_speech.dac_utils import DAC_HOP_LENGTH, DAC_SAMPLE_RATE from vllm_omni.model_executor.models.fish_speech.prompt_utils import ( @@ -149,6 +150,7 @@ def main(args): omni = Omni( model=model_name, + explicit_overrides=getattr(args, "_explicit_overrides", None), stage_configs_path=stage_configs_path, log_stats=args.log_stats, stage_init_timeout=args.stage_init_timeout, @@ -185,6 +187,7 @@ async def main_streaming(args): omni = AsyncOmni( model=model_name, + explicit_overrides=getattr(args, "_explicit_overrides", None), stage_configs_path=stage_configs_path, log_stats=args.log_stats, stage_init_timeout=args.stage_init_timeout, @@ -273,7 +276,7 @@ def parse_args(): default=False, help="Stream audio chunks as they arrive via AsyncOmni.", ) - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) if __name__ == "__main__": diff --git a/examples/offline_inference/glm_image/end2end.py b/examples/offline_inference/glm_image/end2end.py index 13bcd23f55a..be4a5205734 100644 --- a/examples/offline_inference/glm_image/end2end.py +++ b/examples/offline_inference/glm_image/end2end.py @@ -44,6 +44,7 @@ from PIL import Image +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams @@ -260,6 +261,7 @@ def main(args: argparse.Namespace) -> None: omni = Omni( model=args.model_path, + explicit_overrides=getattr(args, "_explicit_overrides", None), stage_configs_path=config_path, log_stats=args.enable_stats, stage_init_timeout=args.stage_init_timeout, @@ -503,7 +505,7 @@ def parse_args() -> argparse.Namespace: help="Enable diffusion pipeline profiler to display stage durations.", ) - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) if __name__ == "__main__": diff --git a/examples/offline_inference/helios/end2end.py b/examples/offline_inference/helios/end2end.py index 88c3b865d42..afcc77b4ab3 100644 --- a/examples/offline_inference/helios/end2end.py +++ b/examples/offline_inference/helios/end2end.py @@ -52,6 +52,7 @@ import torch from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput @@ -196,7 +197,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--cfg-parallel-size", type=int, default=1, choices=[1, 2], help="CFG parallel size.") parser.add_argument("--tensor-parallel-size", type=int, default=1, help="Tensor parallelism size.") - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) def main(): @@ -212,6 +213,7 @@ def main(): omni = Omni( model=args.model, + explicit_overrides=getattr(args, "_explicit_overrides", None), enable_layerwise_offload=args.enable_layerwise_offload, vae_use_slicing=args.vae_use_slicing, vae_use_tiling=args.vae_use_tiling, diff --git a/examples/offline_inference/hunyuan_image3/image_to_text.py b/examples/offline_inference/hunyuan_image3/image_to_text.py index d40134ac0a0..0d64f064c06 100644 --- a/examples/offline_inference/hunyuan_image3/image_to_text.py +++ b/examples/offline_inference/hunyuan_image3/image_to_text.py @@ -6,6 +6,7 @@ from PIL import Image +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.omni import Omni """ @@ -46,7 +47,7 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable diffusion pipeline profiler to display stage durations.", ) - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) def load_image(image_path: str) -> Image.Image: @@ -59,6 +60,7 @@ def load_image(image_path: str) -> Image.Image: def main(args: argparse.Namespace) -> None: omni = Omni( model=args.model, + explicit_overrides=getattr(args, "_explicit_overrides", None), enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler, mode="image-to-text", ) diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index a8035a3fdcb..f5e706609a7 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -95,6 +95,7 @@ from PIL import Image from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput @@ -330,7 +331,7 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable diffusion pipeline profiler to display stage durations.", ) - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) def main(): @@ -386,6 +387,7 @@ def main(): # Initialize Omni with appropriate pipeline omni = Omni( model=args.model, + explicit_overrides=getattr(args, "_explicit_overrides", None), enable_layerwise_offload=args.enable_layerwise_offload, vae_use_slicing=args.vae_use_slicing, vae_use_tiling=args.vae_use_tiling, diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py index 7e7cfbf84e8..2633c73fe20 100644 --- a/examples/offline_inference/image_to_video/image_to_video.py +++ b/examples/offline_inference/image_to_video/image_to_video.py @@ -42,6 +42,7 @@ import torch from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput @@ -187,7 +188,7 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable diffusion pipeline profiler to display stage durations.", ) - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) def calculate_dimensions( @@ -281,6 +282,7 @@ def main(): ) omni = Omni( model=args.model, + explicit_overrides=getattr(args, "_explicit_overrides", None), enable_layerwise_offload=args.enable_layerwise_offload, vae_use_slicing=args.vae_use_slicing, vae_use_tiling=args.vae_use_tiling, diff --git a/examples/offline_inference/magi_human/end2end.py b/examples/offline_inference/magi_human/end2end.py index 64f11c4658b..4cfa5d90bd9 100644 --- a/examples/offline_inference/magi_human/end2end.py +++ b/examples/offline_inference/magi_human/end2end.py @@ -1,6 +1,7 @@ import argparse from vllm_omni.diffusion.utils.media_utils import mux_video_audio_bytes +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams @@ -24,7 +25,7 @@ def parse_args(): parser.add_argument("--width", type=int, default=448, help="Video width.") parser.add_argument("--num-inference-steps", type=int, default=8, help="Number of denoising steps.") parser.add_argument("--seed", type=int, default=52, help="Random seed for generation.") - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) def main(): @@ -33,6 +34,7 @@ def main(): print(f"Initializing MagiHuman pipeline with TP={args.tensor_parallel_size}...") omni = Omni( model=args.model, + explicit_overrides=getattr(args, "_explicit_overrides", None), init_timeout=1200, tensor_parallel_size=args.tensor_parallel_size, devices=list(range(args.tensor_parallel_size)), diff --git a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py index ca87b9e9a94..5c8623d3098 100644 --- a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py +++ b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py @@ -19,6 +19,7 @@ from vllm.multimodal.image import convert_image_mode from vllm_omni import Omni +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides DEFAULT_SYSTEM = "You are a helpful assistant." DEFAULT_QUESTION = "Please summarize the content of this image." @@ -48,7 +49,7 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable diffusion pipeline profiler to display stage durations.", ) - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) def build_prompt(system: str, question: str) -> str: @@ -73,6 +74,7 @@ def main() -> None: omni = Omni( model=args.model, + explicit_overrides=getattr(args, "_explicit_overrides", None), stage_configs_path=args.stage_config, trust_remote_code=args.trust_remote_code, enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler, diff --git a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py index a4c41fee1f8..4a18feaa844 100644 --- a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py +++ b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py @@ -29,6 +29,7 @@ from vllm.sampling_params import SamplingParams from vllm_omni import Omni +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -117,7 +118,7 @@ def parse_args() -> argparse.Namespace: ) p.add_argument("--out", type=str, default="output.png", help="Path to save the generated image.") p.add_argument("--trust-remote-code", action="store_true", help="Trust remote code when loading the model.") - args = p.parse_args() + args = parse_args_with_explicit_overrides(p) if not args.prompt: args.prompt = ["A stylish woman with sunglasses riding a motorcycle in NYC."] return args @@ -194,7 +195,12 @@ def main() -> None: expected_grid_tokens = ar_height * (ar_width + 1) logger.info("Initializing Omni pipeline...") - omni = Omni(model=args.model, stage_configs_path=args.stage_config, trust_remote_code=args.trust_remote_code) + omni = Omni( + model=args.model, + explicit_overrides=getattr(args, "_explicit_overrides", None), + stage_configs_path=args.stage_config, + trust_remote_code=args.trust_remote_code, + ) try: ar_sampling = SamplingParams( temperature=1.0, diff --git a/examples/offline_inference/mimo_audio/end2end.py b/examples/offline_inference/mimo_audio/end2end.py index ae044d2e8a1..163568f1e41 100644 --- a/examples/offline_inference/mimo_audio/end2end.py +++ b/examples/offline_inference/mimo_audio/end2end.py @@ -23,6 +23,7 @@ from vllm import SamplingParams from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniTokensPrompt @@ -182,6 +183,7 @@ def main(args): omni = Omni( model=model_name, + explicit_overrides=getattr(args, "_explicit_overrides", None), stage_configs_path=args.stage_configs_path, log_stats=args.enable_stats, log_file=("omni_pipeline.log" if args.enable_stats else None), @@ -434,7 +436,7 @@ def parse_args(): help="Path to a stage configs file.", ) - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) if __name__ == "__main__": diff --git a/examples/offline_inference/omnivoice/end2end.py b/examples/offline_inference/omnivoice/end2end.py index b41379b011a..4678f0d4809 100644 --- a/examples/offline_inference/omnivoice/end2end.py +++ b/examples/offline_inference/omnivoice/end2end.py @@ -21,6 +21,7 @@ import numpy as np import soundfile as sf +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams @@ -79,7 +80,7 @@ def run_e2e(): default=600, help="Stage initialization timeout in seconds", ) - args = parser.parse_args() + args = parse_args_with_explicit_overrides(parser) if not os.path.exists(args.stage_config): raise FileNotFoundError(f"Stage config not found: {args.stage_config}") @@ -88,6 +89,7 @@ def run_e2e(): omni = Omni( model=args.model, + explicit_overrides=getattr(args, "_explicit_overrides", None), stage_configs_path=args.stage_config, trust_remote_code=True, log_stats=True, diff --git a/examples/offline_inference/qwen2_5_omni/end2end.py b/examples/offline_inference/qwen2_5_omni/end2end.py index 7bba5998308..33f73ab17f3 100644 --- a/examples/offline_inference/qwen2_5_omni/end2end.py +++ b/examples/offline_inference/qwen2_5_omni/end2end.py @@ -20,6 +20,7 @@ from vllm.sampling_params import SamplingParams from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.omni import Omni SEED = 42 @@ -322,6 +323,7 @@ def main(args): query_result = query_func() omni = Omni( model=model_name, + explicit_overrides=getattr(args, "_explicit_overrides", None), log_stats=args.log_stats, stage_init_timeout=args.stage_init_timeout, batch_timeout=args.batch_timeout, @@ -540,7 +542,7 @@ def parse_args(): default=False, help="Use py_generator mode. The returned type of Omni.generate() is a Python Generator object.", ) - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) if __name__ == "__main__": diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py index 155eca4ed9f..1f6352c2070 100644 --- a/examples/offline_inference/qwen3_omni/end2end.py +++ b/examples/offline_inference/qwen3_omni/end2end.py @@ -21,6 +21,7 @@ from vllm.multimodal.image import convert_image_mode from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.omni import Omni SEED = 42 @@ -296,6 +297,7 @@ def main(args): omni = Omni( model=model_name, + explicit_overrides=getattr(args, "_explicit_overrides", None), dtype=args.dtype, stage_configs_path=args.stage_configs_path, log_stats=args.log_stats, @@ -557,7 +559,7 @@ def parse_args(): help="Model dtype (auto, half, float16, bfloat16, float, float32).", ) - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) if __name__ == "__main__": diff --git a/examples/offline_inference/qwen3_omni/end2end_async_chunk.py b/examples/offline_inference/qwen3_omni/end2end_async_chunk.py index 8adbae9eb66..60723215fe7 100644 --- a/examples/offline_inference/qwen3_omni/end2end_async_chunk.py +++ b/examples/offline_inference/qwen3_omni/end2end_async_chunk.py @@ -41,6 +41,7 @@ from vllm.multimodal.image import convert_image_mode from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.async_omni import AsyncOmni logger = logging.getLogger(__name__) @@ -379,6 +380,7 @@ async def run_all(args): try: async_omni = AsyncOmni( model=args.model, + explicit_overrides=getattr(args, "_explicit_overrides", None), stage_configs_path=args.stage_configs_path, log_stats=args.log_stats, stage_init_timeout=args.stage_init_timeout, @@ -584,7 +586,7 @@ def parse_args(): default=16000, help="Sampling rate for audio loading.", ) - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) if __name__ == "__main__": diff --git a/examples/offline_inference/qwen3_tts/end2end.py b/examples/offline_inference/qwen3_tts/end2end.py index 901418c39b8..e56b516fa6e 100644 --- a/examples/offline_inference/qwen3_tts/end2end.py +++ b/examples/offline_inference/qwen3_tts/end2end.py @@ -17,6 +17,7 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni import AsyncOmni, Omni logger = logging.getLogger(__name__) @@ -368,6 +369,7 @@ def main(args): omni = Omni( model=model_name, + explicit_overrides=getattr(args, "_explicit_overrides", None), stage_configs_path=args.stage_configs_path, log_stats=args.log_stats, stage_init_timeout=args.stage_init_timeout, @@ -389,6 +391,7 @@ async def main_streaming(args): omni = AsyncOmni( model=model_name, + explicit_overrides=getattr(args, "_explicit_overrides", None), stage_configs_path=args.stage_configs_path, log_stats=args.log_stats, stage_init_timeout=args.stage_init_timeout, @@ -534,7 +537,7 @@ def parse_args(): help="Number of prompts per batch (default: 1, sequential).", ) - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) if __name__ == "__main__": diff --git a/examples/offline_inference/text_to_audio/text_to_audio.py b/examples/offline_inference/text_to_audio/text_to_audio.py index a6968c419f6..29688984a3b 100644 --- a/examples/offline_inference/text_to_audio/text_to_audio.py +++ b/examples/offline_inference/text_to_audio/text_to_audio.py @@ -20,6 +20,7 @@ import numpy as np import torch +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.platforms import current_omni_platform @@ -95,7 +96,7 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable diffusion pipeline profiler to display stage durations.", ) - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) def save_audio(audio_data: np.ndarray, output_path: str, sample_rate: int = 44100): @@ -140,6 +141,7 @@ def main(): # Initialize Omni with Stable Audio model omni = Omni( model=args.model, + explicit_overrides=getattr(args, "_explicit_overrides", None), enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler, ) diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 615e4067ed8..41bc5953c52 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -10,6 +10,7 @@ import torch from vllm_omni.diffusion.data import DiffusionParallelConfig, logger +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.lora.request import LoRARequest @@ -255,7 +256,7 @@ def parse_args() -> argparse.Namespace: default=None, help=("Custom system prompt. Used when --use-system-prompt is custom. "), ) - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) def main(): @@ -334,6 +335,7 @@ def main(): omni_kwargs = { "model": args.model, + "explicit_overrides": getattr(args, "_explicit_overrides", None), "enable_layerwise_offload": args.enable_layerwise_offload, "vae_use_slicing": args.vae_use_slicing, "vae_use_tiling": args.vae_use_tiling, diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index 83925cc458a..430217bacfe 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -10,6 +10,7 @@ import torch from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput @@ -185,7 +186,7 @@ def parse_args() -> argparse.Namespace: choices=["fp8", "gguf"], help="Quantization method for the transformer (fp8 for online FP8 quantization).", ) - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) def main(): @@ -229,6 +230,7 @@ def main(): omni_kwargs = dict( model=args.model, + explicit_overrides=getattr(args, "_explicit_overrides", None), enable_layerwise_offload=args.enable_layerwise_offload, vae_use_slicing=args.vae_use_slicing, vae_use_tiling=args.vae_use_tiling, diff --git a/examples/offline_inference/vace/vace_video_generation.py b/examples/offline_inference/vace/vace_video_generation.py index 6ca0d74c52e..963814cb4f3 100644 --- a/examples/offline_inference/vace/vace_video_generation.py +++ b/examples/offline_inference/vace/vace_video_generation.py @@ -34,6 +34,7 @@ import torch from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.platforms import current_omni_platform @@ -71,7 +72,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--ulysses-degree", type=int, default=1, help="Ulysses SP degree.") parser.add_argument("--ring-degree", type=int, default=1, help="Ring attention degree.") parser.add_argument("--cfg-parallel-size", type=int, default=1, choices=[1, 2], help="CFG parallel size.") - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) def build_prompts(args): @@ -155,6 +156,7 @@ def main(): omni = Omni( model=args.model, + explicit_overrides=getattr(args, "_explicit_overrides", None), vae_use_tiling=args.vae_use_tiling, flow_shift=args.flow_shift, enforce_eager=args.enforce_eager, diff --git a/examples/offline_inference/voxtral_tts/end2end.py b/examples/offline_inference/voxtral_tts/end2end.py index 0750246450a..08948be09ec 100644 --- a/examples/offline_inference/voxtral_tts/end2end.py +++ b/examples/offline_inference/voxtral_tts/end2end.py @@ -29,6 +29,7 @@ from vllm import SamplingParams from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni import AsyncOmni from vllm_omni.entrypoints.omni import Omni @@ -39,6 +40,7 @@ async def run_streaming(inputs, sampling_params_list, model_name, args, output_dir): async_omni = AsyncOmni( model=model_name, + explicit_overrides=getattr(args, "_explicit_overrides", None), stage_configs_path=args.stage_configs_path, log_stats=args.log_stats, ) @@ -191,6 +193,7 @@ async def _generate_one(batch_idx, single_input): def run_non_streaming(inputs, sampling_params_list, model_name, args, output_dir): llm = Omni( model=model_name, + explicit_overrides=getattr(args, "_explicit_overrides", None), log_stats=args.log_stats, stage_configs_path=args.stage_configs_path, ) @@ -297,7 +300,7 @@ def parse_args() -> Namespace: default=None, help="Voice to use instead of audio file.", ) - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) def compose_request( diff --git a/examples/offline_inference/x_to_video_audio/x_to_video_audio.py b/examples/offline_inference/x_to_video_audio/x_to_video_audio.py index e0424add69b..c28a65404b9 100644 --- a/examples/offline_inference/x_to_video_audio/x_to_video_audio.py +++ b/examples/offline_inference/x_to_video_audio/x_to_video_audio.py @@ -9,6 +9,7 @@ from PIL import Image from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams @@ -56,7 +57,7 @@ def parse_args() -> argparse.Namespace: default=False, help="Enable CPU offloading for diffusion models.", ) - return parser.parse_args() + return parse_args_with_explicit_overrides(parser) def load_image_and_audio(image_paths, audio_paths): @@ -121,6 +122,7 @@ def main() -> None: omni = Omni( model=args.model, + explicit_overrides=getattr(args, "_explicit_overrides", None), parallel_config=parallel_config, model_type=args.model_type, enable_cpu_offload=args.enable_cpu_offload, diff --git a/tests/entrypoints/test_utils.py b/tests/entrypoints/test_utils.py index 6e44fe533c2..103278aef60 100644 --- a/tests/entrypoints/test_utils.py +++ b/tests/entrypoints/test_utils.py @@ -1,8 +1,11 @@ """Unit tests for vllm_omni.entrypoints.utils module.""" +import argparse import os from collections import Counter from dataclasses import dataclass +from pathlib import Path +from textwrap import dedent import pytest import torch @@ -16,6 +19,7 @@ _filter_dict_like_object, filter_dataclass_kwargs, load_and_resolve_stage_configs, + parse_args_with_explicit_overrides, resolve_model_config_path, ) @@ -309,6 +313,28 @@ def mock_exists(path): assert "glm_image.yaml" in result +class TestParseArgsWithExplicitStageOverrides: + def test_only_explicit_cli_values_are_captured(self): + parser = argparse.ArgumentParser() + parser.add_argument("--foo", default="default-foo") + parser.add_argument("--bar", type=int, default=3) + + args = parse_args_with_explicit_overrides(parser, ["--foo", "cli-foo"]) + + assert args.foo == "cli-foo" + assert args.bar == 3 + assert args._explicit_overrides == {"foo": "cli-foo"} + + def test_equals_style_option_is_supported(self): + parser = argparse.ArgumentParser() + parser.add_argument("--tensor-parallel-size", dest="tensor_parallel_size", type=int, default=1) + + args = parse_args_with_explicit_overrides(parser, ["--tensor-parallel-size=4"]) + + assert args.tensor_parallel_size == 4 + assert args._explicit_overrides == {"tensor_parallel_size": 4} + + class TestLoadAndResolveStageConfigs: def test_load_and_resolve_with_kwargs(self): """Ensure that dtype survives default stage creation.""" @@ -322,3 +348,134 @@ def test_load_and_resolve_with_kwargs(self): assert config_path is None assert len(stage_configs) == 1 assert "dtype" in stage_configs[0]["engine_args"] + + def test_explicit_yaml_overrides_explicit_overrides(self, tmp_path: Path): + config_path = tmp_path / "stage.yaml" + config_path.write_text( + dedent( + """ + stage_args: + - stage_id: 0 + stage_type: diffusion + runtime: + process: true + devices: "0" + engine_args: + tensor_parallel_size: 1 + """ + ).strip(), + encoding="utf-8", + ) + + _, stage_configs = load_and_resolve_stage_configs( + model="dummy-model", + stage_configs_path=str(config_path), + kwargs={ + "tensor_parallel_size": 4, + "explicit_overrides": {"tensor_parallel_size": 4}, + }, + ) + + assert stage_configs[0].engine_args.tensor_parallel_size == 1 + + def test_default_yaml_uses_explicit_overrides_instead_of_full_kwargs(self, mocker: MockerFixture): + config_path = "/tmp/default-stage.yaml" + stage_cfg = [ + { + "stage_id": 0, + "stage_type": "diffusion", + "runtime": {"process": True, "devices": "0"}, + "engine_args": { + "tensor_parallel_size": 1, + "distributed_executor_backend": "mp", + }, + } + ] + + mocker.patch( + "vllm_omni.entrypoints.utils.resolve_model_config_path", + return_value=config_path, + ) + mocker.patch( + "vllm_omni.entrypoints.utils.load_yaml_config", + side_effect=lambda path: __import__( + "vllm_omni.config.yaml_util", fromlist=["create_config"] + ).create_config({"stage_args": stage_cfg}), + ) + + _, stage_configs = load_and_resolve_stage_configs( + model="dummy-model", + stage_configs_path=None, + kwargs={ + "distributed_executor_backend": None, + "tensor_parallel_size": 8, + "explicit_overrides": {"tensor_parallel_size": 8}, + }, + ) + + assert stage_configs[0].engine_args.tensor_parallel_size == 8 + assert stage_configs[0].engine_args.distributed_executor_backend == "mp" + + def test_default_yaml_plain_kwargs_do_not_override_defaults(self, mocker: MockerFixture): + config_path = "/tmp/default-stage.yaml" + stage_cfg = [ + { + "stage_id": 0, + "stage_type": "diffusion", + "runtime": {"process": True, "devices": "0"}, + "engine_args": {"tensor_parallel_size": 1}, + } + ] + + mocker.patch( + "vllm_omni.entrypoints.utils.resolve_model_config_path", + return_value=config_path, + ) + mocker.patch( + "vllm_omni.entrypoints.utils.load_yaml_config", + side_effect=lambda path: __import__( + "vllm_omni.config.yaml_util", fromlist=["create_config"] + ).create_config({"stage_args": stage_cfg}), + ) + + _, stage_configs = load_and_resolve_stage_configs( + model="dummy-model", + stage_configs_path=None, + kwargs={"tensor_parallel_size": 8}, + ) + + assert stage_configs[0].engine_args.tensor_parallel_size == 1 + + def test_parser_defaults_still_fill_missing_keys_under_default_yaml(self, mocker: MockerFixture): + config_path = "/tmp/default-stage.yaml" + stage_cfg = [ + { + "stage_id": 0, + "stage_type": "diffusion", + "runtime": {"process": True, "devices": "0"}, + "engine_args": {"tensor_parallel_size": 1}, + } + ] + + mocker.patch( + "vllm_omni.entrypoints.utils.resolve_model_config_path", + return_value=config_path, + ) + mocker.patch( + "vllm_omni.entrypoints.utils.load_yaml_config", + side_effect=lambda path: __import__( + "vllm_omni.config.yaml_util", fromlist=["create_config"] + ).create_config({"stage_args": stage_cfg}), + ) + + _, stage_configs = load_and_resolve_stage_configs( + model="dummy-model", + stage_configs_path=None, + kwargs={ + "pipeline_parallel_size": 2, + "explicit_overrides": {}, + }, + ) + + assert stage_configs[0].engine_args.tensor_parallel_size == 1 + assert stage_configs[0].engine_args.pipeline_parallel_size == 2 diff --git a/vllm_omni/entrypoints/cli/main.py b/vllm_omni/entrypoints/cli/main.py index affa6c83349..cbbbd496622 100644 --- a/vllm_omni/entrypoints/cli/main.py +++ b/vllm_omni/entrypoints/cli/main.py @@ -20,6 +20,7 @@ def main(): import vllm_omni.entrypoints.cli.benchmark.main import vllm_omni.entrypoints.cli.serve + from vllm_omni.entrypoints.utils import parse_args_with_explicit_overrides CMD_MODULES = [ vllm_omni.entrypoints.cli.serve, @@ -49,7 +50,7 @@ def main(): for cmd in new_cmds: cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd) cmds[cmd.name] = cmd - args = parser.parse_args() + args = parse_args_with_explicit_overrides(parser) if args.subparser in cmds: cmds[args.subparser].validate(args) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 0706b98987c..87b5ddb083c 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -430,8 +430,9 @@ async def build_async_omni_from_stage_config( async_omni: EngineClient | None = None try: - kwargs = vars(args).copy() + kwargs = {k: v for k, v in vars(args).items() if not k.startswith("_")} kwargs.pop("model", None) + kwargs["explicit_overrides"] = getattr(args, "_explicit_overrides", None) async_omni = AsyncOmni(model=args.model, **kwargs) # # Don't keep the dummy data in memory diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index c5e49a93364..0a4725c3ed0 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -1,4 +1,6 @@ +import argparse import os +import sys import types from collections import Counter from dataclasses import fields, is_dataclass @@ -23,6 +25,53 @@ } +def parse_args_with_explicit_overrides( + parser: argparse.ArgumentParser, + argv: list[str] | None = None, +) -> argparse.Namespace: + """Parse CLI args and attach explicit override kwargs onto the namespace.""" + if argv is None: + argv = sys.argv[1:] + + option_to_dest: dict[str, str] = {} + for action in parser._actions: + for option_string in action.option_strings: + option_to_dest[option_string] = action.dest + + choices = getattr(action, "choices", None) + if isinstance(choices, dict): + for choice in choices.values(): + if not isinstance(choice, argparse.ArgumentParser): + continue + for sub_action in choice._actions: + for option_string in sub_action.option_strings: + option_to_dest[option_string] = sub_action.dest + + explicit_dests: set[str] = set() + for arg in argv: + if arg == "--": + break + if not arg.startswith("-") or arg == "-": + continue + option = arg.split("=", 1)[0] + dest = option_to_dest.get(option) + if dest is not None: + explicit_dests.add(dest) + + args = parser.parse_args(argv) + explicit_overrides: dict[str, Any] = {} + for dest in explicit_dests: + if dest.startswith("_") or not hasattr(args, dest): + continue + value = getattr(args, dest) + if callable(value): + continue + explicit_overrides[dest] = value + + setattr(args, "_explicit_overrides", explicit_overrides) + return args + + def inject_omni_kv_config(stage: Any, omni_conn_cfg: dict[str, Any], omni_from: str, omni_to: str) -> None: """Inject connector configuration into stage engine arguments.""" # Prepare omni_kv_config dict @@ -320,28 +369,38 @@ def load_stage_configs_from_yaml( Args: config_path: Path to the YAML configuration file base_engine_args: Engine args supplied by the caller. - prefer_stage_engine_args: When True, YAML stage args override caller - engine args. When False, caller engine args override YAML defaults. + prefer_stage_engine_args: Controls precedence between YAML stage args + and explicit_overrides. Plain caller kwargs only fill keys missing + from the YAML and never override YAML defaults. Returns: List of stage configuration dictionaries from the file's stage_args """ if base_engine_args is None: base_engine_args = {} + base_engine_args = dict(base_engine_args) + explicit_overrides = base_engine_args.pop("explicit_overrides", None) config_data = load_yaml_config(config_path) stage_args = config_data.stage_args global_async_chunk = config_data.get("async_chunk", False) - # Convert any nested dataclass objects to dicts before creating DictConfig - base_engine_args = _convert_dataclasses_to_dict(base_engine_args) - base_engine_args = create_config(base_engine_args) + # Convert nested dataclass objects before creating DictConfigs. + base_engine_args = create_config(_convert_dataclasses_to_dict(base_engine_args)) + explicit_engine_args = ( + create_config(_convert_dataclasses_to_dict(explicit_overrides)) if explicit_overrides else None + ) for stage_arg in stage_args: base_engine_args_tmp = base_engine_args.copy() - # Update base_engine_args with stage-specific engine_args if they exist + # Plain caller kwargs only provide fallback values; stage YAML always + # wins over them. Only explicit_overrides participate in conflict + # resolution against the YAML defaults below. if hasattr(stage_arg, "engine_args") and stage_arg.engine_args is not None: + base_engine_args_tmp = create_config(merge_configs(base_engine_args_tmp, stage_arg.engine_args)) + if explicit_engine_args is not None: + explicit_engine_args_tmp = explicit_engine_args.copy() if prefer_stage_engine_args: - merged_engine_args = merge_configs(base_engine_args_tmp, stage_arg.engine_args) + merged_engine_args = merge_configs(explicit_engine_args_tmp, base_engine_args_tmp) else: - merged_engine_args = merge_configs(stage_arg.engine_args, base_engine_args_tmp) + merged_engine_args = merge_configs(base_engine_args_tmp, explicit_engine_args_tmp) base_engine_args_tmp = create_config(merged_engine_args) stage_type = getattr(stage_arg, "stage_type", "llm") if hasattr(stage_arg, "runtime") and stage_arg.runtime is not None and stage_type != "diffusion": @@ -462,6 +521,8 @@ def load_and_resolve_stage_configs( Returns: Tuple of (config_path, stage_configs) """ + kwargs = kwargs or {} + if stage_configs_path is None: config_path = resolve_model_config_path(model) stage_configs = load_stage_configs_from_model(model, base_engine_args=kwargs)