Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 180 additions & 0 deletions tests/diffusion/distributed/test_autoencoder_kl_hunyuanvideo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""Unit tests for DistributedAutoencoderKLHunyuanVideo tile split/merge/blend (CPU-only)."""

import pytest
import torch

pytestmark = [pytest.mark.cpu, pytest.mark.core_model]


class _DummyHunyuanVae:
"""Minimal mock of DistributedAutoencoderKLHunyuanVideo for unit testing."""

def __init__(
self,
tile_sample_min_height=256,
tile_sample_min_width=256,
tile_overlap_factor=0.25,
spatial_ratio=8,
):
self.tile_sample_min_height = tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width
self.tile_overlap_factor = tile_overlap_factor
self.tile_sample_stride_height = int(tile_sample_min_height * (1 - tile_overlap_factor))
self.tile_sample_stride_width = int(tile_sample_min_width * (1 - tile_overlap_factor))
self.tile_latent_min_height = tile_sample_min_height // spatial_ratio
self.tile_latent_min_width = tile_sample_min_width // spatial_ratio
self.tile_latent_stride_height = int(self.tile_latent_min_height * (1 - tile_overlap_factor))
self.tile_latent_stride_width = int(self.tile_latent_min_width * (1 - tile_overlap_factor))
self.dtype = torch.float32

def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b

def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b

def decoder(self, z: torch.Tensor) -> torch.Tensor:
# Mock: upsample latent by spatial_ratio=8 along H and W
return z.repeat_interleave(8, dim=-2).repeat_interleave(8, dim=-1)


def _import_tile_split():
from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuanvideo import (
DistributedAutoencoderKLHunyuanVideo,
)

return DistributedAutoencoderKLHunyuanVideo.tile_split


def _import_tile_exec():
from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuanvideo import (
DistributedAutoencoderKLHunyuanVideo,
)

return DistributedAutoencoderKLHunyuanVideo.tile_exec


def _import_tile_merge():
from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuanvideo import (
DistributedAutoencoderKLHunyuanVideo,
)

return DistributedAutoencoderKLHunyuanVideo.tile_merge


class TestTileSplit:
def test_single_tile(self):
tile_split = _import_tile_split()
vae = _DummyHunyuanVae()
z = torch.zeros(1, 16, 4, 16, 16)
tasks, grid_spec = tile_split(vae, z)
assert len(tasks) == 1
assert grid_spec.grid_shape == (1, 1)

def test_multiple_tiles_480p(self):
tile_split = _import_tile_split()
vae = _DummyHunyuanVae()
z = torch.zeros(1, 16, 4, 60, 104)
tasks, grid_spec = tile_split(vae, z)
assert len(tasks) > 1
grid_h, grid_w = grid_spec.grid_shape
assert grid_h * grid_w == len(tasks)

def test_grid_coords_are_unique(self):
tile_split = _import_tile_split()
vae = _DummyHunyuanVae()
z = torch.zeros(1, 16, 4, 60, 104)
tasks, _ = tile_split(vae, z)
coords = [t.grid_coord for t in tasks]
assert len(coords) == len(set(coords))

def test_tile_ids_are_sequential(self):
tile_split = _import_tile_split()
vae = _DummyHunyuanVae()
z = torch.zeros(1, 16, 4, 60, 104)
tasks, _ = tile_split(vae, z)
assert [t.tile_id for t in tasks] == list(range(len(tasks)))

def test_tile_shape(self):
tile_split = _import_tile_split()
vae = _DummyHunyuanVae()
z = torch.zeros(1, 16, 4, 60, 104)
tasks, _ = tile_split(vae, z)
for t in tasks:
assert t.tensor.shape[-2] <= vae.tile_latent_min_height
assert t.tensor.shape[-1] <= vae.tile_latent_min_width


class TestTileMerge:
def _run_split_exec_merge(self, z):
tile_split = _import_tile_split()
tile_exec = _import_tile_exec()
tile_merge = _import_tile_merge()
vae = _DummyHunyuanVae()
tasks, grid_spec = tile_split(vae, z)
coord_tensor_map = {t.grid_coord: tile_exec(vae, t) for t in tasks}
return tile_merge(vae, coord_tensor_map, grid_spec)

def test_output_shape_single_tile(self):
z = torch.zeros(1, 16, 4, 16, 16)
result = self._run_split_exec_merge(z)
assert result.shape[-2] == 16 * 8
assert result.shape[-1] == 16 * 8

def test_output_shape_480p(self):
z = torch.ones(1, 4, 4, 60, 104)
result = self._run_split_exec_merge(z)
assert result.shape[0] == 1
assert result.shape[-2] > 0
assert result.shape[-1] > 0

def test_uniform_latent_produces_uniform_output(self):
"""A constant latent should produce a constant output (blend seams vanish)."""
z = torch.ones(1, 4, 2, 60, 104) * 0.5
result = self._run_split_exec_merge(z)
assert torch.allclose(result, result[0, 0, 0, 0, 0].expand_as(result), atol=1e-5)


class TestBlend:
def test_blend_v_boundary(self):
vae = _DummyHunyuanVae()
a = torch.ones(1, 4, 2, 32, 32) * 0.0
b = torch.ones(1, 4, 2, 32, 32) * 1.0
blend_extent = 8
result = vae.blend_v(a, b, blend_extent)
assert result[:, :, :, 0, :].mean() < result[:, :, :, blend_extent - 1, :].mean()

def test_blend_h_boundary(self):
vae = _DummyHunyuanVae()
a = torch.ones(1, 4, 2, 32, 32) * 0.0
b = torch.ones(1, 4, 2, 32, 32) * 1.0
blend_extent = 8
result = vae.blend_h(a, b, blend_extent)
assert result[:, :, :, :, 0].mean() < result[:, :, :, :, blend_extent - 1].mean()

def test_blend_v_no_change_beyond_extent(self):
vae = _DummyHunyuanVae()
a = torch.zeros(1, 4, 2, 32, 32)
b = torch.ones(1, 4, 2, 32, 32) * 2.0
result = vae.blend_v(a, b, blend_extent=4)
assert torch.all(result[:, :, :, 4:, :] == 2.0)

def test_blend_h_no_change_beyond_extent(self):
vae = _DummyHunyuanVae()
a = torch.zeros(1, 4, 2, 32, 32)
b = torch.ones(1, 4, 2, 32, 32) * 2.0
result = vae.blend_h(a, b, blend_extent=4)
assert torch.all(result[:, :, :, :, 4:] == 2.0)
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Any

import torch
from diffusers.models.autoencoders import AutoencoderKLHunyuanVideo15
from diffusers.models.autoencoders.vae import DecoderOutput
from vllm.logger import init_logger

from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl import DistributedAutoencoderKL_base
from vllm_omni.diffusion.distributed.autoencoders.distributed_vae_executor import (
DistributedOperator,
GridSpec,
TileTask,
)

logger = init_logger(__name__)


class DistributedAutoencoderKLHunyuanVideo(DistributedAutoencoderKL_base, AutoencoderKLHunyuanVideo15):
"""Distributed VAE for HunyuanVideo 1.5 (T2V and I2V).

Uses diffusers-style overlapping tile split with linear blending for
single-GPU and distributed decode.
"""

def init_distributed(self):
"""Initialize distributed VAE and compute latent tile sizes."""
super().init_distributed()

spatial_ratio = getattr(self.config, "spatial_compression_ratio", 8)

# Derive stride from tile_overlap_factor (set by parent __init__ / enable_tiling).
# AutoencoderKLHunyuanVideo15 does not have tile_sample_stride_* attributes.
self.tile_sample_stride_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor))
self.tile_sample_stride_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor))

self.tile_latent_min_height = self.tile_sample_min_height // spatial_ratio
self.tile_latent_min_width = self.tile_sample_min_width // spatial_ratio
self.tile_latent_stride_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor))
self.tile_latent_stride_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor))

# ---- tile-based split (diffusers-style overlapping tiles) ----

def tile_split(self, z: torch.Tensor) -> tuple[list[TileTask], GridSpec]:
"""Split latent tensor into overlapping spatial tiles along H, W."""
_, _, num_frames, height, width = z.shape

stride_h = self.tile_latent_stride_height
stride_w = self.tile_latent_stride_width
blend_h = self.tile_sample_min_height - self.tile_sample_stride_height
blend_w = self.tile_sample_min_width - self.tile_sample_stride_width
row_limit_h = self.tile_sample_stride_height
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: overlap_h/overlap_w are strides, not overlaps — rename to stride_h/stride_w.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, renamed to stride_h/stride_w.

row_limit_w = self.tile_sample_stride_width

tiletask_list = []
tile_id = 0
for i in range(0, height, stride_h):
for j in range(0, width, stride_w):
tile = z[:, :, :, i : i + self.tile_latent_min_height, j : j + self.tile_latent_min_width]
tiletask_list.append(
TileTask(tile_id, (i // stride_h, j // stride_w), tile, workload=tile.shape[-2] * tile.shape[-1])
)
tile_id += 1

tile_spec = {
"blend_h": blend_h,
"blend_w": blend_w,
"row_limit_h": row_limit_h,
"row_limit_w": row_limit_w,
}
grid_spec = GridSpec(
split_dims=(3, 4),
grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1),
tile_spec=tile_spec,
output_dtype=self.dtype,
)
return tiletask_list, grid_spec

def tile_exec(self, task: TileTask) -> torch.Tensor:
return self.decoder(task.tensor.contiguous())

def tile_merge(self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec) -> torch.Tensor:
grid_h, grid_w = grid_spec.grid_shape
blend_h = grid_spec.tile_spec["blend_h"]
blend_w = grid_spec.tile_spec["blend_w"]
row_limit_h = grid_spec.tile_spec["row_limit_h"]
row_limit_w = grid_spec.tile_spec["row_limit_w"]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parent tile_exec applies post_quant_conv before decoding. Here and in the decode() fallback (line 124: result = self.decoder(z)) it's skipped. Is this intentional for HunyuanVideo's VAE? If use_post_quant_conv is False in the config it's fine, but worth confirming since it's a silent difference from the base class behavior.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is intentional. AutoencoderKLHunyuanVideo15 does not use post_quant_conv in its decode path -- both _decode() and tiled_decode() call self.decoder(z) directly without post_quant_conv. Our tile_exec mirrors that behavior exactly.

# Build a 2D list mirroring diffusers' rows[][] so that in-place
# blending on previous tiles is visible to later iterations.
rows: list[list[torch.Tensor]] = []
for i in range(grid_h):
rows.append([coord_tensor_map[(i, j)] for j in range(grid_w)])

result_rows = []
for i in range(grid_h):
result_row = []
for j in range(grid_w):
tile = rows[i][j]
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_h)
if j > 0:
tile = self.blend_h(rows[i][j - 1], tile, blend_w)
rows[i][j] = tile
crop_h = min(row_limit_h, tile.shape[-2])
crop_w = min(row_limit_w, tile.shape[-1])
result_row.append(tile[:, :, :, :crop_h, :crop_w])
result_rows.append(torch.cat(result_row, dim=-1))
return torch.cat(result_rows, dim=-2)

# ---- decode override ----

def decode(self, z: torch.Tensor, return_dict: bool = True, *args: Any, **kwargs: Any):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you refer to #2368 to implement vae encode parallel for HunyuanVideo?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @gcanlin, thanks for the suggestion. We looked into this carefully but concluded that encode parallel is not needed for HunyuanVideo-1.5.

In Wan I2V (#2368), the VAE encode input is a full-length video condition tensor [B, C, num_frames, H, W] — the same spatial resolution and temporal length as the generated video — so encode is a genuine bottleneck worth parallelizing.

In HunyuanVideo-1.5 I2V, the VAE encode input is a single reference frame [B, C, 1, H, W]. The compute and memory cost is negligible compared to decode, so tiled encode parallel would add code complexity with no practical benefit.

For T2V there is no encode at all. So encode parallel has no meaningful use case for HunyuanVideo-1.5, and we have kept the implementation focused on decode parallel only.

if not self.is_distributed_enabled():
return super().decode(z, return_dict=return_dict, *args, **kwargs)

logger.debug("HunyuanVideo VAE: distributed tiled decode with overlap blending")
result = self.distributed_executor.execute(
z,
DistributedOperator(split=self.tile_split, exec=self.tile_exec, merge=self.tile_merge),
broadcast_result=False,
)

if not return_dict:
return (result,)
return DecoderOutput(sample=result)
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def execute(self, z: torch.Tensor, operator: DistributedOperator, broadcast_resu

# 2. local decode
assigned = self._balance_tasks(tiletask_list, pp_size)
local_tasks = assigned[self.rank] if pp_size <= self.world_size else []
local_tasks = assigned[self.rank] if self.rank < pp_size else []
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need to change this line?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @JiwaniZakir noted in the earlier review:

The old condition was checking a global property rather than whether the current rank falls within the active pipeline stage range.

To be more specific: pp_size is computed as min(self.parallel_size, self.world_size), so pp_size <= self.world_size is always True — meaning every rank would unconditionally receive tasks, making vae_patch_parallel_size effectively a no-op. The fix self.rank < pp_size correctly gates task assignment on whether the current rank is within the active VAE parallel group, so ranks outside the group get an empty task list instead.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix itself is good, but the PR description says this removes "unnecessary pp_size = min(parallel_size, world_size) indirection" — the min() is fine, the old condition was the bug.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks. The PR description has been updated to clarify: the itself is correct and necessary, the bug was the guard condition which is always by construction.

local_results = [(t.tile_id, operator.exec(t)) for t in local_tasks]

# 3. compute shape per rank
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import numpy as np
import torch
from diffusers import AutoencoderKLHunyuanVideo15
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
Expand All @@ -20,6 +19,9 @@
from vllm.model_executor.models.utils import AutoWeightsLoader

from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuanvideo import (
DistributedAutoencoderKLHunyuanVideo,
)
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
Expand Down Expand Up @@ -110,7 +112,7 @@ def __init__(
t5_config = AutoConfig.from_pretrained(model, subfolder="text_encoder_2", local_files_only=local_files_only)
self.text_encoder_2 = T5EncoderModel(t5_config, prefix="text_encoder_2").to(dtype=dtype, device=self.device)

self.vae = AutoencoderKLHunyuanVideo15.from_pretrained(
self.vae = DistributedAutoencoderKLHunyuanVideo.from_pretrained(
model, subfolder="vae", torch_dtype=torch.float32, local_files_only=local_files_only
).to(self.device)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import numpy as np
import PIL.Image
import torch
from diffusers import AutoencoderKLHunyuanVideo15
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
Expand All @@ -27,6 +26,9 @@
from vllm.model_executor.models.utils import AutoWeightsLoader

from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuanvideo import (
DistributedAutoencoderKLHunyuanVideo,
)
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
Expand Down Expand Up @@ -139,7 +141,7 @@ def __init__(
model, subfolder="feature_extractor", local_files_only=local_files_only
)

self.vae = AutoencoderKLHunyuanVideo15.from_pretrained(
self.vae = DistributedAutoencoderKLHunyuanVideo.from_pretrained(
model, subfolder="vae", torch_dtype=torch.float32, local_files_only=local_files_only
).to(self.device)

Expand Down
Loading