diff --git a/examples/offline_inference/hunyuan_image3/prompt_utils.py b/examples/offline_inference/hunyuan_image3/prompt_utils.py new file mode 100644 index 0000000000..a5ef8e1536 --- /dev/null +++ b/examples/offline_inference/hunyuan_image3/prompt_utils.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Prompt construction utilities for HunyuanImage-3.0-Instruct examples. + +Wraps system_prompt.get_system_prompt() with task-aware presets so that +examples and tests don't need to manually concatenate system prompts, +, , and tags. + +Usage: + from prompt_utils import build_prompt + + # IT2I (image editing, think+recaption mode) + prompt = build_prompt("Make the petals neon pink", task="it2i_think") + + # I2T (image understanding) + prompt = build_prompt("Describe the content of the picture.", task="i2t") +""" + +from __future__ import annotations + +from vllm_omni.diffusion.models.hunyuan_image3.system_prompt import ( + get_system_prompt, +) + +# task → (sys_type, bot_task, trigger_tag) +# trigger_tag: "", "", or None +_TASK_PRESETS: dict[str, tuple[str, str | None, str | None]] = { + # Pure text generation (text → text, no image) + "t2t": ("en_unified", None, None), + # Image understanding (image → text) + "i2t": ("en_unified", None, None), + # Image editing (image+text → image), think+recaption mode + "it2i_think": ("en_unified", "think", ""), + # Image editing, recaption-only mode + "it2i_recaption": ("en_unified", "recaption", ""), + # Text-to-image, think mode + "t2i_think": ("en_unified", "think", ""), + # Text-to-image, recaption mode + "t2i_recaption": ("en_unified", "recaption", ""), + # Text-to-image, vanilla (no CoT) + "t2i_vanilla": ("en_vanilla", "image", None), +} + + +def build_prompt( + user_prompt: str, + task: str = "it2i_think", + sys_type: str | None = None, + custom_system_prompt: str | None = None, +) -> str: + """Build a complete HunyuanImage-3.0 prompt with auto-selected system + prompt and mode trigger tags. + + Args: + user_prompt: The user's raw instruction or question. + task: One of the preset task keys (see _TASK_PRESETS). + sys_type: Override the preset's sys_type for get_system_prompt(). + custom_system_prompt: Custom system prompt text (used when + sys_type="custom"). + + Returns: + Fully formatted prompt string ready for Omni.generate(). + """ + if task not in _TASK_PRESETS: + raise ValueError(f"Unknown task {task!r}. Choose from: {sorted(_TASK_PRESETS)}") + + preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task] + effective_sys_type = sys_type or preset_sys_type + + system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt) + sys_text = system_prompt.strip() if system_prompt else "" + + has_image_input = task.startswith("i2t") or task.startswith("it2i") + + parts = ["<|startoftext|>"] + if sys_text: + parts.append(sys_text) + # Instruct conversation template: \n\nUser: ... \n\nAssistant: + parts.append("\n\nUser: ") + if has_image_input: + parts.append("") + parts.append(user_prompt) + parts.append("\n\nAssistant: ") + if trigger_tag: + parts.append(trigger_tag) + + return "".join(parts) diff --git a/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_sampler.py b/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_sampler.py new file mode 100644 index 0000000000..51f6a85f58 --- /dev/null +++ b/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_sampler.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for HunyuanImage3 AR sampler logic (stage transitions, +ratio restriction, comprehension blocking).""" + +import pytest +import torch + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +# Fake token IDs for testing (avoid importing the real model). +END_OF_THINK = 100 +RECAPTION = 101 +END_OF_RECAPTION = 102 +ANSWER = 103 +BOI = 104 +SIZE_TOKEN = 105 +EOS = 106 +RATIO_START = 200 +RATIO_END = 210 +RATIO_OTHER_START = 220 +RATIO_OTHER_END = 223 + + +class FakeSamplerModel: + """Minimal stub that replicates the sampler-relevant attributes of + HunyuanImage3ForConditionalGeneration without loading real weights.""" + + def __init__(self, *, is_comprehension: bool = False): + self._is_comprehension = is_comprehension + self._eos_token_id = EOS + self._end_of_think_id = END_OF_THINK + self._recaption_id = RECAPTION + self._end_of_recaption_id = END_OF_RECAPTION + self._answer_id = ANSWER + self._mrope_boi_token_id = BOI + self._size_token_id = SIZE_TOKEN + self._start_ratio_id = RATIO_START + self._end_ratio_id = RATIO_END + self._ratio_other_slices = [(RATIO_OTHER_START, RATIO_OTHER_END + 1)] + self._all_ratio_ids = set(range(RATIO_START, RATIO_END + 1)) + self._all_ratio_ids.update(range(RATIO_OTHER_START, RATIO_OTHER_END + 1)) + + self._stage_transitions: dict[int, list[int]] = {} + if not is_comprehension: + self._stage_transitions[END_OF_THINK] = [RECAPTION] + self._stage_transitions[END_OF_RECAPTION] = [ANSWER, BOI, SIZE_TOKEN] + + self._blocked_token_ids: set[int] = set() + if is_comprehension: + self._blocked_token_ids.update([BOI, SIZE_TOKEN]) + self._blocked_token_ids.update(self._all_ratio_ids) + + # Bind the real methods from the model class. + from vllm_omni.model_executor.models.hunyuan_image3.hunyuan_image3 import ( + HunyuanImage3ForConditionalGeneration as _Real, + ) + + _get_forced_token = _Real._get_forced_token + _apply_ratio_restriction = _Real._apply_ratio_restriction + + +class TestGetForcedToken: + """Tests for the stateless _get_forced_token method.""" + + def setup_method(self): + self.model = FakeSamplerModel(is_comprehension=False) + + def test_no_trigger_returns_none(self): + assert self.model._get_forced_token([1, 2, 3]) is None + + def test_empty_history_returns_none(self): + assert self.model._get_forced_token([]) is None + + def test_end_of_think_forces_recaption(self): + assert self.model._get_forced_token([END_OF_THINK]) == RECAPTION + + def test_end_of_think_completed(self): + assert self.model._get_forced_token([END_OF_THINK, RECAPTION]) is None + + def test_end_of_recaption_forces_answer(self): + tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION] + assert self.model._get_forced_token(tokens) == ANSWER + + def test_end_of_recaption_forces_boi_after_answer(self): + tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION, ANSWER] + assert self.model._get_forced_token(tokens) == BOI + + def test_end_of_recaption_forces_size_after_boi(self): + tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION, ANSWER, BOI] + assert self.model._get_forced_token(tokens) == SIZE_TOKEN + + def test_full_sequence_complete(self): + tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION, ANSWER, BOI, SIZE_TOKEN] + assert self.model._get_forced_token(tokens) is None + + def test_diverged_history_returns_none(self): + tokens = [END_OF_RECAPTION, 999] # 999 != ANSWER + assert self.model._get_forced_token(tokens) is None + + def test_later_trigger_takes_precedence(self): + tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION] + assert self.model._get_forced_token(tokens) == ANSWER + + def test_trigger_with_extra_tokens_before(self): + tokens = [1, 2, 3, END_OF_THINK] + assert self.model._get_forced_token(tokens) == RECAPTION + + +class TestComprehensionBlocking: + """Tests for comprehension mode token blocking.""" + + def test_blocked_tokens_masked(self): + model = FakeSamplerModel(is_comprehension=True) + vocab_size = 300 + logits = torch.zeros(1, vocab_size) + logits[0, BOI] = 5.0 + logits[0, SIZE_TOKEN] = 3.0 + logits[0, RATIO_START] = 2.0 + min_score = torch.finfo(logits.dtype).min + + for tid in model._blocked_token_ids: + if tid < vocab_size: + logits[0, tid] = min_score + + assert logits[0, BOI].item() == min_score + assert logits[0, SIZE_TOKEN].item() == min_score + assert logits[0, RATIO_START].item() == min_score + + def test_non_blocked_tokens_preserved(self): + model = FakeSamplerModel(is_comprehension=True) + vocab_size = 300 + logits = torch.zeros(1, vocab_size) + logits[0, 50] = 7.0 + min_score = torch.finfo(logits.dtype).min + + for tid in model._blocked_token_ids: + if tid < vocab_size: + logits[0, tid] = min_score + + assert logits[0, 50].item() == 7.0 + + +class TestRatioRestriction: + """Tests for _apply_ratio_restriction (greedy: only argmax ratio survives).""" + + def test_greedy_selects_single_ratio_token(self): + model = FakeSamplerModel(is_comprehension=False) + vocab_size = 300 + logits = torch.zeros(1, vocab_size) + logits[0, RATIO_START + 3] = 10.0 + logits[0, RATIO_START + 1] = 5.0 + logits[0, 50] = 20.0 # non-ratio, should be masked + min_score = torch.finfo(logits.dtype).min + + model._apply_ratio_restriction(logits, 0, min_score) + + assert logits[0, RATIO_START + 3].item() == 0 + assert logits[0, RATIO_START + 1].item() == min_score + assert logits[0, 50].item() == min_score + + def test_extra_ratio_slices_considered(self): + model = FakeSamplerModel(is_comprehension=False) + vocab_size = 300 + logits = torch.zeros(1, vocab_size) + logits[0, RATIO_OTHER_START] = 15.0 + logits[0, RATIO_START] = 5.0 + min_score = torch.finfo(logits.dtype).min + + model._apply_ratio_restriction(logits, 0, min_score) + + assert logits[0, RATIO_OTHER_START].item() == 0 + assert logits[0, RATIO_START].item() == min_score + + +class TestForceEosAfterRatio: + """Tests that a ratio token as last_token forces EOS.""" + + def test_ratio_token_forces_eos(self): + model = FakeSamplerModel(is_comprehension=False) + vocab_size = 300 + logits = torch.randn(1, vocab_size) + min_score = torch.finfo(logits.dtype).min + + logits[0].fill_(min_score) + logits[0, model._eos_token_id] = 0 + + assert logits[0, EOS].item() == 0 + non_eos_max = logits[0, :EOS].max().item() + assert non_eos_max == min_score diff --git a/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py b/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py index 6898763e40..ec4f4693d7 100644 --- a/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py +++ b/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py @@ -17,7 +17,7 @@ MODEL_NAME = "tencent/HunyuanImage-3.0" LOCAL_CLIP_PATH = "openai/clip-vit-base-patch32" REPO_ROOT = Path(__file__).resolve().parents[3] -STAGE_CONFIG_PATH = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "hunyuan_image3_moe.yaml" +STAGE_CONFIG_PATH = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "hunyuan_image3_t2i.yaml" pytestmark = [pytest.mark.advanced_model, pytest.mark.diffusion] diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index 5c280ddcf4..6304eeab29 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -77,7 +77,9 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils.tensor_schema import TensorSchema +from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import Sampler from vllm_omni.model_executor.models.hunyuan_image3.autoencoder_kl_3d import AutoencoderKLConv3D from vllm_omni.model_executor.models.hunyuan_image3.siglip2 import LightProjector, Siglip2VisionTransformer @@ -175,8 +177,11 @@ def contains_unexpected_keyword(name, keywords): return True return False + skipped_unexpected: set[str] = set() + for name, loaded_weight in weights: if contains_unexpected_keyword(name, unexpected_keywords): + skipped_unexpected.add(name) continue if "rotary_emb.inv_freq" in name: @@ -362,6 +367,17 @@ def contains_unexpected_keyword(name, keywords): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) + + if skipped_unexpected: + logger.warning_once( + "Skipped %d weights matching unexpected_keywords " + "(e.g. vae, vision_model, patch_embed, timestep_emb). " + "If upstream renamed components, these may be silently " + "lost. Skipped names: %s", + len(skipped_unexpected), + sorted(skipped_unexpected)[:10], + ) + return loaded_params @@ -1149,6 +1165,8 @@ class HunyuanImage3ForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo HunyuanImage3Inputs: TypeAlias = HunyuanImage3PixelInputs + prefer_model_sampler = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1199,6 +1217,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = PPMissingLayer() + # --- AR-stage components --- + # These are needed for image encoding in the AR stage. + # If a future text-only stage is added, gate on vllm_config.model_config.model_stage. + # vae self.vae = AutoencoderKLConv3D.from_config(config.vae) self.patch_embed = UNetDown( @@ -1226,6 +1248,63 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._mrope_joint_img_sep_token_id = tokenizer.convert_tokens_to_ids("") self._mrope_max_num_patches = config.vit_processor.get("max_num_patches", 729) + # Special token IDs for logits processors (stage transitions). + # These mirror the official tokenization_hunyuan_image_3.py setup. + self._end_of_think_id = tokenizer.convert_tokens_to_ids("") + self._recaption_id = tokenizer.convert_tokens_to_ids("") + self._end_of_recaption_id = tokenizer.convert_tokens_to_ids("") + self._answer_id = tokenizer.convert_tokens_to_ids("") + self._end_of_answer_id = tokenizer.convert_tokens_to_ids("") + image_base_size = getattr(config, "image_base_size", 1024) + self._size_token_id = tokenizer.convert_tokens_to_ids(f"") + self._start_ratio_id = tokenizer.convert_tokens_to_ids("") + self._end_ratio_id = tokenizer.convert_tokens_to_ids("") + ratio_33 = tokenizer.convert_tokens_to_ids("") + ratio_36 = tokenizer.convert_tokens_to_ids("") + self._ratio_other_slices = [(ratio_33, ratio_36 + 1)] + # Build the full set of ratio token IDs for use as stop tokens. + self._all_ratio_ids = set(range(self._start_ratio_id, self._end_ratio_id + 1)) + for s, e in self._ratio_other_slices: + self._all_ratio_ids.update(range(s, e)) + + # Determine mode: comprehension (I2T/T2T) vs generation (IT2I/T2I). + engine_output_type = getattr(vllm_config.model_config, "engine_output_type", None) + self._is_comprehension = engine_output_type in (None, "text") + + # For comprehension mode, block image generation tokens but allow + # text structure tokens (, , etc.) so the model can + # follow its natural generation pattern. Stop tokens in YAML will + # terminate at or EOS. + self._blocked_token_ids: set[int] = set() + if self._is_comprehension: + self._blocked_token_ids.update( + [ + self._mrope_boi_token_id, # + self._mrope_eoi_token_id, # + self._size_token_id, # + ] + ) + self._blocked_token_ids.update(self._all_ratio_ids) + + # For generation mode, build stage transition map. + # Official logic: → [], + # → [, , ] + # After , restrict vocab to ratio tokens only. + # Stage-transition forced sequences, keyed by trigger token. + self._stage_transitions: dict[int, list[int]] = {} + if not self._is_comprehension: + self._stage_transitions[self._end_of_think_id] = [ + self._recaption_id, + ] + self._stage_transitions[self._end_of_recaption_id] = [ + self._answer_id, + self._mrope_boi_token_id, + self._size_token_id, + ] + + self._sampler: Sampler | None = None + self._eos_token_id: int = tokenizer.eos_token_id + self._replace_rotary_embeddings() def _replace_rotary_embeddings(self): @@ -1257,6 +1336,12 @@ def _replace_rotary_embeddings(self): head_dim, rope_theta, ) + if replaced == 0: + raise RuntimeError( + "HunyuanImage3: _replace_rotary_embeddings replaced 0 layers. " + "The custom interleaved 2D mRoPE is not active — model outputs " + "will be incorrect. Check that model.layers[*].self_attn.rotary_emb exists." + ) def _parse_and_validate_image_input( self, @@ -1274,6 +1359,10 @@ def _parse_and_validate_image_input( if vit_pixel_values is None or vae_pixel_values is None: return None + # Handle empty batch (e.g., during profiling with 0 images / T2T mode) + if vit_pixel_values.numel() == 0 or vae_pixel_values.numel() == 0: + return None + return HunyuanImage3PixelInputs( type="pixel_values", pixel_values={ @@ -1472,6 +1561,112 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits + # ------------------------------------------------------------------ + # Custom sampler — applies HunyuanImage3-specific logits processors + # before the standard sampling step. + # + # Comprehension (I2T / T2T): + # Block generation-specific special tokens so sampling can't + # accidentally produce , , ratio tokens, etc. + # + # Generation (IT2I / T2I think): + # 1. _StageTransitionLogitsProcessor — force token sequences at + # transition boundaries (, etc.) + # 2. _ConditionalSliceVocabLogitsProcessor — after , + # restrict vocab to ratio tokens only (greedy). + # ------------------------------------------------------------------ + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput | None: + if logits is None or logits.numel() == 0: + return None + + if self._sampler is None: + self._sampler = Sampler() + + min_score = torch.finfo(logits.dtype).min + + assert logits.shape[0] == 1, f"HunyuanImage3 sampler requires max_num_seqs=1, got batch size {logits.shape[0]}" + + for req_idx in range(logits.shape[0]): + decoded_tokens: list[int] = ( + sampling_metadata.output_token_ids[req_idx] if req_idx < len(sampling_metadata.output_token_ids) else [] + ) + last_token = decoded_tokens[-1] if decoded_tokens else -1 + + if self._is_comprehension: + for tid in self._blocked_token_ids: + logits[req_idx, tid] = min_score + else: + forced = self._get_forced_token(decoded_tokens) + if forced is not None: + logits[req_idx].fill_(min_score) + logits[req_idx, forced] = 0 + elif last_token == self._size_token_id: + self._apply_ratio_restriction(logits, req_idx, min_score) + elif last_token in self._all_ratio_ids: + logits[req_idx].fill_(min_score) + logits[req_idx, self._eos_token_id] = 0 + + return self._sampler(logits=logits, sampling_metadata=sampling_metadata) + + def _get_forced_token(self, decoded_tokens: list[int]) -> int | None: + """Derive the next forced token from output history (stateless). + + Scans decoded_tokens backwards for the most recent trigger token, + then prefix-matches the forced sequence against what followed. + Returns the next token to force, or None if the sequence is complete + or history has diverged from the expected forced sequence. + """ + for i in range(len(decoded_tokens) - 1, -1, -1): + trigger = decoded_tokens[i] + if trigger not in self._stage_transitions: + continue + + forced_seq = self._stage_transitions[trigger] + emitted = decoded_tokens[i + 1 :] + + matched = 0 + for expected, actual in zip(forced_seq, emitted): + if actual != expected: + # History diverged from the expected forced sequence. + # Stop applying transition forcing for safety. + return None + matched += 1 + + if matched < len(forced_seq): + return forced_seq[matched] + return None + + return None + + def _apply_ratio_restriction( + self, + logits: torch.Tensor, + req_idx: int, + min_score: float, + ) -> None: + """Port of official _ConditionalSliceVocabLogitsProcessor.__call__. + + After the size token, only allow ratio tokens and pick greedily. + """ + original = logits[req_idx].clone() + logits[req_idx].fill_(min_score) + # Allow primary ratio range. + logits[req_idx, self._start_ratio_id : self._end_ratio_id + 1] = original[ + self._start_ratio_id : self._end_ratio_id + 1 + ] + # Allow extra ratio slices. + for s, e in self._ratio_other_slices: + logits[req_idx, s:e] = original[s:e] + # Force greedy: keep only the argmax. + max_id = logits[req_idx].argmax().item() + logits[req_idx].fill_(min_score) + logits[req_idx, max_id] = 0 + def make_empty_intermediate_tensors( self, batch_size: int, dtype: torch.dtype, device: torch.device ) -> IntermediateTensors: diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml new file mode 100644 index 0000000000..203b54f257 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml @@ -0,0 +1,44 @@ +# Stage config for HunyuanImage-3.0 Image-to-Text (I2T / image understanding). +# Single LLM stage: AR model reads image + text prompt, generates text output. + +stage_args: + - stage_id: 0 + stage_type: llm + runtime: + process: true + devices: "0,1,2,3" + max_batch_size: 1 + requires_multimodal_data: true + engine_args: + model_stage: AR + max_num_seqs: 1 + model_arch: HunyuanImage3ForCausalMM + worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.95 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + hf_overrides: + rope_parameters: + mrope_section: [0, 32, 32] + rope_type: default + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.0 + top_p: 0.95 + top_k: 1024 + max_tokens: 2048 + stop_token_ids: [127957, 128026] # <|endoftext|>, + detokenize: True + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml new file mode 100644 index 0000000000..9f6adece0f --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml @@ -0,0 +1,78 @@ +# Stage config for HunyuanImage-3.0 Image+Text-to-Image (image editing). +# Stage 0: AR (HunyuanImage3ForConditionalGeneration) — reads (image, text), emits latent tokens +# Stage 1: Diffusion (HunyuanImage3Pipeline / DiT + VAE) — denoise + decode latents → image + +stage_args: + # Stage 0: AR Model + - stage_id: 0 + stage_type: llm + runtime: + process: true + devices: "0,1,2,3" + max_batch_size: 1 + requires_multimodal_data: true # AR needs the original image + engine_args: + model_stage: AR + model_arch: HunyuanImage3ForCausalMM + worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.95 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # AR outputs latent for DiT + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + hf_overrides: + rope_parameters: + mrope_section: [0, 32, 32] + rope_type: default + is_comprehension: false # Generation task, not comprehension + final_output: false # AR is not the final output + default_sampling_params: + temperature: 0.6 + top_p: 0.95 + top_k: 1024 + max_tokens: 4096 + stop_token_ids: [127957] # <|endoftext|> + detokenize: false + + # Stage 1: Diffusion (DiT + VAE) + # Receives latents from AR stage, performs denoising + VAE decode + - stage_id: 1 + stage_type: diffusion + runtime: + process: true + devices: "4,5,6,7" + max_batch_size: 1 + requires_multimodal_data: true # May need condition images + engine_args: + model_stage: dit + model_arch: HunyuanImage3ForCausalMM + enforce_eager: true + trust_remote_code: true + distributed_executor_backend: "mp" + parallel_config: + tensor_parallel_size: 4 + enable_expert_parallel: true + omni_kv_config: + need_recv_cache: true + engine_input_source: [0] # Input from AR stage + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.hunyuan_image3.ar2diffusion + final_output: true + final_output_type: image + default_sampling_params: + num_inference_steps: 50 + guidance_scale: 2.5 + +# Top-level runtime config +runtime: + enabled: true + defaults: + window_size: -1 # Trigger downstream only after full upstream completion + max_inflight: 1 # Process serially within each stage + edges: + - from: 0 # AR → Diffusion + to: 1 + window_size: -1 diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml deleted file mode 100644 index 808b4619f7..0000000000 --- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml +++ /dev/null @@ -1,81 +0,0 @@ -# Stage config for running Hunyuan-Image3.0 for multi-stage omni runtime. -# Stage 0: AR Model (vLLM implementation) - -# The following config has been verified on 8x L40S-48G GPU. -modes: - - mode: text-to-image - stages: [1] - - mode: image-to-text - stages: [0] -stage_args: - - stage_id: 0 - stage_type: llm # Use llm stage type for AR stages - runtime: - process: true # Run this stage in a separate process - devices: "0,1,2,3,4,5,6,7" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) - engine_args: - model_stage: AR - max_num_seqs: 1 - model_arch: HunyuanImage3ForCausalMM - worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.3 - enforce_eager: true # Now we only support eager mode - trust_remote_code: true - engine_output_type: latent - enable_prefix_caching: false - max_num_batched_tokens: 32768 - tensor_parallel_size: 8 - pipeline_parallel_size: 1 - hf_overrides: - rope_parameters: - mrope_section: [0, 32, 32] - rope_type: default - is_comprehension: true - final_output: true - final_output_type: text - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.1 - - stage_id: 1 - stage_type: diffusion - runtime: - process: true - devices: "0,1,2,3,4,5,6,7" - max_batch_size: 1 - engine_args: - model_stage: diffusion - enforce_eager: true - distributed_executor_backend: "mp" - vae_use_slicing: false - vae_use_tiling: false - cache_backend: null - cache_config: null - enable_cache_dit_summary: false - parallel_config: - pipeline_parallel_size: 1 - data_parallel_size: 1 - tensor_parallel_size: 8 - enable_expert_parallel: false - sequence_parallel_size: 1 - ulysses_degree: 1 - ring_degree: 1 - cfg_parallel_size: 1 - vae_patch_parallel_size: 1 - use_hsdp: false - hsdp_shard_size: -1 - hsdp_replicate_size: 1 - final_output: true - final_output_type: image - -# Top-level runtime config (concise): default windows and stage edges -runtime: - enabled: true - defaults: - window_size: -1 # Simplified: trigger downstream only after full upstream completion - max_inflight: 1 # Simplified: process serially within each stage diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml new file mode 100644 index 0000000000..60da8e0bc7 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml @@ -0,0 +1,45 @@ +# Stage config for HunyuanImage-3.0 Text-to-Text (T2T / pure text generation). +# Single LLM stage: AR model reads text prompt only, generates text output. +# Sampling params aligned with official generation_config.json. + +stage_args: + - stage_id: 0 + stage_type: llm + runtime: + process: true + devices: "0,1,2,3" + max_batch_size: 1 + requires_multimodal_data: false + engine_args: + model_stage: AR + max_num_seqs: 1 + model_arch: HunyuanImage3ForCausalMM + worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.95 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + hf_overrides: + rope_parameters: + mrope_section: [0, 32, 32] + rope_type: default + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.0 + top_p: 0.95 + top_k: 1024 + max_tokens: 2048 + stop_token_ids: [127957, 128026] # <|endoftext|>, + detokenize: True + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 diff --git a/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py b/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py new file mode 100644 index 0000000000..89a7a28f6c --- /dev/null +++ b/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Stage input processor for HunyuanImage3: AR → Diffusion transition. + +In IT2I (image editing) mode: + - Stage 0 (AR) receives (image + edit instruction), generates CoT/latent tokens + - Stage 1 (DiT) receives the AR output + original image, denoises → edited image + +The ar2diffusion function bridges these two stages, following the same +signature pattern as glm_image.ar2diffusion. +""" + +from typing import Any + +import torch +from vllm.inputs import TextPrompt +from vllm.logger import init_logger + +from vllm_omni.inputs.data import OmniTokensPrompt + +logger = init_logger(__name__) + + +def ar2diffusion( + stage_list: list[Any], + engine_input_source: list[int], + prompt: OmniTokensPrompt | TextPrompt | list | None = None, + requires_multimodal_data: bool = False, +) -> list[dict[str, Any]]: + """Process AR stage outputs to create Diffusion stage inputs. + + Args: + stage_list: List of stage clients (set by orchestrator). + engine_input_source: List of source stage IDs (from YAML). + prompt: Original user prompt (may contain multimodal data). + requires_multimodal_data: Whether to forward multimodal data. + + Returns: + List of dicts, each consumable by the HunyuanImage3 diffusion pipeline. + """ + if not engine_input_source: + raise ValueError("engine_input_source cannot be empty") + + source_stage_id = engine_input_source[0] + if source_stage_id >= len(stage_list): + raise IndexError(f"Invalid source stage_id: {source_stage_id}") + + if stage_list[source_stage_id].engine_outputs is None: + raise RuntimeError(f"Stage {source_stage_id} has no outputs yet") + + ar_outputs = stage_list[source_stage_id].engine_outputs + diffusion_inputs = [] + + # Normalize prompt to list + if not isinstance(prompt, list): + prompt = [prompt] if prompt is not None else [{}] + + for i, ar_output in enumerate(ar_outputs): + output = ar_output.outputs[0] + generated_token_ids = output.token_ids + generated_text = getattr(output, "text", "") or "" + + # Get original prompt info + original_prompt = prompt[i] if i < len(prompt) else {} + if isinstance(original_prompt, dict): + pass + elif hasattr(original_prompt, "_asdict"): + original_prompt = original_prompt._asdict() + elif hasattr(original_prompt, "__dict__"): + original_prompt = vars(original_prompt) + else: + original_prompt = {} + + height = original_prompt.get("height", 1024) + width = original_prompt.get("width", 1024) + text_prompt = original_prompt.get("prompt", "") + + logger.info( + "[ar2diffusion] Request %d: AR generated %d tokens, text length=%d, target size=%dx%d", + i, + len(generated_token_ids), + len(generated_text), + height, + width, + ) + + token_tensor = torch.tensor(generated_token_ids, dtype=torch.long) + + diffusion_input: dict[str, Any] = { + "prompt": text_prompt, + "height": height, + "width": width, + "extra": { + "ar_token_ids": token_tensor, + "ar_generated_text": generated_text, + }, + } + + # Forward multimodal data (original image for IT2I conditioning) + mm_data = original_prompt.get("multi_modal_data") + if mm_data: + pil_image = mm_data.get("image") + if pil_image is None: + images = mm_data.get("images") + if images: + pil_image = images[0] if isinstance(images, list) else images + if pil_image is not None: + diffusion_input["pil_image"] = pil_image + + # Forward multimodal output from AR (if any) + if hasattr(ar_output, "multimodal_output") and ar_output.multimodal_output: + mm_output = ar_output.multimodal_output + if isinstance(mm_output, dict): + diffusion_input["extra"]["ar_multimodal_output"] = mm_output + + # Forward sampling params + for key in ["seed", "num_inference_steps", "guidance_scale", "negative_prompt"]: + if key in original_prompt: + diffusion_input[key] = original_prompt[key] + + diffusion_inputs.append(diffusion_input) + + return diffusion_inputs diff --git a/vllm_omni/patch.py b/vllm_omni/patch.py index eafff821a2..d4ab78f13a 100644 --- a/vllm_omni/patch.py +++ b/vllm_omni/patch.py @@ -1,6 +1,8 @@ import sys +from functools import cached_property from aenum import extend_enum +from vllm.config import ModelConfig as _OriginalModelConfig from vllm.inputs import TokensPrompt as _OriginalTokensPrompt from vllm.model_executor.layers.rotary_embedding import ( MRotaryEmbedding as _OriginalMRotaryEmbedding, @@ -17,6 +19,56 @@ from vllm_omni.model_executor.layers.rotary_embedding import OmniMRotaryEmbedding from vllm_omni.request import OmniRequest +# ============================================================================= +# Patch ModelConfig.is_mm_prefix_lm to support omni-specific models +# ============================================================================= +# WHY: HunyuanImage-3.0 requires bidirectional attention for image tokens +# (cond_token_attn_type: "joint_full" in config.json). vLLM gates this on +# is_mm_prefix_lm, which checks an internal MM_PREFIX_LM_MODELS list that +# does not include "hunyuan_image_3_moe" (the upstream HF model_type). +# +# WHY NOT model-level: is_mm_prefix_lm is checked in vLLM core (scheduler, +# attention backend selection) before model code runs — no model-level hook. +# +# SCOPE: Only affects model_type in _OMNI_MM_PREFIX_LM_MODELS (currently +# just "hunyuan_image_3_moe"). All other models fall through to the +# original vLLM implementation unchanged. +# +# FRAGILITY: Relies on is_mm_prefix_lm being a cached_property on +# ModelConfig. The __dict__ access + __set_name__ dance works around a +# pydantic dataclass issue in vllm 0.19.0+. If vLLM changes +# is_mm_prefix_lm to a regular method or removes it, this will break. +# +# TODO: Upstream a configurable MM_PREFIX_LM_MODELS or a model_config flag +# so this patch can be removed. +_OMNI_MM_PREFIX_LM_MODELS = ("hunyuan_image_3_moe",) +# Access via __dict__ to avoid triggering cached_property.__get__ which fails +# with "Cannot use cached_property instance without calling __set_name__" in +# pydantic dataclasses (vllm 0.19.0+). +_cp = _OriginalModelConfig.__dict__["is_mm_prefix_lm"] +_original_is_mm_prefix_lm = _cp.func if hasattr(_cp, "func") else _cp.fget + + +def _patched_is_mm_prefix_lm(self): + if _original_is_mm_prefix_lm(self): + return True + model_type = getattr(self.hf_config, "model_type", "") + return model_type in _OMNI_MM_PREFIX_LM_MODELS + + +_patched_cp = cached_property(_patched_is_mm_prefix_lm) +_patched_cp.__set_name__(_OriginalModelConfig, "is_mm_prefix_lm") +_OriginalModelConfig.is_mm_prefix_lm = _patched_cp + +# Sanity check: verify the patch is active. If vLLM changes the descriptor +# type or __set_name__ semantics, this will fail loudly at import time +# rather than silently falling back to unpatched behavior. +_installed = _OriginalModelConfig.__dict__.get("is_mm_prefix_lm") +assert _installed is _patched_cp, ( + "is_mm_prefix_lm patch failed to install — bidirectional attention " + "for HunyuanImage3 will not work. Check vLLM ModelConfig changes." +) + # ============================================================================= # Patch GlmImageTextConfig to expose mrope_section in rope_parameters # =============================================================================