diff --git a/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py b/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py new file mode 100644 index 00000000000..58ce14546e1 --- /dev/null +++ b/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuan import ( + DistributedAutoencoderKLHunyuan, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class _DummyDistributedAutoencoderKLHunyuan(DistributedAutoencoderKLHunyuan): + def __init__(self): + torch.nn.Module.__init__(self) + self.tile_latent_min_size = 8 + self.tile_sample_min_size = 8 + self.tile_overlap_factor = 0.25 + self.use_spatial_tiling = False + + @property + def dtype(self) -> torch.dtype: + return torch.float32 + + def decoder(self, x: torch.Tensor) -> torch.Tensor: + return x + 10 + + def encoder(self, x: torch.Tensor) -> torch.Tensor: + return x + 20 + + def blend_v(self, _a: torch.Tensor, b: torch.Tensor, _blend_extent: int) -> torch.Tensor: + return b + + def blend_h(self, _a: torch.Tensor, b: torch.Tensor, _blend_extent: int) -> torch.Tensor: + return b + + +def test_hunyuan_vae_use_tiling_aliases_spatial_tiling(): + # Verify use_tiling property maps to use_spatial_tiling. + vae = _DummyDistributedAutoencoderKLHunyuan() + + assert not vae.use_tiling + + vae.use_tiling = True + + assert vae.use_spatial_tiling + + +def test_hunyuan_vae_decode_tiles_round_trip(): + # Validate decode tile split/exec/merge returns expected reconstructed tensor. + vae = _DummyDistributedAutoencoderKLHunyuan() + z = torch.arange(144, dtype=torch.float32).reshape(1, 1, 1, 12, 12) + + tile_tasks, grid_spec = vae.tile_split(z) + decoded_tiles = {task.grid_coord: vae.tile_exec(task) for task in tile_tasks} + output = vae.tile_merge(decoded_tiles, grid_spec) + + assert grid_spec.split_dims == (3, 4) + assert grid_spec.grid_shape == (2, 2) + assert grid_spec.tile_spec == {"blend_extent": 2, "row_limit": 6} + assert len(tile_tasks) == 4 + assert torch.equal(output, z + 10) + + +def test_hunyuan_vae_encode_tiles_round_trip(): + # Validate encode tile split/exec/merge returns expected latent tensor. + vae = _DummyDistributedAutoencoderKLHunyuan() + x = torch.arange(144, dtype=torch.float32).reshape(1, 1, 1, 12, 12) + + tile_tasks, grid_spec = vae.encode_tile_split(x) + encoded_tiles = {task.grid_coord: vae.encode_tile_exec(task) for task in tile_tasks} + output = vae.encode_tile_merge(encoded_tiles, grid_spec) + + assert grid_spec.split_dims == (3, 4) + assert grid_spec.grid_shape == (2, 2) + assert grid_spec.tile_spec == {"blend_extent": 2, "row_limit": 6} + assert len(tile_tasks) == 4 + assert torch.equal(output, x + 20) diff --git a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_hunyuan.py b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_hunyuan.py new file mode 100644 index 00000000000..c0e3b4d6cc7 --- /dev/null +++ b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_hunyuan.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +import torch +from vllm.logger import init_logger + +from vllm_omni.diffusion.distributed.autoencoders.distributed_vae_executor import ( + DistributedOperator, + DistributedVaeMixin, + GridSpec, + TileTask, +) +from vllm_omni.diffusion.models.hunyuan_image3.autoencoder import AutoencoderKLConv3D + +logger = init_logger(__name__) + + +class DistributedAutoencoderKLHunyuan(AutoencoderKLConv3D, DistributedVaeMixin): + @classmethod + def from_config(cls, config: Any, **kwargs: Any): + model = super().from_config(config, **kwargs) + model.init_distributed() + return model + + @classmethod + def from_pretrained(cls, *args: Any, **kwargs: Any): + model = super().from_pretrained(*args, **kwargs) + model.init_distributed() + return model + + @property + def use_tiling(self) -> bool: + return self.use_spatial_tiling + + @use_tiling.setter + def use_tiling(self, use_tiling: bool) -> None: + self.use_spatial_tiling = use_tiling + + def tile_split(self, z: torch.Tensor) -> tuple[list[TileTask], GridSpec]: + _, _, _, height, width = z.shape + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = int(self.tile_sample_min_size - blend_extent) + + tiletask_list = [] + for i in range(0, height, overlap_size): + for j in range(0, width, overlap_size): + tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + tiletask_list.append( + TileTask( + len(tiletask_list), + (i // overlap_size, j // overlap_size), + tile, + workload=tile.shape[3] * tile.shape[4], + ) + ) + + grid_spec = GridSpec( + split_dims=(3, 4), + grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1), + tile_spec={"blend_extent": blend_extent, "row_limit": row_limit}, + output_dtype=self.dtype, + ) + return tiletask_list, grid_spec + + def tile_exec(self, task: TileTask) -> torch.Tensor: + return self.decoder(task.tensor) + + def tile_merge(self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec) -> torch.Tensor: + grid_h, grid_w = grid_spec.grid_shape + result_rows = [] + for i in range(grid_h): + result_row = [] + for j in range(grid_w): + tile = coord_tensor_map[(i, j)] + if i > 0: + tile = self.blend_v(coord_tensor_map[(i - 1, j)], tile, grid_spec.tile_spec["blend_extent"]) + if j > 0: + tile = self.blend_h(coord_tensor_map[(i, j - 1)], tile, grid_spec.tile_spec["blend_extent"]) + result_row.append(tile[:, :, :, : grid_spec.tile_spec["row_limit"], : grid_spec.tile_spec["row_limit"]]) + result_rows.append(torch.cat(result_row, dim=-1)) + return torch.cat(result_rows, dim=-2) + + def encode_tile_split(self, x: torch.Tensor) -> tuple[list[TileTask], GridSpec]: + _, _, _, height, width = x.shape + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = int(self.tile_latent_min_size - blend_extent) + + tiletask_list = [] + for i in range(0, height, overlap_size): + for j in range(0, width, overlap_size): + tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tiletask_list.append( + TileTask( + len(tiletask_list), + (i // overlap_size, j // overlap_size), + tile, + workload=tile.shape[3] * tile.shape[4], + ) + ) + + grid_spec = GridSpec( + split_dims=(3, 4), + grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1), + tile_spec={"blend_extent": blend_extent, "row_limit": row_limit}, + output_dtype=self.dtype, + ) + return tiletask_list, grid_spec + + def encode_tile_exec(self, task: TileTask) -> torch.Tensor: + return self.encoder(task.tensor) + + def encode_tile_merge( + self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec + ) -> torch.Tensor: + return self.tile_merge(coord_tensor_map, grid_spec) + + def spatial_tiled_encode(self, x: torch.Tensor): + if not self.is_distributed_enabled(): + return super().spatial_tiled_encode(x) + + logger.debug("Encode running with distributed executor") + return self.distributed_executor.execute( + x, + DistributedOperator( + split=self.encode_tile_split, + exec=self.encode_tile_exec, + merge=self.encode_tile_merge, + ), + broadcast_result=True, + ) + + def spatial_tiled_decode(self, z: torch.Tensor): + if not self.is_distributed_enabled(): + return super().spatial_tiled_decode(z) + + logger.debug("Decode running with distributed executor") + return self.distributed_executor.execute( + z, + DistributedOperator(split=self.tile_split, exec=self.tile_exec, merge=self.tile_merge), + broadcast_result=True, + ) diff --git a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py index 73b89bb11b0..677d67796be 100644 --- a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py @@ -30,7 +30,6 @@ from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.model_executor.models.hunyuan_image3.siglip2 import Siglip2VisionTransformer -from .autoencoder import AutoencoderKLConv3D from .hunyuan_image3_tokenizer import TokenizerWrapper from .hunyuan_image3_transformer import ( CausalMMOutputWithPast, @@ -343,7 +342,13 @@ def __init__(self, od_config: OmniDiffusionConfig) -> None: quant_config = od_config.quantization_config self.model = HunyuanImage3Model(self.hf_config, quant_config=quant_config) self.transformer = self.model - self.vae = AutoencoderKLConv3D.from_config(self.hf_config.vae) + # Lazy import to break circular dependency: + # autoencoder_kl_hunyuan -> hunyuan_image3/__init__ -> pipeline_hunyuan_image3 -> autoencoder_kl_hunyuan + from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuan import ( # noqa: PLC0415 + DistributedAutoencoderKLHunyuan, + ) + + self.vae = DistributedAutoencoderKLHunyuan.from_config(self.hf_config.vae) self.vae.use_spatial_tiling = self.od_config.vae_use_tiling self._pipeline = None self._tkwrapper = TokenizerWrapper(od_config.model)