diff --git a/examples/offline_inference/hunyuan_image3/README.md b/examples/offline_inference/hunyuan_image3/README.md index 3eb3bfbff6f..c1c97bfe1fa 100644 --- a/examples/offline_inference/hunyuan_image3/README.md +++ b/examples/offline_inference/hunyuan_image3/README.md @@ -31,6 +31,14 @@ python end2end.py --model tencent/HunyuanImage-3.0-Instruct \ --prompts "A cute cat sitting on a windowsill watching the sunset" ``` +**With VAE tiling (required on A100 GPUs):** +```bash +python end2end.py --model tencent/HunyuanImage-3.0-Instruct \ + --modality text2img \ + --prompts "A cute cat sitting on a windowsill watching the sunset" \ + --vae-use-tiling +``` + #### Image to Image (img2img) - **Pipeline**: Image + Text → AR (CoT + recaption + latent) → DiT → Edited Image @@ -103,6 +111,7 @@ python end2end.py --modality text2img \ | `--stage-configs-path` | string | auto | Custom stage config YAML path | | `--enforce-eager` | flag | `False` | Disable torch.compile | | `--init-timeout` | int | `300` | Initialization timeout (seconds) | +| `--vae-use-tiling` | flag | `False` | Enable VAE tiling for memory optimization (required to avoid OOM on A100) | ------ @@ -153,7 +162,7 @@ helper handles segment-by-segment tokenization (matches HF `apply_chat_template` ## FAQ -- **OOM errors**: Decrease `gpu_memory_utilization` in the YAML stage config, or use a smaller `max_num_batched_tokens`. +- **OOM errors**: Decrease `gpu_memory_utilization` in the YAML stage config, use a smaller `max_num_batched_tokens`, or enable VAE tiling with `--vae-use-tiling` (required on A100 GPUs). - **Custom image sizes**: Use `--height` and `--width` flags (multiples of 16 recommended). | Stage | VRAM (approx) | diff --git a/examples/offline_inference/hunyuan_image3/end2end.py b/examples/offline_inference/hunyuan_image3/end2end.py index 8993c8dadca..5a01858a447 100644 --- a/examples/offline_inference/hunyuan_image3/end2end.py +++ b/examples/offline_inference/hunyuan_image3/end2end.py @@ -22,6 +22,17 @@ from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniPromptType +# task -> (sys_type, bot_task, trigger_tag) +_TASK_PRESETS: dict[str, tuple[str, str | None, str | None]] = { + "t2t": ("en_unified", None, None), + "i2t": ("en_unified", None, None), + "it2i_think": ("en_unified", "think", ""), + "it2i_recaption": ("en_unified", "recaption", ""), + "t2i_think": ("en_unified", "think", ""), + "t2i_recaption": ("en_unified", "recaption", ""), + "t2i_vanilla": ("en_vanilla", "image", None), +} + # Modality → prompt_utils task mapping _MODALITY_TASK_MAP = { "text2img": "t2i_think", @@ -73,6 +84,11 @@ def parse_args(): parser.add_argument("--seed", type=int, default=42, help="Random seed.") parser.add_argument("--height", type=int, default=1024, help="Output image height.") parser.add_argument("--width", type=int, default=1024, help="Output image width.") + parser.add_argument( + "--vae-use-tiling", + action="store_true", + help="Enable VAE tiling for memory optimization.", + ) # Prompt configuration parser.add_argument( @@ -113,6 +129,7 @@ def main(): # Build Omni omni_kwargs = { "model": args.model, + "vae_use_tiling": args.vae_use_tiling, "stage_configs_path": stage_configs_path, "log_stats": args.log_stats, "init_timeout": args.init_timeout, @@ -148,8 +165,18 @@ def main(): formatted_prompts: list[OmniPromptType] = [] for p in prompts: token_ids = build_prompt_tokens(p, tokenizer, task=task, sys_type=args.sys_type) - - prompt_dict: dict = {"prompt_token_ids": token_ids, "prompt": p} + preset_sys_type, _, _ = _TASK_PRESETS[task] + effective_sys_type = args.sys_type or preset_sys_type + + # `prompt_token_ids` drives the AR stage (matches HF byte-for-byte). + # `prompt` and `use_system_prompt` are forwarded by ar2diffusion to + # the DiT stage so the diffusion pipeline can rebuild the same + # system prefix when constructing its model inputs. + prompt_dict: dict = { + "prompt_token_ids": token_ids, + "prompt": p, + "use_system_prompt": effective_sys_type, + } if args.modality == "text2img": prompt_dict["modalities"] = ["image"] diff --git a/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_it2i_ar_format.py b/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_it2i_ar_format.py new file mode 100644 index 00000000000..7e7b7de91b2 --- /dev/null +++ b/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_it2i_ar_format.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Verify the IT2I AR-prefill prompt matches the official HF chat-template output. + +PR #3107 builds the AR prefill via +:func:`vllm_omni.diffusion.models.hunyuan_image3.prompt_utils.build_prompt_tokens`, +which segment-tokenizes the canonical Instruct chat template (`<|startoftext|>` ++ `{system}\\n\\n` + `User: []{user_prompt}\\n\\nAssistant: {trigger?}`). + +The official HunyuanImage-3.0-Instruct repo ships a Jinja `chat_template` in +its tokenizer config and an `image_processor.py` whose `process_image` +defines the same VAE/VIT preprocessing the diffusion pipeline uses on the +condition image. To prevent silent drift between the AR's input distribution +and what the model was actually trained on, this test asserts: + +1. ``build_prompt_tokens`` token-id sequence equals the HF reference produced + by ``tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)`` + for the same `(system, user_prompt, image)` triple. +2. The image-tensor produced by the diffusion-side ``_resize_and_crop_center`` + is byte-identical to the AR-side ``HunyuanImage3Processor._resize_and_crop`` + output (i.e. AR and DiT preprocess the IT2I condition image identically). + +Both checks need the official tokenizer/image-processor classes; we gate on +``HF_HOME`` cache availability so the suite stays runnable on machines +without the model weights. +""" + +from __future__ import annotations + +import os +import pathlib + +import pytest + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +_HUNYUAN_MODEL_ID = "tencent/HunyuanImage-3.0-Instruct" + + +def _hf_cached(model_id: str) -> bool: + hf_home = os.environ.get("HF_HOME") or os.path.expanduser("~/.cache/huggingface") + snap_dir = os.path.join(hf_home, "hub", f"models--{model_id.replace('/', '--')}", "snapshots") + return os.path.isdir(snap_dir) and any(os.scandir(snap_dir)) + + +def _snapshot_dir(model_id: str) -> pathlib.Path: + hf_home = os.environ.get("HF_HOME") or os.path.expanduser("~/.cache/huggingface") + snap_root = pathlib.Path(hf_home) / "hub" / f"models--{model_id.replace('/', '--')}" / "snapshots" + snap = next(iter(snap_root.iterdir())) + return snap + + +# --- Real AR-output comparison lives in +# tests/e2e/accuracy/test_hunyuan_image3_it2i_ar_output.py --- +# +# Earlier revisions of this file shipped a CPU-only "compare prefill +# token sequences" check that called the official tokenizer's +# `apply_chat_template`. That comparison was misleading: it only verified +# the *input* prompt template, not the AR-stage *generated output*; and +# it kept skipping because instantiating +# `HunyuanImage3TokenizerFast.from_pretrained(snap)` returns a +# byte-fallback (char-level) tokenizer that is not the same encoding the +# vllm-omni production path actually uses (which goes through the +# standard `AutoTokenizer.from_pretrained`). +# +# The "AR output matches official" contract is genuinely a GPU-required +# end-to-end test: it must drive `model.prepare_model_inputs` + +# `model.generate(do_sample=False)` on the HF side and the IT2I `i2t` +# stage on the omni side, then compare AR-generated token sequences. +# That is now the responsibility of the e2e test in +# tests/e2e/accuracy/test_hunyuan_image3_it2i_ar_output.py. + + +_OFFICIAL_PKG = "_hunyuan_image_3_official_snapshot" + + +def _import_official_snapshot_modules(): + """Register the HunyuanImage-3.0-Instruct snapshot as a fake package so + its ``image_processor.py`` (which does ``from .tokenization_hunyuan_image_3 + import ...``) can be loaded with relative imports intact. + + Returns ``(tokenization_module, image_processor_module)`` or ``(None, None)`` + if either fails (e.g. snapshot missing, optional dep like diffusers absent). + """ + import importlib.util + import sys + import types + + if _OFFICIAL_PKG in sys.modules: + pkg = sys.modules[_OFFICIAL_PKG] + return ( + sys.modules.get(f"{_OFFICIAL_PKG}.tokenization_hunyuan_image_3"), + sys.modules.get(f"{_OFFICIAL_PKG}.image_processor"), + ) + + snap = _snapshot_dir(_HUNYUAN_MODEL_ID) + if not (snap / "image_processor.py").is_file(): + return None, None + + pkg = types.ModuleType(_OFFICIAL_PKG) + pkg.__path__ = [str(snap)] + sys.modules[_OFFICIAL_PKG] = pkg + + def _load(name: str): + full = f"{_OFFICIAL_PKG}.{name}" + spec = importlib.util.spec_from_file_location(full, snap / f"{name}.py") + if spec is None or spec.loader is None: + return None + mod = importlib.util.module_from_spec(spec) + sys.modules[full] = mod + try: + spec.loader.exec_module(mod) + except Exception: + del sys.modules[full] + return None + return mod + + tok_mod = _load("tokenization_hunyuan_image_3") + if tok_mod is None: + return None, None + img_mod = _load("image_processor") + return tok_mod, img_mod + + +@pytest.mark.skipif( + not _hf_cached(_HUNYUAN_MODEL_ID), + reason=f"{_HUNYUAN_MODEL_ID} not in HF cache", +) +def test_dit_condition_image_preprocessing_byte_matches_official_hf(): + """The diffusion pipeline's ``_resize_and_crop_center`` (used to feed + the VAE encoder for IT2I conditioning) must produce byte-identical + pixels to the **official** HuggingFace + ``image_processor.resize_and_crop`` (loaded straight out of the + HunyuanImage-3.0-Instruct snapshot's bundled ``image_processor.py``) + at ``crop_type='center'``. + + Bounty-hunter's PR #3107 review flagged that the DiT-side helper had + drifted from the AR-side processor on rounding boundaries; PR #3107 + commit ``0a7e0e6f`` aligned the DiT helper to the AR-side algorithm. + AR and DiT both *claim* to mirror the HF reference, so the actual + contract is "DiT (and AR) match the HF reference verbatim". We + enforce that contract here by comparing directly to the HF function + rather than to a sibling vllm-omni copy. + """ + import numpy as np + from PIL import Image + + from vllm_omni.diffusion.models.hunyuan_image3.pipeline_hunyuan_image3 import ( + _resize_and_crop_center, + ) + + _tok_mod, official_module = _import_official_snapshot_modules() + if official_module is None or not hasattr(official_module, "resize_and_crop"): + pytest.skip("Official HunyuanImage3 image_processor.py not loadable") + official_resize_and_crop = official_module.resize_and_crop + + rng = np.random.default_rng(seed=42) + src_size_pairs = [(640, 1024), (1024, 1024), (1280, 720), (480, 800)] + target_size_pairs = [(1024, 1024), (1024, 768), (768, 1024)] + + for src_w, src_h in src_size_pairs: + src_arr = rng.integers(0, 256, size=(src_h, src_w, 3), dtype=np.uint8) + src = Image.fromarray(src_arr, mode="RGB") + for tw, th in target_size_pairs: + ref_out = official_resize_and_crop( + src, + target_size=(tw, th), + resample=Image.Resampling.LANCZOS, + crop_type="center", + ) + dit_out = _resize_and_crop_center(src, tw, th) + assert ref_out.size == dit_out.size == (tw, th), ( + f"size mismatch for src={(src_w, src_h)} target={(tw, th)}: " + f"hf_official={ref_out.size} dit={dit_out.size}" + ) + ref_pixels = np.asarray(ref_out) + dit_pixels = np.asarray(dit_out) + assert np.array_equal(ref_pixels, dit_pixels), ( + f"DiT condition-image preprocessing diverged from HF " + f"image_processor.resize_and_crop at src={(src_w, src_h)} " + f"target={(tw, th)}: max abs diff = " + f"{int(np.abs(ref_pixels.astype(int) - dit_pixels.astype(int)).max())}" + ) diff --git a/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py index cdd79b9fddb..f2cfa3ed21b 100644 --- a/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py @@ -746,6 +746,9 @@ def __init__(self, config): self.layers = modules + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layers(x) + class HunYuanRotary2DEmbedder: r""" diff --git a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py index 0646d77e6df..86fe2cfcfc4 100644 --- a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py @@ -6,9 +6,13 @@ from collections.abc import Iterable from typing import Any +import numpy as np import torch import torch.nn as nn -from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler +from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) +from PIL import Image as PILImage from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.utils import ALL_CACHE_NAMES, GenerationMixin from transformers.utils.generic import ModelOutput @@ -19,8 +23,12 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import ( + DiffusionPipelineProfilerMixin, +) from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.model_executor.models.hunyuan_image3.siglip2 import Siglip2VisionTransformer from .autoencoder import AutoencoderKLConv3D @@ -63,7 +71,239 @@ def to_device(data, device): return data -class HunyuanImage3Pipeline(HunyuanImage3PreTrainedModel, GenerationMixin, DiffusionPipelineProfilerMixin): +def _to_pil_image(image: Any) -> PILImage.Image: + if isinstance(image, PILImage.Image): + return image + if isinstance(image, str): + return PILImage.open(image) + if isinstance(image, np.ndarray): + array = image + if array.dtype != np.uint8: + if np.issubdtype(array.dtype, np.floating): + if float(np.min(array)) < 0.0: + array = (np.clip(array, -1.0, 1.0) + 1.0) / 2.0 + if float(np.max(array)) <= 1.0: + array = array * 255.0 + array = np.clip(array, 0, 255).astype(np.uint8) + if array.ndim == 3 and array.shape[0] in (1, 3, 4): + array = np.transpose(array, (1, 2, 0)) + return PILImage.fromarray(array) + if isinstance(image, torch.Tensor): + tensor = image.detach().cpu() + if tensor.ndim == 4: + if tensor.shape[0] != 1: + raise ValueError(f"Only a single image tensor is supported, but got shape {tuple(tensor.shape)}.") + tensor = tensor.squeeze(0) + if tensor.ndim == 3 and tensor.shape[0] in (1, 3, 4): + tensor = tensor.permute(1, 2, 0) + if tensor.dtype.is_floating_point: + if float(tensor.min()) < 0.0: + tensor = (tensor.clamp(-1.0, 1.0) + 1.0) / 2.0 + if float(tensor.max()) > 1.0: + tensor = tensor / 255.0 + tensor = (tensor.clamp(0.0, 1.0) * 255.0).to(torch.uint8) + else: + tensor = tensor.to(torch.uint8) + return PILImage.fromarray(tensor.numpy()) + raise TypeError(f"Unsupported image input type: {type(image)}") + + +def _resize_and_crop_center(image: PILImage.Image, target_width: int, target_height: int) -> PILImage.Image: + # Mirrors HunyuanImage3Processor._resize_and_crop in + # vllm_omni.model_executor.models.hunyuan_image3.hunyuan_image3 so the AR + # and DiT stages preprocess condition images identically. + tw, th = target_width, target_height + w, h = image.size + tr = th / tw + r = h / w + if r < tr: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + resized = image.resize((resize_width, resize_height), PILImage.Resampling.LANCZOS) + crop_top = int(round((resize_height - th) / 2.0)) + crop_left = int(round((resize_width - tw) / 2.0)) + return resized.crop((crop_left, crop_top, crop_left + tw, crop_top + th)) + + +def _to_python_scalar(value: Any) -> Any: + if isinstance(value, np.generic): + return value.item() + return value + + +def _image_info_to_payload(image_info: ImageInfo) -> dict[str, Any]: + return { + "image_type": image_info.image_type, + "image_tensor": image_info.image_tensor, + "image_width": _to_python_scalar(image_info.image_width), + "image_height": _to_python_scalar(image_info.image_height), + "token_width": _to_python_scalar(image_info.token_width), + "token_height": _to_python_scalar(image_info.token_height), + "image_token_length": _to_python_scalar(image_info.image_token_length), + "base_size": _to_python_scalar(image_info.base_size), + "ratio_index": _to_python_scalar(image_info.ratio_index), + "add_timestep_token": image_info.add_timestep_token, + "add_guidance_token": image_info.add_guidance_token, + "use_front_boi_token": image_info.use_front_boi_token, + "add_image_shape_token": image_info.add_image_shape_token, + } + + +def _to_tensor_if_needed(value: Any) -> Any: + if isinstance(value, np.generic): + return value.item() + if isinstance(value, list): + return torch.tensor(value) + return value + + +def _image_info_from_payload(payload: dict[str, Any]) -> ImageInfo: + return ImageInfo( + image_type=payload.get("image_type"), + image_tensor=_to_tensor_if_needed(payload.get("image_tensor")), + image_width=payload.get("image_width"), + image_height=payload.get("image_height"), + token_width=payload.get("token_width"), + token_height=payload.get("token_height"), + image_token_length=payload.get("image_token_length"), + base_size=payload.get("base_size"), + ratio_index=payload.get("ratio_index"), + add_timestep_token=payload.get("add_timestep_token", True), + add_guidance_token=payload.get("add_guidance_token", False), + use_front_boi_token=payload.get("use_front_boi_token", True), + add_image_shape_token=payload.get("add_image_shape_token", True), + ) + + +def _joint_image_info_to_payload(joint_image_info: JointImageInfo) -> dict[str, Any]: + return { + "type": "joint_image_info", + "vae_image_info": _image_info_to_payload(joint_image_info.vae_image_info), + "vision_image_info": _image_info_to_payload(joint_image_info.vision_image_info), + "vision_encoder_kwargs": joint_image_info.vision_encoder_kwargs, + } + + +def _joint_image_info_from_payload(payload: Any) -> JointImageInfo: + if isinstance(payload, JointImageInfo): + return payload + if not isinstance(payload, dict): + raise TypeError(f"Expected dict or JointImageInfo for conditional image payload, got {type(payload)}.") + + vae_image_info = _image_info_from_payload(payload["vae_image_info"]) + vision_image_info = _image_info_from_payload(payload["vision_image_info"]) + vision_encoder_kwargs = payload.get("vision_encoder_kwargs") or {} + if isinstance(vision_encoder_kwargs, dict): + vision_encoder_kwargs = {key: _to_tensor_if_needed(value) for key, value in vision_encoder_kwargs.items()} + return JointImageInfo( + vae_image_info=vae_image_info, + vision_image_info=vision_image_info, + vision_encoder_kwargs=vision_encoder_kwargs, + ) + + +def get_hunyuan_image_3_pre_process_func(od_config: OmniDiffusionConfig): + hf_config = get_config(od_config.model, trust_remote_code=True) + image_processor = HunyuanImage3ImageProcessor(hf_config) + vae_h_factor = hf_config.vae_downsample_factor[0] * hf_config.patch_size + vae_w_factor = hf_config.vae_downsample_factor[1] * hf_config.patch_size + vit_patch_size = getattr(image_processor.vision_encoder_processor, "patch_size", 1) + if isinstance(vit_patch_size, tuple | list): + vit_patch_size = int(vit_patch_size[0]) + + def _build_cond_joint_image(raw_image: Any) -> dict[str, Any]: + pil_image = _to_pil_image(raw_image).convert("RGB") + orig_width, orig_height = pil_image.size + + target_width, target_height = image_processor.reso_group.get_target_size(orig_width, orig_height) + target_width = int(target_width) + target_height = int(target_height) + vae_input = _resize_and_crop_center(pil_image, target_width, target_height) + vae_tensor = image_processor.vae_processor(vae_input) + base_size, ratio_idx = image_processor.reso_group.get_base_size_and_ratio_index(orig_width, orig_height) + base_size = int(base_size) + ratio_idx = int(ratio_idx) + + vae_info = ImageInfo( + image_type="vae", + image_tensor=vae_tensor, + image_width=target_width, + image_height=target_height, + token_width=target_width // vae_w_factor, + token_height=target_height // vae_h_factor, + base_size=base_size, + ratio_index=ratio_idx, + ) + + vit_inputs = image_processor.vision_encoder_processor(pil_image, return_tensors="pt") + vit_tensor = vit_inputs["pixel_values"] + spatial_shapes = vit_inputs["spatial_shapes"].squeeze(0) + pixel_attention_mask = vit_inputs["pixel_attention_mask"].squeeze(0) + vit_token_h = int(spatial_shapes[0].item()) + vit_token_w = int(spatial_shapes[1].item()) + + vit_info = ImageInfo( + image_type="siglip2", + image_tensor=vit_tensor, + image_width=vit_token_w * vit_patch_size, + image_height=vit_token_h * vit_patch_size, + token_width=vit_token_w, + token_height=vit_token_h, + image_token_length=int(vit_tensor.shape[1]), + ) + + return _joint_image_info_to_payload( + JointImageInfo( + vae_image_info=vae_info, + vision_image_info=vit_info, + vision_encoder_kwargs={ + "spatial_shapes": spatial_shapes, + "pixel_attention_mask": pixel_attention_mask, + }, + ) + ) + + def pre_process_func(request: OmniDiffusionRequest): + for i, prompt in enumerate(request.prompts): + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + multi_modal_data = prompt.get("multi_modal_data") or {} + raw_images = multi_modal_data.get("image") + if raw_images is None: + raw_images = prompt.get("pil_image") + has_images = raw_images is not None and (not isinstance(raw_images, list) or len(raw_images) > 0) + if has_images: + image_list = raw_images if isinstance(raw_images, list) else [raw_images] + cond_image_infos = [_build_cond_joint_image(image) for image in image_list] + prompt["additional_information"]["batch_cond_image_info"] = cond_image_infos + + first_image_w, first_image_h = _to_pil_image(image_list[0]).size + if request.sampling_params.width is None: + request.sampling_params.width = int(first_image_w) + if request.sampling_params.height is None: + request.sampling_params.height = int(first_image_h) + + request.prompts[i] = prompt + + return request + + return pre_process_func + + +class HunyuanImage3Pipeline( + HunyuanImage3PreTrainedModel, + GenerationMixin, + SupportImageInput, + DiffusionPipelineProfilerMixin, +): + support_image_input = True hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.": "", @@ -108,6 +348,7 @@ def __init__(self, od_config: OmniDiffusionConfig) -> None: self.model = HunyuanImage3Model(self.hf_config, quant_config=quant_config) self.transformer = self.model self.vae = AutoencoderKLConv3D.from_config(self.hf_config.vae) + self.vae.use_spatial_tiling = self.od_config.vae_use_tiling self._pipeline = None self._tkwrapper = TokenizerWrapper(od_config.model) self.image_processor = HunyuanImage3ImageProcessor(self.hf_config) @@ -380,6 +621,13 @@ def build_batch_rope_image_info(output, sections): def vae_encode(self, image, cfg_factor=1): config = self.vae.config + if image.ndim == 3: + image = image.unsqueeze(0) + if image.ndim == 4: + image = image.unsqueeze(2) + if image.ndim != 5: + raise ValueError(f"Expected image tensor with 3/4/5 dims, got shape {tuple(image.shape)}.") + with torch.autocast(device_type=self.model.device.type, dtype=torch.float16, enabled=True): vae_encode_result = self.vae.encode(image) if isinstance(vae_encode_result, torch.Tensor): @@ -499,8 +747,7 @@ def prepare_model_inputs( batch_cot_text = cot_text batch_system_prompt = system_prompt batch_gen_image_info = None - # TODO: construct with user input images - batch_cond_image_info = None + batch_cond_image_info = kwargs.pop("batch_cond_image_info", None) # -- 2.1 message_list if batch_message_list is not None: @@ -545,6 +792,12 @@ def prepare_model_inputs( if mode == "gen_image": batch_gen_image_info = [self.image_processor.build_image_info(image_size) for _ in range(batch_size)] + if batch_cond_image_info is not None: + assert isinstance(batch_cond_image_info, list) and len(batch_cond_image_info) == batch_size, ( + "`batch_cond_image_info` should be a list with the same batch size as `prompt`." + ) + batch_cond_image_info = [cond if isinstance(cond, list) else [cond] for cond in batch_cond_image_info] + # -- 2.3 seed generator = kwargs.get("generator", None) if generator is None: @@ -556,6 +809,11 @@ def prepare_model_inputs( bot_task = kwargs.pop("bot_task", "auto") # If `drop_think` enabled, always drop parts in the context. drop_think = kwargs.get("drop_think", self.generation_config.drop_think) + # Pull sequence_template from the model's generation_config so the DiT + # text prefix matches how the model was trained (Instruct for the + # HunyuanImage-3.0-Instruct checkpoint). Falling back to "pretrain" + # only if the config does not specify it. + sequence_template = getattr(self.generation_config, "sequence_template", "pretrain") # Apply batched prompt or batched message_list to build input sequence with associated info. out = self._tkwrapper.apply_chat_template( batch_prompt=batch_prompt, @@ -568,7 +826,7 @@ def prepare_model_inputs( max_length=kwargs.get("max_length"), bot_task=bot_task, image_base_size=self.config.image_base_size, - sequence_template="pretrain", + sequence_template=sequence_template, cfg_factor=cfg_factor[mode], drop_think=drop_think, ) @@ -579,18 +837,21 @@ def prepare_model_inputs( cond_vae_images, cond_timestep, cond_vit_images = self._encode_cond_image( batch_cond_image_info, cfg_factor[mode] ) - # Pack vit kwargs. Siglip2-so requires spatial_shapes and attention_mask for inference. - vit_kwargs = {"spatial_shapes": [], "attention_mask": []} + # Pack vit kwargs. Siglip2 requires spatial_shapes and pixel_attention_mask + # at the forward boundary. transformers >=5.54 renamed the kwarg from + # `attention_mask` to `pixel_attention_mask` so the dict key must match + # the expected forward signature. + vit_kwargs = {"spatial_shapes": [], "pixel_attention_mask": []} for cond_image_info in batch_cond_image_info: vit_kwargs["spatial_shapes"].append( torch.stack([item.vision_encoder_kwargs["spatial_shapes"] for item in cond_image_info]) ) - vit_kwargs["attention_mask"].append( + vit_kwargs["pixel_attention_mask"].append( torch.stack([item.vision_encoder_kwargs["pixel_attention_mask"] for item in cond_image_info]) ) if cfg_factor[mode] > 1: vit_kwargs["spatial_shapes"] = vit_kwargs["spatial_shapes"] * cfg_factor[mode] - vit_kwargs["attention_mask"] = vit_kwargs["attention_mask"] * cfg_factor[mode] + vit_kwargs["pixel_attention_mask"] = vit_kwargs["pixel_attention_mask"] * cfg_factor[mode] else: cond_vae_images, cond_timestep, cond_vit_images = None, None, None vit_kwargs = None @@ -627,8 +888,16 @@ def prepare_model_inputs( stop_token_id = dict( auto=[tkw.eos_token_id] + extra_auto_stops, image=[tkw.eos_token_id], - recaption=[tkw.end_recaption_token_id, tkw.end_answer_token_id, tkw.eos_token_id], - think=[tkw.end_recaption_token_id, tkw.end_answer_token_id, tkw.eos_token_id], + recaption=[ + tkw.end_recaption_token_id, + tkw.end_answer_token_id, + tkw.eos_token_id, + ], + think=[ + tkw.end_recaption_token_id, + tkw.end_answer_token_id, + tkw.eos_token_id, + ], img_ratio=extra_auto_stops, ) model_input_kwargs = dict( @@ -818,7 +1087,10 @@ def _generate( # 50 and 5.0 hard code results = self.pipeline( batch_size=len(batch_gen_image_info), - image_size=[batch_gen_image_info[0].image_height, batch_gen_image_info[0].image_width], + image_size=[ + batch_gen_image_info[0].image_height, + batch_gen_image_info[0].image_width, + ], num_inference_steps=kwargs.get("num_inference_steps", 50), guidance_scale=kwargs.get("guidance_scale", 5.0), generator=generator, @@ -1006,10 +1278,51 @@ def forward( extra_args = getattr(getattr(req, "sampling_params", None), "extra_args", {}) or {} use_system_prompt = extra_args.get("use_system_prompt") system_prompt = extra_args.get("system_prompt") + # Fall back to per-prompt use_system_prompt forwarded by ar2diffusion + if use_system_prompt is None and req.prompts: + first_prompt = req.prompts[0] + if isinstance(first_prompt, dict): + use_system_prompt = first_prompt.get("use_system_prompt") if use_system_prompt is not None: system_prompt = get_system_prompt(use_system_prompt, "image", system_prompt) system_prompt = system_prompt.strip() if system_prompt is not None else "" prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + + # Extract AR-generated CoT/recaption text from each prompt's extra dict. + # The AR-side stage input processor (``ar2diffusion``) already prepends + # the trigger tag (e.g. ````) when the AR used the KV-reuse + # pretrain format, so ``ar_generated_text`` is a self-contained string + # and ``get_cot_sections()`` can parse the think/recaption structure + # directly. + cot_text_list = [] + for p in req.prompts: + extra = p.get("extra", {}) if isinstance(p, dict) else {} + cot_text_list.append(extra.get("ar_generated_text") or None) + cot_text = cot_text_list if any(t is not None for t in cot_text_list) else None + + batch_cond_image_info: list[list[JointImageInfo]] | None = None + if any(not isinstance(p, str) for p in req.prompts): + batch_cond_image_info = [] + for prompt_item in req.prompts: + if isinstance(prompt_item, str): + batch_cond_image_info.append([]) + continue + prompt_additional_information = prompt_item.get("additional_information") or {} + prompt_cond_infos = prompt_additional_information.get("batch_cond_image_info", []) + if isinstance(prompt_cond_infos, JointImageInfo | dict): + prompt_cond_infos = [prompt_cond_infos] + if prompt_cond_infos is None: + prompt_cond_infos = [] + batch_cond_image_info.append([_joint_image_info_from_payload(item) for item in prompt_cond_infos]) + + has_cond_image = [len(cond_infos) > 0 for cond_infos in batch_cond_image_info] + if any(has_cond_image) and not all(has_cond_image): + raise ValueError( + "When batching Hunyuan image editing requests, every prompt must include input image(s)." + ) + if not any(has_cond_image): + batch_cond_image_info = None + generator = req.sampling_params.generator or generator height = req.sampling_params.height or height width = req.sampling_params.width or width @@ -1019,17 +1332,20 @@ def forward( if guidance_scale <= 1.0: logger.info("HunyuanImage3.0 runs without classifier-free guidance when guidance_scale <= 1.0.") image_size = (height, width) + model_inputs = self.prepare_model_inputs( prompt=prompt, - cot_text=None, + cot_text=cot_text, system_prompt=system_prompt, mode="gen_image", generator=generator, image_size=image_size, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, + batch_cond_image_info=batch_cond_image_info, ) outputs = self._generate(**model_inputs, **kwargs) return DiffusionOutput( - output=outputs[0], stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None + output=outputs[0], + stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None, ) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 753e6229352..59497c77808 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -482,6 +482,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "HeliosPipeline": "get_helios_pre_process_func", "HeliosPyramidPipeline": "get_helios_pre_process_func", "HunyuanVideo15ImageToVideoPipeline": "get_hunyuan_video_15_i2v_pre_process_func", + "HunyuanImage3ForCausalMM": "get_hunyuan_image_3_pre_process_func", "MagiHumanPipeline": "get_magi_human_pre_process_func", } diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml index 413e0f09cbe..31511697371 100644 --- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml @@ -36,7 +36,7 @@ stage_args: top_k: 1024 max_tokens: 4096 stop_token_ids: [127957] # <|endoftext|> - detokenize: false + detokenize: true # DiT bridge consumes ar_generated_text; let the AR engine produce it # Stage 1: Diffusion (DiT + VAE) # Receives latents from AR stage, performs denoising + VAE decode @@ -56,8 +56,6 @@ stage_args: parallel_config: tensor_parallel_size: 4 enable_expert_parallel: true - omni_kv_config: - need_recv_cache: true engine_input_source: [0] # Input from AR stage custom_process_input_func: vllm_omni.model_executor.stage_input_processors.hunyuan_image3.ar2diffusion final_output: true diff --git a/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py b/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py index 0c0e6d7b37f..59f3baa8058 100644 --- a/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py +++ b/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py @@ -10,6 +10,8 @@ signature pattern as glm_image.ar2diffusion. """ +from __future__ import annotations + from typing import Any import torch @@ -74,6 +76,7 @@ def ar2diffusion( height = original_prompt.get("height", 1024) width = original_prompt.get("width", 1024) text_prompt = original_prompt.get("prompt", "") + use_system_prompt = original_prompt.get("use_system_prompt") logger.info( "[ar2diffusion] Request %d: AR generated %d tokens, text length=%d, target size=%dx%d", @@ -96,16 +99,20 @@ def ar2diffusion( }, } - # Forward multimodal data (original image for IT2I conditioning) + # Forward use_system_prompt so the DiT can build the same system prefix + if use_system_prompt is not None: + diffusion_input["use_system_prompt"] = use_system_prompt + + # Forward multimodal data (original image for IT2I conditioning). + # The diffusion pre_process_func reads multi_modal_data["image"], which + # matches vLLM's standard prompt schema, so we only need to pass it once. mm_data = original_prompt.get("multi_modal_data") if mm_data: - pil_image = mm_data.get("image") - if pil_image is None: - images = mm_data.get("images") - if images: - pil_image = images[0] if isinstance(images, list) else images - if pil_image is not None: - diffusion_input["pil_image"] = pil_image + prompt_images = mm_data.get("image") + if prompt_images is None: + prompt_images = mm_data.get("images") + if prompt_images is not None: + diffusion_input["multi_modal_data"] = {"image": prompt_images} # Forward multimodal output from AR (if any) if hasattr(ar_output, "multimodal_output") and ar_output.multimodal_output: