diff --git a/tests/diffusion/test_hunyuan_image3_edit_preprocess.py b/tests/diffusion/test_hunyuan_image3_edit_preprocess.py new file mode 100644 index 00000000000..3394b379d6e --- /dev/null +++ b/tests/diffusion/test_hunyuan_image3_edit_preprocess.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +import numpy as np +import pytest +import torch +from PIL import Image +from pytest_mock import MockerFixture + +from vllm_omni.diffusion.models.hunyuan_image_3 import pipeline_hunyuan_image_3 as hy3_module +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] + + +def test_hunyuan_image3_preprocess_builds_joint_image_info(mocker: MockerFixture): + config = SimpleNamespace(vae_downsample_factor=(8, 8), patch_size=2) + mocker.patch.object(hy3_module, "get_config", return_value=config) + + class _DummyVisionProcessor: + patch_size = 16 + + def __call__(self, image, return_tensors="pt"): + del image, return_tensors + return { + "pixel_values": torch.zeros(1, 10, 4), + "spatial_shapes": torch.tensor([[2, 5]], dtype=torch.long), + "pixel_attention_mask": torch.ones(1, 10, dtype=torch.long), + } + + class _DummyImageProcessor: + def __init__(self, _cfg): + self.reso_group = SimpleNamespace( + get_target_size=lambda width, height: (np.int64(width), np.int64(height)), + get_base_size_and_ratio_index=lambda width, height: (np.int64(1024), np.int64(3)), + ) + self.vae_processor = lambda image: torch.zeros(3, image.size[1], image.size[0]) + self.vision_encoder_processor = _DummyVisionProcessor() + + mocker.patch.object(hy3_module, "HunyuanImage3ImageProcessor", _DummyImageProcessor) + + preprocess = hy3_module.get_hunyuan_image_3_pre_process_func(SimpleNamespace(model="dummy-model")) + request = OmniDiffusionRequest( + prompts=[{"prompt": "edit image", "multi_modal_data": {"image": [Image.new("RGB", (32, 16), "white")]}}], + sampling_params=OmniDiffusionSamplingParams(), + ) + + request = preprocess(request) + prompt = request.prompts[0] + cond_infos = prompt["additional_information"]["batch_cond_image_info"] + assert len(cond_infos) == 1 + + cond_info_payload = cond_infos[0] + assert isinstance(cond_info_payload, dict) + assert cond_info_payload["type"] == "joint_image_info" + assert cond_info_payload["vae_image_info"]["image_tensor"].shape == (3, 16, 32) + assert cond_info_payload["vae_image_info"]["token_width"] == 2 + assert cond_info_payload["vae_image_info"]["token_height"] == 1 + assert isinstance(cond_info_payload["vae_image_info"]["image_width"], int) + assert isinstance(cond_info_payload["vae_image_info"]["image_height"], int) + assert isinstance(cond_info_payload["vae_image_info"]["base_size"], int) + assert isinstance(cond_info_payload["vae_image_info"]["ratio_index"], int) + assert tuple(cond_info_payload["vision_encoder_kwargs"]["spatial_shapes"].tolist()) == (2, 5) + roundtrip_cond_info = hy3_module._joint_image_info_from_payload(cond_info_payload) + assert isinstance(roundtrip_cond_info, hy3_module.JointImageInfo) + assert roundtrip_cond_info.vae_image_info.image_tensor.shape == (3, 16, 32) + assert request.sampling_params.width == 32 + assert request.sampling_params.height == 16 + + +def test_hunyuan_image3_light_projector_is_callable(): + projector = hy3_module.LightProjector( + { + "projector_type": "linear", + "input_dim": 4, + "n_embed": 8, + } + ) + inputs = torch.randn(2, 3, 4) + outputs = projector(inputs) + assert outputs.shape == (2, 3, 8) diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py index bc81ca9c3ed..b1f7b7a51b6 100644 --- a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py +++ b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py @@ -746,6 +746,9 @@ def __init__(self, config): self.layers = modules + def forward(self, x): + return self.layers(x) + class HunYuanRotary2DEmbedder: r""" diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py b/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py index 7e9e2d27877..38b6361dc26 100644 --- a/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py +++ b/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler +from PIL import Image as PILImage from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.utils import ALL_CACHE_NAMES, GenerationMixin from transformers.models.siglip2 import Siglip2VisionConfig, Siglip2VisionModel @@ -23,6 +24,7 @@ from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniTextPrompt from .autoencoder import AutoencoderKLConv3D from .hunyuan_image_3_tokenizer import TokenizerWrapper @@ -64,7 +66,227 @@ def to_device(data, device): return data +def _to_pil_image(image: Any) -> PILImage.Image: + if isinstance(image, PILImage.Image): + return image + if isinstance(image, str): + return PILImage.open(image) + if isinstance(image, np.ndarray): + array = image + if array.dtype != np.uint8: + if np.issubdtype(array.dtype, np.floating): + if float(np.min(array)) < 0.0: + array = (np.clip(array, -1.0, 1.0) + 1.0) / 2.0 + if float(np.max(array)) <= 1.0: + array = array * 255.0 + array = np.clip(array, 0, 255).astype(np.uint8) + if array.ndim == 3 and array.shape[0] in (1, 3, 4): + array = np.transpose(array, (1, 2, 0)) + return PILImage.fromarray(array) + if isinstance(image, torch.Tensor): + tensor = image.detach().cpu() + if tensor.ndim == 4: + if tensor.shape[0] != 1: + raise ValueError(f"Only a single image tensor is supported, but got shape {tuple(tensor.shape)}.") + tensor = tensor.squeeze(0) + if tensor.ndim == 3 and tensor.shape[0] in (1, 3, 4): + tensor = tensor.permute(1, 2, 0) + if tensor.dtype.is_floating_point: + if float(tensor.min()) < 0.0: + tensor = (tensor.clamp(-1.0, 1.0) + 1.0) / 2.0 + if float(tensor.max()) > 1.0: + tensor = tensor / 255.0 + tensor = (tensor.clamp(0.0, 1.0) * 255.0).to(torch.uint8) + else: + tensor = tensor.to(torch.uint8) + return PILImage.fromarray(tensor.numpy()) + raise TypeError(f"Unsupported image input type: {type(image)}") + + +def _resize_and_crop_center(image: PILImage.Image, target_width: int, target_height: int) -> PILImage.Image: + src_width, src_height = image.size + scale = max(target_width / src_width, target_height / src_height) + resized_width = max(target_width, int(round(src_width * scale))) + resized_height = max(target_height, int(round(src_height * scale))) + resized = image.resize((resized_width, resized_height), PILImage.Resampling.LANCZOS) + left = max((resized_width - target_width) // 2, 0) + top = max((resized_height - target_height) // 2, 0) + right = left + target_width + bottom = top + target_height + return resized.crop((left, top, right, bottom)) + + +def _to_python_scalar(value: Any) -> Any: + if isinstance(value, np.generic): + return value.item() + return value + + +def _image_info_to_payload(image_info: ImageInfo) -> dict[str, Any]: + return { + "image_type": image_info.image_type, + "image_tensor": image_info.image_tensor, + "image_width": _to_python_scalar(image_info.image_width), + "image_height": _to_python_scalar(image_info.image_height), + "token_width": _to_python_scalar(image_info.token_width), + "token_height": _to_python_scalar(image_info.token_height), + "image_token_length": _to_python_scalar(image_info.image_token_length), + "base_size": _to_python_scalar(image_info.base_size), + "ratio_index": _to_python_scalar(image_info.ratio_index), + "add_timestep_token": image_info.add_timestep_token, + "add_guidance_token": image_info.add_guidance_token, + "use_front_boi_token": image_info.use_front_boi_token, + "add_image_shape_token": image_info.add_image_shape_token, + } + + +def _to_tensor_if_needed(value: Any) -> Any: + if isinstance(value, np.generic): + return value.item() + if isinstance(value, list): + return torch.tensor(value) + return value + + +def _image_info_from_payload(payload: dict[str, Any]) -> ImageInfo: + return ImageInfo( + image_type=payload.get("image_type"), + image_tensor=_to_tensor_if_needed(payload.get("image_tensor")), + image_width=payload.get("image_width"), + image_height=payload.get("image_height"), + token_width=payload.get("token_width"), + token_height=payload.get("token_height"), + image_token_length=payload.get("image_token_length"), + base_size=payload.get("base_size"), + ratio_index=payload.get("ratio_index"), + add_timestep_token=payload.get("add_timestep_token", True), + add_guidance_token=payload.get("add_guidance_token", False), + use_front_boi_token=payload.get("use_front_boi_token", True), + add_image_shape_token=payload.get("add_image_shape_token", True), + ) + + +def _joint_image_info_to_payload(joint_image_info: JointImageInfo) -> dict[str, Any]: + return { + "type": "joint_image_info", + "vae_image_info": _image_info_to_payload(joint_image_info.vae_image_info), + "vision_image_info": _image_info_to_payload(joint_image_info.vision_image_info), + "vision_encoder_kwargs": joint_image_info.vision_encoder_kwargs, + } + + +def _joint_image_info_from_payload(payload: Any) -> JointImageInfo: + if isinstance(payload, JointImageInfo): + return payload + if not isinstance(payload, dict): + raise TypeError(f"Expected dict or JointImageInfo for conditional image payload, got {type(payload)}.") + + vae_image_info = _image_info_from_payload(payload["vae_image_info"]) + vision_image_info = _image_info_from_payload(payload["vision_image_info"]) + vision_encoder_kwargs = payload.get("vision_encoder_kwargs") or {} + if isinstance(vision_encoder_kwargs, dict): + vision_encoder_kwargs = {k: _to_tensor_if_needed(v) for k, v in vision_encoder_kwargs.items()} + return JointImageInfo( + vae_image_info=vae_image_info, + vision_image_info=vision_image_info, + vision_encoder_kwargs=vision_encoder_kwargs, + ) + + +def get_hunyuan_image_3_pre_process_func( + od_config: OmniDiffusionConfig, +): + hf_config = get_config(od_config.model, trust_remote_code=True) + image_processor = HunyuanImage3ImageProcessor(hf_config) + vae_h_factor = hf_config.vae_downsample_factor[0] * hf_config.patch_size + vae_w_factor = hf_config.vae_downsample_factor[1] * hf_config.patch_size + vit_patch_size = getattr(image_processor.vision_encoder_processor, "patch_size", 1) + if isinstance(vit_patch_size, (tuple, list)): + vit_patch_size = int(vit_patch_size[0]) + + def _build_cond_joint_image(raw_image: Any) -> dict[str, Any]: + pil_image = _to_pil_image(raw_image).convert("RGB") + orig_width, orig_height = pil_image.size + + target_width, target_height = image_processor.reso_group.get_target_size(orig_width, orig_height) + target_width = int(target_width) + target_height = int(target_height) + vae_input = _resize_and_crop_center(pil_image, target_width, target_height) + vae_tensor = image_processor.vae_processor(vae_input) + base_size, ratio_idx = image_processor.reso_group.get_base_size_and_ratio_index(orig_width, orig_height) + base_size = int(base_size) + ratio_idx = int(ratio_idx) + + vae_info = ImageInfo( + image_type="vae", + image_tensor=vae_tensor, + image_width=target_width, + image_height=target_height, + token_width=target_width // vae_w_factor, + token_height=target_height // vae_h_factor, + base_size=base_size, + ratio_index=ratio_idx, + ) + + vit_inputs = image_processor.vision_encoder_processor(pil_image, return_tensors="pt") + vit_tensor = vit_inputs["pixel_values"] + spatial_shapes = vit_inputs["spatial_shapes"].squeeze(0) + pixel_attention_mask = vit_inputs["pixel_attention_mask"].squeeze(0) + vit_token_h = int(spatial_shapes[0].item()) + vit_token_w = int(spatial_shapes[1].item()) + + vit_info = ImageInfo( + image_type="siglip2", + image_tensor=vit_tensor, + image_width=vit_token_w * vit_patch_size, + image_height=vit_token_h * vit_patch_size, + token_width=vit_token_w, + token_height=vit_token_h, + image_token_length=int(vit_tensor.shape[1]), + ) + + return _joint_image_info_to_payload( + JointImageInfo( + vae_image_info=vae_info, + vision_image_info=vit_info, + vision_encoder_kwargs={ + "spatial_shapes": spatial_shapes, + "pixel_attention_mask": pixel_attention_mask, + }, + ) + ) + + def pre_process_func(request: OmniDiffusionRequest): + for i, prompt in enumerate(request.prompts): + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + multi_modal_data = prompt.get("multi_modal_data") or {} + raw_images = multi_modal_data.get("image") + has_images = raw_images is not None and (not isinstance(raw_images, list) or len(raw_images) > 0) + if has_images: + image_list = raw_images if isinstance(raw_images, list) else [raw_images] + cond_image_infos = [_build_cond_joint_image(image) for image in image_list] + prompt["additional_information"]["batch_cond_image_info"] = cond_image_infos + + first_image_w, first_image_h = _to_pil_image(image_list[0]).size + if request.sampling_params.width is None: + request.sampling_params.width = int(first_image_w) + if request.sampling_params.height is None: + request.sampling_params.height = int(first_image_h) + + request.prompts[i] = prompt + + return request + + return pre_process_func + + class HunyuanImage3Pipeline(HunyuanImage3PreTrainedModel, GenerationMixin, DiffusionPipelineProfilerMixin): + support_image_input = True _PROFILER_TARGETS = [ "model.forward", "model.layers[0].forward", @@ -371,6 +593,12 @@ def build_batch_rope_image_info(output, sections): def vae_encode(self, image, cfg_factor=1): config = self.vae.config + if image.ndim == 3: + image = image.unsqueeze(0) + if image.ndim == 4: + image = image.unsqueeze(2) + if image.ndim != 5: + raise ValueError(f"Expected image tensor with 3/4/5 dims, got shape {tuple(image.shape)}.") with torch.autocast(device_type=self.model.device.type, dtype=torch.float16, enabled=True): vae_encode_result = self.vae.encode(image) @@ -491,8 +719,7 @@ def prepare_model_inputs( batch_cot_text = cot_text batch_system_prompt = system_prompt batch_gen_image_info = None - # TODO: construct with user input images - batch_cond_image_info = None + batch_cond_image_info = kwargs.pop("batch_cond_image_info", None) # -- 2.1 message_list if batch_message_list is not None: @@ -537,6 +764,12 @@ def prepare_model_inputs( if mode == "gen_image": batch_gen_image_info = [self.image_processor.build_image_info(image_size) for _ in range(batch_size)] + if batch_cond_image_info is not None: + assert isinstance(batch_cond_image_info, list) and len(batch_cond_image_info) == batch_size, ( + "`batch_cond_image_info` should be a list with the same batch size as `prompt`." + ) + batch_cond_image_info = [cond if isinstance(cond, list) else [cond] for cond in batch_cond_image_info] + # -- 2.3 seed generator = kwargs.get("generator", None) if generator is None: @@ -1002,6 +1235,28 @@ def forward( system_prompt = get_system_prompt(use_system_prompt, "image", system_prompt) system_prompt = system_prompt.strip() if system_prompt is not None else "" prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + batch_cond_image_info: list[list[JointImageInfo]] | None = None + if any(not isinstance(p, str) for p in req.prompts): + batch_cond_image_info = [] + for p in req.prompts: + if isinstance(p, str): + batch_cond_image_info.append([]) + continue + prompt_additional_information = p.get("additional_information") or {} + prompt_cond_infos = prompt_additional_information.get("batch_cond_image_info", []) + if isinstance(prompt_cond_infos, (JointImageInfo, dict)): + prompt_cond_infos = [prompt_cond_infos] + if prompt_cond_infos is None: + prompt_cond_infos = [] + batch_cond_image_info.append([_joint_image_info_from_payload(item) for item in prompt_cond_infos]) + has_cond_image = [len(cond_infos) > 0 for cond_infos in batch_cond_image_info] + if any(has_cond_image) and not all(has_cond_image): + raise ValueError( + "When batching Hunyuan image editing requests, every prompt must include input image(s)." + ) + if not any(has_cond_image): + batch_cond_image_info = None + generator = req.sampling_params.generator or generator height = req.sampling_params.height or height width = req.sampling_params.width or width @@ -1021,6 +1276,7 @@ def forward( image_size=image_size, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, + batch_cond_image_info=batch_cond_image_info, ) outputs = self._generate(**model_inputs, **kwargs) return DiffusionOutput( diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 97bc7fa2925..4291608cf97 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -392,6 +392,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "OmniGen2Pipeline": "get_omnigen2_pre_process_func", "HeliosPipeline": "get_helios_pre_process_func", "HeliosPyramidPipeline": "get_helios_pre_process_func", + "HunyuanImage3ForCausalMM": "get_hunyuan_image_3_pre_process_func", "HunyuanVideo15ImageToVideoPipeline": "get_hunyuan_video_15_i2v_pre_process_func", "MagiHumanPipeline": "get_magi_human_pre_process_func", }