Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# SPDX-License-Identifier: Apache-2.0

import json
from pathlib import Path
from types import SimpleNamespace

import numpy as np
import pytest
from PIL import Image

from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit_plus import (
get_qwen_image_edit_plus_pre_process_func,
)

pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]


def test_qwen_image_edit_plus_rejects_too_many_input_images(tmp_path: Path):
vae_dir = tmp_path / "vae"
vae_dir.mkdir()
# Keep the mock config intentionally minimal: this test only needs the
# fields touched during pre-process initialization.
(vae_dir / "config.json").write_text(json.dumps({"z_dim": 16}))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mock VAE config is extremely minimal. Add a brief comment indicating it is intentionally minimal. This helps future maintainers understand why the test might break if get_qwen_image_edit_plus_pre_process_func starts reading more keys at initialization.


pre_process = get_qwen_image_edit_plus_pre_process_func(SimpleNamespace(model=str(tmp_path)))
image = Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8))
request = SimpleNamespace(
prompts=[
{
"prompt": "combine",
"multi_modal_data": {"image": [image, image, image, image, image]},
}
],
sampling_params=SimpleNamespace(height=None, width=None),
)

with pytest.raises(ValueError, match=r"At most 4 images are supported by this model"):
pre_process(request)
57 changes: 57 additions & 0 deletions tests/entrypoints/openai_api/test_image_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,63 @@ def test_image_edit_rejects_multiple_images_when_model_does_not_support_them(asy
assert engine.captured_prompt is None


def test_image_edit_rejects_too_many_images_for_qwen_image_edit_2511(async_omni_test_client):
engine = async_omni_test_client.app.state.engine_client
engine.get_diffusion_od_config = lambda: SimpleNamespace(
supports_multimodal_inputs=True,
max_multimodal_image_inputs=4,
)

response = async_omni_test_client.post(
"/v1/images/edits",
files=[
("image", make_test_image_bytes((16, 16))),
("image", make_test_image_bytes((16, 16))),
("image", make_test_image_bytes((16, 16))),
("image", make_test_image_bytes((16, 16))),
("image", make_test_image_bytes((16, 16))),
],
data={"prompt": "hello world."},
)

assert response.status_code == 400
assert response.json()["detail"] == "Received 5 input images. At most 4 images are supported by this model."
assert engine.captured_prompt is None


def test_image_edit_rejects_too_many_images_for_qwen_image_edit_2511_before_loading(
async_omni_test_client, monkeypatch: pytest.MonkeyPatch
):
import vllm_omni.entrypoints.openai.api_server as api_server_module

engine = async_omni_test_client.app.state.engine_client
engine.get_diffusion_od_config = lambda: SimpleNamespace(
supports_multimodal_inputs=True,
max_multimodal_image_inputs=4,
)

def _fail_load(*args, **kwargs):
raise AssertionError("_load_input_images should not run for over-limit requests")

monkeypatch.setattr(api_server_module, "_load_input_images", _fail_load)

response = async_omni_test_client.post(
"/v1/images/edits",
files=[
("image", make_test_image_bytes((16, 16))),
("image", make_test_image_bytes((16, 16))),
("image", make_test_image_bytes((16, 16))),
("image", make_test_image_bytes((16, 16))),
("image", make_test_image_bytes((16, 16))),
],
data={"prompt": "hello world."},
)

assert response.status_code == 400
assert response.json()["detail"] == "Received 5 input images. At most 4 images are supported by this model."
assert engine.captured_prompt is None


def test_image_edit_parameter_pass(async_omni_test_client):
img_bytes_1 = make_test_image_bytes((16, 16))

Expand Down
10 changes: 10 additions & 0 deletions tests/test_diffusion_config_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DiffusionParallelConfig,
OmniDiffusionConfig,
)
from vllm_omni.diffusion.model_metadata import QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]

Expand Down Expand Up @@ -109,3 +110,12 @@ def test_extra_kwargs_forwarded(self):
ea = stages[0]["engine_args"]
assert ea["enforce_eager"] is True
assert ea["lora_path"] == "/tmp/lora"


def test_qwen_image_edit_plus_sets_generic_multimodal_limit():
od_config = OmniDiffusionConfig(model="Qwen/Qwen-Image-Edit-2511", model_class_name="QwenImageEditPlusPipeline")

od_config.update_multimodal_support()

assert od_config.supports_multimodal_inputs is True
assert od_config.max_multimodal_image_inputs == QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES
11 changes: 9 additions & 2 deletions vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
QuantizationConfig,
)

from vllm_omni.diffusion.model_metadata import get_diffusion_model_metadata
from vllm_omni.diffusion.utils.network_utils import is_port_available
from vllm_omni.quantization import build_quant_config

Expand Down Expand Up @@ -481,8 +482,10 @@ class OmniDiffusionConfig:
# Scheduler flow_shift for Wan2.2 (12.0 for 480p, 5.0 for 720p)
flow_shift: float | None = None

# support multi images input
# Support multi-image inputs and expose any model-specific request limit
# through a generic config field so serving code stays model-agnostic.
supports_multimodal_inputs: bool = False
max_multimodal_image_inputs: int | None = None

log_level: str = "info"

Expand Down Expand Up @@ -664,7 +667,11 @@ def set_tf_model_config(self, tf_config: "TransformerConfig") -> None:
)

def update_multimodal_support(self) -> None:
self.supports_multimodal_inputs = self.model_class_name in {"QwenImageEditPlusPipeline"}
# Resolve serving-visible multimodal behavior from shared metadata
# instead of importing concrete pipeline modules into the config layer.
metadata = get_diffusion_model_metadata(self.model_class_name)
self.supports_multimodal_inputs = metadata.supports_multimodal_inputs
self.max_multimodal_image_inputs = metadata.max_multimodal_image_inputs

def enrich_config(self) -> None:
"""Load model metadata from HuggingFace and populate config fields.
Expand Down
31 changes: 31 additions & 0 deletions vllm_omni/diffusion/model_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import dataclass


@dataclass(frozen=True)
class DiffusionModelMetadata:
# Keep serving-facing capability metadata in a lightweight shared module so
# config/model plumbing can read it without importing concrete pipelines.
supports_multimodal_inputs: bool = False
max_multimodal_image_inputs: int | None = None


QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES = 4


_DIFFUSION_MODEL_METADATA: dict[str, DiffusionModelMetadata] = {
"QwenImageEditPlusPipeline": DiffusionModelMetadata(
supports_multimodal_inputs=True,
max_multimodal_image_inputs=QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES,
),
}


def get_diffusion_model_metadata(model_class_name: str | None) -> DiffusionModelMetadata:
# Unknown models fall back to "no special multimodal capabilities" so new
# pipelines do not accidentally inherit limits meant for other models.
if model_class_name is None:
return DiffusionModelMetadata()
return _DIFFUSION_MODEL_METADATA.get(model_class_name, DiffusionModelMetadata())
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.model_metadata import QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.qwen_image.cfg_parallel import (
QwenImageCFGParallelMixin,
Expand Down Expand Up @@ -56,6 +57,12 @@

CONDITION_IMAGE_SIZE = 384 * 384
VAE_IMAGE_SIZE = 1024 * 1024
# Keep this in sync with the practical conditioning-token budget for
# Qwen-Image-Edit-2511. Empirically, 4 images stays within the supported range
# while 5 images overflows the prompt/conditioning path and fails downstream.
# Re-export the shared metadata value locally so this pipeline keeps a nearby,
# descriptive constant for validation and tests without becoming the source of truth.
MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES = QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES


def get_qwen_image_edit_plus_pre_process_func(
Expand Down Expand Up @@ -93,6 +100,11 @@ def pre_process_func(

if not isinstance(raw_image, list):
raw_image = [raw_image]
if len(raw_image) > MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES:
raise ValueError(
f"Received {len(raw_image)} input images. "
f"At most {MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES} images are supported by this model."
)
image = [
PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image | np.ndarray | torch.Tensor, im)
for im in raw_image
Expand Down
33 changes: 26 additions & 7 deletions vllm_omni/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,12 +1473,24 @@ async def edit_images(
input_images_list.extend(urls)
if not input_images_list:
raise HTTPException(status_code=422, detail="Field 'image' or 'url' is required")
pil_images = await _load_input_images(input_images_list)
if len(pil_images) > 1 and not _supports_multimodal_image_inputs(raw_request, engine_client):
# Reject oversized multi-image edit requests before fetching or decoding
# any inputs. This keeps over-limit URL requests from burning network,
# CPU, and memory on work that will be rejected anyway.
max_input_images = _get_max_edit_input_images(raw_request, engine_client)
if max_input_images is not None and len(input_images_list) > max_input_images:
detail = (
"Received multiple input images. Only a single image is supported by this model."
if max_input_images == 1
else (
f"Received {len(input_images_list)} input images. "
f"At most {max_input_images} images are supported by this model."
)
)
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail="Received multiple input images. Only a single image is supported by this model.",
detail=detail,
)
pil_images = await _load_input_images(input_images_list)
prompt["multi_modal_data"] = {}
prompt["multi_modal_data"]["image"] = pil_images

Expand Down Expand Up @@ -1651,18 +1663,25 @@ def _get_engine_and_model(raw_request: Request):
return engine_client, model_name, normalized_stage_configs


def _supports_multimodal_image_inputs(raw_request: Request, engine_client: Any) -> bool:
def _get_diffusion_od_config(raw_request: Request, engine_client: Any) -> Any:
diffusion_engine = getattr(raw_request.app.state, "diffusion_engine", None) or engine_client
get_diffusion_od_config = getattr(diffusion_engine, "get_diffusion_od_config", None)
od_config = (
return (
get_diffusion_od_config() if callable(get_diffusion_od_config) else getattr(diffusion_engine, "od_config", None)
)


def _get_max_edit_input_images(raw_request: Request, engine_client: Any) -> int | None:
od_config = _get_diffusion_od_config(raw_request, engine_client)
if od_config is None:
# Preserve the existing compatibility behavior when the diffusion
# config is not exposed on the serving surface.
return True
return bool(getattr(od_config, "supports_multimodal_inputs", False))
return None

if not bool(getattr(od_config, "supports_multimodal_inputs", False)):
return 1

return getattr(od_config, "max_multimodal_image_inputs", None)


def _get_lora_from_json_str(lora_body):
Expand Down
Loading