Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tests/entrypoints/test_async_omni_diffusion_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,46 @@ def test_default_stage_config_includes_default_sampling_params():
}


def test_default_stage_config_engine_args():
"""Ensure default diffusion-stage builder sets and propagates engine_args."""
stage_cfg = AsyncOmniEngine._create_default_diffusion_stage_cfg(
{
"distributed_executor_backend": "ray",
"boundary_ratio": 0.875,
"flow_shift": 5.0,
"trust_remote_code": True,
}
)[0]

engine_args = stage_cfg["engine_args"]
assert engine_args["distributed_executor_backend"] == "ray"
assert engine_args["boundary_ratio"] == 0.875
assert engine_args["flow_shift"] == 5.0
assert engine_args["trust_remote_code"] is True


def test_default_stage_config_whitelist_none_fallback():
"""DeployConfig / StageDeployConfig whitelist fields with value None
fall back to OmniDiffusionConfig dataclass defaults."""
stage_cfg = AsyncOmniEngine._create_default_diffusion_stage_cfg(
{
# DeployConfig pipeline-wide
"trust_remote_code": None,
"distributed_executor_backend": None,
"dtype": None,
# StageDeployConfig
"enforce_eager": None,
}
)[0]

engine_args = stage_cfg["engine_args"]

assert engine_args["trust_remote_code"] is False
assert engine_args["distributed_executor_backend"] == "mp"
assert engine_args["dtype"] == "auto"
assert engine_args["enforce_eager"] is False


def test_serve_cli_accepts_ulysses_mode():
"""Ensure diffusion serve CLI exposes ulysses_mode and wires it to parallel_config."""
parser = FlexibleArgumentParser()
Expand Down
4 changes: 4 additions & 0 deletions vllm_omni/engine/async_omni_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,10 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list:
"diffusion_load_format": kwargs.get("diffusion_load_format", "default"),
"custom_pipeline_args": kwargs.get("custom_pipeline_args", None),
"worker_extension_cls": kwargs.get("worker_extension_cls", None),
"trust_remote_code": (False if kwargs.get("trust_remote_code") is None else kwargs["trust_remote_code"]),
"distributed_executor_backend": (
"mp" if kwargs.get("distributed_executor_backend") is None else kwargs["distributed_executor_backend"]
),
"enable_sleep_mode": kwargs.get("enable_sleep_mode", False),
"enable_multithread_weight_load": kwargs.get("enable_multithread_weight_load", True),
"num_weight_load_threads": kwargs.get("num_weight_load_threads", 4),
Expand Down
Loading