Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/user_guide/diffusion_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ The following tables show which models support each feature:
| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| **LTX-2.3** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **Helios** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ |
| **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ (encode/decode) | ✅ | ❌ |
| **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |

**Frame Interpolation Support**
Expand Down
70 changes: 70 additions & 0 deletions tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuan import (
DistributedAutoencoderKLHunyuan,
)
from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuan_video_15 import (
DistributedAutoencoderKLHunyuanVideo15,
)

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]

Expand Down Expand Up @@ -36,6 +39,33 @@ def blend_h(self, _a: torch.Tensor, b: torch.Tensor, _blend_extent: int) -> torc
return b


class _DummyDistributedAutoencoderKLHunyuanVideo15(DistributedAutoencoderKLHunyuanVideo15):
def __init__(self):
torch.nn.Module.__init__(self)
self.tile_sample_min_height = 8
self.tile_sample_min_width = 12
self.tile_latent_min_height = 8
self.tile_latent_min_width = 12
self.tile_overlap_factor = 0.25
self.use_tiling = False

@property
def dtype(self) -> torch.dtype:
return torch.float32

def decoder(self, x: torch.Tensor) -> torch.Tensor:
return x + 10

def encoder(self, x: torch.Tensor) -> torch.Tensor:
return x + 20

def blend_v(self, _a: torch.Tensor, b: torch.Tensor, _blend_extent: int) -> torch.Tensor:
return b

def blend_h(self, _a: torch.Tensor, b: torch.Tensor, _blend_extent: int) -> torch.Tensor:
return b


def test_hunyuan_vae_use_tiling_aliases_spatial_tiling():
# Verify use_tiling property maps to use_spatial_tiling.
vae = _DummyDistributedAutoencoderKLHunyuan()
Expand Down Expand Up @@ -77,3 +107,43 @@ def test_hunyuan_vae_encode_tiles_round_trip():
assert grid_spec.tile_spec == {"blend_extent": 2, "row_limit": 6}
assert len(tile_tasks) == 4
assert torch.equal(output, x + 20)


def test_hunyuan_video15_vae_decode_tiles_round_trip():
vae = _DummyDistributedAutoencoderKLHunyuanVideo15()
z = torch.arange(192, dtype=torch.float32).reshape(1, 1, 1, 12, 16)

tile_tasks, grid_spec = vae.tile_split(z)
decoded_tiles = {task.grid_coord: vae.tile_exec(task) for task in tile_tasks}
output = vae.tile_merge(decoded_tiles, grid_spec)

assert grid_spec.split_dims == (3, 4)
assert grid_spec.grid_shape == (2, 2)
assert grid_spec.tile_spec == {
"blend_height": 2,
"blend_width": 3,
"row_limit_height": 6,
"row_limit_width": 9,
}
assert len(tile_tasks) == 4
assert torch.equal(output, z + 10)


def test_hunyuan_video15_vae_encode_tiles_round_trip():
vae = _DummyDistributedAutoencoderKLHunyuanVideo15()
x = torch.arange(192, dtype=torch.float32).reshape(1, 1, 1, 12, 16)

tile_tasks, grid_spec = vae.encode_tile_split(x)
encoded_tiles = {task.grid_coord: vae.encode_tile_exec(task) for task in tile_tasks}
output = vae.tile_merge(encoded_tiles, grid_spec)

assert grid_spec.split_dims == (3, 4)
assert grid_spec.grid_shape == (2, 2)
assert grid_spec.tile_spec == {
"blend_height": 2,
"blend_width": 3,
"row_limit_height": 6,
"row_limit_width": 9,
}
assert len(tile_tasks) == 4
assert torch.equal(output, x + 20)
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Any

import torch
from diffusers import AutoencoderKLHunyuanVideo15
from vllm.logger import init_logger

from vllm_omni.diffusion.distributed.autoencoders.distributed_vae_executor import (
DistributedOperator,
DistributedVaeMixin,
GridSpec,
TileTask,
)

logger = init_logger(__name__)


class DistributedAutoencoderKLHunyuanVideo15(AutoencoderKLHunyuanVideo15, DistributedVaeMixin):
@classmethod
def from_pretrained(cls, *args: Any, **kwargs: Any):
model = super().from_pretrained(*args, **kwargs)
model.init_distributed()
return model

def tile_split(self, z: torch.Tensor) -> tuple[list[TileTask], GridSpec]:
_, _, _, height, width = z.shape
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor))
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor))
blend_height = int(self.tile_sample_min_height * self.tile_overlap_factor)
blend_width = int(self.tile_sample_min_width * self.tile_overlap_factor)
row_limit_height = self.tile_sample_min_height - blend_height
row_limit_width = self.tile_sample_min_width - blend_width

tiletask_list = []
for i in range(0, height, overlap_height):
for j in range(0, width, overlap_width):
tile = z[:, :, :, i : i + self.tile_latent_min_height, j : j + self.tile_latent_min_width]
tiletask_list.append(
TileTask(
len(tiletask_list),
(i // overlap_height, j // overlap_width),
tile,
workload=tile.shape[3] * tile.shape[4],
)
)

grid_spec = GridSpec(
split_dims=(3, 4),
grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1),
tile_spec={
"blend_height": blend_height,
"blend_width": blend_width,
"row_limit_height": row_limit_height,
"row_limit_width": row_limit_width,
},
output_dtype=self.dtype,
)
return tiletask_list, grid_spec

def tile_exec(self, task: TileTask) -> torch.Tensor:
return self.decoder(task.tensor)

def tile_merge(self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec) -> torch.Tensor:
grid_h, grid_w = grid_spec.grid_shape
result_rows = []
for i in range(grid_h):
result_row = []
for j in range(grid_w):
tile = coord_tensor_map[(i, j)]
if i > 0:
tile = self.blend_v(coord_tensor_map[(i - 1, j)], tile, grid_spec.tile_spec["blend_height"])
if j > 0:
tile = self.blend_h(coord_tensor_map[(i, j - 1)], tile, grid_spec.tile_spec["blend_width"])
result_row.append(
tile[
:,
:,
:,
: grid_spec.tile_spec["row_limit_height"],
: grid_spec.tile_spec["row_limit_width"],
]
)
result_rows.append(torch.cat(result_row, dim=-1))
return torch.cat(result_rows, dim=-2)

def encode_tile_split(self, x: torch.Tensor) -> tuple[list[TileTask], GridSpec]:
_, _, _, height, width = x.shape
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor))
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor))
blend_height = int(self.tile_latent_min_height * self.tile_overlap_factor)
blend_width = int(self.tile_latent_min_width * self.tile_overlap_factor)
row_limit_height = self.tile_latent_min_height - blend_height
row_limit_width = self.tile_latent_min_width - blend_width

tiletask_list = []
for i in range(0, height, overlap_height):
for j in range(0, width, overlap_width):
tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
tiletask_list.append(
TileTask(
len(tiletask_list),
(i // overlap_height, j // overlap_width),
tile,
workload=tile.shape[3] * tile.shape[4],
)
)

grid_spec = GridSpec(
split_dims=(3, 4),
grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1),
tile_spec={
"blend_height": blend_height,
"blend_width": blend_width,
"row_limit_height": row_limit_height,
"row_limit_width": row_limit_width,
},
output_dtype=self.dtype,
)
return tiletask_list, grid_spec

def encode_tile_exec(self, task: TileTask) -> torch.Tensor:
return self.encoder(task.tensor)

def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
if not self.is_distributed_enabled():
return super().tiled_encode(x)

logger.debug("Encode running with distributed executor")
return self.distributed_executor.execute(
x,
DistributedOperator(
split=self.encode_tile_split,
exec=self.encode_tile_exec,
merge=self.tile_merge,
),
broadcast_result=True,
)

def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
if not self.is_distributed_enabled():
return super().tiled_decode(z)

logger.debug("Decode running with distributed executor")
return self.distributed_executor.execute(
z,
DistributedOperator(split=self.tile_split, exec=self.tile_exec, merge=self.tile_merge),
broadcast_result=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -762,15 +762,14 @@ def forward(
encoder_hidden_states = torch.stack(new_encoder_hidden_states)
encoder_attention_mask = torch.stack(new_encoder_attention_mask)

# Create explicit attn_mask for image tokens when SP auto_pad is active.
max_valid_encoder_tokens = int(encoder_attention_mask.sum(dim=1).max().item())
if max_valid_encoder_tokens < encoder_attention_mask.shape[1]:
encoder_hidden_states = encoder_hidden_states[:, :max_valid_encoder_tokens]
encoder_attention_mask = encoder_attention_mask[:, :max_valid_encoder_tokens]
if encoder_attention_mask.all():
encoder_attention_mask = None

ctx = get_forward_context()
if not ctx.sp_active:
max_valid_encoder_tokens = int(encoder_attention_mask.sum(dim=1).max().item())
if max_valid_encoder_tokens < encoder_attention_mask.shape[1]:
encoder_hidden_states = encoder_hidden_states[:, :max_valid_encoder_tokens]
encoder_attention_mask = encoder_attention_mask[:, :max_valid_encoder_tokens]
if encoder_attention_mask.all():
encoder_attention_mask = None
hidden_states_mask = None
if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0:
padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import numpy as np
import torch
from diffusers import AutoencoderKLHunyuanVideo15
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
Expand All @@ -20,6 +19,9 @@
from vllm.model_executor.models.utils import AutoWeightsLoader

from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuan_video_15 import (
DistributedAutoencoderKLHunyuanVideo15,
)
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
Expand Down Expand Up @@ -110,7 +112,7 @@ def __init__(
t5_config = AutoConfig.from_pretrained(model, subfolder="text_encoder_2", local_files_only=local_files_only)
self.text_encoder_2 = T5EncoderModel(t5_config, prefix="text_encoder_2").to(dtype=dtype, device=self.device)

self.vae = AutoencoderKLHunyuanVideo15.from_pretrained(
self.vae = DistributedAutoencoderKLHunyuanVideo15.from_pretrained(
model, subfolder="vae", torch_dtype=torch.float32, local_files_only=local_files_only
).to(self.device)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import numpy as np
import PIL.Image
import torch
from diffusers import AutoencoderKLHunyuanVideo15
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
Expand All @@ -27,6 +26,9 @@
from vllm.model_executor.models.utils import AutoWeightsLoader

from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuan_video_15 import (
DistributedAutoencoderKLHunyuanVideo15,
)
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
Expand Down Expand Up @@ -139,7 +141,7 @@ def __init__(
model, subfolder="feature_extractor", local_files_only=local_files_only
)

self.vae = AutoencoderKLHunyuanVideo15.from_pretrained(
self.vae = DistributedAutoencoderKLHunyuanVideo15.from_pretrained(
model, subfolder="vae", torch_dtype=torch.float32, local_files_only=local_files_only
).to(self.device)

Expand Down
Loading