Skip to content
206 changes: 188 additions & 18 deletions docs/design/feature/vae_parallel.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# VAE Patch Parallelism

This document describes how to add **VAE Patch Parallelism** support to a diffusion model.
We use **Qwen-Image** as the reference implementation.
Comment thread
gcanlin marked this conversation as resolved.
We use **Qwen-Image** as the reference implementation for decode parallel, and **Wan2.2** for encode parallel.

---

## Table of Contents

- [Overview](#overview)
- [Step-by-Step Implementation](#step-by-step-implementation)
- [Step-by-Step Implementation (Decode)](#step-by-step-implementation-decode)
- [Encode Parallel Implementation](#encode-parallel-implementation)
- [Testing](#testing)
- [Reference Implementations](#reference-implementations)
- [Summary](#summary)
Expand All @@ -19,13 +20,13 @@ We use **Qwen-Image** as the reference implementation.

### What is Vae Patch parallel?

**VAE Patch Parallelism** is a decoding acceleration technique. Instead of decoding the entire latent tensor at once, the latent tensor is:
**VAE Patch Parallelism** is an acceleration technique for both **encoding** and **decoding**. Instead of processing the entire tensor at once, the tensor is:

+ Split into multiple spatial tiles

+ Distributed across multiple ranks

+ Decoded in parallel
+ Encoded/Decoded in parallel

+ Merged to reconstruct the final output

Expand All @@ -35,10 +36,17 @@ This approach:

+ Reduces peak memory usage per device

+ Accelerates decoding latency
+ Accelerates encoding/decoding latency

### When to Use Encode vs Decode Parallel

| Operation | Use Case | Example |
|-----------|----------|---------|
| **Decode Parallel** | Text-to-Image, Text-to-Video | Latent → Image/Video |
| **Encode Parallel** | Image-to-Video (I2V) | Image → Latent (for conditioning) |

### Architecture
We introduce **DistributedVaeExecutor** as the core component responsible for distributed VAE decoding.
We introduce **DistributedVaeExecutor** as the core component responsible for distributed VAE encoding/decoding.

The executor is model-agnostic and accepts three function parameters:

Expand Down Expand Up @@ -84,7 +92,7 @@ Therefore:

+ Merge must perform blending to avoid seams

## Step-by-Step Implementation
## Step-by-Step Implementation (Decode)

### Step 1: Implement DistributedAutoencoderKLQwenImage
`QwenImagePipeline` use `AutoencoderKLQwenImage` for vae, so implement a distributed version:
Expand Down Expand Up @@ -205,14 +213,14 @@ def tile_merge(self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid
We need to override tiled_decode, the main logic is:
+ check distributed is enabled
+ select split/exec/merge
+ Invoke self.distributed_decoder.execute to decode
+ Invoke self.distributed_executor.execute to decode
```
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True):
if not self.is_distributed_enabled():
return super().tiled_decode(z, return_dict=return_dict)

logger.info("Decode run with distributed executor")
result = self.distributed_decoder.execute(
result = self.distributed_executor.execute(
z,
DistributedOperator(split=self.tile_split, exec=self.tile_exec, merge=self.tile_merge),
broadcast_result=True,
Expand Down Expand Up @@ -243,6 +251,166 @@ class YourModelPipeline(nn.Module):
+ ).to(self.device)
```

## Encode Parallel Implementation

For models that require VAE encoding (e.g., Image-to-Video), you can also parallelize the encode operation. We use **Wan2.2** as the reference implementation.

### Step 1: Implement encode_tile_split

Similar to decode, split the input tensor into tiles. Key considerations:

+ **Patchify handling**: If the model uses `patch_size`, scale tile parameters accordingly
+ **Temporal chunking**: Video VAEs may have temporal compression (e.g., 4x)

```python
def encode_tile_split(self, x: torch.Tensor) -> tuple[list[TileTask], GridSpec]:
_, _, num_frames, height, width = x.shape
encode_spatial_compression_ratio = self.spatial_compression_ratio

# Scale tile parameters for patchified coordinate system
tile_sample_min_height = self.tile_sample_min_height
tile_sample_min_width = self.tile_sample_min_width
tile_sample_stride_height = self.tile_sample_stride_height
tile_sample_stride_width = self.tile_sample_stride_width

if self.config.patch_size is not None:
# When input is patchified, scale tile parameters accordingly
encode_spatial_compression_ratio = self.spatial_compression_ratio // self.config.patch_size
tile_sample_min_height = tile_sample_min_height // self.config.patch_size
tile_sample_min_width = tile_sample_min_width // self.config.patch_size
tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size

latent_height = height // encode_spatial_compression_ratio
latent_width = width // encode_spatial_compression_ratio

tile_latent_min_height = tile_sample_min_height // encode_spatial_compression_ratio
tile_latent_min_width = tile_sample_min_width // encode_spatial_compression_ratio
tile_latent_stride_height = tile_sample_stride_height // encode_spatial_compression_ratio
tile_latent_stride_width = tile_sample_stride_width // encode_spatial_compression_ratio

blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width

tiletask_list = []
# Use temporal compression ratio from config instead of hardcoding
temporal_compression = self.config.scale_factor_temporal

for i in range(0, height, tile_sample_stride_height):
for j in range(0, width, tile_sample_stride_width):
time_list = []
frame_range = 1 + (num_frames - 1) // temporal_compression
for k in range(frame_range):
if k == 0:
tile = x[:, :, :1, i : i + tile_sample_min_height, j : j + tile_sample_min_width]
else:
tile = x[
:, :,
1 + temporal_compression * (k - 1) : 1 + temporal_compression * k,
i : i + tile_sample_min_height,
j : j + tile_sample_min_width,
]
time_list.append(tile)
tiletask_list.append(
TileTask(len(tiletask_list), (i // tile_sample_stride_height, j // tile_sample_stride_width),
time_list, workload=time_list[0].shape[3] * time_list[0].shape[4])
)

grid_spec = GridSpec(
split_dims=(3, 4),
grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1),
tile_spec={
"latent_height": latent_height, "latent_width": latent_width,
"blend_height": blend_height, "blend_width": blend_width,
"tile_latent_stride_height": tile_latent_stride_height,
"tile_latent_stride_width": tile_latent_stride_width,
},
output_dtype=self.dtype,
)
return tiletask_list, grid_spec
```

### Step 2: Implement encode_tile_exec

```python
def encode_tile_exec(self, task: TileTask) -> torch.Tensor:
"""Encode a single sample tile into latent space."""
self.clear_cache()
time = []
for k, tile in enumerate(task.tensor):
self._enc_conv_idx = [0]
encoded = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
encoded = self.quant_conv(encoded)
time.append(encoded)
result = torch.cat(time, dim=2)
self.clear_cache()
return result
```

### Step 3: Implement encode_tile_merge

```python
def encode_tile_merge(
self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec
) -> torch.Tensor:
"""Merge encoded tiles into a full latent tensor."""
grid_h, grid_w = grid_spec.grid_shape
result_rows = []
for i in range(grid_h):
result_row = []
for j in range(grid_w):
tile = coord_tensor_map[(i, j)]
if i > 0:
tile = self.blend_v(coord_tensor_map[(i - 1, j)], tile, grid_spec.tile_spec["blend_height"])
if j > 0:
tile = self.blend_h(coord_tensor_map[(i, j - 1)], tile, grid_spec.tile_spec["blend_width"])
result_row.append(tile[:, :, :,
: grid_spec.tile_spec["tile_latent_stride_height"],
: grid_spec.tile_spec["tile_latent_stride_width"]])
result_rows.append(torch.cat(result_row, dim=-1))

enc = torch.cat(result_rows, dim=3)[
:, :, :, : grid_spec.tile_spec["latent_height"], : grid_spec.tile_spec["latent_width"]
]
return enc
```

### Step 4: Override tiled_encode method

Override `tiled_encode` instead of `encode`. The parent's `_encode()` handles patchify before calling `tiled_encode()`, so input `x` is already patchified.

```python
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
"""
Encode using distributed VAE executor.

Note: x is already patchified by parent's _encode() before calling this method.
"""
if not self.is_distributed_enabled():
return super().tiled_encode(x)

self.clear_cache()
result = self.distributed_executor.execute(
x,
DistributedOperator(
split=self.encode_tile_split,
exec=self.encode_tile_exec,
merge=self.encode_tile_merge,
),
broadcast_result=True, # Latents needed by all ranks for diffusion
)
self.clear_cache()
return result
```

**Key differences from decode parallel:**

| Aspect | Decode Parallel | Encode Parallel |
|--------|-----------------|-----------------|
| `broadcast_result` | Often `False` (only rank 0 needs output) | `True` (all ranks need latents for diffusion) |
| Patchify | Applied in merge (unpatchify) | Handled by parent `_encode()` before `tiled_encode()` |
| Temporal chunking | Frame-by-frame | Chunk-based (e.g., 1 + 4n frames) |

## Testing
Verify numerical consistency between:
+ vae_patch_parallel_size = 1
Expand Down Expand Up @@ -272,18 +440,20 @@ When vae_patch_parallel_size is larger than the DiT world size, it will automati

Complete examples in the codebase:

| Model | Path | Notes |
|-------|------|-------|
| **Z-Image** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py` | Distributed AutoencoderKL |
| **Wan2.2** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py` | Distributed AutoencoderKLWan |
| **Qwen-Image** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_qwenimage.py` | Distributed AutoencoderKLQwenImage |
| Model | Path | Decode Parallel | Encode Parallel |
|-------|------|-----------------|-----------------|
| **Z-Image** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py` | ✅ | ❌ |
| **Wan2.2** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py` | ✅ | ✅ |
| **Qwen-Image** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_qwenimage.py` | ✅ | ❌ |

---

## Summary

Adding Vae Patch Parallel support to diffusion model:
Adding VAE Patch Parallel support to diffusion model:

1. **Implement Distributed Vae** - mainly copy from `diffusers` tiled_decode, and refactor into split/exec/merge
2. **Change vae model in pipeline to Distributed Vae**
3. **Test** - Verify with `tensor_parallel_size=N` quality
1. **Implement Distributed VAE** - Inherit from base VAE class and `DistributedVaeMixin`
2. **Decode Parallel** - Refactor `tiled_decode` into `tile_split`/`tile_exec`/`tile_merge`
3. **Encode Parallel** (optional) - Implement `encode_tile_split`/`encode_tile_exec`/`encode_tile_merge` for I2V models
4. **Change VAE model in pipeline** - Use the distributed version
5. **Test** - Verify numerical consistency with `vae_patch_parallel_size=1` vs `N`
20 changes: 10 additions & 10 deletions docs/user_guide/diffusion_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ The following tables show which models support each feature:
| **Nextstep_1(T2I)** | ❓ | ❓ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **OmniGen2** | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **Ovis-Image** | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **Qwen-Image** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| **Qwen-Image-2512** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| **Qwen-Image-Edit** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| **Qwen-Image-Edit-2509** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| **Qwen-Image-Layered** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| **Stable-Diffusion3.5** | ❌ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
| **Z-Image** | ✅ | ✅ | ✅ | ❓ | ✅ (TP=2 only) | ✅ | ❌ | ✅ | ✅ | ❌ |
| **Qwen-Image** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ✅ |
| **Qwen-Image-2512** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ✅ |
| **Qwen-Image-Edit** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ❌ | ❌ |
| **Qwen-Image-Edit-2509** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ | ❌ |
| **Qwen-Image-Layered** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ❌ | ❌ |
| **Stable-Diffusion3.5** | ❌ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ (decode) | ❌ | ❌ |
| **Z-Image** | ✅ | ✅ | ✅ | ❓ | ✅ (TP=2 only) | ✅ | ❌ | ✅ (decode) | ✅ | ❌ |

> Notes:
> 1. Nextstep_1(T2I) does not support cache acceleration methods such as TeaCache or Cache-DiT.
Expand All @@ -130,11 +130,11 @@ The following tables show which models support each feature:

| Model | ⚡TeaCache | ⚡Cache-DiT | 🔀SP (Ulysses & Ring) | 🔀CFG-Parallel | 🔀Tensor-Parallel | 🔀HSDP | 💾CPU Offload (Layerwise) | 💾VAE-Patch-Parallel | 💾Quantization | 🔄Step Execution |
|-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:|
| **Wan2.2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| **Wan2.1-VACE** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| **Wan2.2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (encode/decode) | ❌ | ❌ |
| **Wan2.1-VACE** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ❌ | ❌ |
| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **Helios** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ |
| **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |

### AudioGen
Expand Down
Loading
Loading