From d71981e7610dd10bccaab6acfb7ee9c2f5814cef Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Wed, 29 Apr 2026 10:10:26 +0800 Subject: [PATCH 01/16] fix(hunyuan_image3): handle list pixel_values from Siglip2 in transformers>=5.x Siglip2ImageProcessorFast in transformers>=5.0 returns pixel_values, pixel_attention_mask, and spatial_shapes as lists of tensors/tuples instead of a single batched tensor. The old code called .squeeze(0) directly on the list, causing AttributeError at MultiModalBudget initialization (get_dummy_mm_inputs path) and crashing startup. Fix by stacking list elements into a tensor before squeezing: - pixel_values / pixel_attention_mask: torch.stack(list, dim=0) - spatial_shapes: torch.tensor(list, dtype=torch.long) since elements are tuples, not tensors Tested on transformers 5.6.2: both FA and SDPA backends initialize and produce identical T2T output after this fix. Signed-off-by: zuiho Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../models/hunyuan_image3/hunyuan_image3.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index e2f600eaa46..f9fb796e4ba 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -853,12 +853,19 @@ def process_image(self, image_input: ImageInput): # VIT processing vit_pixel_values = self.vision_encoder_processor(image) - # shape: (seq_len, num_channels * patch_size * patch_size) - current_info["vit_pixel_values"] = vit_pixel_values["pixel_values"].squeeze(0) - # shape: (seq_len, ) - current_info["vit_pixel_attention_mask"] = vit_pixel_values["pixel_attention_mask"].squeeze(0) - # shape: (2, ) - current_info["vit_spatial_shapes"] = vit_pixel_values["spatial_shapes"].squeeze(0) + # transformers>=5.x returns lists; stack to tensor when needed + _pv = vit_pixel_values["pixel_values"] + if isinstance(_pv, list): + _pv = torch.stack(_pv, dim=0) + current_info["vit_pixel_values"] = _pv.squeeze(0) + _pam = vit_pixel_values["pixel_attention_mask"] + if isinstance(_pam, list): + _pam = torch.stack(_pam, dim=0) + current_info["vit_pixel_attention_mask"] = _pam.squeeze(0) + _ss = vit_pixel_values["spatial_shapes"] + if isinstance(_ss, list): + _ss = torch.tensor(_ss, dtype=torch.long) + current_info["vit_spatial_shapes"] = _ss.squeeze(0) # VAE processing image_width, image_height = self.reso_group.get_target_size(image.width, image.height) From d360569a9983acd63b597a3618d023007fe07beb Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Wed, 29 Apr 2026 13:13:29 +0800 Subject: [PATCH 02/16] fix(hunyuan_image3): use instruct chat format for T2T prompt The pretrain-style format (system_prompt + raw user_prompt) used by build_prompt() for task="t2t" leaves the model without an answer-start signal. With temperature=0.0 greedy decoding it falls into garbage repetition (e.g. "massive arches massive arches ..." ad infinitum). Use the instruct chat format `<|startoftext|>User: {prompt}\n\nA: ` which matches what the official HF AR baseline (mode="gen_text", sequence_template="instruct") emits via prepare_model_inputs(). With that format vllm-omni T2T produces structured numbered output with specific facts, comparable to the HF baseline. Verified on remote 2x L20 with HunyuanImage-3.0-Instruct. Signed-off-by: TaffyOfficial <2324465096@qq.com> --- examples/offline_inference/hunyuan_image3/end2end.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/offline_inference/hunyuan_image3/end2end.py b/examples/offline_inference/hunyuan_image3/end2end.py index 2cea303888e..3e444c3393b 100644 --- a/examples/offline_inference/hunyuan_image3/end2end.py +++ b/examples/offline_inference/hunyuan_image3/end2end.py @@ -60,6 +60,14 @@ def build_prompt( has_image_input = task.startswith("i2t") or task.startswith("it2i") + # T2T: pure text comprehension. The pretrain-style format + # (system_prompt + raw user_prompt) does not signal where the answer + # should start, so the model falls into repetition. Use the instruct + # chat format (`User: ...\n\nA: `) that the model was trained with for + # text completion — verified to match the official HF AR baseline. + if task == "t2t": + return f"<|startoftext|>User: {user_prompt}\n\nA: " + parts = ["<|startoftext|>"] if sys_text: parts.append(sys_text) From 27083f9ce7ce8672ab4b3f410f97b6ccdd9e88bb Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Wed, 29 Apr 2026 13:39:18 +0800 Subject: [PATCH 03/16] fix(hunyuan_image3): use instruct chat template for all chat tasks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The pretrain-style prompt format was producing garbage output for greedy decoding across all chat-style tasks (T2T, I2T, IT2I, T2I): - T2T (no trigger): "massive arches massive arches ..." (infinite loop) - IT2I (): repetitive "image_1 完整保留..." segments Root cause: trigger_tag and user_prompt order were inverted vs. what HunyuanImage3's tokenizer.apply_general_template emits for instruct sequence_template. The model was trained to see: <|startoftext|>{system?}\n\nUser: {?}{user_prompt}\n\nA: {trigger?} so the trigger (e.g. ) sits AFTER the assistant prefix and the model continues from there. The previous build_prompt() concatenated trigger BEFORE user_prompt, which placed the user instructions inside the model's "thinking section" and broke greedy decoding. Fix: - T2T: bare instruct template, no system prompt (matches HF baseline) - t2i_vanilla: keep pretrain mode (it is the only task designed for it) - All others: instruct template with trigger after `\n\nA: ` Verified on remote 2x L20 with HunyuanImage-3.0-Instruct: IT2I greedy output is now coherent analysis covering all key elements (matches HF AR baseline structure). Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../hunyuan_image3/end2end.py | 31 ++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/examples/offline_inference/hunyuan_image3/end2end.py b/examples/offline_inference/hunyuan_image3/end2end.py index 3e444c3393b..0060981b777 100644 --- a/examples/offline_inference/hunyuan_image3/end2end.py +++ b/examples/offline_inference/hunyuan_image3/end2end.py @@ -60,22 +60,39 @@ def build_prompt( has_image_input = task.startswith("i2t") or task.startswith("it2i") - # T2T: pure text comprehension. The pretrain-style format - # (system_prompt + raw user_prompt) does not signal where the answer - # should start, so the model falls into repetition. Use the instruct - # chat format (`User: ...\n\nA: `) that the model was trained with for - # text completion — verified to match the official HF AR baseline. + # T2T: pure text comprehension. Skip the en_unified system prompt — it + # only describes image-gen modes and confuses the model into repetition. + # Use the bare instruct chat format that matches the official HF AR + # baseline (verified token-for-token). if task == "t2t": return f"<|startoftext|>User: {user_prompt}\n\nA: " + # t2i_vanilla: pretrain mode for direct text→image generation. The + # vanilla system prompt drives the model with no chat structure. + if task == "t2i_vanilla": + parts = ["<|startoftext|>"] + if sys_text: + parts.append(sys_text) + parts.append(user_prompt) + return "".join(parts) + + # i2t / t2i_think / t2i_recaption / it2i_think / it2i_recaption: + # HunyuanImage3 instruct chat template. + # <|startoftext|>{system?}\n\nUser: {?}{user_prompt}\n\nA: {trigger?} + # The trigger_tag MUST come AFTER the assistant prefix `A: `, not before + # user_prompt. Putting `` before user_prompt (the old pretrain + # layout) puts the user's instructions inside the model's "thinking + # section", which under greedy decoding collapses into repetition garbage. parts = ["<|startoftext|>"] if sys_text: - parts.append(sys_text) + parts.append(f"{sys_text}\n\n") + parts.append("User: ") if has_image_input: parts.append("") + parts.append(user_prompt) + parts.append("\n\nA: ") if trigger_tag: parts.append(trigger_tag) - parts.append(user_prompt) return "".join(parts) From ea80934897a7fddafd942a32cd3e2d60c30bd81d Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Wed, 29 Apr 2026 14:24:26 +0800 Subject: [PATCH 04/16] fix(hunyuan_image3): use Assistant: prefix to match HF tokenizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit HunyuanImage3TokenizerFast.apply_general_template uses Assistant: as the bot role prefix in instruct sequence_template (verified by decoding HF prepare_model_inputs output with system_prompt=en_unified + image + bot_task=think: token 72803 = "Assistant"). Switch build_prompt() to use the full word so the AR prefill aligns with the official HF tokenization. Also unify T2T to the same en_unified + Assistant: template (PR #3107 reference implementation does the same; the previous T2T-specific branch was a workaround for an earlier prompt-format experiment). Note: BPE merge across user_prompt/Assistant boundary still produces 1 merged token (e.g. "。\n\n" -> single id) where HF apply_chat_template keeps them separate. Full byte-identical alignment requires passing pre-tokenized prompt_token_ids — that path is supported by vllm-omni (OmniTokensPrompt) but not yet plumbed through build_prompt(). Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../hunyuan_image3/end2end.py | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/examples/offline_inference/hunyuan_image3/end2end.py b/examples/offline_inference/hunyuan_image3/end2end.py index 0060981b777..ee4e1cacd5f 100644 --- a/examples/offline_inference/hunyuan_image3/end2end.py +++ b/examples/offline_inference/hunyuan_image3/end2end.py @@ -60,13 +60,6 @@ def build_prompt( has_image_input = task.startswith("i2t") or task.startswith("it2i") - # T2T: pure text comprehension. Skip the en_unified system prompt — it - # only describes image-gen modes and confuses the model into repetition. - # Use the bare instruct chat format that matches the official HF AR - # baseline (verified token-for-token). - if task == "t2t": - return f"<|startoftext|>User: {user_prompt}\n\nA: " - # t2i_vanilla: pretrain mode for direct text→image generation. The # vanilla system prompt drives the model with no chat structure. if task == "t2i_vanilla": @@ -76,13 +69,16 @@ def build_prompt( parts.append(user_prompt) return "".join(parts) - # i2t / t2i_think / t2i_recaption / it2i_think / it2i_recaption: - # HunyuanImage3 instruct chat template. - # <|startoftext|>{system?}\n\nUser: {?}{user_prompt}\n\nA: {trigger?} - # The trigger_tag MUST come AFTER the assistant prefix `A: `, not before - # user_prompt. Putting `` before user_prompt (the old pretrain - # layout) puts the user's instructions inside the model's "thinking - # section", which under greedy decoding collapses into repetition garbage. + # All other tasks (t2t / i2t / t2i_think / t2i_recaption / + # it2i_think / it2i_recaption) use HunyuanImage3 Instruct chat template: + # <|startoftext|>{system?}\n\nUser: {?}{user_prompt}\n\nAssistant: {trigger?} + # generation_config.json declares sequence_template="instruct", so the + # AR prefill MUST use this template — verified to match HF's + # apply_chat_template output token-for-token (modulo BPE boundary merges). + # The trigger_tag (e.g. ) MUST come AFTER the `Assistant: ` prefix: + # if it goes BEFORE user_prompt (the old pretrain layout) the model puts + # the user's instructions inside the "thinking section" and collapses + # into repetition garbage under greedy decoding. parts = ["<|startoftext|>"] if sys_text: parts.append(f"{sys_text}\n\n") @@ -90,7 +86,7 @@ def build_prompt( if has_image_input: parts.append("") parts.append(user_prompt) - parts.append("\n\nA: ") + parts.append("\n\nAssistant: ") if trigger_tag: parts.append(trigger_tag) From 7bd429ed7fd3aa89ae601adfd8cb2ce7522d7018 Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Wed, 29 Apr 2026 15:42:16 +0800 Subject: [PATCH 05/16] fix(hunyuan_image3): segment-tokenize prompt to bypass BPE merges MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds build_prompt_tokens() that mirrors HF apply_chat_template's segment-by-segment tokenization. The previous build_prompt() returned a single string that the engine fed through tokenizer.encode() in one BPE pass, which merged tokens across segment boundaries (e.g. user_prompt ending in "。" + the conv separator "\n\n" -> single token id 3490 instead of HF's [1811, 271]). This shifted tokens at the user-text / Assistant: prefix boundary and made vllm-omni's input_ids drift from HF's by 1-2 tokens, causing greedy outputs to diverge after the very first generated token. Loads the model's tokenizer in main(), encodes each conversation segment independently (system prompt, "\n\n", "User: ", placeholder, user_prompt, "\n\nAssistant: ", trigger tag) and passes the resulting list[int] to omni.generate() via the existing prompt_token_ids dict path (OmniSingletonPrompt already supports list[int] / OmniTokensPrompt — no engine-side changes needed). t2i_vanilla still uses the pretrain whole-string path because that mode has no chat-template segments. Verified on remote 2x L20: text portion of input_ids (first 1227 tokens, before the placeholder) is now byte-identical to HF's prepare_model_inputs output. The trailer also matches: the previous "。\n\nAssistant: " -> [3490, 32, 25, 220, 128023] becomes the HF-correct [1811, 271, 72803, 25, 220, 128023]. Note: build_prompt() is kept for backward compatibility but its docstring now warns about the BPE merge issue and points to build_prompt_tokens() as the replacement for HF-aligned inputs. Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../hunyuan_image3/end2end.py | 74 ++++++++++++++++++- 1 file changed, 71 insertions(+), 3 deletions(-) diff --git a/examples/offline_inference/hunyuan_image3/end2end.py b/examples/offline_inference/hunyuan_image3/end2end.py index ee4e1cacd5f..ac12ef2cfee 100644 --- a/examples/offline_inference/hunyuan_image3/end2end.py +++ b/examples/offline_inference/hunyuan_image3/end2end.py @@ -42,13 +42,75 @@ } +def build_prompt_tokens( + user_prompt: str, + tokenizer, + task: str = "it2i_think", + sys_type: str | None = None, + custom_system_prompt: str | None = None, +) -> list[int]: + """Segment-by-segment tokenization that matches HF apply_chat_template. + + Calling tokenizer.encode(build_prompt(...)) on the full string lets BPE + merge tokens across segment boundaries (e.g. user_prompt ends with `。` + and the next segment is `\\n\\n` -> they merge into a single token id + 3490 instead of HF's [1811, 271]). HF's apply_chat_template tokenizes + each segment independently and concatenates token_ids, so no cross- + boundary merge happens. We replicate that here and feed the result to + Omni via OmniTokensPrompt (prompt_token_ids). + """ + if task not in _TASK_PRESETS: + raise ValueError(f"Unknown task {task!r}. Choose from: {sorted(_TASK_PRESETS)}") + + preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task] + effective_sys_type = sys_type or preset_sys_type + + bos_id = tokenizer.convert_tokens_to_ids("<|startoftext|>") + img_id = tokenizer.convert_tokens_to_ids("") + trig_id = tokenizer.convert_tokens_to_ids(trigger_tag) if trigger_tag else None + + has_image_input = task.startswith("i2t") or task.startswith("it2i") + + # t2i_vanilla uses pretrain template with no chat structure; the vanilla + # system prompt drives the model directly. No segment boundaries to + # protect, fall back to whole-string encode. + if task == "t2i_vanilla": + s = build_prompt(user_prompt, task, sys_type, custom_system_prompt) + return tokenizer.encode(s, add_special_tokens=False) + + system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt) + # Do NOT strip — HF apply_chat_template keeps the system prompt's + # natural trailing newline; stripping it would shift one token id. + sys_text = system_prompt or "" + + ids: list[int] = [bos_id] + if sys_text: + ids += tokenizer.encode(sys_text, add_special_tokens=False) + ids += tokenizer.encode("\n\n", add_special_tokens=False) + ids += tokenizer.encode("User: ", add_special_tokens=False) + if has_image_input: + ids += [img_id] + ids += tokenizer.encode(user_prompt, add_special_tokens=False) + ids += tokenizer.encode("\n\nAssistant: ", add_special_tokens=False) + if trig_id is not None: + ids += [trig_id] + return ids + + def build_prompt( user_prompt: str, task: str = "it2i_think", sys_type: str | None = None, custom_system_prompt: str | None = None, ) -> str: - """Build a HunyuanImage-3.0 prompt using pretrain template format.""" + """Build a HunyuanImage-3.0 prompt as a string (legacy/compat path). + + NOTE: when this string is passed to the engine, the engine's tokenizer + will run a single BPE pass over the whole string, which can merge + tokens across segment boundaries (e.g. `。\\n\\n` -> id 3490). For + inputs that need to match HF baseline byte-for-byte, use + `build_prompt_tokens` instead and feed the result via prompt_token_ids. + """ if task not in _TASK_PRESETS: raise ValueError(f"Unknown task {task!r}. Choose from: {sorted(_TASK_PRESETS)}") @@ -200,12 +262,18 @@ def main(): input_image = Image.open(args.image_path).convert("RGB") + # Load tokenizer for segment-wise prompt tokenization (matches HF + # apply_chat_template byte-for-byte; see build_prompt_tokens docstring). + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + # Format prompts formatted_prompts: list[OmniPromptType] = [] for p in prompts: - formatted_text = build_prompt(p, task=task, sys_type=args.sys_type) + token_ids = build_prompt_tokens(p, tokenizer, task=task, sys_type=args.sys_type) - prompt_dict: dict = {"prompt": formatted_text} + prompt_dict: dict = {"prompt_token_ids": token_ids} if args.modality == "text2img": prompt_dict["modalities"] = ["image"] From 3d415e1744751f57c54bf5351daacaac9f2c7102 Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Wed, 29 Apr 2026 15:55:17 +0800 Subject: [PATCH 06/16] docs(hunyuan_image3): document timestep-slot equivalence with HF MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The image expansion uses (128006) at the timestep slot while HF's apply_chat_template uses the literal (128017). Naively swapping the placeholder breaks output (model hallucinates additional images) because HF's modeling forward calls `instantiate_continuous_tokens` to *scatter-replace* the embedding at the position with `timestep_emb(0)` for cond images — the wte embedding of is irrelevant at runtime. vllm-omni's existing -placeholder + multimodal-merger path already produces the same final hidden state at that position by shipping `timestep_emb(0)` at the head of `embed_multimodal()`'s combined_embeddings tensor. So the AR forward is numerically equivalent to HF; only the dumped input_ids differ at that one slot. Switching to would require either a second PromptReplacement targeting 128017, or letting `PromptUpdateDetails.select_token_id` take a list of embed_token_ids. Both are deeper engine-level changes; out of scope for this fix. Add explanatory comments in `_get_prompt_updates` and `embed_multimodal` so future readers don't re-discover this rabbit hole and don't break it with naive cleanups. Verified on remote 2x L20: IT2I greedy output remains structurally correct (2167 chars, full analysis covering all key elements, no image_2..N hallucinations). Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../models/hunyuan_image3/hunyuan_image3.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index f9fb796e4ba..acbb65aee40 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -1059,6 +1059,25 @@ def get_replacement_image(item_idx: int) -> PromptUpdateDetails: if ratio_token_id is None: raise ValueError(f"Ratio token '' not found in tokenizer vocabulary") + # NOTE on the timestep slot: + # HF's apply_chat_template emits the literal token id + # 128017 here. HF's modeling forward (`instantiate_continuous_tokens`, + # see hunyuan3.0_ins/modeling_hunyuan_image_3.py:1964) then *scatter- + # replaces* the embedding at that position with `timestep_emb(0)` + # for cond images. So the wte embedding of is irrelevant + # at runtime — what matters is the timestep_emb injection. + # + # vllm-omni achieves the same effect via the multimodal-embedding + # merger: we put an (128006) placeholder here and ship a + # `timestep_emb(0)` tensor at the head of `embed_multimodal()`'s + # combined_embeddings. The merger replaces this placeholder's + # embedding with the timestep tensor, yielding a final hidden + # state numerically equivalent to HF at that position. + # + # Keep this slot as (NOT ): switching to + # requires either (a) a second PromptReplacement targeting 128017, + # or (b) the merger's embed_token_id to be a list — neither is + # currently supported by PromptUpdateDetails.select_token_id. replacement = ( [boi_token_id] + [base_size_token_id] @@ -1501,11 +1520,14 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: "Each image should have both VAE and ViT embeddings." ) - # Order per image: timestep -> VAE tokens -> ViT tokens + # Order per image: timestep -> VAE tokens -> ViT tokens. + # The placeholder at the timestep slot (see _get_prompt_updates) + # gets its embedding replaced by `timestep_emb(0)` here, which is what + # HF achieves via instantiate_continuous_tokens at runtime. combined_embeddings: list[torch.Tensor] = [] num_images = len(vae_token_embeddings) for img_idx in range(num_images): - # 1. Timestep embedding + # 1. Timestep embedding (cond image timestep == 0) timestep = torch.zeros((1,)).to(vit_embeddings.device).to(vit_embeddings.dtype) timestep_emb = self._timestep_encode(timestep) From 8a1a4af90a2486451eee0fc92351c50ac42dfd04 Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Wed, 29 Apr 2026 16:07:12 +0800 Subject: [PATCH 07/16] docs(hunyuan_image3): document image preprocessing alignment with HF MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Audit of vllm-omni's process_image() against HF's image_processor: - resize/crop math: byte-for-byte identical to HF's `resize_and_crop` with crop_type="center". Same aspect-ratio preservation, same int(round(...)) ordering, same LANCZOS resampler, same crop region computation. - VAE PIL->tensor: identical. Both use transforms.Compose([ToTensor, Normalize([0.5], [0.5])]) — fully equivalent. - ViT processor: same Siglip2 processor class, but transformers version differs at runtime (vllm-omni venv = 5.6.2; HF baseline venv = 4.57.1). The Siglip2ImageProcessorFast normalization path changed between these versions, producing ~1 ULP differences in pixel values. This is a venv-pinning concern, not a code bug. - dtype cast: vllm-omni casts vae_pixel_values to bf16 here; HF stores fp32 and casts inside the encoder forward. Tried delaying the cast to mirror HF, but vllm-omni's _vae_encode runs fp32 input through a bf16-weighted conv3d which raises a dtype mismatch (HF avoids this by an explicit cast at the encoder boundary that vllm-omni does not have). Keep the existing cast and document the divergence — fixing it requires plumbing a cast into _vae_encode, out of scope for this PR. Net effect of this commit: comments only. No behavior change. The remaining numerical drift between vllm-omni and HF on image embeddings is bounded by the transformers version delta and the BF16 reduction-order noise floor; both are out of scope for code changes in this branch. Verified on remote 2x L20: IT2I greedy output unchanged (2167 chars, structurally aligned with HF). Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../models/hunyuan_image3/hunyuan_image3.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index acbb65aee40..0252576d00b 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -867,12 +867,23 @@ def process_image(self, image_input: ImageInput): _ss = torch.tensor(_ss, dtype=torch.long) current_info["vit_spatial_shapes"] = _ss.squeeze(0) - # VAE processing + # VAE processing. + # The resize/crop math here mirrors HF's `resize_and_crop` with + # crop_type="center" (hunyuan3.0_ins/image_processor.py:61). VAE + # normalize uses the same transforms.Compose([ToTensor, + # Normalize([0.5], [0.5])]) as HF's `pil_image_to_tensor`. So + # numerical output of this branch should match HF up to floating- + # point reduction order. image_width, image_height = self.reso_group.get_target_size(image.width, image.height) resized_image = self._resize_and_crop(image, (image_width, image_height)) vae_pixel_values = self.vae_processor(resized_image) token_height = image_height // (self.hf_config.vae_downsample_factor[0] * self.hf_config.patch_size) token_width = image_width // (self.hf_config.vae_downsample_factor[1] * self.hf_config.patch_size) + # Cast to model dtype here. Keeping fp32 raises + # "Input type (float) and bias type (BFloat16) should be the same" + # in the VAE conv3d (vllm-omni's _vae_encode does not auto-cast). + # HF avoids this because its build_cond_images stores fp32 and + # the model forward casts inputs explicitly before encoding. current_info["vae_pixel_values"] = vae_pixel_values.squeeze(0).to(dtype=torch_dtype) current_info["vae_token_grid_hw"] = torch.tensor([token_height, token_width]) From 41d294329a576c124a229ed3b79f02d0d6bed9bf Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Thu, 30 Apr 2026 10:47:48 +0800 Subject: [PATCH 08/16] fix(hunyuan_image3): cast VAE pixels at encoder boundary, not in processor `HunyuanImage3Processor.process_image` previously cast `vae_pixel_values` to model dtype (bf16) right after VAE preprocessing. HF keeps these as fp32 in `build_cond_images` and only casts inside the VAE forward, which preserves fp32 precision through the multimodal_data dict. Move the cast into `_vae_encode` (encoder boundary) and keep `vae_pixel_values` as fp32 in the processor. Verified pixel-level byte-identical with HF (fp32 mean=0.157296). Greedy IT2I output is unchanged (the VAE encoder's first conv casts to bf16 anyway, so the final latent is identical to before this fix), but this removes a ~7e-4 mean-abs-diff bf16 quantization error from `vae_pixel_values` and aligns the multimodal_data path with HF. Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../models/hunyuan_image3/hunyuan_image3.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index 0252576d00b..6d186f75b1d 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -842,8 +842,6 @@ def process_image(self, image_input: ImageInput): else: raise TypeError(f"Unsupported image type: {type(image_input)}.") - torch_dtype = getattr(self.hf_config, "torch_dtype", torch.bfloat16) - batch_data = [] for image in images: current_info = {} @@ -879,12 +877,12 @@ def process_image(self, image_input: ImageInput): vae_pixel_values = self.vae_processor(resized_image) token_height = image_height // (self.hf_config.vae_downsample_factor[0] * self.hf_config.patch_size) token_width = image_width // (self.hf_config.vae_downsample_factor[1] * self.hf_config.patch_size) - # Cast to model dtype here. Keeping fp32 raises - # "Input type (float) and bias type (BFloat16) should be the same" - # in the VAE conv3d (vllm-omni's _vae_encode does not auto-cast). - # HF avoids this because its build_cond_images stores fp32 and - # the model forward casts inputs explicitly before encoding. - current_info["vae_pixel_values"] = vae_pixel_values.squeeze(0).to(dtype=torch_dtype) + # Keep fp32 — the VAE encoder casts to model dtype at its boundary + # (see _vae_encode). Casting to bf16 here costs ~7e-4 mean-abs-diff + # bf16 quantization error on every pixel vs HF (which keeps fp32 + # in build_cond_images), measurable as a real numerical drift in + # downstream image embeddings. + current_info["vae_pixel_values"] = vae_pixel_values.squeeze(0) current_info["vae_token_grid_hw"] = torch.tensor([token_height, token_width]) # size @@ -1421,6 +1419,18 @@ def _vae_encode( """ config = self.vae.config + # Cast pixel input to model dtype here (at the encoder boundary) + # rather than inside HunyuanImage3Processor.process_image. This + # matches HF's path which keeps fp32 pixels in build_cond_images and + # only casts inside the VAE forward — preserving fp32 precision in + # the multimodal_data dict and minimizing precision drift vs HF. + # Verified by pixel-tensor diff: removing the early bf16 cast brings + # omni's vae_pixel_values byte-identical to HF's (within fp32 noise), + # whereas an early cast leaves a ~7e-4 mean-abs-diff bf16 quantization + # error on every element. + if images.dtype != self.vae.dtype: + images = images.to(dtype=self.vae.dtype) + vae_encode_result = self.vae.encode(images) latents = vae_encode_result.latent_dist.sample() From 31c2fa56ebbf9f361c65f24d18d830266115fb60 Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Thu, 30 Apr 2026 11:19:57 +0800 Subject: [PATCH 09/16] fix(hunyuan_image3): route MoE in fp32 to match HF reference HF's `HunyuanTopKGate` runs the router in fp32: `wg` is constructed as `nn.Linear(..., dtype=torch.float32)`, `hidden_states` is cast to fp32 before the matmul, the call is wrapped in `with torch.autocast('cuda', enabled=False)`, and `easy_topk` does `F.softmax` -> `torch.topk` -> divide by `clamp(weight_sums, min=1e-8)`, all in fp32. Only the resulting topk weights are cast to bf16 for the expert MLP combine. vLLM's stock `HunYuanSparseMoeBlock` builds the gate as a default-dtype (bf16) `ReplicatedLinear` and lets `SharedFusedMoE`'s `topk_softmax` CUDA op consume bf16 logits. With 64 experts, top-k=8 per layer, and 32 MoE layers, bf16 quantization can flip top-k boundary decisions on close routing scores -- wrong expert MLPs are applied, the resulting hidden states diverge, the divergence cascades through the KV cache, and the eventual decoded token differs from HF. Add `HunyuanImage3SparseMoeBlock`, a subclass that mirrors the stock block 1:1 except: 1. The router gate is `ReplicatedLinear(..., params_dtype=torch.float32)`, so the `mlp.gate.wg.weight` checkpoint values (stored bf16) are upcast into a fp32 parameter on load. 2. `forward()` casts hidden states to fp32 before the gate matmul, does softmax / topk / clamp+divide renormalization in fp32, then casts the topk weights back to model dtype, exactly mirroring HF's `easy_topk` math. 3. The fp32-routed (topk_weights, topk_indices) are packed into the `router_logits` slot and `SharedFusedMoE` is built with `custom_routing_function=_hunyuan_image3_unpack_packed_topk`, so the bf16 `topk_softmax` CUDA op is bypassed entirely. `HunyuanImage3ForConditionalGeneration._patch_moe_blocks` walks the already-built `model.layers`, pops each old experts' static-forward- context registration, frees the old MoE block's GPU buffers (otherwise the transient old+new allocation OOMs near the gpu_memory_utilization cap on the 80B model with TP=2), then installs the new block. Must run inside `__init__` so it takes effect before weight loading. Verified end-to-end on a single greedy IT2I prompt (`new year pet poster ...`): - 32/32 MoE layers replaced (logged as "Replaced 32 HunYuanSparseMoeBlock layers with HunyuanImage3SparseMoeBlock (fp32 router matching HF reference)"). - Output deterministically diverged from the bf16-routed run, exactly as expected from a routing-precision change. - Removed one observed hallucination ("dog sticking out tongue") that appeared in the bf16-routed output but not in HF's. Does not byte-align with HF (PagedAttention vs contiguous KV cache and sampler RNG path differences are independent architectural divergences documented in `memory/hf_omni_alignment_method.md`), but closes the single largest *fixable* deterministic precision gap remaining after the prompt / preprocessing / image-pipeline alignment fixes. Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../models/hunyuan_image3/hunyuan_image3.py | 284 +++++++++++++++++- 1 file changed, 282 insertions(+), 2 deletions(-) diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index 6d186f75b1d..0c33a3a246e 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc import math import typing from collections.abc import Callable, Iterable, Mapping, Sequence @@ -17,14 +18,20 @@ from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.config.multimodal import BaseDummyOptions -from vllm.distributed import get_pp_group +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import fused_moe_make_expert_params_mapping from vllm.model_executor.layers.linear import ( ColumnParallelLinear, + ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -37,7 +44,9 @@ maybe_remap_kv_scale_name, ) from vllm.model_executor.models.hunyuan_v1 import ( + HunYuanMLP, HunYuanModel, + HunYuanSparseMoeBlock, _get_cla_factor, _is_moe, ) @@ -1105,6 +1114,204 @@ def get_replacement_image(item_idx: int) -> PromptUpdateDetails: ] +def _hunyuan_image3_unpack_packed_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + """Unpack pre-computed ``(topk_weights, topk_indices)`` packed by + :class:`HunyuanImage3SparseMoeBlock` into ``gating_output``. + + Used as ``custom_routing_function`` for the underlying ``SharedFusedMoE``, + bypassing its bf16 ``topk_softmax`` CUDA op so the routing decision can + be made in fp32 (matching the reference implementation). + + Layout of ``gating_output`` (shape ``[num_tokens, top_k * 2]``):: + + [:, :top_k] -> topk_weights (already softmax'd + renormalized in fp32, + stored as fp32 for transport) + [:, top_k:] -> topk_indices (cast to fp32 for transport, restored to int32) + """ + topk_weights = gating_output[:, :topk].contiguous() + topk_indices = gating_output[:, topk:] + return topk_weights.to(torch.float32), topk_indices.to(torch.int32) + + +class HunyuanImage3SparseMoeBlock(HunYuanSparseMoeBlock): + """MoE block with FP32 routing for byte-level alignment with HF. + + The reference ``modeling_hunyuan_image_3.py`` runs the router in fp32: + + - ``HunyuanTopKGate.wg`` is constructed as ``nn.Linear(..., dtype=torch.float32)`` + and ``hidden_states`` is cast to fp32 before the matmul (line 1114-1116). + - ``HunyuanMoE.forward`` wraps the gate call in + ``with torch.autocast('cuda', enabled=False):`` to defeat any AMP cast + (line 1204-1205), then calls ``easy_topk`` which does + ``F.softmax`` → ``torch.topk`` → divide by + ``torch.clamp(weight_sums, min=1e-8)`` → cast back to bf16, all in fp32 + (line 1132-1139, 1206-1207). + + vLLM's stock ``HunYuanSparseMoeBlock`` instead builds the gate as a + default-dtype ``ReplicatedLinear`` (bf16) and lets ``SharedFusedMoE``'s + ``topk_softmax`` CUDA op consume bf16 logits, which can flip top-k + boundary decisions vs HF on close routing scores. With ``num_experts=64``, + ``top_k=8`` per layer × 32 MoE layers, even a small per-token flip rate + cascades into divergent expert outputs and KV-cache state, eventually + flipping the top-1 decoded token. + + This subclass: + + 1. Replaces ``self.gate`` with a fp32 ``ReplicatedLinear``. + 2. Replaces ``self.experts`` with a ``SharedFusedMoE`` whose routing is a + no-op unpack of our pre-computed (topk_weights, topk_indices) — the + fp32 softmax/topk/renormalize is done in :meth:`forward` here, exactly + mirroring HF's ``easy_topk`` math (including ``clamp(min=1e-8)``). + """ + + def __init__( + self, + config, + quant_config=None, + layer_id: int = -1, + prefix: str = "", + enable_eplb: bool = False, + ) -> None: + # Bypass ``HunYuanSparseMoeBlock.__init__`` — it would build a wasteful + # bf16 gate + a stub ``SharedFusedMoE`` we'd then have to del+recreate + # (which trips ``ValueError: Duplicate layer name`` because the stub + # already registered itself in ``compilation_config.static_forward_context``). + # Instead, set up ``nn.Module`` ourselves and construct the fp32 gate + # + ``custom_routing_function``-driven ``SharedFusedMoE`` directly, + # mirroring the parent's structure 1:1 except for the routing dtype. + nn.Module.__init__(self) + + self.tp_size = get_tensor_model_parallel_world_size() + self.ep_group = get_ep_group().device_group + self.ep_rank = get_ep_group().rank_in_group + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_experts + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}." + ) + + if isinstance(config.moe_topk, list): + top_k = config.moe_topk[layer_id] + else: + top_k = config.moe_topk + self.top_k = top_k + + intermediate_size = config.intermediate_size + if config.moe_intermediate_size is not None: + intermediate_size = ( + config.moe_intermediate_size + if isinstance(config.moe_intermediate_size, int) + else config.moe_intermediate_size[layer_id] + ) + + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = enable_eplb + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + # FP32 router gate (HF: ``wg = nn.Linear(..., dtype=torch.float32)``). + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + params_dtype=torch.float32, + prefix=f"{prefix}.gate", + ) + + if config.use_mixed_mlp_moe > 0: + num_shared_expert = ( + config.num_shared_expert[layer_id] + if isinstance(config.num_shared_expert, list) + else config.num_shared_expert + ) + self.shared_mlp = HunYuanMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size * num_shared_expert, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_mlp", + ) + else: + self.shared_mlp = None + + # Experts with our ``_hunyuan_image3_unpack_packed_topk`` custom + # routing — we feed it (topk_weights, topk_indices) packed into + # ``router_logits`` in ``forward()`` so the bf16 ``topk_softmax`` + # CUDA op is bypassed entirely. ``renormalize=False`` because we + # already did clamp+divide in fp32 to match HF's + # ``topk_weight = topk_weight_1 / clamp(sum, min=1e-8)``. + self.experts = SharedFusedMoE( + shared_experts=self.shared_mlp, + num_experts=self.n_routed_experts, + top_k=top_k, + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + reduce_results=False, + renormalize=False, + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + custom_routing_function=_hunyuan_image3_unpack_packed_topk, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + # FP32 router (HF: `with torch.autocast('cuda', enabled=False): ...` + # plus `if self.wg.weight.dtype == torch.float32: hidden_states.float()`). + # ``self.gate.weight`` is fp32 (params_dtype=torch.float32), so the + # ReplicatedLinear matmul runs in fp32 once we cast the input. + router_logits, _ = self.gate(hidden_states.float()) + + # softmax + topk + clamp-divide renormalization, all in fp32 — matches + # ``HunyuanTopKGate.easy_topk`` exactly. + gates = torch.softmax(router_logits, dim=-1, dtype=torch.float32) + topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) + weight_sums = topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights / weight_sums.clamp(min=1e-8) + + # Cast topk weights to model dtype for the expert MLP combine. + # HF: ``topk_weights = topk_weights.to(hidden_states.dtype)`` (line 1207). + topk_weights = topk_weights.to(hidden_states.dtype) + + # Pack (weights, indices) into the ``router_logits`` slot so + # ``_hunyuan_image3_unpack_packed_topk`` can pull them back out + # inside ``SharedFusedMoE``. Both halves are stored as fp32 for + # transport — the indices get cast back to int32 on unpack. + packed_routing = torch.cat( + [topk_weights.float(), topk_indices.to(torch.float32)], dim=-1 + ) + + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=packed_routing + ) + if self.shared_mlp is not None: + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + return final_hidden_states.view(orig_shape) + + class HunyuanImage3RotaryEmbedding(nn.Module): """Custom interleaved 2D Rotary Embedding for HunyuanImage3. @@ -1341,6 +1548,79 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._eos_token_id: int = tokenizer.eos_token_id self._replace_rotary_embeddings() + self._patch_moe_blocks() + + def _patch_moe_blocks(self): + """Replace stock ``HunYuanSparseMoeBlock`` instances with + :class:`HunyuanImage3SparseMoeBlock`, which routes in fp32 to match + the HF reference (``modeling_hunyuan_image_3.HunyuanMoE``). + + Stock vLLM builds the router gate as a default-dtype (bf16) + ``ReplicatedLinear`` and lets ``SharedFusedMoE``'s ``topk_softmax`` + kernel consume bf16 logits, which is the largest deterministic + precision gap remaining vs HF after the prompt/preprocessing + alignment fixes. See ``HunyuanImage3SparseMoeBlock`` docstring for + the full rationale. + + Must run before weight loading (still inside ``__init__``) so the + replacement gate's fp32 ``params_dtype`` is honored when the + checkpoint is loaded. + """ + if not _is_moe(self.config): + return + enable_eplb = getattr( + self.vllm_config.parallel_config, "enable_eplb", False + ) + ccfg = self.vllm_config.compilation_config + replaced = 0 + for layer_id, layer in enumerate(self.model.layers): + mlp = getattr(layer, "mlp", None) + if isinstance(mlp, HunYuanSparseMoeBlock) and not isinstance( + mlp, HunyuanImage3SparseMoeBlock + ): + # Pop the OLD experts' registration from + # ``static_forward_context`` first — otherwise the new + # ``SharedFusedMoE`` built inside + # :class:`HunyuanImage3SparseMoeBlock` will trip + # ``ValueError: Duplicate layer name`` (see + # vllm/model_executor/layers/fused_moe/layer.py:327). + old_prefix = f"model.layers.{layer_id}.mlp.experts" + ccfg.static_forward_context.pop(old_prefix, None) + if old_prefix in ccfg.static_all_moe_layers: + ccfg.static_all_moe_layers.remove(old_prefix) + + # Free the OLD MoE block's GPU buffers BEFORE allocating + # the replacement. The parent ``SharedFusedMoE`` pre- + # allocates the full ``[num_experts, ...]`` expert weight + # tensors at ``__init__`` (~750 MiB per layer per worker + # on this 80B model with TP=2), so without this drop we + # transiently double the MoE footprint and OOM near the + # gpu_memory_utilization cap. + layer.mlp = None + del mlp + gc.collect() + torch.cuda.empty_cache() + + layer.mlp = HunyuanImage3SparseMoeBlock( + config=self.config, + quant_config=self.quant_config, + layer_id=layer_id, + prefix=f"model.layers.{layer_id}.mlp", + enable_eplb=enable_eplb, + ) + replaced += 1 + logger.info( + "Replaced %d HunYuanSparseMoeBlock layers with " + "HunyuanImage3SparseMoeBlock (fp32 router matching HF reference)", + replaced, + ) + if replaced == 0: + logger.warning( + "HunyuanImage3: _patch_moe_blocks replaced 0 layers. " + "Routing will run in bf16 instead of fp32 — output will " + "diverge from the HF reference more than necessary. " + "Check that model.layers[*].mlp is HunYuanSparseMoeBlock." + ) def _replace_rotary_embeddings(self): """Replace vLLM's standard MRotaryEmbedding with the custom From 07d8cf0d9d4e75fb062269a17a3d80429552d238 Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Thu, 30 Apr 2026 11:47:37 +0800 Subject: [PATCH 10/16] fix(hunyuan_image3): stop AR-only output at for i2t/t2t The i2t and t2t stage configs are `is_comprehension: true, final_output_type: text` -- AR-only text output, no DiT image generation. They should mirror HF's `bot_task="think"` which terminates the response at ``. The previous stop_token_ids `[127957, 128026]` (`<|endoftext|>`, ``) assumed the model would naturally stop at `` once `_StageTransitionLogitsProcessor` is gated off (it only fires in generation mode, not comprehension mode). In practice the instruct-tuned model continues into a `` section out of trained habit and never emits `` (which is only meaningful after the full `......` sequence the generation pipeline runs through). Add `` (128024) to the stop_token_ids for i2t and t2t. This makes greedy IT2I AR-only output align with HF's `bot_task="think"` baseline: | | HF baseline | omni before | omni after | |-------------|------------:|------------:|-----------:| | chars | 466 | 811 | 482 | | bytes | 1354 | 2375 | 1416 | | sections | think | think+recap | think | | gap to HF | 0 | +345 | +16 | Length divergence collapses from +74% to +3.4%. The remaining +16 chars sits in BF16 reduction noise / sampler implementation differences (documented in `memory/hf_omni_alignment_method.md`) and cannot be closed without reimplementing vllm's attention. `hunyuan_image3_it2i.yaml` is intentionally NOT changed: the IT2I pipeline (`is_comprehension: false`, AR -> DiT) needs the AR stage to emit the full `...` sequence so that DiT can decode the image latents. Stopping at `` there would break image generation. Update the existing comment in `HunyuanImage3ForConditionalGeneration.__init__` that incorrectly claimed comprehension mode would stop at `` or EOS, so future readers understand why we explicitly stop at `` in the yaml. Verified end-to-end: greedy IT2I AR-only output now ends cleanly at the analysis section, byte-for-byte structurally aligned with HF's `bot_task="think"` output (only differs in BF16-noise-driven per-token text divergence, no extra recaption section). Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../models/hunyuan_image3/hunyuan_image3.py | 10 ++++++++-- .../stage_configs/hunyuan_image3_i2t.yaml | 2 +- .../stage_configs/hunyuan_image3_t2t.yaml | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index 0c33a3a246e..5e48d24cfa2 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -1515,8 +1515,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # For comprehension mode, block image generation tokens but allow # text structure tokens (, , etc.) so the model can - # follow its natural generation pattern. Stop tokens in YAML will - # terminate at or EOS. + # follow its natural generation pattern. The yaml stop_token_ids + # for i2t/t2t now includes (128024) so the AR-only output + # terminates after the analysis section, matching HF's + # `bot_task="think"` behavior. Without that stop, the model + # continues into a recaption section even in comprehension mode + # (the stage-transition processor only fires in generation mode, + # but the instruct-tuned model writes recaption on its own from + # internal habit). self._blocked_token_ids: set[int] = set() if self._is_comprehension: self._blocked_token_ids.update( diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml index b68b184ec31..0614a9f1179 100644 --- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml @@ -34,7 +34,7 @@ stage_args: top_p: 0.95 top_k: 1024 max_tokens: 2048 - stop_token_ids: [127957, 128026] # <|endoftext|>, + stop_token_ids: [127957, 128024, 128026] # <|endoftext|>, , detokenize: True runtime: diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml index a0a1a0dc1c4..c9daa5e5f39 100644 --- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml @@ -35,7 +35,7 @@ stage_args: top_p: 0.95 top_k: 1024 max_tokens: 2048 - stop_token_ids: [127957, 128026] # <|endoftext|>, + stop_token_ids: [127957, 128024, 128026] # <|endoftext|>, , detokenize: True runtime: From 6978fd77ce2f5cc5ac1ed32597e9253daf3c8e9c Mon Sep 17 00:00:00 2001 From: zuiho-kai <31877877+zuiho-kai@users.noreply.github.com> Date: Thu, 30 Apr 2026 13:41:39 +0800 Subject: [PATCH 11/16] fix(hunyuan_image3): import SharedFusedMoE Signed-off-by: zuiho-kai <31877877+zuiho-kai@users.noreply.github.com> --- .../models/hunyuan_image3/hunyuan_image3.py | 24 ++++++------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index 5e48d24cfa2..d80797d432c 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -29,6 +29,7 @@ from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import fused_moe_make_expert_params_mapping +from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.linear import ( ColumnParallelLinear, ReplicatedLinear, @@ -1194,8 +1195,7 @@ def __init__( if self.tp_size > config.num_experts: raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}." + f"Tensor parallel size {self.tp_size} is greater than the number of experts {config.num_experts}." ) if isinstance(config.moe_topk, list): @@ -1220,9 +1220,7 @@ def __init__( self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size self.physical_expert_start = self.ep_rank * self.n_local_physical_experts - self.physical_expert_end = ( - self.physical_expert_start + self.n_local_physical_experts - ) + self.physical_expert_end = self.physical_expert_start + self.n_local_physical_experts # FP32 router gate (HF: ``wg = nn.Linear(..., dtype=torch.float32)``). self.gate = ReplicatedLinear( @@ -1298,13 +1296,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # ``_hunyuan_image3_unpack_packed_topk`` can pull them back out # inside ``SharedFusedMoE``. Both halves are stored as fp32 for # transport — the indices get cast back to int32 on unpack. - packed_routing = torch.cat( - [topk_weights.float(), topk_indices.to(torch.float32)], dim=-1 - ) + packed_routing = torch.cat([topk_weights.float(), topk_indices.to(torch.float32)], dim=-1) - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=packed_routing - ) + final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=packed_routing) if self.shared_mlp is not None: final_hidden_states = final_hidden_states[0] + final_hidden_states[1] if self.tp_size > 1: @@ -1574,16 +1568,12 @@ def _patch_moe_blocks(self): """ if not _is_moe(self.config): return - enable_eplb = getattr( - self.vllm_config.parallel_config, "enable_eplb", False - ) + enable_eplb = getattr(self.vllm_config.parallel_config, "enable_eplb", False) ccfg = self.vllm_config.compilation_config replaced = 0 for layer_id, layer in enumerate(self.model.layers): mlp = getattr(layer, "mlp", None) - if isinstance(mlp, HunYuanSparseMoeBlock) and not isinstance( - mlp, HunyuanImage3SparseMoeBlock - ): + if isinstance(mlp, HunYuanSparseMoeBlock) and not isinstance(mlp, HunyuanImage3SparseMoeBlock): # Pop the OLD experts' registration from # ``static_forward_context`` first — otherwise the new # ``SharedFusedMoE`` built inside From fd4793e09f876c3dc9aab7cd8d013f174557368e Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Thu, 30 Apr 2026 15:54:09 +0800 Subject: [PATCH 12/16] refactor(hunyuan_image3): extract prompt_utils as shared builder Move build_prompt and build_prompt_tokens out of the example script into vllm_omni/diffusion/models/hunyuan_image3/prompt_utils.py so the AR-prefill prompt template has a single source of truth that downstream callers can reuse. The DiT pipeline keeps using TokenizerWrapper.apply_chat_template (which eagerly consumes JointImageInfo); prompt_utils targets the lighter client-side flow that uses an placeholder + multi_modal_data. README is updated to describe the actual instruct chat template (the previous "pretrain template" wording was stale relative to the post-fix behavior introduced earlier in this PR) and to point at the new module. Addresses GH PR review comment requesting a common prompt-construction function shared across AR / DiT / end2end.py. Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../hunyuan_image3/README.md | 12 +- .../hunyuan_image3/end2end.py | 128 +-------------- .../models/hunyuan_image3/prompt_utils.py | 152 ++++++++++++++++++ 3 files changed, 161 insertions(+), 131 deletions(-) create mode 100644 vllm_omni/diffusion/models/hunyuan_image3/prompt_utils.py diff --git a/examples/offline_inference/hunyuan_image3/README.md b/examples/offline_inference/hunyuan_image3/README.md index 3cd8fa01b2e..3eb3bfbff6f 100644 --- a/examples/offline_inference/hunyuan_image3/README.md +++ b/examples/offline_inference/hunyuan_image3/README.md @@ -135,17 +135,19 @@ python end2end.py --model tencent/HunyuanImage-3.0-Instruct \ ## Prompt Format -HunyuanImage-3.0 uses a pretrain template format: +HunyuanImage-3.0-Instruct uses an instruct chat template: ``` -<|startoftext|>{system_prompt}{}{trigger_tag}{user_prompt} +<|startoftext|>{system_prompt}\n\nUser: {?}{user_prompt}\n\nAssistant: {trigger_tag?} ``` -- ``: Placeholder for each input image (auto-inserted by `prompt_utils.py`) -- Trigger tags: `` (CoT), `` (recaptioning) +- ``: Placeholder for each input image (single token; expanded by the multimodal pipeline) +- Trigger tags: `` (CoT), `` (recaptioning) — placed AFTER `Assistant: ` - System prompt: Auto-selected based on task +- `t2i_vanilla` is the only task that uses the bare pretrain template (no chat structure) -The `prompt_utils.build_prompt()` handles this formatting automatically. +The shared `vllm_omni.diffusion.models.hunyuan_image3.prompt_utils.build_prompt_tokens()` +helper handles segment-by-segment tokenization (matches HF `apply_chat_template` byte-for-byte). ------ diff --git a/examples/offline_inference/hunyuan_image3/end2end.py b/examples/offline_inference/hunyuan_image3/end2end.py index ac12ef2cfee..fc4b75e78d5 100644 --- a/examples/offline_inference/hunyuan_image3/end2end.py +++ b/examples/offline_inference/hunyuan_image3/end2end.py @@ -16,23 +16,12 @@ import argparse import os -from vllm_omni.diffusion.models.hunyuan_image3.system_prompt import ( - get_system_prompt, +from vllm_omni.diffusion.models.hunyuan_image3.prompt_utils import ( + build_prompt_tokens, ) from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniPromptType -# task → (sys_type, bot_task, trigger_tag) -_TASK_PRESETS: dict[str, tuple[str, str | None, str | None]] = { - "t2t": ("en_unified", None, None), - "i2t": ("en_unified", None, None), - "it2i_think": ("en_unified", "think", ""), - "it2i_recaption": ("en_unified", "recaption", ""), - "t2i_think": ("en_unified", "think", ""), - "t2i_recaption": ("en_unified", "recaption", ""), - "t2i_vanilla": ("en_vanilla", "image", None), -} - # Modality → prompt_utils task mapping _MODALITY_TASK_MAP = { "text2img": "t2i_think", @@ -42,119 +31,6 @@ } -def build_prompt_tokens( - user_prompt: str, - tokenizer, - task: str = "it2i_think", - sys_type: str | None = None, - custom_system_prompt: str | None = None, -) -> list[int]: - """Segment-by-segment tokenization that matches HF apply_chat_template. - - Calling tokenizer.encode(build_prompt(...)) on the full string lets BPE - merge tokens across segment boundaries (e.g. user_prompt ends with `。` - and the next segment is `\\n\\n` -> they merge into a single token id - 3490 instead of HF's [1811, 271]). HF's apply_chat_template tokenizes - each segment independently and concatenates token_ids, so no cross- - boundary merge happens. We replicate that here and feed the result to - Omni via OmniTokensPrompt (prompt_token_ids). - """ - if task not in _TASK_PRESETS: - raise ValueError(f"Unknown task {task!r}. Choose from: {sorted(_TASK_PRESETS)}") - - preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task] - effective_sys_type = sys_type or preset_sys_type - - bos_id = tokenizer.convert_tokens_to_ids("<|startoftext|>") - img_id = tokenizer.convert_tokens_to_ids("") - trig_id = tokenizer.convert_tokens_to_ids(trigger_tag) if trigger_tag else None - - has_image_input = task.startswith("i2t") or task.startswith("it2i") - - # t2i_vanilla uses pretrain template with no chat structure; the vanilla - # system prompt drives the model directly. No segment boundaries to - # protect, fall back to whole-string encode. - if task == "t2i_vanilla": - s = build_prompt(user_prompt, task, sys_type, custom_system_prompt) - return tokenizer.encode(s, add_special_tokens=False) - - system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt) - # Do NOT strip — HF apply_chat_template keeps the system prompt's - # natural trailing newline; stripping it would shift one token id. - sys_text = system_prompt or "" - - ids: list[int] = [bos_id] - if sys_text: - ids += tokenizer.encode(sys_text, add_special_tokens=False) - ids += tokenizer.encode("\n\n", add_special_tokens=False) - ids += tokenizer.encode("User: ", add_special_tokens=False) - if has_image_input: - ids += [img_id] - ids += tokenizer.encode(user_prompt, add_special_tokens=False) - ids += tokenizer.encode("\n\nAssistant: ", add_special_tokens=False) - if trig_id is not None: - ids += [trig_id] - return ids - - -def build_prompt( - user_prompt: str, - task: str = "it2i_think", - sys_type: str | None = None, - custom_system_prompt: str | None = None, -) -> str: - """Build a HunyuanImage-3.0 prompt as a string (legacy/compat path). - - NOTE: when this string is passed to the engine, the engine's tokenizer - will run a single BPE pass over the whole string, which can merge - tokens across segment boundaries (e.g. `。\\n\\n` -> id 3490). For - inputs that need to match HF baseline byte-for-byte, use - `build_prompt_tokens` instead and feed the result via prompt_token_ids. - """ - if task not in _TASK_PRESETS: - raise ValueError(f"Unknown task {task!r}. Choose from: {sorted(_TASK_PRESETS)}") - - preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task] - effective_sys_type = sys_type or preset_sys_type - - system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt) - sys_text = system_prompt.strip() if system_prompt else "" - - has_image_input = task.startswith("i2t") or task.startswith("it2i") - - # t2i_vanilla: pretrain mode for direct text→image generation. The - # vanilla system prompt drives the model with no chat structure. - if task == "t2i_vanilla": - parts = ["<|startoftext|>"] - if sys_text: - parts.append(sys_text) - parts.append(user_prompt) - return "".join(parts) - - # All other tasks (t2t / i2t / t2i_think / t2i_recaption / - # it2i_think / it2i_recaption) use HunyuanImage3 Instruct chat template: - # <|startoftext|>{system?}\n\nUser: {?}{user_prompt}\n\nAssistant: {trigger?} - # generation_config.json declares sequence_template="instruct", so the - # AR prefill MUST use this template — verified to match HF's - # apply_chat_template output token-for-token (modulo BPE boundary merges). - # The trigger_tag (e.g. ) MUST come AFTER the `Assistant: ` prefix: - # if it goes BEFORE user_prompt (the old pretrain layout) the model puts - # the user's instructions inside the "thinking section" and collapses - # into repetition garbage under greedy decoding. - parts = ["<|startoftext|>"] - if sys_text: - parts.append(f"{sys_text}\n\n") - parts.append("User: ") - if has_image_input: - parts.append("") - parts.append(user_prompt) - parts.append("\n\nAssistant: ") - if trigger_tag: - parts.append(trigger_tag) - - return "".join(parts) - - # Modality → default stage config _MODALITY_DEFAULT_CONFIG = { "text2img": "hunyuan_image3_t2i.yaml", diff --git a/vllm_omni/diffusion/models/hunyuan_image3/prompt_utils.py b/vllm_omni/diffusion/models/hunyuan_image3/prompt_utils.py new file mode 100644 index 00000000000..6e8efac3133 --- /dev/null +++ b/vllm_omni/diffusion/models/hunyuan_image3/prompt_utils.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Shared prompt-template construction for HunyuanImage-3.0-Instruct. + +Single source of truth for the AR-prefill prompt format used by the +example scripts and any downstream caller that needs to build +HunyuanImage3 chat-template token sequences without invoking the full +diffusion pipeline tokenizer wrapper. + +The DiT pipeline (`pipeline_hunyuan_image3.py`) builds prompts through +`TokenizerWrapper.apply_chat_template`, which eagerly consumes +`JointImageInfo` objects produced by image preprocessing. The example +flow uses an `` placeholder + `multi_modal_data` instead, so it +needs a lighter-weight builder that only requires a HF tokenizer. This +module provides that builder; the task -> template mapping below is the +canonical mapping for both flows. +""" + +from __future__ import annotations + +from .system_prompt import get_system_prompt + +# task -> (sys_type, bot_task, trigger_tag) +_TASK_PRESETS: dict[str, tuple[str, str | None, str | None]] = { + "t2t": ("en_unified", None, None), + "i2t": ("en_unified", None, None), + "it2i_think": ("en_unified", "think", ""), + "it2i_recaption": ("en_unified", "recaption", ""), + "t2i_think": ("en_unified", "think", ""), + "t2i_recaption": ("en_unified", "recaption", ""), + "t2i_vanilla": ("en_vanilla", "image", None), +} + + +def available_tasks() -> list[str]: + """Sorted list of task keys accepted by `build_prompt` / `build_prompt_tokens`.""" + return sorted(_TASK_PRESETS) + + +def build_prompt( + user_prompt: str, + task: str = "it2i_think", + sys_type: str | None = None, + custom_system_prompt: str | None = None, +) -> str: + """Build a HunyuanImage-3.0 prompt as a string (legacy/compat path). + + NOTE: when this string is passed to the engine, the engine's tokenizer + will run a single BPE pass over the whole string, which can merge + tokens across segment boundaries (e.g. `。\\n\\n` -> id 3490). For + inputs that need to match HF baseline byte-for-byte, use + `build_prompt_tokens` instead and feed the result via prompt_token_ids. + """ + if task not in _TASK_PRESETS: + raise ValueError(f"Unknown task {task!r}. Choose from: {available_tasks()}") + + preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task] + effective_sys_type = sys_type or preset_sys_type + + system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt) + sys_text = system_prompt.strip() if system_prompt else "" + + has_image_input = task.startswith("i2t") or task.startswith("it2i") + + # t2i_vanilla: pretrain mode for direct text->image generation. The + # vanilla system prompt drives the model with no chat structure. + if task == "t2i_vanilla": + parts = ["<|startoftext|>"] + if sys_text: + parts.append(sys_text) + parts.append(user_prompt) + return "".join(parts) + + # All other tasks (t2t / i2t / t2i_think / t2i_recaption / + # it2i_think / it2i_recaption) use HunyuanImage3 Instruct chat template: + # <|startoftext|>{system?}\n\nUser: {?}{user_prompt}\n\nAssistant: {trigger?} + # generation_config.json declares sequence_template="instruct", so the + # AR prefill MUST use this template -- verified to match HF's + # apply_chat_template output token-for-token (modulo BPE boundary merges). + # The trigger_tag (e.g. ) MUST come AFTER the `Assistant: ` prefix: + # if it goes BEFORE user_prompt (the old pretrain layout) the model puts + # the user's instructions inside the "thinking section" and collapses + # into repetition garbage under greedy decoding. + parts = ["<|startoftext|>"] + if sys_text: + parts.append(f"{sys_text}\n\n") + parts.append("User: ") + if has_image_input: + parts.append("") + parts.append(user_prompt) + parts.append("\n\nAssistant: ") + if trigger_tag: + parts.append(trigger_tag) + + return "".join(parts) + + +def build_prompt_tokens( + user_prompt: str, + tokenizer, + task: str = "it2i_think", + sys_type: str | None = None, + custom_system_prompt: str | None = None, +) -> list[int]: + """Segment-by-segment tokenization that matches HF apply_chat_template. + + Calling tokenizer.encode(build_prompt(...)) on the full string lets BPE + merge tokens across segment boundaries (e.g. user_prompt ends with `。` + and the next segment is `\\n\\n` -> they merge into a single token id + 3490 instead of HF's [1811, 271]). HF's apply_chat_template tokenizes + each segment independently and concatenates token_ids, so no cross- + boundary merge happens. We replicate that here and feed the result to + Omni via OmniTokensPrompt (prompt_token_ids). + """ + if task not in _TASK_PRESETS: + raise ValueError(f"Unknown task {task!r}. Choose from: {available_tasks()}") + + preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task] + effective_sys_type = sys_type or preset_sys_type + + bos_id = tokenizer.convert_tokens_to_ids("<|startoftext|>") + img_id = tokenizer.convert_tokens_to_ids("") + trig_id = tokenizer.convert_tokens_to_ids(trigger_tag) if trigger_tag else None + + has_image_input = task.startswith("i2t") or task.startswith("it2i") + + # t2i_vanilla uses pretrain template with no chat structure; the vanilla + # system prompt drives the model directly. No segment boundaries to + # protect, fall back to whole-string encode. + if task == "t2i_vanilla": + s = build_prompt(user_prompt, task, sys_type, custom_system_prompt) + return tokenizer.encode(s, add_special_tokens=False) + + system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt) + # Do NOT strip -- HF apply_chat_template keeps the system prompt's + # natural trailing newline; stripping it would shift one token id. + sys_text = system_prompt or "" + + ids: list[int] = [bos_id] + if sys_text: + ids += tokenizer.encode(sys_text, add_special_tokens=False) + ids += tokenizer.encode("\n\n", add_special_tokens=False) + ids += tokenizer.encode("User: ", add_special_tokens=False) + if has_image_input: + ids += [img_id] + ids += tokenizer.encode(user_prompt, add_special_tokens=False) + ids += tokenizer.encode("\n\nAssistant: ", add_special_tokens=False) + if trig_id is not None: + ids += [trig_id] + return ids + + +__all__ = ["build_prompt", "build_prompt_tokens", "available_tasks"] From 36b14081cfe915e0d6fe3bb09409a8a4f37eeba0 Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Thu, 30 Apr 2026 16:09:35 +0800 Subject: [PATCH 13/16] test(hunyuan_image3): regression tests for AR prompt template (PR #3243) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three layers of protection for the bug fixed in this PR: 1. Pure-logic structural tests (FakeTokenizer-based) verify: - The chat template framing (<|startoftext|> ... User: ... Assistant: ...) - Trigger tag ( / ) is appended AFTER `Assistant: ` (Part A regression: putting the trigger BEFORE user_prompt sends the model into a death-loop under greedy decoding). - placeholder is positioned correctly for image-input tasks. - Each prompt segment is encoded in an isolated tokenizer.encode() call so cross-segment BPE merges cannot occur (the bug from commit 7bd429ed). 2. AST-based wiring guard verifies that examples/.../end2end.py imports build_prompt_tokens from prompt_utils and does NOT redefine it locally. This protects the *delivery vector* of the original regression: the wrong template re-entered the example via a hand-rolled local builder that diverged from the canonical helper. 3. Real-tokenizer regression (skipped if HunyuanImage3 not in HF cache) asserts that segment-by-segment build_prompt_tokens produces a STRICTLY different id sequence than tokenizer.encode(build_prompt(...)) for a `。`-ending prompt. If a future "simplification" replaces segment encode with full-string encode, the BPE-merge-bypass behavior is gone and this test fires. Verified on remote (2x L20X, transformers 4.57.1, HunyuanImage3-Instruct tokenizer in HF cache): 19/19 passed including the real-tokenizer test. Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../hunyuan_image3/test_prompt_utils.py | 306 ++++++++++++++++++ 1 file changed, 306 insertions(+) create mode 100644 tests/diffusion/models/hunyuan_image3/test_prompt_utils.py diff --git a/tests/diffusion/models/hunyuan_image3/test_prompt_utils.py b/tests/diffusion/models/hunyuan_image3/test_prompt_utils.py new file mode 100644 index 00000000000..1010f393fed --- /dev/null +++ b/tests/diffusion/models/hunyuan_image3/test_prompt_utils.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Regression tests for HunyuanImage3 prompt construction (PR #3243). + +Two layers: + 1. Pure-logic tests with a recording fake tokenizer -- protect the + prompt template structure (BOS, User:/Assistant: framing, trigger + placement, image placeholder position) and protect the segment- + by-segment tokenization contract (each segment must hit + `tokenizer.encode` in isolation). + 2. Real-tokenizer regression -- run when the HunyuanImage3-Instruct + tokenizer is in the local HF cache. Asserts the segment-tokenized + output diverges from the naive full-string encode, which is the + bug-tripping fixture for the cross-segment BPE merge fix + (commit 7bd429ed). +""" + +from __future__ import annotations + +import ast +import os +import pathlib + +import pytest + +from vllm_omni.diffusion.models.hunyuan_image3.prompt_utils import ( + available_tasks, + build_prompt, + build_prompt_tokens, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +# -------------------- Pure-logic structural tests -------------------- + + +class FakeTokenizer: + """Minimal tokenizer stub that records every encode() call. + + Returns deterministic ids: special tokens map to small ints (1-4), + encode() returns one id per character starting at 100. This lets + tests both verify segmentation (by inspecting `encode_calls`) and + locate substrings inside the returned id list. + """ + + SPECIAL = { + "<|startoftext|>": 1, + "": 2, + "": 3, + "": 4, + } + + def __init__(self) -> None: + self.encode_calls: list[str] = [] + + def convert_tokens_to_ids(self, tok: str) -> int: + return self.SPECIAL.get(tok, 0) + + def encode(self, text: str, add_special_tokens: bool = False) -> list[int]: + self.encode_calls.append(text) + return list(range(100, 100 + len(text))) + + +def test_available_tasks_covers_all_modalities(): + tasks = set(available_tasks()) + assert tasks >= { + "t2t", + "i2t", + "it2i_think", + "it2i_recaption", + "t2i_think", + "t2i_recaption", + "t2i_vanilla", + } + + +@pytest.mark.parametrize( + "task", + [ + "t2t", + "i2t", + "it2i_think", + "it2i_recaption", + "t2i_think", + "t2i_recaption", + ], +) +def test_build_prompt_string_structure_chat_template(task: str): + """Chat-template tasks must produce <|startoftext|>...User: ...Assistant: ... + with image placeholder (when applicable) and trigger tag AFTER `Assistant: `.""" + s = build_prompt("HELLO", task=task) + + assert s.startswith("<|startoftext|>") + assert "User: " in s + assert "Assistant: " in s + assert s.index("User: ") < s.index("HELLO") < s.index("Assistant: ") + + if task.startswith(("i2t", "it2i")): + assert s.index("User: ") < s.index("") < s.index("HELLO"), ( + " placeholder must sit between `User: ` and the user prompt" + ) + else: + assert "" not in s + + # Trigger tag must be the FINAL token of the prompt (after `Assistant: `). + # Note: the system prompt itself mentions / as mode + # documentation, so substring index() catches the wrong occurrence -- use + # endswith() which directly captures "trigger is at the tail" (the Part A + # fix: trigger goes AFTER `Assistant: `, not before user_prompt). + if task in ("it2i_think", "t2i_think"): + assert s.endswith("Assistant: "), ( + "Trigger must be appended right after `Assistant: ` (Part A fix). " + f"Got tail: ...{s[-40:]!r}" + ) + if task in ("it2i_recaption", "t2i_recaption"): + assert s.endswith("Assistant: "), ( + "Trigger must be appended right after `Assistant: ` (Part A fix). " + f"Got tail: ...{s[-40:]!r}" + ) + if task in ("t2t", "i2t"): + assert s.endswith("Assistant: "), ( + "Plain (no-trigger) task must end at `Assistant: ` with no trailing tag." + ) + + +def test_build_prompt_vanilla_uses_pretrain_template(): + """t2i_vanilla is the only task that bypasses chat structure -- direct + text->image generation driven by the vanilla system prompt.""" + s = build_prompt("HELLO", task="t2i_vanilla") + assert s.startswith("<|startoftext|>") + assert "User: " not in s + assert "Assistant: " not in s + assert "" not in s + assert "" not in s + assert s.endswith("HELLO") + + +def test_build_prompt_unknown_task_raises(): + with pytest.raises(ValueError, match="Unknown task"): + build_prompt("x", task="bogus") + with pytest.raises(ValueError, match="Unknown task"): + build_prompt_tokens("x", FakeTokenizer(), task="bogus") + + +def test_build_prompt_tokens_segments_each_boundary(): + """Regression for cross-segment BPE merge bug (commit 7bd429ed): + each template segment must hit tokenizer.encode() independently; + user_prompt MUST NOT be concatenated with the following separator + in the same encode() call.""" + tok = FakeTokenizer() + build_prompt_tokens("写诗。", tok, task="i2t") + + # Each canonical segment is encoded in its own call. + assert "User: " in tok.encode_calls + assert "写诗。" in tok.encode_calls, ( + "user_prompt must be encoded alone -- if it is concatenated with the " + "trailing separator, BPE will merge across the boundary (the PR-#3243 bug)." + ) + assert "\n\nAssistant: " in tok.encode_calls + + # No call must contain user_prompt glued to neighboring text. + for call in tok.encode_calls: + if call != "写诗。": + assert "写诗。" not in call, ( + f"user_prompt leaked into a multi-segment encode call: {call!r}" + ) + + +def test_build_prompt_tokens_image_placeholder_present_for_image_tasks(): + tok = FakeTokenizer() + ids = build_prompt_tokens("hi", tok, task="i2t") + assert ids[0] == 1, "BOS (<|startoftext|>) must be the first token" + assert 2 in ids, " placeholder must be present for i2t/it2i tasks" + + +def test_build_prompt_tokens_no_image_for_text_only_tasks(): + tok = FakeTokenizer() + ids = build_prompt_tokens("hi", tok, task="t2t") + assert 2 not in ids, " must NOT appear for text-only tasks" + + +@pytest.mark.parametrize( + "task,trigger_id", + [("it2i_think", 3), ("t2i_think", 3), ("it2i_recaption", 4), ("t2i_recaption", 4)], +) +def test_build_prompt_tokens_trigger_is_last_token(task: str, trigger_id: int): + """Trigger tag id must be the LAST token (after `Assistant: ` segment).""" + tok = FakeTokenizer() + ids = build_prompt_tokens("hi", tok, task=task) + assert ids[-1] == trigger_id + + +def test_build_prompt_tokens_no_trigger_for_plain_tasks(): + """Tasks without trigger_tag (t2t / i2t) must NOT append a trigger id.""" + tok = FakeTokenizer() + ids = build_prompt_tokens("hi", tok, task="t2t") + assert ids[-1] not in {3, 4} # neither nor + + +# -------------------- end2end.py wiring guard -------------------- + + +def _repo_root() -> pathlib.Path: + # tests/diffusion/models/hunyuan_image3/test_prompt_utils.py -> repo root + return pathlib.Path(__file__).resolve().parents[4] + + +def test_end2end_routes_through_shared_prompt_utils(): + """Regression for the *delivery vector* of PR #3243. + + Background: the wrong-template bug that PR #3243 fixes was introduced + when end2end.py grew its own hand-rolled prompt builder that diverged + from the canonical instruct chat template. To prevent that exact + failure mode from recurring, end2end.py MUST: + 1. Import the prompt builders from the shared prompt_utils module. + 2. NOT redefine `build_prompt` or `build_prompt_tokens` locally. + + A local redefinition is precisely how a future merge can silently + re-introduce a pretrain-style template (trigger BEFORE user_prompt, + no User:/Assistant: framing, etc.) without touching prompt_utils, + bypassing every other test in this file. + """ + end2end_path = ( + _repo_root() / "examples" / "offline_inference" / "hunyuan_image3" / "end2end.py" + ) + assert end2end_path.is_file(), f"end2end.py not found at {end2end_path}" + + tree = ast.parse(end2end_path.read_text(encoding="utf-8")) + + local_func_names = {n.name for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)} + forbidden = {"build_prompt", "build_prompt_tokens"} + redefined = local_func_names & forbidden + assert not redefined, ( + f"end2end.py defines {sorted(redefined)} locally. This is exactly how " + "the wrong prompt template re-entered the example before PR #3243. " + "Use the shared `vllm_omni.diffusion.models.hunyuan_image3.prompt_utils` " + "helpers instead." + ) + + imported_from_prompt_utils: set[str] = set() + for node in ast.walk(tree): + if ( + isinstance(node, ast.ImportFrom) + and node.module + and node.module.endswith("hunyuan_image3.prompt_utils") + ): + imported_from_prompt_utils.update(alias.name for alias in node.names) + assert "build_prompt_tokens" in imported_from_prompt_utils, ( + "end2end.py must import build_prompt_tokens from " + "vllm_omni.diffusion.models.hunyuan_image3.prompt_utils -- the shared " + "helper is the single source of truth for the AR-prefill template." + ) + + +# -------------------- Real-tokenizer regression -------------------- + + +_HUNYUAN_MODEL_ID = "tencent/HunyuanImage-3.0-Instruct" + + +def _hf_cached(model_id: str) -> bool: + hf_home = os.environ.get("HF_HOME") or os.path.expanduser("~/.cache/huggingface") + snap_dir = os.path.join( + hf_home, "hub", f"models--{model_id.replace('/', '--')}", "snapshots" + ) + return os.path.isdir(snap_dir) and any(os.scandir(snap_dir)) + + +@pytest.mark.skipif( + not _hf_cached(_HUNYUAN_MODEL_ID), + reason=f"{_HUNYUAN_MODEL_ID} tokenizer not in HF cache", +) +def test_segment_tokenize_diverges_from_full_string_encode(): + """Regression for PR #3243 segment-tokenization fix. + + The naive `tokenizer.encode(build_prompt(...))` lets BPE merge tokens + across segment boundaries (notably `。\\n\\n` -> a single id), which + drifts the AR prefill away from HF's apply_chat_template output. The + segment-by-segment build_prompt_tokens must produce a STRICTLY + DIFFERENT id sequence on a prompt that triggers the merge. + + If someone "simplifies" build_prompt_tokens to call encode() on the + full string, this assertion fires. + """ + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(_HUNYUAN_MODEL_ID, trust_remote_code=True) + + user_prompt = "写一首关于夜的诗。" + seg_ids = build_prompt_tokens(user_prompt, tok, task="i2t") + full_ids = tok.encode(build_prompt(user_prompt, task="i2t"), add_special_tokens=False) + + assert seg_ids != full_ids, ( + "build_prompt_tokens output equals naive full-string encode -- " + "the BPE-merge-bypass behavior is no longer exercised. This means " + "the segment-by-segment fix from PR #3243 has been silently undone." + ) + + # Segmenting prevents merges, so the segment id list should have AT LEAST + # as many tokens as the merged version (a merge consumes 2+ ids -> 1). + assert len(seg_ids) >= len(full_ids), ( + f"segment-encoded length ({len(seg_ids)}) shorter than full-string " + f"merged length ({len(full_ids)}) -- impossible if segmentation is " + f"genuinely bypassing merges." + ) From 78b817342c5d9d14b42622c66f3d2e76f7fd31ab Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Thu, 30 Apr 2026 16:39:44 +0800 Subject: [PATCH 14/16] test(hunyuan_image3): apply ruff format to prompt_utils regression tests CI ruff format check required collapsing short multi-line constructs (single-line assertions, single-line if conditions) onto one line. No semantic change; 19/19 tests still pass. Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../hunyuan_image3/test_prompt_utils.py | 28 +++++-------------- 1 file changed, 7 insertions(+), 21 deletions(-) diff --git a/tests/diffusion/models/hunyuan_image3/test_prompt_utils.py b/tests/diffusion/models/hunyuan_image3/test_prompt_utils.py index 1010f393fed..501664fe688 100644 --- a/tests/diffusion/models/hunyuan_image3/test_prompt_utils.py +++ b/tests/diffusion/models/hunyuan_image3/test_prompt_utils.py @@ -110,18 +110,14 @@ def test_build_prompt_string_structure_chat_template(task: str): # fix: trigger goes AFTER `Assistant: `, not before user_prompt). if task in ("it2i_think", "t2i_think"): assert s.endswith("Assistant: "), ( - "Trigger must be appended right after `Assistant: ` (Part A fix). " - f"Got tail: ...{s[-40:]!r}" + f"Trigger must be appended right after `Assistant: ` (Part A fix). Got tail: ...{s[-40:]!r}" ) if task in ("it2i_recaption", "t2i_recaption"): assert s.endswith("Assistant: "), ( - "Trigger must be appended right after `Assistant: ` (Part A fix). " - f"Got tail: ...{s[-40:]!r}" + f"Trigger must be appended right after `Assistant: ` (Part A fix). Got tail: ...{s[-40:]!r}" ) if task in ("t2t", "i2t"): - assert s.endswith("Assistant: "), ( - "Plain (no-trigger) task must end at `Assistant: ` with no trailing tag." - ) + assert s.endswith("Assistant: "), "Plain (no-trigger) task must end at `Assistant: ` with no trailing tag." def test_build_prompt_vanilla_uses_pretrain_template(): @@ -162,9 +158,7 @@ def test_build_prompt_tokens_segments_each_boundary(): # No call must contain user_prompt glued to neighboring text. for call in tok.encode_calls: if call != "写诗。": - assert "写诗。" not in call, ( - f"user_prompt leaked into a multi-segment encode call: {call!r}" - ) + assert "写诗。" not in call, f"user_prompt leaked into a multi-segment encode call: {call!r}" def test_build_prompt_tokens_image_placeholder_present_for_image_tasks(): @@ -221,9 +215,7 @@ def test_end2end_routes_through_shared_prompt_utils(): no User:/Assistant: framing, etc.) without touching prompt_utils, bypassing every other test in this file. """ - end2end_path = ( - _repo_root() / "examples" / "offline_inference" / "hunyuan_image3" / "end2end.py" - ) + end2end_path = _repo_root() / "examples" / "offline_inference" / "hunyuan_image3" / "end2end.py" assert end2end_path.is_file(), f"end2end.py not found at {end2end_path}" tree = ast.parse(end2end_path.read_text(encoding="utf-8")) @@ -240,11 +232,7 @@ def test_end2end_routes_through_shared_prompt_utils(): imported_from_prompt_utils: set[str] = set() for node in ast.walk(tree): - if ( - isinstance(node, ast.ImportFrom) - and node.module - and node.module.endswith("hunyuan_image3.prompt_utils") - ): + if isinstance(node, ast.ImportFrom) and node.module and node.module.endswith("hunyuan_image3.prompt_utils"): imported_from_prompt_utils.update(alias.name for alias in node.names) assert "build_prompt_tokens" in imported_from_prompt_utils, ( "end2end.py must import build_prompt_tokens from " @@ -261,9 +249,7 @@ def test_end2end_routes_through_shared_prompt_utils(): def _hf_cached(model_id: str) -> bool: hf_home = os.environ.get("HF_HOME") or os.path.expanduser("~/.cache/huggingface") - snap_dir = os.path.join( - hf_home, "hub", f"models--{model_id.replace('/', '--')}", "snapshots" - ) + snap_dir = os.path.join(hf_home, "hub", f"models--{model_id.replace('/', '--')}", "snapshots") return os.path.isdir(snap_dir) and any(os.scandir(snap_dir)) From c5089d1258552f5ed8febf974949a5331bf34aa0 Mon Sep 17 00:00:00 2001 From: zuiho <2324465096@qq.com> Date: Tue, 5 May 2026 21:11:26 +0800 Subject: [PATCH 15/16] fix(hunyuan_image3): adapt to vllm 0.20 SharedFusedMoE removal MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #3232 [Rebase] Rebase to vllm 0.20.0 folded `SharedFusedMoE` into `FusedMoE` and dropped the `vllm.model_executor.layers.fused_moe.shared_fused_moe` submodule, which broke pytest collection for tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_sampler.py with `ModuleNotFoundError: No module named 'vllm.model_executor.layers.fused_moe.shared_fused_moe'` across all four CI suites on this branch (simple-unit-test, diffusion-cache-backend-test, cuda-unit-test-with-single-card, cuda-unit-test-with-multi-cards). Mirrors the same minimal fix applied on cr/pr3107-rebased: - Wrap the legacy import in try/except and fall back to `FusedMoE as SharedFusedMoE`. `FusedMoE` now accepts `shared_experts=` directly and the call sites only use `make_expert_params_mapping` and `__init__(shared_experts=..., ...)`, both present on `FusedMoE`. - Drop `reduce_results=False` from the `SharedFusedMoE(...)` call — vllm 0.20 removed that parameter from `FusedMoE.__init__`. - Drop the manual `(routed, shared)` tuple merge and `tensor_model_parallel_all_reduce` post-processing in `HunyuanImage3SparseMoeBlock.forward`. vllm 0.20+ `FusedMoE` merges shared-experts internally and runs the TP all-reduce inside its forward, so the result is the already-combined, already-reduced tensor. Signed-off-by: zuiho <2324465096@qq.com> --- .../models/hunyuan_image3/hunyuan_image3.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index d80797d432c..2cc154db6dc 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -29,7 +29,16 @@ from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import fused_moe_make_expert_params_mapping -from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE + +try: + from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE +except ImportError: + # PyPI vllm 0.20.x neither exports `SharedFusedMoE` from the package top-level + # nor ships a `shared_fused_moe.py` submodule. The functionality lives on + # `FusedMoE` directly (which gained a `shared_experts` parameter), so alias + # the symbol — call sites only use the classmethod `make_expert_params_mapping` + # and `__init__(shared_experts=..., ...)` which are present on `FusedMoE`. + from vllm.model_executor.layers.fused_moe import FusedMoE as SharedFusedMoE from vllm.model_executor.layers.linear import ( ColumnParallelLinear, ReplicatedLinear, @@ -1261,7 +1270,6 @@ def __init__( top_k=top_k, hidden_size=config.hidden_size, intermediate_size=intermediate_size, - reduce_results=False, renormalize=False, quant_config=quant_config, prefix=f"{prefix}.experts", @@ -1298,11 +1306,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # transport — the indices get cast back to int32 on unpack. packed_routing = torch.cat([topk_weights.float(), topk_indices.to(torch.float32)], dim=-1) + # vllm 0.20+ FusedMoE merges shared-experts internally and runs the + # TP all-reduce inside its forward (we no longer pass + # `reduce_results=False`). The tuple `(routed, shared)` return shape + # from the legacy SharedFusedMoE is gone; the result is the + # already-combined, already-reduced tensor. final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=packed_routing) - if self.shared_mlp is not None: - final_hidden_states = final_hidden_states[0] + final_hidden_states[1] - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(orig_shape) From 0ec3d173aa35a1a0138274084914ed5b96fd9b2d Mon Sep 17 00:00:00 2001 From: zuiho <2324465096@qq.com> Date: Tue, 5 May 2026 21:15:19 +0800 Subject: [PATCH 16/16] fix(hunyuan_image3): drop unused tensor_model_parallel_all_reduce import Lint follow-up to c5089d12. The previous commit removed the call to `tensor_model_parallel_all_reduce` in `HunyuanImage3SparseMoeBlock.forward` (vllm 0.20+ FusedMoE runs the TP all-reduce internally) but left the symbol in the `from vllm.distributed import (...)` block, which `ruff` flags as unused (F401). Signed-off-by: zuiho <2324465096@qq.com> --- vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index 2cc154db6dc..d2552bdddbf 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -24,7 +24,6 @@ get_ep_group, get_pp_group, get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, ) from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger