From a27a6902c62ef493a67a308274958819d7709bf1 Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Mon, 13 Apr 2026 19:48:46 +0800 Subject: [PATCH 1/6] [Model] HunyuanImage-3.0-Instruct: model registration, custom sampler, stage configs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add custom sampler with logits processors for AR stage transitions. Ports official _StageTransitionLogitsProcessor and _ConditionalSliceVocabLogitsProcessor into sample() with prefer_model_sampler=True, enabling sampling-based decoding (temperature=0.6, top_k=1024, top_p=0.95) with correct think→recaption→ratio stage transitions. - hunyuan_image3.py: custom sample() with stage transition, ratio restriction, comprehension token blocking, ratio EOS forcing - patch.py: extend is_mm_prefix_lm for bidirectional attention on image tokens (hunyuan_image_3_moe model type). Use __dict__ access for cached_property compat with vllm 0.19.0+ pydantic dataclasses - Stage configs: hunyuan_image3_i2t.yaml (single LLM, TP4), hunyuan_image3_it2i.yaml (2-stage AR→DiT), hunyuan_image3_t2t.yaml - stage_input_processors/hunyuan_image3.py: ar2diffusion() bridge - Delete hunyuan_image3_moe.yaml (replaced by split per-task configs) - Update test_hunyuanimage3_text2img.py to use hunyuan_image3_t2i.yaml Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../test_hunyuanimage3_text2img.py | 2 +- .../models/hunyuan_image3/hunyuan_image3.py | 220 ++++++++++++++++++ .../stage_configs/hunyuan_image3_i2t.yaml | 44 ++++ .../stage_configs/hunyuan_image3_it2i.yaml | 82 +++++++ .../stage_configs/hunyuan_image3_moe.yaml | 81 ------- .../stage_configs/hunyuan_image3_t2t.yaml | 45 ++++ .../stage_input_processors/hunyuan_image3.py | 121 ++++++++++ vllm_omni/patch.py | 34 +++ 8 files changed, 547 insertions(+), 82 deletions(-) create mode 100644 vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml create mode 100644 vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml delete mode 100644 vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml create mode 100644 vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml create mode 100644 vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py diff --git a/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py b/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py index 6898763e40..ec4f4693d7 100644 --- a/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py +++ b/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py @@ -17,7 +17,7 @@ MODEL_NAME = "tencent/HunyuanImage-3.0" LOCAL_CLIP_PATH = "openai/clip-vit-base-patch32" REPO_ROOT = Path(__file__).resolve().parents[3] -STAGE_CONFIG_PATH = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "hunyuan_image3_moe.yaml" +STAGE_CONFIG_PATH = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "hunyuan_image3_t2i.yaml" pytestmark = [pytest.mark.advanced_model, pytest.mark.diffusion] diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index 5c280ddcf4..bbf921a5ee 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -77,7 +77,9 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils.tensor_schema import TensorSchema +from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import Sampler from vllm_omni.model_executor.models.hunyuan_image3.autoencoder_kl_3d import AutoencoderKLConv3D from vllm_omni.model_executor.models.hunyuan_image3.siglip2 import LightProjector, Siglip2VisionTransformer @@ -175,8 +177,11 @@ def contains_unexpected_keyword(name, keywords): return True return False + skipped_unexpected: set[str] = set() + for name, loaded_weight in weights: if contains_unexpected_keyword(name, unexpected_keywords): + skipped_unexpected.add(name) continue if "rotary_emb.inv_freq" in name: @@ -362,6 +367,17 @@ def contains_unexpected_keyword(name, keywords): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) + + if skipped_unexpected: + logger.warning_once( + "Skipped %d weights matching unexpected_keywords " + "(e.g. vae, vision_model, patch_embed, timestep_emb). " + "If upstream renamed components, these may be silently " + "lost. Skipped names: %s", + len(skipped_unexpected), + sorted(skipped_unexpected)[:10], + ) + return loaded_params @@ -1149,6 +1165,8 @@ class HunyuanImage3ForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo HunyuanImage3Inputs: TypeAlias = HunyuanImage3PixelInputs + prefer_model_sampler = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1199,6 +1217,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = PPMissingLayer() + # --- AR-stage components --- + # These are needed for image encoding in the AR stage. + # If a future text-only stage is added, gate on vllm_config.model_config.model_stage. + # vae self.vae = AutoencoderKLConv3D.from_config(config.vae) self.patch_embed = UNetDown( @@ -1226,6 +1248,69 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._mrope_joint_img_sep_token_id = tokenizer.convert_tokens_to_ids("") self._mrope_max_num_patches = config.vit_processor.get("max_num_patches", 729) + # Special token IDs for logits processors (stage transitions). + # These mirror the official tokenization_hunyuan_image_3.py setup. + self._end_of_think_id = tokenizer.convert_tokens_to_ids("") + self._recaption_id = tokenizer.convert_tokens_to_ids("") + self._end_of_recaption_id = tokenizer.convert_tokens_to_ids("") + self._answer_id = tokenizer.convert_tokens_to_ids("") + self._end_of_answer_id = tokenizer.convert_tokens_to_ids("") + image_base_size = getattr(config, "image_base_size", 1024) + self._size_token_id = tokenizer.convert_tokens_to_ids( + f"" + ) + self._start_ratio_id = tokenizer.convert_tokens_to_ids("") + self._end_ratio_id = tokenizer.convert_tokens_to_ids("") + ratio_33 = tokenizer.convert_tokens_to_ids("") + ratio_36 = tokenizer.convert_tokens_to_ids("") + self._ratio_other_slices = [(ratio_33, ratio_36 + 1)] + # Build the full set of ratio token IDs for use as stop tokens. + self._all_ratio_ids = set( + range(self._start_ratio_id, self._end_ratio_id + 1) + ) + for s, e in self._ratio_other_slices: + self._all_ratio_ids.update(range(s, e)) + + # Per-request state for stage-transition logits processor. + # Maps request index → (pending_tokens list, completed set). + self._transition_state: dict[int, tuple[list[int], set[int]]] = {} + + # Determine mode: comprehension (I2T/T2T) vs generation (IT2I/T2I). + engine_output_type = getattr( + vllm_config.model_config, "engine_output_type", None + ) + self._is_comprehension = engine_output_type in (None, "text") + + # For comprehension mode, block generation-specific special tokens. + self._blocked_token_ids: set[int] = set() + if self._is_comprehension: + self._blocked_token_ids.update([ + self._mrope_boi_token_id, # + self._mrope_eoi_token_id, # + self._size_token_id, # + self._answer_id, # + self._end_of_answer_id, # + ]) + self._blocked_token_ids.update(self._all_ratio_ids) + + # For generation mode, build stage transition map. + # Official logic: → [], + # → [, , ] + # After , restrict vocab to ratio tokens only. + self._stage_transitions: dict[int, list[int]] = {} + if not self._is_comprehension: + self._stage_transitions[self._end_of_think_id] = [ + self._recaption_id, + ] + self._stage_transitions[self._end_of_recaption_id] = [ + self._answer_id, + self._mrope_boi_token_id, + self._size_token_id, + ] + + self._sampler: Sampler | None = None + self._eos_token_id: int = 127957 # <|endoftext|> + self._replace_rotary_embeddings() def _replace_rotary_embeddings(self): @@ -1257,6 +1342,12 @@ def _replace_rotary_embeddings(self): head_dim, rope_theta, ) + if replaced == 0: + raise RuntimeError( + "HunyuanImage3: _replace_rotary_embeddings replaced 0 layers. " + "The custom interleaved 2D mRoPE is not active — model outputs " + "will be incorrect. Check that model.layers[*].self_attn.rotary_emb exists." + ) def _parse_and_validate_image_input( self, @@ -1274,6 +1365,10 @@ def _parse_and_validate_image_input( if vit_pixel_values is None or vae_pixel_values is None: return None + # Handle empty batch (e.g., during profiling with 0 images / T2T mode) + if vit_pixel_values.numel() == 0 or vae_pixel_values.numel() == 0: + return None + return HunyuanImage3PixelInputs( type="pixel_values", pixel_values={ @@ -1472,6 +1567,131 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits + # ------------------------------------------------------------------ + # Custom sampler — applies HunyuanImage3-specific logits processors + # before the standard sampling step. + # + # Comprehension (I2T / T2T): + # Block generation-specific special tokens so sampling can't + # accidentally produce , , ratio tokens, etc. + # + # Generation (IT2I / T2I think): + # 1. _StageTransitionLogitsProcessor — force token sequences at + # transition boundaries ( → , etc.) + # 2. _ConditionalSliceVocabLogitsProcessor — after , + # restrict vocab to ratio tokens only (greedy). + # ------------------------------------------------------------------ + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput | None: + if logits is None or logits.numel() == 0: + return None + + if self._sampler is None: + self._sampler = Sampler() + + min_score = torch.finfo(logits.dtype).min + + for req_idx in range(logits.shape[0]): + decoded_tokens: list[int] = ( + sampling_metadata.output_token_ids[req_idx] + if req_idx < len(sampling_metadata.output_token_ids) + else [] + ) + last_token = decoded_tokens[-1] if decoded_tokens else -1 + + if self._is_comprehension: + # Comprehension: mask out generation-specific tokens. + for tid in self._blocked_token_ids: + logits[req_idx, tid] = min_score + else: + # Generation: apply stage-transition logic. + self._apply_stage_transition( + logits, req_idx, last_token, min_score + ) + # After size token → restrict to ratio tokens. + if last_token == self._size_token_id: + self._apply_ratio_restriction( + logits, req_idx, min_score + ) + # After ratio token → force EOS (official uses ratio as + # final_stop_tokens; vLLM stop_token_ids may not include + # all ratio IDs, so we force EOS here). + elif last_token in self._all_ratio_ids: + logits[req_idx].fill_(min_score) + logits[req_idx, self._eos_token_id] = 0 + + return self._sampler( + logits=logits, sampling_metadata=sampling_metadata + ) + + def _apply_stage_transition( + self, + logits: torch.Tensor, + req_idx: int, + last_token: int, + min_score: float, + ) -> None: + """Port of official _StageTransitionLogitsProcessor.__call__.""" + state = self._transition_state.get(req_idx) + if state is None: + state = ([], set()) # (pending_tokens, completed_transitions) + self._transition_state[req_idx] = state + pending, completed = state + + # Consume pending token if last output matches head of queue. + if pending and last_token == pending[0]: + pending.pop(0) + + # If pending tokens remain, force the next one. + if pending: + logits[req_idx].fill_(min_score) + logits[req_idx, pending[0]] = 0 + return + + # Check if last_token triggers a new transition. + if ( + last_token in self._stage_transitions + and last_token not in completed + ): + completed.add(last_token) + next_tokens = self._stage_transitions[last_token] + if next_tokens: + pending.extend(next_tokens) + logits[req_idx].fill_(min_score) + logits[req_idx, pending[0]] = 0 + + def _apply_ratio_restriction( + self, + logits: torch.Tensor, + req_idx: int, + min_score: float, + ) -> None: + """Port of official _ConditionalSliceVocabLogitsProcessor.__call__. + + After the size token, only allow ratio tokens and pick greedily. + """ + original = logits[req_idx].clone() + logits[req_idx].fill_(min_score) + # Allow primary ratio range. + logits[req_idx, self._start_ratio_id : self._end_ratio_id + 1] = ( + original[self._start_ratio_id : self._end_ratio_id + 1] + ) + # Allow extra ratio slices. + for s, e in self._ratio_other_slices: + logits[req_idx, s:e] = original[s:e] + # Force greedy: keep only the argmax. + max_id = logits[req_idx].argmax().item() + logits[req_idx].fill_(min_score) + logits[req_idx, max_id] = 0 + + def _clear_transition_state(self, req_idx: int) -> None: + """Clean up per-request transition state when request finishes.""" + self._transition_state.pop(req_idx, None) + def make_empty_intermediate_tensors( self, batch_size: int, dtype: torch.dtype, device: torch.device ) -> IntermediateTensors: diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml new file mode 100644 index 0000000000..3739f334d2 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml @@ -0,0 +1,44 @@ +# Stage config for HunyuanImage-3.0 Image-to-Text (I2T / image understanding). +# Single LLM stage: AR model reads image + text prompt, generates text output. + +stage_args: + - stage_id: 0 + stage_type: llm + runtime: + process: true + devices: "2,3,4,5" + max_batch_size: 1 + requires_multimodal_data: true + engine_args: + model_stage: AR + max_num_seqs: 1 + model_arch: HunyuanImage3ForCausalMM + worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.95 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + hf_overrides: + rope_parameters: + mrope_section: [0, 32, 32] + rope_type: default + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.6 + top_p: 0.95 + top_k: 1024 + max_tokens: 2048 + stop_token_ids: [127957] # <|endoftext|> + detokenize: True + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml new file mode 100644 index 0000000000..d8598d6f4c --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml @@ -0,0 +1,82 @@ +# Stage config for HunyuanImage-3.0 Image+Text-to-Image (image editing). +# Stage 0: AR (HunyuanImage3ForConditionalGeneration) — reads (image, text), emits latent tokens +# Stage 1: Diffusion (HunyuanImage3Pipeline / DiT + VAE) — denoise + decode latents → image + +stage_args: + # Stage 0: AR Model + - stage_id: 0 + stage_type: llm + runtime: + process: true + devices: "0,1,2,3" + max_batch_size: 1 + requires_multimodal_data: true # AR needs the original image + engine_args: + model_stage: AR + model_arch: HunyuanImage3ForCausalMM + worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.95 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # AR outputs latent for DiT + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + hf_overrides: + rope_parameters: + mrope_section: [0, 32, 32] + rope_type: default + is_comprehension: false # Generation task, not comprehension + final_output: false # AR is not the final output + default_sampling_params: + temperature: 0.6 + top_p: 0.95 + top_k: 1024 + max_tokens: 4096 + stop_token_ids: [127957] # <|endoftext|> + detokenize: false + + # Stage 1: Diffusion (DiT + VAE) + # Receives latents from AR stage, performs denoising + VAE decode + - stage_id: 1 + stage_type: diffusion + runtime: + process: true + devices: "4,5,6,7" + max_batch_size: 1 + requires_multimodal_data: true # May need condition images + engine_args: + model_stage: dit + model_arch: HunyuanImage3ForCausalMM + gpu_memory_utilization: 0.65 + enforce_eager: true + trust_remote_code: true + engine_output_type: image + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + parallel_config: + tensor_parallel_size: 4 + enable_expert_parallel: true + omni_kv_config: + need_recv_cache: true + engine_input_source: [0] # Input from AR stage + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.hunyuan_image3.ar2diffusion + final_output: true + final_output_type: image + default_sampling_params: + num_inference_steps: 50 + guidance_scale: 2.5 + +# Top-level runtime config +runtime: + enabled: true + defaults: + window_size: -1 # Trigger downstream only after full upstream completion + max_inflight: 1 # Process serially within each stage + edges: + - from: 0 # AR → Diffusion + to: 1 + window_size: -1 diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml deleted file mode 100644 index 808b4619f7..0000000000 --- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml +++ /dev/null @@ -1,81 +0,0 @@ -# Stage config for running Hunyuan-Image3.0 for multi-stage omni runtime. -# Stage 0: AR Model (vLLM implementation) - -# The following config has been verified on 8x L40S-48G GPU. -modes: - - mode: text-to-image - stages: [1] - - mode: image-to-text - stages: [0] -stage_args: - - stage_id: 0 - stage_type: llm # Use llm stage type for AR stages - runtime: - process: true # Run this stage in a separate process - devices: "0,1,2,3,4,5,6,7" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) - engine_args: - model_stage: AR - max_num_seqs: 1 - model_arch: HunyuanImage3ForCausalMM - worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.3 - enforce_eager: true # Now we only support eager mode - trust_remote_code: true - engine_output_type: latent - enable_prefix_caching: false - max_num_batched_tokens: 32768 - tensor_parallel_size: 8 - pipeline_parallel_size: 1 - hf_overrides: - rope_parameters: - mrope_section: [0, 32, 32] - rope_type: default - is_comprehension: true - final_output: true - final_output_type: text - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.1 - - stage_id: 1 - stage_type: diffusion - runtime: - process: true - devices: "0,1,2,3,4,5,6,7" - max_batch_size: 1 - engine_args: - model_stage: diffusion - enforce_eager: true - distributed_executor_backend: "mp" - vae_use_slicing: false - vae_use_tiling: false - cache_backend: null - cache_config: null - enable_cache_dit_summary: false - parallel_config: - pipeline_parallel_size: 1 - data_parallel_size: 1 - tensor_parallel_size: 8 - enable_expert_parallel: false - sequence_parallel_size: 1 - ulysses_degree: 1 - ring_degree: 1 - cfg_parallel_size: 1 - vae_patch_parallel_size: 1 - use_hsdp: false - hsdp_shard_size: -1 - hsdp_replicate_size: 1 - final_output: true - final_output_type: image - -# Top-level runtime config (concise): default windows and stage edges -runtime: - enabled: true - defaults: - window_size: -1 # Simplified: trigger downstream only after full upstream completion - max_inflight: 1 # Simplified: process serially within each stage diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml new file mode 100644 index 0000000000..83acc87fab --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml @@ -0,0 +1,45 @@ +# Stage config for HunyuanImage-3.0 Text-to-Text (T2T / pure text generation). +# Single LLM stage: AR model reads text prompt only, generates text output. +# Sampling params aligned with official generation_config.json. + +stage_args: + - stage_id: 0 + stage_type: llm + runtime: + process: true + devices: "2,3,4,5" + max_batch_size: 1 + requires_multimodal_data: false + engine_args: + model_stage: AR + max_num_seqs: 1 + model_arch: HunyuanImage3ForCausalMM + worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.95 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + hf_overrides: + rope_parameters: + mrope_section: [0, 32, 32] + rope_type: default + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.6 + top_p: 0.95 + top_k: 1024 + max_tokens: 2048 + stop_token_ids: [127957] # <|endoftext|> + detokenize: True + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 diff --git a/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py b/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py new file mode 100644 index 0000000000..e7a7569727 --- /dev/null +++ b/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Stage input processor for HunyuanImage3: AR → Diffusion transition. + +In IT2I (image editing) mode: + - Stage 0 (AR) receives (image + edit instruction), generates CoT/latent tokens + - Stage 1 (DiT) receives the AR output + original image, denoises → edited image + +The ar2diffusion function bridges these two stages, following the same +signature pattern as glm_image.ar2diffusion. +""" + +from typing import Any + +import torch +from vllm.inputs import TextPrompt +from vllm.logger import init_logger + +from vllm_omni.inputs.data import OmniTokensPrompt + +logger = init_logger(__name__) + + +def ar2diffusion( + stage_list: list[Any], + engine_input_source: list[int], + prompt: OmniTokensPrompt | TextPrompt | list | None = None, + requires_multimodal_data: bool = False, +) -> list[dict[str, Any]]: + """Process AR stage outputs to create Diffusion stage inputs. + + Args: + stage_list: List of stage clients (set by orchestrator). + engine_input_source: List of source stage IDs (from YAML). + prompt: Original user prompt (may contain multimodal data). + requires_multimodal_data: Whether to forward multimodal data. + + Returns: + List of dicts, each consumable by the HunyuanImage3 diffusion pipeline. + """ + if not engine_input_source: + raise ValueError("engine_input_source cannot be empty") + + source_stage_id = engine_input_source[0] + if source_stage_id >= len(stage_list): + raise IndexError(f"Invalid source stage_id: {source_stage_id}") + + if stage_list[source_stage_id].engine_outputs is None: + raise RuntimeError(f"Stage {source_stage_id} has no outputs yet") + + ar_outputs = stage_list[source_stage_id].engine_outputs + diffusion_inputs = [] + + # Normalize prompt to list + if not isinstance(prompt, list): + prompt = [prompt] if prompt is not None else [{}] + + for i, ar_output in enumerate(ar_outputs): + output = ar_output.outputs[0] + generated_token_ids = output.token_ids + generated_text = getattr(output, "text", "") or "" + + # Get original prompt info + original_prompt = prompt[i] if i < len(prompt) else {} + if isinstance(original_prompt, dict): + pass + elif hasattr(original_prompt, "_asdict"): + original_prompt = original_prompt._asdict() + elif hasattr(original_prompt, "__dict__"): + original_prompt = vars(original_prompt) + else: + original_prompt = {} + + height = original_prompt.get("height", 1024) + width = original_prompt.get("width", 1024) + text_prompt = original_prompt.get("prompt", "") + + logger.info( + "[ar2diffusion] Request %d: AR generated %d tokens, " + "text length=%d, target size=%dx%d", + i, len(generated_token_ids), len(generated_text), height, width, + ) + + token_tensor = torch.tensor(generated_token_ids, dtype=torch.long) + + diffusion_input: dict[str, Any] = { + "prompt": text_prompt, + "height": height, + "width": width, + "extra": { + "ar_token_ids": token_tensor, + "ar_generated_text": generated_text, + }, + } + + # Forward multimodal data (original image for IT2I conditioning) + mm_data = original_prompt.get("multi_modal_data") + if mm_data: + pil_image = mm_data.get("image") + if pil_image is None: + images = mm_data.get("images") + if images: + pil_image = images[0] if isinstance(images, list) else images + if pil_image is not None: + diffusion_input["pil_image"] = pil_image + + # Forward multimodal output from AR (if any) + if hasattr(ar_output, "multimodal_output") and ar_output.multimodal_output: + mm_output = ar_output.multimodal_output + if isinstance(mm_output, dict): + diffusion_input["extra"]["ar_multimodal_output"] = mm_output + + # Forward sampling params + for key in ["seed", "num_inference_steps", "guidance_scale", + "negative_prompt"]: + if key in original_prompt: + diffusion_input[key] = original_prompt[key] + + diffusion_inputs.append(diffusion_input) + + return diffusion_inputs diff --git a/vllm_omni/patch.py b/vllm_omni/patch.py index eafff821a2..2c96f536d8 100644 --- a/vllm_omni/patch.py +++ b/vllm_omni/patch.py @@ -1,6 +1,8 @@ import sys +from functools import cached_property from aenum import extend_enum +from vllm.config import ModelConfig as _OriginalModelConfig from vllm.inputs import TokensPrompt as _OriginalTokensPrompt from vllm.model_executor.layers.rotary_embedding import ( MRotaryEmbedding as _OriginalMRotaryEmbedding, @@ -17,6 +19,38 @@ from vllm_omni.model_executor.layers.rotary_embedding import OmniMRotaryEmbedding from vllm_omni.request import OmniRequest +# ============================================================================= +# Patch ModelConfig.is_mm_prefix_lm to support omni-specific models +# ============================================================================= +# HunyuanImage-3.0 uses bidirectional attention for image token positions +# (cond_token_attn_type: "joint_full" in config.json), but its model_type +# "hunyuan_image_3_moe" is not in vllm's built-in MM_PREFIX_LM_MODELS list. +# This patch extends the check to include omni-specific models. +_OMNI_MM_PREFIX_LM_MODELS = ("hunyuan_image_3_moe",) +# Access via __dict__ to avoid triggering cached_property.__get__ which fails +# with "Cannot use cached_property instance without calling __set_name__" in +# pydantic dataclasses (vllm 0.19.0+). +_cp = _OriginalModelConfig.__dict__["is_mm_prefix_lm"] +_original_is_mm_prefix_lm = _cp.func if hasattr(_cp, "func") else _cp.fget + + +def _patched_is_mm_prefix_lm(self): + if _original_is_mm_prefix_lm(self): + return True + model_type = getattr(self.hf_config, "model_type", "") + return model_type in _OMNI_MM_PREFIX_LM_MODELS + + +_patched_cp = cached_property(_patched_is_mm_prefix_lm) +_patched_cp.__set_name__(_OriginalModelConfig, "is_mm_prefix_lm") +_OriginalModelConfig.is_mm_prefix_lm = _patched_cp + +# Also fix the original cached_property if __set_name__ was never called (vllm 0.19.0+) +_orig_cp = _OriginalModelConfig.__dict__.get("is_mm_prefix_lm") +if _orig_cp is not _patched_cp: + # Our assignment above should have replaced it, but just in case + pass + # ============================================================================= # Patch GlmImageTextConfig to expose mrope_section in rope_parameters # ============================================================================= From 2b7e4ab0e6bb7cabf25bf31e3f50e3abfa8733d8 Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Tue, 14 Apr 2026 16:49:08 +0800 Subject: [PATCH 2/6] fix(hunyuan_image3): align I2T with official HF model output - build_prompt: add instruct template (\n\nUser: ...\n\nAssistant: ) - hunyuan_image3.py: unblock /<\/answer> tokens so model can follow its natural generation pattern - i2t/t2t YAML: temperature=0.0 for greedy decoding, add (128026) to stop_token_ids Verified on 4xH800: input_ids match official baseline exactly (6364 tokens), greedy output is self-consistent within same process. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../hunyuan_image3/prompt_utils.py | 88 +++++++++++++++++++ .../models/hunyuan_image3/hunyuan_image3.py | 58 +++++------- .../stage_configs/hunyuan_image3_i2t.yaml | 4 +- .../stage_configs/hunyuan_image3_t2t.yaml | 4 +- .../stage_input_processors/hunyuan_image3.py | 12 +-- 5 files changed, 121 insertions(+), 45 deletions(-) create mode 100644 examples/offline_inference/hunyuan_image3/prompt_utils.py diff --git a/examples/offline_inference/hunyuan_image3/prompt_utils.py b/examples/offline_inference/hunyuan_image3/prompt_utils.py new file mode 100644 index 0000000000..a5ef8e1536 --- /dev/null +++ b/examples/offline_inference/hunyuan_image3/prompt_utils.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Prompt construction utilities for HunyuanImage-3.0-Instruct examples. + +Wraps system_prompt.get_system_prompt() with task-aware presets so that +examples and tests don't need to manually concatenate system prompts, +, , and tags. + +Usage: + from prompt_utils import build_prompt + + # IT2I (image editing, think+recaption mode) + prompt = build_prompt("Make the petals neon pink", task="it2i_think") + + # I2T (image understanding) + prompt = build_prompt("Describe the content of the picture.", task="i2t") +""" + +from __future__ import annotations + +from vllm_omni.diffusion.models.hunyuan_image3.system_prompt import ( + get_system_prompt, +) + +# task → (sys_type, bot_task, trigger_tag) +# trigger_tag: "", "", or None +_TASK_PRESETS: dict[str, tuple[str, str | None, str | None]] = { + # Pure text generation (text → text, no image) + "t2t": ("en_unified", None, None), + # Image understanding (image → text) + "i2t": ("en_unified", None, None), + # Image editing (image+text → image), think+recaption mode + "it2i_think": ("en_unified", "think", ""), + # Image editing, recaption-only mode + "it2i_recaption": ("en_unified", "recaption", ""), + # Text-to-image, think mode + "t2i_think": ("en_unified", "think", ""), + # Text-to-image, recaption mode + "t2i_recaption": ("en_unified", "recaption", ""), + # Text-to-image, vanilla (no CoT) + "t2i_vanilla": ("en_vanilla", "image", None), +} + + +def build_prompt( + user_prompt: str, + task: str = "it2i_think", + sys_type: str | None = None, + custom_system_prompt: str | None = None, +) -> str: + """Build a complete HunyuanImage-3.0 prompt with auto-selected system + prompt and mode trigger tags. + + Args: + user_prompt: The user's raw instruction or question. + task: One of the preset task keys (see _TASK_PRESETS). + sys_type: Override the preset's sys_type for get_system_prompt(). + custom_system_prompt: Custom system prompt text (used when + sys_type="custom"). + + Returns: + Fully formatted prompt string ready for Omni.generate(). + """ + if task not in _TASK_PRESETS: + raise ValueError(f"Unknown task {task!r}. Choose from: {sorted(_TASK_PRESETS)}") + + preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task] + effective_sys_type = sys_type or preset_sys_type + + system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt) + sys_text = system_prompt.strip() if system_prompt else "" + + has_image_input = task.startswith("i2t") or task.startswith("it2i") + + parts = ["<|startoftext|>"] + if sys_text: + parts.append(sys_text) + # Instruct conversation template: \n\nUser: ... \n\nAssistant: + parts.append("\n\nUser: ") + if has_image_input: + parts.append("") + parts.append(user_prompt) + parts.append("\n\nAssistant: ") + if trigger_tag: + parts.append(trigger_tag) + + return "".join(parts) diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index bbf921a5ee..56df4c83a6 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -1256,18 +1256,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._answer_id = tokenizer.convert_tokens_to_ids("") self._end_of_answer_id = tokenizer.convert_tokens_to_ids("") image_base_size = getattr(config, "image_base_size", 1024) - self._size_token_id = tokenizer.convert_tokens_to_ids( - f"" - ) + self._size_token_id = tokenizer.convert_tokens_to_ids(f"") self._start_ratio_id = tokenizer.convert_tokens_to_ids("") self._end_ratio_id = tokenizer.convert_tokens_to_ids("") ratio_33 = tokenizer.convert_tokens_to_ids("") ratio_36 = tokenizer.convert_tokens_to_ids("") self._ratio_other_slices = [(ratio_33, ratio_36 + 1)] # Build the full set of ratio token IDs for use as stop tokens. - self._all_ratio_ids = set( - range(self._start_ratio_id, self._end_ratio_id + 1) - ) + self._all_ratio_ids = set(range(self._start_ratio_id, self._end_ratio_id + 1)) for s, e in self._ratio_other_slices: self._all_ratio_ids.update(range(s, e)) @@ -1276,21 +1272,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._transition_state: dict[int, tuple[list[int], set[int]]] = {} # Determine mode: comprehension (I2T/T2T) vs generation (IT2I/T2I). - engine_output_type = getattr( - vllm_config.model_config, "engine_output_type", None - ) + engine_output_type = getattr(vllm_config.model_config, "engine_output_type", None) self._is_comprehension = engine_output_type in (None, "text") - # For comprehension mode, block generation-specific special tokens. + # For comprehension mode, block image generation tokens but allow + # text structure tokens (, , etc.) so the model can + # follow its natural generation pattern. Stop tokens in YAML will + # terminate at or EOS. self._blocked_token_ids: set[int] = set() if self._is_comprehension: - self._blocked_token_ids.update([ - self._mrope_boi_token_id, # - self._mrope_eoi_token_id, # - self._size_token_id, # - self._answer_id, # - self._end_of_answer_id, # - ]) + self._blocked_token_ids.update( + [ + self._mrope_boi_token_id, # + self._mrope_eoi_token_id, # + self._size_token_id, # + ] + ) self._blocked_token_ids.update(self._all_ratio_ids) # For generation mode, build stage transition map. @@ -1597,9 +1594,7 @@ def sample( for req_idx in range(logits.shape[0]): decoded_tokens: list[int] = ( - sampling_metadata.output_token_ids[req_idx] - if req_idx < len(sampling_metadata.output_token_ids) - else [] + sampling_metadata.output_token_ids[req_idx] if req_idx < len(sampling_metadata.output_token_ids) else [] ) last_token = decoded_tokens[-1] if decoded_tokens else -1 @@ -1609,14 +1604,10 @@ def sample( logits[req_idx, tid] = min_score else: # Generation: apply stage-transition logic. - self._apply_stage_transition( - logits, req_idx, last_token, min_score - ) + self._apply_stage_transition(logits, req_idx, last_token, min_score) # After size token → restrict to ratio tokens. if last_token == self._size_token_id: - self._apply_ratio_restriction( - logits, req_idx, min_score - ) + self._apply_ratio_restriction(logits, req_idx, min_score) # After ratio token → force EOS (official uses ratio as # final_stop_tokens; vLLM stop_token_ids may not include # all ratio IDs, so we force EOS here). @@ -1624,9 +1615,7 @@ def sample( logits[req_idx].fill_(min_score) logits[req_idx, self._eos_token_id] = 0 - return self._sampler( - logits=logits, sampling_metadata=sampling_metadata - ) + return self._sampler(logits=logits, sampling_metadata=sampling_metadata) def _apply_stage_transition( self, @@ -1653,10 +1642,7 @@ def _apply_stage_transition( return # Check if last_token triggers a new transition. - if ( - last_token in self._stage_transitions - and last_token not in completed - ): + if last_token in self._stage_transitions and last_token not in completed: completed.add(last_token) next_tokens = self._stage_transitions[last_token] if next_tokens: @@ -1677,9 +1663,9 @@ def _apply_ratio_restriction( original = logits[req_idx].clone() logits[req_idx].fill_(min_score) # Allow primary ratio range. - logits[req_idx, self._start_ratio_id : self._end_ratio_id + 1] = ( - original[self._start_ratio_id : self._end_ratio_id + 1] - ) + logits[req_idx, self._start_ratio_id : self._end_ratio_id + 1] = original[ + self._start_ratio_id : self._end_ratio_id + 1 + ] # Allow extra ratio slices. for s, e in self._ratio_other_slices: logits[req_idx, s:e] = original[s:e] diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml index 3739f334d2..0e2425f886 100644 --- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml @@ -30,11 +30,11 @@ stage_args: final_output: true final_output_type: text default_sampling_params: - temperature: 0.6 + temperature: 0.0 top_p: 0.95 top_k: 1024 max_tokens: 2048 - stop_token_ids: [127957] # <|endoftext|> + stop_token_ids: [127957, 128026] # <|endoftext|>, detokenize: True runtime: diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml index 83acc87fab..af55f0688d 100644 --- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml @@ -31,11 +31,11 @@ stage_args: final_output: true final_output_type: text default_sampling_params: - temperature: 0.6 + temperature: 0.0 top_p: 0.95 top_k: 1024 max_tokens: 2048 - stop_token_ids: [127957] # <|endoftext|> + stop_token_ids: [127957, 128026] # <|endoftext|>, detokenize: True runtime: diff --git a/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py b/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py index e7a7569727..89a7a28f6c 100644 --- a/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py +++ b/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py @@ -76,9 +76,12 @@ def ar2diffusion( text_prompt = original_prompt.get("prompt", "") logger.info( - "[ar2diffusion] Request %d: AR generated %d tokens, " - "text length=%d, target size=%dx%d", - i, len(generated_token_ids), len(generated_text), height, width, + "[ar2diffusion] Request %d: AR generated %d tokens, text length=%d, target size=%dx%d", + i, + len(generated_token_ids), + len(generated_text), + height, + width, ) token_tensor = torch.tensor(generated_token_ids, dtype=torch.long) @@ -111,8 +114,7 @@ def ar2diffusion( diffusion_input["extra"]["ar_multimodal_output"] = mm_output # Forward sampling params - for key in ["seed", "num_inference_steps", "guidance_scale", - "negative_prompt"]: + for key in ["seed", "num_inference_steps", "guidance_scale", "negative_prompt"]: if key in original_prompt: diffusion_input[key] = original_prompt[key] From d9d5136f21d55184f38af3e3b24fdf591b0a7d21 Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Wed, 15 Apr 2026 17:18:16 +0800 Subject: [PATCH 3/6] fix(hunyuan_image3): address PR review - remove dead code, use tokenizer eos_token_id, hook _clear_transition_state Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../model_executor/models/hunyuan_image3/hunyuan_image3.py | 5 ++++- vllm_omni/patch.py | 6 ------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index 56df4c83a6..b76a6cfb5a 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -1306,7 +1306,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ] self._sampler: Sampler | None = None - self._eos_token_id: int = 127957 # <|endoftext|> + self._eos_token_id: int = tokenizer.eos_token_id self._replace_rotary_embeddings() @@ -1598,6 +1598,9 @@ def sample( ) last_token = decoded_tokens[-1] if decoded_tokens else -1 + if not decoded_tokens: + self._clear_transition_state(req_idx) + if self._is_comprehension: # Comprehension: mask out generation-specific tokens. for tid in self._blocked_token_ids: diff --git a/vllm_omni/patch.py b/vllm_omni/patch.py index 2c96f536d8..4a0e856c19 100644 --- a/vllm_omni/patch.py +++ b/vllm_omni/patch.py @@ -45,12 +45,6 @@ def _patched_is_mm_prefix_lm(self): _patched_cp.__set_name__(_OriginalModelConfig, "is_mm_prefix_lm") _OriginalModelConfig.is_mm_prefix_lm = _patched_cp -# Also fix the original cached_property if __set_name__ was never called (vllm 0.19.0+) -_orig_cp = _OriginalModelConfig.__dict__.get("is_mm_prefix_lm") -if _orig_cp is not _patched_cp: - # Our assignment above should have replaced it, but just in case - pass - # ============================================================================= # Patch GlmImageTextConfig to expose mrope_section in rope_parameters # ============================================================================= From e1d7babd917eede7613bdb688b1b6b3ca41a1a8f Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Wed, 15 Apr 2026 17:50:57 +0800 Subject: [PATCH 4/6] fix(hunyuan_image3): simplify transition state to single-request, fix devices, harden patch.py comments Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../models/hunyuan_image3/hunyuan_image3.py | 86 ++++++++----------- .../stage_configs/hunyuan_image3_i2t.yaml | 2 +- .../stage_configs/hunyuan_image3_t2t.yaml | 2 +- vllm_omni/patch.py | 23 ++++- 4 files changed, 56 insertions(+), 57 deletions(-) diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index b76a6cfb5a..5ef012cb3a 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -1267,10 +1267,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): for s, e in self._ratio_other_slices: self._all_ratio_ids.update(range(s, e)) - # Per-request state for stage-transition logits processor. - # Maps request index → (pending_tokens list, completed set). - self._transition_state: dict[int, tuple[list[int], set[int]]] = {} - # Determine mode: comprehension (I2T/T2T) vs generation (IT2I/T2I). engine_output_type = getattr(vllm_config.model_config, "engine_output_type", None) self._is_comprehension = engine_output_type in (None, "text") @@ -1294,6 +1290,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Official logic: → [], # → [, , ] # After , restrict vocab to ratio tokens only. + # Stage-transition forced sequences, keyed by trigger token. self._stage_transitions: dict[int, list[int]] = {} if not self._is_comprehension: self._stage_transitions[self._end_of_think_id] = [ @@ -1598,60 +1595,51 @@ def sample( ) last_token = decoded_tokens[-1] if decoded_tokens else -1 - if not decoded_tokens: - self._clear_transition_state(req_idx) - if self._is_comprehension: - # Comprehension: mask out generation-specific tokens. for tid in self._blocked_token_ids: logits[req_idx, tid] = min_score else: - # Generation: apply stage-transition logic. - self._apply_stage_transition(logits, req_idx, last_token, min_score) - # After size token → restrict to ratio tokens. - if last_token == self._size_token_id: + forced = self._get_forced_token(decoded_tokens) + if forced is not None: + logits[req_idx].fill_(min_score) + logits[req_idx, forced] = 0 + elif last_token == self._size_token_id: self._apply_ratio_restriction(logits, req_idx, min_score) - # After ratio token → force EOS (official uses ratio as - # final_stop_tokens; vLLM stop_token_ids may not include - # all ratio IDs, so we force EOS here). elif last_token in self._all_ratio_ids: logits[req_idx].fill_(min_score) logits[req_idx, self._eos_token_id] = 0 return self._sampler(logits=logits, sampling_metadata=sampling_metadata) - def _apply_stage_transition( - self, - logits: torch.Tensor, - req_idx: int, - last_token: int, - min_score: float, - ) -> None: - """Port of official _StageTransitionLogitsProcessor.__call__.""" - state = self._transition_state.get(req_idx) - if state is None: - state = ([], set()) # (pending_tokens, completed_transitions) - self._transition_state[req_idx] = state - pending, completed = state - - # Consume pending token if last output matches head of queue. - if pending and last_token == pending[0]: - pending.pop(0) - - # If pending tokens remain, force the next one. - if pending: - logits[req_idx].fill_(min_score) - logits[req_idx, pending[0]] = 0 - return - - # Check if last_token triggers a new transition. - if last_token in self._stage_transitions and last_token not in completed: - completed.add(last_token) - next_tokens = self._stage_transitions[last_token] - if next_tokens: - pending.extend(next_tokens) - logits[req_idx].fill_(min_score) - logits[req_idx, pending[0]] = 0 + def _get_forced_token(self, decoded_tokens: list[int]) -> int | None: + """Derive the next forced token from output history (stateless). + + Scans decoded_tokens backwards for the most recent trigger token, + then prefix-matches the forced sequence against what followed. + Returns the next token to force, or None if the sequence is complete + or history has diverged from the expected forced sequence. + """ + for i in range(len(decoded_tokens) - 1, -1, -1): + trigger = decoded_tokens[i] + if trigger not in self._stage_transitions: + continue + + forced_seq = self._stage_transitions[trigger] + emitted = decoded_tokens[i + 1 :] + + matched = 0 + for expected, actual in zip(forced_seq, emitted): + if actual != expected: + # History diverged from the expected forced sequence. + # Stop applying transition forcing for safety. + return None + matched += 1 + + if matched < len(forced_seq): + return forced_seq[matched] + return None + + return None def _apply_ratio_restriction( self, @@ -1677,10 +1665,6 @@ def _apply_ratio_restriction( logits[req_idx].fill_(min_score) logits[req_idx, max_id] = 0 - def _clear_transition_state(self, req_idx: int) -> None: - """Clean up per-request transition state when request finishes.""" - self._transition_state.pop(req_idx, None) - def make_empty_intermediate_tensors( self, batch_size: int, dtype: torch.dtype, device: torch.device ) -> IntermediateTensors: diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml index 0e2425f886..203b54f257 100644 --- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml @@ -6,7 +6,7 @@ stage_args: stage_type: llm runtime: process: true - devices: "2,3,4,5" + devices: "0,1,2,3" max_batch_size: 1 requires_multimodal_data: true engine_args: diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml index af55f0688d..60da8e0bc7 100644 --- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml @@ -7,7 +7,7 @@ stage_args: stage_type: llm runtime: process: true - devices: "2,3,4,5" + devices: "0,1,2,3" max_batch_size: 1 requires_multimodal_data: false engine_args: diff --git a/vllm_omni/patch.py b/vllm_omni/patch.py index 4a0e856c19..a0b3c54eec 100644 --- a/vllm_omni/patch.py +++ b/vllm_omni/patch.py @@ -22,10 +22,25 @@ # ============================================================================= # Patch ModelConfig.is_mm_prefix_lm to support omni-specific models # ============================================================================= -# HunyuanImage-3.0 uses bidirectional attention for image token positions -# (cond_token_attn_type: "joint_full" in config.json), but its model_type -# "hunyuan_image_3_moe" is not in vllm's built-in MM_PREFIX_LM_MODELS list. -# This patch extends the check to include omni-specific models. +# WHY: HunyuanImage-3.0 requires bidirectional attention for image tokens +# (cond_token_attn_type: "joint_full" in config.json). vLLM gates this on +# is_mm_prefix_lm, which checks an internal MM_PREFIX_LM_MODELS list that +# does not include "hunyuan_image_3_moe" (the upstream HF model_type). +# +# WHY NOT model-level: is_mm_prefix_lm is checked in vLLM core (scheduler, +# attention backend selection) before model code runs — no model-level hook. +# +# SCOPE: Only affects model_type in _OMNI_MM_PREFIX_LM_MODELS (currently +# just "hunyuan_image_3_moe"). All other models fall through to the +# original vLLM implementation unchanged. +# +# FRAGILITY: Relies on is_mm_prefix_lm being a cached_property on +# ModelConfig. The __dict__ access + __set_name__ dance works around a +# pydantic dataclass issue in vllm 0.19.0+. If vLLM changes +# is_mm_prefix_lm to a regular method or removes it, this will break. +# +# TODO: Upstream a configurable MM_PREFIX_LM_MODELS or a model_config flag +# so this patch can be removed. _OMNI_MM_PREFIX_LM_MODELS = ("hunyuan_image_3_moe",) # Access via __dict__ to avoid triggering cached_property.__get__ which fails # with "Cannot use cached_property instance without calling __set_name__" in From e18e9209e69373891e0c5d4f1c036fc2f13b3793 Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Thu, 16 Apr 2026 18:52:15 +0800 Subject: [PATCH 5/6] fix(hunyuan_image3): remove LLM-only fields from diffusion stage engine_args in it2i.yaml Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../model_executor/stage_configs/hunyuan_image3_it2i.yaml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml index d8598d6f4c..9f6adece0f 100644 --- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml @@ -50,13 +50,9 @@ stage_args: engine_args: model_stage: dit model_arch: HunyuanImage3ForCausalMM - gpu_memory_utilization: 0.65 enforce_eager: true trust_remote_code: true - engine_output_type: image distributed_executor_backend: "mp" - enable_prefix_caching: false - max_num_batched_tokens: 32768 parallel_config: tensor_parallel_size: 4 enable_expert_parallel: true From fffba5090582d143aac02f696d3b1c74ef3df984 Mon Sep 17 00:00:00 2001 From: TaffyOfficial <2324465096@qq.com> Date: Thu, 16 Apr 2026 19:02:46 +0800 Subject: [PATCH 6/6] fix(hunyuan_image3): add batch assertion, patch sanity check, sampler unit tests Signed-off-by: TaffyOfficial <2324465096@qq.com> --- .../test_hunyuan_image3_sampler.py | 190 ++++++++++++++++++ .../models/hunyuan_image3/hunyuan_image3.py | 2 + vllm_omni/patch.py | 9 + 3 files changed, 201 insertions(+) create mode 100644 tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_sampler.py diff --git a/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_sampler.py b/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_sampler.py new file mode 100644 index 0000000000..51f6a85f58 --- /dev/null +++ b/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_sampler.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for HunyuanImage3 AR sampler logic (stage transitions, +ratio restriction, comprehension blocking).""" + +import pytest +import torch + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +# Fake token IDs for testing (avoid importing the real model). +END_OF_THINK = 100 +RECAPTION = 101 +END_OF_RECAPTION = 102 +ANSWER = 103 +BOI = 104 +SIZE_TOKEN = 105 +EOS = 106 +RATIO_START = 200 +RATIO_END = 210 +RATIO_OTHER_START = 220 +RATIO_OTHER_END = 223 + + +class FakeSamplerModel: + """Minimal stub that replicates the sampler-relevant attributes of + HunyuanImage3ForConditionalGeneration without loading real weights.""" + + def __init__(self, *, is_comprehension: bool = False): + self._is_comprehension = is_comprehension + self._eos_token_id = EOS + self._end_of_think_id = END_OF_THINK + self._recaption_id = RECAPTION + self._end_of_recaption_id = END_OF_RECAPTION + self._answer_id = ANSWER + self._mrope_boi_token_id = BOI + self._size_token_id = SIZE_TOKEN + self._start_ratio_id = RATIO_START + self._end_ratio_id = RATIO_END + self._ratio_other_slices = [(RATIO_OTHER_START, RATIO_OTHER_END + 1)] + self._all_ratio_ids = set(range(RATIO_START, RATIO_END + 1)) + self._all_ratio_ids.update(range(RATIO_OTHER_START, RATIO_OTHER_END + 1)) + + self._stage_transitions: dict[int, list[int]] = {} + if not is_comprehension: + self._stage_transitions[END_OF_THINK] = [RECAPTION] + self._stage_transitions[END_OF_RECAPTION] = [ANSWER, BOI, SIZE_TOKEN] + + self._blocked_token_ids: set[int] = set() + if is_comprehension: + self._blocked_token_ids.update([BOI, SIZE_TOKEN]) + self._blocked_token_ids.update(self._all_ratio_ids) + + # Bind the real methods from the model class. + from vllm_omni.model_executor.models.hunyuan_image3.hunyuan_image3 import ( + HunyuanImage3ForConditionalGeneration as _Real, + ) + + _get_forced_token = _Real._get_forced_token + _apply_ratio_restriction = _Real._apply_ratio_restriction + + +class TestGetForcedToken: + """Tests for the stateless _get_forced_token method.""" + + def setup_method(self): + self.model = FakeSamplerModel(is_comprehension=False) + + def test_no_trigger_returns_none(self): + assert self.model._get_forced_token([1, 2, 3]) is None + + def test_empty_history_returns_none(self): + assert self.model._get_forced_token([]) is None + + def test_end_of_think_forces_recaption(self): + assert self.model._get_forced_token([END_OF_THINK]) == RECAPTION + + def test_end_of_think_completed(self): + assert self.model._get_forced_token([END_OF_THINK, RECAPTION]) is None + + def test_end_of_recaption_forces_answer(self): + tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION] + assert self.model._get_forced_token(tokens) == ANSWER + + def test_end_of_recaption_forces_boi_after_answer(self): + tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION, ANSWER] + assert self.model._get_forced_token(tokens) == BOI + + def test_end_of_recaption_forces_size_after_boi(self): + tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION, ANSWER, BOI] + assert self.model._get_forced_token(tokens) == SIZE_TOKEN + + def test_full_sequence_complete(self): + tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION, ANSWER, BOI, SIZE_TOKEN] + assert self.model._get_forced_token(tokens) is None + + def test_diverged_history_returns_none(self): + tokens = [END_OF_RECAPTION, 999] # 999 != ANSWER + assert self.model._get_forced_token(tokens) is None + + def test_later_trigger_takes_precedence(self): + tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION] + assert self.model._get_forced_token(tokens) == ANSWER + + def test_trigger_with_extra_tokens_before(self): + tokens = [1, 2, 3, END_OF_THINK] + assert self.model._get_forced_token(tokens) == RECAPTION + + +class TestComprehensionBlocking: + """Tests for comprehension mode token blocking.""" + + def test_blocked_tokens_masked(self): + model = FakeSamplerModel(is_comprehension=True) + vocab_size = 300 + logits = torch.zeros(1, vocab_size) + logits[0, BOI] = 5.0 + logits[0, SIZE_TOKEN] = 3.0 + logits[0, RATIO_START] = 2.0 + min_score = torch.finfo(logits.dtype).min + + for tid in model._blocked_token_ids: + if tid < vocab_size: + logits[0, tid] = min_score + + assert logits[0, BOI].item() == min_score + assert logits[0, SIZE_TOKEN].item() == min_score + assert logits[0, RATIO_START].item() == min_score + + def test_non_blocked_tokens_preserved(self): + model = FakeSamplerModel(is_comprehension=True) + vocab_size = 300 + logits = torch.zeros(1, vocab_size) + logits[0, 50] = 7.0 + min_score = torch.finfo(logits.dtype).min + + for tid in model._blocked_token_ids: + if tid < vocab_size: + logits[0, tid] = min_score + + assert logits[0, 50].item() == 7.0 + + +class TestRatioRestriction: + """Tests for _apply_ratio_restriction (greedy: only argmax ratio survives).""" + + def test_greedy_selects_single_ratio_token(self): + model = FakeSamplerModel(is_comprehension=False) + vocab_size = 300 + logits = torch.zeros(1, vocab_size) + logits[0, RATIO_START + 3] = 10.0 + logits[0, RATIO_START + 1] = 5.0 + logits[0, 50] = 20.0 # non-ratio, should be masked + min_score = torch.finfo(logits.dtype).min + + model._apply_ratio_restriction(logits, 0, min_score) + + assert logits[0, RATIO_START + 3].item() == 0 + assert logits[0, RATIO_START + 1].item() == min_score + assert logits[0, 50].item() == min_score + + def test_extra_ratio_slices_considered(self): + model = FakeSamplerModel(is_comprehension=False) + vocab_size = 300 + logits = torch.zeros(1, vocab_size) + logits[0, RATIO_OTHER_START] = 15.0 + logits[0, RATIO_START] = 5.0 + min_score = torch.finfo(logits.dtype).min + + model._apply_ratio_restriction(logits, 0, min_score) + + assert logits[0, RATIO_OTHER_START].item() == 0 + assert logits[0, RATIO_START].item() == min_score + + +class TestForceEosAfterRatio: + """Tests that a ratio token as last_token forces EOS.""" + + def test_ratio_token_forces_eos(self): + model = FakeSamplerModel(is_comprehension=False) + vocab_size = 300 + logits = torch.randn(1, vocab_size) + min_score = torch.finfo(logits.dtype).min + + logits[0].fill_(min_score) + logits[0, model._eos_token_id] = 0 + + assert logits[0, EOS].item() == 0 + non_eos_max = logits[0, :EOS].max().item() + assert non_eos_max == min_score diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index 5ef012cb3a..6304eeab29 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -1589,6 +1589,8 @@ def sample( min_score = torch.finfo(logits.dtype).min + assert logits.shape[0] == 1, f"HunyuanImage3 sampler requires max_num_seqs=1, got batch size {logits.shape[0]}" + for req_idx in range(logits.shape[0]): decoded_tokens: list[int] = ( sampling_metadata.output_token_ids[req_idx] if req_idx < len(sampling_metadata.output_token_ids) else [] diff --git a/vllm_omni/patch.py b/vllm_omni/patch.py index a0b3c54eec..d4ab78f13a 100644 --- a/vllm_omni/patch.py +++ b/vllm_omni/patch.py @@ -60,6 +60,15 @@ def _patched_is_mm_prefix_lm(self): _patched_cp.__set_name__(_OriginalModelConfig, "is_mm_prefix_lm") _OriginalModelConfig.is_mm_prefix_lm = _patched_cp +# Sanity check: verify the patch is active. If vLLM changes the descriptor +# type or __set_name__ semantics, this will fail loudly at import time +# rather than silently falling back to unpatched behavior. +_installed = _OriginalModelConfig.__dict__.get("is_mm_prefix_lm") +assert _installed is _patched_cp, ( + "is_mm_prefix_lm patch failed to install — bidirectional attention " + "for HunyuanImage3 will not work. Check vLLM ModelConfig changes." +) + # ============================================================================= # Patch GlmImageTextConfig to expose mrope_section in rope_parameters # =============================================================================