Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tests/entrypoints/openai_api/test_image_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,19 @@ def test_model_field_omitted_works(test_client):
assert response.status_code == 200


def test_generate_images_rejects_model_mismatch(test_client):
response = test_client.post(
"/v1/images/generations",
json={
"prompt": "test",
"model": "Qwen/Qwen-Image-2512",
"size": "1024x1024",
},
)
assert response.status_code == 400
assert "model mismatch" in response.json()["detail"].lower()


def make_test_image_bytes(size=(64, 64)) -> bytes:
img = Image.new(
"RGB",
Expand Down Expand Up @@ -782,6 +795,20 @@ def test_image_edit_rejects_multiple_images_when_model_does_not_support_them(asy
assert engine.captured_prompt is None


def test_image_edit_rejects_model_mismatch(test_client):
img_bytes = make_test_image_bytes((16, 16))
response = test_client.post(
"/v1/images/edits",
files=[("image", img_bytes)],
data={
"prompt": "edit me",
"model": "Qwen/Qwen-Image-Edit",
},
)
assert response.status_code == 400
assert "model mismatch" in response.json()["detail"].lower()


def test_image_edit_parameter_pass(async_omni_test_client):
img_bytes_1 = make_test_image_bytes((16, 16))

Expand Down
12 changes: 12 additions & 0 deletions tests/entrypoints/openai_api/test_video_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,18 @@ def test_missing_prompt_returns_422(test_client):
assert response.status_code == 422


def test_video_generation_rejects_model_mismatch(test_client):
response = test_client.post(
"/v1/videos",
data={
"prompt": "bad model",
"model": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
},
)
assert response.status_code == 400
assert "model mismatch" in response.json()["detail"].lower()


def test_invalid_size_parse_returns_422(test_client):
response = test_client.post(
"/v1/videos",
Expand Down
22 changes: 12 additions & 10 deletions vllm_omni/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,11 +1304,10 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
# Get engine client (AsyncOmni) from app state
engine_client, model_name, stage_configs = _get_engine_and_model(raw_request)

# Validate model field (warn if mismatch, don't error)
if request.model is not None and request.model != model_name:
logger.warning(
f"Model mismatch: request specifies '{request.model}' but "
f"server is running '{model_name}'. Using server model."
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=(f"Model mismatch: request specifies '{request.model}' but server is running '{model_name}'."),
)

try:
Expand Down Expand Up @@ -1446,8 +1445,9 @@ async def edit_images(
# 1. get engine and model
engine_client, model_name, stage_configs = _get_engine_and_model(raw_request)
if model is not None and model != model_name:
logger.warning(
f"Model mismatch: request specifies '{model}' but server is running '{model_name}'. Using server model."
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=(f"Model mismatch: request specifies '{model}' but server is running '{model_name}'."),
)
# 2. get output format & compression
output_format = _choose_output_format(output_format, background)
Expand Down Expand Up @@ -2136,10 +2136,12 @@ async def _parse_video_form(
app_model_name, app_stage_configs = _resolve_video_runtime_context(raw_request)
effective_model_name = handler.model_name or app_model_name or request.model or "unknown"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The or request.model fallback silently defeats the new hard-error guard. When both handler.model_name and app_model_name are None (which is possible if OmniOpenAIServingVideo was constructed via __init__ with the default model_name=None, e.g. not via for_diffusion), effective_model_name becomes request.model and the check request.model != effective_model_name is always False — no 400 is ever raised.

The image endpoints don't have this hole: _get_engine_and_model uses "unknown" as the fallback, not the client's model string.

If the intent is a hard reject on mismatch, pull the model name directly from the handler/app and error explicitly when it's unset, rather than falling back to what the client sent:

Suggested change
effective_model_name = handler.model_name or app_model_name or request.model or "unknown"
effective_model_name = handler.model_name or app_model_name
if effective_model_name is None:
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE.value,
detail="Server model name is not configured.",
)

Then keep the existing mismatch check below unchanged.

if request.model is not None and effective_model_name is not None and request.model != effective_model_name:
Comment on lines 2137 to 2138

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The or request.model or "unknown" fallback defeats the validation. When both handler.model_name and app_model_name are None, effective_model_name falls through to request.model, so request.model != effective_model_name is always False — the 400 below can never fire. The image endpoints don't have this hole because they compare directly against model_name from _get_engine_and_model.

Suggested change
effective_model_name = handler.model_name or app_model_name or request.model or "unknown"
if request.model is not None and effective_model_name is not None and request.model != effective_model_name:
effective_model_name = handler.model_name or app_model_name
if request.model is not None and effective_model_name is not None and request.model != effective_model_name:

logger.warning(
"Model mismatch: request specifies '%s' but server is running '%s'. Using server model.",
request.model,
effective_model_name,
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=(
f"Model mismatch: request specifies '{request.model}' but server is running "
f"'{effective_model_name}'."
),
)
handler.set_stage_configs_if_missing(app_stage_configs)
except HTTPException:
Expand Down
Loading