Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions vllm_omni/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2129,6 +2129,7 @@ def _build_multistage_generation_inputs(
extra_body: dict[str, Any],
reference_images: list[Image.Image],
gen_params: OmniDiffusionSamplingParams,
tokenizer: Any = None,
) -> tuple[OmniTextPrompt, list[Any]]:
"""Build the shared multistage generation prompt and stage params."""
stage_configs = getattr(engine, "stage_configs", None) or []
Expand Down Expand Up @@ -2159,16 +2160,26 @@ def _build_multistage_generation_inputs(
else:
engine_prompt_data = {"image": reference_images}

prompt_token_ids: list[int] | None = None
if bot_task:
from vllm_omni.diffusion.models.hunyuan_image3.prompt_utils import build_prompt
from vllm_omni.diffusion.models.hunyuan_image3.prompt_utils import (
build_prompt,
build_prompt_tokens,
)

if tokenizer is not None:
prompt_token_ids = build_prompt_tokens(prompt, tokenizer, task=bot_task)
else:
prompt = build_prompt(prompt, task=bot_task)
Comment thread
Bounty-hunter marked this conversation as resolved.

prompt = build_prompt(prompt, task=bot_task)
if reference_images and len(reference_images) == 1:
engine_prompt_data = {"image": reference_images[0]}
modalities = ["image"]

engine_prompt: OmniTextPrompt = {"prompt": prompt}
engine_prompt["modalities"] = modalities
if prompt_token_ids is not None:
engine_prompt["prompt_token_ids"] = prompt_token_ids
if negative_prompt is not None:
engine_prompt["negative_prompt"] = negative_prompt

Expand Down Expand Up @@ -2337,12 +2348,20 @@ async def generate_diffusion_images(
diffusion_engine = cast(AsyncOmni, engine)
stage_configs = getattr(diffusion_engine, "stage_configs", None) or []
if len(stage_configs) > 1:
tokenizer = None
get_tok = getattr(diffusion_engine, "get_tokenizer", None)
if get_tok is not None:
try:
tokenizer = await get_tok()
except Exception as exc:
logger.warning("get_tokenizer failed: %s", exc)
engine_prompt, sampling_params_list = self._build_multistage_generation_inputs(
engine=diffusion_engine,
prompt=prompt,
extra_body=extra_body,
reference_images=pil_images,
gen_params=gen_params,
tokenizer=tokenizer,
)
else:
engine_prompt = gen_prompt
Expand Down
Loading