diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index 4b38692da33..cd1e83393b0 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -177,7 +177,7 @@ def test_client(mock_async_diffusion): [BaseModelPath(name="Qwen/Qwen-Image", model_path="Qwen/Qwen-Image")] ) app.state.args = Namespace( - default_sampling_params='{"0": {"num_inference_steps":4, "guidance_scale":7.5}}', + default_sampling_params='{"0": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}', max_generated_image_size=1024 * 1792, ) @@ -200,7 +200,7 @@ def async_omni_test_client(): SimpleNamespace(stage_type="diffusion"), ] app.state.args = Namespace( - default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}', + default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}', max_generated_image_size=1048576, # 1024*1024 to support resolution tests ) return TestClient(app) @@ -222,7 +222,7 @@ def async_omni_rgba_test_client(): SimpleNamespace(stage_type="diffusion"), ] app.state.args = Namespace( - default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}', + default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}', max_generated_image_size=1048576, ) return TestClient(app) @@ -244,7 +244,7 @@ def async_omni_stage_configs_only_client(): # Intentionally do not populate app.state.stage_configs. Refactored # AsyncOmni exposes stage_configs on the engine instance. app.state.args = Namespace( - default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}', + default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}', max_generated_image_size=1024 * 1792, ) return TestClient(app) @@ -960,6 +960,7 @@ def test_image_edit_parameter_default(async_omni_test_client): assert captured_sampling_params.num_outputs_per_prompt == 1 assert captured_sampling_params.num_inference_steps == 4 assert captured_sampling_params.guidance_scale == 7.5 + assert captured_sampling_params.generator_device == "cpu" # Test that a size exceeding max_generated_image_size returns 400 response = async_omni_test_client.post( @@ -993,6 +994,7 @@ def test_image_edit_parameter_default_single_stage(test_client): assert captured_sampling_params.num_outputs_per_prompt == 1 assert captured_sampling_params.num_inference_steps == 4 assert captured_sampling_params.guidance_scale == 7.5 + assert captured_sampling_params.generator_device == "cpu" # Size exceeding max_generated_image_size (1024*1792) returns 400 response = test_client.post( diff --git a/tests/entrypoints/test_async_omni_diffusion_config.py b/tests/entrypoints/test_async_omni_diffusion_config.py index a55eaf05b9c..7ed8128260e 100644 --- a/tests/entrypoints/test_async_omni_diffusion_config.py +++ b/tests/entrypoints/test_async_omni_diffusion_config.py @@ -69,6 +69,20 @@ def test_default_stage_config_propagates_ulysses_mode(): assert parallel_config.ulysses_mode == "advanced_uaa" +def test_default_stage_config_includes_default_sampling_params(): + """Ensure default sampling params survive the default diffusion-stage builder.""" + stage_cfg = AsyncOmniEngine._create_default_diffusion_stage_cfg( + { + "default_sampling_params": '{"0": {"generator_device":"cpu", "guidance_scale":7.5}}', + } + )[0] + + assert stage_cfg["default_sampling_params"] == { + "generator_device": "cpu", + "guidance_scale": 7.5, + } + + def test_serve_cli_accepts_ulysses_mode(): """Ensure diffusion serve CLI exposes ulysses_mode and wires it to parallel_config.""" parser = FlexibleArgumentParser() diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 23a85e9f5f2..01c11cb9603 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -1138,6 +1138,16 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: # We temporally create a default config for diffusion stage. # In the future, we should merge the default config with the user-provided config. normalized_kwargs = dict(kwargs) + default_sampling_params = normalized_kwargs.get("default_sampling_params") + if isinstance(default_sampling_params, str): + try: + default_sampling_params = json.loads(default_sampling_params) + except json.JSONDecodeError: + logger.warning("Invalid default_sampling_params JSON, ignoring stage defaults.") + default_sampling_params = None + if not isinstance(default_sampling_params, dict): + default_sampling_params = None + stage_default_sampling_params = default_sampling_params.get("0", {}) if default_sampling_params else {} # TODO: hack, convert dtype to string to avoid non-premitive omegaconf create error. if "dtype" in normalized_kwargs and not isinstance(normalized_kwargs["dtype"], str): @@ -1234,6 +1244,7 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: "devices": devices, }, "engine_args": stage_engine_args, + "default_sampling_params": stage_default_sampling_params, "final_output": True, "final_output_type": "image", }