Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion tests/entrypoints/openai_api/test_image_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def async_omni_test_client():
SimpleNamespace(stage_type="diffusion"),
]
app.state.args = Namespace(
default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}',
default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generative-device":"cuda"}}',
max_generated_image_size=1048576, # 1024*1024 to support resolution tests
)
return TestClient(app)
Expand Down Expand Up @@ -347,6 +347,36 @@ def test_generate_images_async_omni_sampling_params(async_omni_test_client):
assert captured[1].seed == 7


def test_generate_images_async_omni_default_generator_device(async_omni_test_client):
"""Test --default-sampling-params can set generator_device for generation endpoint."""
response = async_omni_test_client.post(
"/v1/images/generations",
json={
"prompt": "a cat",
"n": 1,
"size": "256x256",
},
)
assert response.status_code == 200
engine = async_omni_test_client.app.state.engine_client
captured = engine.captured_sampling_params_list
assert captured is not None
assert captured[1].generator_device == "cuda"


def test_apply_stage_default_sampling_params_accepts_generator_device_alias():
"""Test hyphenated alias keys map to generator_device."""
from vllm_omni.entrypoints.openai.api_server import apply_stage_default_sampling_params

params = OmniDiffusionSamplingParams()
apply_stage_default_sampling_params(
'{"1": {"generator-device": "cuda"}}',
params,
"1",
)
assert params.generator_device == "cuda"


def test_generate_images_async_omni_stage_configs_only(async_omni_stage_configs_only_client):
"""Regression: image generation accepts refactored AsyncOmni without stage_list."""
response = async_omni_stage_configs_only_client.post(
Expand Down
27 changes: 24 additions & 3 deletions vllm_omni/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,22 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
prompt["negative_prompt"] = request.negative_prompt
gen_params = OmniDiffusionSamplingParams(num_outputs_per_prompt=request.n)

# Apply stage defaults from --default-sampling-params, if provided.
app_state_args = getattr(raw_request.app.state, "args", None)
default_sample_param = getattr(app_state_args, "default_sampling_params", None)
diffusion_stage_ids = [i for i, cfg in enumerate(stage_configs) if get_stage_type(cfg) == "diffusion"]
if not diffusion_stage_ids:
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE.value,
detail="No diffusion stage found in multi-stage pipeline.",
)
diffusion_stage_id = diffusion_stage_ids[0]
apply_stage_default_sampling_params(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now lets --default-sampling-params override the request-level n, because num_outputs_per_prompt is only set in the constructor before apply_stage_default_sampling_params(). The image edits path re-applies n after defaults, so can we do the same here and add a regression test for request-level n precedence?

default_sample_param,
gen_params,
str(diffusion_stage_id),
)

# Parse per-request LoRA (compatible with chat's extra_body.lora shape).
lora_request, lora_scale = _parse_lora_request(request.lora)
_update_if_not_none(gen_params, "lora_request", lora_request)
Expand All @@ -1300,7 +1316,6 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
else:
size_str = "model default"

app_state_args = getattr(raw_request.app.state, "args", None)
_check_max_generated_image_size(app_state_args, width, height)

_update_if_not_none(gen_params, "width", width)
Expand Down Expand Up @@ -1867,13 +1882,19 @@ def apply_stage_default_sampling_params(
sampling_params: The sampling parameters object to update
stage_key: The stage ID/key in the pipeline
"""
param_aliases = {
"generator-device": "generator_device",
"generative-device": "generator_device",
"generative_device": "generator_device",
}
if default_params_json is not None:
default_params_dict = json.loads(default_params_json)
if stage_key in default_params_dict:
stage_defaults = default_params_dict[stage_key]
for param_name, param_value in stage_defaults.items():
if hasattr(sampling_params, param_name):
setattr(sampling_params, param_name, param_value)
normalized_name = param_aliases.get(param_name, param_name.replace("-", "_"))
if hasattr(sampling_params, normalized_name):
setattr(sampling_params, normalized_name, param_value)


def _resolve_video_runtime_context(raw_request: Request) -> tuple[str | None, list[Any] | None]:
Expand Down