diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index 8695311fb5b..e8d23f4bbeb 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -194,7 +194,7 @@ def async_omni_test_client(): SimpleNamespace(stage_type="diffusion"), ] app.state.args = Namespace( - default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}', + default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generative-device":"cuda"}}', max_generated_image_size=1048576, # 1024*1024 to support resolution tests ) return TestClient(app) @@ -347,6 +347,36 @@ def test_generate_images_async_omni_sampling_params(async_omni_test_client): assert captured[1].seed == 7 +def test_generate_images_async_omni_default_generator_device(async_omni_test_client): + """Test --default-sampling-params can set generator_device for generation endpoint.""" + response = async_omni_test_client.post( + "/v1/images/generations", + json={ + "prompt": "a cat", + "n": 1, + "size": "256x256", + }, + ) + assert response.status_code == 200 + engine = async_omni_test_client.app.state.engine_client + captured = engine.captured_sampling_params_list + assert captured is not None + assert captured[1].generator_device == "cuda" + + +def test_apply_stage_default_sampling_params_accepts_generator_device_alias(): + """Test hyphenated alias keys map to generator_device.""" + from vllm_omni.entrypoints.openai.api_server import apply_stage_default_sampling_params + + params = OmniDiffusionSamplingParams() + apply_stage_default_sampling_params( + '{"1": {"generator-device": "cuda"}}', + params, + "1", + ) + assert params.generator_device == "cuda" + + def test_generate_images_async_omni_stage_configs_only(async_omni_stage_configs_only_client): """Regression: image generation accepts refactored AsyncOmni without stage_list.""" response = async_omni_stage_configs_only_client.post( diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index ceb87f11160..82af95d292e 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1287,6 +1287,22 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) prompt["negative_prompt"] = request.negative_prompt gen_params = OmniDiffusionSamplingParams(num_outputs_per_prompt=request.n) + # Apply stage defaults from --default-sampling-params, if provided. + app_state_args = getattr(raw_request.app.state, "args", None) + default_sample_param = getattr(app_state_args, "default_sampling_params", None) + diffusion_stage_ids = [i for i, cfg in enumerate(stage_configs) if get_stage_type(cfg) == "diffusion"] + if not diffusion_stage_ids: + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, + detail="No diffusion stage found in multi-stage pipeline.", + ) + diffusion_stage_id = diffusion_stage_ids[0] + apply_stage_default_sampling_params( + default_sample_param, + gen_params, + str(diffusion_stage_id), + ) + # Parse per-request LoRA (compatible with chat's extra_body.lora shape). lora_request, lora_scale = _parse_lora_request(request.lora) _update_if_not_none(gen_params, "lora_request", lora_request) @@ -1300,7 +1316,6 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) else: size_str = "model default" - app_state_args = getattr(raw_request.app.state, "args", None) _check_max_generated_image_size(app_state_args, width, height) _update_if_not_none(gen_params, "width", width) @@ -1867,13 +1882,19 @@ def apply_stage_default_sampling_params( sampling_params: The sampling parameters object to update stage_key: The stage ID/key in the pipeline """ + param_aliases = { + "generator-device": "generator_device", + "generative-device": "generator_device", + "generative_device": "generator_device", + } if default_params_json is not None: default_params_dict = json.loads(default_params_json) if stage_key in default_params_dict: stage_defaults = default_params_dict[stage_key] for param_name, param_value in stage_defaults.items(): - if hasattr(sampling_params, param_name): - setattr(sampling_params, param_name, param_value) + normalized_name = param_aliases.get(param_name, param_name.replace("-", "_")) + if hasattr(sampling_params, normalized_name): + setattr(sampling_params, normalized_name, param_value) def _resolve_video_runtime_context(raw_request: Request) -> tuple[str | None, list[Any] | None]: