Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions tests/entrypoints/openai_api/test_image_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_client(mock_async_diffusion):
[BaseModelPath(name="Qwen/Qwen-Image", model_path="Qwen/Qwen-Image")]
)
app.state.args = Namespace(
default_sampling_params='{"0": {"num_inference_steps":4, "guidance_scale":7.5}}',
default_sampling_params='{"0": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}',
max_generated_image_size=1024 * 1792,
)

Expand All @@ -200,7 +200,7 @@ def async_omni_test_client():
SimpleNamespace(stage_type="diffusion"),
]
app.state.args = Namespace(
default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}',
default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}',
max_generated_image_size=1048576, # 1024*1024 to support resolution tests
)
return TestClient(app)
Expand All @@ -222,7 +222,7 @@ def async_omni_rgba_test_client():
SimpleNamespace(stage_type="diffusion"),
]
app.state.args = Namespace(
default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}',
default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}',
max_generated_image_size=1048576,
)
return TestClient(app)
Expand All @@ -244,7 +244,7 @@ def async_omni_stage_configs_only_client():
# Intentionally do not populate app.state.stage_configs. Refactored
# AsyncOmni exposes stage_configs on the engine instance.
app.state.args = Namespace(
default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}',
default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}',
max_generated_image_size=1024 * 1792,
)
return TestClient(app)
Expand Down Expand Up @@ -960,6 +960,7 @@ def test_image_edit_parameter_default(async_omni_test_client):
assert captured_sampling_params.num_outputs_per_prompt == 1
assert captured_sampling_params.num_inference_steps == 4
assert captured_sampling_params.guidance_scale == 7.5
assert captured_sampling_params.generator_device == "cpu"

# Test that a size exceeding max_generated_image_size returns 400
response = async_omni_test_client.post(
Expand Down Expand Up @@ -993,6 +994,7 @@ def test_image_edit_parameter_default_single_stage(test_client):
assert captured_sampling_params.num_outputs_per_prompt == 1
assert captured_sampling_params.num_inference_steps == 4
assert captured_sampling_params.guidance_scale == 7.5
assert captured_sampling_params.generator_device == "cpu"

# Size exceeding max_generated_image_size (1024*1792) returns 400
response = test_client.post(
Expand Down
14 changes: 14 additions & 0 deletions tests/entrypoints/test_async_omni_diffusion_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ def test_default_stage_config_propagates_ulysses_mode():
assert parallel_config.ulysses_mode == "advanced_uaa"


def test_default_stage_config_includes_default_sampling_params():
"""Ensure default sampling params survive the default diffusion-stage builder."""
stage_cfg = AsyncOmniEngine._create_default_diffusion_stage_cfg(
{
"default_sampling_params": '{"0": {"generator_device":"cpu", "guidance_scale":7.5}}',
}
)[0]

assert stage_cfg["default_sampling_params"] == {
"generator_device": "cpu",
"guidance_scale": 7.5,
}


def test_serve_cli_accepts_ulysses_mode():
"""Ensure diffusion serve CLI exposes ulysses_mode and wires it to parallel_config."""
parser = FlexibleArgumentParser()
Expand Down
11 changes: 11 additions & 0 deletions vllm_omni/engine/async_omni_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,16 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list:
# We temporally create a default config for diffusion stage.
# In the future, we should merge the default config with the user-provided config.
normalized_kwargs = dict(kwargs)
default_sampling_params = normalized_kwargs.get("default_sampling_params")
if isinstance(default_sampling_params, str):
try:
default_sampling_params = json.loads(default_sampling_params)
except json.JSONDecodeError:
logger.warning("Invalid default_sampling_params JSON, ignoring stage defaults.")
default_sampling_params = None
if not isinstance(default_sampling_params, dict):
default_sampling_params = None
stage_default_sampling_params = default_sampling_params.get("0", {}) if default_sampling_params else {}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here the logic is, for multiple stages model, we take stage-0 as defaulte params for all stages. Right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly.

This code path is only used when we synthesize the fallback single-stage diffusion config in _create_default_diffusion_stage_cfg(). In that case, there is only one generated stage, and its stage_id is always 0, so reading default_sampling_params["0"] is intentional.

For multi-stage models, we do not go through this helper. We load the resolved stage configs directly, and each stage keeps its own default_sampling_params, which are later read from that stage’s config during metadata extraction. So this change does not apply stage-0 defaults to all stages; it only preserves stage-0 defaults for the synthetic single-stage diffusion path.


# TODO: hack, convert dtype to string to avoid non-premitive omegaconf create error.
if "dtype" in normalized_kwargs and not isinstance(normalized_kwargs["dtype"], str):
Expand Down Expand Up @@ -1234,6 +1244,7 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list:
"devices": devices,
},
"engine_args": stage_engine_args,
"default_sampling_params": stage_default_sampling_params,
"final_output": True,
"final_output_type": "image",
}
Expand Down
Loading