Skip to content
177 changes: 167 additions & 10 deletions tests/entrypoints/openai_api/test_image_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,31 +106,40 @@ def test_encode_image_base64():


class MockGenerationResult:
"""Mock result object from AsyncOmni.generate()"""
"""Mock result object compatible with current diffusion output shape."""

def __init__(self, images):
self.images = images
self.request_output = SimpleNamespace(images=images)
self.stage_durations = {}
self.peak_memory_mb = 0.0


class FakeAsyncOmni:
"""Fake AsyncOmni that yields a single diffusion output."""

def __init__(self, images=None):
self.stage_configs = [
SimpleNamespace(stage_type="llm"),
SimpleNamespace(stage_type="diffusion"),
SimpleNamespace(stage_type="llm", is_comprehension=True),
SimpleNamespace(stage_type="diffusion", is_comprehension=False),
]
self.default_sampling_params_list = [SamplingParams(temperature=0.1), OmniDiffusionSamplingParams()]
self.captured_sampling_params_list = None
self.captured_prompt = None
self._images = images or [Image.new("RGB", (64, 64), color="green")]

async def generate(self, prompt, request_id, sampling_params_list):
self.captured_sampling_params_list = sampling_params_list
async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None):
if sampling_params_list is not None:
self.captured_sampling_params_list = sampling_params_list
else:
self.captured_sampling_params_list = [sampling_params]
self.captured_prompt = prompt
images = [img.copy() for img in self._images]
yield MockGenerationResult(images)

def __class_getitem__(cls, item):
return cls


@pytest.fixture
def mock_async_diffusion(mocker: MockerFixture):
Expand Down Expand Up @@ -189,12 +198,49 @@ def async_omni_test_client():
"""Create test client with mocked AsyncOmni engine."""
from fastapi import FastAPI

from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.openai.api_server import router
from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat

class FakeAsyncOmniClass(AsyncOmni):
def __init__(self):
self.stage_configs = [
SimpleNamespace(stage_type="llm", is_comprehension=True),
SimpleNamespace(stage_type="diffusion", is_comprehension=False),
]
self.default_sampling_params_list = [
SamplingParams(temperature=0.1),
OmniDiffusionSamplingParams(),
]
self.captured_sampling_params_list = None
self.captured_prompt = None
self._images = [Image.new("RGB", (64, 64), color="green")]
self.od_config = SimpleNamespace(supports_multimodal_inputs=True)

async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None):
if sampling_params_list is not None:
self.captured_sampling_params_list = sampling_params_list
else:
self.captured_sampling_params_list = [sampling_params]
self.captured_prompt = prompt
images = [img.copy() for img in self._images]
yield MockGenerationResult(images)

def __class_getitem__(cls, item):
return cls

def get_diffusion_od_config(self):
return self.od_config

app = FastAPI()
app.include_router(router)

app.state.engine_client = FakeAsyncOmni()
engine = FakeAsyncOmniClass()
chat_handler = object.__new__(OmniOpenAIServingChat)
chat_handler.engine_client = engine
chat_handler._diffusion_engine = None
app.state.openai_serving_chat = chat_handler
app.state.engine_client = engine
app.state.stage_configs = [
SimpleNamespace(stage_type="llm"),
SimpleNamespace(stage_type="diffusion"),
Expand All @@ -211,12 +257,49 @@ def async_omni_rgba_test_client():
"""Create test client with mocked AsyncOmni engine returning RGBA output."""
from fastapi import FastAPI

from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.openai.api_server import router
from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat

class FakeAsyncOmniClass(AsyncOmni):
def __init__(self):
self.stage_configs = [
SimpleNamespace(stage_type="llm", is_comprehension=True),
SimpleNamespace(stage_type="diffusion", is_comprehension=False),
]
self.default_sampling_params_list = [
SamplingParams(temperature=0.1),
OmniDiffusionSamplingParams(),
]
self.captured_sampling_params_list = None
self.captured_prompt = None
self._images = [Image.new("RGBA", (64, 64), color=(0, 255, 0, 128))]
self.od_config = SimpleNamespace(supports_multimodal_inputs=True)

async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None):
if sampling_params_list is not None:
self.captured_sampling_params_list = sampling_params_list
else:
self.captured_sampling_params_list = [sampling_params]
self.captured_prompt = prompt
images = [img.copy() for img in self._images]
yield MockGenerationResult(images)

def __class_getitem__(cls, item):
return cls

def get_diffusion_od_config(self):
return self.od_config

app = FastAPI()
app.include_router(router)

app.state.engine_client = FakeAsyncOmni(images=[Image.new("RGBA", (64, 64), color=(0, 255, 0, 128))])
engine = FakeAsyncOmniClass()
chat_handler = object.__new__(OmniOpenAIServingChat)
chat_handler.engine_client = engine
chat_handler._diffusion_engine = None
app.state.openai_serving_chat = chat_handler
app.state.engine_client = engine
app.state.stage_configs = [
SimpleNamespace(stage_type="llm"),
SimpleNamespace(stage_type="diffusion"),
Expand All @@ -233,16 +316,50 @@ def async_omni_stage_configs_only_client():
"""Create test client with refactored AsyncOmni compatibility surface only."""
from fastapi import FastAPI

from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.openai.api_server import router
from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat

class FakeAsyncOmniClass(AsyncOmni):
def __init__(self):
self.stage_configs = [
SimpleNamespace(stage_type="llm", is_comprehension=True),
SimpleNamespace(stage_type="diffusion", is_comprehension=False),
]
self.default_sampling_params_list = [
SamplingParams(temperature=0.1),
OmniDiffusionSamplingParams(),
]
self.captured_sampling_params_list = None
self.captured_prompt = None
self._images = [Image.new("RGB", (64, 64), color="green")]
self.od_config = SimpleNamespace(supports_multimodal_inputs=True)

async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None):
if sampling_params_list is not None:
self.captured_sampling_params_list = sampling_params_list
else:
self.captured_sampling_params_list = [sampling_params]
self.captured_prompt = prompt
images = [img.copy() for img in self._images]
yield MockGenerationResult(images)

def __class_getitem__(cls, item):
return cls

def get_diffusion_od_config(self):
return self.od_config

app = FastAPI()
app.include_router(router)

engine = FakeAsyncOmni()
engine = FakeAsyncOmniClass()
assert not hasattr(engine, "stage_list")
app.state.engine_client = engine
# Intentionally do not populate app.state.stage_configs. Refactored
# AsyncOmni exposes stage_configs on the engine instance.
chat_handler = object.__new__(OmniOpenAIServingChat)
chat_handler.engine_client = engine
chat_handler._diffusion_engine = None
app.state.openai_serving_chat = chat_handler
app.state.args = Namespace(
default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}',
max_generated_image_size=1024 * 1792,
Expand Down Expand Up @@ -306,6 +423,9 @@ def test_models_endpoint_no_engine():

def test_generate_single_image(test_client):
"""Test generating a single image"""
# Single-stage path should not require openai_serving_chat.
assert not hasattr(test_client.app.state, "openai_serving_chat")

response = test_client.post(
"/v1/images/generations",
json={
Expand Down Expand Up @@ -374,6 +494,43 @@ def test_generate_images_async_omni_stage_configs_only(async_omni_stage_configs_
assert captured[1].seed == 11


def test_multistage_images_async_omni_construction(async_omni_test_client):
"""Regression: multistage image generation builds the expected chat-style payload."""
response = async_omni_test_client.post(
"/v1/images/generations",
json={
"prompt": "a cat",
"n": 2,
"size": "128x256",
"seed": 7,
"num_inference_steps": 12,
"guidance_scale": 6.5,
},
)
assert response.status_code == 200

engine = async_omni_test_client.app.state.engine_client
captured_prompt = engine.captured_prompt
assert captured_prompt["prompt"] == "a cat"
assert captured_prompt["modalities"] == ["image"]
assert captured_prompt["mm_processor_kwargs"] == {
"target_h": 256,
"target_w": 128,
}

captured = engine.captured_sampling_params_list
assert captured is not None
assert len(captured) == 2
assert captured[0].temperature == 0.1
assert captured[0].seed == 7
assert captured[1].num_outputs_per_prompt == 2
assert captured[1].width == 128
assert captured[1].height == 256
assert captured[1].seed == 7
assert captured[1].num_inference_steps == 12
assert captured[1].guidance_scale == 6.5


def test_image_edits_async_omni_stage_configs_only(async_omni_stage_configs_only_client):
"""Regression: image edits accepts refactored AsyncOmni without stage_list."""
img_bytes = make_test_image_bytes((16, 16))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# SPDX-License-Identifier: Apache-2.0
"""Regression tests for multistage diffusion generation input construction."""

from __future__ import annotations

from types import SimpleNamespace

import pytest
from PIL import Image
from vllm.sampling_params import SamplingParams

from vllm_omni.inputs.data import OmniDiffusionSamplingParams

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]


@pytest.fixture
def serving_chat():
from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat

return object.__new__(OmniOpenAIServingChat)


def test_build_multistage_generation_inputs_applies_stage_specific_overrides(serving_chat):
from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat

engine = SimpleNamespace(
stage_configs=[
SimpleNamespace(stage_type="llm", is_comprehension=True),
SimpleNamespace(stage_type="diffusion", is_comprehension=False),
SimpleNamespace(stage_type="diffusion", is_comprehension=False),
],
default_sampling_params_list=[
SamplingParams(temperature=0.2, seed=11),
OmniDiffusionSamplingParams(),
OmniDiffusionSamplingParams(),
],
)
reference_image = Image.new("RGB", (24, 24), color="green")
extra_body = {
"negative_prompt": "blurry",
"num_inference_steps": 28,
"guidance_scale": 7.5,
"true_cfg_scale": 5.0,
"guidance_scale_2": 1.25,
"layers": 6,
"resolution": 1024,
"lora": {"name": "adapter-a", "path": "/tmp/adapter-a", "scale": 0.6},
}
gen_params = OmniDiffusionSamplingParams(height=768, width=1024, seed=0, num_outputs_per_prompt=2)

engine_prompt, sampling_params_list = OmniOpenAIServingChat._build_multistage_generation_inputs(
serving_chat,
engine=engine,
prompt="draw a robot",
extra_body=extra_body,
reference_images=[reference_image],
gen_params=gen_params,
)

assert engine_prompt["prompt"] == "draw a robot"
assert engine_prompt["modalities"] == ["img2img"]
assert engine_prompt["negative_prompt"] == "blurry"
assert engine_prompt["mm_processor_kwargs"] == {"target_h": 768, "target_w": 1024}
assert engine_prompt["multi_modal_data"]["img2img"].size == (24, 24)

assert len(sampling_params_list) == 3
assert sampling_params_list[0].temperature == 0.2
assert sampling_params_list[0].seed == 0
assert sampling_params_list[1].height == 768
assert sampling_params_list[1].width == 1024
assert sampling_params_list[1].seed == 0
assert sampling_params_list[1].num_inference_steps == 28
assert sampling_params_list[1].guidance_scale == 7.5
assert sampling_params_list[1].num_outputs_per_prompt == 2
assert sampling_params_list[1].true_cfg_scale == 5.0
assert sampling_params_list[1].lora_request.name == "adapter-a"
assert sampling_params_list[2].height == 768
assert sampling_params_list[2].width == 1024
assert sampling_params_list[2].num_inference_steps == 28
assert engine.default_sampling_params_list[1].height is None
assert engine.default_sampling_params_list[2].resolution == 640
Loading
Loading