Skip to content
79 changes: 79 additions & 0 deletions tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuan import (
DistributedAutoencoderKLHunyuan,
)

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]


class _DummyDistributedAutoencoderKLHunyuan(DistributedAutoencoderKLHunyuan):
def __init__(self):
torch.nn.Module.__init__(self)
self.tile_latent_min_size = 8
self.tile_sample_min_size = 8
self.tile_overlap_factor = 0.25
self.use_spatial_tiling = False

@property
def dtype(self) -> torch.dtype:
return torch.float32

def decoder(self, x: torch.Tensor) -> torch.Tensor:
return x + 10

def encoder(self, x: torch.Tensor) -> torch.Tensor:
return x + 20

def blend_v(self, _a: torch.Tensor, b: torch.Tensor, _blend_extent: int) -> torch.Tensor:
return b

def blend_h(self, _a: torch.Tensor, b: torch.Tensor, _blend_extent: int) -> torch.Tensor:
return b


def test_hunyuan_vae_use_tiling_aliases_spatial_tiling():
# Verify use_tiling property maps to use_spatial_tiling.
vae = _DummyDistributedAutoencoderKLHunyuan()

assert not vae.use_tiling

vae.use_tiling = True

assert vae.use_spatial_tiling


def test_hunyuan_vae_decode_tiles_round_trip():
# Validate decode tile split/exec/merge returns expected reconstructed tensor.
vae = _DummyDistributedAutoencoderKLHunyuan()
z = torch.arange(144, dtype=torch.float32).reshape(1, 1, 1, 12, 12)

tile_tasks, grid_spec = vae.tile_split(z)
decoded_tiles = {task.grid_coord: vae.tile_exec(task) for task in tile_tasks}
output = vae.tile_merge(decoded_tiles, grid_spec)

assert grid_spec.split_dims == (3, 4)
assert grid_spec.grid_shape == (2, 2)
assert grid_spec.tile_spec == {"blend_extent": 2, "row_limit": 6}
assert len(tile_tasks) == 4
assert torch.equal(output, z + 10)


def test_hunyuan_vae_encode_tiles_round_trip():
# Validate encode tile split/exec/merge returns expected latent tensor.
vae = _DummyDistributedAutoencoderKLHunyuan()
x = torch.arange(144, dtype=torch.float32).reshape(1, 1, 1, 12, 12)

tile_tasks, grid_spec = vae.encode_tile_split(x)
encoded_tiles = {task.grid_coord: vae.encode_tile_exec(task) for task in tile_tasks}
output = vae.encode_tile_merge(encoded_tiles, grid_spec)

assert grid_spec.split_dims == (3, 4)
assert grid_spec.grid_shape == (2, 2)
assert grid_spec.tile_spec == {"blend_extent": 2, "row_limit": 6}
assert len(tile_tasks) == 4
assert torch.equal(output, x + 20)
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Any

import torch
from vllm.logger import init_logger

from vllm_omni.diffusion.distributed.autoencoders.distributed_vae_executor import (
DistributedOperator,
DistributedVaeMixin,
GridSpec,
TileTask,
)
from vllm_omni.diffusion.models.hunyuan_image3.autoencoder import AutoencoderKLConv3D

logger = init_logger(__name__)


class DistributedAutoencoderKLHunyuan(AutoencoderKLConv3D, DistributedVaeMixin):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing from_pretrained, consistency suggests adding it

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified

@classmethod
def from_config(cls, config: Any, **kwargs: Any):
model = super().from_config(config, **kwargs)
model.init_distributed()
return model

@classmethod
def from_pretrained(cls, *args: Any, **kwargs: Any):
model = super().from_pretrained(*args, **kwargs)
model.init_distributed()
return model

@property
def use_tiling(self) -> bool:
return self.use_spatial_tiling

@use_tiling.setter
def use_tiling(self, use_tiling: bool) -> None:
self.use_spatial_tiling = use_tiling

def tile_split(self, z: torch.Tensor) -> tuple[list[TileTask], GridSpec]:
_, _, _, height, width = z.shape
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = int(self.tile_sample_min_size - blend_extent)

tiletask_list = []
for i in range(0, height, overlap_size):
for j in range(0, width, overlap_size):
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
tiletask_list.append(
TileTask(
len(tiletask_list),
(i // overlap_size, j // overlap_size),
tile,
workload=tile.shape[3] * tile.shape[4],
)
)

grid_spec = GridSpec(
split_dims=(3, 4),
grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1),
tile_spec={"blend_extent": blend_extent, "row_limit": row_limit},
output_dtype=self.dtype,
)
return tiletask_list, grid_spec

def tile_exec(self, task: TileTask) -> torch.Tensor:
return self.decoder(task.tensor)

def tile_merge(self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec) -> torch.Tensor:
grid_h, grid_w = grid_spec.grid_shape
result_rows = []
for i in range(grid_h):
result_row = []
for j in range(grid_w):
tile = coord_tensor_map[(i, j)]
if i > 0:
tile = self.blend_v(coord_tensor_map[(i - 1, j)], tile, grid_spec.tile_spec["blend_extent"])
if j > 0:
tile = self.blend_h(coord_tensor_map[(i, j - 1)], tile, grid_spec.tile_spec["blend_extent"])
result_row.append(tile[:, :, :, : grid_spec.tile_spec["row_limit"], : grid_spec.tile_spec["row_limit"]])
result_rows.append(torch.cat(result_row, dim=-1))
return torch.cat(result_rows, dim=-2)

def encode_tile_split(self, x: torch.Tensor) -> tuple[list[TileTask], GridSpec]:
_, _, _, height, width = x.shape
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = int(self.tile_latent_min_size - blend_extent)

tiletask_list = []
for i in range(0, height, overlap_size):
for j in range(0, width, overlap_size):
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tiletask_list.append(
TileTask(
len(tiletask_list),
(i // overlap_size, j // overlap_size),
tile,
workload=tile.shape[3] * tile.shape[4],
)
)

grid_spec = GridSpec(
split_dims=(3, 4),
grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1),
tile_spec={"blend_extent": blend_extent, "row_limit": row_limit},
output_dtype=self.dtype,
)
return tiletask_list, grid_spec

def encode_tile_exec(self, task: TileTask) -> torch.Tensor:
return self.encoder(task.tensor)

def encode_tile_merge(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tile_merge and encode_tile_merge are byte-for-byte identical. Could extract helper

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified

self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec
) -> torch.Tensor:
return self.tile_merge(coord_tensor_map, grid_spec)

def spatial_tiled_encode(self, x: torch.Tensor):
if not self.is_distributed_enabled():
return super().spatial_tiled_encode(x)

logger.debug("Encode running with distributed executor")
return self.distributed_executor.execute(
x,
DistributedOperator(
split=self.encode_tile_split,
exec=self.encode_tile_exec,
merge=self.encode_tile_merge,
),
broadcast_result=True,
)

def spatial_tiled_decode(self, z: torch.Tensor):
if not self.is_distributed_enabled():
return super().spatial_tiled_decode(z)

logger.debug("Decode running with distributed executor")
return self.distributed_executor.execute(
z,
DistributedOperator(split=self.tile_split, exec=self.tile_exec, merge=self.tile_merge),
broadcast_result=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from vllm_omni.inputs.data import OmniTextPrompt
from vllm_omni.model_executor.models.hunyuan_image3.siglip2 import Siglip2VisionTransformer

from .autoencoder import AutoencoderKLConv3D
from .hunyuan_image3_tokenizer import TokenizerWrapper
from .hunyuan_image3_transformer import (
CausalMMOutputWithPast,
Expand Down Expand Up @@ -343,7 +342,13 @@ def __init__(self, od_config: OmniDiffusionConfig) -> None:
quant_config = od_config.quantization_config
self.model = HunyuanImage3Model(self.hf_config, quant_config=quant_config)
self.transformer = self.model
self.vae = AutoencoderKLConv3D.from_config(self.hf_config.vae)
# Lazy import to break circular dependency:
# autoencoder_kl_hunyuan -> hunyuan_image3/__init__ -> pipeline_hunyuan_image3 -> autoencoder_kl_hunyuan
from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuan import ( # noqa: PLC0415
DistributedAutoencoderKLHunyuan,
)

self.vae = DistributedAutoencoderKLHunyuan.from_config(self.hf_config.vae)
self.vae.use_spatial_tiling = self.od_config.vae_use_tiling
self._pipeline = None
self._tkwrapper = TokenizerWrapper(od_config.model)
Expand Down
Loading