Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tests/entrypoints/openai_api/test_image_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,19 @@ def test_model_field_omitted_works(test_client):
assert response.status_code == 200


def test_generate_images_rejects_model_mismatch(test_client):
response = test_client.post(
"/v1/images/generations",
json={
"prompt": "test",
"model": "Qwen/Qwen-Image-2512",
"size": "1024x1024",
},
)
assert response.status_code == 400
assert "model mismatch" in response.json()["detail"].lower()


def make_test_image_bytes(size=(64, 64)) -> bytes:
img = Image.new(
"RGB",
Expand Down Expand Up @@ -954,6 +967,20 @@ def test_image_edit_rejects_multiple_images_when_model_does_not_support_them(asy
assert engine.captured_prompt is None


def test_image_edit_rejects_model_mismatch(test_client):
img_bytes = make_test_image_bytes((16, 16))
response = test_client.post(
"/v1/images/edits",
files=[("image", img_bytes)],
data={
"prompt": "edit me",
"model": "Qwen/Qwen-Image-Edit",
},
)
assert response.status_code == 400
assert "model mismatch" in response.json()["detail"].lower()


def test_image_edit_rejects_too_many_images_for_qwen_image_edit_2511(async_omni_test_client):
engine = async_omni_test_client.app.state.engine_client
engine.get_diffusion_od_config = lambda: SimpleNamespace(
Expand Down
12 changes: 12 additions & 0 deletions tests/entrypoints/openai_api/test_video_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,18 @@ def test_missing_prompt_returns_422(test_client):
assert response.status_code == 422


def test_video_generation_rejects_model_mismatch(test_client):
response = test_client.post(
"/v1/videos",
data={
"prompt": "bad model",
"model": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
},
)
assert response.status_code == 400
assert "model mismatch" in response.json()["detail"].lower()


def test_invalid_size_parse_returns_422(test_client):
response = test_client.post(
"/v1/videos",
Expand Down
22 changes: 12 additions & 10 deletions vllm_omni/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,11 +1321,10 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
# Get engine client (AsyncOmni) from app state
engine_client, model_name, stage_configs = _get_engine_and_model(raw_request)

# Validate model field (warn if mismatch, don't error)
if request.model is not None and request.model != model_name:
logger.warning(
f"Model mismatch: request specifies '{request.model}' but "
f"server is running '{model_name}'. Using server model."
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=(f"Model mismatch: request specifies '{request.model}' but server is running '{model_name}'."),
)

try:
Expand Down Expand Up @@ -1517,8 +1516,9 @@ async def edit_images(
# 1. get engine and model
engine_client, model_name, stage_configs = _get_engine_and_model(raw_request)
if model is not None and model != model_name:
logger.warning(
f"Model mismatch: request specifies '{model}' but server is running '{model_name}'. Using server model."
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=(f"Model mismatch: request specifies '{model}' but server is running '{model_name}'."),
)
# 2. get output format & compression
output_format = _choose_output_format(output_format, background)
Expand Down Expand Up @@ -2234,10 +2234,12 @@ async def _parse_video_form(
app_model_name, app_stage_configs = _resolve_video_runtime_context(raw_request)
effective_model_name = handler.model_name or app_model_name or request.model or "unknown"
if request.model is not None and effective_model_name is not None and request.model != effective_model_name:
logger.warning(
"Model mismatch: request specifies '%s' but server is running '%s'. Using server model.",
request.model,
effective_model_name,
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=(
f"Model mismatch: request specifies '{request.model}' but server is running "
f"'{effective_model_name}'."
),
)
handler.set_stage_configs_if_missing(app_stage_configs)
except HTTPException:
Expand Down
Loading