diff --git a/tests/entrypoints/test_async_omni_diffusion_config.py b/tests/entrypoints/test_async_omni_diffusion_config.py index 83b465fdb47..063e06094b8 100644 --- a/tests/entrypoints/test_async_omni_diffusion_config.py +++ b/tests/entrypoints/test_async_omni_diffusion_config.py @@ -121,6 +121,46 @@ def test_default_stage_config_includes_default_sampling_params(): } +def test_default_stage_config_engine_args(): + """Ensure default diffusion-stage builder sets and propagates engine_args.""" + stage_cfg = AsyncOmniEngine._create_default_diffusion_stage_cfg( + { + "distributed_executor_backend": "ray", + "boundary_ratio": 0.875, + "flow_shift": 5.0, + "trust_remote_code": True, + } + )[0] + + engine_args = stage_cfg["engine_args"] + assert engine_args["distributed_executor_backend"] == "ray" + assert engine_args["boundary_ratio"] == 0.875 + assert engine_args["flow_shift"] == 5.0 + assert engine_args["trust_remote_code"] is True + + +def test_default_stage_config_whitelist_none_fallback(): + """DeployConfig / StageDeployConfig whitelist fields with value None + fall back to OmniDiffusionConfig dataclass defaults.""" + stage_cfg = AsyncOmniEngine._create_default_diffusion_stage_cfg( + { + # DeployConfig pipeline-wide + "trust_remote_code": None, + "distributed_executor_backend": None, + "dtype": None, + # StageDeployConfig + "enforce_eager": None, + } + )[0] + + engine_args = stage_cfg["engine_args"] + + assert engine_args["trust_remote_code"] is False + assert engine_args["distributed_executor_backend"] == "mp" + assert engine_args["dtype"] == "auto" + assert engine_args["enforce_eager"] is False + + def test_serve_cli_accepts_ulysses_mode(): """Ensure diffusion serve CLI exposes ulysses_mode and wires it to parallel_config.""" parser = FlexibleArgumentParser() diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 66096de7c0c..1483210cdf1 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -1346,6 +1346,10 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: "diffusion_load_format": kwargs.get("diffusion_load_format", "default"), "custom_pipeline_args": kwargs.get("custom_pipeline_args", None), "worker_extension_cls": kwargs.get("worker_extension_cls", None), + "trust_remote_code": (False if kwargs.get("trust_remote_code") is None else kwargs["trust_remote_code"]), + "distributed_executor_backend": ( + "mp" if kwargs.get("distributed_executor_backend") is None else kwargs["distributed_executor_backend"] + ), "enable_sleep_mode": kwargs.get("enable_sleep_mode", False), "enable_multithread_weight_load": kwargs.get("enable_multithread_weight_load", True), "num_weight_load_threads": kwargs.get("num_weight_load_threads", 4),