-
Notifications
You must be signed in to change notification settings - Fork 949
[Feat] Enable VAE parallel in HunyuanImage3 #3091
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Fishermanykx
wants to merge
8
commits into
vllm-project:main
Choose a base branch
from
Fishermanykx:yukexiong/hunyuan_vae_opt
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+231
−2
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
4e17c90
[WIP][Feat.] Enable VAE parallel in HunyuanImage3
Fishermanykx 272bc98
[UT] Add Hunyuan distributed VAE tests
Fishermanykx 61fc92a
Add concise per-test comments in tests/diffusion/distributed/test_aut…
BLANKETusers 07027f8
fix(diffusion): break circular import in pipeline_hunyuan_image3
BLANKETusers 7d48820
fix(diffusion): break circular import in pipeline_hunyuan_image3
BLANKETusers b52f18b
feat(diffusion): add from_pretrained to DistributedAutoencoderKLHunyu…
BLANKETusers 85c590c
update tile grid assertions for tile_overlap_factor=0.25 in Hunyuan V…
BLANKETusers 4f024a9
update tile grid assertions for tile_overlap_factor=0.25 in Hunyuan V…
BLANKETusers File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
79 changes: 79 additions & 0 deletions
79
tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuan import ( | ||
| DistributedAutoencoderKLHunyuan, | ||
| ) | ||
|
|
||
| pytestmark = [pytest.mark.core_model, pytest.mark.cpu] | ||
|
|
||
|
|
||
| class _DummyDistributedAutoencoderKLHunyuan(DistributedAutoencoderKLHunyuan): | ||
| def __init__(self): | ||
| torch.nn.Module.__init__(self) | ||
| self.tile_latent_min_size = 8 | ||
| self.tile_sample_min_size = 8 | ||
| self.tile_overlap_factor = 0.25 | ||
| self.use_spatial_tiling = False | ||
|
|
||
| @property | ||
| def dtype(self) -> torch.dtype: | ||
| return torch.float32 | ||
|
|
||
| def decoder(self, x: torch.Tensor) -> torch.Tensor: | ||
| return x + 10 | ||
|
|
||
| def encoder(self, x: torch.Tensor) -> torch.Tensor: | ||
| return x + 20 | ||
|
|
||
| def blend_v(self, _a: torch.Tensor, b: torch.Tensor, _blend_extent: int) -> torch.Tensor: | ||
| return b | ||
|
|
||
| def blend_h(self, _a: torch.Tensor, b: torch.Tensor, _blend_extent: int) -> torch.Tensor: | ||
| return b | ||
|
|
||
|
|
||
| def test_hunyuan_vae_use_tiling_aliases_spatial_tiling(): | ||
| # Verify use_tiling property maps to use_spatial_tiling. | ||
| vae = _DummyDistributedAutoencoderKLHunyuan() | ||
|
|
||
| assert not vae.use_tiling | ||
|
|
||
| vae.use_tiling = True | ||
|
|
||
| assert vae.use_spatial_tiling | ||
|
|
||
|
|
||
| def test_hunyuan_vae_decode_tiles_round_trip(): | ||
| # Validate decode tile split/exec/merge returns expected reconstructed tensor. | ||
| vae = _DummyDistributedAutoencoderKLHunyuan() | ||
| z = torch.arange(144, dtype=torch.float32).reshape(1, 1, 1, 12, 12) | ||
|
|
||
| tile_tasks, grid_spec = vae.tile_split(z) | ||
| decoded_tiles = {task.grid_coord: vae.tile_exec(task) for task in tile_tasks} | ||
| output = vae.tile_merge(decoded_tiles, grid_spec) | ||
|
|
||
| assert grid_spec.split_dims == (3, 4) | ||
| assert grid_spec.grid_shape == (2, 2) | ||
| assert grid_spec.tile_spec == {"blend_extent": 2, "row_limit": 6} | ||
| assert len(tile_tasks) == 4 | ||
| assert torch.equal(output, z + 10) | ||
|
|
||
|
|
||
| def test_hunyuan_vae_encode_tiles_round_trip(): | ||
| # Validate encode tile split/exec/merge returns expected latent tensor. | ||
| vae = _DummyDistributedAutoencoderKLHunyuan() | ||
| x = torch.arange(144, dtype=torch.float32).reshape(1, 1, 1, 12, 12) | ||
|
|
||
| tile_tasks, grid_spec = vae.encode_tile_split(x) | ||
| encoded_tiles = {task.grid_coord: vae.encode_tile_exec(task) for task in tile_tasks} | ||
| output = vae.encode_tile_merge(encoded_tiles, grid_spec) | ||
|
|
||
| assert grid_spec.split_dims == (3, 4) | ||
| assert grid_spec.grid_shape == (2, 2) | ||
| assert grid_spec.tile_spec == {"blend_extent": 2, "row_limit": 6} | ||
| assert len(tile_tasks) == 4 | ||
| assert torch.equal(output, x + 20) |
145 changes: 145 additions & 0 deletions
145
vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_hunyuan.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,145 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from typing import Any | ||
|
|
||
| import torch | ||
| from vllm.logger import init_logger | ||
|
|
||
| from vllm_omni.diffusion.distributed.autoencoders.distributed_vae_executor import ( | ||
| DistributedOperator, | ||
| DistributedVaeMixin, | ||
| GridSpec, | ||
| TileTask, | ||
| ) | ||
| from vllm_omni.diffusion.models.hunyuan_image3.autoencoder import AutoencoderKLConv3D | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class DistributedAutoencoderKLHunyuan(AutoencoderKLConv3D, DistributedVaeMixin): | ||
| @classmethod | ||
| def from_config(cls, config: Any, **kwargs: Any): | ||
| model = super().from_config(config, **kwargs) | ||
| model.init_distributed() | ||
| return model | ||
|
|
||
| @classmethod | ||
| def from_pretrained(cls, *args: Any, **kwargs: Any): | ||
| model = super().from_pretrained(*args, **kwargs) | ||
| model.init_distributed() | ||
| return model | ||
|
|
||
| @property | ||
| def use_tiling(self) -> bool: | ||
| return self.use_spatial_tiling | ||
|
|
||
| @use_tiling.setter | ||
| def use_tiling(self, use_tiling: bool) -> None: | ||
| self.use_spatial_tiling = use_tiling | ||
|
|
||
| def tile_split(self, z: torch.Tensor) -> tuple[list[TileTask], GridSpec]: | ||
| _, _, _, height, width = z.shape | ||
| overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) | ||
| blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) | ||
| row_limit = int(self.tile_sample_min_size - blend_extent) | ||
|
|
||
| tiletask_list = [] | ||
| for i in range(0, height, overlap_size): | ||
| for j in range(0, width, overlap_size): | ||
| tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] | ||
| tiletask_list.append( | ||
| TileTask( | ||
| len(tiletask_list), | ||
| (i // overlap_size, j // overlap_size), | ||
| tile, | ||
| workload=tile.shape[3] * tile.shape[4], | ||
| ) | ||
| ) | ||
|
|
||
| grid_spec = GridSpec( | ||
| split_dims=(3, 4), | ||
| grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1), | ||
| tile_spec={"blend_extent": blend_extent, "row_limit": row_limit}, | ||
| output_dtype=self.dtype, | ||
| ) | ||
| return tiletask_list, grid_spec | ||
|
|
||
| def tile_exec(self, task: TileTask) -> torch.Tensor: | ||
| return self.decoder(task.tensor) | ||
|
|
||
| def tile_merge(self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec) -> torch.Tensor: | ||
| grid_h, grid_w = grid_spec.grid_shape | ||
| result_rows = [] | ||
| for i in range(grid_h): | ||
| result_row = [] | ||
| for j in range(grid_w): | ||
| tile = coord_tensor_map[(i, j)] | ||
| if i > 0: | ||
| tile = self.blend_v(coord_tensor_map[(i - 1, j)], tile, grid_spec.tile_spec["blend_extent"]) | ||
| if j > 0: | ||
| tile = self.blend_h(coord_tensor_map[(i, j - 1)], tile, grid_spec.tile_spec["blend_extent"]) | ||
| result_row.append(tile[:, :, :, : grid_spec.tile_spec["row_limit"], : grid_spec.tile_spec["row_limit"]]) | ||
| result_rows.append(torch.cat(result_row, dim=-1)) | ||
| return torch.cat(result_rows, dim=-2) | ||
|
|
||
| def encode_tile_split(self, x: torch.Tensor) -> tuple[list[TileTask], GridSpec]: | ||
| _, _, _, height, width = x.shape | ||
| overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) | ||
| blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) | ||
| row_limit = int(self.tile_latent_min_size - blend_extent) | ||
|
|
||
| tiletask_list = [] | ||
| for i in range(0, height, overlap_size): | ||
| for j in range(0, width, overlap_size): | ||
| tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] | ||
| tiletask_list.append( | ||
| TileTask( | ||
| len(tiletask_list), | ||
| (i // overlap_size, j // overlap_size), | ||
| tile, | ||
| workload=tile.shape[3] * tile.shape[4], | ||
| ) | ||
| ) | ||
|
|
||
| grid_spec = GridSpec( | ||
| split_dims=(3, 4), | ||
| grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1), | ||
| tile_spec={"blend_extent": blend_extent, "row_limit": row_limit}, | ||
| output_dtype=self.dtype, | ||
| ) | ||
| return tiletask_list, grid_spec | ||
|
|
||
| def encode_tile_exec(self, task: TileTask) -> torch.Tensor: | ||
| return self.encoder(task.tensor) | ||
|
|
||
| def encode_tile_merge( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tile_merge and encode_tile_merge are byte-for-byte identical. Could extract helper There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modified |
||
| self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec | ||
| ) -> torch.Tensor: | ||
| return self.tile_merge(coord_tensor_map, grid_spec) | ||
|
|
||
| def spatial_tiled_encode(self, x: torch.Tensor): | ||
| if not self.is_distributed_enabled(): | ||
| return super().spatial_tiled_encode(x) | ||
|
|
||
| logger.debug("Encode running with distributed executor") | ||
| return self.distributed_executor.execute( | ||
| x, | ||
| DistributedOperator( | ||
| split=self.encode_tile_split, | ||
| exec=self.encode_tile_exec, | ||
| merge=self.encode_tile_merge, | ||
| ), | ||
| broadcast_result=True, | ||
| ) | ||
|
|
||
| def spatial_tiled_decode(self, z: torch.Tensor): | ||
| if not self.is_distributed_enabled(): | ||
| return super().spatial_tiled_decode(z) | ||
|
|
||
| logger.debug("Decode running with distributed executor") | ||
| return self.distributed_executor.execute( | ||
| z, | ||
| DistributedOperator(split=self.tile_split, exec=self.tile_exec, merge=self.tile_merge), | ||
| broadcast_result=True, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing from_pretrained, consistency suggests adding it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified