diff --git a/examples/offline_inference/hunyuan_image3/README.md b/examples/offline_inference/hunyuan_image3/README.md index 6db4cbec9ed..8b90e6b7fa3 100644 --- a/examples/offline_inference/hunyuan_image3/README.md +++ b/examples/offline_inference/hunyuan_image3/README.md @@ -112,6 +112,7 @@ python end2end.py --modality text2img \ --additional-config '{"torchair_graph_config":{"enabled":true}}' ``` + ## Key Arguments | Argument | Description | @@ -123,16 +124,15 @@ python end2end.py --modality text2img \ | `--steps` | Number of diffusion inference steps for image generation. | | `--guidance-scale` | Classifier-free guidance scale for image generation. | | `--height`, `--width` | Output image size for `text2img`. | -| `--bot-task` | Prompt behavior. `auto` selects the default from `--modality`; `think` adds ``; `recaption` adds ``; `vanilla` uses the text-to-image pretrain template. | +| `--bot-task` | Override prompt mode. `none`, `think`, `recaption`, `think_recaption`, or `vanilla`. | | `--sys-type` | Override the system prompt type, for example `en_unified` or `en_vanilla`. | | `--vae-use-tiling` | Enable VAE tiling for memory reduction. | ## Notes -- `hunyuan_image3_ar.yaml` is a 4-card AR-only text/comprehension deploy. It sets `engine_output_type: text`, `final_output_type: text`, and text sampling defaults. -- `hunyuan_image3_dit.yaml` is a single-stage DiT deploy with `stage_id: 0`; it does not require stage 1 or a running AR stage. +- `hunyuan_image3_ar.yaml` is a 4-card AR-only text/comprehension deploy. +- `hunyuan_image3_dit.yaml` is a single-stage DiT deploy with `stage_id: 0`. - The old HunyuanImage3 YAMLs under `model_executor/stage_configs/` and `platforms/*/stage_configs/` have been folded into the deploy YAMLs. -- This PR does not keep the HunyuanImage3 AR-to-DiT KV reuse wiring. The deploy YAMLs describe the topology and platform settings only. ## Prompt Format @@ -148,22 +148,8 @@ Assistant: {trigger_tag?} - ``: Placeholder for each input image (single token; expanded by the multimodal pipeline). - Trigger tags: `` for CoT and `` for recaptioning, placed after `Assistant: `. -- System prompt: Auto-selected based on task. -- `t2i_vanilla` is the only task that uses the bare pretrain template without chat structure. -- The example composes the internal prompt task from `--modality` and `--bot-task` - before calling `prompt_utils`; for example, `img2text + think` becomes - `i2t_think` for prompt and stop-token lookup. +- System prompt: Auto-selected from `task` and `bot_task`. +- `bot_task='vanilla'` with `task='t2i'` uses the bare pretrain template. The shared `vllm_omni.diffusion.models.hunyuan_image3.prompt_utils.build_prompt_tokens()` helper handles segment-by-segment tokenization and matches HF `apply_chat_template`. - -## FAQ - -- **OOM errors**: Decrease `gpu_memory_utilization` in the deploy YAML, use a smaller `max_num_batched_tokens`, or enable VAE tiling with `--vae-use-tiling`. -- **Custom image sizes**: Use `--height` and `--width` flags (multiples of 16 recommended). - -| Stage | VRAM (approx) | -| :--- | :--- | -| Stage 0 (AR) | ~15 GiB + KV Cache | -| Stage 1 (DiT) | ~30 GiB | -| Total (8-GPU) | ~45 GiB + KV Cache | diff --git a/examples/offline_inference/hunyuan_image3/end2end.py b/examples/offline_inference/hunyuan_image3/end2end.py index 5232568f11e..16f7d8f06c1 100644 --- a/examples/offline_inference/hunyuan_image3/end2end.py +++ b/examples/offline_inference/hunyuan_image3/end2end.py @@ -1,16 +1,5 @@ """ HunyuanImage-3.0-Instruct unified end-to-end inference script. - -Supports all modalities through a single entry point: - - text2img: Text → AR → DiT → Image - - img2img: Text+Image → AR → DiT → Edited Image (IT2I) - - img2text: Image+Text → AR → Text description (I2T) - - text2text: Text → AR → Text (comprehension, no image) - -Usage: - python end2end.py --modality text2img --prompts "A cute cat" - python end2end.py --modality img2img --image-path input.png --prompts "Make it snowy" - python end2end.py --modality img2text --image-path input.png --prompts "Describe this image" """ import argparse @@ -19,18 +8,25 @@ from pathlib import Path from vllm_omni.diffusion.models.hunyuan_image3.prompt_utils import ( - _TASK_PRESETS, + MAX_IMAGES_PER_REQUEST, build_prompt_tokens, resolve_stop_token_ids, + resolve_sys_type, ) from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniPromptType -# Default deploy configs are absolute so this example works from any cwd. _REPO_ROOT = Path(__file__).resolve().parents[3] _DEFAULT_DEPLOY_CONFIG = str(_REPO_ROOT / "vllm_omni" / "deploy" / "hunyuan_image3.yaml") _DEFAULT_AR_DEPLOY_CONFIG = str(_REPO_ROOT / "vllm_omni" / "deploy" / "hunyuan_image3_ar.yaml") +_MODALITY_TASK_MAP: dict[str, tuple[str, str | None]] = { + "text2img": ("t2i", "think"), + "img2img": ("it2i", "think"), + "img2text": ("i2t", None), + "text2text": ("t2t", None), +} + _MODALITY_DEFAULT_DEPLOY_CONFIG = { "text2img": _DEFAULT_DEPLOY_CONFIG, "img2img": _DEFAULT_DEPLOY_CONFIG, @@ -45,73 +41,37 @@ "text2text": "text-to-text", } -_MODALITY_TASK_MAP = { - "text2img": "t2i", - "img2img": "it2i", - "img2text": "i2t", - "text2text": "t2t", -} - def parse_args(): parser = argparse.ArgumentParser(description="HunyuanImage-3.0-Instruct end-to-end inference.") - parser.add_argument( - "--model", - default="tencent/HunyuanImage-3.0-Instruct", - help="Model name or local path.", - ) + parser.add_argument("--model", default="tencent/HunyuanImage-3.0-Instruct", help="Model name or local path.") parser.add_argument( "--modality", default="text2img", - choices=["text2img", "img2img", "img2text", "text2text"], - help="Modality mode to control stage execution.", + choices=list(_MODALITY_TASK_MAP), ) parser.add_argument("--prompts", nargs="+", default=None, help="Input text prompts.") parser.add_argument( "--image-path", type=str, default=None, - help="Path to input image (for img2img/img2text).", - ) - parser.add_argument( - "--output", - type=str, - default=".", - help="Output directory to save results.", + help="Input image path(s) for img2img/img2text. Comma-separated for multi-image (up to 3).", ) - - # Generation parameters + parser.add_argument("--output", type=str, default=".", help="Output directory to save results.") parser.add_argument("--steps", type=int, default=50, help="Number of inference steps.") parser.add_argument("--guidance-scale", type=float, default=5.0, help="Classifier-free guidance scale.") parser.add_argument("--seed", type=int, default=42, help="Random seed.") parser.add_argument("--height", type=int, default=1024, help="Output image height.") parser.add_argument("--width", type=int, default=1024, help="Output image width.") - parser.add_argument( - "--vae-use-tiling", - action="store_true", - help="Enable VAE tiling for memory optimization.", - ) - - # Prompt configuration + parser.add_argument("--vae-use-tiling", action="store_true", help="Enable VAE tiling.") parser.add_argument( "--bot-task", type=str, - default="auto", - choices=["auto", "think", "recaption", "think_recaption", "vanilla"], - help=( - "Prompt behavior. 'auto' selects the default for the modality; " - "'think' adds ; 'recaption' adds ; " - "'vanilla' uses the t2i pretrain template." - ), - ) - parser.add_argument( - "--sys-type", - type=str, default=None, - help="Override system prompt type (e.g. en_unified, en_vanilla).", + choices=["none", "think", "recaption", "think_recaption", "vanilla"], + help="Override prompt mode. Default: auto from --modality.", ) - - # Omni init args + parser.add_argument("--sys-type", type=str, default=None, help="Override system prompt type.") parser.add_argument("--deploy-config", type=str, default=None, help="Custom deploy YAML path.") parser.add_argument("--stage-configs-path", type=str, default=None, help="Custom legacy stage config YAML path.") parser.add_argument("--log-stats", action="store_true", default=False) @@ -157,22 +117,13 @@ def main(): os.makedirs(args.output, exist_ok=True) additional_config = parse_additional_config(args.additional_config) - # Determine task for prompt formatting from modality + bot behavior. - task = _MODALITY_TASK_MAP[args.modality] - assert task is not None - bot_task = args.bot_task - if bot_task != "auto": - task = task + "_" + bot_task - if task not in _TASK_PRESETS: - valid_bot_tasks = { - "text2img": ["think", "recaption", "vanilla"], - "img2img": ["think", "recaption", "think_recaption"], - "img2text": ["auto"], - "text2text": ["auto"], - }[args.modality] - raise ValueError( - f"--bot-task {bot_task!r} is not supported for {args.modality}. Choose from: {valid_bot_tasks}" - ) + task, default_bot_task = _MODALITY_TASK_MAP[args.modality] + if args.bot_task is None: + bot_task: str | None = default_bot_task + elif args.bot_task == "none": + bot_task = None + else: + bot_task = args.bot_task if args.deploy_config is not None and args.stage_configs_path is not None: raise ValueError("--deploy-config and --stage-configs-path are mutually exclusive.") @@ -182,13 +133,13 @@ def main(): if deploy_config is None and stage_configs_path is None: deploy_config = _MODALITY_DEFAULT_DEPLOY_CONFIG[args.modality] - # Build Omni omni_kwargs = { "model": args.model, "vae_use_tiling": args.vae_use_tiling, "log_stats": args.log_stats, "init_timeout": args.init_timeout, "enforce_eager": args.enforce_eager, + "mode": _MODALITY_MODE[args.modality], } if additional_config is not None: @@ -197,71 +148,67 @@ def main(): omni_kwargs["deploy_config"] = deploy_config else: omni_kwargs["stage_configs_path"] = stage_configs_path - omni_kwargs["mode"] = _MODALITY_MODE[args.modality] omni = Omni(**omni_kwargs) - # Prepare prompts prompts = args.prompts or ["A cute cat"] - if not prompts: - print("[Info] No prompts provided, using default.") - prompts = ["A cute cat"] - - # Load image if needed - input_image = None + input_images: list = [] if args.modality in ("img2img", "img2text"): - if not args.image_path or not os.path.exists(args.image_path): + if not args.image_path: raise ValueError(f"--image-path required for {args.modality}, got: {args.image_path}") from PIL import Image - input_image = Image.open(args.image_path).convert("RGB") + image_paths = [p.strip() for p in args.image_path.split(",") if p.strip()] + if len(image_paths) > MAX_IMAGES_PER_REQUEST: + raise ValueError( + f"--image-path accepts at most {MAX_IMAGES_PER_REQUEST} images for " + f"HunyuanImage-3.0 IT2I, got {len(image_paths)}: {args.image_path}" + ) + for image_path in image_paths: + if not os.path.exists(image_path): + raise ValueError(f"Image path does not exist: {image_path}") + input_images.append(Image.open(image_path).convert("RGB")) + if not input_images: + raise ValueError(f"--image-path produced no usable paths: {args.image_path!r}") - # Load tokenizer for segment-wise prompt tokenization (matches HF - # apply_chat_template byte-for-byte; see build_prompt_tokens docstring). from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + mm_image_payload = (input_images[0] if len(input_images) == 1 else input_images) if input_images else None - # Format prompts formatted_prompts: list[OmniPromptType] = [] - for p in prompts: - result = build_prompt_tokens(p, tokenizer, task=task, sys_type=args.sys_type) + for prompt in prompts: + build_kwargs: dict = {"task": task, "bot_task": bot_task, "sys_type": args.sys_type} + if input_images: + build_kwargs["num_images"] = len(input_images) + result = build_prompt_tokens(prompt, tokenizer, **build_kwargs) token_ids = result.token_ids - effective_sys_type = result.system_prompt_type + effective_sys_type = args.sys_type or resolve_sys_type(bot_task) - # `prompt_token_ids` drives the AR stage (matches HF byte-for-byte). - # `prompt` and `use_system_prompt` are forwarded by ar2diffusion to - # the DiT stage so the diffusion pipeline can rebuild the same - # system prefix when constructing its model inputs. prompt_dict: dict = { "prompt_token_ids": token_ids, - "prompt": p, + "prompt": prompt, "use_system_prompt": effective_sys_type, } - if args.modality == "text2img": prompt_dict["modalities"] = ["image"] elif args.modality == "img2img": prompt_dict["modalities"] = ["image"] - prompt_dict["multi_modal_data"] = {"image": input_image} - prompt_dict["height"] = input_image.height - prompt_dict["width"] = input_image.width + prompt_dict["multi_modal_data"] = {"image": mm_image_payload} + prompt_dict["height"] = input_images[0].height + prompt_dict["width"] = input_images[0].width elif args.modality == "img2text": prompt_dict["modalities"] = ["text"] - prompt_dict["multi_modal_data"] = {"image": input_image} - elif args.modality == "text2text": + prompt_dict["multi_modal_data"] = {"image": mm_image_payload} + else: prompt_dict["modalities"] = ["text"] - formatted_prompts.append(prompt_dict) - # Build sampling params from defaults params_list = list(omni.default_sampling_params_list) - # Override diffusion params if applicable from vllm_omni.inputs.data import OmniDiffusionSamplingParams ar_stop_token_ids = resolve_stop_token_ids(task=task, bot_task=bot_task, tokenizer=tokenizer) - assert ar_stop_token_ids is not None for sp in params_list: if isinstance(sp, OmniDiffusionSamplingParams): sp.num_inference_steps = args.steps @@ -269,13 +216,12 @@ def main(): sp.guidance_scale_provided = True if args.seed is not None: sp.seed = args.seed - if args.modality in ("text2img",): + if args.modality == "text2img": sp.height = args.height sp.width = args.width elif hasattr(sp, "stop_token_ids"): sp.stop_token_ids = ar_stop_token_ids - # Print configuration print(f"\n{'=' * 60}") print("HunyuanImage-3.0 Generation Configuration:") print(f" Model: {args.model}") @@ -300,13 +246,9 @@ def main(): print(f" Prompts: {prompts}") print(f"{'=' * 60}\n") - # Generate omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list)) - - # Process outputs img_idx = 0 for req_output in omni_outputs: - # Text output (AR stage or text-only) ro = getattr(req_output, "request_output", None) txt = "" if ro and getattr(ro, "outputs", None): @@ -320,11 +262,9 @@ def main(): if txt: print(f"[Output] Text:\n{txt}") - # Image output (DiT stage) images = getattr(req_output, "images", None) if not images and ro and hasattr(ro, "images"): images = ro.images - if images: for j, img in enumerate(images): save_path = os.path.join(args.output, f"output_{img_idx}_{j}.png") diff --git a/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_it2i_multi_image.py b/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_it2i_multi_image.py new file mode 100644 index 00000000000..1e0fd159063 --- /dev/null +++ b/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_it2i_multi_image.py @@ -0,0 +1,255 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Multi-image input regression for HunyuanImage3 IT2I prompt construction. + +The official HunyuanImage-3.0-Instruct supports up to 3 reference images +per IT2I request ("Multi-Image Fusion"; see hunyuan3.0_ins/README.md +section 200-216 + line 500). Each cond image becomes its own user-role +message and `apply_general_template` concatenates successive user +messages back-to-back inside ONE user_prefix/user_suffix wrap (see +hunyuan3.0_ins/tokenization_hunyuan_image_3.py:1399-1400, 1499-1515). +The lightweight `` + `multi_modal_data` builder used by the example +flow must match that contract: N consecutive `` placeholders sit +between `User: ` and the user prompt, with no separator between them. + +This file pins: + 1. N consecutive `` placeholders for N=1/2/3 across both the + string builder (`build_prompt`) and the token builder + (`build_prompt_tokens`). + 2. The N=1 path stays bit-identical to the legacy single-image builder + (regression guard so default callers don't notice). + 3. N=2 / N=3 token sequences differ from N=1 by exactly (N-1) extra + `` ids inserted between `User: ` and `user_prompt`. + 4. Validation: N<1 and N>3 raise ValueError (hard cap N<=3 mirrors + official upstream). + 5. Text-only tasks ignore `num_images` (no validation, no extra ids). +""" + +from __future__ import annotations + +import os + +import pytest + +from vllm_omni.diffusion.models.hunyuan_image3.prompt_utils import ( + MAX_IMAGES_PER_REQUEST, + build_prompt, + build_prompt_tokens, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class FakeTokenizer: + """Recording fake tokenizer mirroring the one in test_prompt_utils. + + Special token ids: `<|startoftext|>`=1, ``=2, ``=3, + ``=4. encode() returns one id per character starting at + 100, so substring-position assertions are stable. + """ + + SPECIAL = { + "<|startoftext|>": 1, + "": 2, + "": 3, + "": 4, + } + + def __init__(self) -> None: + self.encode_calls: list[str] = [] + + def convert_tokens_to_ids(self, tok: str) -> int: + return self.SPECIAL.get(tok, 0) + + def encode(self, text: str, add_special_tokens: bool = False) -> list[int]: + self.encode_calls.append(text) + return list(range(100, 100 + len(text))) + + +_IMAGE_TASK_COMBOS = ( + ("i2t", None), + ("it2i", "think"), + ("it2i", "recaption"), +) +_TEXT_ONLY_TASK_COMBOS = (("t2t", None),) + + +# -------------------- string builder -------------------- + + +@pytest.mark.parametrize("task,bot_task", _IMAGE_TASK_COMBOS) +@pytest.mark.parametrize("num_images", [1, 2, 3]) +def test_build_prompt_emits_N_consecutive_img_placeholders(task: str, bot_task: str | None, num_images: int): + """N=1/2/3 -> exactly N `` substrings appear consecutively + between `User: ` and the user prompt, with no separator between them.""" + s = build_prompt("HELLO", task=task, bot_task=bot_task, num_images=num_images) + assert s.count("") == num_images, ( + f"task={task} bot_task={bot_task} num_images={num_images}: expected {num_images} " + f"placeholders, found {s.count('')} -- prompt was: {s!r}" + ) + + # All `` placeholders must form one contiguous run "..." + # immediately after `User: ` and before HELLO. + user_idx = s.index("User: ") + len("User: ") + hello_idx = s.index("HELLO") + between = s[user_idx:hello_idx] + assert between == "" * num_images, ( + f"region between `User: ` and prompt must be exactly N placeholders; got {between!r}" + ) + + +def test_build_prompt_default_num_images_matches_legacy(): + """num_images default = 1 must produce a string bit-identical to the + pre-multi-image behavior (single `` placeholder).""" + legacy = build_prompt("HELLO", task="it2i", bot_task="think") + explicit = build_prompt("HELLO", task="it2i", bot_task="think", num_images=1) + assert legacy == explicit, "default num_images=1 must match legacy single-image output" + + +# -------------------- token builder -------------------- + + +@pytest.mark.parametrize("task,bot_task", _IMAGE_TASK_COMBOS) +def test_build_prompt_tokens_inserts_N_img_ids(task: str, bot_task: str | None): + """N=1/2/3 -> the resulting id sequence contains exactly N copies of + img_id (=2) sitting consecutively after the `User: ` segment.""" + tok = FakeTokenizer() + ids_n1 = build_prompt_tokens("hi", tok, task=task, bot_task=bot_task, num_images=1).token_ids + tok = FakeTokenizer() + ids_n2 = build_prompt_tokens("hi", tok, task=task, bot_task=bot_task, num_images=2).token_ids + tok = FakeTokenizer() + ids_n3 = build_prompt_tokens("hi", tok, task=task, bot_task=bot_task, num_images=3).token_ids + + assert ids_n1.count(2) == 1 + assert ids_n2.count(2) == 2 + assert ids_n3.count(2) == 3 + + # Each additional image must extend the sequence by exactly one img_id, + # not shift other tokens around. + assert len(ids_n2) == len(ids_n1) + 1 + assert len(ids_n3) == len(ids_n1) + 2 + + # The img_ids must be CONSECUTIVE (no other token between successive + # `` placeholders -- mirrors the official `process_successive_message` + # wrapping where successive user messages share one user_prefix/suffix). + for ids, n in [(ids_n2, 2), (ids_n3, 3)]: + first = ids.index(2) + for k in range(n): + assert ids[first + k] == 2, ( + f"img_ids must be consecutive starting at position {first} for n={n}; got {ids[first : first + n]!r}" + ) + + +def test_build_prompt_tokens_default_num_images_matches_legacy(): + """num_images default = 1 must produce the same id sequence as + omitting the parameter (regression guard for existing single-image + callers).""" + tok_a = FakeTokenizer() + legacy = build_prompt_tokens("hi", tok_a, task="it2i", bot_task="think").token_ids + tok_b = FakeTokenizer() + explicit = build_prompt_tokens("hi", tok_b, task="it2i", bot_task="think", num_images=1).token_ids + assert legacy == explicit + # Also: encode() must have been called on the same set of segments, + # so segment boundaries are preserved. + assert tok_a.encode_calls == tok_b.encode_calls + + +# -------------------- validation -------------------- + + +@pytest.mark.parametrize("task,bot_task", _IMAGE_TASK_COMBOS) +@pytest.mark.parametrize("bad", [0, -1, MAX_IMAGES_PER_REQUEST + 1, 99]) +def test_build_prompt_rejects_out_of_range_num_images(task: str, bot_task: str | None, bad: int): + with pytest.raises(ValueError, match="num_images must be in"): + build_prompt("hi", task=task, bot_task=bot_task, num_images=bad) + with pytest.raises(ValueError, match="num_images must be in"): + build_prompt_tokens("hi", FakeTokenizer(), task=task, bot_task=bot_task, num_images=bad) + + +@pytest.mark.parametrize("task,bot_task", _TEXT_ONLY_TASK_COMBOS) +@pytest.mark.parametrize("num_images", [0, 1, 2, 99]) +def test_text_only_tasks_ignore_num_images(task: str, bot_task: str | None, num_images: int): + """Validation only kicks in for image-input tasks; t2t et al. accept + any num_images and emit zero `` placeholders.""" + s = build_prompt("hi", task=task, bot_task=bot_task, num_images=num_images) + assert "" not in s + ids = build_prompt_tokens("hi", FakeTokenizer(), task=task, bot_task=bot_task, num_images=num_images).token_ids + assert 2 not in ids + + +# -------------------- real HF tokenizer regression -------------------- + +_HUNYUAN_MODEL_ID = "tencent/HunyuanImage-3.0-Instruct" + + +def _hf_cached(model_id: str) -> bool: + hf_home = os.environ.get("HF_HOME") or os.path.expanduser("~/.cache/huggingface") + snap_dir = os.path.join(hf_home, "hub", f"models--{model_id.replace('/', '--')}", "snapshots") + return os.path.isdir(snap_dir) and any(os.scandir(snap_dir)) + + +@pytest.mark.skipif(not _hf_cached(_HUNYUAN_MODEL_ID), reason=f"{_HUNYUAN_MODEL_ID} tokenizer not in HF cache") +@pytest.mark.parametrize("num_images", [1, 2, 3]) +def test_real_tokenizer_emits_n_consecutive_img_ids(num_images: int): + """Real `AutoTokenizer.from_pretrained(...)` (the production path) must + encode N=1/2/3 prompts to a sequence with exactly N consecutive `` + token-ids in the right place — proves the placeholder layout from + `build_prompt_tokens` survives a real BPE tokenizer, not just FakeTokenizer. + """ + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(_HUNYUAN_MODEL_ID, trust_remote_code=True) + img_id = tok.convert_tokens_to_ids("") + assert img_id is not None and img_id >= 0, f" not in tokenizer vocab; got id={img_id}" + + ids = build_prompt_tokens("hi", tok, task="it2i", bot_task="think", num_images=num_images).token_ids + + # Exactly N copies of id, all consecutive. + img_positions = [i for i, x in enumerate(ids) if x == img_id] + assert len(img_positions) == num_images, ( + f"expected {num_images} ids, got {len(img_positions)} at positions {img_positions}" + ) + assert img_positions == list(range(img_positions[0], img_positions[0] + num_images)), ( + f" ids must be contiguous; got positions {img_positions}" + ) + + +@pytest.mark.skipif(not _hf_cached(_HUNYUAN_MODEL_ID), reason=f"{_HUNYUAN_MODEL_ID} tokenizer not in HF cache") +def test_real_tokenizer_n_plus_one_extends_by_exactly_one_img_id(): + """Going from N to N+1 images must extend the encoded id sequence by + exactly one extra `` token-id and shift nothing else. Catches + accidental separator tokens between successive `` placeholders + that a FakeTokenizer (deterministic encode) can't surface.""" + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(_HUNYUAN_MODEL_ID, trust_remote_code=True) + img_id = tok.convert_tokens_to_ids("") + + ids_n1 = build_prompt_tokens("hi", tok, task="it2i", bot_task="think", num_images=1).token_ids + ids_n2 = build_prompt_tokens("hi", tok, task="it2i", bot_task="think", num_images=2).token_ids + ids_n3 = build_prompt_tokens("hi", tok, task="it2i", bot_task="think", num_images=3).token_ids + + assert len(ids_n2) == len(ids_n1) + 1, f"N=2 should be N=1 + 1 token; got {len(ids_n2)} vs {len(ids_n1)}" + assert len(ids_n3) == len(ids_n1) + 2, f"N=3 should be N=1 + 2 tokens; got {len(ids_n3)} vs {len(ids_n1)}" + + # Insert one img_id at the existing position; everything else unchanged. + p1 = ids_n1.index(img_id) + assert ids_n2[: p1 + 1] == ids_n1[: p1 + 1] + [], "prefix before extra must match N=1" + assert ids_n2[p1] == img_id and ids_n2[p1 + 1] == img_id, "two consecutive ids at the insertion point" + assert ids_n2[p1 + 2 :] == ids_n1[p1 + 1 :], "tail after the extra must match N=1's tail" + # N=3 same pattern, three in a row. + assert ids_n3[p1 : p1 + 3] == [img_id, img_id, img_id] + assert ids_n3[p1 + 3 :] == ids_n1[p1 + 1 :] + + +@pytest.mark.skipif(not _hf_cached(_HUNYUAN_MODEL_ID), reason=f"{_HUNYUAN_MODEL_ID} tokenizer not in HF cache") +def test_real_tokenizer_default_n1_byte_identical_to_legacy(): + """Default `num_images=1` must produce the exact same id sequence as + omitting the parameter — pins the legacy single-image regression + against the real tokenizer (not just FakeTokenizer).""" + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(_HUNYUAN_MODEL_ID, trust_remote_code=True) + legacy = build_prompt_tokens("hi", tok, task="it2i", bot_task="think").token_ids + explicit = build_prompt_tokens("hi", tok, task="it2i", bot_task="think", num_images=1).token_ids + assert legacy == explicit, "real tokenizer: default num_images=1 must be byte-identical to legacy" diff --git a/tests/diffusion/models/hunyuan_image3/test_prompt_utils.py b/tests/diffusion/models/hunyuan_image3/test_prompt_utils.py index 1130c0f6db1..7c3256eee72 100644 --- a/tests/diffusion/models/hunyuan_image3/test_prompt_utils.py +++ b/tests/diffusion/models/hunyuan_image3/test_prompt_utils.py @@ -1,20 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Regression tests for HunyuanImage3 prompt construction (PR #3243). - -Two layers: - 1. Pure-logic tests with a recording fake tokenizer -- protect the - prompt template structure (BOS, User:/Assistant: framing, trigger - placement, image placeholder position) and protect the segment- - by-segment tokenization contract (each segment must hit - `tokenizer.encode` in isolation). - 2. Real-tokenizer regression -- run when the HunyuanImage3-Instruct - tokenizer is in the local HF cache. Asserts the segment-tokenized - output diverges from the naive full-string encode, which is the - bug-tripping fixture for the cross-segment BPE merge fix - (commit 7bd429ed). -""" - from __future__ import annotations import ast @@ -24,7 +9,9 @@ import pytest from vllm_omni.diffusion.models.hunyuan_image3.prompt_utils import ( + _TASK_PRESETS, HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS, + available_bot_tasks, available_tasks, build_prompt, build_prompt_tokens, @@ -34,18 +21,7 @@ pytestmark = [pytest.mark.core_model, pytest.mark.cpu] -# -------------------- Pure-logic structural tests -------------------- - - class FakeTokenizer: - """Minimal tokenizer stub that records every encode() call. - - Returns deterministic ids from convert_tokens_to_ids while - encode() returns one id per character starting at 100. This lets - tests both verify segmentation (by inspecting `encode_calls`) and - locate substrings inside the returned id list. - """ - SPECIAL = { "<|startoftext|>": 1, "": 2, @@ -72,85 +48,131 @@ def encode(self, text: str, add_special_tokens: bool = False) -> list[int]: def test_available_tasks_covers_all_modalities(): - tasks = set(available_tasks()) - assert tasks >= { - "t2t", - "i2t", + assert set(available_tasks()) == {"t2t", "i2t", "it2i", "t2i"} + + +def test_available_bot_tasks_covers_all_modes(): + assert set(available_bot_tasks()) == {None, "think", "recaption", "think_recaption", "vanilla"} + + +def test_legacy_task_presets_still_available(): + assert { "it2i_think", "it2i_recaption", "it2i_think_recaption", "t2i_think", "t2i_recaption", "t2i_vanilla", + } <= set(_TASK_PRESETS) + + +def test_legacy_base_task_omitted_bot_task_keeps_plain_mode(): + prompt = build_prompt("HELLO", task="i2t") + assert prompt.endswith("Assistant: ") + assert not prompt.endswith("") + + result = build_prompt_tokens("hi", FakeTokenizer(), task="i2t") + assert result.system_prompt_type == "en_unified" + assert result.token_ids[-1] not in { + FakeTokenizer.SPECIAL[""], + FakeTokenizer.SPECIAL[""], } -def test_resolve_stop_token_ids_uses_answer_for_generation_tasks(): +def test_legacy_composite_task_with_none_bot_task_keeps_encoded_mode(): + prompt = build_prompt("HELLO", task="it2i_think", bot_task=None) + assert prompt.endswith("Assistant: ") + + result = build_prompt_tokens("hi", FakeTokenizer(), task="it2i_recaption", bot_task=None) + assert result.token_ids[-1] == FakeTokenizer.SPECIAL[""] + + +def test_default_prompt_still_uses_it2i_think_mode(): + prompt = build_prompt("HELLO") + assert prompt.endswith("Assistant: ") + + result = build_prompt_tokens("hi", FakeTokenizer()) + assert result.system_prompt_type == "en_unified" + assert result.token_ids[-1] == FakeTokenizer.SPECIAL[""] + + +def test_resolve_stop_token_ids_image_tasks_stop_on_ratio_range(): + """Image-output tasks stop on any ```` token. + + Mirrors upstream ``modeling_hunyuan_image_3.py::generate_image`` + (line 3289-3303): when ``need_ratio`` is true, + ``final_stop_tokens = list(range(start_ratio, end_ratio + 1)) + + ratio_token_other_slices``. AR stops AT the ratio token sampled + after ````; the bridge then strips the trailing ratio + token before passing the cot to DiT. + """ tok = FakeTokenizer() + start = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + end = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + other_start = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + other_end = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + expected = list(range(start, end + 1)) + list(range(other_start, other_end + 1)) + + # Image-output: t2i / it2i stop on the full ratio token range. + for bot in ("think", "recaption", "think_recaption", "vanilla"): + assert resolve_stop_token_ids(task="t2i", bot_task=bot, tokenizer=tok) == expected + assert resolve_stop_token_ids(task="it2i", bot_task=bot, tokenizer=tok) == expected + + # Text-output: i2t / t2t comprehension stops on (response sits inside). answer_id = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] - assert resolve_stop_token_ids(task="t2i_think", tokenizer=tok) == [answer_id] - assert resolve_stop_token_ids(task="t2i_recaption", tokenizer=tok) == [answer_id] + assert resolve_stop_token_ids(task="i2t", bot_task=None, tokenizer=tok) == [answer_id] + assert resolve_stop_token_ids(task="t2t", bot_task=None, tokenizer=tok) == [answer_id] @pytest.mark.parametrize( - "task", + "task,bot_task", [ - "t2t", - "i2t", - "it2i_think", - "it2i_recaption", - "it2i_think_recaption", - "t2i_think", - "t2i_recaption", + ("t2t", None), + ("i2t", None), + ("it2i", "think"), + ("it2i", "recaption"), + ("it2i", "think_recaption"), + ("t2i", "think"), + ("t2i", "recaption"), + ("t2i", "think_recaption"), ], ) -def test_build_prompt_string_structure_chat_template(task: str): - """Chat-template tasks must produce <|startoftext|>...User: ...Assistant: ... - with image placeholder (when applicable) and trigger tag AFTER `Assistant: `.""" - s = build_prompt("HELLO", task=task) - +def test_build_prompt_string_structure_chat_template(task: str, bot_task: str | None): + s = build_prompt("HELLO", task=task, bot_task=bot_task) assert s.startswith("<|startoftext|>") assert "User: " in s assert "Assistant: " in s assert s.index("User: ") < s.index("HELLO") < s.index("Assistant: ") - if task.startswith(("i2t", "it2i")): - assert s.index("User: ") < s.index("") < s.index("HELLO"), ( - " placeholder must sit between `User: ` and the user prompt" - ) + if task in ("i2t", "it2i"): + assert s.index("User: ") < s.index("") < s.index("HELLO") else: assert "" not in s - # Trigger tag must be the FINAL token of the prompt (after `Assistant: `). - # Note: the system prompt itself mentions / as mode - # documentation, so substring index() catches the wrong occurrence -- use - # endswith() which directly captures "trigger is at the tail" (the Part A - # fix: trigger goes AFTER `Assistant: `, not before user_prompt). - if task in ("it2i_think", "t2i_think", "it2i_think_recaption"): - assert s.endswith("Assistant: "), ( - f"Trigger must be appended right after `Assistant: ` (Part A fix). Got tail: ...{s[-40:]!r}" - ) - if task in ("it2i_recaption", "t2i_recaption"): - assert s.endswith("Assistant: "), ( - f"Trigger must be appended right after `Assistant: ` (Part A fix). Got tail: ...{s[-40:]!r}" - ) - if task in ("t2t", "i2t"): - assert s.endswith("Assistant: "), "Plain (no-trigger) task must end at `Assistant: ` with no trailing tag." + if bot_task in ("think", "think_recaption"): + assert s.endswith("Assistant: ") + elif bot_task == "recaption": + assert s.endswith("Assistant: ") + elif bot_task is None: + assert s.endswith("Assistant: ") def test_build_prompt_vanilla_uses_pretrain_template(): - """t2i_vanilla is the only task that bypasses chat structure -- direct - text->image generation driven by the vanilla system prompt.""" - s = build_prompt("HELLO", task="t2i_vanilla") + s = build_prompt("HELLO", task="t2i", bot_task="vanilla") assert s.startswith("<|startoftext|>") assert "User: " not in s assert "Assistant: " not in s - assert "" not in s - assert "" not in s assert s.endswith("HELLO") +def test_build_prompt_vanilla_rejects_non_t2i_task(): + with pytest.raises(ValueError, match="bot_task='vanilla'"): + build_prompt("x", task="it2i", bot_task="vanilla") + with pytest.raises(ValueError, match="bot_task='vanilla'"): + build_prompt_tokens("x", FakeTokenizer(), task="i2t", bot_task="vanilla") + + def test_build_prompt_unknown_task_raises(): with pytest.raises(ValueError, match="Unknown task"): build_prompt("x", task="bogus") @@ -158,127 +180,83 @@ def test_build_prompt_unknown_task_raises(): build_prompt_tokens("x", FakeTokenizer(), task="bogus") +def test_build_prompt_unknown_bot_task_raises(): + with pytest.raises(ValueError, match="Unknown bot_task"): + build_prompt("x", task="t2i", bot_task="bogus") + with pytest.raises(ValueError, match="Unknown bot_task"): + build_prompt_tokens("x", FakeTokenizer(), task="t2i", bot_task="bogus") + + def test_build_prompt_tokens_segments_each_boundary(): - """Regression for cross-segment BPE merge bug (commit 7bd429ed): - each template segment must hit tokenizer.encode() independently; - user_prompt MUST NOT be concatenated with the following separator - in the same encode() call.""" tok = FakeTokenizer() - build_prompt_tokens("写诗。", tok, task="i2t") - - # Each canonical segment is encoded in its own call. + build_prompt_tokens("写诗。", tok, task="i2t", bot_task=None) assert "User: " in tok.encode_calls - assert "写诗。" in tok.encode_calls, ( - "user_prompt must be encoded alone -- if it is concatenated with the " - "trailing separator, BPE will merge across the boundary (the PR-#3243 bug)." - ) + assert "写诗。" in tok.encode_calls assert "\n\nAssistant: " in tok.encode_calls - - # No call must contain user_prompt glued to neighboring text. for call in tok.encode_calls: if call != "写诗。": - assert "写诗。" not in call, f"user_prompt leaked into a multi-segment encode call: {call!r}" + assert "写诗。" not in call def test_build_prompt_tokens_image_placeholder_present_for_image_tasks(): tok = FakeTokenizer() - result = build_prompt_tokens("hi", tok, task="i2t") + result = build_prompt_tokens("hi", tok, task="i2t", bot_task=None) ids = result.token_ids - assert ids[0] == FakeTokenizer.SPECIAL["<|startoftext|>"], "BOS (<|startoftext|>) must be the first token" - assert FakeTokenizer.SPECIAL[""] in ids, " placeholder must be present for i2t/it2i tasks" + assert ids[0] == FakeTokenizer.SPECIAL["<|startoftext|>"] + assert FakeTokenizer.SPECIAL[""] in ids def test_build_prompt_tokens_no_image_for_text_only_tasks(): tok = FakeTokenizer() - result = build_prompt_tokens("hi", tok, task="t2t") + result = build_prompt_tokens("hi", tok, task="t2t", bot_task=None) ids = result.token_ids - assert FakeTokenizer.SPECIAL[""] not in ids, " must NOT appear for text-only tasks" + assert FakeTokenizer.SPECIAL[""] not in ids @pytest.mark.parametrize( - "task,trigger_id", + "task,bot_task,trigger_id", [ - ("it2i_think", FakeTokenizer.SPECIAL[""]), - ("t2i_think", FakeTokenizer.SPECIAL[""]), - ("it2i_recaption", FakeTokenizer.SPECIAL[""]), - ("t2i_recaption", FakeTokenizer.SPECIAL[""]), + ("it2i", "think", FakeTokenizer.SPECIAL[""]), + ("t2i", "think", FakeTokenizer.SPECIAL[""]), + ("t2i", "think_recaption", FakeTokenizer.SPECIAL[""]), + ("it2i", "recaption", FakeTokenizer.SPECIAL[""]), + ("t2i", "recaption", FakeTokenizer.SPECIAL[""]), + ("it2i_think", None, FakeTokenizer.SPECIAL[""]), + ("it2i_recaption", None, FakeTokenizer.SPECIAL[""]), ], ) -def test_build_prompt_tokens_trigger_is_last_token(task: str, trigger_id: int): - """Trigger tag id must be the LAST token (after `Assistant: ` segment).""" +def test_build_prompt_tokens_trigger_is_last_token(task: str, bot_task: str | None, trigger_id: int): tok = FakeTokenizer() - result = build_prompt_tokens("hi", tok, task=task) - ids = result.token_ids - assert ids[-1] == trigger_id + result = build_prompt_tokens("hi", tok, task=task, bot_task=bot_task) + assert result.token_ids[-1] == trigger_id def test_build_prompt_tokens_no_trigger_for_plain_tasks(): - """Tasks without trigger_tag (t2t / i2t) must NOT append a trigger id.""" tok = FakeTokenizer() - result = build_prompt_tokens("hi", tok, task="t2t") - ids = result.token_ids - assert ids[-1] not in { + result = build_prompt_tokens("hi", tok, task="t2t", bot_task=None) + assert result.token_ids[-1] not in { FakeTokenizer.SPECIAL[""], FakeTokenizer.SPECIAL[""], } -# -------------------- end2end.py wiring guard -------------------- - - def _repo_root() -> pathlib.Path: - # tests/diffusion/models/hunyuan_image3/test_prompt_utils.py -> repo root return pathlib.Path(__file__).resolve().parents[4] def test_end2end_routes_through_shared_prompt_utils(): - """Regression for the *delivery vector* of PR #3243. - - Background: the wrong-template bug that PR #3243 fixes was introduced - when end2end.py grew its own hand-rolled prompt builder that diverged - from the canonical instruct chat template. To prevent that exact - failure mode from recurring, end2end.py MUST: - 1. Import the prompt builders from the shared prompt_utils module. - 2. NOT redefine `build_prompt` or `build_prompt_tokens` locally. - - A local redefinition is precisely how a future merge can silently - re-introduce a pretrain-style template (trigger BEFORE user_prompt, - no User:/Assistant: framing, etc.) without touching prompt_utils, - bypassing every other test in this file. - """ end2end_path = _repo_root() / "examples" / "offline_inference" / "hunyuan_image3" / "end2end.py" - assert end2end_path.is_file(), f"end2end.py not found at {end2end_path}" - tree = ast.parse(end2end_path.read_text(encoding="utf-8")) local_func_names = {n.name for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)} - forbidden = {"build_prompt", "build_prompt_tokens"} - redefined = local_func_names & forbidden - assert not redefined, ( - f"end2end.py defines {sorted(redefined)} locally. This is exactly how " - "the wrong prompt template re-entered the example before PR #3243. " - "Use the shared `vllm_omni.diffusion.models.hunyuan_image3.prompt_utils` " - "helpers instead." - ) + assert not (local_func_names & {"build_prompt", "build_prompt_tokens"}) imported_from_prompt_utils: set[str] = set() for node in ast.walk(tree): if isinstance(node, ast.ImportFrom) and node.module and node.module.endswith("hunyuan_image3.prompt_utils"): imported_from_prompt_utils.update(alias.name for alias in node.names) - expected_imports = { - "_TASK_PRESETS", - "build_prompt_tokens", - "resolve_stop_token_ids", - } - assert expected_imports <= imported_from_prompt_utils, ( - "end2end.py must import the HunyuanImage3 prompt and stop-token helpers from " - "vllm_omni.diffusion.models.hunyuan_image3.prompt_utils -- the shared " - "module is the single source of truth for the AR-prefill template and " - "bot_task-derived AR stop token ids." - ) - - -# -------------------- Real-tokenizer regression -------------------- + expected_imports = {"build_prompt_tokens", "resolve_stop_token_ids", "resolve_sys_type"} + assert expected_imports <= imported_from_prompt_utils _HUNYUAN_MODEL_ID = "tencent/HunyuanImage-3.0-Instruct" @@ -290,41 +268,14 @@ def _hf_cached(model_id: str) -> bool: return os.path.isdir(snap_dir) and any(os.scandir(snap_dir)) -@pytest.mark.skipif( - not _hf_cached(_HUNYUAN_MODEL_ID), - reason=f"{_HUNYUAN_MODEL_ID} tokenizer not in HF cache", -) +@pytest.mark.skipif(not _hf_cached(_HUNYUAN_MODEL_ID), reason=f"{_HUNYUAN_MODEL_ID} tokenizer not in HF cache") def test_segment_tokenize_diverges_from_full_string_encode(): - """Regression for PR #3243 segment-tokenization fix. - - The naive `tokenizer.encode(build_prompt(...))` lets BPE merge tokens - across segment boundaries (notably `。\\n\\n` -> a single id), which - drifts the AR prefill away from HF's apply_chat_template output. The - segment-by-segment build_prompt_tokens must produce a STRICTLY - DIFFERENT id sequence on a prompt that triggers the merge. - - If someone "simplifies" build_prompt_tokens to call encode() on the - full string, this assertion fires. - """ from transformers import AutoTokenizer tok = AutoTokenizer.from_pretrained(_HUNYUAN_MODEL_ID, trust_remote_code=True) - user_prompt = "写一首关于夜的诗。" - result = build_prompt_tokens(user_prompt, tok, task="i2t") + result = build_prompt_tokens(user_prompt, tok, task="i2t", bot_task=None) seg_ids = result.token_ids - full_ids = tok.encode(build_prompt(user_prompt, task="i2t"), add_special_tokens=False) - - assert seg_ids != full_ids, ( - "build_prompt_tokens output equals naive full-string encode -- " - "the BPE-merge-bypass behavior is no longer exercised. This means " - "the segment-by-segment fix from PR #3243 has been silently undone." - ) - - # Segmenting prevents merges, so the segment id list should have AT LEAST - # as many tokens as the merged version (a merge consumes 2+ ids -> 1). - assert len(seg_ids) >= len(full_ids), ( - f"segment-encoded length ({len(seg_ids)}) shorter than full-string " - f"merged length ({len(full_ids)}) -- impossible if segmentation is " - f"genuinely bypassing merges." - ) + full_ids = tok.encode(build_prompt(user_prompt, task="i2t", bot_task=None), add_special_tokens=False) + assert seg_ids != full_ids + assert len(seg_ids) >= len(full_ids) diff --git a/tests/e2e/accuracy/test_hunyuan_image3.py b/tests/e2e/accuracy/test_hunyuan_image3.py index 93671e7bbf6..0871793c5db 100644 --- a/tests/e2e/accuracy/test_hunyuan_image3.py +++ b/tests/e2e/accuracy/test_hunyuan_image3.py @@ -93,7 +93,13 @@ def _run(stage_config_path: str, output_path: Path) -> tuple[Image.Image, str, f from vllm_omni.platforms import current_omni_platform tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) - result = build_prompt_tokens(PROMPT, tokenizer, task="it2i_recaption", sys_type="en_unified") + result = build_prompt_tokens( + PROMPT, + tokenizer, + task="it2i", + bot_task="recaption", + sys_type="en_unified", + ) token_ids = result.token_ids system_prompt_type = result.system_prompt_type diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index b5ff891f8f6..40adb7a9151 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -1349,8 +1349,16 @@ def test_image_edit_parameter_default(async_omni_test_client): engine = async_omni_test_client.app.state.engine_client captured_sampling_params = engine.captured_sampling_params_list[-1] - assert captured_sampling_params.width == 24 - assert captured_sampling_params.height == 16 + # size="auto" on multi-stage pipelines deliberately leaves the diffusion + # stages sampling_params width/height unset so AR-driven pipelines (e.g. + # HunyuanImage-3.0) can let ar2diffusion override the final bucket from + # the AR-predicted ratio token; see + # test_image_edits_size_auto_preserves_bridge_size for the contract. + # Single-stage diffusion (test_image_edit_parameter_default_single_stage) + # still pins width/height to the input image size via api_servers + # gen_params, which is unchanged. + assert captured_sampling_params.width is None + assert captured_sampling_params.height is None assert captured_sampling_params.num_outputs_per_prompt == 1 assert captured_sampling_params.num_inference_steps == 4 assert captured_sampling_params.guidance_scale == 7.5 @@ -1649,3 +1657,59 @@ def __init__(self): assert len(images) == 1 assert isinstance(images[0], Image.Image) assert images[0].size == (32, 32) + + +def test_image_edits_size_auto_preserves_bridge_size(async_omni_stage_configs_only_client): + """size=auto must NOT pin the diffusion stage sampling_params.height/width. + + Regression: prior to the fix, edit_images resolved size=auto to the + first input image dimensions and forwarded them through gen_params + + extra_body to the diffusion stages sampling_params. AR-driven + pipelines (e.g. HunyuanImage-3.0) rely on ar2diffusions + bridge to override the final bucket via the AR-predicted ratio token, + and the DiT pre_process_func only fills sampling_params from the + bridge value when sampling_params.width is None (see + pipeline_hunyuan_image3.py:290). Non-None width from the input image + silently suppressed the AR decision, producing the wrong bucket + (e.g. 1024x1024 square instead of the AR-decided 1280x720 landscape + for multi-image fusion). + + Cross-pins the multi-image fix at the API level: 2 reference images + with bot_task=think must produce 2 placeholders in the captured + AR prompt (build_prompt called with num_images=2). + """ + img_a = make_test_image_bytes((32, 32)) + img_b = make_test_image_bytes((128, 64)) + response = async_omni_stage_configs_only_client.post( + "/v1/images/edits", + files=[("image", img_a), ("image", img_b)], + data={ + "prompt": "fuse", + "size": "auto", + "bot_task": "think", + }, + ) + assert response.status_code == 200, response.text + + engine = async_omni_stage_configs_only_client.app.state.engine_client + captured = engine.captured_sampling_params_list + assert captured is not None + assert len(captured) == 2 + + diffusion_params = captured[1] + assert diffusion_params.height is None, ( + f"size=auto leaked into diffusion sampling_params.height={diffusion_params.height}; " + "must stay None so AR-driven pipelines can apply the bridges decision." + ) + assert diffusion_params.width is None, ( + f"size=auto leaked into diffusion sampling_params.width={diffusion_params.width}; " + "must stay None so AR-driven pipelines can apply the bridges decision." + ) + + KEY = "prompt" + IMG = "" + captured_prompt = engine.captured_prompt + if isinstance(captured_prompt, dict) and isinstance(captured_prompt.get("prompt"), str): + assert captured_prompt["prompt"].count("") == 2, ( + f"N=2 reference images must emit 2 placeholders in AR prompt; got {captured_prompt[KEY].count(IMG)} -- prompt: {captured_prompt[KEY]!r}" + ) diff --git a/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py b/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py index 144a0e97a6c..4b63588bae7 100644 --- a/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py +++ b/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py @@ -91,3 +91,275 @@ def test_build_multistage_generation_inputs_applies_stage_specific_overrides(ser assert engine.default_sampling_params_list[1].lora_request is None assert engine.default_sampling_params_list[2].resolution == 640 assert engine.default_sampling_params_list[2].lora_request is None + + +def test_build_multistage_generation_inputs_multi_image_emits_n_img_placeholders(serving_chat): + """N reference images with bot_task set must emit N placeholders. + + Regression: prior to the multi-image online fix, build_prompt was + called without num_images, defaulting to 1. A 2-image edit request + would only get a single placeholder in the AR prompt; vLLMs + _process_multimodal then raised + AssertionError(Failed to apply prompt replacement for mm_items[image][1]) + when trying to replace the second image (no placeholder left for it). + + Pins the contract that build_prompt() is invoked with the actual image + count so multi-image IT2I is wired correctly through the online + /v1/images/edits path. + """ + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + engine = SimpleNamespace( + stage_configs=[ + SimpleNamespace(stage_type="llm", is_comprehension=True), + SimpleNamespace(stage_type="diffusion", is_comprehension=False), + ], + default_sampling_params_list=[ + SamplingParams(temperature=0.0), + OmniDiffusionSamplingParams(), + ], + ) + IMG = "" + images = [Image.new("RGB", (32, 32), color="red") for _ in range(3)] + + for n in (1, 2, 3): + engine_prompt, _ = OmniOpenAIServingChat._build_multistage_generation_inputs( + serving_chat, + engine=engine, + prompt="edit me", + extra_body={"bot_task": "think"}, + reference_images=images[:n], + gen_params=OmniDiffusionSamplingParams(), + ) + prompt_str = engine_prompt["prompt"] + assert prompt_str.count("") == n, ( + f"N={n}: expected {n} placeholders, got {prompt_str.count(IMG)} -- prompt: {prompt_str!r}" + ) + + +def test_build_multistage_generation_inputs_tokenizer_path_emits_prompt_token_ids(serving_chat): + """When a tokenizer is provided, the helper must emit HF byte-for-byte + prompt_token_ids and forward use_system_prompt to the engine prompt. + + Regression: prior to the HF-byte-equivalent fix, online IT2I always + passed the prompt as a single string. The engine then BPE-merged across + chat-template segment boundaries (e.g. user_prompt-ending punctuation + plus the trailing \n\n before \"Assistant: \") producing a token + sequence that differs from HF apply_chat_template / offline + end2end.py. AR generated different cot_text (706 tokens / 1190 chars + vs offline 661 / 1118 for the same inputs) and DiT produced a visually + different image (yin-yang on brushed-metal vs three-blue swirl on + canvas) under the same seed. + + Pins: + 1. engine_prompt[\"prompt_token_ids\"] is set when tokenizer is passed. + 2. engine_prompt[\"prompt\"] stays as the raw user prompt -- the DiT + side rebuilds its own system prefix via use_system_prompt. + 3. engine_prompt[\"use_system_prompt\"] == \"en_unified\" so + ar2diffusion forwards the matching system prompt to DiT. + 4. N reference images emit N token ids in the AR sequence. + """ + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + # Minimal FakeTokenizer mirroring tests/diffusion/.../test_hunyuan_image3_it2i_multi_image.py + class FakeTokenizer: + SPECIAL = { + "<|startoftext|>": 1, + "": 2, + "": 3, + "": 4, + } + + def convert_tokens_to_ids(self, tok: str) -> int: + return self.SPECIAL.get(tok, 0) + + def encode(self, text: str, add_special_tokens: bool = False) -> list[int]: + return list(range(100, 100 + len(text))) + + engine = SimpleNamespace( + stage_configs=[ + SimpleNamespace(stage_type="llm", is_comprehension=True), + SimpleNamespace(stage_type="diffusion", is_comprehension=False), + ], + default_sampling_params_list=[ + SamplingParams(temperature=0.0), + OmniDiffusionSamplingParams(), + ], + ) + PROMPT_KEY = "prompt" + USP_KEY = "use_system_prompt" + images = [Image.new("RGB", (32, 32), color="red") for _ in range(3)] + + for n in (1, 2, 3): + tok = FakeTokenizer() + engine_prompt, _ = OmniOpenAIServingChat._build_multistage_generation_inputs( + serving_chat, + engine=engine, + prompt="edit me", + extra_body={"bot_task": "think"}, + reference_images=images[:n], + gen_params=OmniDiffusionSamplingParams(), + tokenizer=tok, + ) + # (1) prompt_token_ids must be set and non-empty + assert "prompt_token_ids" in engine_prompt, f"N={n}: prompt_token_ids missing" + token_ids = engine_prompt["prompt_token_ids"] + assert isinstance(token_ids, list) and len(token_ids) > 0, f"N={n}: prompt_token_ids empty" + # (2) raw prompt preserved (DiT bridge needs raw user text) + assert engine_prompt["prompt"] == "edit me", ( + f"N={n}: prompt must stay raw user text, got {engine_prompt[PROMPT_KEY]!r}" + ) + # (3) use_system_prompt forwarded for ar2diffusion bridge + assert engine_prompt.get("use_system_prompt") == "en_unified", ( + f"N={n}: use_system_prompt must be en_unified, got {engine_prompt.get(USP_KEY)!r}" + ) + # (4) N token ids (id=2 in FakeTokenizer) + img_count = token_ids.count(2) + assert img_count == n, f"N={n}: expected {n} token ids in prompt_token_ids, got {img_count}" + + +def test_build_multistage_generation_inputs_bot_task_semantic_changes_trigger_and_sys(serving_chat): + """Passing bot_task=think_recaption (vs default "think") must flip the + resolved sys_type to en_think_recaption (and trigger tag is still + ). Pins that the API actually plumbs the bot_task semantic + through to build_prompt rather than ignoring it. + """ + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + engine = SimpleNamespace( + stage_configs=[ + SimpleNamespace(stage_type="llm", is_comprehension=True), + SimpleNamespace(stage_type="diffusion", is_comprehension=False), + ], + default_sampling_params_list=[ + SamplingParams(temperature=0.0), + OmniDiffusionSamplingParams(), + ], + ) + images = [Image.new("RGB", (32, 32), color="red")] + + # Default bot_task (think) -> en_unified system prompt baked into the + # legacy string path. Use legacy build_prompt (tokenizer=None) so the + # rendered prompt is a string we can grep. + think_prompt, _ = OmniOpenAIServingChat._build_multistage_generation_inputs( + serving_chat, + engine=engine, + prompt="edit me", + extra_body={"task": "it2i", "bot_task": "think"}, + reference_images=images, + gen_params=OmniDiffusionSamplingParams(), + ) + # think_recaption -> en_think_recaption system prompt (different content). + recap_prompt, _ = OmniOpenAIServingChat._build_multistage_generation_inputs( + serving_chat, + engine=engine, + prompt="edit me", + extra_body={"task": "it2i", "bot_task": "think_recaption"}, + reference_images=images, + gen_params=OmniDiffusionSamplingParams(), + ) + assert think_prompt["prompt"] != recap_prompt["prompt"], ( + "bot_task semantic must change the rendered system prompt: " + f"think/think_recaption produced identical strings (len={len(think_prompt['prompt'])})" + ) + + +def test_build_multistage_generation_inputs_sys_type_override(serving_chat): + """Caller-supplied sys_type must override the bot_task-derived default. + Mirrors offline `--bot-task think_recaption --sys-type en_unified` + where the user wants think_recaptions trigger but the unified system + prompt body. + """ + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + engine = SimpleNamespace( + stage_configs=[ + SimpleNamespace(stage_type="llm", is_comprehension=True), + SimpleNamespace(stage_type="diffusion", is_comprehension=False), + ], + default_sampling_params_list=[ + SamplingParams(temperature=0.0), + OmniDiffusionSamplingParams(), + ], + ) + images = [Image.new("RGB", (32, 32), color="red")] + + # think_recaption defaults sys_type -> en_think_recaption. + default_sys, _ = OmniOpenAIServingChat._build_multistage_generation_inputs( + serving_chat, + engine=engine, + prompt="edit me", + extra_body={"task": "it2i", "bot_task": "think_recaption"}, + reference_images=images, + gen_params=OmniDiffusionSamplingParams(), + ) + # sys_type=en_unified overrides -> same system body as bot_task=think. + overridden, _ = OmniOpenAIServingChat._build_multistage_generation_inputs( + serving_chat, + engine=engine, + prompt="edit me", + extra_body={"task": "it2i", "bot_task": "think_recaption", "sys_type": "en_unified"}, + reference_images=images, + gen_params=OmniDiffusionSamplingParams(), + ) + plain_think, _ = OmniOpenAIServingChat._build_multistage_generation_inputs( + serving_chat, + engine=engine, + prompt="edit me", + extra_body={"task": "it2i", "bot_task": "think"}, + reference_images=images, + gen_params=OmniDiffusionSamplingParams(), + ) + + # Override must (a) differ from the no-override default, and (b) equal + # the prompt that bot_task=think produces (both end up with + # en_unified system body + trigger). + assert overridden["prompt"] != default_sys["prompt"], ( + "sys_type override must change the rendered prompt body vs the bot_task default" + ) + assert overridden["prompt"] == plain_think["prompt"], ( + "sys_type=en_unified + bot_task=think_recaption must produce the same prompt as " + "bot_task=think (both = en_unified system body + trigger)" + ) + + +def test_build_multistage_generation_inputs_custom_system_prompt(serving_chat): + """`extra_body["system_prompt"]` must reach build_prompt as + `custom_system_prompt`, enabling sys_type="custom" callers to inject + a verbatim system body. Without this plumbing the sys_type="custom" + branch in get_system_prompt() returns None and silently drops the + user-supplied content. + """ + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + engine = SimpleNamespace( + stage_configs=[ + SimpleNamespace(stage_type="llm", is_comprehension=True), + SimpleNamespace(stage_type="diffusion", is_comprehension=False), + ], + default_sampling_params_list=[ + SamplingParams(temperature=0.0), + OmniDiffusionSamplingParams(), + ], + ) + images = [Image.new("RGB", (32, 32), color="red")] + + marker = "ZZZ_CUSTOM_SYSTEM_PROMPT_MARKER_ZZZ" + + out, _ = OmniOpenAIServingChat._build_multistage_generation_inputs( + serving_chat, + engine=engine, + prompt="edit me", + extra_body={ + "task": "it2i", + "bot_task": "think", + "sys_type": "custom", + "system_prompt": marker, + }, + reference_images=images, + gen_params=OmniDiffusionSamplingParams(), + ) + assert marker in out["prompt"], ( + f"custom system_prompt content must reach the rendered prompt; " + f"marker {marker!r} not found in prompt of length {len(out['prompt'])}" + ) diff --git a/tests/model_executor/stage_input_processors/test_hunyuan_image3_bridge.py b/tests/model_executor/stage_input_processors/test_hunyuan_image3_bridge.py new file mode 100644 index 00000000000..76f3e500622 --- /dev/null +++ b/tests/model_executor/stage_input_processors/test_hunyuan_image3_bridge.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for HunyuanImage3 stage input processor.""" + +import builtins +from types import SimpleNamespace + +import pytest + +from vllm_omni.diffusion.models.hunyuan_image3.prompt_utils import ( + HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS, +) +from vllm_omni.model_executor.stage_input_processors.hunyuan_image3 import ( + _extract_ratio_index, + _truncate_at_cot_end, + ar2diffusion, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _source_output(token_ids: list[int], text: str = ""): + return SimpleNamespace( + outputs=[ + SimpleNamespace( + token_ids=token_ids, + cumulative_token_ids=token_ids, + text=text, + ) + ], + multimodal_output=None, + ) + + +def test_extract_ratio_index_uses_fixed_special_token_ids(): + ratio_33 = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + ratio_36 = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + + assert _extract_ratio_index([1, ratio_33, 2]) == 33 + assert _extract_ratio_index([1, ratio_33, 2, ratio_36]) == 36 + + +def test_truncate_at_cot_end_strips_tail_after_recaption_marker(): + text = _truncate_at_cot_end("body text") + assert text == "body text" + + +def test_ar2diffusion_applies_ratio_and_truncates_tail_without_tokenizer(monkeypatch: pytest.MonkeyPatch): + real_import = builtins.__import__ + + def _block_transformers_import(name, *args, **kwargs): + if name == "transformers" or name.startswith("transformers."): + raise AssertionError("ar2diffusion must not import transformers on the bridge path") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", _block_transformers_import) + + end_recaption = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + answer = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + boi = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + size = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + ratio_0 = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + token_ids = [100, 101, end_recaption, answer, boi, size, ratio_0] + + result = ar2diffusion( + [_source_output(token_ids, text="decoded without special tokens")], + prompt=[{"prompt": "edit", "height": 64, "width": 64}], + ) + + assert len(result) == 1 + assert (result[0]["height"], result[0]["width"]) == (512, 2048) + assert result[0]["extra"]["ar_generated_text"] == "decoded without special tokens" + assert "ar_token_ids" not in result[0]["extra"] + + +def test_ar2diffusion_forwards_custom_system_prompt_body(): + end_think = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + marker = "CUSTOM_SYSTEM_BODY" + + result = ar2diffusion( + [_source_output([100, end_think], text="thought")], + prompt=[ + { + "prompt": "edit", + "use_system_prompt": "custom", + "system_prompt": marker, + } + ], + ) + + assert result[0]["use_system_prompt"] == "custom" + assert result[0]["system_prompt"] == marker diff --git a/vllm_omni/deploy/hunyuan_image3.yaml b/vllm_omni/deploy/hunyuan_image3.yaml index c52f28674db..20d72304b9d 100644 --- a/vllm_omni/deploy/hunyuan_image3.yaml +++ b/vllm_omni/deploy/hunyuan_image3.yaml @@ -22,6 +22,13 @@ connectors: stages: - stage_id: 0 + # ``is_comprehension`` in vllm-omni names the tokenizer-owning AR stage + # (see config/stage_config.py + serving_chat AR-stage lookup), independent + # of whether the AR's task is comprehension (i2t/t2t) or generation + # (it2i/t2i). HunyuanImage-3.0's stage-0 owns the tokenizer and emits the + # cot+ratio token sequence consumed by stage-1, so it must be marked True + # for the serving path to set AR seed/stop_token_ids on this stage. + is_comprehension: true final_output: true final_output_type: text max_num_seqs: 1 diff --git a/vllm_omni/diffusion/model_metadata.py b/vllm_omni/diffusion/model_metadata.py index ec133e7380e..f3346338434 100644 --- a/vllm_omni/diffusion/model_metadata.py +++ b/vllm_omni/diffusion/model_metadata.py @@ -13,6 +13,8 @@ class DiffusionModelMetadata: QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES = 4 +# Upstream HunyuanImage-3.0 "Multi-Image Fusion" caps reference images at 3. +HUNYUAN_IMAGE3_MAX_INPUT_IMAGES = 3 _DIFFUSION_MODEL_METADATA: dict[str, DiffusionModelMetadata] = { @@ -20,6 +22,10 @@ class DiffusionModelMetadata: supports_multimodal_inputs=True, max_multimodal_image_inputs=QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES, ), + "HunyuanImage3Pipeline": DiffusionModelMetadata( + supports_multimodal_inputs=True, + max_multimodal_image_inputs=HUNYUAN_IMAGE3_MAX_INPUT_IMAGES, + ), } diff --git a/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py index 1eb0cdf113b..4edcfb6ca3a 100644 --- a/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py @@ -471,8 +471,21 @@ def __str__(self): return f"{self.h}x{self.w}" +# Baked-in extras matching the official model's +# `HunyuanImage3ImageProcessor.vae_reso_group` (image_processor.py:147-152). +# These four aspect buckets sit at ratio_token indices 33-36 in the trained +# model and the AR was trained to address them, so any deviation breaks the +# ratio-token vocab → output-shape lookup. +HUNYUAN_IMAGE3_EXTRA_RESOLUTIONS: tuple[str, ...] = ( + "1024x768", + "1280x720", + "768x1024", + "720x1280", +) + + class ResolutionGroup: - def __init__(self, base_size=None, step=None, align=1): + def __init__(self, base_size=None, step=None, align=1, extra_resolutions=None): self.align = align self.base_size = base_size assert base_size % align == 0, f"base_size {base_size} is not divisible by align {align}" @@ -486,6 +499,11 @@ def __init__(self, base_size=None, step=None, align=1): self.step = step self.data = self._calc_by_step() + if extra_resolutions is not None: + for er in extra_resolutions: + if not any(r.ratio == er.ratio for r in self.data): + self.data.append(er) + self.ratio = np.array([x.ratio for x in self.data]) self.attr = ["" for _ in range(len(self.data))] self.prefix_space = 0 @@ -1351,7 +1369,10 @@ class HunyuanImage3ImageProcessor: def __init__(self, config): self.config = config - self.reso_group = ResolutionGroup(base_size=config.image_base_size) + self.reso_group = ResolutionGroup( + base_size=config.image_base_size, + extra_resolutions=[Resolution(s) for s in HUNYUAN_IMAGE3_EXTRA_RESOLUTIONS], + ) self.vae_processor = transforms.Compose( [ transforms.ToTensor(), diff --git a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py index 1f88e9e7155..73b89bb11b0 100644 --- a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py @@ -283,11 +283,13 @@ def pre_process_func(request: OmniDiffusionRequest): cond_image_infos = [_build_cond_joint_image(image) for image in image_list] prompt["additional_information"]["batch_cond_image_info"] = cond_image_infos + bridge_h = prompt.get("height") if isinstance(prompt, dict) else None + bridge_w = prompt.get("width") if isinstance(prompt, dict) else None first_image_w, first_image_h = _to_pil_image(image_list[0]).size if request.sampling_params.width is None: - request.sampling_params.width = int(first_image_w) + request.sampling_params.width = int(bridge_w or first_image_w) if request.sampling_params.height is None: - request.sampling_params.height = int(first_image_h) + request.sampling_params.height = int(bridge_h or first_image_h) request.prompts[i] = prompt @@ -539,7 +541,12 @@ def instantiate_timestep_tokens( timestep_scatter_index: BatchRaggedTensor, ): batch_size, seq_len, n_embd = x.shape - # batch_size x n x n_embd + # `_encode_cond_image` returns `t` as list[Tensor] for the + # multi-image branch (outer length = batch_size, currently fixed + # at 1 by the stage runtime `max_batch_size`); flatten to a Tensor + # before reshape. + if isinstance(t, list): + t = torch.cat([ti.reshape(-1) for ti in t], dim=0) timestep_scatter_src = self.timestep_emb(t.reshape(-1)).reshape(batch_size, -1, n_embd) x.scatter_( dim=1, @@ -627,7 +634,11 @@ def vae_encode(self, image, cfg_factor=1): if isinstance(vae_encode_result, torch.Tensor): latents = vae_encode_result else: - latents = vae_encode_result.latent_dist.sample() + # Match HunyuanImage-3's cond encode path: sample the + # posterior, but use a fixed generator so repeated online + # requests are deterministic. + _cond_vae_gen = torch.Generator(device=image.device).manual_seed(0) + latents = vae_encode_result.latent_dist.sample(_cond_vae_gen) if hasattr(config, "shift_factor") and config.shift_factor: latents.sub_(config.shift_factor) if hasattr(config, "scaling_factor") and config.scaling_factor: @@ -1353,25 +1364,21 @@ def forward( use_system_prompt = extra_args.get("use_system_prompt") system_prompt = extra_args.get("system_prompt") # Fall back to per-prompt use_system_prompt forwarded by ar2diffusion - if use_system_prompt is None and req.prompts: + if req.prompts: first_prompt = req.prompts[0] if isinstance(first_prompt, dict): - use_system_prompt = first_prompt.get("use_system_prompt") + if use_system_prompt is None: + use_system_prompt = first_prompt.get("use_system_prompt") + if system_prompt is None: + system_prompt = first_prompt.get("system_prompt") if use_system_prompt is not None: system_prompt = get_system_prompt(use_system_prompt, "image", system_prompt) system_prompt = system_prompt.strip() if system_prompt is not None else "" prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt - # Extract AR-generated CoT/recaption text from each prompt's extra dict. - # The AR-side stage input processor (``ar2diffusion``) already prepends - # the trigger tag (e.g. ````) when the AR used the KV-reuse - # pretrain format, so ``ar_generated_text`` is a self-contained string - # and ``get_cot_sections()`` can parse the think/recaption structure - # directly. - cot_text_list = [] - for p in req.prompts: - extra = p.get("extra", {}) if isinstance(p, dict) else {} - cot_text_list.append(extra.get("ar_generated_text") or None) + cot_text_list = [ + (p.get("extra", {}).get("ar_generated_text") if isinstance(p, dict) else None) or None for p in req.prompts + ] cot_text = ( [self._normalize_cot_text(t) for t in cot_text_list] if any(t is not None for t in cot_text_list) else None ) diff --git a/vllm_omni/diffusion/models/hunyuan_image3/prompt_utils.py b/vllm_omni/diffusion/models/hunyuan_image3/prompt_utils.py index 5d8e9af6ab8..b178b021fd6 100644 --- a/vllm_omni/diffusion/models/hunyuan_image3/prompt_utils.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/prompt_utils.py @@ -11,8 +11,23 @@ `JointImageInfo` objects produced by image preprocessing. The example flow uses an `` placeholder + `multi_modal_data` instead, so it needs a lighter-weight builder that only requires a HF tokenizer. This -module provides that builder; the task -> template mapping below is the -canonical mapping for both flows. +module provides that builder; the (task, bot_task) -> template mapping +below is the canonical mapping for both flows. + +Two orthogonal axes: + + * `task` selects the I/O modality combination, which only controls + whether `` placeholders are emitted between `User: ` and the + user prompt: ``i2t`` / ``it2i`` produce them, ``t2t`` / ``t2i`` do + not. + + * `bot_task` selects the prompting mode and drives both the system + prompt and the trigger tag appended after ``Assistant: ``. ``None`` + (default) gives a plain Assistant turn under the unified prompt; + ``think`` / ``recaption`` switch the trigger tag to ```` / + ````; ``think_recaption`` swaps the system prompt for + the dedicated combined-mode template; ``vanilla`` drops the chat + structure entirely (pretrain template, ``t2i`` only). """ from __future__ import annotations @@ -45,143 +60,234 @@ "": 130106, } -# task -> (sys_type, bot_task, trigger_tag) +# bot_task -> (sys_type, trigger_tag). +# ``vanilla`` is special-cased downstream: it bypasses the chat template +# (no ``User:`` / ``Assistant:`` framing) and is only valid with +# ``task='t2i'``. +_BOT_TASK_PRESETS: dict[str | None, tuple[str, str | None]] = { + None: ("en_unified", None), + "think": ("en_unified", ""), + "recaption": ("en_unified", ""), + "think_recaption": ("en_think_recaption", ""), + "vanilla": ("en_vanilla", None), +} + +_TASKS: frozenset[str] = frozenset({"t2t", "i2t", "it2i", "t2i"}) + + +class _DefaultBotTask: + pass + + +_DEFAULT_BOT_TASK = _DefaultBotTask() + +# Legacy composite task alias -> (task, bot_task). Keep this during rebase so +# older callers and intermediate commits still resolve cleanly. _TASK_PRESETS: dict[str, tuple[str, str | None, str | None]] = { "t2t": ("en_unified", None, None), "i2t": ("en_unified", None, None), "it2i_think": ("en_unified", "think", ""), "it2i_recaption": ("en_unified", "recaption", ""), "it2i_think_recaption": ("en_unified", "think_recaption", ""), - "t2i": ("en_unified", "image", None), - "t2i_vanilla": ("en_vanilla", "image", None), + "t2i": ("en_unified", None, None), + "t2i_vanilla": ("en_vanilla", "vanilla", None), "t2i_think": ("en_unified", "think", ""), "t2i_recaption": ("en_unified", "recaption", ""), } +_LEGACY_COMPOSITE_TASKS: frozenset[str] = frozenset(_TASK_PRESETS) - {"t2t", "i2t", "t2i"} + + +def _normalize_task_and_bot_task( + task: str, + bot_task: str | None | _DefaultBotTask, +) -> tuple[str, str | None]: + bot_task_was_omitted = bot_task is _DEFAULT_BOT_TASK + if task in _TASK_PRESETS: + _, legacy_bot_task, _ = _TASK_PRESETS[task] + base_task = task.split("_", 1)[0] + if base_task == "t2i" and task == "t2i": + base_task = "t2i" + if task in ("t2t", "i2t", "t2i"): + base_task = task + if bot_task_was_omitted: + bot_task = legacy_bot_task + elif task in _LEGACY_COMPOSITE_TASKS and bot_task is None: + # Composite task names already encode the legacy bot_task. Keep + # calls like build_prompt_tokens(task="it2i_think", bot_task=None) + # on their historical meaning; explicit None is the plain-mode + # escape hatch only for the new two-axis base tasks. + bot_task = legacy_bot_task + task = base_task + elif bot_task_was_omitted: + bot_task = "think" + return task, bot_task + def available_tasks() -> list[str]: - """Sorted list of task keys accepted by `build_prompt` / `build_prompt_tokens`.""" - return sorted(_TASK_PRESETS) + """Sorted list of `task` values accepted by the prompt builders.""" + return sorted(_TASKS) + + +def available_bot_tasks() -> list[str | None]: + """Sorted list of `bot_task` values (with ``None`` first).""" + rest = sorted(k for k in _BOT_TASK_PRESETS if k is not None) + return [None, *rest] + + +def resolve_sys_type(bot_task: str | None) -> str: + """Default system-prompt type for a given ``bot_task``.""" + if bot_task not in _BOT_TASK_PRESETS: + raise ValueError(f"Unknown bot_task {bot_task!r}. Choose from: {available_bot_tasks()}") + return _BOT_TASK_PRESETS[bot_task][0] def resolve_stop_token_ids( - task: str = "it2i_think", - bot_task: str = "think", + task: str = "it2i", + bot_task: str | None | _DefaultBotTask = _DEFAULT_BOT_TASK, tokenizer: Any | None = None, -): +) -> list[int]: + """AR stop-token ids for a given (task, bot_task) generation request. + + Image-output tasks (``it2i`` / ``t2i``) stop on any ```` + token. Upstream ``modeling_hunyuan_image_3.py::generate_image`` + (line 3289-3303) sets ``final_stop_tokens`` to the full ratio token + range when ``need_ratio`` is true, then strips the trailing ratio + token before passing the cot to the image stage. AR's natural + trajectory under ``_stage_transitions`` is + ````; stopping + AT the ratio token means KV ends exactly at the prefix DiT reuses, + and ``ar2diffusion`` can read the ratio off the last sampled token + without AR wasting decode steps on ``<|endoftext|>``. + + Text-output tasks (``i2t`` / ``t2t``) stop on ```` -- the AR + is the final stage, and the comprehension response sits inside the + ```` body so the answer-open is the natural cot/recaption + terminator. + """ + task, bot_task = _normalize_task_and_bot_task(task, bot_task) + if task not in _TASKS: + raise ValueError(f"Unknown task {task!r}. Choose from: {available_tasks()}") + if bot_task not in _BOT_TASK_PRESETS: + raise ValueError(f"Unknown bot_task {bot_task!r}. Choose from: {available_bot_tasks()}") + if task in ("it2i", "t2i"): + # Main ratio range: .. . + start = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + end = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + stops = list(range(start, end + 1)) + # Other slices (upstream tokenizer ``ratio_token_other_slices``): + # .. . + other_start = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + other_end = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + stops.extend(range(other_start, other_end + 1)) + return stops return [HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""]] +# Upstream "Multi-Image Fusion" caps reference images at 3 per request. +MAX_IMAGES_PER_REQUEST = 3 + + +def _validate_num_images(num_images: int) -> None: + if not (1 <= num_images <= MAX_IMAGES_PER_REQUEST): + raise ValueError(f"num_images must be in [1, {MAX_IMAGES_PER_REQUEST}], got {num_images}") + + +def _resolve_preset(task: str, bot_task: str | None) -> tuple[str, str | None]: + """Validate (task, bot_task) and return ``(sys_type, trigger_tag)``.""" + task, bot_task = _normalize_task_and_bot_task(task, bot_task) + if task not in _TASKS: + raise ValueError(f"Unknown task {task!r}. Choose from: {available_tasks()}") + if bot_task not in _BOT_TASK_PRESETS: + raise ValueError(f"Unknown bot_task {bot_task!r}. Choose from: {available_bot_tasks()}") + if bot_task == "vanilla" and task != "t2i": + raise ValueError(f"bot_task='vanilla' is only valid with task='t2i' (pretrain template); got task={task!r}") + return _BOT_TASK_PRESETS[bot_task] + + def build_prompt( user_prompt: str, - task: str = "it2i_think", + task: str = "it2i", + bot_task: str | None | _DefaultBotTask = _DEFAULT_BOT_TASK, sys_type: str | None = None, custom_system_prompt: str | None = None, + num_images: int = 1, ) -> str: - """Build a HunyuanImage-3.0 prompt as a string (legacy/compat path). - - NOTE: when this string is passed to the engine, the engine's tokenizer - will run a single BPE pass over the whole string, which can merge - tokens across segment boundaries (e.g. `。\\n\\n` -> id 3490). For - inputs that need to match HF baseline byte-for-byte, use - `build_prompt_tokens` instead and feed the result via prompt_token_ids. - """ - if task not in _TASK_PRESETS: - raise ValueError(f"Unknown task {task!r}. Choose from: {available_tasks()}") - - preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task] + """Build a HunyuanImage-3.0 prompt as a string (legacy/compat path).""" + task, bot_task = _normalize_task_and_bot_task(task, bot_task) + preset_sys_type, trigger_tag = _resolve_preset(task, bot_task) effective_sys_type = sys_type or preset_sys_type - system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt) - sys_text = system_prompt.strip() if system_prompt else "" + system_prompt = get_system_prompt(effective_sys_type, bot_task, custom_system_prompt) + sys_text = system_prompt or "" - has_image_input = task.startswith("i2t") or task.startswith("it2i") + has_image_input = task in ("i2t", "it2i") + if has_image_input: + _validate_num_images(num_images) - # t2i_vanilla: pretrain mode for direct text->image generation. The - # vanilla system prompt drives the model with no chat structure. - if task == "t2i_vanilla": + if bot_task == "vanilla": parts = ["<|startoftext|>"] if sys_text: parts.append(sys_text) parts.append(user_prompt) return "".join(parts) - # All other tasks (t2t / i2t / t2i_think / t2i_recaption / - # it2i_think / it2i_recaption) use HunyuanImage3 Instruct chat template: - # <|startoftext|>{system?}\n\nUser: {?}{user_prompt}\n\nAssistant: {trigger?} - # generation_config.json declares sequence_template="instruct", so the - # AR prefill MUST use this template -- verified to match HF's - # apply_chat_template output token-for-token (modulo BPE boundary merges). - # The trigger_tag (e.g. ) MUST come AFTER the `Assistant: ` prefix: - # if it goes BEFORE user_prompt (the old pretrain layout) the model puts - # the user's instructions inside the "thinking section" and collapses - # into repetition garbage under greedy decoding. parts = ["<|startoftext|>"] if sys_text: parts.append(f"{sys_text}\n\n") parts.append("User: ") if has_image_input: - parts.append("") + parts.extend([""] * num_images) parts.append(user_prompt) parts.append("\n\nAssistant: ") if trigger_tag: parts.append(trigger_tag) - return "".join(parts) @dataclass class PromptTokensResult: - token_ids: list[int] # The tokenized prompt - system_prompt_type: str # The effective system prompt type used + token_ids: list[int] + system_prompt_type: str def build_prompt_tokens( user_prompt: str, tokenizer, - task: str = "it2i_think", + task: str = "it2i", + bot_task: str | None | _DefaultBotTask = _DEFAULT_BOT_TASK, sys_type: str | None = None, custom_system_prompt: str | None = None, + num_images: int = 1, ) -> PromptTokensResult: - """Segment-by-segment tokenization that matches HF apply_chat_template. - - Calling tokenizer.encode(build_prompt(...)) on the full string lets BPE - merge tokens across segment boundaries (e.g. user_prompt ends with `。` - and the next segment is `\\n\\n` -> they merge into a single token id - 3490 instead of HF's [1811, 271]). HF's apply_chat_template tokenizes - each segment independently and concatenates token_ids, so no cross- - boundary merge happens. We replicate that here and feed the result to - Omni via OmniTokensPrompt (prompt_token_ids). - - Returns: - PromptTokensResult - """ - if task not in _TASK_PRESETS: - raise ValueError(f"Unknown task {task!r}. Choose from: {available_tasks()}") - - preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task] + """Segment-by-segment tokenization that matches HF apply_chat_template.""" + task, bot_task = _normalize_task_and_bot_task(task, bot_task) + preset_sys_type, trigger_tag = _resolve_preset(task, bot_task) effective_sys_type = sys_type or preset_sys_type bos_id = tokenizer.convert_tokens_to_ids("<|startoftext|>") img_id = tokenizer.convert_tokens_to_ids("") trig_id = tokenizer.convert_tokens_to_ids(trigger_tag) if trigger_tag else None - has_image_input = task.startswith("i2t") or task.startswith("it2i") - - # t2i_vanilla uses pretrain template with no chat structure; the vanilla - # system prompt drives the model directly. No segment boundaries to - # protect, fall back to whole-string encode. - if task == "t2i_vanilla": - s = build_prompt(user_prompt, task, sys_type, custom_system_prompt) + has_image_input = task in ("i2t", "it2i") + if has_image_input: + _validate_num_images(num_images) + + if bot_task == "vanilla": + s = build_prompt( + user_prompt, + task=task, + bot_task=bot_task, + sys_type=sys_type, + custom_system_prompt=custom_system_prompt, + ) token_ids = tokenizer.encode(s, add_special_tokens=False) return PromptTokensResult( token_ids=token_ids, system_prompt_type=effective_sys_type, ) - system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt) - # Do NOT strip -- HF apply_chat_template keeps the system prompt's - # natural trailing newline; stripping it would shift one token id. + system_prompt = get_system_prompt(effective_sys_type, bot_task, custom_system_prompt) sys_text = system_prompt or "" ids: list[int] = [bos_id] @@ -190,7 +296,7 @@ def build_prompt_tokens( ids += tokenizer.encode("\n\n", add_special_tokens=False) ids += tokenizer.encode("User: ", add_special_tokens=False) if has_image_input: - ids += [img_id] + ids += [img_id] * num_images ids += tokenizer.encode(user_prompt, add_special_tokens=False) ids += tokenizer.encode("\n\nAssistant: ", add_special_tokens=False) if trig_id is not None: @@ -202,4 +308,14 @@ def build_prompt_tokens( ) -__all__ = ["build_prompt", "build_prompt_tokens", "resolve_stop_token_ids", _TASK_PRESETS] +__all__ = [ + "HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS", + "MAX_IMAGES_PER_REQUEST", + "_TASK_PRESETS", + "available_bot_tasks", + "available_tasks", + "build_prompt", + "build_prompt_tokens", + "resolve_stop_token_ids", + "resolve_sys_type", +] diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 06fb0a7f4cb..c1467f7190a 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1700,7 +1700,10 @@ async def edit_images( # vllm-omni extension for layered models (e.g., Qwen-Image-Layered) layers: int | None = Form(None), resolution: int | None = Form(None), # See SUPPORTED_LAYERED_RESOLUTIONS + # /v1/images/edits is always IT2I; only the prompting knobs are exposed. bot_task: str | None = Form(None), + sys_type: str | None = Form(None), + system_prompt: str | None = Form(None), ) -> ImageGenerationResponse: """ OpenAI-compatible image edit endpoint. @@ -1751,16 +1754,20 @@ async def edit_images( status_code=HTTPStatus.BAD_REQUEST.value, detail=detail, ) - pil_images = await _load_input_images(input_images_list) + # Match the offline path: RGB normalize when the caller opts into + # Hunyuan-aware behavior. RGBA/P uploads otherwise diverge from offline. + normalize_edit_images_rgb = bot_task is not None or sys_type is not None + pil_images = await _load_input_images(input_images_list, normalize_rgb=normalize_edit_images_rgb) prompt["multi_modal_data"] = {} prompt["multi_modal_data"]["image"] = pil_images if mask_image is not None: - loaded = await _load_input_images([mask_image]) + # Mask role is different (alpha channel matters); never normalize. + loaded = await _load_input_images([mask_image], normalize_rgb=False) prompt["multi_modal_data"]["mask_image"] = loaded[0] if reference_image is not None: - loaded = await _load_input_images([reference_image]) + loaded = await _load_input_images([reference_image], normalize_rgb=normalize_edit_images_rgb) prompt["multi_modal_data"]["reference_image"] = loaded[0] # 3 Build sample params @@ -1811,7 +1818,8 @@ async def edit_images( # 3.3 Parse and add size if provided width, height = None, None - if size.lower() == "auto": + size_was_auto = size.lower() == "auto" + if size_was_auto: if resolution is None: # No resolution specified, use input image size width, height = pil_images[0].size @@ -1882,10 +1890,13 @@ async def edit_images( "seed": effective_seed, "num_outputs_per_prompt": n, } - if width is not None: - extra_body["width"] = width - if height is not None: - extra_body["height"] = height + # size="auto" resolves width/height from input image; forwarding + # those would override AR-driven `` token selection. + if not size_was_auto: + if width is not None: + extra_body["width"] = width + if height is not None: + extra_body["height"] = height if negative_prompt is not None: extra_body["negative_prompt"] = negative_prompt if num_inference_steps is not None: @@ -1907,6 +1918,10 @@ async def edit_images( extra_body["lora"] = lora_dict if bot_task is not None: extra_body["bot_task"] = bot_task + if sys_type is not None: + extra_body["sys_type"] = sys_type + if system_prompt is not None: + extra_body["system_prompt"] = system_prompt prompt_text = prompt.get("prompt", "") generation_result = await chat_handler.generate_diffusion_images( @@ -2187,6 +2202,8 @@ def _extract_images_from_result(result: Any) -> list[Any]: async def _load_input_images( inputs: list[str], + *, + normalize_rgb: bool = True, ) -> list[Image.Image]: """ convert to PIL.Image.Image list @@ -2233,7 +2250,18 @@ async def _load_input_images( if not images: raise ValueError("No valid input images found") - return images + if not normalize_rgb: + return images + + # Match the offline HunyuanImage3 image-edit example path, which eagerly + # normalizes input files with ``Image.open(...).convert("RGB")`` before + # they reach the AR stage. Keeping uploads as RGBA/P PIL objects makes + # online IT2I observe a different visual input than offline (for example + # transparent-logo uploads alpha-composited over white instead of black), + # which is enough for HunyuanImage3 AR recaption to diverge before DiT + # sees the request -- root cause of the "online 3 magnets vs offline 1 + # magnet" systematic semantic mismatch. + return [img.convert("RGB") for img in images] def _choose_output_format(output_format: str | None, background: str | None) -> str: diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 99827454e70..2c375fa2928 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -419,7 +419,10 @@ async def create_chat_completion( # consistency. After the multimodal processor consumes # the image data, the uuids remain as a stable reference. tprompt["multi_modal_uuids"] = { - k: [f"{request_id}-{k}-{i}"] for i, k in enumerate(engine_prompt_image) + k: [f"{request_id}-{k}-{i}" for i in range(len(v))] + if isinstance(v, list) + else [f"{request_id}-{k}-0"] + for k, v in engine_prompt_image.items() } engine_prompts = [tprompt] @@ -2245,6 +2248,8 @@ def _build_multistage_generation_inputs( layers = extra_body.get("layers") resolution = extra_body.get("resolution") bot_task = extra_body.get("bot_task") + sys_type = extra_body.get("sys_type") + custom_system_prompt = extra_body.get("system_prompt") engine_prompt_data: dict[str, Any] | None = None modalities = ["image"] @@ -2255,31 +2260,48 @@ def _build_multistage_generation_inputs( else: engine_prompt_data = {"image": reference_images} - engine_prompt: OmniTextPrompt = {"prompt": prompt} - if bot_task: + prompt_token_ids: list[int] | None = None + system_prompt_type: str | None = None + if bot_task is not None or sys_type is not None or custom_system_prompt is not None: from vllm_omni.diffusion.models.hunyuan_image3.prompt_utils import ( build_prompt, build_prompt_tokens, ) - prompt_token_ids: list[int] | None = None - system_prompt_type: str | None = None + build_kwargs: dict[str, Any] = { + "task": "it2i" if reference_images else "t2i", + "sys_type": sys_type, + "custom_system_prompt": custom_system_prompt, + "num_images": len(reference_images) if reference_images else 1, + } + if bot_task is not None: + build_kwargs["bot_task"] = bot_task + elif "bot_task" in extra_body: + # Explicit None from the caller is plain-mode; omitted lets + # each task fall back to its default trigger. + build_kwargs["bot_task"] = None if tokenizer is not None: - result = build_prompt_tokens(prompt, tokenizer, task=bot_task) + # Feed segment-tokenized prompt_token_ids so AR matches HF + # apply_chat_template byte-for-byte (engine BPE would merge + # across template boundaries, e.g. "。\n\n" -> single id). + result = build_prompt_tokens(prompt, tokenizer, **build_kwargs) prompt_token_ids = result.token_ids system_prompt_type = result.system_prompt_type else: - prompt = build_prompt(prompt, task=bot_task) - engine_prompt["prompt"] = prompt - + prompt = build_prompt(prompt, **build_kwargs) if reference_images and len(reference_images) == 1: engine_prompt_data = {"image": reference_images[0]} modalities = ["image"] - if prompt_token_ids is not None: - engine_prompt["prompt_token_ids"] = prompt_token_ids - if system_prompt_type is not None: - engine_prompt["use_system_prompt"] = system_prompt_type + engine_prompt: OmniTextPrompt = {"prompt": prompt} + if prompt_token_ids is not None: + engine_prompt["prompt_token_ids"] = prompt_token_ids + if system_prompt_type is not None: + engine_prompt["use_system_prompt"] = system_prompt_type + # DiT's get_system_prompt(use_system_prompt, "image", system_prompt) reads + # this; omitting it makes sys_type=custom yield an empty DiT prefix. + if custom_system_prompt is not None: + engine_prompt["system_prompt"] = custom_system_prompt engine_prompt["modalities"] = modalities if negative_prompt is not None: engine_prompt["negative_prompt"] = negative_prompt @@ -2295,7 +2317,11 @@ def _build_multistage_generation_inputs( engine_prompt["multi_modal_data"] = engine_prompt_data # Provide multi_modal_uuids so that newer vLLM versions can # validate multi_modal_data / multi_modal_uuids consistency. - engine_prompt["multi_modal_uuids"] = {k: [f"img-{k}-{i}"] for i, k in enumerate(engine_prompt_data)} + # Generate one uuid per image when the value is a list (multi-image inputs). + engine_prompt["multi_modal_uuids"] = { + k: [f"img-{k}-{i}" for i in range(len(v))] if isinstance(v, list) else [f"img-{k}-0"] + for k, v in engine_prompt_data.items() + } comprehension_idx = None for idx, stage in enumerate(stage_configs): @@ -2320,10 +2346,8 @@ def _build_multistage_generation_inputs( ): default_stage_params.seed = seed - # Inject target_h/w into comprehension (AR) stage sampling params - # for models that need M-RoPE position pre-computation (e.g. - # GLM-Image). max_tokens is handled via the deploy YAML default - # (upper-bound ceiling) rather than computed dynamically here. + # Inject target_h/w into AR stage for M-RoPE position pre-computation + # (e.g. GLM-Image). max_tokens comes from deploy YAML. if comprehension_idx is not None and idx == comprehension_idx and height is not None and width is not None: extra_args = getattr(default_stage_params, "extra_args", None) if extra_args is None: @@ -2449,13 +2473,17 @@ async def generate_diffusion_images( diffusion_engine = cast(AsyncOmni, engine) stage_configs = getattr(diffusion_engine, "stage_configs", None) or [] if len(stage_configs) > 1: + # Pull tokenizer from the comprehension (AR) stage so we can + # build HF byte-for-byte prompt_token_ids in the helper. If + # the engine doesn"t expose one, fall back to the legacy + # string-prompt path (engine re-tokenizes). tokenizer = None get_tok = getattr(diffusion_engine, "get_tokenizer", None) if get_tok is not None: try: tokenizer = await get_tok() except Exception as exc: - logger.warning("get_tokenizer failed: %s", exc) + logger.warning("get_tokenizer failed; falling back to string prompt path: %s", exc) engine_prompt, sampling_params_list = self._build_multistage_generation_inputs( engine=diffusion_engine, prompt=prompt, diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index 0f80026210d..2edf3d4df79 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -737,7 +737,7 @@ def __str__(self): class ResolutionGroup: """Group of resolutions for image processing.""" - def __init__(self, base_size=None, step=None, align=1): + def __init__(self, base_size=None, step=None, align=1, extra_resolutions=None): self.align = align self.base_size = base_size assert base_size % align == 0, f"base_size {base_size} is not divisible by align {align}" @@ -751,6 +751,11 @@ def __init__(self, base_size=None, step=None, align=1): self.step = step self.data = self._calc_by_step() + if extra_resolutions is not None: + for er in extra_resolutions: + if not any(r.ratio == er.ratio for r in self.data): + self.data.append(er) + self.ratio = np.array([x.ratio for x in self.data]) self.attr = ["" for _ in range(len(self.data))] self.prefix_space = 0 @@ -815,7 +820,18 @@ def get_base_size_and_ratio_index(self, width, height): def __init__(self, tokenizer, hf_config, **kwargs: object): self.tokenizer = tokenizer self.hf_config = hf_config - self.reso_group = self.ResolutionGroup(base_size=hf_config.image_base_size) + # `HUNYUAN_IMAGE3_EXTRA_RESOLUTIONS` mirrors the official + # `vae_reso_group` extras (image_processor.py:147-152). Build with + # this processor's inner Resolution class so `data` stays + # type-homogeneous. + from vllm_omni.diffusion.models.hunyuan_image3.hunyuan_image3_transformer import ( + HUNYUAN_IMAGE3_EXTRA_RESOLUTIONS, + ) + + self.reso_group = self.ResolutionGroup( + base_size=hf_config.image_base_size, + extra_resolutions=[HunyuanImage3Processor.Resolution(s) for s in HUNYUAN_IMAGE3_EXTRA_RESOLUTIONS], + ) self.vision_encoder_processor = Siglip2ImageProcessorFast.from_dict(hf_config.vit_processor) self.vae_processor = transforms.Compose( [ @@ -860,6 +876,13 @@ def process_image(self, image_input: ImageInput): else: raise TypeError(f"Unsupported image type: {type(image_input)}.") + # Each cond image keeps its own VAE bucket (mirrors official HF's + # ragged behavior in `_encode_cond_image`). VAE pixel tensors have + # different (H_i, W_i) per image, so they're flattened to 1-D and + # concatenated; vLLM `flat_from_sizes("image", vae_pixel_size)` slices + # them back per-image at consumption time. VIT (Siglip2 naflex) pads + # to `max_num_patches` so VIT fields keep the existing `batched` + # stack path. batch_data = [] for image in images: current_info = {} @@ -883,68 +906,80 @@ def process_image(self, image_input: ImageInput): _ss = torch.tensor(_ss, dtype=torch.long) current_info["vit_spatial_shapes"] = _ss.squeeze(0) - # VAE processing. - # The resize/crop math here mirrors HF's `resize_and_crop` with - # crop_type="center" (hunyuan3.0_ins/image_processor.py:61). VAE - # normalize uses the same transforms.Compose([ToTensor, - # Normalize([0.5], [0.5])]) as HF's `pil_image_to_tensor`. So - # numerical output of this branch should match HF up to floating- - # point reduction order. + # VAE: per-image bucket via `reso_group.get_target_size`; mirrors + # HF's `resize_and_crop` (crop_type="center", the official + # generate_image default with infer_align_image_size=False). + # Keep fp32 — the VAE encoder casts to model dtype at its + # boundary (see `_vae_encode`). image_width, image_height = self.reso_group.get_target_size(image.width, image.height) resized_image = self._resize_and_crop(image, (image_width, image_height)) - vae_pixel_values = self.vae_processor(resized_image) + vae_pixel_values = self.vae_processor(resized_image).squeeze(0) token_height = image_height // (self.hf_config.vae_downsample_factor[0] * self.hf_config.patch_size) token_width = image_width // (self.hf_config.vae_downsample_factor[1] * self.hf_config.patch_size) - # Keep fp32 — the VAE encoder casts to model dtype at its boundary - # (see _vae_encode). Casting to bf16 here costs ~7e-4 mean-abs-diff - # bf16 quantization error on every pixel vs HF (which keeps fp32 - # in build_cond_images), measurable as a real numerical drift in - # downstream image embeddings. - current_info["vae_pixel_values"] = vae_pixel_values.squeeze(0) + + current_info["vae_pixel_values_flat"] = vae_pixel_values.reshape(-1) + current_info["vae_pixel_size"] = torch.tensor(vae_pixel_values.numel(), dtype=torch.long) current_info["vae_token_grid_hw"] = torch.tensor([token_height, token_width]) - # size base_size, ratio_index = self.reso_group.get_base_size_and_ratio_index(image_width, image_height) current_info["base_size"] = torch.tensor(base_size) current_info["ratio_index"] = torch.tensor(ratio_index) batch_data.append(current_info) - # Stack the tensors in the list into a batch dimension (B, ...) - final_image_info = {} - if len(batch_data) > 0: - for key in batch_data[0].keys(): - final_image_info[key] = torch.stack([d[key] for d in batch_data], dim=0) + final_image_info: dict[str, torch.Tensor] = {} + if not batch_data: + return final_image_info + + # Same-shape fields: stack along a new image-batch dim as before. + same_shape_keys = [ + "vit_pixel_values", + "vit_pixel_attention_mask", + "vit_spatial_shapes", + "vae_token_grid_hw", + "vae_pixel_size", + "base_size", + "ratio_index", + ] + for key in same_shape_keys: + final_image_info[key] = torch.stack([d[key] for d in batch_data], dim=0) + + # Variable-shape VAE pixels: 1-D concat across images (paired with + # `vae_pixel_size` via `flat_from_sizes` in `_get_mm_fields_config`). + final_image_info["vae_pixel_values"] = torch.cat([d["vae_pixel_values_flat"] for d in batch_data], dim=0) - if final_image_info: - shapes_info = {k: tuple(v.shape) for k, v in final_image_info.items()} - logger.info(f"Successfully processed {len(images)} image(s). Final tensor shapes: {shapes_info}") + shapes_info = {k: tuple(v.shape) for k, v in final_image_info.items()} + logger.info(f"Successfully processed {len(images)} image(s). Final tensor shapes: {shapes_info}") return final_image_info - def _resize_and_crop(self, image: Image.Image, target_size: tuple[int, int]) -> Image.Image: + def _resize_and_crop( + self, + image: Image.Image, + target_size: tuple[int, int], + crop_type: str = "resize", + ) -> Image.Image: + # Default mode mirrors the official `infer_align_image_size=True` + # path (image_processor.py:355 → crop_type="resize") used by the + # IT2I demo: stretch the cond image to the bucket dims so its + # `` tag and ViT/VAE features stay aligned with the + # bucket, instead of dropping content via center crop. tw, th = target_size + if crop_type == "resize": + return image.resize((tw, th), resample=Image.Resampling.LANCZOS) w, h = image.size - tr = th / tw r = h / w - - # resize if r < tr: resize_height = th resize_width = int(round(th / h * w)) else: resize_width = tw resize_height = int(round(tw / w * h)) - image = image.resize((resize_width, resize_height), resample=Image.Resampling.LANCZOS) - - # center crop crop_top = int(round((resize_height - th) / 2.0)) crop_left = int(round((resize_width - tw) / 2.0)) - - image = image.crop((crop_left, crop_top, crop_left + tw, crop_top + th)) - return image + return image.crop((crop_left, crop_top, crop_left + tw, crop_top + th)) class HunyuanImage3ProcessingInfo(BaseProcessingInfo): @@ -1030,8 +1065,13 @@ def _get_mm_fields_config( config["vit_pixel_attention_mask"] = MultiModalFieldConfig.batched("image") if "vit_spatial_shapes" in hf_inputs: config["vit_spatial_shapes"] = MultiModalFieldConfig.batched("image") - if "vae_pixel_values" in hf_inputs: - config["vae_pixel_values"] = MultiModalFieldConfig.batched("image") + # `vae_pixel_values` is a 1-D concatenation of variable-shape per-image + # VAE tensors (see `process_image`). `vae_pixel_size` carries the + # per-image flat length so vLLM can split the buffer back per image. + if "vae_pixel_values" in hf_inputs and "vae_pixel_size" in hf_inputs: + config["vae_pixel_values"] = MultiModalFieldConfig.flat_from_sizes("image", hf_inputs["vae_pixel_size"]) + if "vae_pixel_size" in hf_inputs: + config["vae_pixel_size"] = MultiModalFieldConfig.batched("image") if "vae_token_grid_hw" in hf_inputs: config["vae_token_grid_hw"] = MultiModalFieldConfig.batched("image") if "base_size" in hf_inputs: @@ -1085,38 +1125,37 @@ def get_replacement_image(item_idx: int) -> PromptUpdateDetails: ratio_token_id = tokenizer.convert_tokens_to_ids(f"") if ratio_token_id is None: raise ValueError(f"Ratio token '' not found in tokenizer vocabulary") + timestep_token_id = tokenizer.convert_tokens_to_ids("") + if timestep_token_id is None: + raise ValueError("Timestep token '' not found in tokenizer vocabulary") - # NOTE on the timestep slot: - # HF's apply_chat_template emits the literal token id - # 128017 here. HF's modeling forward (`instantiate_continuous_tokens`, - # see hunyuan3.0_ins/modeling_hunyuan_image_3.py:1964) then *scatter- - # replaces* the embedding at that position with `timestep_emb(0)` - # for cond images. So the wte embedding of is irrelevant - # at runtime — what matters is the timestep_emb injection. + # Use the real token id (HF parity). The trained wte + # at this slot is overwritten with timestep_emb(0) at runtime by + # `embed_input_ids`. # - # vllm-omni achieves the same effect via the multimodal-embedding - # merger: we put an (128006) placeholder here and ship a - # `timestep_emb(0)` tensor at the head of `embed_multimodal()`'s - # combined_embeddings. The merger replaces this placeholder's - # embedding with the timestep tensor, yielding a final hidden - # state numerically equivalent to HF at that position. - # - # Keep this slot as (NOT ): switching to - # requires either (a) a second PromptReplacement targeting 128017, - # or (b) the merger's embed_token_id to be a list — neither is - # currently supported by PromptUpdateDetails.select_token_id. + # Mark *VAE + + *ViT as one contiguous + # embed run so vLLM's prefix-LM mask treats it as a single + # bidirectional region, mirroring official `joint_image_slices` + # full-attention range (image_processor.py:388, with + # cond_token_attn_type effectively spanning VAE+sep+ViT). With the + # default `select_token_id()` mask, sep splits the run into + # two regions; that asymmetry is what biased multi-image AR + # ratio prediction to the first image's bucket. replacement = ( [boi_token_id] + [base_size_token_id] + [ratio_token_id] - + [img_token_id] * timestep_token_num + + [timestep_token_id] * timestep_token_num + [img_token_id] * vae_token_num + [joint_img_sep_token_id] + [img_token_id] * vit_token_num + [eoi_token_id] ) logger.debug(f"actual replacement token count: {timestep_token_num + vae_token_num + vit_token_num}") - return PromptUpdateDetails.select_token_id(replacement, embed_token_id=img_token_id) + return PromptUpdateDetails.select_token_ids( + replacement, + embed_token_ids=[img_token_id, joint_img_sep_token_id], + ) return [ PromptReplacement(modality="image", target=[img_token_id], replacement=get_replacement_image), @@ -1502,6 +1541,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._end_of_answer_id = tokenizer.convert_tokens_to_ids("") image_base_size = getattr(config, "image_base_size", 1024) self._size_token_id = tokenizer.convert_tokens_to_ids(f"") + self._timestep_token_id = tokenizer.convert_tokens_to_ids("") self._start_ratio_id = tokenizer.convert_tokens_to_ids("") self._end_ratio_id = tokenizer.convert_tokens_to_ids("") ratio_33 = tokenizer.convert_tokens_to_ids("") @@ -1669,6 +1709,9 @@ def _parse_and_validate_image_input( vit_pixel_attention_mask = kwargs.pop("vit_pixel_attention_mask", None) vit_spatial_shapes = kwargs.pop("vit_spatial_shapes", None) vae_pixel_values = kwargs.pop("vae_pixel_values", None) + # vae_pixel_size is only metadata for vLLM's flat_from_sizes split; + # we reconstruct per-image shapes from vae_token_grid_hw below. + kwargs.pop("vae_pixel_size", None) vae_token_grid_hw = kwargs.pop("vae_token_grid_hw", None) if vit_pixel_values is None or vae_pixel_values is None: @@ -1678,13 +1721,36 @@ def _parse_and_validate_image_input( if vit_pixel_values.numel() == 0 or vae_pixel_values.numel() == 0: return None + # `vae_pixel_values` arrives as a 1-D concatenation of per-image flat + # buffers (see `process_image` + `flat_from_sizes`). Reconstruct a + # list of per-image (3, H_i, W_i) tensors using the per-image grid + # dims so the downstream VAE encoder can run image-by-image. + vae_factor_h = self.config.vae_downsample_factor[0] * self.config.patch_size + vae_factor_w = self.config.vae_downsample_factor[1] * self.config.patch_size + num_images = vae_token_grid_hw.shape[0] + vae_image_list: list[torch.Tensor] = [] + offset = 0 + flat = vae_pixel_values.reshape(-1) + for i in range(num_images): + token_h, token_w = vae_token_grid_hw[i].tolist() + h_i = int(token_h) * vae_factor_h + w_i = int(token_w) * vae_factor_w + n_i = 3 * h_i * w_i + vae_image_list.append(flat[offset : offset + n_i].reshape(3, h_i, w_i)) + offset += n_i + if offset != flat.numel(): + raise ValueError( + f"vae_pixel_values size mismatch: consumed {offset} of {flat.numel()} elements " + f"across {num_images} images (token_grid_hw={vae_token_grid_hw.tolist()})" + ) + return HunyuanImage3PixelInputs( type="pixel_values", pixel_values={ "vit_pixel_values": vit_pixel_values, "vit_pixel_attention_mask": vit_pixel_attention_mask, "vit_spatial_shapes": vit_spatial_shapes, - "vae_pixel_values": vae_pixel_values, + "vae_pixel_values": vae_image_list, "vae_token_grid_hw": vae_token_grid_hw, }, ) @@ -1712,7 +1778,11 @@ def _vae_encode( images = images.to(dtype=self.vae.dtype) vae_encode_result = self.vae.encode(images) - latents = vae_encode_result.latent_dist.sample() + # Match HunyuanImage-3's cond encode path: sample the posterior, but + # use a fixed generator so online requests do not consume the global + # RNG and drift across a long-running server. + _cond_vae_gen = torch.Generator(device=images.device).manual_seed(0) + latents = vae_encode_result.latent_dist.sample(_cond_vae_gen) # Apply shift and scaling factors if present if hasattr(config, "shift_factor") and config.shift_factor: @@ -1796,22 +1866,12 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: # Perform ViT encoding vit_embeddings = self._vit_encode(vit_pixel_values, vit_pixel_attention_mask, vit_spatial_shapes) - # Perform VAE encoding - t, latents = self._vae_encode(vae_pixel_values, vae_cfg_factor) - - # Process VAE latents through patch_embed to convert to token embeddings - # VAE latents are in (B, C, H, W) format, need to be converted to (B, seq_len, hidden_size) + # VAE encode + patch_embed per image — each cond image is at its own + # `reso_group` bucket so shapes are ragged across the image-batch dim. vae_token_embeddings = [] - batch_size = latents.shape[0] - for i in range(batch_size): - t_i = t[i] - latents_i = latents[i : i + 1] # Shape: (1, C, H, W) - - # Time embedding for VAE processing - t_emb = self.time_embed(t_i) - - # Process VAE latent through patch_embed - # Input: (1, C, H, W) -> Output: (1, seq_len, hidden_size) + for vae_image_i in vae_pixel_values: + t_i, latents_i = self._vae_encode(vae_image_i.unsqueeze(0), vae_cfg_factor) + t_emb = self.time_embed(t_i[0]) vae_tokens, _, _ = self.patch_embed(latents_i, t_emb) vae_token_embeddings.append(vae_tokens) @@ -1821,27 +1881,31 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: "Each image should have both VAE and ViT embeddings." ) - # Order per image: timestep -> VAE tokens -> ViT tokens. - # The placeholder at the timestep slot (see _get_prompt_updates) - # gets its embedding replaced by `timestep_emb(0)` here, which is what - # HF achieves via instantiate_continuous_tokens at runtime. + # Order per image: VAE tokens -> wte -> ViT tokens. + # The wte is included so it joins the bidirectional + # MM region (matching the official `joint_image_slices` full-attn + # range that spans VAE+sep+ViT). The merger replaces the sep slot + # with this wte tensor, which is numerically identical to what + # `model.embed_input_ids` would produce — no semantic change for + # single-image, but with multi-image the sep position now sits + # inside the bidirectional region (matching how the model was + # trained). + sep_token_id = self._mrope_joint_img_sep_token_id + sep_input_ids = torch.tensor([sep_token_id], device=vit_embeddings.device, dtype=torch.long) + sep_embed = self.model.embed_input_ids(sep_input_ids).to(vit_embeddings.dtype) + + # The slot at the head of each per-image scaffold is NOT + # included here — its embedding is patched in by `embed_input_ids` + # via a token-id mask, mirroring HF's `instantiate_continuous_tokens` + # scatter-replace. combined_embeddings: list[torch.Tensor] = [] num_images = len(vae_token_embeddings) for img_idx in range(num_images): - # 1. Timestep embedding (cond image timestep == 0) - timestep = torch.zeros((1,)).to(vit_embeddings.device).to(vit_embeddings.dtype) - timestep_emb = self._timestep_encode(timestep) - - # 2. VAE image token embeddings vae_token_embed = vae_token_embeddings[img_idx] - # Remove batch dimension if present: (B, seq_len, hidden_size) -> (seq_len, hidden_size) if vae_token_embed.ndim == 3: vae_token_embed = vae_token_embed.squeeze(0) - - # 3. ViT image embeddings vit_embed = vit_embeddings[img_idx] - - stacked_embed = torch.cat([timestep_emb, vae_token_embed, vit_embed], dim=0) + stacked_embed = torch.cat([vae_token_embed, sep_embed, vit_embed], dim=0) combined_embeddings.append(stacked_embed) return combined_embeddings @@ -1854,14 +1918,23 @@ def embed_input_ids( is_multimodal: torch.Tensor | None = None, ) -> torch.Tensor: """Embed input IDs with optional multimodal embeddings.""" - # Get text embeddings inputs_embeds = self.model.embed_input_ids(input_ids) - # If no multimodal embeddings, return text embeddings + # Patch slots with timestep_emb(0). HF parity: the trained + # wte at this slot is irrelevant; runtime uses + # `instantiate_continuous_tokens(timestep_emb(0))`. With multi-image, + # keeping these slots as ids merged the timestep position into + # the bidirectional MM region and biased AR ratio prediction toward + # the first image's bucket. + timestep_mask = input_ids == self._timestep_token_id + n_timestep = int(timestep_mask.sum().item()) + if n_timestep > 0: + timestep_input = torch.zeros((n_timestep,), device=inputs_embeds.device, dtype=inputs_embeds.dtype) + inputs_embeds[timestep_mask] = self._timestep_encode(timestep_input) + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: return inputs_embeds - # Merge multimodal embeddings with text embeddings merged_embeds = _merge_multimodal_embeddings( inputs_embeds=inputs_embeds, multimodal_embeddings=multimodal_embeddings, @@ -2077,6 +2150,7 @@ def get_mrope_input_positions( boi_token_id = self._mrope_boi_token_id eoi_token_id = self._mrope_eoi_token_id joint_img_sep_token_id = self._mrope_joint_img_sep_token_id + timestep_token_id = self._timestep_token_id # Build position arrays t_pos: list[int] = [] # temporal (same as 1D for this model) @@ -2093,7 +2167,7 @@ def get_mrope_input_positions( if tok == boi_token_id: # Found start of image block. - # Structure: *timestep *vae + # Structure: *vae # *vit # token t_pos.append(pos) @@ -2118,8 +2192,8 @@ def get_mrope_input_positions( pos += 1 i += 1 - # Timestep token (1 token) - if i < n and input_tokens[i] == img_token_id: + # token (1 token) + if i < n and input_tokens[i] == timestep_token_id: t_pos.append(pos) h_pos.append(pos) w_pos.append(pos) diff --git a/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py b/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py index b7630bb8ac8..749e213e099 100644 --- a/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py +++ b/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Stage input processor for HunyuanImage3: AR → Diffusion transition. +"""Stage input processor for HunyuanImage3: AR to Diffusion transition. In IT2I (image editing) mode: - Stage 0 (AR) receives (image + edit instruction), generates CoT/latent tokens - - Stage 1 (DiT) receives the AR output + original image, denoises → edited image + - Stage 1 (DiT) receives the AR output + original image, denoises to edited image The ar2diffusion function bridges these two stages, following the same signature pattern as glm_image.ar2diffusion. @@ -12,16 +12,154 @@ from __future__ import annotations +from functools import lru_cache from typing import Any -import torch from vllm.inputs import TextPrompt from vllm.logger import init_logger +from vllm_omni.diffusion.models.hunyuan_image3.prompt_utils import ( + HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS, +) from vllm_omni.inputs.data import OmniTokensPrompt logger = init_logger(__name__) +# AR emits `` after `` in IT2I/T2I +# (see `HunyuanImage3ForCausalMM.sample` and `_stage_transitions`). The +# ratio_index resolves to a (height, width) bucket via ResolutionGroup, which +# is the official upstream's mechanism for AR-driven output aspect; without +# this lookup the DiT pipeline falls back to the user-provided width/height +# (in the `/v1/images/edits` path that defaults to `pil_images[0].size`, +# i.e. the first reference image's bucket, usually square, see +# api_server.py:1808-1811). +_HUNYUAN_IMAGE3_EXTRA_RESOLUTIONS: tuple[str, ...] = ( + "1024x768", + "1280x720", + "768x1024", + "720x1280", +) + + +class _Resolution: + def __init__(self, size: str | int | tuple[int, int], *args: int): + if isinstance(size, str): + if "x" in size: + h, w = size.split("x") + size = (int(h), int(w)) + else: + size = int(size) + if args: + size = (int(size), args[0]) + if isinstance(size, int): + size = (size, size) + + self.height = int(size[0]) + self.width = int(size[1]) + self.ratio = self.height / self.width + + +def _build_resolutions_by_step(base_size: int, align: int = 1) -> list[_Resolution]: + step = base_size // 16 + min_height = base_size // 2 + min_width = base_size // 2 + max_height = base_size * 2 + max_width = base_size * 2 + + resolutions = [_Resolution(base_size, base_size)] + + cur_height, cur_width = base_size, base_size + while True: + if cur_height >= max_height and cur_width <= min_width: + break + cur_height = min(cur_height + step, max_height) + cur_width = max(cur_width - step, min_width) + resolutions.append(_Resolution(cur_height // align * align, cur_width // align * align)) + + cur_height, cur_width = base_size, base_size + while True: + if cur_height <= min_height and cur_width >= max_width: + break + cur_height = max(cur_height - step, min_height) + cur_width = min(cur_width + step, max_width) + resolutions.append(_Resolution(cur_height // align * align, cur_width // align * align)) + + return sorted(resolutions, key=lambda x: x.ratio) + + +@lru_cache(maxsize=4) +def _build_ratio_size_table(base_size: int) -> list[tuple[int, int]]: + """Return `[(height, width)]` indexed by ratio_index for HunyuanImage-3. + + Mirrors `HunyuanImage3ImageProcessor.build_image_info`'s + `reso_group[ratio_index]` reverse lookup. Cached because the table + is constant per `base_size`. + """ + resolutions = _build_resolutions_by_step(base_size) + for extra_resolution in (_Resolution(s) for s in _HUNYUAN_IMAGE3_EXTRA_RESOLUTIONS): + if not any(r.ratio == extra_resolution.ratio for r in resolutions): + resolutions.append(extra_resolution) + return [(r.height, r.width) for r in resolutions] + + +def _truncate_at_cot_end(generated_text: str) -> str: + """Truncate AR output at first `` (or `` fallback). + + Mirrors upstream `HunyuanImage3ForCausalMM.generate_image` which feeds + DiT only the cot text up to the closing tag; the trailing + `` is consumed via height/width + extraction and must not leak into DiT's prompt builder. + """ + for marker in ("", ""): + idx = generated_text.find(marker) + if idx != -1: + return generated_text[: idx + len(marker)] + return generated_text + + +@lru_cache(maxsize=4) +def _build_ratio_id_lookup() -> dict[int, int]: + """Return `{token_id: ratio_index}` for HunyuanImage3 ratio tokens. + + The ids are fixed in tokenizer.json and already pinned in prompt_utils. + Avoid loading AutoTokenizer here: this bridge runs on the hot AR->DiT + transition path and must keep working in offline deployments where the + tokenizer object is not exposed to the stage-input processor. + """ + ratio_0 = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + ratio_32 = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + ratio_33 = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + ratio_36 = HUNYUAN_IMAGE3_SPECIAL_TOKEN_IDS[""] + + table: dict[int, int] = {} + for i in range(ratio_32 - ratio_0 + 1): + table[ratio_0 + i] = i + base_idx = ratio_32 - ratio_0 + 1 + for j in range(ratio_36 - ratio_33 + 1): + table[ratio_33 + j] = base_idx + j + return table + + +def _extract_ratio_index(generated_token_ids) -> int | None: + """Resolve the AR-predicted ratio_index from this stage's output. + + `HunyuanImage3ForCausalMM`'s `_stage_transitions` forces the AR to emit + exactly one `` token after ` + `, so we scan the token stream from the tail for the first + id that maps to a ratio. Token-ids are the source of truth; text-side + regex is unreliable because most deploy yamls run AR with + `skip_special_tokens: True` (special tokens are stripped from text but + still present in `cumulative_token_ids`). + """ + if generated_token_ids is None: + return None + table = _build_ratio_id_lookup() + for tid in reversed(list(generated_token_ids)): + idx = table.get(int(tid)) + if idx is not None: + return idx + return None + def ar2diffusion( source_outputs: list[Any], @@ -64,31 +202,65 @@ def ar2diffusion( width = original_prompt.get("width", 1024) text_prompt = original_prompt.get("prompt", "") use_system_prompt = original_prompt.get("use_system_prompt") + custom_system_prompt = original_prompt.get("system_prompt") + + # Prefer the AR's predicted output aspect (`` + # tail emitted by `HunyuanImage3ForCausalMM.sample` under the + # ratio-restriction logits processor) over the carried-through + # height/width, which the serving layer fills with the first + # reference image's bucket and so collapses non-square targets to + # square in the multi-image / mismatched-aspect case. Mirrors the + # official upstream where `reso_group[ratio_index]` is the + # canonical source of the diffusion target shape. + ratio_idx = _extract_ratio_index(generated_token_ids) + ar_predicted = False + if ratio_idx is not None: + base_size = int(original_prompt.get("image_base_size", 1024)) + size_table = _build_ratio_size_table(base_size) + if 0 <= ratio_idx < len(size_table): + height, width = size_table[ratio_idx] + ar_predicted = True + else: + logger.warning( + "[ar2diffusion] Request %d: ratio_index=%d out of range [0,%d), keeping prompt size %dx%d", + i, + ratio_idx, + len(size_table), + height, + width, + ) + + cot_text_for_dit = _truncate_at_cot_end(generated_text) logger.info( - "[ar2diffusion] Request %d: AR generated %d tokens, text length=%d, target size=%dx%d", + "[ar2diffusion] Request %d: AR generated %d tokens, text length=%d, " + "cot_text length=%d, target size=%dx%d (%s)", i, len(generated_token_ids), len(generated_text), + len(cot_text_for_dit), height, width, + f"AR ratio_idx={ratio_idx}" if ar_predicted else "from prompt (no AR ratio token)", ) - token_tensor = torch.tensor(generated_token_ids, dtype=torch.long) - diffusion_input: dict[str, Any] = { "prompt": text_prompt, "height": height, "width": width, "extra": { - "ar_token_ids": token_tensor, - "ar_generated_text": generated_text, + "ar_generated_text": cot_text_for_dit, }, } - # Forward use_system_prompt so the DiT can build the same system prefix + # Forward use_system_prompt so the DiT can build the same system prefix. + # Also forward the custom system prompt body when sys_type=custom so + # DiT's `get_system_prompt(use, "image", body)` doesn't fall back to + # an empty prefix and silently diverge from AR. if use_system_prompt is not None: diffusion_input["use_system_prompt"] = use_system_prompt + if custom_system_prompt is not None: + diffusion_input["system_prompt"] = custom_system_prompt # Forward multimodal data (original image for IT2I conditioning). # The diffusion pre_process_func reads multi_modal_data["image"], which