From 4e17c90bf2c3075b1e79abc0219d19fdd79f6961 Mon Sep 17 00:00:00 2001 From: KexiongYu Date: Fri, 24 Apr 2026 11:01:43 +0800 Subject: [PATCH 1/8] [WIP][Feat.] Enable VAE parallel in HunyuanImage3 Signed-off-by: KexiongYu --- .../autoencoders/autoencoder_kl_hunyuan.py | 151 ++++++++++++++++++ .../hunyuan_image3/pipeline_hunyuan_image3.py | 6 +- 2 files changed, 155 insertions(+), 2 deletions(-) create mode 100644 vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_hunyuan.py diff --git a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_hunyuan.py b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_hunyuan.py new file mode 100644 index 00000000000..6295bbd1746 --- /dev/null +++ b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_hunyuan.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +import torch +from vllm.logger import init_logger + +from vllm_omni.diffusion.distributed.autoencoders.distributed_vae_executor import ( + DistributedOperator, + DistributedVaeMixin, + GridSpec, + TileTask, +) +from vllm_omni.diffusion.models.hunyuan_image3.autoencoder import AutoencoderKLConv3D + +logger = init_logger(__name__) + + +class DistributedAutoencoderKLHunyuan(AutoencoderKLConv3D, DistributedVaeMixin): + @classmethod + def from_config(cls, config: Any, **kwargs: Any): + model = super().from_config(config, **kwargs) + model.init_distributed() + return model + + @property + def use_tiling(self) -> bool: + return self.use_spatial_tiling + + @use_tiling.setter + def use_tiling(self, use_tiling: bool) -> None: + self.use_spatial_tiling = use_tiling + + def tile_split(self, z: torch.Tensor) -> tuple[list[TileTask], GridSpec]: + _, _, _, height, width = z.shape + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = int(self.tile_sample_min_size - blend_extent) + + tiletask_list = [] + for i in range(0, height, overlap_size): + for j in range(0, width, overlap_size): + tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + tiletask_list.append( + TileTask( + len(tiletask_list), + (i // overlap_size, j // overlap_size), + tile, + workload=tile.shape[3] * tile.shape[4], + ) + ) + + grid_spec = GridSpec( + split_dims=(3, 4), + grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1), + tile_spec={"blend_extent": blend_extent, "row_limit": row_limit}, + output_dtype=self.dtype, + ) + return tiletask_list, grid_spec + + def tile_exec(self, task: TileTask) -> torch.Tensor: + return self.decoder(task.tensor) + + def tile_merge(self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec) -> torch.Tensor: + grid_h, grid_w = grid_spec.grid_shape + result_rows = [] + for i in range(grid_h): + result_row = [] + for j in range(grid_w): + tile = coord_tensor_map[(i, j)] + if i > 0: + tile = self.blend_v(coord_tensor_map[(i - 1, j)], tile, grid_spec.tile_spec["blend_extent"]) + if j > 0: + tile = self.blend_h(coord_tensor_map[(i, j - 1)], tile, grid_spec.tile_spec["blend_extent"]) + result_row.append(tile[:, :, :, : grid_spec.tile_spec["row_limit"], : grid_spec.tile_spec["row_limit"]]) + result_rows.append(torch.cat(result_row, dim=-1)) + return torch.cat(result_rows, dim=-2) + + def encode_tile_split(self, x: torch.Tensor) -> tuple[list[TileTask], GridSpec]: + _, _, _, height, width = x.shape + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = int(self.tile_latent_min_size - blend_extent) + + tiletask_list = [] + for i in range(0, height, overlap_size): + for j in range(0, width, overlap_size): + tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tiletask_list.append( + TileTask( + len(tiletask_list), + (i // overlap_size, j // overlap_size), + tile, + workload=tile.shape[3] * tile.shape[4], + ) + ) + + grid_spec = GridSpec( + split_dims=(3, 4), + grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1), + tile_spec={"blend_extent": blend_extent, "row_limit": row_limit}, + output_dtype=self.dtype, + ) + return tiletask_list, grid_spec + + def encode_tile_exec(self, task: TileTask) -> torch.Tensor: + return self.encoder(task.tensor) + + def encode_tile_merge( + self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec + ) -> torch.Tensor: + grid_h, grid_w = grid_spec.grid_shape + result_rows = [] + for i in range(grid_h): + result_row = [] + for j in range(grid_w): + tile = coord_tensor_map[(i, j)] + if i > 0: + tile = self.blend_v(coord_tensor_map[(i - 1, j)], tile, grid_spec.tile_spec["blend_extent"]) + if j > 0: + tile = self.blend_h(coord_tensor_map[(i, j - 1)], tile, grid_spec.tile_spec["blend_extent"]) + result_row.append(tile[:, :, :, : grid_spec.tile_spec["row_limit"], : grid_spec.tile_spec["row_limit"]]) + result_rows.append(torch.cat(result_row, dim=-1)) + return torch.cat(result_rows, dim=-2) + + def spatial_tiled_encode(self, x: torch.Tensor): + if not self.is_distributed_enabled(): + return super().spatial_tiled_encode(x) + + logger.debug("Encode running with distributed executor") + return self.distributed_executor.execute( + x, + DistributedOperator( + split=self.encode_tile_split, + exec=self.encode_tile_exec, + merge=self.encode_tile_merge, + ), + broadcast_result=True, + ) + + def spatial_tiled_decode(self, z: torch.Tensor): + if not self.is_distributed_enabled(): + return super().spatial_tiled_decode(z) + + logger.debug("Decode running with distributed executor") + return self.distributed_executor.execute( + z, + DistributedOperator(split=self.tile_split, exec=self.tile_exec, merge=self.tile_merge), + broadcast_result=True, + ) diff --git a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py index 73b89bb11b0..228e5b4f93f 100644 --- a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py @@ -20,6 +20,9 @@ from vllm.transformers_utils.config import get_config from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuan import ( + DistributedAutoencoderKLHunyuan, +) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput @@ -30,7 +33,6 @@ from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.model_executor.models.hunyuan_image3.siglip2 import Siglip2VisionTransformer -from .autoencoder import AutoencoderKLConv3D from .hunyuan_image3_tokenizer import TokenizerWrapper from .hunyuan_image3_transformer import ( CausalMMOutputWithPast, @@ -343,7 +345,7 @@ def __init__(self, od_config: OmniDiffusionConfig) -> None: quant_config = od_config.quantization_config self.model = HunyuanImage3Model(self.hf_config, quant_config=quant_config) self.transformer = self.model - self.vae = AutoencoderKLConv3D.from_config(self.hf_config.vae) + self.vae = DistributedAutoencoderKLHunyuan.from_config(self.hf_config.vae) self.vae.use_spatial_tiling = self.od_config.vae_use_tiling self._pipeline = None self._tkwrapper = TokenizerWrapper(od_config.model) From 272bc98c2ce346707925f959585abd6f8823fabf Mon Sep 17 00:00:00 2001 From: KexiongYu Date: Fri, 15 May 2026 15:01:51 +0800 Subject: [PATCH 2/8] [UT] Add Hunyuan distributed VAE tests Signed-off-by: KexiongYu --- .../test_autoencoder_kl_hunyuan.py | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py diff --git a/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py b/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py new file mode 100644 index 00000000000..2c645da1d69 --- /dev/null +++ b/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuan import ( + DistributedAutoencoderKLHunyuan, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class _DummyDistributedAutoencoderKLHunyuan(DistributedAutoencoderKLHunyuan): + def __init__(self): + torch.nn.Module.__init__(self) + self.tile_latent_min_size = 2 + self.tile_sample_min_size = 2 + self.tile_overlap_factor = 0.0 + self.use_spatial_tiling = False + + @property + def dtype(self) -> torch.dtype: + return torch.float32 + + def decoder(self, x: torch.Tensor) -> torch.Tensor: + return x + 10 + + def encoder(self, x: torch.Tensor) -> torch.Tensor: + return x + 20 + + def blend_v(self, _a: torch.Tensor, b: torch.Tensor, _blend_extent: int) -> torch.Tensor: + return b + + def blend_h(self, _a: torch.Tensor, b: torch.Tensor, _blend_extent: int) -> torch.Tensor: + return b + + +def test_hunyuan_vae_use_tiling_aliases_spatial_tiling(): + vae = _DummyDistributedAutoencoderKLHunyuan() + + assert not vae.use_tiling + + vae.use_tiling = True + + assert vae.use_spatial_tiling + + +def test_hunyuan_vae_decode_tiles_round_trip(): + vae = _DummyDistributedAutoencoderKLHunyuan() + z = torch.arange(16, dtype=torch.float32).reshape(1, 1, 1, 4, 4) + + tile_tasks, grid_spec = vae.tile_split(z) + decoded_tiles = {task.grid_coord: vae.tile_exec(task) for task in tile_tasks} + output = vae.tile_merge(decoded_tiles, grid_spec) + + assert grid_spec.split_dims == (3, 4) + assert grid_spec.grid_shape == (2, 2) + assert grid_spec.tile_spec == {"blend_extent": 0, "row_limit": 2} + assert [task.grid_coord for task in tile_tasks] == [(0, 0), (0, 1), (1, 0), (1, 1)] + assert torch.equal(output, z + 10) + + +def test_hunyuan_vae_encode_tiles_round_trip(): + vae = _DummyDistributedAutoencoderKLHunyuan() + x = torch.arange(16, dtype=torch.float32).reshape(1, 1, 1, 4, 4) + + tile_tasks, grid_spec = vae.encode_tile_split(x) + encoded_tiles = {task.grid_coord: vae.encode_tile_exec(task) for task in tile_tasks} + output = vae.encode_tile_merge(encoded_tiles, grid_spec) + + assert grid_spec.split_dims == (3, 4) + assert grid_spec.grid_shape == (2, 2) + assert grid_spec.tile_spec == {"blend_extent": 0, "row_limit": 2} + assert [task.grid_coord for task in tile_tasks] == [(0, 0), (0, 1), (1, 0), (1, 1)] + assert torch.equal(output, x + 20) From 61fc92a1a39763e2cc40cbd94fee7cbb800463d2 Mon Sep 17 00:00:00 2001 From: zzh <943967662@qq.com> Date: Tue, 19 May 2026 11:31:27 +0800 Subject: [PATCH 3/8] Add concise per-test comments in tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py to clarify what each case validates, without changing test behavior Signed-off-by: zzh <943967662@qq.com> --- tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py b/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py index 2c645da1d69..9767f5c8e71 100644 --- a/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py +++ b/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py @@ -37,6 +37,7 @@ def blend_h(self, _a: torch.Tensor, b: torch.Tensor, _blend_extent: int) -> torc def test_hunyuan_vae_use_tiling_aliases_spatial_tiling(): + # Verify use_tiling property maps to use_spatial_tiling. vae = _DummyDistributedAutoencoderKLHunyuan() assert not vae.use_tiling @@ -47,6 +48,7 @@ def test_hunyuan_vae_use_tiling_aliases_spatial_tiling(): def test_hunyuan_vae_decode_tiles_round_trip(): + # Validate decode tile split/exec/merge returns expected reconstructed tensor. vae = _DummyDistributedAutoencoderKLHunyuan() z = torch.arange(16, dtype=torch.float32).reshape(1, 1, 1, 4, 4) @@ -62,6 +64,7 @@ def test_hunyuan_vae_decode_tiles_round_trip(): def test_hunyuan_vae_encode_tiles_round_trip(): + # Validate encode tile split/exec/merge returns expected latent tensor. vae = _DummyDistributedAutoencoderKLHunyuan() x = torch.arange(16, dtype=torch.float32).reshape(1, 1, 1, 4, 4) From 07027f8f9d71b46a159be4eb8f2b060a2a652d80 Mon Sep 17 00:00:00 2001 From: zzh <943967662@qq.com> Date: Tue, 19 May 2026 14:22:38 +0800 Subject: [PATCH 4/8] fix(diffusion): break circular import in pipeline_hunyuan_image3 autoencoder_kl_hunyuan imports AutoencoderKLConv3D from the hunyuan_image3 package, which triggers hunyuan_image3/__init__.py to execute and import pipeline_hunyuan_image3, which in turn imported DistributedAutoencoderKLHunyuan back from autoencoder_kl_hunyuan before it finished initializing, causing a circular import error during test collection. Fix by moving the top-level import of DistributedAutoencoderKLHunyuan into HunyuanImage3Pipeline.__init__ as a lazy import, so it is only resolved at call time when both modules are fully initialized. Signed-off-by: zzh <943967662@qq.com> --- .../models/hunyuan_image3/pipeline_hunyuan_image3.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py index 228e5b4f93f..12d3af4dcb6 100644 --- a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py @@ -20,9 +20,6 @@ from vllm.transformers_utils.config import get_config from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuan import ( - DistributedAutoencoderKLHunyuan, -) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput @@ -345,6 +342,11 @@ def __init__(self, od_config: OmniDiffusionConfig) -> None: quant_config = od_config.quantization_config self.model = HunyuanImage3Model(self.hf_config, quant_config=quant_config) self.transformer = self.model + # Lazy import to break circular dependency: + # autoencoder_kl_hunyuan -> hunyuan_image3/__init__ -> pipeline_hunyuan_image3 -> autoencoder_kl_hunyuan + from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuan import ( # noqa: PLC0415 + DistributedAutoencoderKLHunyuan, + ) self.vae = DistributedAutoencoderKLHunyuan.from_config(self.hf_config.vae) self.vae.use_spatial_tiling = self.od_config.vae_use_tiling self._pipeline = None From 7d4882036997f289ff4dd4d846ccb708bbe6c3ca Mon Sep 17 00:00:00 2001 From: zzh <943967662@qq.com> Date: Tue, 19 May 2026 14:29:24 +0800 Subject: [PATCH 5/8] fix(diffusion): break circular import in pipeline_hunyuan_image3 autoencoder_kl_hunyuan imports AutoencoderKLConv3D from the hunyuan_image3 package, which triggers hunyuan_image3/__init__.py to execute and import pipeline_hunyuan_image3, which in turn imported DistributedAutoencoderKLHunyuan back from autoencoder_kl_hunyuan before it finished initializing, causing a circular import error during test collection. Fix by moving the top-level import of DistributedAutoencoderKLHunyuan into HunyuanImage3Pipeline.__init__ as a lazy import, so it is only resolved at call time when both modules are fully initialized. Signed-off-by: zzh <943967662@qq.com> --- .../diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py index 12d3af4dcb6..677d67796be 100644 --- a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py @@ -347,6 +347,7 @@ def __init__(self, od_config: OmniDiffusionConfig) -> None: from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_hunyuan import ( # noqa: PLC0415 DistributedAutoencoderKLHunyuan, ) + self.vae = DistributedAutoencoderKLHunyuan.from_config(self.hf_config.vae) self.vae.use_spatial_tiling = self.od_config.vae_use_tiling self._pipeline = None From b52f18b327232053b0675d18e51f8856ad121db3 Mon Sep 17 00:00:00 2001 From: zzh <943967662@qq.com> Date: Tue, 19 May 2026 18:54:56 +0800 Subject: [PATCH 6/8] feat(diffusion): add from_pretrained to DistributedAutoencoderKLHunyuan and deduplicate tile merge - Add missing from_pretrained classmethod for consistency with other distributed autoencoders (KL, Wan, QwenImage) - Delegate encode_tile_merge to tile_merge to eliminate byte-for-byte duplicate code Signed-off-by: zzh <943967662@qq.com> --- .../test_autoencoder_kl_hunyuan.py | 2 +- .../autoencoders/autoencoder_kl_hunyuan.py | 20 +++++++------------ 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py b/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py index 9767f5c8e71..8bf59fb288e 100644 --- a/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py +++ b/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py @@ -16,7 +16,7 @@ def __init__(self): torch.nn.Module.__init__(self) self.tile_latent_min_size = 2 self.tile_sample_min_size = 2 - self.tile_overlap_factor = 0.0 + self.tile_overlap_factor = 0.25 self.use_spatial_tiling = False @property diff --git a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_hunyuan.py b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_hunyuan.py index 6295bbd1746..c0e3b4d6cc7 100644 --- a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_hunyuan.py +++ b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_hunyuan.py @@ -24,6 +24,12 @@ def from_config(cls, config: Any, **kwargs: Any): model.init_distributed() return model + @classmethod + def from_pretrained(cls, *args: Any, **kwargs: Any): + model = super().from_pretrained(*args, **kwargs) + model.init_distributed() + return model + @property def use_tiling(self) -> bool: return self.use_spatial_tiling @@ -110,19 +116,7 @@ def encode_tile_exec(self, task: TileTask) -> torch.Tensor: def encode_tile_merge( self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec ) -> torch.Tensor: - grid_h, grid_w = grid_spec.grid_shape - result_rows = [] - for i in range(grid_h): - result_row = [] - for j in range(grid_w): - tile = coord_tensor_map[(i, j)] - if i > 0: - tile = self.blend_v(coord_tensor_map[(i - 1, j)], tile, grid_spec.tile_spec["blend_extent"]) - if j > 0: - tile = self.blend_h(coord_tensor_map[(i, j - 1)], tile, grid_spec.tile_spec["blend_extent"]) - result_row.append(tile[:, :, :, : grid_spec.tile_spec["row_limit"], : grid_spec.tile_spec["row_limit"]]) - result_rows.append(torch.cat(result_row, dim=-1)) - return torch.cat(result_rows, dim=-2) + return self.tile_merge(coord_tensor_map, grid_spec) def spatial_tiled_encode(self, x: torch.Tensor): if not self.is_distributed_enabled(): From 85c590c69aa146c516f7b113edae667998a1fe21 Mon Sep 17 00:00:00 2001 From: zzh <943967662@qq.com> Date: Tue, 19 May 2026 20:30:41 +0800 Subject: [PATCH 7/8] update tile grid assertions for tile_overlap_factor=0.25 in Hunyuan VAE tests Adjust grid_shape from (2,2) to (4,4) and tile count from 4 to 16. When tile_overlap_factor=0.25, overlap_size becomes 1 instead of 2, producing a denser 4x4 tile grid on the 4x4 input. Signed-off-by: zzh <943967662@qq.com> --- .../diffusion/distributed/test_autoencoder_kl_hunyuan.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py b/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py index 8bf59fb288e..17e3961805e 100644 --- a/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py +++ b/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py @@ -57,9 +57,9 @@ def test_hunyuan_vae_decode_tiles_round_trip(): output = vae.tile_merge(decoded_tiles, grid_spec) assert grid_spec.split_dims == (3, 4) - assert grid_spec.grid_shape == (2, 2) + assert grid_spec.grid_shape == (4, 4) assert grid_spec.tile_spec == {"blend_extent": 0, "row_limit": 2} - assert [task.grid_coord for task in tile_tasks] == [(0, 0), (0, 1), (1, 0), (1, 1)] + assert len(tile_tasks) == 16 assert torch.equal(output, z + 10) @@ -73,7 +73,7 @@ def test_hunyuan_vae_encode_tiles_round_trip(): output = vae.encode_tile_merge(encoded_tiles, grid_spec) assert grid_spec.split_dims == (3, 4) - assert grid_spec.grid_shape == (2, 2) + assert grid_spec.grid_shape == (4, 4) assert grid_spec.tile_spec == {"blend_extent": 0, "row_limit": 2} - assert [task.grid_coord for task in tile_tasks] == [(0, 0), (0, 1), (1, 0), (1, 1)] + assert len(tile_tasks) == 16 assert torch.equal(output, x + 20) From 4f024a9169cc50a21fc233a02619d74fe28707e0 Mon Sep 17 00:00:00 2001 From: zzh <943967662@qq.com> Date: Tue, 19 May 2026 21:38:40 +0800 Subject: [PATCH 8/8] update tile grid assertions for tile_overlap_factor=0.25 in Hunyuan VAE tests With tile_latent_min_size=2 and tile_overlap_factor=0.25, blend_extent truncates to int(0.5)=0, causing overlapping tiles with no blending and producing misaligned 7x7 output instead of the expected 4x4. Increasing min_size to 8 makes blend_extent=2 and keeps the tile pipeline's math self-consistent while preserving tile_overlap_factor=0.25. Signed-off-by: zzh <943967662@qq.com> --- .../test_autoencoder_kl_hunyuan.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py b/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py index 17e3961805e..58ce14546e1 100644 --- a/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py +++ b/tests/diffusion/distributed/test_autoencoder_kl_hunyuan.py @@ -14,8 +14,8 @@ class _DummyDistributedAutoencoderKLHunyuan(DistributedAutoencoderKLHunyuan): def __init__(self): torch.nn.Module.__init__(self) - self.tile_latent_min_size = 2 - self.tile_sample_min_size = 2 + self.tile_latent_min_size = 8 + self.tile_sample_min_size = 8 self.tile_overlap_factor = 0.25 self.use_spatial_tiling = False @@ -50,30 +50,30 @@ def test_hunyuan_vae_use_tiling_aliases_spatial_tiling(): def test_hunyuan_vae_decode_tiles_round_trip(): # Validate decode tile split/exec/merge returns expected reconstructed tensor. vae = _DummyDistributedAutoencoderKLHunyuan() - z = torch.arange(16, dtype=torch.float32).reshape(1, 1, 1, 4, 4) + z = torch.arange(144, dtype=torch.float32).reshape(1, 1, 1, 12, 12) tile_tasks, grid_spec = vae.tile_split(z) decoded_tiles = {task.grid_coord: vae.tile_exec(task) for task in tile_tasks} output = vae.tile_merge(decoded_tiles, grid_spec) assert grid_spec.split_dims == (3, 4) - assert grid_spec.grid_shape == (4, 4) - assert grid_spec.tile_spec == {"blend_extent": 0, "row_limit": 2} - assert len(tile_tasks) == 16 + assert grid_spec.grid_shape == (2, 2) + assert grid_spec.tile_spec == {"blend_extent": 2, "row_limit": 6} + assert len(tile_tasks) == 4 assert torch.equal(output, z + 10) def test_hunyuan_vae_encode_tiles_round_trip(): # Validate encode tile split/exec/merge returns expected latent tensor. vae = _DummyDistributedAutoencoderKLHunyuan() - x = torch.arange(16, dtype=torch.float32).reshape(1, 1, 1, 4, 4) + x = torch.arange(144, dtype=torch.float32).reshape(1, 1, 1, 12, 12) tile_tasks, grid_spec = vae.encode_tile_split(x) encoded_tiles = {task.grid_coord: vae.encode_tile_exec(task) for task in tile_tasks} output = vae.encode_tile_merge(encoded_tiles, grid_spec) assert grid_spec.split_dims == (3, 4) - assert grid_spec.grid_shape == (4, 4) - assert grid_spec.tile_spec == {"blend_extent": 0, "row_limit": 2} - assert len(tile_tasks) == 16 + assert grid_spec.grid_shape == (2, 2) + assert grid_spec.tile_spec == {"blend_extent": 2, "row_limit": 6} + assert len(tile_tasks) == 4 assert torch.equal(output, x + 20)