diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index cd1e83393b0..75e4a148621 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -106,10 +106,13 @@ def test_encode_image_base64(): class MockGenerationResult: - """Mock result object from AsyncOmni.generate()""" + """Mock result object compatible with current diffusion output shape.""" def __init__(self, images): self.images = images + self.request_output = SimpleNamespace(images=images) + self.stage_durations = {} + self.peak_memory_mb = 0.0 class FakeAsyncOmni: @@ -117,20 +120,26 @@ class FakeAsyncOmni: def __init__(self, images=None): self.stage_configs = [ - SimpleNamespace(stage_type="llm"), - SimpleNamespace(stage_type="diffusion"), + SimpleNamespace(stage_type="llm", is_comprehension=True), + SimpleNamespace(stage_type="diffusion", is_comprehension=False), ] self.default_sampling_params_list = [SamplingParams(temperature=0.1), OmniDiffusionSamplingParams()] self.captured_sampling_params_list = None self.captured_prompt = None self._images = images or [Image.new("RGB", (64, 64), color="green")] - async def generate(self, prompt, request_id, sampling_params_list): - self.captured_sampling_params_list = sampling_params_list + async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None): + if sampling_params_list is not None: + self.captured_sampling_params_list = sampling_params_list + else: + self.captured_sampling_params_list = [sampling_params] self.captured_prompt = prompt images = [img.copy() for img in self._images] yield MockGenerationResult(images) + def __class_getitem__(cls, item): + return cls + @pytest.fixture def mock_async_diffusion(mocker: MockerFixture): @@ -189,12 +198,49 @@ def async_omni_test_client(): """Create test client with mocked AsyncOmni engine.""" from fastapi import FastAPI + from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.entrypoints.openai.api_server import router + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + class FakeAsyncOmniClass(AsyncOmni): + def __init__(self): + self.stage_configs = [ + SimpleNamespace(stage_type="llm", is_comprehension=True), + SimpleNamespace(stage_type="diffusion", is_comprehension=False), + ] + self.default_sampling_params_list = [ + SamplingParams(temperature=0.1), + OmniDiffusionSamplingParams(), + ] + self.captured_sampling_params_list = None + self.captured_prompt = None + self._images = [Image.new("RGB", (64, 64), color="green")] + self.od_config = SimpleNamespace(supports_multimodal_inputs=True) + + async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None): + if sampling_params_list is not None: + self.captured_sampling_params_list = sampling_params_list + else: + self.captured_sampling_params_list = [sampling_params] + self.captured_prompt = prompt + images = [img.copy() for img in self._images] + yield MockGenerationResult(images) + + def __class_getitem__(cls, item): + return cls + + def get_diffusion_od_config(self): + return self.od_config app = FastAPI() app.include_router(router) - app.state.engine_client = FakeAsyncOmni() + engine = FakeAsyncOmniClass() + chat_handler = object.__new__(OmniOpenAIServingChat) + chat_handler.engine_client = engine + chat_handler._diffusion_engine = None + app.state.openai_serving_chat = chat_handler + app.state.engine_client = engine app.state.stage_configs = [ SimpleNamespace(stage_type="llm"), SimpleNamespace(stage_type="diffusion"), @@ -211,12 +257,49 @@ def async_omni_rgba_test_client(): """Create test client with mocked AsyncOmni engine returning RGBA output.""" from fastapi import FastAPI + from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.entrypoints.openai.api_server import router + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + class FakeAsyncOmniClass(AsyncOmni): + def __init__(self): + self.stage_configs = [ + SimpleNamespace(stage_type="llm", is_comprehension=True), + SimpleNamespace(stage_type="diffusion", is_comprehension=False), + ] + self.default_sampling_params_list = [ + SamplingParams(temperature=0.1), + OmniDiffusionSamplingParams(), + ] + self.captured_sampling_params_list = None + self.captured_prompt = None + self._images = [Image.new("RGBA", (64, 64), color=(0, 255, 0, 128))] + self.od_config = SimpleNamespace(supports_multimodal_inputs=True) + + async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None): + if sampling_params_list is not None: + self.captured_sampling_params_list = sampling_params_list + else: + self.captured_sampling_params_list = [sampling_params] + self.captured_prompt = prompt + images = [img.copy() for img in self._images] + yield MockGenerationResult(images) + + def __class_getitem__(cls, item): + return cls + + def get_diffusion_od_config(self): + return self.od_config app = FastAPI() app.include_router(router) - app.state.engine_client = FakeAsyncOmni(images=[Image.new("RGBA", (64, 64), color=(0, 255, 0, 128))]) + engine = FakeAsyncOmniClass() + chat_handler = object.__new__(OmniOpenAIServingChat) + chat_handler.engine_client = engine + chat_handler._diffusion_engine = None + app.state.openai_serving_chat = chat_handler + app.state.engine_client = engine app.state.stage_configs = [ SimpleNamespace(stage_type="llm"), SimpleNamespace(stage_type="diffusion"), @@ -233,16 +316,50 @@ def async_omni_stage_configs_only_client(): """Create test client with refactored AsyncOmni compatibility surface only.""" from fastapi import FastAPI + from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.entrypoints.openai.api_server import router + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + class FakeAsyncOmniClass(AsyncOmni): + def __init__(self): + self.stage_configs = [ + SimpleNamespace(stage_type="llm", is_comprehension=True), + SimpleNamespace(stage_type="diffusion", is_comprehension=False), + ] + self.default_sampling_params_list = [ + SamplingParams(temperature=0.1), + OmniDiffusionSamplingParams(), + ] + self.captured_sampling_params_list = None + self.captured_prompt = None + self._images = [Image.new("RGB", (64, 64), color="green")] + self.od_config = SimpleNamespace(supports_multimodal_inputs=True) + + async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None): + if sampling_params_list is not None: + self.captured_sampling_params_list = sampling_params_list + else: + self.captured_sampling_params_list = [sampling_params] + self.captured_prompt = prompt + images = [img.copy() for img in self._images] + yield MockGenerationResult(images) + + def __class_getitem__(cls, item): + return cls + + def get_diffusion_od_config(self): + return self.od_config app = FastAPI() app.include_router(router) - engine = FakeAsyncOmni() + engine = FakeAsyncOmniClass() assert not hasattr(engine, "stage_list") app.state.engine_client = engine - # Intentionally do not populate app.state.stage_configs. Refactored - # AsyncOmni exposes stage_configs on the engine instance. + chat_handler = object.__new__(OmniOpenAIServingChat) + chat_handler.engine_client = engine + chat_handler._diffusion_engine = None + app.state.openai_serving_chat = chat_handler app.state.args = Namespace( default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}', max_generated_image_size=1024 * 1792, @@ -306,6 +423,9 @@ def test_models_endpoint_no_engine(): def test_generate_single_image(test_client): """Test generating a single image""" + # Single-stage path should not require openai_serving_chat. + assert not hasattr(test_client.app.state, "openai_serving_chat") + response = test_client.post( "/v1/images/generations", json={ @@ -374,6 +494,43 @@ def test_generate_images_async_omni_stage_configs_only(async_omni_stage_configs_ assert captured[1].seed == 11 +def test_multistage_images_async_omni_construction(async_omni_test_client): + """Regression: multistage image generation builds the expected chat-style payload.""" + response = async_omni_test_client.post( + "/v1/images/generations", + json={ + "prompt": "a cat", + "n": 2, + "size": "128x256", + "seed": 7, + "num_inference_steps": 12, + "guidance_scale": 6.5, + }, + ) + assert response.status_code == 200 + + engine = async_omni_test_client.app.state.engine_client + captured_prompt = engine.captured_prompt + assert captured_prompt["prompt"] == "a cat" + assert captured_prompt["modalities"] == ["image"] + assert captured_prompt["mm_processor_kwargs"] == { + "target_h": 256, + "target_w": 128, + } + + captured = engine.captured_sampling_params_list + assert captured is not None + assert len(captured) == 2 + assert captured[0].temperature == 0.1 + assert captured[0].seed == 7 + assert captured[1].num_outputs_per_prompt == 2 + assert captured[1].width == 128 + assert captured[1].height == 256 + assert captured[1].seed == 7 + assert captured[1].num_inference_steps == 12 + assert captured[1].guidance_scale == 6.5 + + def test_image_edits_async_omni_stage_configs_only(async_omni_stage_configs_only_client): """Regression: image edits accepts refactored AsyncOmni without stage_list.""" img_bytes = make_test_image_bytes((16, 16)) diff --git a/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py b/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py new file mode 100644 index 00000000000..a9b9f53ba8a --- /dev/null +++ b/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Regression tests for multistage diffusion generation input construction.""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +from PIL import Image +from vllm.sampling_params import SamplingParams + +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +@pytest.fixture +def serving_chat(): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + return object.__new__(OmniOpenAIServingChat) + + +def test_build_multistage_generation_inputs_applies_stage_specific_overrides(serving_chat): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + engine = SimpleNamespace( + stage_configs=[ + SimpleNamespace(stage_type="llm", is_comprehension=True), + SimpleNamespace(stage_type="diffusion", is_comprehension=False), + SimpleNamespace(stage_type="diffusion", is_comprehension=False), + ], + default_sampling_params_list=[ + SamplingParams(temperature=0.2, seed=11), + OmniDiffusionSamplingParams(), + OmniDiffusionSamplingParams(), + ], + ) + reference_image = Image.new("RGB", (24, 24), color="green") + extra_body = { + "negative_prompt": "blurry", + "num_inference_steps": 28, + "guidance_scale": 7.5, + "true_cfg_scale": 5.0, + "guidance_scale_2": 1.25, + "layers": 6, + "resolution": 1024, + "lora": {"name": "adapter-a", "path": "/tmp/adapter-a", "scale": 0.6}, + } + gen_params = OmniDiffusionSamplingParams(height=768, width=1024, seed=0, num_outputs_per_prompt=2) + + engine_prompt, sampling_params_list = OmniOpenAIServingChat._build_multistage_generation_inputs( + serving_chat, + engine=engine, + prompt="draw a robot", + extra_body=extra_body, + reference_images=[reference_image], + gen_params=gen_params, + ) + + assert engine_prompt["prompt"] == "draw a robot" + assert engine_prompt["modalities"] == ["img2img"] + assert engine_prompt["negative_prompt"] == "blurry" + assert engine_prompt["mm_processor_kwargs"] == {"target_h": 768, "target_w": 1024} + assert engine_prompt["multi_modal_data"]["img2img"].size == (24, 24) + + assert len(sampling_params_list) == 3 + assert sampling_params_list[0].temperature == 0.2 + assert sampling_params_list[0].seed == 0 + assert sampling_params_list[1].height == 768 + assert sampling_params_list[1].width == 1024 + assert sampling_params_list[1].seed == 0 + assert sampling_params_list[1].num_inference_steps == 28 + assert sampling_params_list[1].guidance_scale == 7.5 + assert sampling_params_list[1].num_outputs_per_prompt == 2 + assert sampling_params_list[1].true_cfg_scale == 5.0 + assert sampling_params_list[1].lora_request.name == "adapter-a" + assert sampling_params_list[2].height == 768 + assert sampling_params_list[2].width == 1024 + assert sampling_params_list[2].num_inference_steps == 28 + assert engine.default_sampling_params_list[1].height is None + assert engine.default_sampling_params_list[2].resolution == 640 diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 1e45758368d..a5ef7e32f78 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -121,6 +121,7 @@ logger = init_logger(__name__) router = APIRouter() +MAX_UINT32_SEED = 2**32 - 1 profiler_router = APIRouter() @@ -1312,6 +1313,57 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) ) try: + # Unify request construction for any multi-stage pipeline to avoid + # divergence between /v1/images and /v1/chat/completions. + if len(stage_configs) > 1: + chat_handler = getattr(raw_request.app.state, "openai_serving_chat", None) + if chat_handler is None: + logger.warning("openai_serving_chat is not initialized for multi-stage /v1/images/generations") + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, + detail="openai_serving_chat is not initialized for multi-stage image generation.", + ) + + effective_seed = request.seed if request.seed is not None else random.randint(0, MAX_UINT32_SEED) + extra_body: dict[str, Any] = { + "seed": effective_seed, + "num_outputs_per_prompt": request.n, + } + if request.size is not None: + parse_size(request.size) + width, height = parse_size(request.size) + app_state_args = getattr(raw_request.app.state, "args", None) + _check_max_generated_image_size(app_state_args, width, height) + extra_body["size"] = request.size + if request.negative_prompt is not None: + extra_body["negative_prompt"] = request.negative_prompt + if request.num_inference_steps is not None: + extra_body["num_inference_steps"] = request.num_inference_steps + if request.guidance_scale is not None: + extra_body["guidance_scale"] = request.guidance_scale + if request.true_cfg_scale is not None: + extra_body["true_cfg_scale"] = request.true_cfg_scale + if request.generator_device is not None: + extra_body["generator_device"] = request.generator_device + if request.lora is not None: + # Keep /images validation semantics: invalid LoRA should fail with 400. + _parse_lora_request(request.lora) + extra_body["lora"] = request.lora + + generation_result = await chat_handler.generate_diffusion_images( + prompt=request.prompt, + extra_body=extra_body, + request_id=f"img_gen-{random_uuid()}", + ) + if isinstance(generation_result, ErrorResponse): + return JSONResponse( + status_code=generation_result.error.code if generation_result.error else 400, + content=generation_result.model_dump(), + ) + flat_images, _, _ = generation_result + image_data = [ImageData(b64_json=encode_image_base64(img), revised_prompt=None) for img in flat_images] + return ImageGenerationResponse(created=int(time.time()), data=image_data) + # Build params - pass through user values directly prompt: OmniTextPrompt = {"prompt": request.prompt} if request.negative_prompt is not None: @@ -1352,7 +1404,7 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) # This fixes issues where using the default global generator # might produce blurry images in some environments. _update_if_not_none( - gen_params, "seed", request.seed if request.seed is not None else random.randint(0, 2**32 - 1) + gen_params, "seed", request.seed if request.seed is not None else random.randint(0, MAX_UINT32_SEED) ) _update_if_not_none(gen_params, "generator_device", request.generator_device) _update_if_not_none(gen_params, "layers", request.layers) @@ -1561,7 +1613,7 @@ async def edit_images( # a proper generator is initialized in the backend. # This fixes issues where using the default global generator # might produce blurry images in some environments. - _update_if_not_none(gen_params, "seed", seed if seed is not None else random.randint(0, 2**32 - 1)) + _update_if_not_none(gen_params, "seed", seed if seed is not None else random.randint(0, MAX_UINT32_SEED)) _update_if_not_none(gen_params, "generator_device", generator_device) _update_if_not_none(gen_params, "layers", layers) _update_if_not_none(gen_params, "resolution", resolution) diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 28d6ef277b3..f1e6d56472b 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -733,6 +733,12 @@ def _apply_request_overrides( return params + @staticmethod + def _set_if_supported(obj: Any, **kwargs: Any) -> None: + for key, value in kwargs.items(): + if value is not None and hasattr(obj, key): + setattr(obj, key, value) + def _build_sampling_params_list_from_request( self, request: ChatCompletionRequest, @@ -2057,6 +2063,254 @@ def _create_image_choice( return choices # ==================== Diffusion Mode Methods ==================== + def _build_multistage_generation_inputs( + self, + *, + engine: AsyncOmni, + prompt: str, + extra_body: dict[str, Any], + reference_images: list[Image.Image], + gen_params: OmniDiffusionSamplingParams, + ) -> tuple[OmniTextPrompt, list[Any]]: + """Build the shared multistage generation prompt and stage params.""" + stage_configs = getattr(engine, "stage_configs", None) or [] + default_params_list = list(getattr(engine, "default_sampling_params_list", []) or []) + + height = gen_params.height + width = gen_params.width + seed = gen_params.seed + generator_device = gen_params.generator_device + num_outputs_per_prompt = gen_params.num_outputs_per_prompt + num_inference_steps = extra_body.get("num_inference_steps") + guidance_scale = extra_body.get("guidance_scale") + true_cfg_scale = extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale") + negative_prompt = extra_body.get("negative_prompt") + num_frames = extra_body.get("num_frames") + guidance_scale_2 = extra_body.get("guidance_scale_2") + lora_body = extra_body.get("lora") + layers = extra_body.get("layers") + resolution = extra_body.get("resolution") + + engine_prompt_data: dict[str, Any] | None = None + modalities = ["image"] + if reference_images: + if len(reference_images) == 1: + engine_prompt_data = {"img2img": reference_images[0]} + modalities = ["img2img"] + else: + engine_prompt_data = {"image": reference_images} + + engine_prompt: OmniTextPrompt = {"prompt": prompt} + engine_prompt["modalities"] = modalities + if negative_prompt is not None: + engine_prompt["negative_prompt"] = negative_prompt + + mm_processor_kwargs: dict[str, Any] = {} + if height is not None: + mm_processor_kwargs["target_h"] = height + if width is not None: + mm_processor_kwargs["target_w"] = width + if mm_processor_kwargs: + engine_prompt["mm_processor_kwargs"] = mm_processor_kwargs + if engine_prompt_data is not None: + engine_prompt["multi_modal_data"] = engine_prompt_data + + comprehension_idx = None + for idx, stage in enumerate(stage_configs): + if getattr(stage, "is_comprehension", False): + comprehension_idx = idx + break + + sampling_params_list: list[Any] = [] + for idx, stage_cfg in enumerate(stage_configs): + stage_type = get_stage_type(stage_cfg) + if idx < len(default_params_list): + default_stage_params = default_params_list[idx] + if hasattr(default_stage_params, "clone"): + try: + default_stage_params = default_stage_params.clone() + except Exception: + pass + elif stage_type == "diffusion": + default_stage_params = gen_params.clone() + else: + default_stage_params = SamplingParams() + + if ( + comprehension_idx is not None + and idx == comprehension_idx + and seed is not None + and hasattr(default_stage_params, "seed") + ): + default_stage_params.seed = seed + + if stage_type == "diffusion": + self._set_if_supported( + default_stage_params, + height=height, + width=width, + seed=seed, + generator_device=generator_device, + num_outputs_per_prompt=num_outputs_per_prompt, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + true_cfg_scale=true_cfg_scale, + num_frames=num_frames, + guidance_scale_2=guidance_scale_2, + layers=layers, + resolution=resolution, + ) + if lora_body and isinstance(lora_body, dict): + try: + lora_req, lora_scale = parse_lora_request(lora_body) + if lora_req is not None: + default_stage_params.lora_request = lora_req + if lora_scale is not None: + default_stage_params.lora_scale = lora_scale + except Exception as e: # pragma: no cover - safeguard + logger.warning("Failed to parse LoRA request: %s", e) + + sampling_params_list.append(default_stage_params) + + return engine_prompt, sampling_params_list + + async def generate_diffusion_images( + self, + *, + prompt: str, + extra_body: dict[str, Any] | None = None, + reference_images: list[str] | None = None, + request_id: str | None = None, + ) -> tuple[list[Image.Image], dict[str, Any], float] | ErrorResponse: + """Generate diffusion images and return raw images plus generation stats.""" + if request_id is None: + request_id = f"chatcmpl-{uuid.uuid4().hex[:16]}" + if extra_body is None: + extra_body = {} + if reference_images is None: + reference_images = [] + + engine = self._diffusion_engine if self._diffusion_engine is not None else self.engine_client + + height = extra_body.get("height") + width = extra_body.get("width") + if "size" in extra_body: + try: + size_str = extra_body["size"] + if isinstance(size_str, str) and "x" in size_str.lower(): + w, h = size_str.lower().split("x") + width, height = int(w), int(h) + except ValueError: + logger.warning("Invalid size format: %s", extra_body.get("size")) + + seed = extra_body.get("seed") + generator_device = extra_body.get("generator_device") + negative_prompt = extra_body.get("negative_prompt") + num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt", 1) + lora_body = extra_body.get("lora") + + pil_images: list[Image.Image] = [] + for img_b64 in reference_images: + try: + img_bytes = base64.b64decode(img_b64) + pil_images.append(Image.open(BytesIO(img_bytes))) + except Exception as e: + logger.warning("Failed to decode reference image: %s", e) + + gen_params = OmniDiffusionSamplingParams( + height=height, + width=width, + num_outputs_per_prompt=num_outputs_per_prompt, + seed=seed, + ) + self._set_if_supported( + gen_params, + generator_device=generator_device, + num_inference_steps=extra_body.get("num_inference_steps"), + guidance_scale=extra_body.get("guidance_scale"), + true_cfg_scale=extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale"), + num_frames=extra_body.get("num_frames"), + guidance_scale_2=extra_body.get("guidance_scale_2"), + layers=extra_body.get("layers"), + resolution=extra_body.get("resolution"), + ) + + if lora_body and isinstance(lora_body, dict): + try: + lora_req, lora_scale = parse_lora_request(lora_body) + if lora_req is not None: + gen_params.lora_request = lora_req + if lora_scale is not None: + gen_params.lora_scale = lora_scale + except Exception as e: # pragma: no cover - safeguard + logger.warning("Failed to parse LoRA request: %s", e) + + gen_prompt: OmniTextPrompt = { + "prompt": prompt, + "negative_prompt": negative_prompt, + } + if pil_images: + if len(pil_images) == 1: + gen_prompt["multi_modal_data"] = {"image": pil_images[0]} + else: + od_config = getattr(engine, "od_config", None) + supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False) + if od_config is None: + supports_multimodal_inputs = True + if supports_multimodal_inputs: + gen_prompt["multi_modal_data"] = {"image": pil_images} + else: + return self._create_error_response( + "Multiple input images are not supported by the current diffusion model. " + "For multi-image editing, start the server with Qwen-Image-Edit-2509 " + "and send multiple images in the user message content.", + status_code=400, + ) + + if isinstance(engine, AsyncOmni): + diffusion_engine = cast(AsyncOmni, engine) + stage_configs = getattr(diffusion_engine, "stage_configs", None) or [] + if len(stage_configs) > 1: + engine_prompt, sampling_params_list = self._build_multistage_generation_inputs( + engine=diffusion_engine, + prompt=prompt, + extra_body=extra_body, + reference_images=pil_images, + gen_params=gen_params, + ) + else: + engine_prompt = gen_prompt + sampling_params_list = [gen_params] + + result = None + async for output in diffusion_engine.generate( + prompt=engine_prompt, + sampling_params_list=sampling_params_list, + request_id=request_id, + ): + result = output + if result is None: + return self._create_error_response("No output generated from AsyncOmni", status_code=500) + else: + result = await engine.generate( + prompt=gen_prompt, + sampling_params=gen_params, + request_id=request_id, + ) + + images = getattr(result.request_output, "images", []) + stage_durations = result.stage_durations + peak_memory_mb = result.peak_memory_mb + + flat_images: list[Image.Image] = [] + for item in images: + if isinstance(item, list): + flat_images.extend(item) + else: + flat_images.append(item) + + return flat_images, stage_durations, peak_memory_mb + async def _create_diffusion_chat_completion( self, request: ChatCompletionRequest, @@ -2234,8 +2488,8 @@ async def _create_diffusion_chat_completion( if hasattr(default_stage_params, "clone"): try: default_stage_params = default_stage_params.clone() - except Exception: - pass + except Exception as e: + logger.warning("Failed to clone default params for stage %d: %s", idx, e) sampling_params_list.append(default_stage_params) if not sampling_params_list: