Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions vllm_omni/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,12 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
# Keep /images validation semantics: invalid LoRA should fail with 400.
_parse_lora_request(request.lora)
extra_body["lora"] = request.lora
if request.bot_task is not None:
extra_body["bot_task"] = request.bot_task
if request.use_system_prompt is not None:
extra_body["use_system_prompt"] = request.use_system_prompt
if request.system_prompt is not None:
extra_body["system_prompt"] = request.system_prompt

generation_result = await chat_handler.generate_diffusion_images(
prompt=request.prompt,
Expand All @@ -1544,8 +1550,9 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
status_code=generation_result.error.code if generation_result.error else 400,
content=generation_result.model_dump(),
)
flat_images, _, _ = generation_result
flat_images, _, _, _ = generation_result
image_data = [ImageData(b64_json=encode_image_base64(img), revised_prompt=None) for img in flat_images]

return ImageGenerationResponse(created=int(time.time()), data=image_data)

# Build params - pass through user values directly
Expand All @@ -1558,6 +1565,8 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
extra_args["use_system_prompt"] = request.use_system_prompt
if request.system_prompt is not None:
extra_args["system_prompt"] = request.system_prompt
if request.bot_task is not None:
extra_args["bot_task"] = request.bot_task
if extra_args:
gen_params.extra_args = extra_args
# Parse per-request LoRA (compatible with chat's extra_body.lora shape).
Expand Down Expand Up @@ -1725,6 +1734,7 @@ async def edit_images(
)
try:
# 2. Build prompt & images params
cot_output = None
prompt: OmniTextPrompt = {"prompt": prompt}
if negative_prompt is not None:
prompt["negative_prompt"] = negative_prompt
Expand Down Expand Up @@ -1935,7 +1945,7 @@ async def edit_images(
status_code=generation_result.error.code if generation_result.error else 400,
detail=generation_result.message,
)
images, _, _ = generation_result
images, _, _, cot_output = generation_result
else:
# Single-stage diffusion: use the direct path.
result = await _generate_with_async_omni(
Expand Down Expand Up @@ -1965,6 +1975,7 @@ async def edit_images(
data=image_data,
output_format=output_format,
size=size_str,
cot_output=cot_output,
)

except (EngineGenerateError, EngineDeadError) as exc:
Expand Down
10 changes: 10 additions & 0 deletions vllm_omni/entrypoints/openai/protocol/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class ImageGenerationRequest(BaseModel):

# Required fields
prompt: str = Field(..., description="Text description of the desired image(s)")
bot_task: str | None = Field(
None,
description="Task mode for the model (e.g., 'cot' enables chain-of-thought generation). "
"Only supported by specific diffusion models.",
)

# OpenAI standard fields
model: str | None = Field(
Expand Down Expand Up @@ -165,3 +170,8 @@ class ImageGenerationResponse(BaseModel):
data: list[ImageData] = Field(..., description="Array of generated images")
output_format: str = Field(None, description="The output format of the image generation")
size: str = Field(None, description="The size of the image generated")
cot_output: str | None = Field(
None,
description="Chain-of-thought text output from the AR stage. "
"Only present for image editing (IT2I) with CoT-enabled models.",
)
24 changes: 19 additions & 5 deletions vllm_omni/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2248,7 +2248,7 @@ def _build_multistage_generation_inputs(
layers = extra_body.get("layers")
resolution = extra_body.get("resolution")
bot_task = extra_body.get("bot_task")
sys_type = extra_body.get("sys_type")
use_system_prompt = extra_body.get("use_system_prompt") or extra_body.get("sys_type")
custom_system_prompt = extra_body.get("system_prompt")

engine_prompt_data: dict[str, Any] | None = None
Expand All @@ -2262,24 +2262,27 @@ def _build_multistage_generation_inputs(

prompt_token_ids: list[int] | None = None
system_prompt_type: str | None = None
if bot_task is not None or sys_type is not None or custom_system_prompt is not None:
build_kwargs: dict[str, Any] = {}

if bot_task is not None or use_system_prompt is not None or custom_system_prompt is not None:
from vllm_omni.diffusion.models.hunyuan_image3.prompt_utils import (
build_prompt,
build_prompt_tokens,
)

build_kwargs: dict[str, Any] = {
"task": "it2i" if reference_images else "t2i",
"sys_type": sys_type,
"sys_type": use_system_prompt,
"custom_system_prompt": custom_system_prompt,
"num_images": len(reference_images) if reference_images else 1,
}

if bot_task is not None:
build_kwargs["bot_task"] = bot_task
elif "bot_task" in extra_body:
# Explicit None from the caller is plain-mode; omitted lets
# each task fall back to its default trigger.
build_kwargs["bot_task"] = None
build_kwargs["bot_task"] = extra_body["bot_task"]
if tokenizer is not None:
# Feed segment-tokenized prompt_token_ids so AR matches HF
# apply_chat_template byte-for-byte (engine BPE would merge
Expand Down Expand Up @@ -2515,6 +2518,17 @@ async def generate_diffusion_images(
images = getattr(result.request_output, "images", [])
stage_durations = result.stage_durations
peak_memory_mb = result.peak_memory_mb
cot_output = None

req_out = getattr(result, "request_output", None)
if req_out:
prompt_obj = getattr(req_out, "prompt", None)
if isinstance(prompt_obj, dict):
extra = prompt_obj.get("extra", {})
if isinstance(extra, dict):
ar_text = extra.get("ar_generated_text")
if isinstance(ar_text, str) and ar_text.strip():
cot_output = ar_text

flat_images: list[Image.Image] = []
for item in images:
Expand All @@ -2523,7 +2537,7 @@ async def generate_diffusion_images(
else:
flat_images.append(item)

return flat_images, stage_durations, peak_memory_mb
return flat_images, stage_durations, peak_memory_mb, cot_output

async def _create_diffusion_chat_completion(
self,
Expand Down
Loading