diff --git a/examples/offline_inference/hunyuan_image3/README.md b/examples/offline_inference/hunyuan_image3/README.md
index 3cd8fa01b2e..3eb3bfbff6f 100644
--- a/examples/offline_inference/hunyuan_image3/README.md
+++ b/examples/offline_inference/hunyuan_image3/README.md
@@ -135,17 +135,19 @@ python end2end.py --model tencent/HunyuanImage-3.0-Instruct \
## Prompt Format
-HunyuanImage-3.0 uses a pretrain template format:
+HunyuanImage-3.0-Instruct uses an instruct chat template:
```
-<|startoftext|>{system_prompt}{
}{trigger_tag}{user_prompt}
+<|startoftext|>{system_prompt}\n\nUser: {
?}{user_prompt}\n\nAssistant: {trigger_tag?}
```
-- `
`: Placeholder for each input image (auto-inserted by `prompt_utils.py`)
-- Trigger tags: `` (CoT), `` (recaptioning)
+- `
`: Placeholder for each input image (single token; expanded by the multimodal pipeline)
+- Trigger tags: `` (CoT), `` (recaptioning) — placed AFTER `Assistant: `
- System prompt: Auto-selected based on task
+- `t2i_vanilla` is the only task that uses the bare pretrain template (no chat structure)
-The `prompt_utils.build_prompt()` handles this formatting automatically.
+The shared `vllm_omni.diffusion.models.hunyuan_image3.prompt_utils.build_prompt_tokens()`
+helper handles segment-by-segment tokenization (matches HF `apply_chat_template` byte-for-byte).
------
diff --git a/examples/offline_inference/hunyuan_image3/end2end.py b/examples/offline_inference/hunyuan_image3/end2end.py
index 2cea303888e..fc4b75e78d5 100644
--- a/examples/offline_inference/hunyuan_image3/end2end.py
+++ b/examples/offline_inference/hunyuan_image3/end2end.py
@@ -16,23 +16,12 @@
import argparse
import os
-from vllm_omni.diffusion.models.hunyuan_image3.system_prompt import (
- get_system_prompt,
+from vllm_omni.diffusion.models.hunyuan_image3.prompt_utils import (
+ build_prompt_tokens,
)
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniPromptType
-# task → (sys_type, bot_task, trigger_tag)
-_TASK_PRESETS: dict[str, tuple[str, str | None, str | None]] = {
- "t2t": ("en_unified", None, None),
- "i2t": ("en_unified", None, None),
- "it2i_think": ("en_unified", "think", ""),
- "it2i_recaption": ("en_unified", "recaption", ""),
- "t2i_think": ("en_unified", "think", ""),
- "t2i_recaption": ("en_unified", "recaption", ""),
- "t2i_vanilla": ("en_vanilla", "image", None),
-}
-
# Modality → prompt_utils task mapping
_MODALITY_TASK_MAP = {
"text2img": "t2i_think",
@@ -42,36 +31,6 @@
}
-def build_prompt(
- user_prompt: str,
- task: str = "it2i_think",
- sys_type: str | None = None,
- custom_system_prompt: str | None = None,
-) -> str:
- """Build a HunyuanImage-3.0 prompt using pretrain template format."""
- if task not in _TASK_PRESETS:
- raise ValueError(f"Unknown task {task!r}. Choose from: {sorted(_TASK_PRESETS)}")
-
- preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task]
- effective_sys_type = sys_type or preset_sys_type
-
- system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt)
- sys_text = system_prompt.strip() if system_prompt else ""
-
- has_image_input = task.startswith("i2t") or task.startswith("it2i")
-
- parts = ["<|startoftext|>"]
- if sys_text:
- parts.append(sys_text)
- if has_image_input:
- parts.append("
")
- if trigger_tag:
- parts.append(trigger_tag)
- parts.append(user_prompt)
-
- return "".join(parts)
-
-
# Modality → default stage config
_MODALITY_DEFAULT_CONFIG = {
"text2img": "hunyuan_image3_t2i.yaml",
@@ -179,12 +138,18 @@ def main():
input_image = Image.open(args.image_path).convert("RGB")
+ # Load tokenizer for segment-wise prompt tokenization (matches HF
+ # apply_chat_template byte-for-byte; see build_prompt_tokens docstring).
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
+
# Format prompts
formatted_prompts: list[OmniPromptType] = []
for p in prompts:
- formatted_text = build_prompt(p, task=task, sys_type=args.sys_type)
+ token_ids = build_prompt_tokens(p, tokenizer, task=task, sys_type=args.sys_type)
- prompt_dict: dict = {"prompt": formatted_text}
+ prompt_dict: dict = {"prompt_token_ids": token_ids}
if args.modality == "text2img":
prompt_dict["modalities"] = ["image"]
diff --git a/tests/diffusion/models/hunyuan_image3/test_prompt_utils.py b/tests/diffusion/models/hunyuan_image3/test_prompt_utils.py
new file mode 100644
index 00000000000..501664fe688
--- /dev/null
+++ b/tests/diffusion/models/hunyuan_image3/test_prompt_utils.py
@@ -0,0 +1,292 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Regression tests for HunyuanImage3 prompt construction (PR #3243).
+
+Two layers:
+ 1. Pure-logic tests with a recording fake tokenizer -- protect the
+ prompt template structure (BOS, User:/Assistant: framing, trigger
+ placement, image placeholder position) and protect the segment-
+ by-segment tokenization contract (each segment must hit
+ `tokenizer.encode` in isolation).
+ 2. Real-tokenizer regression -- run when the HunyuanImage3-Instruct
+ tokenizer is in the local HF cache. Asserts the segment-tokenized
+ output diverges from the naive full-string encode, which is the
+ bug-tripping fixture for the cross-segment BPE merge fix
+ (commit 7bd429ed).
+"""
+
+from __future__ import annotations
+
+import ast
+import os
+import pathlib
+
+import pytest
+
+from vllm_omni.diffusion.models.hunyuan_image3.prompt_utils import (
+ available_tasks,
+ build_prompt,
+ build_prompt_tokens,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+# -------------------- Pure-logic structural tests --------------------
+
+
+class FakeTokenizer:
+ """Minimal tokenizer stub that records every encode() call.
+
+ Returns deterministic ids: special tokens map to small ints (1-4),
+ encode() returns one id per character starting at 100. This lets
+ tests both verify segmentation (by inspecting `encode_calls`) and
+ locate substrings inside the returned id list.
+ """
+
+ SPECIAL = {
+ "<|startoftext|>": 1,
+ "
": 2,
+ "": 3,
+ "": 4,
+ }
+
+ def __init__(self) -> None:
+ self.encode_calls: list[str] = []
+
+ def convert_tokens_to_ids(self, tok: str) -> int:
+ return self.SPECIAL.get(tok, 0)
+
+ def encode(self, text: str, add_special_tokens: bool = False) -> list[int]:
+ self.encode_calls.append(text)
+ return list(range(100, 100 + len(text)))
+
+
+def test_available_tasks_covers_all_modalities():
+ tasks = set(available_tasks())
+ assert tasks >= {
+ "t2t",
+ "i2t",
+ "it2i_think",
+ "it2i_recaption",
+ "t2i_think",
+ "t2i_recaption",
+ "t2i_vanilla",
+ }
+
+
+@pytest.mark.parametrize(
+ "task",
+ [
+ "t2t",
+ "i2t",
+ "it2i_think",
+ "it2i_recaption",
+ "t2i_think",
+ "t2i_recaption",
+ ],
+)
+def test_build_prompt_string_structure_chat_template(task: str):
+ """Chat-template tasks must produce <|startoftext|>...User: ...Assistant: ...
+ with image placeholder (when applicable) and trigger tag AFTER `Assistant: `."""
+ s = build_prompt("HELLO", task=task)
+
+ assert s.startswith("<|startoftext|>")
+ assert "User: " in s
+ assert "Assistant: " in s
+ assert s.index("User: ") < s.index("HELLO") < s.index("Assistant: ")
+
+ if task.startswith(("i2t", "it2i")):
+ assert s.index("User: ") < s.index("
") < s.index("HELLO"), (
+ "
placeholder must sit between `User: ` and the user prompt"
+ )
+ else:
+ assert "
" not in s
+
+ # Trigger tag must be the FINAL token of the prompt (after `Assistant: `).
+ # Note: the system prompt itself mentions / as mode
+ # documentation, so substring index() catches the wrong occurrence -- use
+ # endswith() which directly captures "trigger is at the tail" (the Part A
+ # fix: trigger goes AFTER `Assistant: `, not before user_prompt).
+ if task in ("it2i_think", "t2i_think"):
+ assert s.endswith("Assistant: "), (
+ f"Trigger must be appended right after `Assistant: ` (Part A fix). Got tail: ...{s[-40:]!r}"
+ )
+ if task in ("it2i_recaption", "t2i_recaption"):
+ assert s.endswith("Assistant: "), (
+ f"Trigger must be appended right after `Assistant: ` (Part A fix). Got tail: ...{s[-40:]!r}"
+ )
+ if task in ("t2t", "i2t"):
+ assert s.endswith("Assistant: "), "Plain (no-trigger) task must end at `Assistant: ` with no trailing tag."
+
+
+def test_build_prompt_vanilla_uses_pretrain_template():
+ """t2i_vanilla is the only task that bypasses chat structure -- direct
+ text->image generation driven by the vanilla system prompt."""
+ s = build_prompt("HELLO", task="t2i_vanilla")
+ assert s.startswith("<|startoftext|>")
+ assert "User: " not in s
+ assert "Assistant: " not in s
+ assert "" not in s
+ assert "" not in s
+ assert s.endswith("HELLO")
+
+
+def test_build_prompt_unknown_task_raises():
+ with pytest.raises(ValueError, match="Unknown task"):
+ build_prompt("x", task="bogus")
+ with pytest.raises(ValueError, match="Unknown task"):
+ build_prompt_tokens("x", FakeTokenizer(), task="bogus")
+
+
+def test_build_prompt_tokens_segments_each_boundary():
+ """Regression for cross-segment BPE merge bug (commit 7bd429ed):
+ each template segment must hit tokenizer.encode() independently;
+ user_prompt MUST NOT be concatenated with the following separator
+ in the same encode() call."""
+ tok = FakeTokenizer()
+ build_prompt_tokens("写诗。", tok, task="i2t")
+
+ # Each canonical segment is encoded in its own call.
+ assert "User: " in tok.encode_calls
+ assert "写诗。" in tok.encode_calls, (
+ "user_prompt must be encoded alone -- if it is concatenated with the "
+ "trailing separator, BPE will merge across the boundary (the PR-#3243 bug)."
+ )
+ assert "\n\nAssistant: " in tok.encode_calls
+
+ # No call must contain user_prompt glued to neighboring text.
+ for call in tok.encode_calls:
+ if call != "写诗。":
+ assert "写诗。" not in call, f"user_prompt leaked into a multi-segment encode call: {call!r}"
+
+
+def test_build_prompt_tokens_image_placeholder_present_for_image_tasks():
+ tok = FakeTokenizer()
+ ids = build_prompt_tokens("hi", tok, task="i2t")
+ assert ids[0] == 1, "BOS (<|startoftext|>) must be the first token"
+ assert 2 in ids, "
placeholder must be present for i2t/it2i tasks"
+
+
+def test_build_prompt_tokens_no_image_for_text_only_tasks():
+ tok = FakeTokenizer()
+ ids = build_prompt_tokens("hi", tok, task="t2t")
+ assert 2 not in ids, "
must NOT appear for text-only tasks"
+
+
+@pytest.mark.parametrize(
+ "task,trigger_id",
+ [("it2i_think", 3), ("t2i_think", 3), ("it2i_recaption", 4), ("t2i_recaption", 4)],
+)
+def test_build_prompt_tokens_trigger_is_last_token(task: str, trigger_id: int):
+ """Trigger tag id must be the LAST token (after `Assistant: ` segment)."""
+ tok = FakeTokenizer()
+ ids = build_prompt_tokens("hi", tok, task=task)
+ assert ids[-1] == trigger_id
+
+
+def test_build_prompt_tokens_no_trigger_for_plain_tasks():
+ """Tasks without trigger_tag (t2t / i2t) must NOT append a trigger id."""
+ tok = FakeTokenizer()
+ ids = build_prompt_tokens("hi", tok, task="t2t")
+ assert ids[-1] not in {3, 4} # neither nor
+
+
+# -------------------- end2end.py wiring guard --------------------
+
+
+def _repo_root() -> pathlib.Path:
+ # tests/diffusion/models/hunyuan_image3/test_prompt_utils.py -> repo root
+ return pathlib.Path(__file__).resolve().parents[4]
+
+
+def test_end2end_routes_through_shared_prompt_utils():
+ """Regression for the *delivery vector* of PR #3243.
+
+ Background: the wrong-template bug that PR #3243 fixes was introduced
+ when end2end.py grew its own hand-rolled prompt builder that diverged
+ from the canonical instruct chat template. To prevent that exact
+ failure mode from recurring, end2end.py MUST:
+ 1. Import the prompt builders from the shared prompt_utils module.
+ 2. NOT redefine `build_prompt` or `build_prompt_tokens` locally.
+
+ A local redefinition is precisely how a future merge can silently
+ re-introduce a pretrain-style template (trigger BEFORE user_prompt,
+ no User:/Assistant: framing, etc.) without touching prompt_utils,
+ bypassing every other test in this file.
+ """
+ end2end_path = _repo_root() / "examples" / "offline_inference" / "hunyuan_image3" / "end2end.py"
+ assert end2end_path.is_file(), f"end2end.py not found at {end2end_path}"
+
+ tree = ast.parse(end2end_path.read_text(encoding="utf-8"))
+
+ local_func_names = {n.name for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)}
+ forbidden = {"build_prompt", "build_prompt_tokens"}
+ redefined = local_func_names & forbidden
+ assert not redefined, (
+ f"end2end.py defines {sorted(redefined)} locally. This is exactly how "
+ "the wrong prompt template re-entered the example before PR #3243. "
+ "Use the shared `vllm_omni.diffusion.models.hunyuan_image3.prompt_utils` "
+ "helpers instead."
+ )
+
+ imported_from_prompt_utils: set[str] = set()
+ for node in ast.walk(tree):
+ if isinstance(node, ast.ImportFrom) and node.module and node.module.endswith("hunyuan_image3.prompt_utils"):
+ imported_from_prompt_utils.update(alias.name for alias in node.names)
+ assert "build_prompt_tokens" in imported_from_prompt_utils, (
+ "end2end.py must import build_prompt_tokens from "
+ "vllm_omni.diffusion.models.hunyuan_image3.prompt_utils -- the shared "
+ "helper is the single source of truth for the AR-prefill template."
+ )
+
+
+# -------------------- Real-tokenizer regression --------------------
+
+
+_HUNYUAN_MODEL_ID = "tencent/HunyuanImage-3.0-Instruct"
+
+
+def _hf_cached(model_id: str) -> bool:
+ hf_home = os.environ.get("HF_HOME") or os.path.expanduser("~/.cache/huggingface")
+ snap_dir = os.path.join(hf_home, "hub", f"models--{model_id.replace('/', '--')}", "snapshots")
+ return os.path.isdir(snap_dir) and any(os.scandir(snap_dir))
+
+
+@pytest.mark.skipif(
+ not _hf_cached(_HUNYUAN_MODEL_ID),
+ reason=f"{_HUNYUAN_MODEL_ID} tokenizer not in HF cache",
+)
+def test_segment_tokenize_diverges_from_full_string_encode():
+ """Regression for PR #3243 segment-tokenization fix.
+
+ The naive `tokenizer.encode(build_prompt(...))` lets BPE merge tokens
+ across segment boundaries (notably `。\\n\\n` -> a single id), which
+ drifts the AR prefill away from HF's apply_chat_template output. The
+ segment-by-segment build_prompt_tokens must produce a STRICTLY
+ DIFFERENT id sequence on a prompt that triggers the merge.
+
+ If someone "simplifies" build_prompt_tokens to call encode() on the
+ full string, this assertion fires.
+ """
+ from transformers import AutoTokenizer
+
+ tok = AutoTokenizer.from_pretrained(_HUNYUAN_MODEL_ID, trust_remote_code=True)
+
+ user_prompt = "写一首关于夜的诗。"
+ seg_ids = build_prompt_tokens(user_prompt, tok, task="i2t")
+ full_ids = tok.encode(build_prompt(user_prompt, task="i2t"), add_special_tokens=False)
+
+ assert seg_ids != full_ids, (
+ "build_prompt_tokens output equals naive full-string encode -- "
+ "the BPE-merge-bypass behavior is no longer exercised. This means "
+ "the segment-by-segment fix from PR #3243 has been silently undone."
+ )
+
+ # Segmenting prevents merges, so the segment id list should have AT LEAST
+ # as many tokens as the merged version (a merge consumes 2+ ids -> 1).
+ assert len(seg_ids) >= len(full_ids), (
+ f"segment-encoded length ({len(seg_ids)}) shorter than full-string "
+ f"merged length ({len(full_ids)}) -- impossible if segmentation is "
+ f"genuinely bypassing merges."
+ )
diff --git a/vllm_omni/diffusion/models/hunyuan_image3/prompt_utils.py b/vllm_omni/diffusion/models/hunyuan_image3/prompt_utils.py
new file mode 100644
index 00000000000..6e8efac3133
--- /dev/null
+++ b/vllm_omni/diffusion/models/hunyuan_image3/prompt_utils.py
@@ -0,0 +1,152 @@
+# SPDX-License-Identifier: Apache-2.0
+"""Shared prompt-template construction for HunyuanImage-3.0-Instruct.
+
+Single source of truth for the AR-prefill prompt format used by the
+example scripts and any downstream caller that needs to build
+HunyuanImage3 chat-template token sequences without invoking the full
+diffusion pipeline tokenizer wrapper.
+
+The DiT pipeline (`pipeline_hunyuan_image3.py`) builds prompts through
+`TokenizerWrapper.apply_chat_template`, which eagerly consumes
+`JointImageInfo` objects produced by image preprocessing. The example
+flow uses an `
` placeholder + `multi_modal_data` instead, so it
+needs a lighter-weight builder that only requires a HF tokenizer. This
+module provides that builder; the task -> template mapping below is the
+canonical mapping for both flows.
+"""
+
+from __future__ import annotations
+
+from .system_prompt import get_system_prompt
+
+# task -> (sys_type, bot_task, trigger_tag)
+_TASK_PRESETS: dict[str, tuple[str, str | None, str | None]] = {
+ "t2t": ("en_unified", None, None),
+ "i2t": ("en_unified", None, None),
+ "it2i_think": ("en_unified", "think", ""),
+ "it2i_recaption": ("en_unified", "recaption", ""),
+ "t2i_think": ("en_unified", "think", ""),
+ "t2i_recaption": ("en_unified", "recaption", ""),
+ "t2i_vanilla": ("en_vanilla", "image", None),
+}
+
+
+def available_tasks() -> list[str]:
+ """Sorted list of task keys accepted by `build_prompt` / `build_prompt_tokens`."""
+ return sorted(_TASK_PRESETS)
+
+
+def build_prompt(
+ user_prompt: str,
+ task: str = "it2i_think",
+ sys_type: str | None = None,
+ custom_system_prompt: str | None = None,
+) -> str:
+ """Build a HunyuanImage-3.0 prompt as a string (legacy/compat path).
+
+ NOTE: when this string is passed to the engine, the engine's tokenizer
+ will run a single BPE pass over the whole string, which can merge
+ tokens across segment boundaries (e.g. `。\\n\\n` -> id 3490). For
+ inputs that need to match HF baseline byte-for-byte, use
+ `build_prompt_tokens` instead and feed the result via prompt_token_ids.
+ """
+ if task not in _TASK_PRESETS:
+ raise ValueError(f"Unknown task {task!r}. Choose from: {available_tasks()}")
+
+ preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task]
+ effective_sys_type = sys_type or preset_sys_type
+
+ system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt)
+ sys_text = system_prompt.strip() if system_prompt else ""
+
+ has_image_input = task.startswith("i2t") or task.startswith("it2i")
+
+ # t2i_vanilla: pretrain mode for direct text->image generation. The
+ # vanilla system prompt drives the model with no chat structure.
+ if task == "t2i_vanilla":
+ parts = ["<|startoftext|>"]
+ if sys_text:
+ parts.append(sys_text)
+ parts.append(user_prompt)
+ return "".join(parts)
+
+ # All other tasks (t2t / i2t / t2i_think / t2i_recaption /
+ # it2i_think / it2i_recaption) use HunyuanImage3 Instruct chat template:
+ # <|startoftext|>{system?}\n\nUser: {
?}{user_prompt}\n\nAssistant: {trigger?}
+ # generation_config.json declares sequence_template="instruct", so the
+ # AR prefill MUST use this template -- verified to match HF's
+ # apply_chat_template output token-for-token (modulo BPE boundary merges).
+ # The trigger_tag (e.g. ) MUST come AFTER the `Assistant: ` prefix:
+ # if it goes BEFORE user_prompt (the old pretrain layout) the model puts
+ # the user's instructions inside the "thinking section" and collapses
+ # into repetition garbage under greedy decoding.
+ parts = ["<|startoftext|>"]
+ if sys_text:
+ parts.append(f"{sys_text}\n\n")
+ parts.append("User: ")
+ if has_image_input:
+ parts.append("
")
+ parts.append(user_prompt)
+ parts.append("\n\nAssistant: ")
+ if trigger_tag:
+ parts.append(trigger_tag)
+
+ return "".join(parts)
+
+
+def build_prompt_tokens(
+ user_prompt: str,
+ tokenizer,
+ task: str = "it2i_think",
+ sys_type: str | None = None,
+ custom_system_prompt: str | None = None,
+) -> list[int]:
+ """Segment-by-segment tokenization that matches HF apply_chat_template.
+
+ Calling tokenizer.encode(build_prompt(...)) on the full string lets BPE
+ merge tokens across segment boundaries (e.g. user_prompt ends with `。`
+ and the next segment is `\\n\\n` -> they merge into a single token id
+ 3490 instead of HF's [1811, 271]). HF's apply_chat_template tokenizes
+ each segment independently and concatenates token_ids, so no cross-
+ boundary merge happens. We replicate that here and feed the result to
+ Omni via OmniTokensPrompt (prompt_token_ids).
+ """
+ if task not in _TASK_PRESETS:
+ raise ValueError(f"Unknown task {task!r}. Choose from: {available_tasks()}")
+
+ preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task]
+ effective_sys_type = sys_type or preset_sys_type
+
+ bos_id = tokenizer.convert_tokens_to_ids("<|startoftext|>")
+ img_id = tokenizer.convert_tokens_to_ids("
")
+ trig_id = tokenizer.convert_tokens_to_ids(trigger_tag) if trigger_tag else None
+
+ has_image_input = task.startswith("i2t") or task.startswith("it2i")
+
+ # t2i_vanilla uses pretrain template with no chat structure; the vanilla
+ # system prompt drives the model directly. No segment boundaries to
+ # protect, fall back to whole-string encode.
+ if task == "t2i_vanilla":
+ s = build_prompt(user_prompt, task, sys_type, custom_system_prompt)
+ return tokenizer.encode(s, add_special_tokens=False)
+
+ system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt)
+ # Do NOT strip -- HF apply_chat_template keeps the system prompt's
+ # natural trailing newline; stripping it would shift one token id.
+ sys_text = system_prompt or ""
+
+ ids: list[int] = [bos_id]
+ if sys_text:
+ ids += tokenizer.encode(sys_text, add_special_tokens=False)
+ ids += tokenizer.encode("\n\n", add_special_tokens=False)
+ ids += tokenizer.encode("User: ", add_special_tokens=False)
+ if has_image_input:
+ ids += [img_id]
+ ids += tokenizer.encode(user_prompt, add_special_tokens=False)
+ ids += tokenizer.encode("\n\nAssistant: ", add_special_tokens=False)
+ if trig_id is not None:
+ ids += [trig_id]
+ return ids
+
+
+__all__ = ["build_prompt", "build_prompt_tokens", "available_tasks"]
diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py
index e2f600eaa46..d2552bdddbf 100644
--- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py
+++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import gc
import math
import typing
from collections.abc import Callable, Iterable, Mapping, Sequence
@@ -17,14 +18,29 @@
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from vllm.compilation.decorators import support_torch_compile
-from vllm.config import VllmConfig
+from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.multimodal import BaseDummyOptions
-from vllm.distributed import get_pp_group
+from vllm.distributed import (
+ get_ep_group,
+ get_pp_group,
+ get_tensor_model_parallel_world_size,
+)
from vllm.inputs import MultiModalDataDict
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import fused_moe_make_expert_params_mapping
+
+try:
+ from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
+except ImportError:
+ # PyPI vllm 0.20.x neither exports `SharedFusedMoE` from the package top-level
+ # nor ships a `shared_fused_moe.py` submodule. The functionality lives on
+ # `FusedMoE` directly (which gained a `shared_experts` parameter), so alias
+ # the symbol — call sites only use the classmethod `make_expert_params_mapping`
+ # and `__init__(shared_experts=..., ...)` which are present on `FusedMoE`.
+ from vllm.model_executor.layers.fused_moe import FusedMoE as SharedFusedMoE
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
+ ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -37,7 +53,9 @@
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.hunyuan_v1 import (
+ HunYuanMLP,
HunYuanModel,
+ HunYuanSparseMoeBlock,
_get_cla_factor,
_is_moe,
)
@@ -842,8 +860,6 @@ def process_image(self, image_input: ImageInput):
else:
raise TypeError(f"Unsupported image type: {type(image_input)}.")
- torch_dtype = getattr(self.hf_config, "torch_dtype", torch.bfloat16)
-
batch_data = []
for image in images:
current_info = {}
@@ -853,20 +869,38 @@ def process_image(self, image_input: ImageInput):
# VIT processing
vit_pixel_values = self.vision_encoder_processor(image)
- # shape: (seq_len, num_channels * patch_size * patch_size)
- current_info["vit_pixel_values"] = vit_pixel_values["pixel_values"].squeeze(0)
- # shape: (seq_len, )
- current_info["vit_pixel_attention_mask"] = vit_pixel_values["pixel_attention_mask"].squeeze(0)
- # shape: (2, )
- current_info["vit_spatial_shapes"] = vit_pixel_values["spatial_shapes"].squeeze(0)
-
- # VAE processing
+ # transformers>=5.x returns lists; stack to tensor when needed
+ _pv = vit_pixel_values["pixel_values"]
+ if isinstance(_pv, list):
+ _pv = torch.stack(_pv, dim=0)
+ current_info["vit_pixel_values"] = _pv.squeeze(0)
+ _pam = vit_pixel_values["pixel_attention_mask"]
+ if isinstance(_pam, list):
+ _pam = torch.stack(_pam, dim=0)
+ current_info["vit_pixel_attention_mask"] = _pam.squeeze(0)
+ _ss = vit_pixel_values["spatial_shapes"]
+ if isinstance(_ss, list):
+ _ss = torch.tensor(_ss, dtype=torch.long)
+ current_info["vit_spatial_shapes"] = _ss.squeeze(0)
+
+ # VAE processing.
+ # The resize/crop math here mirrors HF's `resize_and_crop` with
+ # crop_type="center" (hunyuan3.0_ins/image_processor.py:61). VAE
+ # normalize uses the same transforms.Compose([ToTensor,
+ # Normalize([0.5], [0.5])]) as HF's `pil_image_to_tensor`. So
+ # numerical output of this branch should match HF up to floating-
+ # point reduction order.
image_width, image_height = self.reso_group.get_target_size(image.width, image.height)
resized_image = self._resize_and_crop(image, (image_width, image_height))
vae_pixel_values = self.vae_processor(resized_image)
token_height = image_height // (self.hf_config.vae_downsample_factor[0] * self.hf_config.patch_size)
token_width = image_width // (self.hf_config.vae_downsample_factor[1] * self.hf_config.patch_size)
- current_info["vae_pixel_values"] = vae_pixel_values.squeeze(0).to(dtype=torch_dtype)
+ # Keep fp32 — the VAE encoder casts to model dtype at its boundary
+ # (see _vae_encode). Casting to bf16 here costs ~7e-4 mean-abs-diff
+ # bf16 quantization error on every pixel vs HF (which keeps fp32
+ # in build_cond_images), measurable as a real numerical drift in
+ # downstream image embeddings.
+ current_info["vae_pixel_values"] = vae_pixel_values.squeeze(0)
current_info["vae_token_grid_hw"] = torch.tensor([token_height, token_width])
# size
@@ -1052,6 +1086,25 @@ def get_replacement_image(item_idx: int) -> PromptUpdateDetails:
if ratio_token_id is None:
raise ValueError(f"Ratio token '' not found in tokenizer vocabulary")
+ # NOTE on the timestep slot:
+ # HF's apply_chat_template emits the literal token id
+ # 128017 here. HF's modeling forward (`instantiate_continuous_tokens`,
+ # see hunyuan3.0_ins/modeling_hunyuan_image_3.py:1964) then *scatter-
+ # replaces* the embedding at that position with `timestep_emb(0)`
+ # for cond images. So the wte embedding of is irrelevant
+ # at runtime — what matters is the timestep_emb injection.
+ #
+ # vllm-omni achieves the same effect via the multimodal-embedding
+ # merger: we put an
(128006) placeholder here and ship a
+ # `timestep_emb(0)` tensor at the head of `embed_multimodal()`'s
+ # combined_embeddings. The merger replaces this placeholder's
+ # embedding with the timestep tensor, yielding a final hidden
+ # state numerically equivalent to HF at that position.
+ #
+ # Keep this slot as
(NOT ): switching to
+ # requires either (a) a second PromptReplacement targeting 128017,
+ # or (b) the merger's embed_token_id to be a list — neither is
+ # currently supported by PromptUpdateDetails.select_token_id.
replacement = (
[boi_token_id]
+ [base_size_token_id]
@@ -1070,6 +1123,197 @@ def get_replacement_image(item_idx: int) -> PromptUpdateDetails:
]
+def _hunyuan_image3_unpack_packed_topk(
+ hidden_states: torch.Tensor,
+ gating_output: torch.Tensor,
+ topk: int,
+ renormalize: bool,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Unpack pre-computed ``(topk_weights, topk_indices)`` packed by
+ :class:`HunyuanImage3SparseMoeBlock` into ``gating_output``.
+
+ Used as ``custom_routing_function`` for the underlying ``SharedFusedMoE``,
+ bypassing its bf16 ``topk_softmax`` CUDA op so the routing decision can
+ be made in fp32 (matching the reference implementation).
+
+ Layout of ``gating_output`` (shape ``[num_tokens, top_k * 2]``)::
+
+ [:, :top_k] -> topk_weights (already softmax'd + renormalized in fp32,
+ stored as fp32 for transport)
+ [:, top_k:] -> topk_indices (cast to fp32 for transport, restored to int32)
+ """
+ topk_weights = gating_output[:, :topk].contiguous()
+ topk_indices = gating_output[:, topk:]
+ return topk_weights.to(torch.float32), topk_indices.to(torch.int32)
+
+
+class HunyuanImage3SparseMoeBlock(HunYuanSparseMoeBlock):
+ """MoE block with FP32 routing for byte-level alignment with HF.
+
+ The reference ``modeling_hunyuan_image_3.py`` runs the router in fp32:
+
+ - ``HunyuanTopKGate.wg`` is constructed as ``nn.Linear(..., dtype=torch.float32)``
+ and ``hidden_states`` is cast to fp32 before the matmul (line 1114-1116).
+ - ``HunyuanMoE.forward`` wraps the gate call in
+ ``with torch.autocast('cuda', enabled=False):`` to defeat any AMP cast
+ (line 1204-1205), then calls ``easy_topk`` which does
+ ``F.softmax`` → ``torch.topk`` → divide by
+ ``torch.clamp(weight_sums, min=1e-8)`` → cast back to bf16, all in fp32
+ (line 1132-1139, 1206-1207).
+
+ vLLM's stock ``HunYuanSparseMoeBlock`` instead builds the gate as a
+ default-dtype ``ReplicatedLinear`` (bf16) and lets ``SharedFusedMoE``'s
+ ``topk_softmax`` CUDA op consume bf16 logits, which can flip top-k
+ boundary decisions vs HF on close routing scores. With ``num_experts=64``,
+ ``top_k=8`` per layer × 32 MoE layers, even a small per-token flip rate
+ cascades into divergent expert outputs and KV-cache state, eventually
+ flipping the top-1 decoded token.
+
+ This subclass:
+
+ 1. Replaces ``self.gate`` with a fp32 ``ReplicatedLinear``.
+ 2. Replaces ``self.experts`` with a ``SharedFusedMoE`` whose routing is a
+ no-op unpack of our pre-computed (topk_weights, topk_indices) — the
+ fp32 softmax/topk/renormalize is done in :meth:`forward` here, exactly
+ mirroring HF's ``easy_topk`` math (including ``clamp(min=1e-8)``).
+ """
+
+ def __init__(
+ self,
+ config,
+ quant_config=None,
+ layer_id: int = -1,
+ prefix: str = "",
+ enable_eplb: bool = False,
+ ) -> None:
+ # Bypass ``HunYuanSparseMoeBlock.__init__`` — it would build a wasteful
+ # bf16 gate + a stub ``SharedFusedMoE`` we'd then have to del+recreate
+ # (which trips ``ValueError: Duplicate layer name`` because the stub
+ # already registered itself in ``compilation_config.static_forward_context``).
+ # Instead, set up ``nn.Module`` ourselves and construct the fp32 gate
+ # + ``custom_routing_function``-driven ``SharedFusedMoE`` directly,
+ # mirroring the parent's structure 1:1 except for the routing dtype.
+ nn.Module.__init__(self)
+
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.ep_group = get_ep_group().device_group
+ self.ep_rank = get_ep_group().rank_in_group
+ self.ep_size = self.ep_group.size()
+ self.n_routed_experts = config.num_experts
+
+ if self.tp_size > config.num_experts:
+ raise ValueError(
+ f"Tensor parallel size {self.tp_size} is greater than the number of experts {config.num_experts}."
+ )
+
+ if isinstance(config.moe_topk, list):
+ top_k = config.moe_topk[layer_id]
+ else:
+ top_k = config.moe_topk
+ self.top_k = top_k
+
+ intermediate_size = config.intermediate_size
+ if config.moe_intermediate_size is not None:
+ intermediate_size = (
+ config.moe_intermediate_size
+ if isinstance(config.moe_intermediate_size, int)
+ else config.moe_intermediate_size[layer_id]
+ )
+
+ vllm_config = get_current_vllm_config()
+ eplb_config = vllm_config.parallel_config.eplb_config
+ self.enable_eplb = enable_eplb
+ self.n_logical_experts = self.n_routed_experts
+ self.n_redundant_experts = eplb_config.num_redundant_experts
+ self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
+ self.n_local_physical_experts = self.n_physical_experts // self.ep_size
+ self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
+ self.physical_expert_end = self.physical_expert_start + self.n_local_physical_experts
+
+ # FP32 router gate (HF: ``wg = nn.Linear(..., dtype=torch.float32)``).
+ self.gate = ReplicatedLinear(
+ config.hidden_size,
+ config.num_experts,
+ bias=False,
+ quant_config=None,
+ params_dtype=torch.float32,
+ prefix=f"{prefix}.gate",
+ )
+
+ if config.use_mixed_mlp_moe > 0:
+ num_shared_expert = (
+ config.num_shared_expert[layer_id]
+ if isinstance(config.num_shared_expert, list)
+ else config.num_shared_expert
+ )
+ self.shared_mlp = HunYuanMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size * num_shared_expert,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ reduce_results=False,
+ prefix=f"{prefix}.shared_mlp",
+ )
+ else:
+ self.shared_mlp = None
+
+ # Experts with our ``_hunyuan_image3_unpack_packed_topk`` custom
+ # routing — we feed it (topk_weights, topk_indices) packed into
+ # ``router_logits`` in ``forward()`` so the bf16 ``topk_softmax``
+ # CUDA op is bypassed entirely. ``renormalize=False`` because we
+ # already did clamp+divide in fp32 to match HF's
+ # ``topk_weight = topk_weight_1 / clamp(sum, min=1e-8)``.
+ self.experts = SharedFusedMoE(
+ shared_experts=self.shared_mlp,
+ num_experts=self.n_routed_experts,
+ top_k=top_k,
+ hidden_size=config.hidden_size,
+ intermediate_size=intermediate_size,
+ renormalize=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.experts",
+ enable_eplb=self.enable_eplb,
+ num_redundant_experts=self.n_redundant_experts,
+ custom_routing_function=_hunyuan_image3_unpack_packed_topk,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ orig_shape = hidden_states.shape
+ hidden_dim = hidden_states.shape[-1]
+ hidden_states = hidden_states.view(-1, hidden_dim)
+
+ # FP32 router (HF: `with torch.autocast('cuda', enabled=False): ...`
+ # plus `if self.wg.weight.dtype == torch.float32: hidden_states.float()`).
+ # ``self.gate.weight`` is fp32 (params_dtype=torch.float32), so the
+ # ReplicatedLinear matmul runs in fp32 once we cast the input.
+ router_logits, _ = self.gate(hidden_states.float())
+
+ # softmax + topk + clamp-divide renormalization, all in fp32 — matches
+ # ``HunyuanTopKGate.easy_topk`` exactly.
+ gates = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
+ topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1)
+ weight_sums = topk_weights.sum(dim=-1, keepdim=True)
+ topk_weights = topk_weights / weight_sums.clamp(min=1e-8)
+
+ # Cast topk weights to model dtype for the expert MLP combine.
+ # HF: ``topk_weights = topk_weights.to(hidden_states.dtype)`` (line 1207).
+ topk_weights = topk_weights.to(hidden_states.dtype)
+
+ # Pack (weights, indices) into the ``router_logits`` slot so
+ # ``_hunyuan_image3_unpack_packed_topk`` can pull them back out
+ # inside ``SharedFusedMoE``. Both halves are stored as fp32 for
+ # transport — the indices get cast back to int32 on unpack.
+ packed_routing = torch.cat([topk_weights.float(), topk_indices.to(torch.float32)], dim=-1)
+
+ # vllm 0.20+ FusedMoE merges shared-experts internally and runs the
+ # TP all-reduce inside its forward (we no longer pass
+ # `reduce_results=False`). The tuple `(routed, shared)` return shape
+ # from the legacy SharedFusedMoE is gone; the result is the
+ # already-combined, already-reduced tensor.
+ final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=packed_routing)
+ return final_hidden_states.view(orig_shape)
+
+
class HunyuanImage3RotaryEmbedding(nn.Module):
"""Custom interleaved 2D Rotary Embedding for HunyuanImage3.
@@ -1273,8 +1517,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# For comprehension mode, block image generation tokens but allow
# text structure tokens (, , etc.) so the model can
- # follow its natural generation pattern. Stop tokens in YAML will
- # terminate at or EOS.
+ # follow its natural generation pattern. The yaml stop_token_ids
+ # for i2t/t2t now includes (128024) so the AR-only output
+ # terminates after the analysis section, matching HF's
+ # `bot_task="think"` behavior. Without that stop, the model
+ # continues into a recaption section even in comprehension mode
+ # (the stage-transition processor only fires in generation mode,
+ # but the instruct-tuned model writes recaption on its own from
+ # internal habit).
self._blocked_token_ids: set[int] = set()
if self._is_comprehension:
self._blocked_token_ids.update(
@@ -1306,6 +1556,75 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._eos_token_id: int = tokenizer.eos_token_id
self._replace_rotary_embeddings()
+ self._patch_moe_blocks()
+
+ def _patch_moe_blocks(self):
+ """Replace stock ``HunYuanSparseMoeBlock`` instances with
+ :class:`HunyuanImage3SparseMoeBlock`, which routes in fp32 to match
+ the HF reference (``modeling_hunyuan_image_3.HunyuanMoE``).
+
+ Stock vLLM builds the router gate as a default-dtype (bf16)
+ ``ReplicatedLinear`` and lets ``SharedFusedMoE``'s ``topk_softmax``
+ kernel consume bf16 logits, which is the largest deterministic
+ precision gap remaining vs HF after the prompt/preprocessing
+ alignment fixes. See ``HunyuanImage3SparseMoeBlock`` docstring for
+ the full rationale.
+
+ Must run before weight loading (still inside ``__init__``) so the
+ replacement gate's fp32 ``params_dtype`` is honored when the
+ checkpoint is loaded.
+ """
+ if not _is_moe(self.config):
+ return
+ enable_eplb = getattr(self.vllm_config.parallel_config, "enable_eplb", False)
+ ccfg = self.vllm_config.compilation_config
+ replaced = 0
+ for layer_id, layer in enumerate(self.model.layers):
+ mlp = getattr(layer, "mlp", None)
+ if isinstance(mlp, HunYuanSparseMoeBlock) and not isinstance(mlp, HunyuanImage3SparseMoeBlock):
+ # Pop the OLD experts' registration from
+ # ``static_forward_context`` first — otherwise the new
+ # ``SharedFusedMoE`` built inside
+ # :class:`HunyuanImage3SparseMoeBlock` will trip
+ # ``ValueError: Duplicate layer name`` (see
+ # vllm/model_executor/layers/fused_moe/layer.py:327).
+ old_prefix = f"model.layers.{layer_id}.mlp.experts"
+ ccfg.static_forward_context.pop(old_prefix, None)
+ if old_prefix in ccfg.static_all_moe_layers:
+ ccfg.static_all_moe_layers.remove(old_prefix)
+
+ # Free the OLD MoE block's GPU buffers BEFORE allocating
+ # the replacement. The parent ``SharedFusedMoE`` pre-
+ # allocates the full ``[num_experts, ...]`` expert weight
+ # tensors at ``__init__`` (~750 MiB per layer per worker
+ # on this 80B model with TP=2), so without this drop we
+ # transiently double the MoE footprint and OOM near the
+ # gpu_memory_utilization cap.
+ layer.mlp = None
+ del mlp
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ layer.mlp = HunyuanImage3SparseMoeBlock(
+ config=self.config,
+ quant_config=self.quant_config,
+ layer_id=layer_id,
+ prefix=f"model.layers.{layer_id}.mlp",
+ enable_eplb=enable_eplb,
+ )
+ replaced += 1
+ logger.info(
+ "Replaced %d HunYuanSparseMoeBlock layers with "
+ "HunyuanImage3SparseMoeBlock (fp32 router matching HF reference)",
+ replaced,
+ )
+ if replaced == 0:
+ logger.warning(
+ "HunyuanImage3: _patch_moe_blocks replaced 0 layers. "
+ "Routing will run in bf16 instead of fp32 — output will "
+ "diverge from the HF reference more than necessary. "
+ "Check that model.layers[*].mlp is HunYuanSparseMoeBlock."
+ )
def _replace_rotary_embeddings(self):
"""Replace vLLM's standard MRotaryEmbedding with the custom
@@ -1384,6 +1703,18 @@ def _vae_encode(
"""
config = self.vae.config
+ # Cast pixel input to model dtype here (at the encoder boundary)
+ # rather than inside HunyuanImage3Processor.process_image. This
+ # matches HF's path which keeps fp32 pixels in build_cond_images and
+ # only casts inside the VAE forward — preserving fp32 precision in
+ # the multimodal_data dict and minimizing precision drift vs HF.
+ # Verified by pixel-tensor diff: removing the early bf16 cast brings
+ # omni's vae_pixel_values byte-identical to HF's (within fp32 noise),
+ # whereas an early cast leaves a ~7e-4 mean-abs-diff bf16 quantization
+ # error on every element.
+ if images.dtype != self.vae.dtype:
+ images = images.to(dtype=self.vae.dtype)
+
vae_encode_result = self.vae.encode(images)
latents = vae_encode_result.latent_dist.sample()
@@ -1494,11 +1825,14 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
"Each image should have both VAE and ViT embeddings."
)
- # Order per image: timestep -> VAE tokens -> ViT tokens
+ # Order per image: timestep -> VAE tokens -> ViT tokens.
+ # The
placeholder at the timestep slot (see _get_prompt_updates)
+ # gets its embedding replaced by `timestep_emb(0)` here, which is what
+ # HF achieves via instantiate_continuous_tokens at runtime.
combined_embeddings: list[torch.Tensor] = []
num_images = len(vae_token_embeddings)
for img_idx in range(num_images):
- # 1. Timestep embedding
+ # 1. Timestep embedding (cond image timestep == 0)
timestep = torch.zeros((1,)).to(vit_embeddings.device).to(vit_embeddings.dtype)
timestep_emb = self._timestep_encode(timestep)
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml
index b68b184ec31..0614a9f1179 100644
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml
@@ -34,7 +34,7 @@ stage_args:
top_p: 0.95
top_k: 1024
max_tokens: 2048
- stop_token_ids: [127957, 128026] # <|endoftext|>,
+ stop_token_ids: [127957, 128024, 128026] # <|endoftext|>, ,
detokenize: True
runtime:
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml
index a0a1a0dc1c4..c9daa5e5f39 100644
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml
@@ -35,7 +35,7 @@ stage_args:
top_p: 0.95
top_k: 1024
max_tokens: 2048
- stop_token_ids: [127957, 128026] # <|endoftext|>,
+ stop_token_ids: [127957, 128024, 128026] # <|endoftext|>, ,
detokenize: True
runtime: