-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[Feat][HunyuanVideo-1.5]Support vae-patch-parallel #2418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,180 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| """Unit tests for DistributedAutoencoderKLHunyuanVideo tile split/merge/blend (CPU-only).""" | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| pytestmark = [pytest.mark.cpu, pytest.mark.core_model] | ||
|
|
||
|
|
||
| class _DummyHunyuanVae: | ||
| """Minimal mock of DistributedAutoencoderKLHunyuanVideo for unit testing.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| tile_sample_min_height=256, | ||
| tile_sample_min_width=256, | ||
| tile_overlap_factor=0.25, | ||
| spatial_ratio=8, | ||
| ): | ||
| self.tile_sample_min_height = tile_sample_min_height | ||
| self.tile_sample_min_width = tile_sample_min_width | ||
| self.tile_overlap_factor = tile_overlap_factor | ||
| self.tile_sample_stride_height = int(tile_sample_min_height * (1 - tile_overlap_factor)) | ||
| self.tile_sample_stride_width = int(tile_sample_min_width * (1 - tile_overlap_factor)) | ||
| self.tile_latent_min_height = tile_sample_min_height // spatial_ratio | ||
| self.tile_latent_min_width = tile_sample_min_width // spatial_ratio | ||
| self.tile_latent_stride_height = int(self.tile_latent_min_height * (1 - tile_overlap_factor)) | ||
| self.tile_latent_stride_width = int(self.tile_latent_min_width * (1 - tile_overlap_factor)) | ||
| self.dtype = torch.float32 | ||
|
|
||
| def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: | ||
| blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) | ||
| for y in range(blend_extent): | ||
| b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( | ||
| y / blend_extent | ||
| ) | ||
| return b | ||
|
|
||
| def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: | ||
| blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) | ||
| for x in range(blend_extent): | ||
| b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( | ||
| x / blend_extent | ||
| ) | ||
| return b | ||
|
|
||
| def decoder(self, z: torch.Tensor) -> torch.Tensor: | ||
| # Mock: upsample latent by spatial_ratio=8 along H and W | ||
| return z.repeat_interleave(8, dim=-2).repeat_interleave(8, dim=-1) | ||
|
|
||
|
|
||
| def _import_tile_split(): | ||
| from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuanvideo import ( | ||
| DistributedAutoencoderKLHunyuanVideo, | ||
| ) | ||
|
|
||
| return DistributedAutoencoderKLHunyuanVideo.tile_split | ||
|
|
||
|
|
||
| def _import_tile_exec(): | ||
| from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuanvideo import ( | ||
| DistributedAutoencoderKLHunyuanVideo, | ||
| ) | ||
|
|
||
| return DistributedAutoencoderKLHunyuanVideo.tile_exec | ||
|
|
||
|
|
||
| def _import_tile_merge(): | ||
| from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuanvideo import ( | ||
| DistributedAutoencoderKLHunyuanVideo, | ||
| ) | ||
|
|
||
| return DistributedAutoencoderKLHunyuanVideo.tile_merge | ||
|
|
||
|
|
||
| class TestTileSplit: | ||
| def test_single_tile(self): | ||
| tile_split = _import_tile_split() | ||
| vae = _DummyHunyuanVae() | ||
| z = torch.zeros(1, 16, 4, 16, 16) | ||
| tasks, grid_spec = tile_split(vae, z) | ||
| assert len(tasks) == 1 | ||
| assert grid_spec.grid_shape == (1, 1) | ||
|
|
||
| def test_multiple_tiles_480p(self): | ||
| tile_split = _import_tile_split() | ||
| vae = _DummyHunyuanVae() | ||
| z = torch.zeros(1, 16, 4, 60, 104) | ||
| tasks, grid_spec = tile_split(vae, z) | ||
| assert len(tasks) > 1 | ||
| grid_h, grid_w = grid_spec.grid_shape | ||
| assert grid_h * grid_w == len(tasks) | ||
|
|
||
| def test_grid_coords_are_unique(self): | ||
| tile_split = _import_tile_split() | ||
| vae = _DummyHunyuanVae() | ||
| z = torch.zeros(1, 16, 4, 60, 104) | ||
| tasks, _ = tile_split(vae, z) | ||
| coords = [t.grid_coord for t in tasks] | ||
| assert len(coords) == len(set(coords)) | ||
|
|
||
| def test_tile_ids_are_sequential(self): | ||
| tile_split = _import_tile_split() | ||
| vae = _DummyHunyuanVae() | ||
| z = torch.zeros(1, 16, 4, 60, 104) | ||
| tasks, _ = tile_split(vae, z) | ||
| assert [t.tile_id for t in tasks] == list(range(len(tasks))) | ||
|
|
||
| def test_tile_shape(self): | ||
| tile_split = _import_tile_split() | ||
| vae = _DummyHunyuanVae() | ||
| z = torch.zeros(1, 16, 4, 60, 104) | ||
| tasks, _ = tile_split(vae, z) | ||
| for t in tasks: | ||
| assert t.tensor.shape[-2] <= vae.tile_latent_min_height | ||
| assert t.tensor.shape[-1] <= vae.tile_latent_min_width | ||
|
|
||
|
|
||
| class TestTileMerge: | ||
| def _run_split_exec_merge(self, z): | ||
| tile_split = _import_tile_split() | ||
| tile_exec = _import_tile_exec() | ||
| tile_merge = _import_tile_merge() | ||
| vae = _DummyHunyuanVae() | ||
| tasks, grid_spec = tile_split(vae, z) | ||
| coord_tensor_map = {t.grid_coord: tile_exec(vae, t) for t in tasks} | ||
| return tile_merge(vae, coord_tensor_map, grid_spec) | ||
|
|
||
| def test_output_shape_single_tile(self): | ||
| z = torch.zeros(1, 16, 4, 16, 16) | ||
| result = self._run_split_exec_merge(z) | ||
| assert result.shape[-2] == 16 * 8 | ||
| assert result.shape[-1] == 16 * 8 | ||
|
|
||
| def test_output_shape_480p(self): | ||
| z = torch.ones(1, 4, 4, 60, 104) | ||
| result = self._run_split_exec_merge(z) | ||
| assert result.shape[0] == 1 | ||
| assert result.shape[-2] > 0 | ||
| assert result.shape[-1] > 0 | ||
|
|
||
| def test_uniform_latent_produces_uniform_output(self): | ||
| """A constant latent should produce a constant output (blend seams vanish).""" | ||
| z = torch.ones(1, 4, 2, 60, 104) * 0.5 | ||
| result = self._run_split_exec_merge(z) | ||
| assert torch.allclose(result, result[0, 0, 0, 0, 0].expand_as(result), atol=1e-5) | ||
|
|
||
|
|
||
| class TestBlend: | ||
| def test_blend_v_boundary(self): | ||
| vae = _DummyHunyuanVae() | ||
| a = torch.ones(1, 4, 2, 32, 32) * 0.0 | ||
| b = torch.ones(1, 4, 2, 32, 32) * 1.0 | ||
| blend_extent = 8 | ||
| result = vae.blend_v(a, b, blend_extent) | ||
| assert result[:, :, :, 0, :].mean() < result[:, :, :, blend_extent - 1, :].mean() | ||
|
|
||
| def test_blend_h_boundary(self): | ||
| vae = _DummyHunyuanVae() | ||
| a = torch.ones(1, 4, 2, 32, 32) * 0.0 | ||
| b = torch.ones(1, 4, 2, 32, 32) * 1.0 | ||
| blend_extent = 8 | ||
| result = vae.blend_h(a, b, blend_extent) | ||
| assert result[:, :, :, :, 0].mean() < result[:, :, :, :, blend_extent - 1].mean() | ||
|
|
||
| def test_blend_v_no_change_beyond_extent(self): | ||
| vae = _DummyHunyuanVae() | ||
| a = torch.zeros(1, 4, 2, 32, 32) | ||
| b = torch.ones(1, 4, 2, 32, 32) * 2.0 | ||
| result = vae.blend_v(a, b, blend_extent=4) | ||
| assert torch.all(result[:, :, :, 4:, :] == 2.0) | ||
|
|
||
| def test_blend_h_no_change_beyond_extent(self): | ||
| vae = _DummyHunyuanVae() | ||
| a = torch.zeros(1, 4, 2, 32, 32) | ||
| b = torch.ones(1, 4, 2, 32, 32) * 2.0 | ||
| result = vae.blend_h(a, b, blend_extent=4) | ||
| assert torch.all(result[:, :, :, :, 4:] == 2.0) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from typing import Any | ||
|
|
||
| import torch | ||
| from diffusers.models.autoencoders import AutoencoderKLHunyuanVideo15 | ||
| from diffusers.models.autoencoders.vae import DecoderOutput | ||
| from vllm.logger import init_logger | ||
|
|
||
| from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl import DistributedAutoencoderKL_base | ||
| from vllm_omni.diffusion.distributed.autoencoders.distributed_vae_executor import ( | ||
| DistributedOperator, | ||
| GridSpec, | ||
| TileTask, | ||
| ) | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class DistributedAutoencoderKLHunyuanVideo(DistributedAutoencoderKL_base, AutoencoderKLHunyuanVideo15): | ||
| """Distributed VAE for HunyuanVideo 1.5 (T2V and I2V). | ||
|
|
||
| Uses diffusers-style overlapping tile split with linear blending for | ||
| single-GPU and distributed decode. | ||
| """ | ||
|
|
||
| def init_distributed(self): | ||
| """Initialize distributed VAE and compute latent tile sizes.""" | ||
| super().init_distributed() | ||
|
|
||
| spatial_ratio = getattr(self.config, "spatial_compression_ratio", 8) | ||
|
|
||
| # Derive stride from tile_overlap_factor (set by parent __init__ / enable_tiling). | ||
| # AutoencoderKLHunyuanVideo15 does not have tile_sample_stride_* attributes. | ||
| self.tile_sample_stride_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor)) | ||
| self.tile_sample_stride_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor)) | ||
|
|
||
| self.tile_latent_min_height = self.tile_sample_min_height // spatial_ratio | ||
| self.tile_latent_min_width = self.tile_sample_min_width // spatial_ratio | ||
| self.tile_latent_stride_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor)) | ||
| self.tile_latent_stride_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor)) | ||
|
|
||
| # ---- tile-based split (diffusers-style overlapping tiles) ---- | ||
|
|
||
| def tile_split(self, z: torch.Tensor) -> tuple[list[TileTask], GridSpec]: | ||
| """Split latent tensor into overlapping spatial tiles along H, W.""" | ||
| _, _, num_frames, height, width = z.shape | ||
|
|
||
| stride_h = self.tile_latent_stride_height | ||
| stride_w = self.tile_latent_stride_width | ||
| blend_h = self.tile_sample_min_height - self.tile_sample_stride_height | ||
| blend_w = self.tile_sample_min_width - self.tile_sample_stride_width | ||
| row_limit_h = self.tile_sample_stride_height | ||
| row_limit_w = self.tile_sample_stride_width | ||
|
|
||
| tiletask_list = [] | ||
| tile_id = 0 | ||
| for i in range(0, height, stride_h): | ||
| for j in range(0, width, stride_w): | ||
| tile = z[:, :, :, i : i + self.tile_latent_min_height, j : j + self.tile_latent_min_width] | ||
| tiletask_list.append( | ||
| TileTask(tile_id, (i // stride_h, j // stride_w), tile, workload=tile.shape[-2] * tile.shape[-1]) | ||
| ) | ||
| tile_id += 1 | ||
|
|
||
| tile_spec = { | ||
| "blend_h": blend_h, | ||
| "blend_w": blend_w, | ||
| "row_limit_h": row_limit_h, | ||
| "row_limit_w": row_limit_w, | ||
| } | ||
| grid_spec = GridSpec( | ||
| split_dims=(3, 4), | ||
| grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1), | ||
| tile_spec=tile_spec, | ||
| output_dtype=self.dtype, | ||
| ) | ||
| return tiletask_list, grid_spec | ||
|
|
||
| def tile_exec(self, task: TileTask) -> torch.Tensor: | ||
| return self.decoder(task.tensor.contiguous()) | ||
|
|
||
| def tile_merge(self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec) -> torch.Tensor: | ||
| grid_h, grid_w = grid_spec.grid_shape | ||
| blend_h = grid_spec.tile_spec["blend_h"] | ||
| blend_w = grid_spec.tile_spec["blend_w"] | ||
| row_limit_h = grid_spec.tile_spec["row_limit_h"] | ||
| row_limit_w = grid_spec.tile_spec["row_limit_w"] | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Parent
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is intentional. AutoencoderKLHunyuanVideo15 does not use post_quant_conv in its decode path -- both _decode() and tiled_decode() call self.decoder(z) directly without post_quant_conv. Our tile_exec mirrors that behavior exactly. |
||
| # Build a 2D list mirroring diffusers' rows[][] so that in-place | ||
| # blending on previous tiles is visible to later iterations. | ||
| rows: list[list[torch.Tensor]] = [] | ||
| for i in range(grid_h): | ||
| rows.append([coord_tensor_map[(i, j)] for j in range(grid_w)]) | ||
|
|
||
| result_rows = [] | ||
| for i in range(grid_h): | ||
| result_row = [] | ||
| for j in range(grid_w): | ||
| tile = rows[i][j] | ||
| if i > 0: | ||
| tile = self.blend_v(rows[i - 1][j], tile, blend_h) | ||
| if j > 0: | ||
| tile = self.blend_h(rows[i][j - 1], tile, blend_w) | ||
| rows[i][j] = tile | ||
| crop_h = min(row_limit_h, tile.shape[-2]) | ||
| crop_w = min(row_limit_w, tile.shape[-1]) | ||
| result_row.append(tile[:, :, :, :crop_h, :crop_w]) | ||
| result_rows.append(torch.cat(result_row, dim=-1)) | ||
| return torch.cat(result_rows, dim=-2) | ||
|
|
||
| # ---- decode override ---- | ||
|
|
||
| def decode(self, z: torch.Tensor, return_dict: bool = True, *args: Any, **kwargs: Any): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you refer to #2368 to implement vae encode parallel for HunyuanVideo?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @gcanlin, thanks for the suggestion. We looked into this carefully but concluded that encode parallel is not needed for HunyuanVideo-1.5. In Wan I2V (#2368), the VAE encode input is a full-length video condition tensor In HunyuanVideo-1.5 I2V, the VAE encode input is a single reference frame For T2V there is no encode at all. So encode parallel has no meaningful use case for HunyuanVideo-1.5, and we have kept the implementation focused on decode parallel only. |
||
| if not self.is_distributed_enabled(): | ||
| return super().decode(z, return_dict=return_dict, *args, **kwargs) | ||
|
|
||
| logger.debug("HunyuanVideo VAE: distributed tiled decode with overlap blending") | ||
| result = self.distributed_executor.execute( | ||
| z, | ||
| DistributedOperator(split=self.tile_split, exec=self.tile_exec, merge=self.tile_merge), | ||
| broadcast_result=False, | ||
| ) | ||
|
|
||
| if not return_dict: | ||
| return (result,) | ||
| return DecoderOutput(sample=result) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -125,7 +125,7 @@ def execute(self, z: torch.Tensor, operator: DistributedOperator, broadcast_resu | |
|
|
||
| # 2. local decode | ||
| assigned = self._balance_tasks(tiletask_list, pp_size) | ||
| local_tasks = assigned[self.rank] if pp_size <= self.world_size else [] | ||
| local_tasks = assigned[self.rank] if self.rank < pp_size else [] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why need to change this line?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As @JiwaniZakir noted in the earlier review:
To be more specific:
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fix itself is good, but the PR description says this removes "unnecessary pp_size = min(parallel_size, world_size) indirection" — the min() is fine, the old condition was the bug.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, thanks. The PR description has been updated to clarify: the itself is correct and necessary, the bug was the guard condition which is always by construction. |
||
| local_results = [(t.tile_id, operator.exec(t)) for t in local_tasks] | ||
|
|
||
| # 3. compute shape per rank | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
overlap_h/overlap_ware strides, not overlaps — rename tostride_h/stride_w.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, renamed to
stride_h/stride_w.