From b9d2ed1cdb9901cf163f9489f40167b3260a09a3 Mon Sep 17 00:00:00 2001 From: Lancer Date: Fri, 27 Mar 2026 12:46:10 +0800 Subject: [PATCH 1/5] [Bugfix] Fix image quality in /v1/images/generations for multi-stage pipeline Signed-off-by: Lancer --- .../openai_api/test_image_server.py | 53 ++- ...test_serving_chat_multistage_generation.py | 82 ++++ vllm_omni/entrypoints/openai/api_server.py | 46 ++ vllm_omni/entrypoints/openai/serving_chat.py | 430 ++++++++++++------ 4 files changed, 471 insertions(+), 140 deletions(-) create mode 100644 tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index b5d22dc09fd..656b3978aa2 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -106,10 +106,13 @@ def test_encode_image_base64(): class MockGenerationResult: - """Mock result object from AsyncOmniDiffusion.generate()""" + """Mock result object compatible with current diffusion output shape.""" def __init__(self, images): self.images = images + self.request_output = SimpleNamespace(images=images) + self.stage_durations = {} + self.peak_memory_mb = 0.0 class FakeAsyncOmni: @@ -117,8 +120,8 @@ class FakeAsyncOmni: def __init__(self): self.stage_configs = [ - SimpleNamespace(stage_type="llm"), - SimpleNamespace(stage_type="diffusion"), + SimpleNamespace(stage_type="llm", is_comprehension=True), + SimpleNamespace(stage_type="diffusion", is_comprehension=False), ] self.default_sampling_params_list = [SamplingParams(temperature=0.1), OmniDiffusionSamplingParams()] self.captured_sampling_params_list = None @@ -160,6 +163,7 @@ def test_client(mock_async_diffusion): from fastapi import FastAPI from vllm_omni.entrypoints.openai.api_server import router + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat app = FastAPI() app.include_router(router) @@ -183,11 +187,16 @@ def async_omni_test_client(): from fastapi import FastAPI from vllm_omni.entrypoints.openai.api_server import router + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat app = FastAPI() app.include_router(router) app.state.engine_client = FakeAsyncOmni() + chat_handler = object.__new__(OmniOpenAIServingChat) + chat_handler.engine_client = app.state.engine_client + chat_handler._diffusion_engine = None + app.state.openai_serving_chat = chat_handler app.state.stage_configs = [ SimpleNamespace(stage_type="llm"), SimpleNamespace(stage_type="diffusion"), @@ -345,6 +354,44 @@ def test_generate_images_async_omni_stage_configs_only(async_omni_stage_configs_ assert captured[1].seed == 11 +def test_multistage_images_async_omni_construction(async_omni_test_client, mocker: MockerFixture): + mocker.patch("vllm_omni.entrypoints.openai.serving_chat.AsyncOmni", FakeAsyncOmni) + """Regression: multistage image generation builds the expected chat-style payload.""" + response = async_omni_test_client.post( + "/v1/images/generations", + json={ + "prompt": "a cat", + "n": 2, + "size": "128x256", + "seed": 7, + "num_inference_steps": 12, + "guidance_scale": 6.5, + }, + ) + assert response.status_code == 200 + + engine = async_omni_test_client.app.state.engine_client + captured_prompt = engine.captured_prompt + assert captured_prompt["prompt"] == "a cat" + assert captured_prompt["modalities"] == ["image"] + assert captured_prompt["mm_processor_kwargs"] == { + "target_h": 256, + "target_w": 128, + } + + captured = engine.captured_sampling_params_list + assert captured is not None + assert len(captured) == 2 + assert captured[0].temperature == 0.1 + assert captured[0].seed == 7 + assert captured[1].num_outputs_per_prompt == 2 + assert captured[1].width == 128 + assert captured[1].height == 256 + assert captured[1].seed == 7 + assert captured[1].num_inference_steps == 12 + assert captured[1].guidance_scale == 6.5 + + def test_image_edits_async_omni_stage_configs_only(async_omni_stage_configs_only_client): """Regression: image edits accepts refactored AsyncOmni without stage_list.""" img_bytes = make_test_image_bytes((16, 16)) diff --git a/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py b/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py new file mode 100644 index 00000000000..a9b9f53ba8a --- /dev/null +++ b/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Regression tests for multistage diffusion generation input construction.""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +from PIL import Image +from vllm.sampling_params import SamplingParams + +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +@pytest.fixture +def serving_chat(): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + return object.__new__(OmniOpenAIServingChat) + + +def test_build_multistage_generation_inputs_applies_stage_specific_overrides(serving_chat): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + engine = SimpleNamespace( + stage_configs=[ + SimpleNamespace(stage_type="llm", is_comprehension=True), + SimpleNamespace(stage_type="diffusion", is_comprehension=False), + SimpleNamespace(stage_type="diffusion", is_comprehension=False), + ], + default_sampling_params_list=[ + SamplingParams(temperature=0.2, seed=11), + OmniDiffusionSamplingParams(), + OmniDiffusionSamplingParams(), + ], + ) + reference_image = Image.new("RGB", (24, 24), color="green") + extra_body = { + "negative_prompt": "blurry", + "num_inference_steps": 28, + "guidance_scale": 7.5, + "true_cfg_scale": 5.0, + "guidance_scale_2": 1.25, + "layers": 6, + "resolution": 1024, + "lora": {"name": "adapter-a", "path": "/tmp/adapter-a", "scale": 0.6}, + } + gen_params = OmniDiffusionSamplingParams(height=768, width=1024, seed=0, num_outputs_per_prompt=2) + + engine_prompt, sampling_params_list = OmniOpenAIServingChat._build_multistage_generation_inputs( + serving_chat, + engine=engine, + prompt="draw a robot", + extra_body=extra_body, + reference_images=[reference_image], + gen_params=gen_params, + ) + + assert engine_prompt["prompt"] == "draw a robot" + assert engine_prompt["modalities"] == ["img2img"] + assert engine_prompt["negative_prompt"] == "blurry" + assert engine_prompt["mm_processor_kwargs"] == {"target_h": 768, "target_w": 1024} + assert engine_prompt["multi_modal_data"]["img2img"].size == (24, 24) + + assert len(sampling_params_list) == 3 + assert sampling_params_list[0].temperature == 0.2 + assert sampling_params_list[0].seed == 0 + assert sampling_params_list[1].height == 768 + assert sampling_params_list[1].width == 1024 + assert sampling_params_list[1].seed == 0 + assert sampling_params_list[1].num_inference_steps == 28 + assert sampling_params_list[1].guidance_scale == 7.5 + assert sampling_params_list[1].num_outputs_per_prompt == 2 + assert sampling_params_list[1].true_cfg_scale == 5.0 + assert sampling_params_list[1].lora_request.name == "adapter-a" + assert sampling_params_list[2].height == 768 + assert sampling_params_list[2].width == 1024 + assert sampling_params_list[2].num_inference_steps == 28 + assert engine.default_sampling_params_list[1].height is None + assert engine.default_sampling_params_list[2].resolution == 640 diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 717bc08ce40..83a6ba777aa 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1279,7 +1279,53 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) f"server is running '{model_name}'. Using server model." ) + def _should_route_images_via_chat(configs: list[Any]) -> bool: + # Unify request construction for any multi-stage pipeline to avoid + # divergence between /v1/images and /v1/chat/completions. + return len(configs) > 1 + try: + if _should_route_images_via_chat(stage_configs): + chat_handler = Omnichat(raw_request) + if chat_handler is not None: + effective_seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1) + extra_body: dict[str, Any] = { + "seed": effective_seed, + "num_outputs_per_prompt": request.n, + } + if request.size is not None: + # Keep /images validation semantics: invalid size should fail with 400. + parse_size(request.size) + extra_body["size"] = request.size + if request.negative_prompt is not None: + extra_body["negative_prompt"] = request.negative_prompt + if request.num_inference_steps is not None: + extra_body["num_inference_steps"] = request.num_inference_steps + if request.guidance_scale is not None: + extra_body["guidance_scale"] = request.guidance_scale + if request.true_cfg_scale is not None: + extra_body["true_cfg_scale"] = request.true_cfg_scale + if request.generator_device is not None: + extra_body["generator_device"] = request.generator_device + if request.lora is not None: + # Keep /images validation semantics: invalid LoRA should fail with 400. + _parse_lora_request(request.lora) + extra_body["lora"] = request.lora + + generation_result = await chat_handler.generate_diffusion_images( + prompt=request.prompt, + extra_body=extra_body, + request_id=f"img_gen-{random_uuid()}", + ) + if isinstance(generation_result, ErrorResponse): + return JSONResponse( + status_code=generation_result.error.code if generation_result.error else 400, + content=generation_result.model_dump(), + ) + flat_images, _, _ = generation_result + image_data = [ImageData(b64_json=encode_image_base64(img), revised_prompt=None) for img in flat_images] + return ImageGenerationResponse(created=int(time.time()), data=image_data) + # Build params - pass through user values directly prompt: OmniTextPrompt = {"prompt": request.prompt} if request.negative_prompt is not None: diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 1ae0102a33c..d87276c59a5 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -84,7 +84,7 @@ from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin from vllm_omni.entrypoints.openai.protocol import OmniChatCompletionStreamResponse from vllm_omni.entrypoints.openai.protocol.audio import AudioResponse, CreateAudio -from vllm_omni.entrypoints.openai.utils import parse_lora_request +from vllm_omni.entrypoints.openai.utils import get_stage_type, parse_lora_request from vllm_omni.lora.request import LoRARequest from vllm_omni.outputs import OmniRequestOutput @@ -2021,6 +2021,288 @@ def _create_image_choice( return choices # ==================== Diffusion Mode Methods ==================== + + def _build_multistage_generation_inputs( + self, + *, + engine: AsyncOmni, + prompt: str, + extra_body: dict[str, Any], + reference_images: list[Image.Image], + gen_params: OmniDiffusionSamplingParams, + ) -> tuple[OmniTextPrompt, list[Any]]: + """Build the shared multistage generation prompt and stage params.""" + stage_configs = getattr(engine, "stage_configs", None) or [] + default_params_list = list(getattr(engine, "default_sampling_params_list", []) or []) + + height = gen_params.height + width = gen_params.width + seed = gen_params.seed + generator_device = gen_params.generator_device + num_outputs_per_prompt = gen_params.num_outputs_per_prompt + num_inference_steps = extra_body.get("num_inference_steps") + guidance_scale = extra_body.get("guidance_scale") + true_cfg_scale = extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale") + negative_prompt = extra_body.get("negative_prompt") + num_frames = extra_body.get("num_frames") + guidance_scale_2 = extra_body.get("guidance_scale_2") + lora_body = extra_body.get("lora") + layers = extra_body.get("layers") + resolution = extra_body.get("resolution") + + engine_prompt_data: dict[str, Any] | None = None + modalities = ["image"] + if reference_images: + if len(reference_images) == 1: + engine_prompt_data = {"img2img": reference_images[0]} + modalities = ["img2img"] + else: + # Preserve previous multi-image behavior when backend supports it. + engine_prompt_data = {"image": reference_images} + + engine_prompt: OmniTextPrompt = {"prompt": prompt} + engine_prompt["modalities"] = modalities + if negative_prompt is not None: + engine_prompt["negative_prompt"] = negative_prompt + + mm_processor_kwargs: dict[str, Any] = {} + if height is not None: + mm_processor_kwargs["target_h"] = height + if width is not None: + mm_processor_kwargs["target_w"] = width + if mm_processor_kwargs: + engine_prompt["mm_processor_kwargs"] = mm_processor_kwargs + if engine_prompt_data is not None: + engine_prompt["multi_modal_data"] = engine_prompt_data + + comprehension_idx = None + for idx, stage in enumerate(stage_configs): + if getattr(stage, "is_comprehension", False): + comprehension_idx = idx + break + sampling_params_list: list[Any] = [] + for idx, stage_cfg in enumerate(stage_configs): + stage_type = get_stage_type(stage_cfg) + if idx < len(default_params_list): + default_stage_params = default_params_list[idx] + if hasattr(default_stage_params, "clone"): + try: + default_stage_params = default_stage_params.clone() + except Exception: + pass + else: + # If defaults are missing, use a diffusion-typed fallback so fields + # like height/width/steps/CFG can still be applied deterministically. + if stage_type == "diffusion": + default_stage_params = gen_params.clone() + else: + default_stage_params = SamplingParams() + + if comprehension_idx is not None and idx == comprehension_idx and seed is not None and hasattr( + default_stage_params, "seed" + ): + default_stage_params.seed = seed + if stage_type == "diffusion": + if hasattr(default_stage_params, "height") and height is not None: + default_stage_params.height = height + if hasattr(default_stage_params, "width") and width is not None: + default_stage_params.width = width + if hasattr(default_stage_params, "seed") and seed is not None: + default_stage_params.seed = seed + if hasattr(default_stage_params, "generator_device") and generator_device is not None: + default_stage_params.generator_device = generator_device + if hasattr(default_stage_params, "num_outputs_per_prompt") and num_outputs_per_prompt is not None: + default_stage_params.num_outputs_per_prompt = num_outputs_per_prompt + if hasattr(default_stage_params, "num_inference_steps") and num_inference_steps is not None: + default_stage_params.num_inference_steps = num_inference_steps + if hasattr(default_stage_params, "guidance_scale") and guidance_scale is not None: + default_stage_params.guidance_scale = guidance_scale + if hasattr(default_stage_params, "true_cfg_scale") and true_cfg_scale is not None: + default_stage_params.true_cfg_scale = true_cfg_scale + if hasattr(default_stage_params, "num_frames") and num_frames is not None: + default_stage_params.num_frames = num_frames + if hasattr(default_stage_params, "guidance_scale_2") and guidance_scale_2 is not None: + default_stage_params.guidance_scale_2 = guidance_scale_2 + if hasattr(default_stage_params, "layers") and layers is not None: + default_stage_params.layers = layers + if hasattr(default_stage_params, "resolution") and resolution is not None: + default_stage_params.resolution = resolution + if lora_body and isinstance(lora_body, dict): + try: + lora_req, lora_scale = parse_lora_request(lora_body) + if lora_req is not None: + default_stage_params.lora_request = lora_req + if lora_scale is not None: + default_stage_params.lora_scale = lora_scale + except Exception as e: # pragma: no cover - safeguard + logger.warning("Failed to parse LoRA request: %s", e) + sampling_params_list.append(default_stage_params) + + return engine_prompt, sampling_params_list + + async def generate_diffusion_images( + self, + *, + prompt: str, + extra_body: dict[str, Any] | None = None, + reference_images: list[str] | None = None, + request_id: str | None = None, + ) -> tuple[list[Image.Image], dict[str, Any], float] | ErrorResponse: + """Generate diffusion images and return raw images plus generation stats. + + This avoids coupling callers to chat response serialization details. + """ + if request_id is None: + request_id = f"chatcmpl-{uuid.uuid4().hex[:16]}" + + if extra_body is None: + extra_body = {} + if reference_images is None: + reference_images = [] + engine = self._diffusion_engine if self._diffusion_engine is not None else self.engine_client + + # Parse size if provided (supports "1024x1024" format) + height = extra_body.get("height") + width = extra_body.get("width") + if "size" in extra_body: + try: + size_str = extra_body["size"] + if isinstance(size_str, str) and "x" in size_str.lower(): + w, h = size_str.lower().split("x") + width, height = int(w), int(h) + except ValueError: + logger.warning("Invalid size format: %s", extra_body.get("size")) + + # Request-level generation parameters. + num_inference_steps = extra_body.get("num_inference_steps") + guidance_scale = extra_body.get("guidance_scale") + true_cfg_scale = extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale") + seed = extra_body.get("seed") + generator_device = extra_body.get("generator_device") + negative_prompt = extra_body.get("negative_prompt") + num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt", 1) + + # Text-to-video parameters. + num_frames = extra_body.get("num_frames") + guidance_scale_2 = extra_body.get("guidance_scale_2") + lora_body = extra_body.get("lora") + + # Qwen-Image-Layered parameters. + layers = extra_body.get("layers") + resolution = extra_body.get("resolution") + + # Decode reference images if provided. + pil_images: list[Image.Image] = [] + for img_b64 in reference_images: + try: + img_bytes = base64.b64decode(img_b64) + pil_images.append(Image.open(BytesIO(img_bytes))) + except Exception as e: + logger.warning("Failed to decode reference image: %s", e) + + gen_params = OmniDiffusionSamplingParams( + height=height, + width=width, + num_outputs_per_prompt=num_outputs_per_prompt, + seed=seed, + ) + if generator_device is not None: + gen_params.generator_device = generator_device + if num_inference_steps is not None: + gen_params.num_inference_steps = num_inference_steps + if guidance_scale is not None: + gen_params.guidance_scale = guidance_scale + if true_cfg_scale is not None: + gen_params.true_cfg_scale = true_cfg_scale + if num_frames is not None: + gen_params.num_frames = num_frames + if guidance_scale_2 is not None: + gen_params.guidance_scale_2 = guidance_scale_2 + if layers is not None: + gen_params.layers = layers + if resolution is not None: + gen_params.resolution = resolution + + # Parse per-request LoRA (works for both AsyncOmniDiffusion and AsyncOmni). + if lora_body and isinstance(lora_body, dict): + try: + lora_req, lora_scale = parse_lora_request(lora_body) + if lora_req is not None: + gen_params.lora_request = lora_req + if lora_scale is not None: + gen_params.lora_scale = lora_scale + except Exception as e: # pragma: no cover - safeguard + logger.warning("Failed to parse LoRA request: %s", e) + + # Shared prompt building for pure-diffusion invocation. + gen_prompt: OmniTextPrompt = { + "prompt": prompt, + "negative_prompt": negative_prompt, + } + if pil_images: + if len(pil_images) == 1: + gen_prompt["multi_modal_data"] = {"image": pil_images[0]} + else: + od_config = getattr(engine, "od_config", None) + supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False) + if od_config is None: + supports_multimodal_inputs = True + if supports_multimodal_inputs: + gen_prompt["multi_modal_data"] = {"image": pil_images} + else: + return self._create_error_response( + "Multiple input images are not supported by the current diffusion model. " + "For multi-image editing, start the server with Qwen-Image-Edit-2509 " + "and send multiple images in the user message content.", + status_code=400, + ) + + # Generate image. + if isinstance(engine, AsyncOmni): + diffusion_engine = cast(AsyncOmni, engine) + stage_configs = getattr(diffusion_engine, "stage_configs", None) or [] + if len(stage_configs) > 1: + engine_prompt, sampling_params_list = self._build_multistage_generation_inputs( + engine=diffusion_engine, + prompt=prompt, + extra_body=extra_body, + reference_images=pil_images, + gen_params=gen_params, + ) + else: + # Keep compatibility for single-stage AsyncOmni topologies. + engine_prompt = gen_prompt + sampling_params_list = [gen_params] + result = None + async for output in diffusion_engine.generate( + prompt=engine_prompt, + sampling_params_list=sampling_params_list, + request_id=request_id, + ): + result = output + if result is None: + return self._create_error_response("No output generated from AsyncOmni", status_code=500) + else: + diffusion_engine = engine + result = await diffusion_engine.generate( + prompt=gen_prompt, + sampling_params=gen_params, + request_id=request_id, + ) + + images = getattr(result.request_output, "images", []) + stage_durations = result.stage_durations + peak_memory_mb = result.peak_memory_mb + + flat_images: list[Image.Image] = [] + for item in images: + if isinstance(item, list): + flat_images.extend(item) + else: + flat_images.append(item) + + return flat_images, stage_durations, peak_memory_mb + async def _create_diffusion_chat_completion( self, request: ChatCompletionRequest, @@ -2062,38 +2344,6 @@ async def _create_diffusion_chat_completion( if not extra_body: extra_body = request.model_extra or {} - # Parse size if provided (supports "1024x1024" format) - height = extra_body.get("height") - width = extra_body.get("width") - if "size" in extra_body: - try: - size_str = extra_body["size"] - if isinstance(size_str, str) and "x" in size_str.lower(): - w, h = size_str.lower().split("x") - width, height = int(w), int(h) - except ValueError: - logger.warning("Invalid size format: %s", extra_body.get("size")) - - # Get request parameters from extra_body. - # Avoid hardcoded defaults here — let each pipeline's forward() - # method apply its own model-specific default when the user does - # not provide a value. - num_inference_steps = extra_body.get("num_inference_steps") - guidance_scale = extra_body.get("guidance_scale") - true_cfg_scale = extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale") - seed = extra_body.get("seed") - negative_prompt = extra_body.get("negative_prompt") - num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt", 1) - - # Text-to-video parameters (ref: text_to_video.py) - num_frames = extra_body.get("num_frames") - guidance_scale_2 = extra_body.get("guidance_scale_2") - lora_body = extra_body.get("lora") - - # Qwen-Image-Layered parameters - layers = extra_body.get("layers") - resolution = extra_body.get("resolution") - logger.info( "Diffusion chat request %s: prompt=%r, ref_images=%d, params=%s", request_id, @@ -2102,112 +2352,18 @@ async def _create_diffusion_chat_completion( {k: v for k, v in extra_body.items() if v is not None}, ) - # Decode reference images if provided - pil_images: list[Image.Image] = [] - for img_b64 in reference_images: - try: - img_bytes = base64.b64decode(img_b64) - pil_images.append(Image.open(BytesIO(img_bytes))) - except Exception as e: - logger.warning("Failed to decode reference image: %s", e) - - # Build generation kwargs - gen_prompt: OmniTextPrompt = { - "prompt": prompt, - "negative_prompt": negative_prompt, - } - gen_params = OmniDiffusionSamplingParams( - height=height, - width=width, - num_outputs_per_prompt=num_outputs_per_prompt, - seed=seed, + generation_result = await self.generate_diffusion_images( + prompt=prompt, + extra_body=extra_body, + reference_images=reference_images, + request_id=request_id, ) + if isinstance(generation_result, ErrorResponse): + return generation_result + flat_images, stage_durations, peak_memory_mb = generation_result - # Only override defaults when the user explicitly provides values - if num_inference_steps is not None: - gen_params.num_inference_steps = num_inference_steps - if guidance_scale is not None: - gen_params.guidance_scale = guidance_scale - if true_cfg_scale is not None: - gen_params.true_cfg_scale = true_cfg_scale - if num_frames is not None: - gen_params.num_frames = num_frames - if guidance_scale_2 is not None: - gen_params.guidance_scale_2 = guidance_scale_2 - if layers is not None: - gen_params.layers = layers - if resolution is not None: - gen_params.resolution = resolution - - # Parse per-request LoRA (works for both AsyncOmniDiffusion and AsyncOmni). - if lora_body and isinstance(lora_body, dict): - try: - lora_req, lora_scale = parse_lora_request(lora_body) - if lora_req is not None: - gen_params.lora_request = lora_req - if lora_scale is not None: - gen_params.lora_scale = lora_scale - except Exception as e: # pragma: no cover - safeguard - logger.warning("Failed to parse LoRA request: %s", e) - - # Add reference image if provided - if pil_images: - if len(pil_images) == 1: - gen_prompt["multi_modal_data"] = {} - gen_prompt["multi_modal_data"]["image"] = pil_images[0] - else: - od_config = getattr(self._diffusion_engine, "od_config", None) - supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False) - if od_config is None: - # TODO: entry is asyncOmni. We hack the od config here. - supports_multimodal_inputs = True - if supports_multimodal_inputs: - gen_prompt["multi_modal_data"] = {} - gen_prompt["multi_modal_data"]["image"] = pil_images - else: - return self._create_error_response( - "Multiple input images are not supported by the current diffusion model. " - "For multi-image editing, start the server with Qwen-Image-Edit-2509 " - "and send multiple images in the user message content.", - status_code=400, - ) - - # Generate image - # Handle both AsyncOmniDiffusion (returns OmniRequestOutput) and AsyncOmni (returns AsyncGenerator) - if isinstance(self._diffusion_engine, AsyncOmni): - diffusion_engine = cast(AsyncOmni, self._diffusion_engine) - result = None - async for output in diffusion_engine.generate( - prompt=gen_prompt, - sampling_params_list=[gen_params], # Pass as single-stage params - request_id=request_id, - ): - result = output - if result is None: - return self._create_error_response("No output generated from AsyncOmni") - else: - # AsyncOmniDiffusion: direct call - diffusion_engine = cast(AsyncOmniDiffusion, self._diffusion_engine) - result = await diffusion_engine.generate( - prompt=gen_prompt, - sampling_params=gen_params, - request_id=request_id, - ) - # Extract images from result - # Handle nested OmniRequestOutput structure where images might be in request_output - images = getattr(result.request_output, "images", []) - stage_durations = result.stage_durations - peak_memory_mb = result.peak_memory_mb - - # Convert images to base64 content + # Convert images to base64 content. image_contents: list[dict[str, Any]] = [] - flat_images = [] - for item in images: - if isinstance(item, list): - flat_images.extend(item) - else: - flat_images.append(item) - for img in flat_images: with BytesIO() as buffer: img.save(buffer, format="PNG") @@ -2264,7 +2420,7 @@ async def _create_diffusion_chat_completion( logger.info( "Diffusion chat completed for request %s: %d images", request_id, - len(images), + len(flat_images), ) return response From 8a9151a61fd50ec7d6aa73c80075c29a16138b33 Mon Sep 17 00:00:00 2001 From: Lancer Date: Fri, 27 Mar 2026 16:51:51 +0800 Subject: [PATCH 2/5] upd Signed-off-by: Lancer --- .../openai_api/test_image_server.py | 4 +- vllm_omni/entrypoints/openai/api_server.py | 85 ++++++++++--------- vllm_omni/entrypoints/openai/serving_chat.py | 26 ++++-- 3 files changed, 70 insertions(+), 45 deletions(-) diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index 656b3978aa2..2f48679d6a9 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -163,7 +163,6 @@ def test_client(mock_async_diffusion): from fastapi import FastAPI from vllm_omni.entrypoints.openai.api_server import router - from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat app = FastAPI() app.include_router(router) @@ -286,6 +285,9 @@ def test_models_endpoint_no_engine(): def test_generate_single_image(test_client): """Test generating a single image""" + # Single-stage path should not require openai_serving_chat. + assert not hasattr(test_client.app.state, "openai_serving_chat") + response = test_client.post( "/v1/images/generations", json={ diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index d3f069ad999..3a6cd71a0c8 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -118,6 +118,7 @@ # Supported resolution buckets for layered models (e.g., Qwen-Image-Layered) SUPPORTED_LAYERED_RESOLUTIONS = (640, 1024) +MAX_UINT32_SEED = 2**32 - 1 profiler_router = APIRouter() @@ -1288,44 +1289,50 @@ def _should_route_images_via_chat(configs: list[Any]) -> bool: try: if _should_route_images_via_chat(stage_configs): chat_handler = Omnichat(raw_request) - if chat_handler is not None: - effective_seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1) - extra_body: dict[str, Any] = { - "seed": effective_seed, - "num_outputs_per_prompt": request.n, - } - if request.size is not None: - # Keep /images validation semantics: invalid size should fail with 400. - parse_size(request.size) - extra_body["size"] = request.size - if request.negative_prompt is not None: - extra_body["negative_prompt"] = request.negative_prompt - if request.num_inference_steps is not None: - extra_body["num_inference_steps"] = request.num_inference_steps - if request.guidance_scale is not None: - extra_body["guidance_scale"] = request.guidance_scale - if request.true_cfg_scale is not None: - extra_body["true_cfg_scale"] = request.true_cfg_scale - if request.generator_device is not None: - extra_body["generator_device"] = request.generator_device - if request.lora is not None: - # Keep /images validation semantics: invalid LoRA should fail with 400. - _parse_lora_request(request.lora) - extra_body["lora"] = request.lora - - generation_result = await chat_handler.generate_diffusion_images( - prompt=request.prompt, - extra_body=extra_body, - request_id=f"img_gen-{random_uuid()}", + if chat_handler is None: + logger.warning("openai_serving_chat is not initialized for multi-stage /v1/images/generations") + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, + detail="openai_serving_chat is not initialized for multi-stage image generation.", ) - if isinstance(generation_result, ErrorResponse): - return JSONResponse( - status_code=generation_result.error.code if generation_result.error else 400, - content=generation_result.model_dump(), - ) - flat_images, _, _ = generation_result - image_data = [ImageData(b64_json=encode_image_base64(img), revised_prompt=None) for img in flat_images] - return ImageGenerationResponse(created=int(time.time()), data=image_data) + + effective_seed = request.seed if request.seed is not None else random.randint(0, MAX_UINT32_SEED) + extra_body: dict[str, Any] = { + "seed": effective_seed, + "num_outputs_per_prompt": request.n, + } + if request.size is not None: + # Keep /images validation semantics: invalid size should fail with 400. + parse_size(request.size) + extra_body["size"] = request.size + if request.negative_prompt is not None: + extra_body["negative_prompt"] = request.negative_prompt + if request.num_inference_steps is not None: + extra_body["num_inference_steps"] = request.num_inference_steps + if request.guidance_scale is not None: + extra_body["guidance_scale"] = request.guidance_scale + if request.true_cfg_scale is not None: + extra_body["true_cfg_scale"] = request.true_cfg_scale + if request.generator_device is not None: + extra_body["generator_device"] = request.generator_device + if request.lora is not None: + # Keep /images validation semantics: invalid LoRA should fail with 400. + _parse_lora_request(request.lora) + extra_body["lora"] = request.lora + + generation_result = await chat_handler.generate_diffusion_images( + prompt=request.prompt, + extra_body=extra_body, + request_id=f"img_gen-{random_uuid()}", + ) + if isinstance(generation_result, ErrorResponse): + return JSONResponse( + status_code=generation_result.error.code if generation_result.error else 400, + content=generation_result.model_dump(), + ) + flat_images, _, _ = generation_result + image_data = [ImageData(b64_json=encode_image_base64(img), revised_prompt=None) for img in flat_images] + return ImageGenerationResponse(created=int(time.time()), data=image_data) # Build params - pass through user values directly prompt: OmniTextPrompt = {"prompt": request.prompt} @@ -1357,7 +1364,7 @@ def _should_route_images_via_chat(configs: list[Any]) -> bool: # This fixes issues where using the default global generator # might produce blurry images in some environments. _update_if_not_none( - gen_params, "seed", request.seed if request.seed is not None else random.randint(0, 2**32 - 1) + gen_params, "seed", request.seed if request.seed is not None else random.randint(0, MAX_UINT32_SEED) ) _update_if_not_none(gen_params, "generator_device", request.generator_device) @@ -1562,7 +1569,7 @@ async def edit_images( # a proper generator is initialized in the backend. # This fixes issues where using the default global generator # might produce blurry images in some environments. - _update_if_not_none(gen_params, "seed", seed if seed is not None else random.randint(0, 2**32 - 1)) + _update_if_not_none(gen_params, "seed", seed if seed is not None else random.randint(0, MAX_UINT32_SEED)) _update_if_not_none(gen_params, "generator_device", generator_device) _update_if_not_none(gen_params, "layers", layers) _update_if_not_none(gen_params, "resolution", resolution) diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index d87276c59a5..49959cd2896 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -688,7 +688,13 @@ def _apply_request_overrides( Returns: New SamplingParams with YAML defaults overridden by request values. """ - params = default_params.clone() + clone_fn = getattr(default_params, "clone", None) + if not callable(clone_fn): + raise ValueError(f"default sampling params does not support clone(): {type(default_params).__name__}") + try: + params = clone_fn() + except Exception as e: + raise ValueError(f"failed to clone default sampling params: {e}") from e for field_name in self._OPENAI_SAMPLING_FIELDS: value = getattr(request, field_name, None) @@ -718,7 +724,7 @@ def _build_sampling_params_list_from_request( default_params_list = self.engine_client.default_sampling_params_list comprehension_idx = self._get_comprehension_stage_index() - sampling_params_list = [] + sampling_params_list: list[SamplingParams] = [] for idx, default_params in enumerate(default_params_list): if isinstance(default_params, dict): default_params = SamplingParams(**default_params) @@ -727,7 +733,14 @@ def _build_sampling_params_list_from_request( sampling_params_list.append(params) else: # For other stages, clone default params - sampling_params_list.append(default_params.clone()) + clone_fn = getattr(default_params, "clone", None) + if callable(clone_fn): + try: + sampling_params_list.append(clone_fn()) + except Exception: + sampling_params_list.append(default_params) + else: + sampling_params_list.append(default_params) return sampling_params_list @@ -2098,8 +2111,11 @@ def _build_multistage_generation_inputs( else: default_stage_params = SamplingParams() - if comprehension_idx is not None and idx == comprehension_idx and seed is not None and hasattr( - default_stage_params, "seed" + if ( + comprehension_idx is not None + and idx == comprehension_idx + and seed is not None + and hasattr(default_stage_params, "seed") ): default_stage_params.seed = seed if stage_type == "diffusion": From 688aac30ee342e4a0d3314e7576c533844ecd083 Mon Sep 17 00:00:00 2001 From: Lancer Date: Sun, 5 Apr 2026 21:30:45 +0800 Subject: [PATCH 3/5] upd Signed-off-by: Lancer --- vllm_omni/entrypoints/openai/api_server.py | 11 +--- vllm_omni/entrypoints/openai/serving_chat.py | 59 +++++++++----------- 2 files changed, 30 insertions(+), 40 deletions(-) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 0cafb29dd14..6fead6f3723 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -120,8 +120,6 @@ logger = init_logger(__name__) router = APIRouter() -# Supported resolution buckets for layered models (e.g., Qwen-Image-Layered) -SUPPORTED_LAYERED_RESOLUTIONS = (640, 1024) MAX_UINT32_SEED = 2**32 - 1 profiler_router = APIRouter() @@ -1303,14 +1301,11 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) f"server is running '{model_name}'. Using server model." ) - def _should_route_images_via_chat(configs: list[Any]) -> bool: + try: # Unify request construction for any multi-stage pipeline to avoid # divergence between /v1/images and /v1/chat/completions. - return len(configs) > 1 - - try: - if _should_route_images_via_chat(stage_configs): - chat_handler = Omnichat(raw_request) + if len(stage_configs) > 1: + chat_handler = getattr(raw_request.app.state, "openai_serving_chat", None) if chat_handler is None: logger.warning("openai_serving_chat is not initialized for multi-stage /v1/images/generations") raise HTTPException( diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 5988e1567aa..84d96add6f4 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -82,7 +82,6 @@ from vllm.utils.collection_utils import as_list from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin -from vllm_omni.entrypoints.openai.image_api_utils import validate_layered_layers from vllm_omni.entrypoints.openai.protocol import OmniChatCompletionStreamResponse from vllm_omni.entrypoints.openai.protocol.audio import AudioResponse, CreateAudio from vllm_omni.entrypoints.openai.utils import get_stage_type, parse_lora_request @@ -709,6 +708,12 @@ def _apply_request_overrides( return params + @staticmethod + def _set_if_supported(obj: Any, **kwargs: Any) -> None: + for key, value in kwargs.items(): + if value is not None and hasattr(obj, key): + setattr(obj, key, value) + def _build_sampling_params_list_from_request( self, request: ChatCompletionRequest, @@ -740,13 +745,12 @@ def _build_sampling_params_list_from_request( else: # For other stages, clone default params clone_fn = getattr(default_params, "clone", None) - if callable(clone_fn): - try: - sampling_params_list.append(clone_fn()) - except Exception: - sampling_params_list.append(default_params) - else: - sampling_params_list.append(default_params) + if not callable(clone_fn): + raise ValueError("default sampling params must provide clone() for safe per-request copying") + try: + sampling_params_list.append(clone_fn()) + except Exception as e: + raise ValueError(f"failed to clone default sampling params: {e}") from e return sampling_params_list @@ -2125,30 +2129,21 @@ def _build_multistage_generation_inputs( ): default_stage_params.seed = seed if stage_type == "diffusion": - if hasattr(default_stage_params, "height") and height is not None: - default_stage_params.height = height - if hasattr(default_stage_params, "width") and width is not None: - default_stage_params.width = width - if hasattr(default_stage_params, "seed") and seed is not None: - default_stage_params.seed = seed - if hasattr(default_stage_params, "generator_device") and generator_device is not None: - default_stage_params.generator_device = generator_device - if hasattr(default_stage_params, "num_outputs_per_prompt") and num_outputs_per_prompt is not None: - default_stage_params.num_outputs_per_prompt = num_outputs_per_prompt - if hasattr(default_stage_params, "num_inference_steps") and num_inference_steps is not None: - default_stage_params.num_inference_steps = num_inference_steps - if hasattr(default_stage_params, "guidance_scale") and guidance_scale is not None: - default_stage_params.guidance_scale = guidance_scale - if hasattr(default_stage_params, "true_cfg_scale") and true_cfg_scale is not None: - default_stage_params.true_cfg_scale = true_cfg_scale - if hasattr(default_stage_params, "num_frames") and num_frames is not None: - default_stage_params.num_frames = num_frames - if hasattr(default_stage_params, "guidance_scale_2") and guidance_scale_2 is not None: - default_stage_params.guidance_scale_2 = guidance_scale_2 - if hasattr(default_stage_params, "layers") and layers is not None: - default_stage_params.layers = layers - if hasattr(default_stage_params, "resolution") and resolution is not None: - default_stage_params.resolution = resolution + self._set_if_supported( + default_stage_params, + height=height, + width=width, + seed=seed, + generator_device=generator_device, + num_outputs_per_prompt=num_outputs_per_prompt, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + true_cfg_scale=true_cfg_scale, + num_frames=num_frames, + guidance_scale_2=guidance_scale_2, + layers=layers, + resolution=resolution, + ) if lora_body and isinstance(lora_body, dict): try: lora_req, lora_scale = parse_lora_request(lora_body) From 89816541ac5cad972213370030b440ab26204670 Mon Sep 17 00:00:00 2001 From: Lancer Date: Wed, 15 Apr 2026 18:05:47 +0800 Subject: [PATCH 4/5] upd Signed-off-by: Lancer --- vllm_omni/entrypoints/openai/serving_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index dbb589c6e64..aca715f3646 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -85,8 +85,8 @@ from vllm_omni.entrypoints.openai.protocol import OmniChatCompletionStreamResponse from vllm_omni.entrypoints.openai.protocol.audio import AudioResponse, CreateAudio from vllm_omni.entrypoints.openai.utils import ( - get_supported_speakers_from_hf_config, get_stage_type, + get_supported_speakers_from_hf_config, parse_lora_request, validate_requested_speaker, ) From f1ce40f8b0c9167154bf10e53492f332d85971e0 Mon Sep 17 00:00:00 2001 From: Lancer Date: Fri, 17 Apr 2026 17:34:04 +0800 Subject: [PATCH 5/5] upd Signed-off-by: Lancer --- .../openai_api/test_image_server.py | 128 ++++++++- vllm_omni/entrypoints/openai/api_server.py | 4 +- vllm_omni/entrypoints/openai/serving_chat.py | 258 +++++++++++++++++- 3 files changed, 377 insertions(+), 13 deletions(-) diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index 29d778aedf4..75e4a148621 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -128,12 +128,18 @@ def __init__(self, images=None): self.captured_prompt = None self._images = images or [Image.new("RGB", (64, 64), color="green")] - async def generate(self, prompt, request_id, sampling_params_list): - self.captured_sampling_params_list = sampling_params_list + async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None): + if sampling_params_list is not None: + self.captured_sampling_params_list = sampling_params_list + else: + self.captured_sampling_params_list = [sampling_params] self.captured_prompt = prompt images = [img.copy() for img in self._images] yield MockGenerationResult(images) + def __class_getitem__(cls, item): + return cls + @pytest.fixture def mock_async_diffusion(mocker: MockerFixture): @@ -192,17 +198,49 @@ def async_omni_test_client(): """Create test client with mocked AsyncOmni engine.""" from fastapi import FastAPI + from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.entrypoints.openai.api_server import router from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + class FakeAsyncOmniClass(AsyncOmni): + def __init__(self): + self.stage_configs = [ + SimpleNamespace(stage_type="llm", is_comprehension=True), + SimpleNamespace(stage_type="diffusion", is_comprehension=False), + ] + self.default_sampling_params_list = [ + SamplingParams(temperature=0.1), + OmniDiffusionSamplingParams(), + ] + self.captured_sampling_params_list = None + self.captured_prompt = None + self._images = [Image.new("RGB", (64, 64), color="green")] + self.od_config = SimpleNamespace(supports_multimodal_inputs=True) + + async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None): + if sampling_params_list is not None: + self.captured_sampling_params_list = sampling_params_list + else: + self.captured_sampling_params_list = [sampling_params] + self.captured_prompt = prompt + images = [img.copy() for img in self._images] + yield MockGenerationResult(images) + + def __class_getitem__(cls, item): + return cls + + def get_diffusion_od_config(self): + return self.od_config + app = FastAPI() app.include_router(router) - app.state.engine_client = FakeAsyncOmni() + engine = FakeAsyncOmniClass() chat_handler = object.__new__(OmniOpenAIServingChat) - chat_handler.engine_client = app.state.engine_client + chat_handler.engine_client = engine chat_handler._diffusion_engine = None app.state.openai_serving_chat = chat_handler + app.state.engine_client = engine app.state.stage_configs = [ SimpleNamespace(stage_type="llm"), SimpleNamespace(stage_type="diffusion"), @@ -219,12 +257,49 @@ def async_omni_rgba_test_client(): """Create test client with mocked AsyncOmni engine returning RGBA output.""" from fastapi import FastAPI + from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.entrypoints.openai.api_server import router + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + class FakeAsyncOmniClass(AsyncOmni): + def __init__(self): + self.stage_configs = [ + SimpleNamespace(stage_type="llm", is_comprehension=True), + SimpleNamespace(stage_type="diffusion", is_comprehension=False), + ] + self.default_sampling_params_list = [ + SamplingParams(temperature=0.1), + OmniDiffusionSamplingParams(), + ] + self.captured_sampling_params_list = None + self.captured_prompt = None + self._images = [Image.new("RGBA", (64, 64), color=(0, 255, 0, 128))] + self.od_config = SimpleNamespace(supports_multimodal_inputs=True) + + async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None): + if sampling_params_list is not None: + self.captured_sampling_params_list = sampling_params_list + else: + self.captured_sampling_params_list = [sampling_params] + self.captured_prompt = prompt + images = [img.copy() for img in self._images] + yield MockGenerationResult(images) + + def __class_getitem__(cls, item): + return cls + + def get_diffusion_od_config(self): + return self.od_config app = FastAPI() app.include_router(router) - app.state.engine_client = FakeAsyncOmni(images=[Image.new("RGBA", (64, 64), color=(0, 255, 0, 128))]) + engine = FakeAsyncOmniClass() + chat_handler = object.__new__(OmniOpenAIServingChat) + chat_handler.engine_client = engine + chat_handler._diffusion_engine = None + app.state.openai_serving_chat = chat_handler + app.state.engine_client = engine app.state.stage_configs = [ SimpleNamespace(stage_type="llm"), SimpleNamespace(stage_type="diffusion"), @@ -241,16 +316,50 @@ def async_omni_stage_configs_only_client(): """Create test client with refactored AsyncOmni compatibility surface only.""" from fastapi import FastAPI + from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.entrypoints.openai.api_server import router + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + class FakeAsyncOmniClass(AsyncOmni): + def __init__(self): + self.stage_configs = [ + SimpleNamespace(stage_type="llm", is_comprehension=True), + SimpleNamespace(stage_type="diffusion", is_comprehension=False), + ] + self.default_sampling_params_list = [ + SamplingParams(temperature=0.1), + OmniDiffusionSamplingParams(), + ] + self.captured_sampling_params_list = None + self.captured_prompt = None + self._images = [Image.new("RGB", (64, 64), color="green")] + self.od_config = SimpleNamespace(supports_multimodal_inputs=True) + + async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None): + if sampling_params_list is not None: + self.captured_sampling_params_list = sampling_params_list + else: + self.captured_sampling_params_list = [sampling_params] + self.captured_prompt = prompt + images = [img.copy() for img in self._images] + yield MockGenerationResult(images) + + def __class_getitem__(cls, item): + return cls + + def get_diffusion_od_config(self): + return self.od_config app = FastAPI() app.include_router(router) - engine = FakeAsyncOmni() + engine = FakeAsyncOmniClass() assert not hasattr(engine, "stage_list") app.state.engine_client = engine - # Intentionally do not populate app.state.stage_configs. Refactored - # AsyncOmni exposes stage_configs on the engine instance. + chat_handler = object.__new__(OmniOpenAIServingChat) + chat_handler.engine_client = engine + chat_handler._diffusion_engine = None + app.state.openai_serving_chat = chat_handler app.state.args = Namespace( default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}', max_generated_image_size=1024 * 1792, @@ -385,8 +494,7 @@ def test_generate_images_async_omni_stage_configs_only(async_omni_stage_configs_ assert captured[1].seed == 11 -def test_multistage_images_async_omni_construction(async_omni_test_client, mocker: MockerFixture): - mocker.patch("vllm_omni.entrypoints.openai.serving_chat.AsyncOmni", FakeAsyncOmni) +def test_multistage_images_async_omni_construction(async_omni_test_client): """Regression: multistage image generation builds the expected chat-style payload.""" response = async_omni_test_client.post( "/v1/images/generations", diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 9223197bac2..a5ef7e32f78 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1330,8 +1330,10 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) "num_outputs_per_prompt": request.n, } if request.size is not None: - # Keep /images validation semantics: invalid size should fail with 400. parse_size(request.size) + width, height = parse_size(request.size) + app_state_args = getattr(raw_request.app.state, "args", None) + _check_max_generated_image_size(app_state_args, width, height) extra_body["size"] = request.size if request.negative_prompt is not None: extra_body["negative_prompt"] = request.negative_prompt diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 28d6ef277b3..f1e6d56472b 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -733,6 +733,12 @@ def _apply_request_overrides( return params + @staticmethod + def _set_if_supported(obj: Any, **kwargs: Any) -> None: + for key, value in kwargs.items(): + if value is not None and hasattr(obj, key): + setattr(obj, key, value) + def _build_sampling_params_list_from_request( self, request: ChatCompletionRequest, @@ -2057,6 +2063,254 @@ def _create_image_choice( return choices # ==================== Diffusion Mode Methods ==================== + def _build_multistage_generation_inputs( + self, + *, + engine: AsyncOmni, + prompt: str, + extra_body: dict[str, Any], + reference_images: list[Image.Image], + gen_params: OmniDiffusionSamplingParams, + ) -> tuple[OmniTextPrompt, list[Any]]: + """Build the shared multistage generation prompt and stage params.""" + stage_configs = getattr(engine, "stage_configs", None) or [] + default_params_list = list(getattr(engine, "default_sampling_params_list", []) or []) + + height = gen_params.height + width = gen_params.width + seed = gen_params.seed + generator_device = gen_params.generator_device + num_outputs_per_prompt = gen_params.num_outputs_per_prompt + num_inference_steps = extra_body.get("num_inference_steps") + guidance_scale = extra_body.get("guidance_scale") + true_cfg_scale = extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale") + negative_prompt = extra_body.get("negative_prompt") + num_frames = extra_body.get("num_frames") + guidance_scale_2 = extra_body.get("guidance_scale_2") + lora_body = extra_body.get("lora") + layers = extra_body.get("layers") + resolution = extra_body.get("resolution") + + engine_prompt_data: dict[str, Any] | None = None + modalities = ["image"] + if reference_images: + if len(reference_images) == 1: + engine_prompt_data = {"img2img": reference_images[0]} + modalities = ["img2img"] + else: + engine_prompt_data = {"image": reference_images} + + engine_prompt: OmniTextPrompt = {"prompt": prompt} + engine_prompt["modalities"] = modalities + if negative_prompt is not None: + engine_prompt["negative_prompt"] = negative_prompt + + mm_processor_kwargs: dict[str, Any] = {} + if height is not None: + mm_processor_kwargs["target_h"] = height + if width is not None: + mm_processor_kwargs["target_w"] = width + if mm_processor_kwargs: + engine_prompt["mm_processor_kwargs"] = mm_processor_kwargs + if engine_prompt_data is not None: + engine_prompt["multi_modal_data"] = engine_prompt_data + + comprehension_idx = None + for idx, stage in enumerate(stage_configs): + if getattr(stage, "is_comprehension", False): + comprehension_idx = idx + break + + sampling_params_list: list[Any] = [] + for idx, stage_cfg in enumerate(stage_configs): + stage_type = get_stage_type(stage_cfg) + if idx < len(default_params_list): + default_stage_params = default_params_list[idx] + if hasattr(default_stage_params, "clone"): + try: + default_stage_params = default_stage_params.clone() + except Exception: + pass + elif stage_type == "diffusion": + default_stage_params = gen_params.clone() + else: + default_stage_params = SamplingParams() + + if ( + comprehension_idx is not None + and idx == comprehension_idx + and seed is not None + and hasattr(default_stage_params, "seed") + ): + default_stage_params.seed = seed + + if stage_type == "diffusion": + self._set_if_supported( + default_stage_params, + height=height, + width=width, + seed=seed, + generator_device=generator_device, + num_outputs_per_prompt=num_outputs_per_prompt, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + true_cfg_scale=true_cfg_scale, + num_frames=num_frames, + guidance_scale_2=guidance_scale_2, + layers=layers, + resolution=resolution, + ) + if lora_body and isinstance(lora_body, dict): + try: + lora_req, lora_scale = parse_lora_request(lora_body) + if lora_req is not None: + default_stage_params.lora_request = lora_req + if lora_scale is not None: + default_stage_params.lora_scale = lora_scale + except Exception as e: # pragma: no cover - safeguard + logger.warning("Failed to parse LoRA request: %s", e) + + sampling_params_list.append(default_stage_params) + + return engine_prompt, sampling_params_list + + async def generate_diffusion_images( + self, + *, + prompt: str, + extra_body: dict[str, Any] | None = None, + reference_images: list[str] | None = None, + request_id: str | None = None, + ) -> tuple[list[Image.Image], dict[str, Any], float] | ErrorResponse: + """Generate diffusion images and return raw images plus generation stats.""" + if request_id is None: + request_id = f"chatcmpl-{uuid.uuid4().hex[:16]}" + if extra_body is None: + extra_body = {} + if reference_images is None: + reference_images = [] + + engine = self._diffusion_engine if self._diffusion_engine is not None else self.engine_client + + height = extra_body.get("height") + width = extra_body.get("width") + if "size" in extra_body: + try: + size_str = extra_body["size"] + if isinstance(size_str, str) and "x" in size_str.lower(): + w, h = size_str.lower().split("x") + width, height = int(w), int(h) + except ValueError: + logger.warning("Invalid size format: %s", extra_body.get("size")) + + seed = extra_body.get("seed") + generator_device = extra_body.get("generator_device") + negative_prompt = extra_body.get("negative_prompt") + num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt", 1) + lora_body = extra_body.get("lora") + + pil_images: list[Image.Image] = [] + for img_b64 in reference_images: + try: + img_bytes = base64.b64decode(img_b64) + pil_images.append(Image.open(BytesIO(img_bytes))) + except Exception as e: + logger.warning("Failed to decode reference image: %s", e) + + gen_params = OmniDiffusionSamplingParams( + height=height, + width=width, + num_outputs_per_prompt=num_outputs_per_prompt, + seed=seed, + ) + self._set_if_supported( + gen_params, + generator_device=generator_device, + num_inference_steps=extra_body.get("num_inference_steps"), + guidance_scale=extra_body.get("guidance_scale"), + true_cfg_scale=extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale"), + num_frames=extra_body.get("num_frames"), + guidance_scale_2=extra_body.get("guidance_scale_2"), + layers=extra_body.get("layers"), + resolution=extra_body.get("resolution"), + ) + + if lora_body and isinstance(lora_body, dict): + try: + lora_req, lora_scale = parse_lora_request(lora_body) + if lora_req is not None: + gen_params.lora_request = lora_req + if lora_scale is not None: + gen_params.lora_scale = lora_scale + except Exception as e: # pragma: no cover - safeguard + logger.warning("Failed to parse LoRA request: %s", e) + + gen_prompt: OmniTextPrompt = { + "prompt": prompt, + "negative_prompt": negative_prompt, + } + if pil_images: + if len(pil_images) == 1: + gen_prompt["multi_modal_data"] = {"image": pil_images[0]} + else: + od_config = getattr(engine, "od_config", None) + supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False) + if od_config is None: + supports_multimodal_inputs = True + if supports_multimodal_inputs: + gen_prompt["multi_modal_data"] = {"image": pil_images} + else: + return self._create_error_response( + "Multiple input images are not supported by the current diffusion model. " + "For multi-image editing, start the server with Qwen-Image-Edit-2509 " + "and send multiple images in the user message content.", + status_code=400, + ) + + if isinstance(engine, AsyncOmni): + diffusion_engine = cast(AsyncOmni, engine) + stage_configs = getattr(diffusion_engine, "stage_configs", None) or [] + if len(stage_configs) > 1: + engine_prompt, sampling_params_list = self._build_multistage_generation_inputs( + engine=diffusion_engine, + prompt=prompt, + extra_body=extra_body, + reference_images=pil_images, + gen_params=gen_params, + ) + else: + engine_prompt = gen_prompt + sampling_params_list = [gen_params] + + result = None + async for output in diffusion_engine.generate( + prompt=engine_prompt, + sampling_params_list=sampling_params_list, + request_id=request_id, + ): + result = output + if result is None: + return self._create_error_response("No output generated from AsyncOmni", status_code=500) + else: + result = await engine.generate( + prompt=gen_prompt, + sampling_params=gen_params, + request_id=request_id, + ) + + images = getattr(result.request_output, "images", []) + stage_durations = result.stage_durations + peak_memory_mb = result.peak_memory_mb + + flat_images: list[Image.Image] = [] + for item in images: + if isinstance(item, list): + flat_images.extend(item) + else: + flat_images.append(item) + + return flat_images, stage_durations, peak_memory_mb + async def _create_diffusion_chat_completion( self, request: ChatCompletionRequest, @@ -2234,8 +2488,8 @@ async def _create_diffusion_chat_completion( if hasattr(default_stage_params, "clone"): try: default_stage_params = default_stage_params.clone() - except Exception: - pass + except Exception as e: + logger.warning("Failed to clone default params for stage %d: %s", idx, e) sampling_params_list.append(default_stage_params) if not sampling_params_list: