diff --git a/docs/design/feature/vae_parallel.md b/docs/design/feature/vae_parallel.md index 9009ece72a5..e330b41a68f 100644 --- a/docs/design/feature/vae_parallel.md +++ b/docs/design/feature/vae_parallel.md @@ -1,14 +1,15 @@ # VAE Patch Parallelism This document describes how to add **VAE Patch Parallelism** support to a diffusion model. -We use **Qwen-Image** as the reference implementation. +We use **Qwen-Image** as the reference implementation for decode parallel, and **Wan2.2** for encode parallel. --- ## Table of Contents - [Overview](#overview) -- [Step-by-Step Implementation](#step-by-step-implementation) +- [Step-by-Step Implementation (Decode)](#step-by-step-implementation-decode) +- [Encode Parallel Implementation](#encode-parallel-implementation) - [Testing](#testing) - [Reference Implementations](#reference-implementations) - [Summary](#summary) @@ -19,13 +20,13 @@ We use **Qwen-Image** as the reference implementation. ### What is Vae Patch parallel? -**VAE Patch Parallelism** is a decoding acceleration technique. Instead of decoding the entire latent tensor at once, the latent tensor is: +**VAE Patch Parallelism** is an acceleration technique for both **encoding** and **decoding**. Instead of processing the entire tensor at once, the tensor is: + Split into multiple spatial tiles + Distributed across multiple ranks -+ Decoded in parallel ++ Encoded/Decoded in parallel + Merged to reconstruct the final output @@ -35,10 +36,17 @@ This approach: + Reduces peak memory usage per device -+ Accelerates decoding latency ++ Accelerates encoding/decoding latency + +### When to Use Encode vs Decode Parallel + +| Operation | Use Case | Example | +|-----------|----------|---------| +| **Decode Parallel** | Text-to-Image, Text-to-Video | Latent → Image/Video | +| **Encode Parallel** | Image-to-Video (I2V) | Image → Latent (for conditioning) | ### Architecture -We introduce **DistributedVaeExecutor** as the core component responsible for distributed VAE decoding. +We introduce **DistributedVaeExecutor** as the core component responsible for distributed VAE encoding/decoding. The executor is model-agnostic and accepts three function parameters: @@ -84,7 +92,7 @@ Therefore: + Merge must perform blending to avoid seams -## Step-by-Step Implementation +## Step-by-Step Implementation (Decode) ### Step 1: Implement DistributedAutoencoderKLQwenImage `QwenImagePipeline` use `AutoencoderKLQwenImage` for vae, so implement a distributed version: @@ -205,14 +213,14 @@ def tile_merge(self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid We need to override tiled_decode, the main logic is: + check distributed is enabled + select split/exec/merge -+ Invoke self.distributed_decoder.execute to decode ++ Invoke self.distributed_executor.execute to decode ``` def tiled_decode(self, z: torch.Tensor, return_dict: bool = True): if not self.is_distributed_enabled(): return super().tiled_decode(z, return_dict=return_dict) logger.info("Decode run with distributed executor") - result = self.distributed_decoder.execute( + result = self.distributed_executor.execute( z, DistributedOperator(split=self.tile_split, exec=self.tile_exec, merge=self.tile_merge), broadcast_result=True, @@ -243,6 +251,166 @@ class YourModelPipeline(nn.Module): + ).to(self.device) ``` +## Encode Parallel Implementation + +For models that require VAE encoding (e.g., Image-to-Video), you can also parallelize the encode operation. We use **Wan2.2** as the reference implementation. + +### Step 1: Implement encode_tile_split + +Similar to decode, split the input tensor into tiles. Key considerations: + ++ **Patchify handling**: If the model uses `patch_size`, scale tile parameters accordingly ++ **Temporal chunking**: Video VAEs may have temporal compression (e.g., 4x) + +```python +def encode_tile_split(self, x: torch.Tensor) -> tuple[list[TileTask], GridSpec]: + _, _, num_frames, height, width = x.shape + encode_spatial_compression_ratio = self.spatial_compression_ratio + + # Scale tile parameters for patchified coordinate system + tile_sample_min_height = self.tile_sample_min_height + tile_sample_min_width = self.tile_sample_min_width + tile_sample_stride_height = self.tile_sample_stride_height + tile_sample_stride_width = self.tile_sample_stride_width + + if self.config.patch_size is not None: + # When input is patchified, scale tile parameters accordingly + encode_spatial_compression_ratio = self.spatial_compression_ratio // self.config.patch_size + tile_sample_min_height = tile_sample_min_height // self.config.patch_size + tile_sample_min_width = tile_sample_min_width // self.config.patch_size + tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size + tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size + + latent_height = height // encode_spatial_compression_ratio + latent_width = width // encode_spatial_compression_ratio + + tile_latent_min_height = tile_sample_min_height // encode_spatial_compression_ratio + tile_latent_min_width = tile_sample_min_width // encode_spatial_compression_ratio + tile_latent_stride_height = tile_sample_stride_height // encode_spatial_compression_ratio + tile_latent_stride_width = tile_sample_stride_width // encode_spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + tiletask_list = [] + # Use temporal compression ratio from config instead of hardcoding + temporal_compression = self.config.scale_factor_temporal + + for i in range(0, height, tile_sample_stride_height): + for j in range(0, width, tile_sample_stride_width): + time_list = [] + frame_range = 1 + (num_frames - 1) // temporal_compression + for k in range(frame_range): + if k == 0: + tile = x[:, :, :1, i : i + tile_sample_min_height, j : j + tile_sample_min_width] + else: + tile = x[ + :, :, + 1 + temporal_compression * (k - 1) : 1 + temporal_compression * k, + i : i + tile_sample_min_height, + j : j + tile_sample_min_width, + ] + time_list.append(tile) + tiletask_list.append( + TileTask(len(tiletask_list), (i // tile_sample_stride_height, j // tile_sample_stride_width), + time_list, workload=time_list[0].shape[3] * time_list[0].shape[4]) + ) + + grid_spec = GridSpec( + split_dims=(3, 4), + grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1), + tile_spec={ + "latent_height": latent_height, "latent_width": latent_width, + "blend_height": blend_height, "blend_width": blend_width, + "tile_latent_stride_height": tile_latent_stride_height, + "tile_latent_stride_width": tile_latent_stride_width, + }, + output_dtype=self.dtype, + ) + return tiletask_list, grid_spec +``` + +### Step 2: Implement encode_tile_exec + +```python +def encode_tile_exec(self, task: TileTask) -> torch.Tensor: + """Encode a single sample tile into latent space.""" + self.clear_cache() + time = [] + for k, tile in enumerate(task.tensor): + self._enc_conv_idx = [0] + encoded = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + encoded = self.quant_conv(encoded) + time.append(encoded) + result = torch.cat(time, dim=2) + self.clear_cache() + return result +``` + +### Step 3: Implement encode_tile_merge + +```python +def encode_tile_merge( + self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec +) -> torch.Tensor: + """Merge encoded tiles into a full latent tensor.""" + grid_h, grid_w = grid_spec.grid_shape + result_rows = [] + for i in range(grid_h): + result_row = [] + for j in range(grid_w): + tile = coord_tensor_map[(i, j)] + if i > 0: + tile = self.blend_v(coord_tensor_map[(i - 1, j)], tile, grid_spec.tile_spec["blend_height"]) + if j > 0: + tile = self.blend_h(coord_tensor_map[(i, j - 1)], tile, grid_spec.tile_spec["blend_width"]) + result_row.append(tile[:, :, :, + : grid_spec.tile_spec["tile_latent_stride_height"], + : grid_spec.tile_spec["tile_latent_stride_width"]]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[ + :, :, :, : grid_spec.tile_spec["latent_height"], : grid_spec.tile_spec["latent_width"] + ] + return enc +``` + +### Step 4: Override tiled_encode method + +Override `tiled_encode` instead of `encode`. The parent's `_encode()` handles patchify before calling `tiled_encode()`, so input `x` is already patchified. + +```python +def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + """ + Encode using distributed VAE executor. + + Note: x is already patchified by parent's _encode() before calling this method. + """ + if not self.is_distributed_enabled(): + return super().tiled_encode(x) + + self.clear_cache() + result = self.distributed_executor.execute( + x, + DistributedOperator( + split=self.encode_tile_split, + exec=self.encode_tile_exec, + merge=self.encode_tile_merge, + ), + broadcast_result=True, # Latents needed by all ranks for diffusion + ) + self.clear_cache() + return result +``` + +**Key differences from decode parallel:** + +| Aspect | Decode Parallel | Encode Parallel | +|--------|-----------------|-----------------| +| `broadcast_result` | Often `False` (only rank 0 needs output) | `True` (all ranks need latents for diffusion) | +| Patchify | Applied in merge (unpatchify) | Handled by parent `_encode()` before `tiled_encode()` | +| Temporal chunking | Frame-by-frame | Chunk-based (e.g., 1 + 4n frames) | + ## Testing Verify numerical consistency between: + vae_patch_parallel_size = 1 @@ -272,18 +440,20 @@ When vae_patch_parallel_size is larger than the DiT world size, it will automati Complete examples in the codebase: -| Model | Path | Notes | -|-------|------|-------| -| **Z-Image** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py` | Distributed AutoencoderKL | -| **Wan2.2** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py` | Distributed AutoencoderKLWan | -| **Qwen-Image** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_qwenimage.py` | Distributed AutoencoderKLQwenImage | +| Model | Path | Decode Parallel | Encode Parallel | +|-------|------|-----------------|-----------------| +| **Z-Image** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py` | ✅ | ❌ | +| **Wan2.2** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py` | ✅ | ✅ | +| **Qwen-Image** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_qwenimage.py` | ✅ | ❌ | --- ## Summary -Adding Vae Patch Parallel support to diffusion model: +Adding VAE Patch Parallel support to diffusion model: -1. **Implement Distributed Vae** - mainly copy from `diffusers` tiled_decode, and refactor into split/exec/merge -2. **Change vae model in pipeline to Distributed Vae** -3. **Test** - Verify with `tensor_parallel_size=N` quality +1. **Implement Distributed VAE** - Inherit from base VAE class and `DistributedVaeMixin` +2. **Decode Parallel** - Refactor `tiled_decode` into `tile_split`/`tile_exec`/`tile_merge` +3. **Encode Parallel** (optional) - Implement `encode_tile_split`/`encode_tile_exec`/`encode_tile_merge` for I2V models +4. **Change VAE model in pipeline** - Use the distributed version +5. **Test** - Verify numerical consistency with `vae_patch_parallel_size=1` vs `N` diff --git a/docs/user_guide/diffusion_features.md b/docs/user_guide/diffusion_features.md index e7f33306ec6..c151164ca0e 100644 --- a/docs/user_guide/diffusion_features.md +++ b/docs/user_guide/diffusion_features.md @@ -114,13 +114,13 @@ The following tables show which models support each feature: | **Nextstep_1(T2I)** | ❓ | ❓ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | | **OmniGen2** | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | **Ovis-Image** | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| **Qwen-Image** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-2512** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Edit** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | -| **Qwen-Image-Edit-2509** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | -| **Qwen-Image-Layered** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | -| **Stable-Diffusion3.5** | ❌ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | -| **Z-Image** | ✅ | ✅ | ✅ | ❓ | ✅ (TP=2 only) | ✅ | ❌ | ✅ | ✅ | ❌ | +| **Qwen-Image** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ✅ | +| **Qwen-Image-2512** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ✅ | +| **Qwen-Image-Edit** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ❌ | ❌ | +| **Qwen-Image-Edit-2509** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ | ❌ | +| **Qwen-Image-Layered** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ❌ | ❌ | +| **Stable-Diffusion3.5** | ❌ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ (decode) | ❌ | ❌ | +| **Z-Image** | ✅ | ✅ | ✅ | ❓ | ✅ (TP=2 only) | ✅ | ❌ | ✅ (decode) | ✅ | ❌ | > Notes: > 1. Nextstep_1(T2I) does not support cache acceleration methods such as TeaCache or Cache-DiT. @@ -130,11 +130,11 @@ The following tables show which models support each feature: | Model | ⚡TeaCache | ⚡Cache-DiT | 🔀SP (Ulysses & Ring) | 🔀CFG-Parallel | 🔀Tensor-Parallel | 🔀HSDP | 💾CPU Offload (Layerwise) | 💾VAE-Patch-Parallel | 💾Quantization | 🔄Step Execution | |-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:| -| **Wan2.2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | -| **Wan2.1-VACE** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | +| **Wan2.2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (encode/decode) | ❌ | ❌ | +| **Wan2.1-VACE** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ❌ | ❌ | | **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | | **Helios** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | -| **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | +| **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ | | **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ### AudioGen diff --git a/tests/diffusion/distributed/test_autoencoder_kl_wan_encode.py b/tests/diffusion/distributed/test_autoencoder_kl_wan_encode.py new file mode 100644 index 00000000000..7a18fa66da3 --- /dev/null +++ b/tests/diffusion/distributed/test_autoencoder_kl_wan_encode.py @@ -0,0 +1,273 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Unit tests for DistributedAutoencoderKLWan encode parallel (CPU-only).""" + +import pytest +import torch + +pytestmark = [pytest.mark.cpu, pytest.mark.core_model] + + +class _DummyConfig: + def __init__(self, patch_size=None, scale_factor_temporal=4): + self.patch_size = patch_size + self.scale_factor_temporal = scale_factor_temporal + + +class _DummyWanVae: + """Minimal mock of DistributedAutoencoderKLWan for testing encode_tile_split.""" + + def __init__( + self, + config=None, + spatial_compression_ratio=8, + tile_sample_min_height=256, + tile_sample_min_width=256, + tile_sample_stride_height=192, + tile_sample_stride_width=192, + ): + self.config = config or _DummyConfig() + self.spatial_compression_ratio = spatial_compression_ratio + self.tile_sample_min_height = tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width + self.dtype = torch.float32 + + # Mock caches + self._enc_feat_map = None + self._enc_conv_idx = [0] + + def clear_cache(self): + self._enc_feat_map = None + self._enc_conv_idx = [0] + + def encoder(self, x, feat_cache=None, feat_idx=None): # noqa: ARG002 + # Simple mock: just return the input + return x + + def quant_conv(self, x): + return x + + def blend_v(self, _a, b, _blend_extent): + return b + + def blend_h(self, _a, b, _blend_extent): + return b + + +def _import_encode_tile_split(): + """Import the encode_tile_split method from the module.""" + from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_wan import ( + DistributedAutoencoderKLWan, + ) + + return DistributedAutoencoderKLWan.encode_tile_split + + +def _import_encode_tile_exec(): + from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_wan import ( + DistributedAutoencoderKLWan, + ) + + return DistributedAutoencoderKLWan.encode_tile_exec + + +def _import_encode_tile_merge(): + from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_wan import ( + DistributedAutoencoderKLWan, + ) + + return DistributedAutoencoderKLWan.encode_tile_merge + + +class TestEncodeTileSplit: + """Tests for encode_tile_split method.""" + + def test_basic_split_without_patch_size(self): + """Test basic tile splitting without patch_size.""" + encode_tile_split = _import_encode_tile_split() + + vae = _DummyWanVae( + config=_DummyConfig(patch_size=None, scale_factor_temporal=4), + spatial_compression_ratio=8, + tile_sample_min_height=256, + tile_sample_min_width=256, + tile_sample_stride_height=192, + tile_sample_stride_width=192, + ) + + # Input: (B, C, T, H, W) = (1, 3, 5, 256, 256) + x = torch.randn(1, 3, 5, 256, 256) + + tiletask_list, grid_spec = encode_tile_split(vae, x) + + # With stride 192 and input size 256, we should get: + # Height: ceil(256/192) = 2 positions (0, 192) but 192+256 > 256, so only 1 + # Actually for i in range(0, 256, 192): i = 0, 192 but 192 is out of bounds + # So we get 1x1 grid + assert len(tiletask_list) >= 1 + assert grid_spec.grid_shape[0] >= 1 + assert grid_spec.grid_shape[1] >= 1 + + # Check temporal chunking: 5 frames -> 1 + (5-1)//4 = 2 chunks + first_task = tiletask_list[0] + assert len(first_task.tensor) == 2 # 2 temporal chunks + + def test_split_with_patch_size_scales_coordinates(self): + """Test that patch_size properly scales tile coordinates.""" + encode_tile_split = _import_encode_tile_split() + + # Without patch_size + vae_no_patch = _DummyWanVae( + config=_DummyConfig(patch_size=None, scale_factor_temporal=4), + spatial_compression_ratio=8, + tile_sample_min_height=256, + tile_sample_min_width=256, + tile_sample_stride_height=128, + tile_sample_stride_width=128, + ) + + # With patch_size=2 (simulating patchified input) + vae_with_patch = _DummyWanVae( + config=_DummyConfig(patch_size=2, scale_factor_temporal=4), + spatial_compression_ratio=8, + tile_sample_min_height=256, + tile_sample_min_width=256, + tile_sample_stride_height=128, + tile_sample_stride_width=128, + ) + + # Same patchified input size + x = torch.randn(1, 3, 5, 256, 256) + + tasks_no_patch, _ = encode_tile_split(vae_no_patch, x) + tasks_with_patch, _ = encode_tile_split(vae_with_patch, x) + + # With patch_size=2, stride becomes 128//2=64, so more tiles + assert len(tasks_with_patch) >= len(tasks_no_patch) + + def test_temporal_compression_from_config(self): + """Test that temporal compression ratio is read from config.""" + encode_tile_split = _import_encode_tile_split() + + # temporal_compression=4 (default) + vae_4x = _DummyWanVae( + config=_DummyConfig(scale_factor_temporal=4), + tile_sample_min_height=512, + tile_sample_min_width=512, + tile_sample_stride_height=512, + tile_sample_stride_width=512, + ) + + # temporal_compression=2 + vae_2x = _DummyWanVae( + config=_DummyConfig(scale_factor_temporal=2), + tile_sample_min_height=512, + tile_sample_min_width=512, + tile_sample_stride_height=512, + tile_sample_stride_width=512, + ) + + # 9 frames input + x = torch.randn(1, 3, 9, 512, 512) + + tasks_4x, _ = encode_tile_split(vae_4x, x) + tasks_2x, _ = encode_tile_split(vae_2x, x) + + # With 4x compression: 1 + (9-1)//4 = 3 chunks + assert len(tasks_4x[0].tensor) == 3 + + # With 2x compression: 1 + (9-1)//2 = 5 chunks + assert len(tasks_2x[0].tensor) == 5 + + def test_grid_spec_latent_dimensions(self): + """Test that grid_spec contains correct latent dimensions.""" + encode_tile_split = _import_encode_tile_split() + + vae = _DummyWanVae( + config=_DummyConfig(patch_size=None), + spatial_compression_ratio=8, + tile_sample_min_height=512, + tile_sample_min_width=512, + tile_sample_stride_height=512, + tile_sample_stride_width=512, + ) + + # Input: 512x512 with compression 8 -> 64x64 latent + x = torch.randn(1, 3, 5, 512, 512) + + _, grid_spec = encode_tile_split(vae, x) + + assert grid_spec.tile_spec["latent_height"] == 64 + assert grid_spec.tile_spec["latent_width"] == 64 + + +class TestEncodeTileExec: + """Tests for encode_tile_exec method.""" + + def test_basic_exec(self): + """Test basic tile execution.""" + encode_tile_exec = _import_encode_tile_exec() + + vae = _DummyWanVae() + + from vllm_omni.diffusion.distributed.autoencoders.distributed_vae_executor import ( + TileTask, + ) + + # Create a simple task with 2 temporal chunks + tile1 = torch.randn(1, 3, 1, 32, 32) + tile2 = torch.randn(1, 3, 4, 32, 32) + task = TileTask(tile_id=0, grid_coord=(0, 0), tensor=[tile1, tile2]) + + result = encode_tile_exec(vae, task) + + # Result should concatenate temporal dimension + assert result.shape[2] == 5 # 1 + 4 frames + + +class TestEncodeTileMerge: + """Tests for encode_tile_merge method.""" + + def test_basic_merge(self): + """Test basic tile merging.""" + encode_tile_merge = _import_encode_tile_merge() + + vae = _DummyWanVae() + + from vllm_omni.diffusion.distributed.autoencoders.distributed_vae_executor import ( + GridSpec, + ) + + # Create 2x2 grid of tiles + tile_00 = torch.ones(1, 16, 2, 32, 32) * 0 + tile_01 = torch.ones(1, 16, 2, 32, 32) * 1 + tile_10 = torch.ones(1, 16, 2, 32, 32) * 2 + tile_11 = torch.ones(1, 16, 2, 32, 32) * 3 + + coord_tensor_map = { + (0, 0): tile_00, + (0, 1): tile_01, + (1, 0): tile_10, + (1, 1): tile_11, + } + + grid_spec = GridSpec( + split_dims=(3, 4), + grid_shape=(2, 2), + tile_spec={ + "latent_height": 48, + "latent_width": 48, + "blend_height": 8, + "blend_width": 8, + "tile_latent_stride_height": 24, + "tile_latent_stride_width": 24, + }, + ) + + result = encode_tile_merge(vae, coord_tensor_map, grid_spec) + + # Output should be (1, 16, 2, 48, 48) + assert result.shape == (1, 16, 2, 48, 48) diff --git a/tests/diffusion/distributed/test_distributed_vae_executor.py b/tests/diffusion/distributed/test_distributed_vae_executor.py index 42e9f3300bc..93cf3d195f5 100644 --- a/tests/diffusion/distributed/test_distributed_vae_executor.py +++ b/tests/diffusion/distributed/test_distributed_vae_executor.py @@ -59,9 +59,9 @@ def merge(self, coord_tensor_map, grid_spec): class DummyMixin(DistributedVaeMixin): def __init__(self): self.use_tiling = True - self.distributed_decoder = MagicMock() - self.distributed_decoder.parallel_size = 2 - self.distributed_decoder.group = None + self.distributed_executor = MagicMock() + self.distributed_executor.parallel_size = 2 + self.distributed_executor.group = None @pytest.fixture(autouse=True) diff --git a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py index 7df2d6a8add..0084719a8ab 100644 --- a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py +++ b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py @@ -93,7 +93,7 @@ def patch_split(self, z: torch.Tensor) -> tuple[list[TileTask], GridSpec]: _, _, latent_h, latent_w = z.shape scale = int(2 ** (len(self.config.block_out_channels) - 1)) - max_parallel_size = self.distributed_decoder.parallel_size + max_parallel_size = self.distributed_executor.parallel_size root = int(math.sqrt(max_parallel_size)) for rows in range(root, 0, -1): @@ -187,7 +187,7 @@ def decode(self, z: torch.Tensor, return_dict: bool = True, *args: Any, **kwargs if split is not None: strategy = "tile" if split == self.tile_split else "patch" logger.info(f"Decode run with distributed executor, split strategy is {strategy}") - result = self.distributed_decoder.execute( + result = self.distributed_executor.execute( z, DistributedOperator(split=split, exec=exec, merge=merge), broadcast_result=False ) if not return_dict: diff --git a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_qwenimage.py b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_qwenimage.py index 7549bbd3d5a..f9dea8a36d9 100644 --- a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_qwenimage.py +++ b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_qwenimage.py @@ -108,8 +108,8 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True): if not self.is_distributed_enabled(): return super().tiled_decode(z, return_dict=return_dict) - logger.info("Decode run with distributed executor") - result = self.distributed_decoder.execute( + logger.debug("Decode running with distributed executor") + result = self.distributed_executor.execute( z, DistributedOperator(split=self.tile_split, exec=self.tile_exec, merge=self.tile_merge), broadcast_result=True, diff --git a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py index 7defbae79b7..027991c3f26 100644 --- a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py +++ b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py @@ -92,6 +92,119 @@ def tile_exec(self, task: TileTask) -> torch.Tensor: result = torch.cat(time, dim=2) return result + def encode_tile_split(self, x: torch.Tensor) -> tuple[list[TileTask], GridSpec]: + _, _, num_frames, height, width = x.shape + encode_spatial_compression_ratio = self.spatial_compression_ratio + # Scale tile parameters for patchified coordinate system + tile_sample_min_height = self.tile_sample_min_height + tile_sample_min_width = self.tile_sample_min_width + tile_sample_stride_height = self.tile_sample_stride_height + tile_sample_stride_width = self.tile_sample_stride_width + if self.config.patch_size is not None: + assert encode_spatial_compression_ratio % self.config.patch_size == 0 + encode_spatial_compression_ratio = self.spatial_compression_ratio // self.config.patch_size + # When input is patchified, scale tile parameters accordingly + tile_sample_min_height = tile_sample_min_height // self.config.patch_size + tile_sample_min_width = tile_sample_min_width // self.config.patch_size + tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size + tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size + + latent_height = height // encode_spatial_compression_ratio + latent_width = width // encode_spatial_compression_ratio + + tile_latent_min_height = tile_sample_min_height // encode_spatial_compression_ratio + tile_latent_min_width = tile_sample_min_width // encode_spatial_compression_ratio + tile_latent_stride_height = tile_sample_stride_height // encode_spatial_compression_ratio + tile_latent_stride_width = tile_sample_stride_width // encode_spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + tiletask_list = [] + temporal_compression = self.config.scale_factor_temporal + for i in range(0, height, tile_sample_stride_height): + for j in range(0, width, tile_sample_stride_width): + time_list = [] + frame_range = 1 + (num_frames - 1) // temporal_compression + for k in range(frame_range): + if k == 0: + tile = x[:, :, :1, i : i + tile_sample_min_height, j : j + tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + temporal_compression * (k - 1) : 1 + temporal_compression * k, + i : i + tile_sample_min_height, + j : j + tile_sample_min_width, + ] + time_list.append(tile) + tiletask_list.append( + TileTask( + len(tiletask_list), + (i // tile_sample_stride_height, j // tile_sample_stride_width), + time_list, + workload=time_list[0].shape[3] * time_list[0].shape[4], + ) + ) + + grid_spec = GridSpec( + split_dims=(3, 4), + grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1), + tile_spec={ + "latent_height": latent_height, + "latent_width": latent_width, + "blend_height": blend_height, + "blend_width": blend_width, + "tile_latent_stride_height": tile_latent_stride_height, + "tile_latent_stride_width": tile_latent_stride_width, + }, + output_dtype=self.dtype, + ) + return tiletask_list, grid_spec + + def encode_tile_exec(self, task: TileTask) -> torch.Tensor: + """Encode a single sample tile into latent space.""" + self.clear_cache() + time = [] + for k, tile in enumerate(task.tensor): + self._enc_conv_idx = [0] + encoded = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + encoded = self.quant_conv(encoded) + time.append(encoded) + result = torch.cat(time, dim=2) + self.clear_cache() + return result + + def encode_tile_merge( + self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec + ) -> torch.Tensor: + """Merge encoded tiles into a full latent tensor.""" + grid_h, grid_w = grid_spec.grid_shape + result_rows = [] + for i in range(grid_h): + result_row = [] + for j in range(grid_w): + tile = coord_tensor_map[(i, j)] + if i > 0: + tile = self.blend_v(coord_tensor_map[(i - 1, j)], tile, grid_spec.tile_spec["blend_height"]) + if j > 0: + tile = self.blend_h(coord_tensor_map[(i, j - 1)], tile, grid_spec.tile_spec["blend_width"]) + result_row.append( + tile[ + :, + :, + :, + : grid_spec.tile_spec["tile_latent_stride_height"], + : grid_spec.tile_spec["tile_latent_stride_width"], + ] + ) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[ + :, :, :, : grid_spec.tile_spec["latent_height"], : grid_spec.tile_spec["latent_width"] + ] + return enc + def tile_merge(self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec) -> torch.Tensor: """Merge decoded tiles into a full image.""" grid_h, grid_w = grid_spec.grid_shape @@ -130,8 +243,8 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True): if not self.is_distributed_enabled(): return super().tiled_decode(z, return_dict=return_dict) - logger.info("Decode run with distributed executor") - result = self.distributed_decoder.execute( + logger.debug("Decode running with distributed executor") + result = self.distributed_executor.execute( z, DistributedOperator(split=self.tile_split, exec=self.tile_exec, merge=self.tile_merge), broadcast_result=False, @@ -140,3 +253,26 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True): return (result,) return DecoderOutput(sample=result) + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + """ + Encode using distributed VAE executor. + + Note: x is already patchified by parent's _encode() before calling this method. + """ + if not self.is_distributed_enabled(): + return super().tiled_encode(x) + + logger.debug("Encode running with distributed executor") + self.clear_cache() + result = self.distributed_executor.execute( + x, + DistributedOperator( + split=self.encode_tile_split, + exec=self.encode_tile_exec, + merge=self.encode_tile_merge, + ), + broadcast_result=True, + ) + self.clear_cache() + return result diff --git a/vllm_omni/diffusion/distributed/autoencoders/distributed_vae_executor.py b/vllm_omni/diffusion/distributed/autoencoders/distributed_vae_executor.py index bdf664741db..ad60d164aae 100644 --- a/vllm_omni/diffusion/distributed/autoencoders/distributed_vae_executor.py +++ b/vllm_omni/diffusion/distributed/autoencoders/distributed_vae_executor.py @@ -168,25 +168,25 @@ def _sync_final_result(self, rank0_result, output_ndim, output_device, output_dt class DistributedVaeMixin: def init_distributed(self): - self.distributed_decoder = DistributedVaeExecutor() + self.distributed_executor = DistributedVaeExecutor() - def set_parallel_size(self, parallel_size: int) -> bool: - return self.distributed_decoder.set_parallel_size(parallel_size) + def set_parallel_size(self, parallel_size: int) -> None: + self.distributed_executor.set_parallel_size(parallel_size) def is_distributed_enabled(self) -> bool: if ( - self.distributed_decoder.parallel_size <= 1 + self.distributed_executor.parallel_size <= 1 or not dist.is_initialized() or not getattr(self, "use_tiling", False) ): return False - world_size = dist.get_world_size(group=self.distributed_decoder.group) - pp_size = min(int(self.distributed_decoder.parallel_size), int(world_size)) + world_size = dist.get_world_size(group=self.distributed_executor.group) + pp_size = min(int(self.distributed_executor.parallel_size), int(world_size)) if pp_size <= 1: return False - if self.distributed_decoder.parallel_size > pp_size: + if self.distributed_executor.parallel_size > pp_size: logger.warning( - f"vae_patch_parallel_size={self.distributed_decoder.parallel_size} " + f"vae_patch_parallel_size={self.distributed_executor.parallel_size} " f"is greater than dit_group={world_size};" f" using dit_group size={world_size}" )