Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 188 additions & 18 deletions docs/design/feature/vae_parallel.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# VAE Patch Parallelism

This document describes how to add **VAE Patch Parallelism** support to a diffusion model.
We use **Qwen-Image** as the reference implementation.
We use **Qwen-Image** as the reference implementation for decode parallel, and **Wan2.2** for encode parallel.

---

## Table of Contents

- [Overview](#overview)
- [Step-by-Step Implementation](#step-by-step-implementation)
- [Step-by-Step Implementation (Decode)](#step-by-step-implementation-decode)
- [Encode Parallel Implementation](#encode-parallel-implementation)
- [Testing](#testing)
- [Reference Implementations](#reference-implementations)
- [Summary](#summary)
Expand All @@ -19,13 +20,13 @@ We use **Qwen-Image** as the reference implementation.

### What is Vae Patch parallel?

**VAE Patch Parallelism** is a decoding acceleration technique. Instead of decoding the entire latent tensor at once, the latent tensor is:
**VAE Patch Parallelism** is an acceleration technique for both **encoding** and **decoding**. Instead of processing the entire tensor at once, the tensor is:

+ Split into multiple spatial tiles

+ Distributed across multiple ranks

+ Decoded in parallel
+ Encoded/Decoded in parallel

+ Merged to reconstruct the final output

Expand All @@ -35,10 +36,17 @@ This approach:

+ Reduces peak memory usage per device

+ Accelerates decoding latency
+ Accelerates encoding/decoding latency

### When to Use Encode vs Decode Parallel

| Operation | Use Case | Example |
|-----------|----------|---------|
| **Decode Parallel** | Text-to-Image, Text-to-Video | Latent → Image/Video |
| **Encode Parallel** | Image-to-Video (I2V) | Image → Latent (for conditioning) |

### Architecture
We introduce **DistributedVaeExecutor** as the core component responsible for distributed VAE decoding.
We introduce **DistributedVaeExecutor** as the core component responsible for distributed VAE encoding/decoding.

The executor is model-agnostic and accepts three function parameters:

Expand Down Expand Up @@ -84,7 +92,7 @@ Therefore:

+ Merge must perform blending to avoid seams

## Step-by-Step Implementation
## Step-by-Step Implementation (Decode)

### Step 1: Implement DistributedAutoencoderKLQwenImage
`QwenImagePipeline` use `AutoencoderKLQwenImage` for vae, so implement a distributed version:
Expand Down Expand Up @@ -205,14 +213,14 @@ def tile_merge(self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid
We need to override tiled_decode, the main logic is:
+ check distributed is enabled
+ select split/exec/merge
+ Invoke self.distributed_decoder.execute to decode
+ Invoke self.distributed_executor.execute to decode
```
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True):
if not self.is_distributed_enabled():
return super().tiled_decode(z, return_dict=return_dict)

logger.info("Decode run with distributed executor")
result = self.distributed_decoder.execute(
result = self.distributed_executor.execute(
z,
DistributedOperator(split=self.tile_split, exec=self.tile_exec, merge=self.tile_merge),
broadcast_result=True,
Expand Down Expand Up @@ -243,6 +251,166 @@ class YourModelPipeline(nn.Module):
+ ).to(self.device)
```

## Encode Parallel Implementation

For models that require VAE encoding (e.g., Image-to-Video), you can also parallelize the encode operation. We use **Wan2.2** as the reference implementation.

### Step 1: Implement encode_tile_split

Similar to decode, split the input tensor into tiles. Key considerations:

+ **Patchify handling**: If the model uses `patch_size`, scale tile parameters accordingly
+ **Temporal chunking**: Video VAEs may have temporal compression (e.g., 4x)

```python
def encode_tile_split(self, x: torch.Tensor) -> tuple[list[TileTask], GridSpec]:
_, _, num_frames, height, width = x.shape
encode_spatial_compression_ratio = self.spatial_compression_ratio

# Scale tile parameters for patchified coordinate system
tile_sample_min_height = self.tile_sample_min_height
tile_sample_min_width = self.tile_sample_min_width
tile_sample_stride_height = self.tile_sample_stride_height
tile_sample_stride_width = self.tile_sample_stride_width

if self.config.patch_size is not None:
# When input is patchified, scale tile parameters accordingly
encode_spatial_compression_ratio = self.spatial_compression_ratio // self.config.patch_size
tile_sample_min_height = tile_sample_min_height // self.config.patch_size
tile_sample_min_width = tile_sample_min_width // self.config.patch_size
tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size

latent_height = height // encode_spatial_compression_ratio
latent_width = width // encode_spatial_compression_ratio

tile_latent_min_height = tile_sample_min_height // encode_spatial_compression_ratio
tile_latent_min_width = tile_sample_min_width // encode_spatial_compression_ratio
tile_latent_stride_height = tile_sample_stride_height // encode_spatial_compression_ratio
tile_latent_stride_width = tile_sample_stride_width // encode_spatial_compression_ratio

blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width

tiletask_list = []
# Use temporal compression ratio from config instead of hardcoding
temporal_compression = self.config.scale_factor_temporal

for i in range(0, height, tile_sample_stride_height):
for j in range(0, width, tile_sample_stride_width):
time_list = []
frame_range = 1 + (num_frames - 1) // temporal_compression
for k in range(frame_range):
if k == 0:
tile = x[:, :, :1, i : i + tile_sample_min_height, j : j + tile_sample_min_width]
else:
tile = x[
:, :,
1 + temporal_compression * (k - 1) : 1 + temporal_compression * k,
i : i + tile_sample_min_height,
j : j + tile_sample_min_width,
]
time_list.append(tile)
tiletask_list.append(
TileTask(len(tiletask_list), (i // tile_sample_stride_height, j // tile_sample_stride_width),
time_list, workload=time_list[0].shape[3] * time_list[0].shape[4])
)

grid_spec = GridSpec(
split_dims=(3, 4),
grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1),
tile_spec={
"latent_height": latent_height, "latent_width": latent_width,
"blend_height": blend_height, "blend_width": blend_width,
"tile_latent_stride_height": tile_latent_stride_height,
"tile_latent_stride_width": tile_latent_stride_width,
},
output_dtype=self.dtype,
)
return tiletask_list, grid_spec
```

### Step 2: Implement encode_tile_exec

```python
def encode_tile_exec(self, task: TileTask) -> torch.Tensor:
"""Encode a single sample tile into latent space."""
self.clear_cache()
time = []
for k, tile in enumerate(task.tensor):
self._enc_conv_idx = [0]
encoded = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
encoded = self.quant_conv(encoded)
time.append(encoded)
result = torch.cat(time, dim=2)
self.clear_cache()
return result
```

### Step 3: Implement encode_tile_merge

```python
def encode_tile_merge(
self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec
) -> torch.Tensor:
"""Merge encoded tiles into a full latent tensor."""
grid_h, grid_w = grid_spec.grid_shape
result_rows = []
for i in range(grid_h):
result_row = []
for j in range(grid_w):
tile = coord_tensor_map[(i, j)]
if i > 0:
tile = self.blend_v(coord_tensor_map[(i - 1, j)], tile, grid_spec.tile_spec["blend_height"])
if j > 0:
tile = self.blend_h(coord_tensor_map[(i, j - 1)], tile, grid_spec.tile_spec["blend_width"])
result_row.append(tile[:, :, :,
: grid_spec.tile_spec["tile_latent_stride_height"],
: grid_spec.tile_spec["tile_latent_stride_width"]])
result_rows.append(torch.cat(result_row, dim=-1))

enc = torch.cat(result_rows, dim=3)[
:, :, :, : grid_spec.tile_spec["latent_height"], : grid_spec.tile_spec["latent_width"]
]
return enc
```

### Step 4: Override tiled_encode method

Override `tiled_encode` instead of `encode`. The parent's `_encode()` handles patchify before calling `tiled_encode()`, so input `x` is already patchified.

```python
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
"""
Encode using distributed VAE executor.

Note: x is already patchified by parent's _encode() before calling this method.
"""
if not self.is_distributed_enabled():
return super().tiled_encode(x)

self.clear_cache()
result = self.distributed_executor.execute(
x,
DistributedOperator(
split=self.encode_tile_split,
exec=self.encode_tile_exec,
merge=self.encode_tile_merge,
),
broadcast_result=True, # Latents needed by all ranks for diffusion
)
self.clear_cache()
return result
```

**Key differences from decode parallel:**

| Aspect | Decode Parallel | Encode Parallel |
|--------|-----------------|-----------------|
| `broadcast_result` | Often `False` (only rank 0 needs output) | `True` (all ranks need latents for diffusion) |
| Patchify | Applied in merge (unpatchify) | Handled by parent `_encode()` before `tiled_encode()` |
| Temporal chunking | Frame-by-frame | Chunk-based (e.g., 1 + 4n frames) |

## Testing
Verify numerical consistency between:
+ vae_patch_parallel_size = 1
Expand Down Expand Up @@ -272,18 +440,20 @@ When vae_patch_parallel_size is larger than the DiT world size, it will automati

Complete examples in the codebase:

| Model | Path | Notes |
|-------|------|-------|
| **Z-Image** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py` | Distributed AutoencoderKL |
| **Wan2.2** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py` | Distributed AutoencoderKLWan |
| **Qwen-Image** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_qwenimage.py` | Distributed AutoencoderKLQwenImage |
| Model | Path | Decode Parallel | Encode Parallel |
|-------|------|-----------------|-----------------|
| **Z-Image** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py` | ✅ | ❌ |
| **Wan2.2** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py` | ✅ | ✅ |
| **Qwen-Image** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_qwenimage.py` | ✅ | ❌ |

---

## Summary

Adding Vae Patch Parallel support to diffusion model:
Adding VAE Patch Parallel support to diffusion model:

1. **Implement Distributed Vae** - mainly copy from `diffusers` tiled_decode, and refactor into split/exec/merge
2. **Change vae model in pipeline to Distributed Vae**
3. **Test** - Verify with `tensor_parallel_size=N` quality
1. **Implement Distributed VAE** - Inherit from base VAE class and `DistributedVaeMixin`
2. **Decode Parallel** - Refactor `tiled_decode` into `tile_split`/`tile_exec`/`tile_merge`
3. **Encode Parallel** (optional) - Implement `encode_tile_split`/`encode_tile_exec`/`encode_tile_merge` for I2V models
4. **Change VAE model in pipeline** - Use the distributed version
5. **Test** - Verify numerical consistency with `vae_patch_parallel_size=1` vs `N`
Loading