diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index c1467f7190a..ddc6e36815a 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1533,6 +1533,12 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) # Keep /images validation semantics: invalid LoRA should fail with 400. _parse_lora_request(request.lora) extra_body["lora"] = request.lora + if request.bot_task is not None: + extra_body["bot_task"] = request.bot_task + if request.use_system_prompt is not None: + extra_body["use_system_prompt"] = request.use_system_prompt + if request.system_prompt is not None: + extra_body["system_prompt"] = request.system_prompt generation_result = await chat_handler.generate_diffusion_images( prompt=request.prompt, @@ -1544,8 +1550,9 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) status_code=generation_result.error.code if generation_result.error else 400, content=generation_result.model_dump(), ) - flat_images, _, _ = generation_result + flat_images, _, _, _ = generation_result image_data = [ImageData(b64_json=encode_image_base64(img), revised_prompt=None) for img in flat_images] + return ImageGenerationResponse(created=int(time.time()), data=image_data) # Build params - pass through user values directly @@ -1558,6 +1565,8 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) extra_args["use_system_prompt"] = request.use_system_prompt if request.system_prompt is not None: extra_args["system_prompt"] = request.system_prompt + if request.bot_task is not None: + extra_args["bot_task"] = request.bot_task if extra_args: gen_params.extra_args = extra_args # Parse per-request LoRA (compatible with chat's extra_body.lora shape). @@ -1725,6 +1734,7 @@ async def edit_images( ) try: # 2. Build prompt & images params + cot_output = None prompt: OmniTextPrompt = {"prompt": prompt} if negative_prompt is not None: prompt["negative_prompt"] = negative_prompt @@ -1935,7 +1945,7 @@ async def edit_images( status_code=generation_result.error.code if generation_result.error else 400, detail=generation_result.message, ) - images, _, _ = generation_result + images, _, _, cot_output = generation_result else: # Single-stage diffusion: use the direct path. result = await _generate_with_async_omni( @@ -1965,6 +1975,7 @@ async def edit_images( data=image_data, output_format=output_format, size=size_str, + cot_output=cot_output, ) except (EngineGenerateError, EngineDeadError) as exc: diff --git a/vllm_omni/entrypoints/openai/protocol/images.py b/vllm_omni/entrypoints/openai/protocol/images.py index 0fb22a548cf..dbf4c24b348 100644 --- a/vllm_omni/entrypoints/openai/protocol/images.py +++ b/vllm_omni/entrypoints/openai/protocol/images.py @@ -32,6 +32,11 @@ class ImageGenerationRequest(BaseModel): # Required fields prompt: str = Field(..., description="Text description of the desired image(s)") + bot_task: str | None = Field( + None, + description="Task mode for the model (e.g., 'cot' enables chain-of-thought generation). " + "Only supported by specific diffusion models.", + ) # OpenAI standard fields model: str | None = Field( @@ -165,3 +170,8 @@ class ImageGenerationResponse(BaseModel): data: list[ImageData] = Field(..., description="Array of generated images") output_format: str = Field(None, description="The output format of the image generation") size: str = Field(None, description="The size of the image generated") + cot_output: str | None = Field( + None, + description="Chain-of-thought text output from the AR stage. " + "Only present for image editing (IT2I) with CoT-enabled models.", + ) diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 2c375fa2928..4f9e7dcc8b8 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -2248,7 +2248,7 @@ def _build_multistage_generation_inputs( layers = extra_body.get("layers") resolution = extra_body.get("resolution") bot_task = extra_body.get("bot_task") - sys_type = extra_body.get("sys_type") + use_system_prompt = extra_body.get("use_system_prompt") or extra_body.get("sys_type") custom_system_prompt = extra_body.get("system_prompt") engine_prompt_data: dict[str, Any] | None = None @@ -2262,7 +2262,9 @@ def _build_multistage_generation_inputs( prompt_token_ids: list[int] | None = None system_prompt_type: str | None = None - if bot_task is not None or sys_type is not None or custom_system_prompt is not None: + build_kwargs: dict[str, Any] = {} + + if bot_task is not None or use_system_prompt is not None or custom_system_prompt is not None: from vllm_omni.diffusion.models.hunyuan_image3.prompt_utils import ( build_prompt, build_prompt_tokens, @@ -2270,16 +2272,17 @@ def _build_multistage_generation_inputs( build_kwargs: dict[str, Any] = { "task": "it2i" if reference_images else "t2i", - "sys_type": sys_type, + "sys_type": use_system_prompt, "custom_system_prompt": custom_system_prompt, "num_images": len(reference_images) if reference_images else 1, } + if bot_task is not None: build_kwargs["bot_task"] = bot_task elif "bot_task" in extra_body: # Explicit None from the caller is plain-mode; omitted lets # each task fall back to its default trigger. - build_kwargs["bot_task"] = None + build_kwargs["bot_task"] = extra_body["bot_task"] if tokenizer is not None: # Feed segment-tokenized prompt_token_ids so AR matches HF # apply_chat_template byte-for-byte (engine BPE would merge @@ -2515,6 +2518,17 @@ async def generate_diffusion_images( images = getattr(result.request_output, "images", []) stage_durations = result.stage_durations peak_memory_mb = result.peak_memory_mb + cot_output = None + + req_out = getattr(result, "request_output", None) + if req_out: + prompt_obj = getattr(req_out, "prompt", None) + if isinstance(prompt_obj, dict): + extra = prompt_obj.get("extra", {}) + if isinstance(extra, dict): + ar_text = extra.get("ar_generated_text") + if isinstance(ar_text, str) and ar_text.strip(): + cot_output = ar_text flat_images: list[Image.Image] = [] for item in images: @@ -2523,7 +2537,7 @@ async def generate_diffusion_images( else: flat_images.append(item) - return flat_images, stage_durations, peak_memory_mb + return flat_images, stage_durations, peak_memory_mb, cot_output async def _create_diffusion_chat_completion( self,