diff --git a/examples/offline_inference/hunyuan_image3/prompt_utils.py b/examples/offline_inference/hunyuan_image3/prompt_utils.py
new file mode 100644
index 0000000000..a5ef8e1536
--- /dev/null
+++ b/examples/offline_inference/hunyuan_image3/prompt_utils.py
@@ -0,0 +1,88 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Prompt construction utilities for HunyuanImage-3.0-Instruct examples.
+
+Wraps system_prompt.get_system_prompt() with task-aware presets so that
+examples and tests don't need to manually concatenate system prompts,
+
, , and tags.
+
+Usage:
+ from prompt_utils import build_prompt
+
+ # IT2I (image editing, think+recaption mode)
+ prompt = build_prompt("Make the petals neon pink", task="it2i_think")
+
+ # I2T (image understanding)
+ prompt = build_prompt("Describe the content of the picture.", task="i2t")
+"""
+
+from __future__ import annotations
+
+from vllm_omni.diffusion.models.hunyuan_image3.system_prompt import (
+ get_system_prompt,
+)
+
+# task → (sys_type, bot_task, trigger_tag)
+# trigger_tag: "", "", or None
+_TASK_PRESETS: dict[str, tuple[str, str | None, str | None]] = {
+ # Pure text generation (text → text, no image)
+ "t2t": ("en_unified", None, None),
+ # Image understanding (image → text)
+ "i2t": ("en_unified", None, None),
+ # Image editing (image+text → image), think+recaption mode
+ "it2i_think": ("en_unified", "think", ""),
+ # Image editing, recaption-only mode
+ "it2i_recaption": ("en_unified", "recaption", ""),
+ # Text-to-image, think mode
+ "t2i_think": ("en_unified", "think", ""),
+ # Text-to-image, recaption mode
+ "t2i_recaption": ("en_unified", "recaption", ""),
+ # Text-to-image, vanilla (no CoT)
+ "t2i_vanilla": ("en_vanilla", "image", None),
+}
+
+
+def build_prompt(
+ user_prompt: str,
+ task: str = "it2i_think",
+ sys_type: str | None = None,
+ custom_system_prompt: str | None = None,
+) -> str:
+ """Build a complete HunyuanImage-3.0 prompt with auto-selected system
+ prompt and mode trigger tags.
+
+ Args:
+ user_prompt: The user's raw instruction or question.
+ task: One of the preset task keys (see _TASK_PRESETS).
+ sys_type: Override the preset's sys_type for get_system_prompt().
+ custom_system_prompt: Custom system prompt text (used when
+ sys_type="custom").
+
+ Returns:
+ Fully formatted prompt string ready for Omni.generate().
+ """
+ if task not in _TASK_PRESETS:
+ raise ValueError(f"Unknown task {task!r}. Choose from: {sorted(_TASK_PRESETS)}")
+
+ preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task]
+ effective_sys_type = sys_type or preset_sys_type
+
+ system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt)
+ sys_text = system_prompt.strip() if system_prompt else ""
+
+ has_image_input = task.startswith("i2t") or task.startswith("it2i")
+
+ parts = ["<|startoftext|>"]
+ if sys_text:
+ parts.append(sys_text)
+ # Instruct conversation template: \n\nUser: ... \n\nAssistant:
+ parts.append("\n\nUser: ")
+ if has_image_input:
+ parts.append("
")
+ parts.append(user_prompt)
+ parts.append("\n\nAssistant: ")
+ if trigger_tag:
+ parts.append(trigger_tag)
+
+ return "".join(parts)
diff --git a/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_sampler.py b/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_sampler.py
new file mode 100644
index 0000000000..51f6a85f58
--- /dev/null
+++ b/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_sampler.py
@@ -0,0 +1,190 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for HunyuanImage3 AR sampler logic (stage transitions,
+ratio restriction, comprehension blocking)."""
+
+import pytest
+import torch
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+# Fake token IDs for testing (avoid importing the real model).
+END_OF_THINK = 100
+RECAPTION = 101
+END_OF_RECAPTION = 102
+ANSWER = 103
+BOI = 104
+SIZE_TOKEN = 105
+EOS = 106
+RATIO_START = 200
+RATIO_END = 210
+RATIO_OTHER_START = 220
+RATIO_OTHER_END = 223
+
+
+class FakeSamplerModel:
+ """Minimal stub that replicates the sampler-relevant attributes of
+ HunyuanImage3ForConditionalGeneration without loading real weights."""
+
+ def __init__(self, *, is_comprehension: bool = False):
+ self._is_comprehension = is_comprehension
+ self._eos_token_id = EOS
+ self._end_of_think_id = END_OF_THINK
+ self._recaption_id = RECAPTION
+ self._end_of_recaption_id = END_OF_RECAPTION
+ self._answer_id = ANSWER
+ self._mrope_boi_token_id = BOI
+ self._size_token_id = SIZE_TOKEN
+ self._start_ratio_id = RATIO_START
+ self._end_ratio_id = RATIO_END
+ self._ratio_other_slices = [(RATIO_OTHER_START, RATIO_OTHER_END + 1)]
+ self._all_ratio_ids = set(range(RATIO_START, RATIO_END + 1))
+ self._all_ratio_ids.update(range(RATIO_OTHER_START, RATIO_OTHER_END + 1))
+
+ self._stage_transitions: dict[int, list[int]] = {}
+ if not is_comprehension:
+ self._stage_transitions[END_OF_THINK] = [RECAPTION]
+ self._stage_transitions[END_OF_RECAPTION] = [ANSWER, BOI, SIZE_TOKEN]
+
+ self._blocked_token_ids: set[int] = set()
+ if is_comprehension:
+ self._blocked_token_ids.update([BOI, SIZE_TOKEN])
+ self._blocked_token_ids.update(self._all_ratio_ids)
+
+ # Bind the real methods from the model class.
+ from vllm_omni.model_executor.models.hunyuan_image3.hunyuan_image3 import (
+ HunyuanImage3ForConditionalGeneration as _Real,
+ )
+
+ _get_forced_token = _Real._get_forced_token
+ _apply_ratio_restriction = _Real._apply_ratio_restriction
+
+
+class TestGetForcedToken:
+ """Tests for the stateless _get_forced_token method."""
+
+ def setup_method(self):
+ self.model = FakeSamplerModel(is_comprehension=False)
+
+ def test_no_trigger_returns_none(self):
+ assert self.model._get_forced_token([1, 2, 3]) is None
+
+ def test_empty_history_returns_none(self):
+ assert self.model._get_forced_token([]) is None
+
+ def test_end_of_think_forces_recaption(self):
+ assert self.model._get_forced_token([END_OF_THINK]) == RECAPTION
+
+ def test_end_of_think_completed(self):
+ assert self.model._get_forced_token([END_OF_THINK, RECAPTION]) is None
+
+ def test_end_of_recaption_forces_answer(self):
+ tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION]
+ assert self.model._get_forced_token(tokens) == ANSWER
+
+ def test_end_of_recaption_forces_boi_after_answer(self):
+ tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION, ANSWER]
+ assert self.model._get_forced_token(tokens) == BOI
+
+ def test_end_of_recaption_forces_size_after_boi(self):
+ tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION, ANSWER, BOI]
+ assert self.model._get_forced_token(tokens) == SIZE_TOKEN
+
+ def test_full_sequence_complete(self):
+ tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION, ANSWER, BOI, SIZE_TOKEN]
+ assert self.model._get_forced_token(tokens) is None
+
+ def test_diverged_history_returns_none(self):
+ tokens = [END_OF_RECAPTION, 999] # 999 != ANSWER
+ assert self.model._get_forced_token(tokens) is None
+
+ def test_later_trigger_takes_precedence(self):
+ tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION]
+ assert self.model._get_forced_token(tokens) == ANSWER
+
+ def test_trigger_with_extra_tokens_before(self):
+ tokens = [1, 2, 3, END_OF_THINK]
+ assert self.model._get_forced_token(tokens) == RECAPTION
+
+
+class TestComprehensionBlocking:
+ """Tests for comprehension mode token blocking."""
+
+ def test_blocked_tokens_masked(self):
+ model = FakeSamplerModel(is_comprehension=True)
+ vocab_size = 300
+ logits = torch.zeros(1, vocab_size)
+ logits[0, BOI] = 5.0
+ logits[0, SIZE_TOKEN] = 3.0
+ logits[0, RATIO_START] = 2.0
+ min_score = torch.finfo(logits.dtype).min
+
+ for tid in model._blocked_token_ids:
+ if tid < vocab_size:
+ logits[0, tid] = min_score
+
+ assert logits[0, BOI].item() == min_score
+ assert logits[0, SIZE_TOKEN].item() == min_score
+ assert logits[0, RATIO_START].item() == min_score
+
+ def test_non_blocked_tokens_preserved(self):
+ model = FakeSamplerModel(is_comprehension=True)
+ vocab_size = 300
+ logits = torch.zeros(1, vocab_size)
+ logits[0, 50] = 7.0
+ min_score = torch.finfo(logits.dtype).min
+
+ for tid in model._blocked_token_ids:
+ if tid < vocab_size:
+ logits[0, tid] = min_score
+
+ assert logits[0, 50].item() == 7.0
+
+
+class TestRatioRestriction:
+ """Tests for _apply_ratio_restriction (greedy: only argmax ratio survives)."""
+
+ def test_greedy_selects_single_ratio_token(self):
+ model = FakeSamplerModel(is_comprehension=False)
+ vocab_size = 300
+ logits = torch.zeros(1, vocab_size)
+ logits[0, RATIO_START + 3] = 10.0
+ logits[0, RATIO_START + 1] = 5.0
+ logits[0, 50] = 20.0 # non-ratio, should be masked
+ min_score = torch.finfo(logits.dtype).min
+
+ model._apply_ratio_restriction(logits, 0, min_score)
+
+ assert logits[0, RATIO_START + 3].item() == 0
+ assert logits[0, RATIO_START + 1].item() == min_score
+ assert logits[0, 50].item() == min_score
+
+ def test_extra_ratio_slices_considered(self):
+ model = FakeSamplerModel(is_comprehension=False)
+ vocab_size = 300
+ logits = torch.zeros(1, vocab_size)
+ logits[0, RATIO_OTHER_START] = 15.0
+ logits[0, RATIO_START] = 5.0
+ min_score = torch.finfo(logits.dtype).min
+
+ model._apply_ratio_restriction(logits, 0, min_score)
+
+ assert logits[0, RATIO_OTHER_START].item() == 0
+ assert logits[0, RATIO_START].item() == min_score
+
+
+class TestForceEosAfterRatio:
+ """Tests that a ratio token as last_token forces EOS."""
+
+ def test_ratio_token_forces_eos(self):
+ model = FakeSamplerModel(is_comprehension=False)
+ vocab_size = 300
+ logits = torch.randn(1, vocab_size)
+ min_score = torch.finfo(logits.dtype).min
+
+ logits[0].fill_(min_score)
+ logits[0, model._eos_token_id] = 0
+
+ assert logits[0, EOS].item() == 0
+ non_eos_max = logits[0, :EOS].max().item()
+ assert non_eos_max == min_score
diff --git a/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py b/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py
index 6898763e40..ec4f4693d7 100644
--- a/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py
+++ b/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py
@@ -17,7 +17,7 @@
MODEL_NAME = "tencent/HunyuanImage-3.0"
LOCAL_CLIP_PATH = "openai/clip-vit-base-patch32"
REPO_ROOT = Path(__file__).resolve().parents[3]
-STAGE_CONFIG_PATH = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "hunyuan_image3_moe.yaml"
+STAGE_CONFIG_PATH = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "hunyuan_image3_t2i.yaml"
pytestmark = [pytest.mark.advanced_model, pytest.mark.diffusion]
diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py
index 5c280ddcf4..6304eeab29 100644
--- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py
+++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py
@@ -77,7 +77,9 @@
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils.tensor_schema import TensorSchema
+from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
+from vllm.v1.sample.sampler import Sampler
from vllm_omni.model_executor.models.hunyuan_image3.autoencoder_kl_3d import AutoencoderKLConv3D
from vllm_omni.model_executor.models.hunyuan_image3.siglip2 import LightProjector, Siglip2VisionTransformer
@@ -175,8 +177,11 @@ def contains_unexpected_keyword(name, keywords):
return True
return False
+ skipped_unexpected: set[str] = set()
+
for name, loaded_weight in weights:
if contains_unexpected_keyword(name, unexpected_keywords):
+ skipped_unexpected.add(name)
continue
if "rotary_emb.inv_freq" in name:
@@ -362,6 +367,17 @@ def contains_unexpected_keyword(name, keywords):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
+
+ if skipped_unexpected:
+ logger.warning_once(
+ "Skipped %d weights matching unexpected_keywords "
+ "(e.g. vae, vision_model, patch_embed, timestep_emb). "
+ "If upstream renamed components, these may be silently "
+ "lost. Skipped names: %s",
+ len(skipped_unexpected),
+ sorted(skipped_unexpected)[:10],
+ )
+
return loaded_params
@@ -1149,6 +1165,8 @@ class HunyuanImage3ForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo
HunyuanImage3Inputs: TypeAlias = HunyuanImage3PixelInputs
+ prefer_model_sampler = True
+
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -1199,6 +1217,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
else:
self.lm_head = PPMissingLayer()
+ # --- AR-stage components ---
+ # These are needed for image encoding in the AR stage.
+ # If a future text-only stage is added, gate on vllm_config.model_config.model_stage.
+
# vae
self.vae = AutoencoderKLConv3D.from_config(config.vae)
self.patch_embed = UNetDown(
@@ -1226,6 +1248,63 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._mrope_joint_img_sep_token_id = tokenizer.convert_tokens_to_ids("")
self._mrope_max_num_patches = config.vit_processor.get("max_num_patches", 729)
+ # Special token IDs for logits processors (stage transitions).
+ # These mirror the official tokenization_hunyuan_image_3.py setup.
+ self._end_of_think_id = tokenizer.convert_tokens_to_ids("")
+ self._recaption_id = tokenizer.convert_tokens_to_ids("")
+ self._end_of_recaption_id = tokenizer.convert_tokens_to_ids("")
+ self._answer_id = tokenizer.convert_tokens_to_ids("")
+ self._end_of_answer_id = tokenizer.convert_tokens_to_ids("")
+ image_base_size = getattr(config, "image_base_size", 1024)
+ self._size_token_id = tokenizer.convert_tokens_to_ids(f"")
+ self._start_ratio_id = tokenizer.convert_tokens_to_ids("")
+ self._end_ratio_id = tokenizer.convert_tokens_to_ids("")
+ ratio_33 = tokenizer.convert_tokens_to_ids("")
+ ratio_36 = tokenizer.convert_tokens_to_ids("")
+ self._ratio_other_slices = [(ratio_33, ratio_36 + 1)]
+ # Build the full set of ratio token IDs for use as stop tokens.
+ self._all_ratio_ids = set(range(self._start_ratio_id, self._end_ratio_id + 1))
+ for s, e in self._ratio_other_slices:
+ self._all_ratio_ids.update(range(s, e))
+
+ # Determine mode: comprehension (I2T/T2T) vs generation (IT2I/T2I).
+ engine_output_type = getattr(vllm_config.model_config, "engine_output_type", None)
+ self._is_comprehension = engine_output_type in (None, "text")
+
+ # For comprehension mode, block image generation tokens but allow
+ # text structure tokens (, , etc.) so the model can
+ # follow its natural generation pattern. Stop tokens in YAML will
+ # terminate at or EOS.
+ self._blocked_token_ids: set[int] = set()
+ if self._is_comprehension:
+ self._blocked_token_ids.update(
+ [
+ self._mrope_boi_token_id, #
+ self._mrope_eoi_token_id, #
+ self._size_token_id, #
+ ]
+ )
+ self._blocked_token_ids.update(self._all_ratio_ids)
+
+ # For generation mode, build stage transition map.
+ # Official logic: → [],
+ # → [, , ]
+ # After , restrict vocab to ratio tokens only.
+ # Stage-transition forced sequences, keyed by trigger token.
+ self._stage_transitions: dict[int, list[int]] = {}
+ if not self._is_comprehension:
+ self._stage_transitions[self._end_of_think_id] = [
+ self._recaption_id,
+ ]
+ self._stage_transitions[self._end_of_recaption_id] = [
+ self._answer_id,
+ self._mrope_boi_token_id,
+ self._size_token_id,
+ ]
+
+ self._sampler: Sampler | None = None
+ self._eos_token_id: int = tokenizer.eos_token_id
+
self._replace_rotary_embeddings()
def _replace_rotary_embeddings(self):
@@ -1257,6 +1336,12 @@ def _replace_rotary_embeddings(self):
head_dim,
rope_theta,
)
+ if replaced == 0:
+ raise RuntimeError(
+ "HunyuanImage3: _replace_rotary_embeddings replaced 0 layers. "
+ "The custom interleaved 2D mRoPE is not active — model outputs "
+ "will be incorrect. Check that model.layers[*].self_attn.rotary_emb exists."
+ )
def _parse_and_validate_image_input(
self,
@@ -1274,6 +1359,10 @@ def _parse_and_validate_image_input(
if vit_pixel_values is None or vae_pixel_values is None:
return None
+ # Handle empty batch (e.g., during profiling with 0 images / T2T mode)
+ if vit_pixel_values.numel() == 0 or vae_pixel_values.numel() == 0:
+ return None
+
return HunyuanImage3PixelInputs(
type="pixel_values",
pixel_values={
@@ -1472,6 +1561,112 @@ def compute_logits(
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
+ # ------------------------------------------------------------------
+ # Custom sampler — applies HunyuanImage3-specific logits processors
+ # before the standard sampling step.
+ #
+ # Comprehension (I2T / T2T):
+ # Block generation-specific special tokens so sampling can't
+ # accidentally produce , , ratio tokens, etc.
+ #
+ # Generation (IT2I / T2I think):
+ # 1. _StageTransitionLogitsProcessor — force token sequences at
+ # transition boundaries ( → , etc.)
+ # 2. _ConditionalSliceVocabLogitsProcessor — after ,
+ # restrict vocab to ratio tokens only (greedy).
+ # ------------------------------------------------------------------
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> SamplerOutput | None:
+ if logits is None or logits.numel() == 0:
+ return None
+
+ if self._sampler is None:
+ self._sampler = Sampler()
+
+ min_score = torch.finfo(logits.dtype).min
+
+ assert logits.shape[0] == 1, f"HunyuanImage3 sampler requires max_num_seqs=1, got batch size {logits.shape[0]}"
+
+ for req_idx in range(logits.shape[0]):
+ decoded_tokens: list[int] = (
+ sampling_metadata.output_token_ids[req_idx] if req_idx < len(sampling_metadata.output_token_ids) else []
+ )
+ last_token = decoded_tokens[-1] if decoded_tokens else -1
+
+ if self._is_comprehension:
+ for tid in self._blocked_token_ids:
+ logits[req_idx, tid] = min_score
+ else:
+ forced = self._get_forced_token(decoded_tokens)
+ if forced is not None:
+ logits[req_idx].fill_(min_score)
+ logits[req_idx, forced] = 0
+ elif last_token == self._size_token_id:
+ self._apply_ratio_restriction(logits, req_idx, min_score)
+ elif last_token in self._all_ratio_ids:
+ logits[req_idx].fill_(min_score)
+ logits[req_idx, self._eos_token_id] = 0
+
+ return self._sampler(logits=logits, sampling_metadata=sampling_metadata)
+
+ def _get_forced_token(self, decoded_tokens: list[int]) -> int | None:
+ """Derive the next forced token from output history (stateless).
+
+ Scans decoded_tokens backwards for the most recent trigger token,
+ then prefix-matches the forced sequence against what followed.
+ Returns the next token to force, or None if the sequence is complete
+ or history has diverged from the expected forced sequence.
+ """
+ for i in range(len(decoded_tokens) - 1, -1, -1):
+ trigger = decoded_tokens[i]
+ if trigger not in self._stage_transitions:
+ continue
+
+ forced_seq = self._stage_transitions[trigger]
+ emitted = decoded_tokens[i + 1 :]
+
+ matched = 0
+ for expected, actual in zip(forced_seq, emitted):
+ if actual != expected:
+ # History diverged from the expected forced sequence.
+ # Stop applying transition forcing for safety.
+ return None
+ matched += 1
+
+ if matched < len(forced_seq):
+ return forced_seq[matched]
+ return None
+
+ return None
+
+ def _apply_ratio_restriction(
+ self,
+ logits: torch.Tensor,
+ req_idx: int,
+ min_score: float,
+ ) -> None:
+ """Port of official _ConditionalSliceVocabLogitsProcessor.__call__.
+
+ After the size token, only allow ratio tokens and pick greedily.
+ """
+ original = logits[req_idx].clone()
+ logits[req_idx].fill_(min_score)
+ # Allow primary ratio range.
+ logits[req_idx, self._start_ratio_id : self._end_ratio_id + 1] = original[
+ self._start_ratio_id : self._end_ratio_id + 1
+ ]
+ # Allow extra ratio slices.
+ for s, e in self._ratio_other_slices:
+ logits[req_idx, s:e] = original[s:e]
+ # Force greedy: keep only the argmax.
+ max_id = logits[req_idx].argmax().item()
+ logits[req_idx].fill_(min_score)
+ logits[req_idx, max_id] = 0
+
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype, device: torch.device
) -> IntermediateTensors:
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml
new file mode 100644
index 0000000000..203b54f257
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml
@@ -0,0 +1,44 @@
+# Stage config for HunyuanImage-3.0 Image-to-Text (I2T / image understanding).
+# Single LLM stage: AR model reads image + text prompt, generates text output.
+
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ runtime:
+ process: true
+ devices: "0,1,2,3"
+ max_batch_size: 1
+ requires_multimodal_data: true
+ engine_args:
+ model_stage: AR
+ max_num_seqs: 1
+ model_arch: HunyuanImage3ForCausalMM
+ worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.95
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 4
+ pipeline_parallel_size: 1
+ hf_overrides:
+ rope_parameters:
+ mrope_section: [0, 32, 32]
+ rope_type: default
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 0.95
+ top_k: 1024
+ max_tokens: 2048
+ stop_token_ids: [127957, 128026] # <|endoftext|>,
+ detokenize: True
+
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml
new file mode 100644
index 0000000000..9f6adece0f
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml
@@ -0,0 +1,78 @@
+# Stage config for HunyuanImage-3.0 Image+Text-to-Image (image editing).
+# Stage 0: AR (HunyuanImage3ForConditionalGeneration) — reads (image, text), emits latent tokens
+# Stage 1: Diffusion (HunyuanImage3Pipeline / DiT + VAE) — denoise + decode latents → image
+
+stage_args:
+ # Stage 0: AR Model
+ - stage_id: 0
+ stage_type: llm
+ runtime:
+ process: true
+ devices: "0,1,2,3"
+ max_batch_size: 1
+ requires_multimodal_data: true # AR needs the original image
+ engine_args:
+ model_stage: AR
+ model_arch: HunyuanImage3ForCausalMM
+ worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.95
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: latent # AR outputs latent for DiT
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 4
+ pipeline_parallel_size: 1
+ hf_overrides:
+ rope_parameters:
+ mrope_section: [0, 32, 32]
+ rope_type: default
+ is_comprehension: false # Generation task, not comprehension
+ final_output: false # AR is not the final output
+ default_sampling_params:
+ temperature: 0.6
+ top_p: 0.95
+ top_k: 1024
+ max_tokens: 4096
+ stop_token_ids: [127957] # <|endoftext|>
+ detokenize: false
+
+ # Stage 1: Diffusion (DiT + VAE)
+ # Receives latents from AR stage, performs denoising + VAE decode
+ - stage_id: 1
+ stage_type: diffusion
+ runtime:
+ process: true
+ devices: "4,5,6,7"
+ max_batch_size: 1
+ requires_multimodal_data: true # May need condition images
+ engine_args:
+ model_stage: dit
+ model_arch: HunyuanImage3ForCausalMM
+ enforce_eager: true
+ trust_remote_code: true
+ distributed_executor_backend: "mp"
+ parallel_config:
+ tensor_parallel_size: 4
+ enable_expert_parallel: true
+ omni_kv_config:
+ need_recv_cache: true
+ engine_input_source: [0] # Input from AR stage
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.hunyuan_image3.ar2diffusion
+ final_output: true
+ final_output_type: image
+ default_sampling_params:
+ num_inference_steps: 50
+ guidance_scale: 2.5
+
+# Top-level runtime config
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1 # Trigger downstream only after full upstream completion
+ max_inflight: 1 # Process serially within each stage
+ edges:
+ - from: 0 # AR → Diffusion
+ to: 1
+ window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml
deleted file mode 100644
index 808b4619f7..0000000000
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml
+++ /dev/null
@@ -1,81 +0,0 @@
-# Stage config for running Hunyuan-Image3.0 for multi-stage omni runtime.
-# Stage 0: AR Model (vLLM implementation)
-
-# The following config has been verified on 8x L40S-48G GPU.
-modes:
- - mode: text-to-image
- stages: [1]
- - mode: image-to-text
- stages: [0]
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true # Run this stage in a separate process
- devices: "0,1,2,3,4,5,6,7" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
- engine_args:
- model_stage: AR
- max_num_seqs: 1
- model_arch: HunyuanImage3ForCausalMM
- worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.3
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- tensor_parallel_size: 8
- pipeline_parallel_size: 1
- hf_overrides:
- rope_parameters:
- mrope_section: [0, 32, 32]
- rope_type: default
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- - stage_id: 1
- stage_type: diffusion
- runtime:
- process: true
- devices: "0,1,2,3,4,5,6,7"
- max_batch_size: 1
- engine_args:
- model_stage: diffusion
- enforce_eager: true
- distributed_executor_backend: "mp"
- vae_use_slicing: false
- vae_use_tiling: false
- cache_backend: null
- cache_config: null
- enable_cache_dit_summary: false
- parallel_config:
- pipeline_parallel_size: 1
- data_parallel_size: 1
- tensor_parallel_size: 8
- enable_expert_parallel: false
- sequence_parallel_size: 1
- ulysses_degree: 1
- ring_degree: 1
- cfg_parallel_size: 1
- vae_patch_parallel_size: 1
- use_hsdp: false
- hsdp_shard_size: -1
- hsdp_replicate_size: 1
- final_output: true
- final_output_type: image
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml
new file mode 100644
index 0000000000..60da8e0bc7
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml
@@ -0,0 +1,45 @@
+# Stage config for HunyuanImage-3.0 Text-to-Text (T2T / pure text generation).
+# Single LLM stage: AR model reads text prompt only, generates text output.
+# Sampling params aligned with official generation_config.json.
+
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ runtime:
+ process: true
+ devices: "0,1,2,3"
+ max_batch_size: 1
+ requires_multimodal_data: false
+ engine_args:
+ model_stage: AR
+ max_num_seqs: 1
+ model_arch: HunyuanImage3ForCausalMM
+ worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.95
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 4
+ pipeline_parallel_size: 1
+ hf_overrides:
+ rope_parameters:
+ mrope_section: [0, 32, 32]
+ rope_type: default
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 0.95
+ top_k: 1024
+ max_tokens: 2048
+ stop_token_ids: [127957, 128026] # <|endoftext|>,
+ detokenize: True
+
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
diff --git a/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py b/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py
new file mode 100644
index 0000000000..89a7a28f6c
--- /dev/null
+++ b/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py
@@ -0,0 +1,123 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Stage input processor for HunyuanImage3: AR → Diffusion transition.
+
+In IT2I (image editing) mode:
+ - Stage 0 (AR) receives (image + edit instruction), generates CoT/latent tokens
+ - Stage 1 (DiT) receives the AR output + original image, denoises → edited image
+
+The ar2diffusion function bridges these two stages, following the same
+signature pattern as glm_image.ar2diffusion.
+"""
+
+from typing import Any
+
+import torch
+from vllm.inputs import TextPrompt
+from vllm.logger import init_logger
+
+from vllm_omni.inputs.data import OmniTokensPrompt
+
+logger = init_logger(__name__)
+
+
+def ar2diffusion(
+ stage_list: list[Any],
+ engine_input_source: list[int],
+ prompt: OmniTokensPrompt | TextPrompt | list | None = None,
+ requires_multimodal_data: bool = False,
+) -> list[dict[str, Any]]:
+ """Process AR stage outputs to create Diffusion stage inputs.
+
+ Args:
+ stage_list: List of stage clients (set by orchestrator).
+ engine_input_source: List of source stage IDs (from YAML).
+ prompt: Original user prompt (may contain multimodal data).
+ requires_multimodal_data: Whether to forward multimodal data.
+
+ Returns:
+ List of dicts, each consumable by the HunyuanImage3 diffusion pipeline.
+ """
+ if not engine_input_source:
+ raise ValueError("engine_input_source cannot be empty")
+
+ source_stage_id = engine_input_source[0]
+ if source_stage_id >= len(stage_list):
+ raise IndexError(f"Invalid source stage_id: {source_stage_id}")
+
+ if stage_list[source_stage_id].engine_outputs is None:
+ raise RuntimeError(f"Stage {source_stage_id} has no outputs yet")
+
+ ar_outputs = stage_list[source_stage_id].engine_outputs
+ diffusion_inputs = []
+
+ # Normalize prompt to list
+ if not isinstance(prompt, list):
+ prompt = [prompt] if prompt is not None else [{}]
+
+ for i, ar_output in enumerate(ar_outputs):
+ output = ar_output.outputs[0]
+ generated_token_ids = output.token_ids
+ generated_text = getattr(output, "text", "") or ""
+
+ # Get original prompt info
+ original_prompt = prompt[i] if i < len(prompt) else {}
+ if isinstance(original_prompt, dict):
+ pass
+ elif hasattr(original_prompt, "_asdict"):
+ original_prompt = original_prompt._asdict()
+ elif hasattr(original_prompt, "__dict__"):
+ original_prompt = vars(original_prompt)
+ else:
+ original_prompt = {}
+
+ height = original_prompt.get("height", 1024)
+ width = original_prompt.get("width", 1024)
+ text_prompt = original_prompt.get("prompt", "")
+
+ logger.info(
+ "[ar2diffusion] Request %d: AR generated %d tokens, text length=%d, target size=%dx%d",
+ i,
+ len(generated_token_ids),
+ len(generated_text),
+ height,
+ width,
+ )
+
+ token_tensor = torch.tensor(generated_token_ids, dtype=torch.long)
+
+ diffusion_input: dict[str, Any] = {
+ "prompt": text_prompt,
+ "height": height,
+ "width": width,
+ "extra": {
+ "ar_token_ids": token_tensor,
+ "ar_generated_text": generated_text,
+ },
+ }
+
+ # Forward multimodal data (original image for IT2I conditioning)
+ mm_data = original_prompt.get("multi_modal_data")
+ if mm_data:
+ pil_image = mm_data.get("image")
+ if pil_image is None:
+ images = mm_data.get("images")
+ if images:
+ pil_image = images[0] if isinstance(images, list) else images
+ if pil_image is not None:
+ diffusion_input["pil_image"] = pil_image
+
+ # Forward multimodal output from AR (if any)
+ if hasattr(ar_output, "multimodal_output") and ar_output.multimodal_output:
+ mm_output = ar_output.multimodal_output
+ if isinstance(mm_output, dict):
+ diffusion_input["extra"]["ar_multimodal_output"] = mm_output
+
+ # Forward sampling params
+ for key in ["seed", "num_inference_steps", "guidance_scale", "negative_prompt"]:
+ if key in original_prompt:
+ diffusion_input[key] = original_prompt[key]
+
+ diffusion_inputs.append(diffusion_input)
+
+ return diffusion_inputs
diff --git a/vllm_omni/patch.py b/vllm_omni/patch.py
index eafff821a2..d4ab78f13a 100644
--- a/vllm_omni/patch.py
+++ b/vllm_omni/patch.py
@@ -1,6 +1,8 @@
import sys
+from functools import cached_property
from aenum import extend_enum
+from vllm.config import ModelConfig as _OriginalModelConfig
from vllm.inputs import TokensPrompt as _OriginalTokensPrompt
from vllm.model_executor.layers.rotary_embedding import (
MRotaryEmbedding as _OriginalMRotaryEmbedding,
@@ -17,6 +19,56 @@
from vllm_omni.model_executor.layers.rotary_embedding import OmniMRotaryEmbedding
from vllm_omni.request import OmniRequest
+# =============================================================================
+# Patch ModelConfig.is_mm_prefix_lm to support omni-specific models
+# =============================================================================
+# WHY: HunyuanImage-3.0 requires bidirectional attention for image tokens
+# (cond_token_attn_type: "joint_full" in config.json). vLLM gates this on
+# is_mm_prefix_lm, which checks an internal MM_PREFIX_LM_MODELS list that
+# does not include "hunyuan_image_3_moe" (the upstream HF model_type).
+#
+# WHY NOT model-level: is_mm_prefix_lm is checked in vLLM core (scheduler,
+# attention backend selection) before model code runs — no model-level hook.
+#
+# SCOPE: Only affects model_type in _OMNI_MM_PREFIX_LM_MODELS (currently
+# just "hunyuan_image_3_moe"). All other models fall through to the
+# original vLLM implementation unchanged.
+#
+# FRAGILITY: Relies on is_mm_prefix_lm being a cached_property on
+# ModelConfig. The __dict__ access + __set_name__ dance works around a
+# pydantic dataclass issue in vllm 0.19.0+. If vLLM changes
+# is_mm_prefix_lm to a regular method or removes it, this will break.
+#
+# TODO: Upstream a configurable MM_PREFIX_LM_MODELS or a model_config flag
+# so this patch can be removed.
+_OMNI_MM_PREFIX_LM_MODELS = ("hunyuan_image_3_moe",)
+# Access via __dict__ to avoid triggering cached_property.__get__ which fails
+# with "Cannot use cached_property instance without calling __set_name__" in
+# pydantic dataclasses (vllm 0.19.0+).
+_cp = _OriginalModelConfig.__dict__["is_mm_prefix_lm"]
+_original_is_mm_prefix_lm = _cp.func if hasattr(_cp, "func") else _cp.fget
+
+
+def _patched_is_mm_prefix_lm(self):
+ if _original_is_mm_prefix_lm(self):
+ return True
+ model_type = getattr(self.hf_config, "model_type", "")
+ return model_type in _OMNI_MM_PREFIX_LM_MODELS
+
+
+_patched_cp = cached_property(_patched_is_mm_prefix_lm)
+_patched_cp.__set_name__(_OriginalModelConfig, "is_mm_prefix_lm")
+_OriginalModelConfig.is_mm_prefix_lm = _patched_cp
+
+# Sanity check: verify the patch is active. If vLLM changes the descriptor
+# type or __set_name__ semantics, this will fail loudly at import time
+# rather than silently falling back to unpatched behavior.
+_installed = _OriginalModelConfig.__dict__.get("is_mm_prefix_lm")
+assert _installed is _patched_cp, (
+ "is_mm_prefix_lm patch failed to install — bidirectional attention "
+ "for HunyuanImage3 will not work. Check vLLM ModelConfig changes."
+)
+
# =============================================================================
# Patch GlmImageTextConfig to expose mrope_section in rope_parameters
# =============================================================================