From c98eb8493219cf793a668d0587339c4685c9a8e4 Mon Sep 17 00:00:00 2001
From: Chen Yang <2082464740@qq.com>
Date: Tue, 24 Mar 2026 20:06:38 +0800
Subject: [PATCH 1/8] [Feature]Adding vae patch parallel supports for VideoGen
Signed-off-by: Chen Yang <2082464740@qq.com>
---
.../autoencoders/autoencoder_kl_ltx2.py | 143 ++++++++++++++++++
.../diffusion/models/ltx2/pipeline_ltx2.py | 9 +-
2 files changed, 149 insertions(+), 3 deletions(-)
create mode 100644 vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_ltx2.py
diff --git a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_ltx2.py b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_ltx2.py
new file mode 100644
index 00000000000..3c1436d6716
--- /dev/null
+++ b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_ltx2.py
@@ -0,0 +1,143 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from typing import Any
+
+import torch
+from diffusers import AutoencoderKLLTX2Video
+from diffusers.models.autoencoders.vae import DecoderOutput
+from vllm.logger import init_logger
+
+from vllm_omni.diffusion.distributed.autoencoders.distributed_vae_executor import (
+ DistributedOperator,
+ DistributedVaeMixin,
+ GridSpec,
+ TileTask,
+)
+
+logger = init_logger(__name__)
+
+
+class DistributedAutoencoderKLLTX2Video(AutoencoderKLLTX2Video, DistributedVaeMixin):
+ @classmethod
+ def from_pretrained(cls, *args: Any, **kwargs: Any):
+ model = super().from_pretrained(*args, **kwargs)
+ model.init_distributed()
+ return model
+
+ def tile_split(self, z: torch.Tensor) -> tuple[list[TileTask], GridSpec]:
+ _, _, num_frames, height, width = z.shape
+ sample_height = height * self.spatial_compression_ratio
+ sample_width = width * self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+ tile_sample_stride_height = self.tile_sample_stride_height
+ tile_sample_stride_width = self.tile_sample_stride_width
+ # `super().decode(...)` already returns fully decoded pixel tiles, so
+ # no extra patch-size downscaling/unpatchifying should be applied here.
+ blend_height = self.tile_sample_min_height - tile_sample_stride_height
+ blend_width = self.tile_sample_min_width - tile_sample_stride_width
+
+ tiletask_list = []
+ for i in range(0, height, tile_latent_stride_height):
+ for j in range(0, width, tile_latent_stride_width):
+ tile = z[:, :, :num_frames, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
+ tiletask_list.append(
+ TileTask(
+ len(tiletask_list),
+ (i // tile_latent_stride_height, j // tile_latent_stride_width),
+ tile,
+ workload=tile.shape[2] * tile.shape[3] * tile.shape[4],
+ )
+ )
+
+ tile_spec = {
+ "sample_height": sample_height,
+ "sample_width": sample_width,
+ "blend_height": blend_height,
+ "blend_width": blend_width,
+ "tile_sample_stride_height": tile_sample_stride_height,
+ "tile_sample_stride_width": tile_sample_stride_width,
+ }
+ grid_spec = GridSpec(
+ split_dims=(3, 4),
+ grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1),
+ tile_spec=tile_spec,
+ output_dtype=self.dtype,
+ )
+ return tiletask_list, grid_spec
+
+ def tile_exec(
+ self,
+ task: TileTask,
+ timestep: torch.Tensor | None = None,
+ *args: Any,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ """Decode a single latent tile into video space."""
+ tile = task.tensor
+ if hasattr(self, "clear_cache"):
+ self.clear_cache()
+ return super().decode(tile, timestep, return_dict=False, *args, **kwargs)[0]
+
+ def tile_merge(self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec) -> torch.Tensor:
+ """Merge decoded tiles into a full video."""
+ grid_h, grid_w = grid_spec.grid_shape
+ result_rows = []
+
+ if hasattr(self, "clear_cache"):
+ self.clear_cache()
+
+ for i in range(grid_h):
+ result_row = []
+ for j in range(grid_w):
+ tile = coord_tensor_map[(i, j)]
+ if i > 0:
+ tile = self.blend_v(coord_tensor_map[(i - 1, j)], tile, grid_spec.tile_spec["blend_height"])
+ if j > 0:
+ tile = self.blend_h(coord_tensor_map[(i, j - 1)], tile, grid_spec.tile_spec["blend_width"])
+ result_row.append(
+ tile[
+ :,
+ :,
+ :,
+ : grid_spec.tile_spec["tile_sample_stride_height"],
+ : grid_spec.tile_spec["tile_sample_stride_width"],
+ ]
+ )
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ dec = torch.cat(result_rows, dim=3)[
+ :, :, :, : grid_spec.tile_spec["sample_height"], : grid_spec.tile_spec["sample_width"]
+ ]
+ dec = torch.clamp(dec, min=-1.0, max=1.0)
+ return dec
+
+ def decode(
+ self,
+ z: torch.Tensor,
+ timestep: torch.Tensor | None = None,
+ return_dict: bool = True,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ if not self.is_distributed_enabled():
+ return super().decode(z, timestep, return_dict=return_dict, *args, **kwargs)
+
+ logger.info("Decode run with distributed executor")
+ result = self.distributed_decoder.execute(
+ z,
+ DistributedOperator(
+ split=self.tile_split,
+ exec=lambda task: self.tile_exec(task, timestep=timestep, *args, **kwargs),
+ merge=self.tile_merge,
+ ),
+ broadcast_result=True,
+ )
+ if not return_dict:
+ return (result,)
+
+ return DecoderOutput(sample=result)
\ No newline at end of file
diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py
index 34263e217e9..f86666c6dd4 100644
--- a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py
+++ b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py
@@ -13,7 +13,7 @@
import numpy as np
import torch
-from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, FlowMatchEulerDiscreteScheduler
+from diffusers import AutoencoderKLLTX2Audio, FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.ltx2 import LTX2TextConnectors
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg, retrieve_timesteps
@@ -24,6 +24,9 @@
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
+from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_ltx2 import (
+ DistributedAutoencoderKLLTX2Video,
+)
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.parallel_state import (
get_cfg_group,
@@ -158,7 +161,7 @@ def __init__(
local_files_only=local_files_only,
).to(self.device)
- self.vae = AutoencoderKLLTX2Video.from_pretrained(
+ self.vae = DistributedAutoencoderKLLTX2Video.from_pretrained(
model,
subfolder="vae",
torch_dtype=dtype,
@@ -1173,4 +1176,4 @@ def forward(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
- return loader.load_weights(weights)
+ return loader.load_weights(weights)
\ No newline at end of file
From c760d7a0dfc92de0a269e2d547a5c9f55fe08f61 Mon Sep 17 00:00:00 2001
From: Chen Yang <2082464740@qq.com>
Date: Tue, 24 Mar 2026 20:20:01 +0800
Subject: [PATCH 2/8] fix
Signed-off-by: Chen Yang <2082464740@qq.com>
---
.../diffusion/distributed/autoencoders/autoencoder_kl_ltx2.py | 2 +-
vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_ltx2.py b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_ltx2.py
index 3c1436d6716..70f7377d6c1 100644
--- a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_ltx2.py
+++ b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_ltx2.py
@@ -140,4 +140,4 @@ def decode(
if not return_dict:
return (result,)
- return DecoderOutput(sample=result)
\ No newline at end of file
+ return DecoderOutput(sample=result)
diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py
index f86666c6dd4..52fefd92138 100644
--- a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py
+++ b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py
@@ -1176,4 +1176,4 @@ def forward(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
- return loader.load_weights(weights)
\ No newline at end of file
+ return loader.load_weights(weights)
From d3d6d8d11b72218dea01cf0808383303d0bd4a79 Mon Sep 17 00:00:00 2001
From: erfgss <97771661+erfgss@users.noreply.github.com>
Date: Thu, 26 Mar 2026 15:07:25 +0800
Subject: [PATCH 3/8] Update autoencoder_kl_ltx2.py
Signed-off-by: erfgss <97771661+erfgss@users.noreply.github.com>
---
.../autoencoders/autoencoder_kl_ltx2.py | 132 +++++++++++++++++-
1 file changed, 128 insertions(+), 4 deletions(-)
diff --git a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_ltx2.py b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_ltx2.py
index 70f7377d6c1..cab09e7fb87 100644
--- a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_ltx2.py
+++ b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_ltx2.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import math
from typing import Any
import torch
@@ -116,6 +117,124 @@ def tile_merge(self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid
dec = torch.clamp(dec, min=-1.0, max=1.0)
return dec
+ def patch_split(self, z: torch.Tensor) -> tuple[list[TileTask], GridSpec]:
+ _, _, _, latent_h, latent_w = z.shape
+
+ overlap_sample_h = max(0, self.tile_sample_min_height - self.tile_sample_stride_height)
+ overlap_sample_w = max(0, self.tile_sample_min_width - self.tile_sample_stride_width)
+ overlap_latent_h = overlap_sample_h // self.spatial_compression_ratio
+ overlap_latent_w = overlap_sample_w // self.spatial_compression_ratio
+ halo_base_h = max(0, overlap_latent_h // 2)
+ halo_base_w = max(0, overlap_latent_w // 2)
+
+ max_parallel_size = self.distributed_decoder.parallel_size
+ root = int(math.sqrt(max_parallel_size))
+ for rows in range(root, 0, -1):
+ if max_parallel_size % rows == 0:
+ grid_rows, grid_cols = rows, max_parallel_size // rows
+ break
+
+ tiletask_list = []
+ halo_size = {}
+ for i in range(grid_rows):
+ for j in range(grid_cols):
+ h0 = (i * latent_h) // grid_rows
+ h1 = ((i + 1) * latent_h) // grid_rows
+ w0 = (j * latent_w) // grid_cols
+ w1 = ((j + 1) * latent_w) // grid_cols
+
+ core_h = max(0, h1 - h0)
+ core_w = max(0, w1 - w0)
+ halo_h = max(halo_base_h, core_h // 2)
+ halo_w = max(halo_base_w, core_w // 2)
+
+ ph0 = max(0, h0 - halo_h)
+ ph1 = min(latent_h, h1 + halo_h)
+ pw0 = max(0, w0 - halo_w)
+ pw1 = min(latent_w, w1 + halo_w)
+
+ tile = z[:, :, :, ph0:ph1, pw0:pw1]
+ tiletask_list.append(
+ TileTask(
+ len(tiletask_list),
+ (i, j),
+ tile,
+ workload=tile.shape[2] * tile.shape[3] * tile.shape[4],
+ )
+ )
+ halo_size[(i, j)] = {
+ "up": h0 - ph0,
+ "down": ph1 - h1,
+ "left": w0 - pw0,
+ "right": pw1 - w1,
+ }
+
+ grid_spec = GridSpec(
+ split_dims=(3, 4),
+ grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1),
+ tile_spec={
+ "halo_size": halo_size,
+ "scale": self.spatial_compression_ratio,
+ },
+ output_dtype=self.dtype,
+ )
+ return tiletask_list, grid_spec
+
+ def patch_exec(
+ self,
+ task: TileTask,
+ timestep: torch.Tensor | None = None,
+ *args: Any,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ return self.tile_exec(task, timestep=timestep, *args, **kwargs)
+
+ def patch_merge(self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec) -> torch.Tensor:
+ grid_h, grid_w = grid_spec.grid_shape
+ result_rows = []
+ scale = grid_spec.tile_spec["scale"]
+
+ if hasattr(self, "clear_cache"):
+ self.clear_cache()
+
+ for i in range(grid_h):
+ result_row = []
+ for j in range(grid_w):
+ halo = grid_spec.tile_spec["halo_size"][(i, j)]
+ tile = coord_tensor_map[(i, j)]
+
+ halo_up = halo["up"] * scale
+ halo_down = halo["down"] * scale
+ halo_left = halo["left"] * scale
+ halo_right = halo["right"] * scale
+
+ core_tile = tile[
+ :,
+ :,
+ :,
+ halo_up : (None if halo_down == 0 else -halo_down),
+ halo_left : (None if halo_right == 0 else -halo_right),
+ ]
+ result_row.append(core_tile)
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ dec = torch.cat(result_rows, dim=3)
+ dec = torch.clamp(dec, min=-1.0, max=1.0)
+ return dec
+
+ def _strategy_select(self, z: torch.Tensor):
+ tile_latent_min_height = getattr(self, "tile_sample_min_height", None)
+ tile_latent_min_width = getattr(self, "tile_sample_min_width", None)
+ if tile_latent_min_height is None or tile_latent_min_width is None:
+ return None, None, None
+
+ tile_latent_min_height = tile_latent_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = tile_latent_min_width // self.spatial_compression_ratio
+ if z.shape[-2] > tile_latent_min_height or z.shape[-1] > tile_latent_min_width:
+ return self.tile_split, self.tile_exec, self.tile_merge
+
+ return self.patch_split, self.patch_exec, self.patch_merge
+
def decode(
self,
z: torch.Tensor,
@@ -127,13 +246,18 @@ def decode(
if not self.is_distributed_enabled():
return super().decode(z, timestep, return_dict=return_dict, *args, **kwargs)
- logger.info("Decode run with distributed executor")
+ split, exec, merge = self._strategy_select(z)
+ if split is None:
+ return super().decode(z, timestep, return_dict=return_dict, *args, **kwargs)
+
+ strategy = "tile" if split == self.tile_split else "patch"
+ logger.info(f"Decode run with distributed executor, split strategy is {strategy}")
result = self.distributed_decoder.execute(
z,
DistributedOperator(
- split=self.tile_split,
- exec=lambda task: self.tile_exec(task, timestep=timestep, *args, **kwargs),
- merge=self.tile_merge,
+ split=split,
+ exec=lambda task: exec(task, timestep=timestep, *args, **kwargs),
+ merge=merge,
),
broadcast_result=True,
)
From a490461b83282b6a288beb462dc742351993667f Mon Sep 17 00:00:00 2001
From: erfgss <97771661+erfgss@users.noreply.github.com>
Date: Thu, 26 Mar 2026 15:16:20 +0800
Subject: [PATCH 4/8] Update parallelism_acceleration.md
Signed-off-by: erfgss <97771661+erfgss@users.noreply.github.com>
---
docs/user_guide/diffusion/parallelism_acceleration.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md
index 0d6903a6a35..5e662b7fea8 100644
--- a/docs/user_guide/diffusion/parallelism_acceleration.md
+++ b/docs/user_guide/diffusion/parallelism_acceleration.md
@@ -64,7 +64,7 @@ The following table shows which models are currently supported by parallelism me
| **Wan2.1** | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` | ✅ | ✅ | ✅ | ✅ | ✅ |
| **Wan2.1** | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | ✅ | ✅ | ✅ | ✅ | ✅ |
| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ✅ | ✅ | ✅ | ✅ | ✅ |
-| **LTX-2** | `Lightricks/LTX-2` | ✅ | ✅ | ✅ | ❌ | ❌ |
+| **LTX-2** | `Lightricks/LTX-2` | ✅ | ✅ | ✅ | ❌ | ✅ |
### Tensor Parallelism
From 47c1acf9ff18a83acb102e3e37b7bbcd8c3c161b Mon Sep 17 00:00:00 2001
From: Chen Yang <2082464740@qq.com>
Date: Tue, 31 Mar 2026 10:13:50 +0800
Subject: [PATCH 5/8] fix
Signed-off-by: Chen Yang <2082464740@qq.com>
---
.../diffusion/parallelism_acceleration.md | 477 ------------------
1 file changed, 477 deletions(-)
delete mode 100644 docs/user_guide/diffusion/parallelism_acceleration.md
diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md
deleted file mode 100644
index 1340c1da0ef..00000000000
--- a/docs/user_guide/diffusion/parallelism_acceleration.md
+++ /dev/null
@@ -1,477 +0,0 @@
-# Parallelism Acceleration Guide
-
-This guide includes how to use parallelism methods in vLLM-Omni to speed up diffusion model inference as well as reduce the memory requirement on each device.
-
-## Overview
-
-The following parallelism methods are currently supported in vLLM-Omni:
-
-1. DeepSpeed Ulysses Sequence Parallel (DeepSpeed Ulysses-SP) ([arxiv paper](https://arxiv.org/pdf/2309.14509)): Ulysses-SP splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads.
-
-2. [Ring-Attention](#ring-attention) - splits the input along the sequence dimension and uses ring-based P2P communication to accumulate attention results, keeping the sequence dimension sharded
-
-3. Classifier-Free-Guidance Parallel (CFG-Parallel): CFG-Parallel runs the positive/negative prompts of classifier-free guidance (CFG) on different devices, then merges on a single device to perform the scheduler step.
-
-4. [Tensor Parallelism](#tensor-parallelism): Tensor parallelism shards model weights across devices. This can reduce per-GPU memory usage. Note that for diffusion models we currently shard the majority of layers within the DiT.
-
-5. [VAE Patch Parallelism](#vae-patch-parallelism): VAE patch parallelism shards VAE decode spatially across ranks. This can reduce the peak memory of VAE decode and (depending on resolution and communication overhead) speed up VAE decode.
-
-6. [HSDP](#hsdp): Hybrid Sharded Data Parallel shards model weights across GPUs using PyTorch FSDP2. This reduces per-GPU memory usage, enabling inference of large models on GPUs with limited memory.
-
-7. [Expert Parallel](#expert-parallelism): Expert Parallelism shards the Experts of a Mixture-of-Experts (MoE) layer across multiple devices. During the forward, a gating mechanism routes tokens to their designated experts, necessitating cross-cards communication(all-to-all) to dispatch tokens to the correct ranks and combine the results. This parallelism allows for massive scaling of model parameters without a proportional increase in the computational load per device.
-
-The following table shows which models are currently supported by parallelism method:
-
-### ImageGen
-
-| Model | Model Identifier | Ulysses-SP | Ring-SP | CFG-Parallel | Tensor-Parallel | VAE-Patch-Parallel | Expert-Parallel | HSDP |
-|--------------------------|--------------------------------------|:----------:|:-------:|:------------:|:---------------:|:------------------:|:---------------:|:----:|
-| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ✅ | ✅ | ❌ | ✅ | ❌ | N/A | ❌ |
-| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ✅ | ✅ | ❌ | ✅ | ❌ | N/A | ❌ |
-| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ❌ | ❌ | ❌ | ❌ | N/A | ❌ |
-| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | ✅ | N/A | ❌ |
-| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | ❌ | N/A | ❌ |
-| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | ❌ | N/A | ❌ |
-| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | ✅ | ✅ | ❌ | N/A | ❌ |
-| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ✅ (TP=2 only) | ✅ | N/A | ❌ |
-| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ✅ | ✅ | N/A | ❌ |
-| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ✅ | ✅ | ❌ | ✅ | ❌ | N/A | ✅ |
-| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ✅ | ✅ | ❌ | N/A | ✅ |
-| **FLUX.2-dev** | `black-forest-labs/FLUX.2-dev` | ❌ | ❌ | ❌ | ✅ | ❌ | N/A | ✅ |
-| **HunyuanImage3.0** | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ |
-| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ✅ | ✅ | ❌ | N/A | ❌ |
-| **DreamID-Omni** | `XuGuo699/DreamID-Omni` | ❌ | ❌ | ✅ | ❌ | ❌ | N/A | ❌ |
-| **FLUX.1-Kontext-dev** | `black-forest-labs/FLUX.1-Kontext-dev` | ❌ | ❌ | ❌ | ✅ | ❌ | N/A | ✅ |
-| **OmniGen2** | `OmniGen2/OmniGen2` | ❌ | ❌ | ❌ | ✅ | ❌ | N/A | ❌ |
-
-!!! note "TP Limitations for Diffusion Models"
- We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the `text_encoder` component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP.
-
- - Good news: The text_encoder typically has minimal impact on overall inference performance.
- - Bad news: When TP is enabled, every TP process retains a full copy of the text_encoder weights, leading to significant GPU memory waste.
-
- We are actively refactoring this design to address this. For details and progress, please refer to [Issue #771](https://github.com/vllm-project/vllm-omni/issues/771).
-
-
-!!! note "Why Z-Image is TP=2 only"
- Z-Image Turbo is currently limited to `tensor_parallel_size` of **1 or 2** due to model shape divisibility constraints.
- For example, the model has `n_heads=30` and a final projection out dimension of `64`, so valid TP sizes must divide both 30 and 64; the only common divisors are **1 and 2**.
-
-### VideoGen
-
-| Model | Model Identifier | Ulysses-SP | Ring-Attention | Tensor-Parallel | HSDP | VAE-Patch-Parallel |
-|-------|------------------|:----------:|:--------------:|:---------------:|:----:| :----:|
-| **Wan2.1** | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` | ✅ | ✅ | ✅ | ✅ | ✅ |
-| **Wan2.1** | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | ✅ | ✅ | ✅ | ✅ | ✅ |
-| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ✅ | ✅ | ✅ | ✅ | ✅ |
-| **LTX-2** | `Lightricks/LTX-2` | ✅ | ✅ | ✅ | ❌ | ✅ |
-
-### Tensor Parallelism
-
-Tensor parallelism splits model parameters across GPUs. In vLLM-Omni, tensor parallelism is configured via `DiffusionParallelConfig.tensor_parallel_size`.
-
-#### Offline Inference
-
-```python
-from vllm_omni import Omni
-from vllm_omni.diffusion.data import DiffusionParallelConfig
-
-omni = Omni(
- model="Tongyi-MAI/Z-Image-Turbo",
- parallel_config=DiffusionParallelConfig(tensor_parallel_size=2),
-)
-
-outputs = omni.generate(
- "a cat reading a book",
- OmniDiffusionSamplingParams(
- num_inference_steps=9,
- width=512,
- height=512,
- ),
-)
-```
-
-### VAE Patch Parallelism
-
-VAE patch parallelism distributes the VAE decode workload across multiple ranks by splitting the latent spatially. It is configured via `DiffusionParallelConfig.vae_patch_parallel_size` and can be combined with other parallelism methods (e.g., TP).
-
-!!! note "Enablement and feature gate"
- - VAE patch parallelism is currently **enabled only for validated pipelines** (check [ImageGen](#imagegen) and [VideoGen](#videogen) for more information).
- - If `vae_patch_parallel_size > 1` is set for a validated pipeline, vLLM-Omni will automatically enable `vae_use_tiling` as a safety gate. (We use `vae_use_tiling` because it indicates the VAE supports diffusers tiling parameters like `tile_latent_min_size` and `tile_overlap_factor`.)
-
-#### Offline Inference
-
-```python
-from vllm_omni import Omni
-from vllm_omni.diffusion.data import DiffusionParallelConfig
-
-omni = Omni(
- model="Tongyi-MAI/Z-Image-Turbo",
- parallel_config=DiffusionParallelConfig(
- tensor_parallel_size=2,
- vae_patch_parallel_size=2,
- ),
- vae_use_tiling=True,
-)
-
-outputs = omni.generate(
- prompt="a cat reading a book",
- num_inference_steps=9,
- width=1024,
- height=1024,
-)
-```
-
-#### How it works (method selection)
-
-VAE patch parallelism automatically selects between two internal decode methods based on whether diffusers tiling would kick in:
-
-- `_distributed_tiled_decode`: Used when the latent spatial size exceeds `vae.tile_latent_min_size` (i.e., diffusers tiled decode). Each rank decodes a subset of tiles; rank0 gathers and runs the same overlap+blend+stitch logic as diffusers. This matches the single-rank diffusers tiled output.
-
-- `_distributed_patch_decode`: Used when diffusers tiling would not kick in. Each rank decodes a grid patch expanded with a latent-space halo; then rank0 gathers the cropped core patches and stitches them into the full image. This path has no blending and can introduce small numerical differences compared to the non-parallel decode.
-
-### Sequence Parallelism
-
-#### Ulysses-SP
-
-!!! note "Experimental UAA mode"
- `ulysses_mode="advanced_uaa"` is an experimental extension to Ulysses-SP. It lets Ulysses attention handle arbitrary sequence lengths and arbitrary attention head counts without relying on `attention_mask`-based token padding.
-
- In hybrid Ulysses + Ring mode, Ring still requires every rank in the same ring group to observe the same post-Ulysses sequence length. If that condition is not met, vLLM-Omni raises a validation error instead of entering the ring kernel with inconsistent shapes.
-
-##### Offline Inference
-
-An example of offline inference script using [Ulysses-SP](https://arxiv.org/pdf/2309.14509) is shown below:
-```python
-from vllm_omni import Omni
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-from vllm_omni.diffusion.data import DiffusionParallelConfig
-ulysses_degree = 2
-
-omni = Omni(
- model="Qwen/Qwen-Image",
- parallel_config=DiffusionParallelConfig(ulysses_degree=2)
-)
-
-outputs = omni.generate(
- "A cat sitting on a windowsill",
- OmniDiffusionSamplingParams(num_inference_steps=50, width=2048, height=2048),
-)
-```
-
-See `examples/offline_inference/text_to_image/text_to_image.py` for a complete working example.
-
-To enable the experimental UAA mode explicitly, use a model/configuration that actually requires it. For example, `Tongyi-MAI/Z-Image-Turbo` has 30 attention heads, so `ulysses_degree=4` requires UAA because 30 is not divisible by 4:
-
-```python
-omni = Omni(
- model="Tongyi-MAI/Z-Image-Turbo",
- parallel_config=DiffusionParallelConfig(
- ulysses_degree=4,
- ulysses_mode="advanced_uaa",
- ),
-)
-```
-
-##### Online Serving
-
-You can enable Ulysses-SP in online serving for diffusion models via `--usp`:
-
-```bash
-# Text-to-image (requires >= 2 GPUs)
-vllm serve Qwen/Qwen-Image --omni --port 8091 --usp 2
-
-# Experimental UAA mode for a model with 30 attention heads
-vllm serve Tongyi-MAI/Z-Image-Turbo --omni --port 8091 --usp 4 --ulysses-mode advanced_uaa
-```
-
-##### Benchmarks
-!!! note "Benchmark Disclaimer"
- These benchmarks are provided for **general reference only**. The configurations shown use default or common parameter settings and have not been exhaustively optimized for maximum performance. Actual performance may vary based on:
-
- - Specific model and use case
- - Hardware configuration
- - Careful parameter tuning
- - Different inference settings (e.g., number of steps, image resolution)
-
-
-To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** model generating images (**2048x2048** as long sequence input) with 50 inference steps. The hardware devices are NVIDIA H800 GPUs. `sdpa` is the attention backends.
-
-| Configuration | Ulysses degree |Generation Time | Speedup |
-|---------------|----------------|---------|---------|
-| **Baseline (diffusers)** | - | 112.5s | 1.0x |
-| Ulysses-SP | 2 | 65.2s | 1.73x |
-| Ulysses-SP | 4 | 39.6s | 2.84x |
-| Ulysses-SP | 8 | 30.8s | 3.65x |
-
-#### Ring-Attention
-
-Ring-Attention ([arxiv paper](https://arxiv.org/abs/2310.01889)) splits the input along the sequence dimension and uses ring-based P2P communication to accumulate attention results. Unlike Ulysses-SP which uses all-to-all communication, Ring-Attention keeps the sequence dimension sharded throughout the computation and circulates Key/Value blocks through a ring topology.
-
-##### Offline Inference
-
-An example of offline inference script using Ring-Attention is shown below:
-```python
-from vllm_omni import Omni
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-from vllm_omni.diffusion.data import DiffusionParallelConfig
-ring_degree = 2
-
-omni = Omni(
- model="Qwen/Qwen-Image",
- parallel_config=DiffusionParallelConfig(ring_degree=2)
-)
-
-outputs = omni.generate(
- "A cat sitting on a windowsill",
- OmniDiffusionSamplingParams(num_inference_steps=50, width=2048, height=2048),
-)
-```
-
-See `examples/offline_inference/text_to_image/text_to_image.py` for a complete working example.
-
-
-##### Online Serving
-
-You can enable Ring-Attention in online serving for diffusion models via `--ring`:
-
-```bash
-# Text-to-image (requires >= 2 GPUs)
-vllm serve Qwen/Qwen-Image --omni --port 8091 --ring 2
-```
-
-##### Benchmarks
-!!! note "Benchmark Disclaimer"
- These benchmarks are provided for **general reference only**. The configurations shown use default or common parameter settings and have not been exhaustively optimized for maximum performance. Actual performance may vary based on:
-
- - Specific model and use case
- - Hardware configuration
- - Careful parameter tuning
- - Different inference settings (e.g., number of steps, image resolution)
-
-
-To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** model generating images (**1024x1024** as long sequence input) with 50 inference steps. The hardware devices are NVIDIA A100 GPUs. `flash_attn` is the attention backends.
-
-| Configuration | Ring degree |Generation Time | Speedup |
-|---------------|----------------|---------|---------|
-| **Baseline (diffusers)** | - | 45.2s | 1.0x |
-| Ring-Attention | 2 | 29.9s | 1.51x |
-| Ring-Attention | 4 | 23.3s | 1.94x |
-
-
-#### Hybrid Ulysses + Ring
-
-You can combine both Ulysses-SP and Ring-Attention for larger scale parallelism. The total sequence parallel size equals `ulysses_degree × ring_degree`.
-
-!!! note "Experimental UAA in hybrid mode"
- `ulysses_mode="advanced_uaa"` can also be used with hybrid Ulysses + Ring, but this does not remove Ring's shape requirement. Every rank in the same ring group must still have the same post-Ulysses sequence length.
-
-##### Offline Inference
-
-```python
-from vllm_omni import Omni
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-from vllm_omni.diffusion.data import DiffusionParallelConfig
-
-# Hybrid: 2 Ulysses × 2 Ring = 4 GPUs total
-omni = Omni(
- model="Qwen/Qwen-Image",
- parallel_config=DiffusionParallelConfig(
- ulysses_degree=2,
- ring_degree=2,
- )
-)
-
-outputs = omni.generate(
- "A cat sitting on a windowsill",
- OmniDiffusionSamplingParams(num_inference_steps=50, width=2048, height=2048),
-)
-```
-
-##### Online Serving
-
-```bash
-# Text-to-image (requires >= 4 GPUs)
-vllm serve Qwen/Qwen-Image --omni --port 8091 --usp 2 --ring 2
-```
-
-##### Benchmarks
-!!! note "Benchmark Disclaimer"
- These benchmarks are provided for **general reference only**. The configurations shown use default or common parameter settings and have not been exhaustively optimized for maximum performance. Actual performance may vary based on:
-
- - Specific model and use case
- - Hardware configuration
- - Careful parameter tuning
- - Different inference settings (e.g., number of steps, image resolution)
-
-
-To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** model generating images (**1024x1024** as long sequence input) with 50 inference steps. The hardware devices are NVIDIA A100 GPUs. `flash_attn` is the attention backends.
-
-| Configuration | Ulysses degree | Ring degree | Generation Time | Speedup |
-|---------------|----------------|-------------|-----------------|---------|
-| **Baseline (diffusers)** | - | - | 45.2s | 1.0x |
-| Hybrid Ulysses + Ring | 2 | 2 | 24.3s | 1.87x |
-
-
-### CFG-Parallel
-
-#### Offline Inference
-
-CFG-Parallel is enabled through `DiffusionParallelConfig(cfg_parallel_size=2)`, which runs one rank for the positive branch and one rank for the negative branch.
-
-An example of offline inference using CFG-Parallel (image-to-image) is shown below:
-
-```python
-from vllm_omni import Omni
-from vllm_omni.diffusion.data import DiffusionParallelConfig
-
-image_path = "path_to_image.png"
-omni = Omni(
- model="Qwen/Qwen-Image-Edit",
- parallel_config=DiffusionParallelConfig(cfg_parallel_size=2),
-)
-input_image = Image.open(image_path).convert("RGB")
-
-outputs = omni.generate(
- {
- "prompt": "turn this cat to a dog",
- "negative_prompt": "low quality, blurry",
- "multi_modal_data": {"image": input_image},
- },
- OmniDiffusionSamplingParams(
- true_cfg_scale=4.0,
- num_inference_steps=50,
- ),
-)
-```
-
-Notes:
-
-- CFG-Parallel is only effective when a `negative_prompt` is provided AND a guidance scale (or `cfg_scale`) is greater than 1.
-
-See `examples/offline_inference/image_to_image/image_edit.py` for a complete working example.
-```bash
-cd examples/offline_inference/image_to_image/
-python image_edit.py \
- --model "Qwen/Qwen-Image-Edit" \
- --image "qwen_image_output.png" \
- --prompt "turn this cat to a dog" \
- --negative-prompt "low quality, blurry" \
- --cfg-scale 4.0 \
- --output "edited_image.png" \
- --cfg-parallel-size 2
-```
-
-#### Online Serving
-
-You can enable CFG-Parallel in online serving for diffusion models via `--cfg-parallel-size`:
-
-```bash
-vllm serve Qwen/Qwen-Image-Edit --omni --port 8091 --cfg-parallel-size 2
-```
-
-### HSDP
-
-HSDP (Hybrid Sharded Data Parallel) shards model weights across GPUs to reduce per-GPU memory usage. This enables inference of large models (e.g., Wan2.2 14B) on GPUs with limited memory.
-
-Unlike Tensor Parallelism which splits computation, HSDP uses PyTorch's FSDP2 to shard and redistribute weights at runtime. Each GPU only holds a fraction of the model weights, and weights are gathered on-demand during forward passes.
-
-#### Configuration
-
-HSDP is configured via `DiffusionParallelConfig`:
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `use_hsdp` | bool | False | Enable HSDP |
-| `hsdp_shard_size` | int | -1 | Number of GPUs to shard weights across. -1 = auto (requires other parallelism > 1) |
-| `hsdp_replicate_size` | int | 1 | Number of replica groups. Each group holds a full sharded copy |
-
-**Constraints:**
-
-- `hsdp_replicate_size × hsdp_shard_size == world_size`
-- HSDP cannot be used with Tensor Parallelism (`tensor_parallel_size` must be 1)
-
-#### Operating Modes
-
-HSDP can work in two modes:
-
-- **Standalone Mode**: HSDP alone without other parallelism. Must specify `hsdp_shard_size` explicitly.
-- **Combined Mode**: HSDP overlays on top of other parallelism (Ulysses Sequence Parallel, CFG Parallel). HSDP dimensions must match world_size.
-
-#### Offline Inference
-
-**Standalone HSDP** (shard across 4 GPUs, no other parallelism):
-
-```python
-from vllm_omni import Omni
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-from vllm_omni.diffusion.data import DiffusionParallelConfig
-
-omni = Omni(
- model="Wan-AI/Wan2.2-T2V-A14B-Diffusers",
- parallel_config=DiffusionParallelConfig(
- use_hsdp=True,
- hsdp_shard_size=4, # Shard across 4 GPUs
- ),
-)
-
-outputs = omni.generate(
- "A cat playing piano",
- OmniDiffusionSamplingParams(num_inference_steps=50),
-)
-```
-
-**Combined HSDP + Sequence Parallel**:
-
-```python
-omni = Omni(
- model="Wan-AI/Wan2.2-T2V-A14B-Diffusers",
- parallel_config=DiffusionParallelConfig(
- ulysses_degree=4, # Sequence parallel
- use_hsdp=True, # HSDP overlays on SP
- ),
-)
-```
-
-#### Online Serving
-
-**Standalone HSDP** (shard model across 4 GPUs):
-
-```bash
-vllm serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --omni --port 8091 --use-hsdp --hsdp-shard-size 4
-```
-
-**Combined with Sequence Parallel**:
-
-```bash
-vllm serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --omni --port 8091 --use-hsdp --usp 4
-```
-
-#### Adding HSDP Support to New Models
-
-For detailed instructions on adding HSDP support to new models, see the [HSDP Contributing Guide](../../design/feature/hsdp.md).
-
-### Expert Parallelism
-
-Unlike Tensor Parallelism which shards every layer's weights, EP only shards the MoE expert MLP blocks. This significantly reduces the memory footprint of MoE models (e.g., HunyuanImage3.0) while maintaining constant dense-equivalent compute efficiency. Expert Parallelism is enabled via `DiffusionParallelConfig.enable_expert_parallel`. And `self.ep = tp * sp * cfg * dp` for now, so at least one of TP/SP/CFG/DP should set when EP enabled.
-
-#### Offline Inference
-
-```python
-from vllm_omni import Omni
-from vllm_omni.diffusion.data import DiffusionParallelConfig
-
-omni = Omni(
- model="tencent/HunyuanImage-3.0",
- parallel_config=DiffusionParallelConfig(tensor_parallel_size=8, enable_expert_parallel=True),
-)
-
-outputs = omni.generate(
- "A brown and white dog is running on the grass",
- OmniDiffusionSamplingParams(
- num_inference_steps=50,
- width=1024,
- height=1024,
- ),
-)
-```
From 980b87c22a680bf62ad7a7541a826f012c5f9886 Mon Sep 17 00:00:00 2001
From: Chen Yang <2082464740@qq.com>
Date: Tue, 31 Mar 2026 10:31:53 +0800
Subject: [PATCH 6/8] fix
Signed-off-by: Chen Yang <2082464740@qq.com>
---
.../diffusion/cache_acceleration/cache_dit.md | 285 ++++++++++++++++++
.../diffusion/cache_acceleration/teacache.md | 194 ++++++++++++
.../diffusion/cache_dit_acceleration.md | 228 --------------
.../diffusion/parallelism/cfg_parallel.md | 169 +++++++++++
.../diffusion/parallelism/expert_parallel.md | 87 ++++++
docs/user_guide/diffusion/parallelism/hsdp.md | 149 +++++++++
.../diffusion/parallelism/overview.md | 16 +
.../parallelism/sequence_parallel.md | 233 ++++++++++++++
.../diffusion/parallelism/tensor_parallel.md | 151 ++++++++++
.../parallelism/vae_patch_parallel.md | 200 ++++++++++++
docs/user_guide/diffusion/quantization/fp8.md | 1 +
docs/user_guide/diffusion/step_execution.md | 2 +-
docs/user_guide/diffusion/teacache.md | 145 ---------
13 files changed, 1486 insertions(+), 374 deletions(-)
create mode 100644 docs/user_guide/diffusion/cache_acceleration/cache_dit.md
create mode 100644 docs/user_guide/diffusion/cache_acceleration/teacache.md
delete mode 100644 docs/user_guide/diffusion/cache_dit_acceleration.md
create mode 100644 docs/user_guide/diffusion/parallelism/cfg_parallel.md
create mode 100644 docs/user_guide/diffusion/parallelism/expert_parallel.md
create mode 100644 docs/user_guide/diffusion/parallelism/hsdp.md
create mode 100644 docs/user_guide/diffusion/parallelism/overview.md
create mode 100644 docs/user_guide/diffusion/parallelism/sequence_parallel.md
create mode 100644 docs/user_guide/diffusion/parallelism/tensor_parallel.md
create mode 100644 docs/user_guide/diffusion/parallelism/vae_patch_parallel.md
delete mode 100644 docs/user_guide/diffusion/teacache.md
diff --git a/docs/user_guide/diffusion/cache_acceleration/cache_dit.md b/docs/user_guide/diffusion/cache_acceleration/cache_dit.md
new file mode 100644
index 00000000000..dec52b9d6b6
--- /dev/null
+++ b/docs/user_guide/diffusion/cache_acceleration/cache_dit.md
@@ -0,0 +1,285 @@
+# Cache-DiT Guide
+
+
+## Table of Content
+
+- [Overview](#overview)
+- [Quick Start](#quick-start)
+- [Example Script](#example-script)
+- [Acceleration Methods](#acceleration-methods)
+- [Configuration Parameters](#configuration-parameters)
+- [Best Practices](#best-practices)
+- [Troubleshooting](#troubleshooting)
+- [Summary](#summary)
+- [Additional Resources](#additional-resources)
+
+---
+
+## Overview
+
+Cache-DiT accelerates diffusion transformer models through intelligent caching mechanisms, providing significant speedup with minimal quality loss. It supports multiple acceleration techniques that can be combined for optimal performance:
+
+- **DBCache**: Dual Block Cache for reducing redundant computations
+- **TaylorSeer**: Taylor expansion-based forecasting for faster inference
+- **SCM**: Step Computation Masking for selective step computation
+
+See supported models list in [Supported Models](../../diffusion_features.md#supported-models).
+
+---
+
+## Quick Start
+
+### Basic Usage
+
+Enable cache-dit acceleration by simply setting `cache_backend="cache_dit"`:
+
+```python
+from vllm_omni import Omni
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+
+omni = Omni(
+ model="Qwen/Qwen-Image",
+ cache_backend="cache_dit", # Enable Cache-DiT with defaults
+)
+
+outputs = omni.generate(
+ "a beautiful landscape",
+ OmniDiffusionSamplingParams(num_inference_steps=50),
+)
+```
+
+**Note**: When `cache_config` is not provided, Cache-DiT uses optimized default values. See the [Configuration Parameters](#configuration-parameters) section for details.
+
+### Custom Configuration
+
+To customize cache-dit settings, provide a `cache_config` dictionary, for example:
+
+```python
+omni = Omni(
+ model="Qwen/Qwen-Image",
+ cache_backend="cache_dit",
+ cache_config={
+ "Fn_compute_blocks": 1,
+ "Bn_compute_blocks": 0,
+ "max_warmup_steps": 4,
+ "residual_diff_threshold": 0.12,
+ },
+)
+```
+
+---
+
+## Example Script
+
+### Offline Inference
+
+Use the example script under `examples/offline_inference/text_to_image`:
+
+```bash
+cd examples/offline_inference/text_to_image
+python text_to_image.py \
+ --model Qwen/Qwen-Image \
+ --prompt "a cup of coffee on the table" \
+ --cache-backend cache_dit \
+ --num-inference-steps 50
+```
+
+See the [text_to_image.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/text_to_image/text_to_image.py) for detailed configuration options.
+
+The script uses cache-dit acceleration with a hybrid configuration combining DBCache, SCM, and TaylorSeer:
+
+```python
+omni = Omni(
+ model="Qwen/Qwen-Image",
+ cache_backend="cache_dit",
+ cache_config={
+ # Scheme: Hybrid DBCache + SCM + TaylorSeer
+ "Fn_compute_blocks": 1, # Optimized for single-transformer models
+ "Bn_compute_blocks": 0, # Number of backward compute blocks
+ "max_warmup_steps": 4, # Maximum warmup steps (works for few-step models)
+ "residual_diff_threshold": 0.24, # Higher threshold for more aggressive caching
+ "max_continuous_cached_steps": 3, # Limit to prevent precision degradation
+ # TaylorSeer parameters [cache-dit only]
+ "enable_taylorseer": False, # Disabled by default (not suitable for few-step models)
+ "taylorseer_order": 1, # TaylorSeer polynomial order
+ # SCM (Step Computation Masking) parameters [cache-dit only]
+ "scm_steps_mask_policy": None, # SCM mask policy: None (disabled), "slow", "medium", "fast", "ultra"
+ "scm_steps_policy": "dynamic", # SCM steps policy: "dynamic" or "static"
+ }
+)
+```
+
+You can customize the configuration by modifying the `cache_config` dictionary to use only specific methods (e.g., DBCache only, DBCache + SCM, etc.) based on your quality and speed requirements.
+
+For image-to-image tasks, use the example script under `examples/offline_inference/image_to_image`:
+
+```bash
+cd examples/offline_inference/image_to_image
+python image_edit.py \
+ --model Qwen/Qwen-Image-Edit \
+ --prompt "make the sky more colorful" \
+ --image path/to/input/image.jpg \
+ --cache-backend cache_dit \
+ --num-inference-steps 50 \
+ --cache-dit-max-continuous-cached-steps 3 \
+ --cache-dit-residual-diff-threshold 0.24 \
+ --cache-dit-enable-taylorseer
+```
+
+See the [image_edit.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py) for detailed configuration options.
+
+### Online Serving
+
+```bash
+# Default configuration (recommended)
+vllm serve Qwen/Qwen-Image --omni --port 8091 --cache-backend cache_dit
+
+# Custom configuration
+vllm serve Qwen/Qwen-Image --omni --port 8091 \
+ --cache-backend cache_dit \
+ --cache-config '{"Fn_compute_blocks": 1, "residual_diff_threshold": 0.12}'
+```
+
+---
+
+## Acceleration Methods
+
+For comprehensive illustration, please view Cache-DiT [User Guide](https://cache-dit.readthedocs.io/en/latest/user_guide/OVERVIEWS/).
+
+### 1. DBCache (Dual Block Cache)
+
+DBCache intelligently caches intermediate transformer block outputs when the residual differences between consecutive steps are small, reducing redundant computations without sacrificing quality.
+
+**Example Configuration**:
+
+```python
+cache_config={
+ "Fn_compute_blocks": 8, # Use first 8 blocks for difference computation
+ "Bn_compute_blocks": 0, # No additional fusion blocks
+ "max_warmup_steps": 8, # Cache after 8 warmup steps
+ "residual_diff_threshold": 0.12, # Lower threshold for faster inference
+ "max_cached_steps": -1, # No limit on cached steps
+}
+```
+
+**Performance Tips**:
+
+- Default `Fn_compute_blocks=1` works well for most cases. Some models (e.g., [FLUX.2-klein](https://github.com/wtomin/vllm-omni/blob/main/vllm_omni/diffusion/cache/cache_dit_backend.py#L363)) use a larger value for `Fn_compute_blocks` for a balanced performance.
+- Increase `residual_diff_threshold` (e.g., 0.12-0.15) for faster inference with slight quality trade-off, or decrease from default 0.24 for higher quality.
+- Default `max_warmup_steps=4` is optimized for few-step models. Increase to 6-8 for more steps if needed.
+
+### 2. TaylorSeer
+
+TaylorSeer uses Taylor expansion to forecast future hidden states, allowing the model to skip some computation steps while maintaining quality.
+
+**Example Configuration**:
+
+```python
+cache_config={
+ "enable_taylorseer": True,
+ "taylorseer_order": 1, # First-order Taylor expansion
+}
+```
+
+**Performance Tips**:
+
+- TaylorSeer is **not suitable for few-step distilled models**.
+- Use `taylorseer_order=1` for most cases (good balance of speed and quality).
+- Combine with DBCache for maximum acceleration.
+- Higher orders (2-3) may improve quality but reduce speed gains.
+
+### 3. SCM (Step Computation Masking)
+
+SCM allows you to specify which steps must be computed and which can use cached results, similar to LeMiCa/EasyCache style acceleration.
+
+`scm_steps_mask_policy` options (number of compute steps out of 28):
+
+| Policy | Compute Steps | Speed | Quality |
+|--------|--------------|-------|---------|
+| `None` (default) | All | Baseline | Best |
+| `"slow"` | 18 / 28 | Moderate | High |
+| `"medium"` | 15 / 28 | Balanced | Good |
+| `"fast"` | 11 / 28 | Fast | Moderate |
+| `"ultra"` | 8 / 28 | Fastest | Lower |
+
+**Example Configuration**:
+
+```python
+cache_config={
+ "scm_steps_mask_policy": "medium", # Balanced speed/quality
+ "scm_steps_policy": "dynamic", # Use dynamic cache
+}
+```
+
+**Performance Tips**:
+
+- SCM is disabled by default. Enable it by setting a policy value if you need additional acceleration.
+- Start with `"medium"` policy and adjust based on quality requirements.
+- Use `"fast"` or `"ultra"` for maximum speed when quality can be slightly compromised.
+- `"dynamic"` policy generally provides better quality than `"static"`.
+- SCM mask is automatically regenerated when `num_inference_steps` changes during inference.
+
+---
+
+## Configuration Parameters
+
+In `cache_config` passed to `Omni` constructor, it accepts the arguments of `DBCacheConfig` ([Cache-DiT API Reference](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/)). Key parameters are listed below:
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `Fn_compute_blocks` | int | 1 | First n blocks for difference computation (optimized for single-transformer models) |
+| `Bn_compute_blocks` | int | 0 | Last n blocks for fusion |
+| `max_warmup_steps` | int | 4 | Steps before caching starts (optimized for few-step distilled models) |
+| `max_cached_steps` | int | -1 | Max cached steps (-1 = unlimited) |
+| `max_continuous_cached_steps` | int | 3 | Max consecutive cached steps (prevents precision degradation) |
+| `residual_diff_threshold` | float | 0.24 | Residual difference threshold (higher for more aggressive caching) |
+| `num_inference_steps` | int \| None | None | Initial inference steps for SCM mask generation (optional, auto-refreshed during inference) |
+| `enable_taylorseer` | bool | False | Enable TaylorSeer acceleration (not suitable for few-step distilled models) |
+| `taylorseer_order` | int | 1 | Taylor expansion order |
+| `scm_steps_mask_policy` | str \| None | None | SCM mask policy (None, "slow", "medium", "fast", "ultra") |
+| `scm_steps_policy` | str | "dynamic" | SCM computation policy ("dynamic" or "static") |
+
+---
+
+## Best Practices
+
+### When to Use
+
+**Good for:**
+
+- Production deployments requiring fast inference
+- Diffusion transformer models (DiT architecture)
+- Scenarios where 1.5x-3x speedup is valuable
+
+**Not for:**
+
+- Non-DiT architectures (use model-specific acceleration instead)
+- Models already using few-step distillation (< 10 steps)
+
+---
+
+## Troubleshooting
+
+### Common Issue 1: Quality Degradation
+
+**Symptoms**: Generated images have visible artifacts or lower quality
+
+**Solution**:
+```python
+# Reduce aggressiveness - use more conservative settings
+cache_config={
+ "residual_diff_threshold": 0.20, # Lower threshold (closer to default 0.24)
+ "Fn_compute_blocks": 8, # Use more blocks for better decisions
+ "max_warmup_steps": 6, # Longer warmup
+ "scm_steps_mask_policy": "slow", # More compute steps
+}
+```
+
+---
+
+## Summary
+
+Using Cache-DiT acceleration:
+
+1. ✅ **Enable Cache-DiT** - Set `cache_backend="cache_dit"` to get 1.5x-3x speedup with optimized defaults
+2. ✅ **(Optional) Customize** - Adjust `cache_config` parameters for specific speed/quality trade-offs
diff --git a/docs/user_guide/diffusion/cache_acceleration/teacache.md b/docs/user_guide/diffusion/cache_acceleration/teacache.md
new file mode 100644
index 00000000000..026b86ec7f9
--- /dev/null
+++ b/docs/user_guide/diffusion/cache_acceleration/teacache.md
@@ -0,0 +1,194 @@
+# TeaCache Guide
+
+
+## Table of Content
+
+- [Overview](#overview)
+- [Quick Start](#quick-start)
+- [Example Script](#example-script)
+- [Configuration Parameters](#configuration-parameters)
+- [Best Practices](#best-practices)
+- [Troubleshooting](#troubleshooting)
+- [Summary](#summary)
+
+---
+
+## Overview
+
+TeaCache accelerates diffusion model inference by caching transformer computations when consecutive timesteps are similar, providing **1.5x-2.0x speedup** with minimal quality loss. It dynamically decides whether to reuse cached outputs based on input similarity, making it ideal for production deployments where inference speed matters without sacrificing generation quality.
+
+See supported models list in [Supported Models](../../diffusion_features.md#supported-models).
+
+---
+
+## Quick Start
+
+
+
+### Basic Usage
+
+
+```python
+from vllm_omni import Omni
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+
+omni = Omni(
+ model="Qwen/Qwen-Image",
+ cache_backend="tea_cache",
+)
+
+outputs = omni.generate(
+ "A cat sitting on a windowsill",
+ OmniDiffusionSamplingParams(num_inference_steps=50),
+)
+```
+
+### Custom Configuration
+
+```python
+omni = Omni(
+ model="Qwen/Qwen-Image",
+ cache_backend="tea_cache",
+ cache_config={
+ "rel_l1_thresh": 0.2, # Controls speed/quality tradeoff
+ },
+)
+```
+
+### Using Environment Variable
+
+You can also enable TeaCache via environment variable:
+
+```bash
+export DIFFUSION_CACHE_BACKEND=tea_cache
+```
+
+Then initialize without explicitly setting `cache_backend`:
+
+```python
+from vllm_omni import Omni
+
+omni = Omni(
+ model="Qwen/Qwen-Image",
+ cache_config={"rel_l1_thresh": 0.2}
+)
+```
+
+---
+
+## Example Script
+
+### Offline Inference
+
+Use python script under `examples/offline_inference/text_to_image/` or `examples/offline_inference/image_to_image/` with CLI:
+
+```bash
+# Text-to-image example
+python examples/offline_inference/text_to_image/text_to_image.py \
+ --model Qwen/Qwen-Image \
+ --cache-backend tea_cache
+
+# Image-to-image example
+python examples/offline_inference/image_to_image/image_edit.py \
+ --model Qwen/Qwen-Image-Edit \
+ --image input.png \
+ --prompt "Edit description" \
+ --cache-backend tea_cache \
+ --tea-cache-rel-l1-thresh 0.25
+```
+
+See the [text_to_image.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/text_to_image/text_to_image.py) or [image_edit.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py) for detailed configuration options.
+
+### Online Serving
+
+```bash
+# Default configuration
+vllm serve Qwen/Qwen-Image --omni --port 8091 --cache-backend tea_cache
+
+# Custom configuration
+vllm serve Qwen/Qwen-Image --omni --port 8091 \
+ --cache-backend tea_cache \
+ --cache-config '{"rel_l1_thresh": 0.2}'
+```
+
+---
+
+## Configuration Parameters
+
+In `OmniDiffusionConfig`
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `rel_l1_thresh` | float | `0.2` | Similarity threshold for cache reuse. Lower values prioritize quality (less caching), higher values prioritize speed (more caching). Suggested range: 0.1-0.8 |
+| `coefficients` | list[float] \| None | `None` | Polynomial coefficients for rescaling L1 distance. Must contain exactly 5 elements if provided. If `None`, uses model-specific defaults based on transformer type. |
+
+Users can find the default model coefficients in [`vllm_omni/diffusion/cache/teacache/config.py`](https://github.com/vllm-project/vllm-omni/blob/main/vllm_omni/diffusion/cache/teacache/config.py), for example:
+
+```python
+_MODEL_COEFFICIENTS = {
+ # Qwen-Image transformer coefficients from ComfyUI-TeaCache
+ # Tuned specifically for Qwen's dual-stream transformer architecture
+ # Used for all Qwen-Image Family pipelines, in general
+ "QwenImageTransformer2DModel": [
+ -4.50000000e02,
+ 2.80000000e02,
+ -4.50000000e01,
+ 3.20000000e00,
+ -2.00000000e-02,
+ ],
+ ...
+}
+```
+
+---
+
+## Best Practices
+
+### When to Use
+
+**Good for:**
+
+- Production deployments requiring faster inference, tolerant of minimal quality loss
+- Scenarios where 1.5-2x speedup is valuable
+- Useful for single-card acceleration
+
+**Not for:**
+
+- Maximum quality requirements where no degradation is acceptable
+- Very short inference runs (< 20 steps) where caching overhead may outweigh benefits
+
+
+---
+
+## Troubleshooting
+
+### Common Issue 1: Quality Degradation
+
+**Symptoms**: Generated images show artifacts, reduced detail, or inconsistent quality compared to non-cached results
+
+**Solution**:
+
+```python
+# Lower the threshold for more conservative caching
+cache_config={"rel_l1_thresh": 0.1}
+```
+
+### Common Issue 2: Limited Speedup
+
+**Symptoms**: Actual speedup is less than expected (< 1.3x)
+
+**Solutions**:
+1. Increase the threshold to enable more aggressive caching:
+ ```python
+ cache_config={"rel_l1_thresh": 0.8}
+ ```
+2. Ensure you're using sufficient inference steps (35+ recommended)
+3. Check that your model architecture is supported (see Supported Models section)
+
+---
+
+
+## Summary
+
+1. ✅ **Enable TeaCache** - Set `cache_backend="tea_cache"` to get 1.5x-2.0x speedup with optimized defaults
+2. ✅ **(Optional) Customize** - Adjust thresholds and polynomial coefficients for specific speed/quality trade-offs
diff --git a/docs/user_guide/diffusion/cache_dit_acceleration.md b/docs/user_guide/diffusion/cache_dit_acceleration.md
deleted file mode 100644
index c51ecca1e1c..00000000000
--- a/docs/user_guide/diffusion/cache_dit_acceleration.md
+++ /dev/null
@@ -1,228 +0,0 @@
-# Cache-DiT Acceleration Guide
-
-This guide explains how to use cache-dit acceleration in vLLM-Omni to speed up diffusion model inference.
-
-## Overview
-
-Cache-dit is a library that accelerates diffusion transformer models through intelligent caching mechanisms. It supports multiple acceleration techniques that can be combined for optimal performance:
-
-- **DBCache**: Dual Block Cache for reducing redundant computations
-- **TaylorSeer**: Taylor expansion-based forecasting for faster inference
-- **SCM**: Step Computation Masking for selective step computation
-
-## Quick Start
-
-### Basic Usage
-
-Enable cache-dit acceleration by simply setting `cache_backend="cache_dit"`. Cache-dit will use its recommended default parameters:
-
-```python
-from vllm_omni.entrypoints.omni import Omni
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-# Simplest way: just enable cache-dit with default parameters
-omni = Omni(
- model="Qwen/Qwen-Image",
- cache_backend="cache_dit",
-)
-
-images = omni.generate(
- "a beautiful landscape",
- OmniDiffusionSamplingParams(num_inference_steps=50),
-)
-```
-
-**Default Parameters**: When `cache_config` is not provided, cache-dit uses optimized default values. See the [Configuration Reference](#configuration-reference) section for a complete list of all parameters and their default values.
-
-### Custom Configuration
-
-To customize cache-dit settings, provide a `cache_config` dictionary, for example:
-
-```python
-omni = Omni(
- model="Qwen/Qwen-Image",
- cache_backend="cache_dit",
- cache_config={
- "Fn_compute_blocks": 1,
- "Bn_compute_blocks": 0,
- "max_warmup_steps": 4,
- "residual_diff_threshold": 0.12,
- },
-)
-```
-
-## Online Serving (OpenAI-Compatible)
-
-Enable Cache-DiT for online serving by passing `--cache-backend cache_dit` when starting the server:
-
-```bash
-# Use Cache-DiT default (recommended) parameters
-vllm serve Qwen/Qwen-Image --omni --port 8091 --cache-backend cache_dit
-```
-
-To customize Cache-DiT settings for online serving, pass a JSON string via `--cache-config`:
-
-```bash
-vllm serve Qwen/Qwen-Image --omni --port 8091 \
- --cache-backend cache_dit \
- --cache-config '{"Fn_compute_blocks": 1, "Bn_compute_blocks": 0, "max_warmup_steps": 4, "residual_diff_threshold": 0.12}'
-```
-
-## Acceleration Methods
-
-For comprehensive illustration, please view cache-dit [User_Guide](https://cache-dit.readthedocs.io/en/latest/user_guide/OVERVIEWS/)
-
-### 1. DBCache (Dual Block Cache)
-
-DBCache intelligently caches intermediate transformer block outputs when the residual differences between consecutive steps are small, reducing redundant computations without sacrificing quality.
-
-**Key Parameters**:
-
-- `Fn_compute_blocks` (int, default: 1): Number of **first n** transformer blocks used to compute stable feature differences. Higher values provide more accurate caching decisions but increase computation.
-- `Bn_compute_blocks` (int, default: 0): Number of **last n** transformer blocks used for additional fusion. These blocks act as an auto-scaler for approximate hidden states.
-- `max_warmup_steps` (int, default: 4): Number of initial steps where caching is disabled to ensure the model learns sufficient features before caching begins. Optimized for few-step distilled models.
-- `residual_diff_threshold` (float, default: 0.24): Threshold for residual difference. Higher values lead to faster performance but may reduce precision. Default uses a relatively higher threshold for more aggressive caching.
-- `max_cached_steps` (int, default: -1): Maximum number of cached steps. Set to -1 for unlimited caching.
-- `max_continuous_cached_steps` (int, default: 3): Maximum number of consecutive cached steps. Limits consecutive caching to prevent precision degradation.
-
-**Example Configuration**:
-
-```python
-cache_config={
- "Fn_compute_blocks": 8, # Use first 8 blocks for difference computation
- "Bn_compute_blocks": 0, # No additional fusion blocks
- "max_warmup_steps": 8, # Cache after 8 warmup steps
- "residual_diff_threshold": 0.12, # Higher threshold for faster inference
- "max_cached_steps": -1, # No limit on cached steps
-}
-```
-
-**Performance Tips**:
-
-- Default `Fn_compute_blocks=1` works well for most cases. Increase to 8-12 for larger models or when more accuracy is needed
-- Increase `residual_diff_threshold` (e.g., 0.12-0.15) for faster inference with slight quality trade-off, or decrease from default 0.24 for higher quality
-- Default `max_warmup_steps=4` is optimized for few-step models. Increase to 6-8 for more steps if needed
-
-### 2. TaylorSeer
-
-TaylorSeer uses Taylor expansion to forecast future hidden states, allowing the model to skip some computation steps while maintaining quality.
-
-**Key Parameters**:
-
-- `enable_taylorseer` (bool, default: False): Enable TaylorSeer acceleration
-- `taylorseer_order` (int, default: 1): Order of Taylor expansion. Higher orders provide better accuracy but require more computation.
-
-**Example Configuration**:
-
-```python
-cache_config={
- "enable_taylorseer": True,
- "taylorseer_order": 1, # First-order Taylor expansion
-}
-```
-
-**Performance Tips**:
-
-- Use `taylorseer_order=1` for most cases (good balance of speed and quality)
-- Combine with DBCache for maximum acceleration
-- Higher orders (2-3) may improve quality but reduce speed gains
-
-### 3. SCM (Step Computation Masking)
-
-SCM allows you to specify which steps must be computed and which can use cached results, similar to LeMiCa/EasyCache style acceleration.
-
-**Key Parameters**:
-
-- `scm_steps_mask_policy` (str | None, default: None): Predefined mask policy. Options:
- - `None`: SCM disabled (default)
- - `"slow"`: More compute steps, higher quality (18 compute steps out of 28)
- - `"medium"`: Balanced (15 compute steps out of 28)
- - `"fast"`: More cache steps, faster inference (11 compute steps out of 28)
- - `"ultra"`: Maximum speed (8 compute steps out of 28)
-- `scm_steps_policy` (str, default: "dynamic"): Policy for cached steps:
- - `"dynamic"`: Use dynamic cache for masked steps (recommended)
- - `"static"`: Use static cache for masked steps
-
-**Example Configuration**:
-
-```python
-cache_config={
- "scm_steps_mask_policy": "medium", # Balanced speed/quality
- "scm_steps_policy": "dynamic", # Use dynamic cache
-}
-```
-
-**Performance Tips**:
-
-- SCM is disabled by default (`scm_steps_mask_policy=None`). Enable it by setting a policy value if you need additional acceleration
-- Start with `"medium"` policy and adjust based on quality requirements
-- Use `"fast"` or `"ultra"` for maximum speed when quality can be slightly compromised
-- `"dynamic"` policy generally provides better quality than `"static"`
-- SCM mask is automatically regenerated when `num_inference_steps` changes during inference
-
-## Configuration Reference
-
-### DiffusionCacheConfig Parameters
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `Fn_compute_blocks` | int | 1 | First n blocks for difference computation (optimized for single-transformer models) |
-| `Bn_compute_blocks` | int | 0 | Last n blocks for fusion |
-| `max_warmup_steps` | int | 4 | Steps before caching starts (optimized for few-step distilled models) |
-| `max_cached_steps` | int | -1 | Max cached steps (-1 = unlimited) |
-| `max_continuous_cached_steps` | int | 3 | Max consecutive cached steps (prevents precision degradation) |
-| `residual_diff_threshold` | float | 0.24 | Residual difference threshold (higher for more aggressive caching) |
-| `num_inference_steps` | int \| None | None | Initial inference steps for SCM mask generation (optional, auto-refreshed during inference) |
-| `enable_taylorseer` | bool | False | Enable TaylorSeer acceleration (not suitable for few-step distilled models) |
-| `taylorseer_order` | int | 1 | Taylor expansion order |
-| `scm_steps_mask_policy` | str \| None | None | SCM mask policy (None, "slow", "medium", "fast", "ultra") |
-| `scm_steps_policy` | str | "dynamic" | SCM computation policy ("dynamic" or "static") |
-
-## Example: Accelerate Text-to-Image Generation with CacheDiT
-
-See `examples/offline_inference/text_to_image/text_to_image.py` for a complete working example with cache-dit acceleration.
-
-```bash
-# Enable cache-dit with hybrid acceleration
-cd examples/offline_inference/text_to_image
-python text_to_image.py \
- --model Qwen/Qwen-Image \
- --prompt "a cup of coffee on the table" \
- --cache-backend cache_dit \
- --num-inference-steps 50
-```
-
-
-The script uses cache-dit acceleration with a hybrid configuration combining DBCache, SCM, and TaylorSeer:
-
-```python
-omni = Omni(
- model="Qwen/Qwen-Image",
- cache_backend="cache_dit",
- cache_config={
- # Scheme: Hybrid DBCache + SCM + TaylorSeer
- # DBCache
- "Fn_compute_blocks": 8,
- "Bn_compute_blocks": 0,
- "max_warmup_steps": 4,
- "residual_diff_threshold": 0.12,
- # TaylorSeer
- "enable_taylorseer": True,
- "taylorseer_order": 1,
- # SCM
- "scm_steps_mask_policy": "fast", # Set to None to disable SCM
- "scm_steps_policy": "dynamic",
- },
-)
-```
-
-You can customize the configuration by modifying the `cache_config` dictionary to use only specific methods (e.g., DBCache only, DBCache + SCM, etc.) based on your quality and speed requirements.
-
-To test another model, you can modify `--model` with the target model identifier like `Tongyi-MAI/Z-Image-Turbo` and update `cache_config` according the model architecture (e.g., number of transformer blocks).
-
-
-## Additional Resources
-
-- [Cache-DiT User Guide](https://cache-dit.readthedocs.io/en/latest/user_guide/OVERVIEWS/)
-- [Cache-DiT Benchmark](https://cache-dit.readthedocs.io/en/latest/benchmark/HYBRID_CACHE/)
-- [DBCache Technical Details](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/)
diff --git a/docs/user_guide/diffusion/parallelism/cfg_parallel.md b/docs/user_guide/diffusion/parallelism/cfg_parallel.md
new file mode 100644
index 00000000000..5541106680a
--- /dev/null
+++ b/docs/user_guide/diffusion/parallelism/cfg_parallel.md
@@ -0,0 +1,169 @@
+# CFG-Parallel Guide
+
+
+## Table of Content
+
+- [Overview](#overview)
+- [Quick Start](#quick-start)
+- [Example Script](#example-script)
+- [Configuration Parameters](#configuration-parameters)
+- [Best Practices](#best-practices)
+- [Troubleshooting](#troubleshooting)
+- [Summary](#summary)
+
+---
+
+## Overview
+
+CFG-Parallel accelerates diffusion models by distributing positive and negative classifier-free guidance (CFG) passes across different GPUs, providing ~1.8x speedup when CFG is enabled. It's ideal for image editing tasks that require guidance scales greater than 1.0.
+
+See supported models list in [Supported Models](../../diffusion_features.md#supported-models).
+
+---
+
+## Quick Start
+
+### Basic Usage
+
+Simplest working example:
+
+```python
+from vllm_omni import Omni
+from vllm_omni.diffusion.data import DiffusionParallelConfig
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+from PIL import Image
+
+omni = Omni(
+ model="Qwen/Qwen-Image-Edit",
+ parallel_config=DiffusionParallelConfig(cfg_parallel_size=2), # Enable CFG-Parallel
+)
+
+input_image = Image.open("input.png").convert("RGB")
+outputs = omni.generate(
+ {
+ "prompt": "turn this cat to a dog",
+ "negative_prompt": "low quality, blurry",
+ "multi_modal_data": {"image": input_image},
+ },
+ OmniDiffusionSamplingParams(
+ true_cfg_scale=4.0,
+ num_inference_steps=50,
+ ),
+)
+```
+
+---
+
+## Example Script
+
+### Offline Inference
+
+Use python script under `examples/offline_inference/image_to_image/image_edit.py`:
+
+```bash
+cd examples/offline_inference/image_to_image/
+python image_edit.py \
+ --model "Qwen/Qwen-Image-Edit" \
+ --image "input.png" \
+ --prompt "turn this cat to a dog" \
+ --negative-prompt "low quality, blurry" \
+ --cfg-scale 4.0 \
+ --output "edited_image.png" \
+ --cfg-parallel-size 2
+```
+
+### Online Serving
+
+Enable CFG-Parallel in online serving:
+
+```bash
+# Default configuration
+vllm serve Qwen/Qwen-Image-Edit --omni --port 8091 --cfg-parallel-size 2
+
+```
+
+---
+
+## Configuration Parameters
+
+In `DiffusionParallelConfig`
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `cfg_parallel_size` | int | 1 | Number of GPUs for CFG parallelism. Set to 2 to enable CFG-Parallel (rank 0 for positive, rank 1 for negative branch) |
+
+
+!!! info
+ Most models support `cfg_parallel_size=2` (positive branch on rank 0, negative branch on rank 1). **Bagel** is an exception: it supports `cfg_parallel_size=3`, which adds a third branch on rank 2 for full three-way CFG parallelism.
+
+
+---
+
+## Best Practices
+
+### When to Use
+
+**Good for:**
+
+- Tasks requiring classifier-free guidance
+- Multi-GPU setups (at least 2 GPUs available)
+- Combining with other parallelism methods (sequence/tensor parallel)
+
+**Not for:**
+
+- Single GPU setups
+- Models that don't support CFG-Parallel (check [supported models](../../diffusion_features.md#supported-models))
+- Workloads without negative prompts or classifier-free guidance
+- Very short inference runs (< 10 steps) where parallelism overhead may outweigh benefits
+
+### Expected Performance
+
+| Configuration | Speedup | Quality | Use Case |
+|--------------|---------|---------|----------|
+| CFG-Parallel (2 GPUs) | 1.5~1.8x | No degradation | Large model, VRAM limited |
+
+---
+
+## Troubleshooting
+
+### Common Issue 1: No Speedup with CFG-Parallel
+
+**Symptoms**: CFG-Parallel enabled but no performance improvement
+
+**Solutions**:
+
+1. **Ensure CFG scale is set correctly:**
+```python
+# Bad: No CFG effect
+sampling_params = OmniDiffusionSamplingParams(num_inference_steps=50)
+
+# Good: CFG-Parallel will work
+sampling_params = OmniDiffusionSamplingParams(
+ num_inference_steps=50,
+ true_cfg_scale=4.0 # Must be > 1.0
+)
+```
+
+2. **Add negative prompt:**
+```python
+outputs = omni.generate(
+ {
+ "prompt": "beautiful landscape",
+ "negative_prompt": "low quality, blurry", # Required for best results
+ "multi_modal_data": {"image": input_image}
+ },
+ sampling_params
+)
+```
+
+3. **Check model support:**
+ - Verify your model in [supported models](../../diffusion_features.md#supported-models)
+ - Some models don't support CFG-Parallel
+
+---
+
+## Summary
+
+1. ✅ **Enable CFG-Parallel** - Set `cfg_parallel_size=2` in `DiffusionParallelConfig` to get speedup when using CFG
+2. ✅ **Set CFG Scale** - Ensure `true_cfg_scale > 1.0` in `OmniDiffusionSamplingParams` for CFG-Parallel to take effect
+3. ✅ **Check Model Support** - Verify your model supports CFG-Parallel in [supported models](../../diffusion_features.md#supported-models)
diff --git a/docs/user_guide/diffusion/parallelism/expert_parallel.md b/docs/user_guide/diffusion/parallelism/expert_parallel.md
new file mode 100644
index 00000000000..7d26d1e5c4f
--- /dev/null
+++ b/docs/user_guide/diffusion/parallelism/expert_parallel.md
@@ -0,0 +1,87 @@
+# Expert Parallelism Guide
+
+
+## Table of Content
+
+- [Overview](#overview)
+- [Quick Start](#quick-start)
+- [Configuration Parameters](#configuration-parameters)
+- [Best Practices](#best-practices)
+- [Summary](#summary)
+
+---
+
+## Overview
+
+Unlike Tensor Parallelism which shards every layer's weights, Expert Parallelism (EP) only shards the MoE expert MLP blocks. This significantly reduces the memory footprint of MoE models (e.g., HunyuanImage3.0) while maintaining constant dense-equivalent compute efficiency.
+
+During the forward pass, a gating mechanism routes tokens to their designated experts, requiring all-to-all communication to dispatch tokens to the correct ranks and combine results.
+
+See supported models list in [Supported Models](../../diffusion_features.md#supported-models).
+
+!!! note "EP Size Constraint"
+ The effective EP size equals `tp × sp × cfg × dp`. At least one of TP/SP/CFG/DP must be set when EP is enabled.
+
+---
+
+## Quick Start
+
+### Basic Usage
+
+```python
+from vllm_omni import Omni
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+from vllm_omni.diffusion.data import DiffusionParallelConfig
+
+omni = Omni(
+ model="tencent/HunyuanImage-3.0",
+ parallel_config=DiffusionParallelConfig(
+ tensor_parallel_size=8,
+ enable_expert_parallel=True,
+ ),
+)
+
+outputs = omni.generate(
+ "A brown and white dog is running on the grass",
+ OmniDiffusionSamplingParams(
+ num_inference_steps=50,
+ width=1024,
+ height=1024,
+ ),
+)
+```
+
+---
+
+## Configuration Parameters
+
+In `DiffusionParallelConfig`:
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `enable_expert_parallel` | bool | False | Enable Expert Parallelism for MoE models |
+
+EP size is derived automatically as `tp × sp × cfg × dp` — configure at least one of those to set the EP degree.
+
+---
+
+## Best Practices
+
+### When to Use
+
+**Good for:**
+
+- MoE models (e.g., HunyuanImage3.0) with numbers of experts
+- Memory-constrained multi-GPU setups where only expert blocks need sharding
+
+**Not for:**
+
+- Dense models (no MoE layers) — EP has no effect
+- Single GPU setups
+
+---
+
+## Summary
+
+1. ✅ **Enable EP** - Set `enable_expert_parallel=True` in `DiffusionParallelConfig` for MoE models
+2. ✅ **Set parallelism degree** - At least one of `tensor_parallel_size` / `ulysses_degree` / `cfg_parallel_size` must be > 1 to define the EP size
diff --git a/docs/user_guide/diffusion/parallelism/hsdp.md b/docs/user_guide/diffusion/parallelism/hsdp.md
new file mode 100644
index 00000000000..96a357c86b3
--- /dev/null
+++ b/docs/user_guide/diffusion/parallelism/hsdp.md
@@ -0,0 +1,149 @@
+# HSDP Guide
+
+
+## Table of Content
+
+- [Overview](#overview)
+- [Quick Start](#quick-start)
+- [Example Script](#example-script)
+- [Configuration Parameters](#configuration-parameters)
+- [Best Practices](#best-practices)
+- [Summary](#summary)
+
+---
+
+## Overview
+
+HSDP (Hybrid Sharded Data Parallel) shards model weights across GPUs to reduce per-GPU memory usage. This enables inference of large models (e.g., Wan2.2 14B) on GPUs with limited memory.
+
+Unlike Tensor Parallelism which splits computation, HSDP uses PyTorch's FSDP2 to shard and redistribute weights at runtime. Each GPU only holds a fraction of the model weights, and weights are gathered on-demand during forward passes.
+
+See supported models list in [Supported Models](../../diffusion_features.md#supported-models).
+
+**Operating Modes:**
+
+- **Standalone Mode**: HSDP alone without other parallelism. Must specify `hsdp_shard_size` explicitly.
+- **Combined Mode**: HSDP overlays on top of other parallelism (Ulysses-SP, CFG-Parallel). HSDP dimensions must match world_size.
+
+---
+
+## Quick Start
+
+### Basic Usage
+
+Simplest working example (standalone HSDP, shard across 4 GPUs):
+
+```python
+from vllm_omni import Omni
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+from vllm_omni.diffusion.data import DiffusionParallelConfig
+
+omni = Omni(
+ model="Wan-AI/Wan2.2-T2V-A14B-Diffusers",
+ parallel_config=DiffusionParallelConfig(
+ use_hsdp=True,
+ hsdp_shard_size=4, # Shard across 4 GPUs
+ ),
+)
+
+outputs = omni.generate(
+ "A cat playing piano",
+ OmniDiffusionSamplingParams(num_inference_steps=50),
+)
+```
+
+### Combined with Sequence Parallel
+
+```python
+omni = Omni(
+ model="Wan-AI/Wan2.2-T2V-A14B-Diffusers",
+ parallel_config=DiffusionParallelConfig(
+ ulysses_degree=4, # Sequence parallel
+ use_hsdp=True, # HSDP overlays on SP
+ ),
+)
+```
+
+---
+
+## Example Script
+
+### Offline Inference
+
+Use Python script under `examples/offline_inference/image_to_video/`:
+
+```bash
+# Standalone HSDP: shard across 4 GPUs
+python examples/offline_inference/image_to_video/image_to_video.py \
+ --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \
+ --use-hsdp \
+ --hsdp-shard-size 4
+
+# Combined HSDP + Sequence Parallel
+python examples/offline_inference/image_to_video/image_to_video.py \
+ --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \
+ --ulysses-degree 4 \
+ --use-hsdp
+```
+
+### Online Serving
+
+**Standalone HSDP** (shard model across 4 GPUs):
+
+```bash
+vllm serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --omni --port 8091 \
+ --use-hsdp --hsdp-shard-size 4
+```
+
+**Combined with Sequence Parallel**:
+
+```bash
+vllm serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --omni --port 8091 \
+ --use-hsdp --usp 4
+```
+
+---
+
+## Configuration Parameters
+
+In `DiffusionParallelConfig`:
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `use_hsdp` | bool | False | Enable HSDP |
+| `hsdp_shard_size` | int | -1 | Number of GPUs to shard weights across. `-1` = auto (requires other parallelism > 1) |
+| `hsdp_replicate_size` | int | 1 | Number of replica groups. Each group holds a full sharded copy |
+
+**Constraints:**
+
+- `hsdp_replicate_size × hsdp_shard_size == world_size`
+- HSDP cannot be used with Tensor Parallelism (`tensor_parallel_size` must be 1)
+
+---
+
+## Best Practices
+
+### When to Use
+
+**Good for:**
+
+- Very large models (e.g., Wan2.2 14B)
+- Multi-GPU setups where memory reduction is the primary goal
+- Combining with Sequence Parallelism for large video models
+
+**Not for:**
+
+- Models that fit comfortably in single-GPU memory
+- Use cases requiring Tensor Parallelism (HSDP and TP are mutually exclusive)
+
+### Adding HSDP Support to New Models
+
+For detailed instructions on adding HSDP support to new models, see the [HSDP Contributing Guide](../../../design/feature/hsdp.md).
+
+---
+
+## Summary
+
+1. ✅ **Enable HSDP** - Set `use_hsdp=True` and `hsdp_shard_size` to reduce per-GPU memory for large models
+2. ✅ **Combine with SP** - Use together with `ulysses_degree` for video models requiring both memory reduction and sequence parallelism
+3. ⚠️ **Incompatible with TP** - `tensor_parallel_size` must be 1 when HSDP is enabled
diff --git a/docs/user_guide/diffusion/parallelism/overview.md b/docs/user_guide/diffusion/parallelism/overview.md
new file mode 100644
index 00000000000..90d0b9660ef
--- /dev/null
+++ b/docs/user_guide/diffusion/parallelism/overview.md
@@ -0,0 +1,16 @@
+# Parallelism Acceleration Guide
+
+This guide covers the parallelism methods in vLLM-Omni for speeding up diffusion model inference and reducing per-device memory requirements.
+
+## Supported Methods
+
+| Method | Description |
+|--------|-------------|
+| **[Tensor Parallelism](tensor_parallel.md)** | Shards DiT weights across GPUs to reduce per-GPU memory |
+| **[Sequence Parallelism](sequence_parallel.md)** | Splits sequence dimension across GPUs (Ulysses-SP, Ring-Attention, or hybrid) for high-resolution images and videos |
+| **[CFG-Parallel](cfg_parallel.md)** | Runs CFG positive/negative branches on separate GPUs for ~1.8x speedup on guided generation |
+| **[VAE Patch Parallelism](vae_patch_parallel.md)** | Distributes VAE decode spatially across GPUs to reduce peak VAE memory |
+| **[HSDP](hsdp.md)** | Shards full model weights via PyTorch FSDP2 to enable large-model inference on memory-constrained GPUs |
+| **[Expert Parallelism](expert_parallel.md)** | Shards MoE expert blocks across GPUs for MoE models (e.g. HunyuanImage3.0) |
+
+See [Supported Models](../../diffusion_features.md#supported-models) for per-model compatibility.
diff --git a/docs/user_guide/diffusion/parallelism/sequence_parallel.md b/docs/user_guide/diffusion/parallelism/sequence_parallel.md
new file mode 100644
index 00000000000..e69b541f2ed
--- /dev/null
+++ b/docs/user_guide/diffusion/parallelism/sequence_parallel.md
@@ -0,0 +1,233 @@
+# Sequence Parallelism Guide
+
+
+## Table of Content
+
+- [Overview](#overview)
+- [Quick Start](#quick-start)
+- [Example Script](#example-script)
+- [Best Practices](#best-practices)
+- [Troubleshooting](#troubleshooting)
+- [Summary](#summary)
+
+---
+
+## Overview
+
+Sequence parallelism splits the input along the sequence dimension across multiple GPUs, allowing each device to process only a portion of the sequence. vLLM-Omni provides 1.5x-3.6x speedup for large images and videos using DeepSpeed Ulysses, Ring-Attention, or hybrid approaches. Use sequence parallelism when generating high-resolution images/videos that don't fit on a single GPU or require faster inference.
+
+See supported models list in [Diffusion Features - Supported Models](../../diffusion_features.md#supported-models).
+
+**Supported Methods:**
+
+- **DeepSpeed Ulysses Sequence Parallel (Ulysses-SP)** ([paper](https://arxiv.org/pdf/2309.14509)): Uses all-to-all communication for subset of attention heads per device
+- **Ring-Attention** ([paper](https://arxiv.org/abs/2310.01889)): Uses ring-based P2P communication with sharded sequence dimension throughout
+- **Hybrid Ulysses + Ring**: Combines both for larger scale parallelism (`ulysses_degree × ring_degree`)
+
+---
+
+## Quick Start
+
+### Basic Usage - Ulysses-SP
+
+Simplest working example with Ulysses Sequence Parallel:
+
+```python
+from vllm_omni import Omni
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+from vllm_omni.diffusion.data import DiffusionParallelConfig
+
+omni = Omni(
+ model="Qwen/Qwen-Image",
+ parallel_config=DiffusionParallelConfig(ulysses_degree=2) # Enable Ulysses-SP
+)
+
+outputs = omni.generate(
+ "A cat sitting on a windowsill",
+ OmniDiffusionSamplingParams(num_inference_steps=50, width=1024, height=1024),
+)
+```
+
+!!! note "Experimental UAA mode"
+ `ulysses_mode="advanced_uaa"` is an experimental extension to Ulysses-SP. It lets Ulysses attention handle arbitrary sequence lengths and arbitrary attention head counts without relying on `attention_mask`-based token padding.
+
+ In hybrid Ulysses + Ring mode, Ring still requires every rank in the same ring group to observe the same post-Ulysses sequence length. If that condition is not met, vLLM-Omni raises a validation error instead of entering the ring kernel with inconsistent shapes.
+
+To enable the experimental UAA mode, use a model/configuration that requires it. For example, `Tongyi-MAI/Z-Image-Turbo` has 30 attention heads, so `ulysses_degree=4` requires UAA because 30 is not divisible by 4:
+
+```python
+omni = Omni(
+ model="Tongyi-MAI/Z-Image-Turbo",
+ parallel_config=DiffusionParallelConfig(
+ ulysses_degree=4,
+ ulysses_mode="advanced_uaa",
+ ),
+)
+```
+
+### Alternative Methods
+
+**Ring-Attention** (better for very long sequences):
+
+```python
+omni = Omni(
+ model="Qwen/Qwen-Image",
+ parallel_config=DiffusionParallelConfig(ring_degree=2) # Enable Ring-Attention
+)
+```
+
+**Hybrid Ulysses + Ring** (for larger scale):
+
+```python
+omni = Omni(
+ model="Qwen/Qwen-Image",
+ parallel_config=DiffusionParallelConfig(ulysses_degree=2, ring_degree=2) # 4 GPUs total
+)
+```
+
+---
+
+## Example Script
+
+### Offline Inference
+
+Use Python script under `examples/offline_inference/text_to_image/text_to_image.py`:
+
+**Ulysses-SP:**
+
+```bash
+python examples/offline_inference/text_to_image/text_to_image.py \
+ --model Qwen/Qwen-Image \
+ --prompt "A cat sitting on a windowsill" \
+ --ulysses-degree 2 \
+ --width 1024 --height 1024
+```
+
+**Ring-Attention:**
+
+```bash
+python examples/offline_inference/text_to_image/text_to_image.py \
+ --model Qwen/Qwen-Image \
+ --prompt "A cat sitting on a windowsill" \
+ --ring-degree 2 \
+ --width 1024 --height 1024
+```
+
+**Hybrid Ulysses + Ring:**
+
+```bash
+# Hybrid: 2 Ulysses × 2 Ring = 4 GPUs total
+python examples/offline_inference/text_to_image/text_to_image.py \
+ --model Qwen/Qwen-Image \
+ --prompt "A cat sitting on a windowsill" \
+ --ulysses-degree 2 --ring-degree 2 \
+ --width 1024 --height 1024
+```
+
+### Online Serving
+
+**Ulysses-SP:**
+
+```bash
+# Text-to-image (requires >= 2 GPUs)
+vllm serve Qwen/Qwen-Image --omni --port 8091 --usp 2
+```
+
+**Ulysses-SP with UAA mode** (for models with non-divisible head counts):
+
+```bash
+vllm serve Tongyi-MAI/Z-Image-Turbo --omni --port 8091 --usp 4 --ulysses-mode advanced_uaa
+```
+
+**Ring-Attention:**
+
+```bash
+# Text-to-image (requires >= 2 GPUs)
+vllm serve Qwen/Qwen-Image --omni --port 8091 --ring 2
+```
+
+**Hybrid Ulysses + Ring:**
+
+```bash
+# Text-to-image (requires >= 4 GPUs)
+vllm serve Qwen/Qwen-Image --omni --port 8091 --usp 2 --ring 2
+```
+
+---
+
+## Configuration Parameters
+
+In `DiffusionParallelConfig`:
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `ulysses_degree` | int | 1 | Number of GPUs for Ulysses-SP. Uses all-to-all communication. |
+| `ring_degree` | int | 1 | Number of GPUs for Ring-Attention. Uses P2P ring communication. |
+| `ulysses_mode` | str | `"default"` | Ulysses attention mode. Set to `"advanced_uaa"` to handle arbitrary sequence lengths and head counts without padding. |
+
+**Notes:**
+- Total sequence parallel size equals to `ulysses_degree × ring_degree`
+- Degrees must evenly divide the sequence length for optimal performance (or use `ulysses_mode="advanced_uaa"` for Ulysses-SP)
+
+
+## Best Practices
+
+### When to Use
+
+**Good for:**
+
+- Large images (1024x1024 or higher) or videos
+- Fast inter-GPU communication, larger bandwidth (e.g., NVLink)
+
+**Not for:**
+
+- Small images (<1024px) - overhead exceeds benefit, use single GPU with cache instead
+
+
+---
+
+## Troubleshooting
+
+### Common Issue 1: Performance Not Scaling
+
+**Symptoms**: Adding GPUs doesn't improve speed proportionally, or higher parallelism degree is slower
+
+**Diagnosis:**
+```bash
+# Check GPU topology
+nvidia-smi topo -m
+
+```
+
+**Solutions:**
+
+1. Check inter-GPU communication - NVLink is better than PCIe
+2. Reduce parallelism degree if over-parallelized:
+```python
+# If 4 GPUs is slower than 2
+parallel_config=DiffusionParallelConfig(ulysses_degree=2)
+```
+3. Try to switch between Ring-Attention and Ulysses-SP
+
+- Ring-Attention has advantages, like communication-computation overlap, but the block-wise loop overhead is relatively higher, especially for short sequences
+- Ulysses-SP: can benefit from larger bandwidth (such as NVLink), with two major constraints, the sequence length should be divisible by usp size, and the number of heads should be divisible by usp size (or use `ulysses_mode="advanced_uaa"`)
+
+
+### Common Issue 2: Out of Memory (OOM)
+
+**Symptoms**: CUDA OOM errors or process crashes with memory errors
+
+**Solutions:**
+
+1. Increase parallelism degree to split sequence more:
+```python
+parallel_config=DiffusionParallelConfig(ulysses_degree=4) # From 2
+```
+2. Combine with other parallelism method, e.g., tensor parallel, and memory optimization methods, e.g., cpu offloading.
+
+
+## Summary
+
+1. ✅ **Enable Sequence Parallelism** - Set `ulysses_degree` or `ring_degree` for long sequence generation
+2. ✅ **UAA mode** - Use `ulysses_mode="advanced_uaa"` when head count is not divisible by `ulysses_degree`
+3. ✅ **Troubleshooting** - Check GPU topology with `nvidia-smi topo -m`, reduce degree if performance doesn't scale
diff --git a/docs/user_guide/diffusion/parallelism/tensor_parallel.md b/docs/user_guide/diffusion/parallelism/tensor_parallel.md
new file mode 100644
index 00000000000..8e6851412cf
--- /dev/null
+++ b/docs/user_guide/diffusion/parallelism/tensor_parallel.md
@@ -0,0 +1,151 @@
+# Tensor Parallelism Guide
+
+
+## Table of Content
+
+- [Overview](#overview)
+- [Quick Start](#quick-start)
+- [Example Script](#example-script)
+- [Configuration Parameters](#configuration-parameters)
+- [Best Practices](#best-practices)
+- [Troubleshooting](#troubleshooting)
+- [Summary](#summary)
+
+---
+
+## Overview
+
+Tensor Parallelism (TP) shards some model weights across multiple GPUs, usually the Linear layers. This enables running large models that don't fit on a single GPU. It's essential for memory-constrained setups or very large models.
+
+See supported models list in [Supported Models](../../diffusion_features.md#supported-models).
+
+!!! note "TP Limitations for Diffusion Models"
+ We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the `text_encoder` component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP.
+
+ - Good news: The text_encoder typically has minimal impact on overall inference performance.
+ - Bad news: When TP is enabled, every TP process retains a full copy of the text_encoder weights, leading to significant GPU memory waste.
+
+ We are actively refactoring this design to address this. For details and progress, please refer to [Issue #771](https://github.com/vllm-project/vllm-omni/issues/771).
+
+---
+
+## Quick Start
+
+
+### Basic Usage
+
+Simplest working example:
+
+```python
+from vllm_omni import Omni
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+from vllm_omni.diffusion.data import DiffusionParallelConfig
+
+omni = Omni(
+ model="Tongyi-MAI/Z-Image-Turbo",
+ parallel_config=DiffusionParallelConfig(tensor_parallel_size=2), # Enable TP
+)
+
+outputs = omni.generate(
+ "a cat reading a book",
+ OmniDiffusionSamplingParams(num_inference_steps=9),
+)
+```
+
+---
+
+## Example Script
+
+### Offline Inference
+
+Use Python script under `examples/offline_inference`, and enable TP:
+
+```bash
+# Text-to-Image with Qwen-Image
+python examples/offline_inference/text_to_image/text_to_image.py \
+ --model Qwen/Qwen-Image \
+ --tensor-parallel-size 2
+
+# Image Editing with Qwen-Image-Edit
+python examples/offline_inference/image_to_image/image_edit.py \
+ --model Qwen/Qwen-Image-Edit \
+ --image input.png \
+ --prompt "Edit description" \
+ --tensor-parallel-size 2
+```
+
+### Online Serving
+
+You can enable tensor parallelism in online serving via `--tensor-parallel-size`:
+
+```bash
+# Text-to-Image with Qwen-Image on 2 GPUs
+vllm serve Qwen/Qwen-Image --omni --port 8091 \
+ --tensor-parallel-size 2
+
+# Text-to-Image with Z-Image (TP=2 only)
+vllm serve Tongyi-MAI/Z-Image-Turbo --omni --port 8091 \
+ --tensor-parallel-size 2
+```
+
+---
+
+## Configuration Parameters
+
+In `DiffusionParallelConfig`:
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `tensor_parallel_size` | int | 1 | Number of GPUs to shard model weights across. Must divide number of heads. |
+
+
+---
+
+## Best Practices
+
+### When to Use
+
+**Good for:**
+
+- Large models that don't fit on a single GPU, especially for models with large DiT blocks (transformer layers)
+- Memory-constrained environments
+
+**Not for:**
+
+- When maximum throughput is needed and memory is sufficient
+- Models with incompatible dimensions (e.g., Z-Image `num_heads=30`, which now supports `tensor_parallel_size=2`)
+
+
+## Troubleshooting
+
+### Common Issue 1: Out of Memory (OOM)
+
+**Symptoms**: CUDA OOM errors during model loading or inference, process crashes with memory errors
+
+**Solution**:
+```python
+# Step 1: Enable TP with smallest degree
+parallel_config=DiffusionParallelConfig(tensor_parallel_size=2)
+
+# Step 2: If still OOM, increase TP degree
+parallel_config=DiffusionParallelConfig(tensor_parallel_size=4)
+
+```
+
+### Common Issue 2: Divisibility Error
+
+**Symptoms**: Error like "Model dimension X not divisible by tensor_parallel_size Y"
+
+**Solutions**:
+1. Check model-specific constraints (e.g., Z-Image only supports TP=2)
+2. Use a smaller TP size that divides model dimensions
+3. Consult [Supported Models](../../diffusion_features.md#supported-models) for compatible TP sizes
+
+
+---
+
+## Summary
+
+1. ✅ **Enable TP** - Set `--tensor-parallel-size` to reduce per-GPU memory
+2. ✅ **Increase TP size** - Only increase if OOM persists
+3. ⚠️ **Text encoder not sharded** - Known limitation
diff --git a/docs/user_guide/diffusion/parallelism/vae_patch_parallel.md b/docs/user_guide/diffusion/parallelism/vae_patch_parallel.md
new file mode 100644
index 00000000000..4e8513eabf4
--- /dev/null
+++ b/docs/user_guide/diffusion/parallelism/vae_patch_parallel.md
@@ -0,0 +1,200 @@
+# VAE Patch Parallelism Guide
+
+
+## Table of Content
+
+- [Overview](#overview)
+- [Quick Start](#quick-start)
+- [Example Script](#example-script)
+- [Configuration Parameters](#configuration-parameters)
+- [Best Practices](#best-practices)
+- [Troubleshooting](#troubleshooting)
+- [Summary](#summary)
+
+---
+
+## Overview
+
+VAE Patch Parallelism distributes the VAE (Variational AutoEncoder) decode/encode computation across multiple GPUs by splitting the latent space into spatial tiles or patches. Each GPU processes a subset of tiles in parallel, significantly reducing peak memory consumption during the VAE decode stage while maintaining output quality.
+
+This is particularly useful for:
+- **High-resolution image generation** where VAE decode can become a memory bottleneck
+- **Memory-constrained environments** where the VAE decode activation peak exceeds available VRAM
+- **Multi-GPU setups** where you want to leverage distributed resources for the VAE stage
+
+See supported models list in [Supported Models](../../diffusion_features.md#supported-models).
+
+
+VAE Patch Parallelism uses two strategies based on image size:
+
+| Strategy | Use Case | How It Works | Overlap Handling | Output Quality |
+|----------|----------|--------------|------------------|----------------|
+| **Tiled Decode** | Large images (triggers VAE tiling) | Distributes existing VAE tiling computation across ranks. Each rank decodes a subset of overlapping tiles. | Uses VAE's native `blend_v` and `blend_h` functions to seamlessly merge overlapping regions | Bit-identical (same logic as single-GPU tiling) |
+| **Patch Decode** | Small images (no VAE tiling) | Splits latent into spatial patches with halos. Each rank decodes one patch with boundary context. | Halo regions provide edge context; core regions are directly stitched without blending | Near-identical (diff < 0.5%, visually imperceptible) |
+
+
+VAE Patch Parallelism **reuses the DiT process group** (`dit_group`) and does not initialize a separate ProcessGroup. This means:
+
+- **Shared ranks**: VAE patch parallelism uses the same GPU ranks as DiT parallelism (Tensor Parallel, Sequence Parallel, etc.)
+- **Combined usage**: VAE patch parallelism is typically used together with other parallelism methods
+- **Configuration alignment**: The `vae_patch_parallel_size` should be no greater than the size of your DiT process group
+
+---
+
+## Quick Start
+
+### Basic Usage
+
+Simplest working example:
+
+```python
+from vllm_omni import Omni
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+from vllm_omni.diffusion.data import DiffusionParallelConfig
+
+# TP=2 for DiT, VAE patch parallel also uses these 2 GPUs
+omni = Omni(
+ model="Tongyi-MAI/Z-Image-Turbo",
+ parallel_config=DiffusionParallelConfig(
+ tensor_parallel_size=2, # Enable tensor parallelism for DiT
+ vae_patch_parallel_size=2, # Enable VAE patch parallelism
+ ),
+ vae_use_tiling=True, # Required for VAE patch parallelism
+)
+
+outputs = omni.generate(
+ "a futuristic city at sunset, high resolution, 8k",
+ OmniDiffusionSamplingParams(
+ num_inference_steps=9,
+ height=1152, # High resolution benefits from VAE patch parallel
+ width=1152,
+ ),
+)
+```
+
+---
+
+## Example Script
+
+### Offline Inference
+
+Use Python script under `examples/offline_inference/text_to_image/`:
+
+```bash
+# Text-to-Image with Z-Image
+python examples/offline_inference/text_to_image/text_to_image.py \
+ --model Tongyi-MAI/Z-Image-Turbo \
+ --prompt "a futuristic city at sunset" \
+ --height 1152 \
+ --width 1152 \
+ --tensor-parallel-size 2 \
+ --vae-patch-parallel-size 2 \
+ --vae-use-tiling
+```
+
+### Online Serving
+
+You can enable VAE patch parallelism in online serving via `--vae-patch-parallel-size`:
+
+```bash
+# Text-to-Image with Z-Image, TP=2 + VAE patch parallel=2
+vllm serve Tongyi-MAI/Z-Image-Turbo --omni --port 8091 \
+ --tensor-parallel-size 2 \
+ --vae-patch-parallel-size 2 \
+ --vae-use-tiling
+```
+
+---
+
+## Configuration Parameters
+
+In `DiffusionParallelConfig`:
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `vae_patch_parallel_size` | int | 1 | Number of GPUs for VAE patch/tile parallelism. Set to 2 or higher to enable. Should typically match `tensor_parallel_size` as they share the same process group. |
+
+Additional requirements:
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `vae_use_tiling` | bool | False | Must be set to `True` when using VAE patch parallelism. |
+
+!!! note "Automatic VAE Tiling"
+ When `vae_patch_parallel_size > 1` and the model has a distributed VAE (`DistributedVaeMixin`), the system automatically sets `vae_use_tiling=True` if not already enabled.
+
+---
+
+## Best Practices
+
+### When to Use
+
+**Good for:**
+
+- High-resolution image generation and long video generation
+- Memory-constrained setups where VAE decode causes OOM
+- Multi-GPU environments
+
+**Not for:**
+
+- Low-resolution images/videos where VAE decode is not a bottleneck
+- Single GPU setups should use vae tiling decode, but not parallel vae tiling decode
+- Models that do not support vae patch parallel
+
+---
+
+## Troubleshooting
+
+### Common Issue 1: Model Not Support VAE Patch Parallel
+
+**Symptoms**:
+```
+WARNING: vae_patch_parallel_size=2 is set but VAE patch parallelism is NOT enabled for xxxPipeline; ignoring.
+```
+
+**Root Cause**: VAE Patch Parallelism requires the model's VAE to implement `DistributedVaeMixin`. At startup, `vllm_omni/diffusion/registry.py` checks whether the instantiated pipeline has a `.vae` attribute that is an instance of `DistributedVaeMixin`. If it does not, the setting is silently ignored:
+
+```python
+vae_pp_size = od_config.parallel_config.vae_patch_parallel_size
+is_distributed_vae = hasattr(model, "vae") and isinstance(model.vae, DistributedVaeMixin)
+if vae_pp_size > 1 and not is_distributed_vae:
+ logger.warning(
+ "vae_patch_parallel_size=%d is set but VAE patch parallelism is NOT enabled for %s; ignoring.",
+ vae_pp_size,
+ od_config.model_class_name,
+ )
+```
+
+**Solutions**:
+
+1. **Use a supported model** (recommended): check [Supported Models](../../diffusion_features.md#supported-models) for the VAE-Patch-Parallel column.
+
+2. To add support for a new model, implement `DistributedVaeMixin` on its VAE class (contributions are welcome).
+
+
+### Common Issue 2: `vae_patch_parallel_size` Exceeds DiT Process Group Size
+
+**Symptoms**: Shows warning message, and vae patch parallel size is resized to DiT process group size
+
+**Root Cause**: VAE Patch Parallelism reuses the DiT process group.
+
+**Recommendation**: Always set `vae_patch_parallel_size` to be no greater than your DiT process group size.
+
+Note that the size of DiT process group size equals to:
+```text
+dit_parallel_size = data_parallel_size
+ × cfg_parallel_size
+ × sequence_parallel_size
+ × pipeline_parallel_size
+ × tensor_parallel_size
+
+```
+_sequence_parallel_size = ulysses_degree × ring_degree_
+
+---
+
+## Summary
+
+1. ✅ **Enable VAE Patch Parallelism** - Set `vae_patch_parallel_size`, `vae_use_tiling=True` in `DiffusionParallelConfig` to reduce VAE decode peak memory
+2. ✅ **Use Long Sequence** - VAE patch parallelism benefits are most apparent at long sequence decoding
+3. ✅ **Combine with other parallelism methods** - Suggest to use together with Tensor Parallel or CFG-Parallel for maximum memory savings
diff --git a/docs/user_guide/diffusion/quantization/fp8.md b/docs/user_guide/diffusion/quantization/fp8.md
index 338781fa22c..9906631b625 100644
--- a/docs/user_guide/diffusion/quantization/fp8.md
+++ b/docs/user_guide/diffusion/quantization/fp8.md
@@ -65,6 +65,7 @@ The available `ignored_layers` names depend on the model architecture (e.g., `to
| Flux | `black-forest-labs/FLUX.1-dev` | All layers | None |
| HunyuanImage-3 | `tencent/HunyuanImage3` | All layers | None |
| HunyuanVideo-1.5 | `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v`, `720p_t2v`, `480p_i2v` | All layers | None |
+| Helios | `BestWishYsh/Helios-Base`, `BestWishYsh/Helios-Mid`, `BestWishYsh/Helios-Distilled` | All layers | None |
## Combining with Other Features
diff --git a/docs/user_guide/diffusion/step_execution.md b/docs/user_guide/diffusion/step_execution.md
index 99c2878506e..f8c9fa8ddb2 100644
--- a/docs/user_guide/diffusion/step_execution.md
+++ b/docs/user_guide/diffusion/step_execution.md
@@ -46,7 +46,7 @@ its stepwise request state machine. For normal diffusion inference, leave it
disabled unless your workflow depends on this mode.
If you are looking for general diffusion speedups, see
-[Diffusion Acceleration Overview](../diffusion_acceleration.md).
+[Diffusion Features Overview](../diffusion_features.md).
## Troubleshooting
diff --git a/docs/user_guide/diffusion/teacache.md b/docs/user_guide/diffusion/teacache.md
deleted file mode 100644
index 40dafeb88ad..00000000000
--- a/docs/user_guide/diffusion/teacache.md
+++ /dev/null
@@ -1,145 +0,0 @@
-# TeaCache Configuration Guide
-
-TeaCache speeds up diffusion model inference by caching transformer computations when consecutive timesteps are similar. This typically provides **1.5x-2.0x speedup** with minimal quality loss.
-
-## Quick Start
-
-Enable TeaCache by setting `cache_backend` to `"tea_cache"`:
-
-```python
-from vllm_omni import Omni
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-# Simple configuration - model_type is automatically extracted from pipeline.__class__.__name__
-omni = Omni(
- model="Qwen/Qwen-Image",
- cache_backend="tea_cache",
- cache_config={
- "rel_l1_thresh": 0.2 # Optional, defaults to 0.2
- }
-)
-outputs = omni.generate(
- "A cat sitting on a windowsill",
- OmniDiffusionSamplingParams(
- num_inference_steps=50,
- ),
-)
-```
-
-### Using Environment Variable
-
-You can also enable TeaCache via environment variable:
-
-```bash
-export DIFFUSION_CACHE_BACKEND=tea_cache
-```
-
-Then initialize without explicitly setting `cache_backend`:
-
-```python
-from vllm_omni import Omni
-
-omni = Omni(
- model="Qwen/Qwen-Image",
- cache_config={"rel_l1_thresh": 0.2} # Optional
-)
-```
-
-## Online Serving (OpenAI-Compatible)
-
-Enable TeaCache for online serving by passing `--cache-backend tea_cache` when starting the server:
-
-```bash
-vllm serve Qwen/Qwen-Image --omni --port 8091 \
- --cache-backend tea_cache \
- --cache-config '{"rel_l1_thresh": 0.2}'
-```
-
-## Configuration Parameters
-
-### `rel_l1_thresh` (float, default: `0.2`)
-
-Controls the balance between speed and quality. Lower values prioritize quality, higher values prioritize speed.
-
-**Recommended values:**
-
-- `0.2` - **~1.5x speedup** with minimal quality loss (recommended)
-- `0.4` - **~1.8x speedup** with slight quality loss
-- `0.6` - **~2.0x speedup** with noticeable quality loss
-- `0.8` - **~2.25x speedup** with significant quality loss
-
-## Examples
-
-### Python API
-
-```python
-from vllm_omni import Omni
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-omni = Omni(
- model="Qwen/Qwen-Image",
- cache_backend="tea_cache",
- cache_config={"rel_l1_thresh": 0.2}
-)
-outputs = omni.generate(
- "A cat sitting on a windowsill",
- OmniDiffusionSamplingParams(
- num_inference_steps=50,
- ),
-)
-```
-
-## Performance Tuning
-
-Start with the default `rel_l1_thresh=0.2` and adjust based on your needs:
-
-- **Maximum quality**: Use `0.1-0.2`
-- **Balanced**: Use `0.2-0.4` (recommended)
-- **Maximum speed**: Use `0.6-0.8` (may reduce quality)
-
-## Troubleshooting
-
-### Quality Degradation
-
-If you notice quality issues, lower the threshold:
-
-```python
-cache_config={"rel_l1_thresh": 0.1} # More conservative caching
-```
-
-## Supported Models
-
-### ImageGen
-
-
-
-| Architecture | Models | Example HF Models |
-|--------------|--------|-------------------|
-| `QwenImagePipeline` | Qwen-Image | `Qwen/Qwen-Image` |
-| `QwenImageEditPipeline` | Qwen-Image-Edit | `Qwen/Qwen-Image-Edit` |
-| `QwenImageEditPlusPipeline` | Qwen-Image-Edit-2509 | `Qwen/Qwen-Image-Edit-2509` |
-| `QwenImageLayeredPipeline` | Qwen-Image-Layered | `Qwen/Qwen-Image-Layered` |
-| `BagelForConditionalGeneration` | BAGEL (DiT-only) | `ByteDance-Seed/BAGEL-7B-MoT` |
-
-### VideoGen
-
-No VideoGen models are supported by TeaCache yet.
-
-### Coming Soon
-
-
-
-| Architecture | Models | Example HF Models |
-|--------------|--------|-------------------|
-| `FluxPipeline` | Flux | - |
-| `CogVideoXPipeline` | CogVideoX | - |
From d662ca162604d94434591a923414ec1b745f2a80 Mon Sep 17 00:00:00 2001
From: Chen Yang <2082464740@qq.com>
Date: Fri, 3 Apr 2026 11:31:55 +0800
Subject: [PATCH 7/8] add test
Signed-off-by: Chen Yang <2082464740@qq.com>
---
docs/user_guide/diffusion_features.md | 2 +-
.../e2e/online_serving/test_ltx2_expansion.py | 139 ++++++++++++++++++
2 files changed, 140 insertions(+), 1 deletion(-)
create mode 100644 tests/e2e/online_serving/test_ltx2_expansion.py
diff --git a/docs/user_guide/diffusion_features.md b/docs/user_guide/diffusion_features.md
index 9cd407d377a..f7c1d071f90 100644
--- a/docs/user_guide/diffusion_features.md
+++ b/docs/user_guide/diffusion_features.md
@@ -131,7 +131,7 @@ The following tables show which models support each feature:
|-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:|
| **Wan2.2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| **Wan2.1-VACE** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
-| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
+| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| **Helios** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
diff --git a/tests/e2e/online_serving/test_ltx2_expansion.py b/tests/e2e/online_serving/test_ltx2_expansion.py
new file mode 100644
index 00000000000..7792feb33db
--- /dev/null
+++ b/tests/e2e/online_serving/test_ltx2_expansion.py
@@ -0,0 +1,139 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+L4 e2e tests for LTX-2 in online serving mode.
+
+Coverage:
+- Cache-DiT (1 GPU)
+- Cache-DiT + TP=2 + VAE patch parallel=2 (2 GPUs)
+
+LTX-2 is served through the async video API (/v1/videos) in online serving mode.
+"""
+
+import time
+
+import pytest
+import requests
+
+# Disable proxy for local test server requests
+NO_PROXY = {"http": None, "https": None}
+
+from tests.conftest import (
+ OmniServer,
+ OmniServerParams,
+)
+from tests.utils import hardware_marks
+
+MODEL = "Lightricks/LTX-2"
+PROMPT = "A cinematic close-up of ocean waves at golden hour."
+NEGATIVE_PROMPT = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+SINGLE_CARD_FEATURE_MARKS = hardware_marks(res={"cuda": "L4"})
+PARALLEL_MARKS = hardware_marks(res={"cuda": "L4"}, num_cards=2)
+
+VIDEO_TIMEOUT_S = 900.0
+VIDEO_POLL_INTERVAL_S = 2.0
+
+
+def _get_diffusion_feature_cases(model: str):
+ """Return L4 diffusion feature cases for LTX-2."""
+ return [
+ # (1 GPU) Cache-DiT
+ pytest.param(
+ OmniServerParams(
+ model=model,
+ server_args=[
+ "--cache-backend",
+ "cache_dit",
+ ],
+ ),
+ id="single_card_cachedit",
+ marks=SINGLE_CARD_FEATURE_MARKS,
+ ),
+ # (2 GPUs) Cache-DiT + TP=2 + VAE patch parallel=2
+ pytest.param(
+ OmniServerParams(
+ model=model,
+ server_args=[
+ "--cache-backend",
+ "cache_dit",
+ "--tensor-parallel-size",
+ "2",
+ "--vae-patch-parallel-size",
+ "2",
+ "--vae-use-tiling",
+ ],
+ ),
+ id="parallel_cachedit_tp2_vae2",
+ marks=PARALLEL_MARKS,
+ ),
+ ]
+
+
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
+@pytest.mark.parametrize(
+ "omni_server",
+ _get_diffusion_feature_cases(MODEL),
+ indirect=True,
+)
+def test_ltx2(
+ omni_server: OmniServer,
+):
+ """L4 diffusion feature coverage for LTX-2 on L4."""
+ url = f"http://{omni_server.host}:{omni_server.port}/v1/videos"
+
+ payload = {
+ "prompt": PROMPT,
+ "height": 512,
+ "width": 768,
+ "num_frames": 9,
+ "num_inference_steps": 2,
+ "negative_prompt": NEGATIVE_PROMPT,
+ "guidance_scale": 4.0,
+ "fps": 24,
+ "seed": 42,
+ }
+
+ files = [(k, (None, str(v))) for k, v in payload.items()]
+
+ create_resp = requests.post(url, files=files, timeout=VIDEO_TIMEOUT_S, proxies=NO_PROXY)
+ assert create_resp.status_code == 200, create_resp.text
+
+ data = create_resp.json()
+ video_id = data["id"]
+ assert data["status"] == "queued"
+ assert data["model"] == omni_server.model
+
+ # Poll for completion
+ deadline = time.time() + VIDEO_TIMEOUT_S
+ last_status = None
+ while time.time() < deadline:
+ status_resp = requests.get(f"{url}/{video_id}", timeout=30, proxies=NO_PROXY)
+ assert status_resp.status_code == 200, status_resp.text
+ status_data = status_resp.json()
+ last_status = status_data["status"]
+ if last_status == "completed":
+ break
+ if last_status == "failed":
+ raise AssertionError(f"Video generation failed: {status_data}")
+ time.sleep(VIDEO_POLL_INTERVAL_S)
+ else:
+ raise AssertionError(
+ f"Timed out waiting for video generation. Last status: {last_status}"
+ )
+
+ # Verify download returns a valid MP4
+ download_resp = requests.get(
+ f"{url}/{video_id}/content",
+ timeout=VIDEO_TIMEOUT_S,
+ proxies=NO_PROXY,
+ )
+ assert download_resp.status_code == 200, download_resp.text
+ assert download_resp.headers["content-type"].startswith("video/mp4")
+ assert len(download_resp.content) > 32, (
+ f"Downloaded video payload is unexpectedly small: {len(download_resp.content)} bytes"
+ )
+ assert download_resp.content[4:8] == b"ftyp", (
+ "Downloaded payload does not look like an MP4 file."
+ )
\ No newline at end of file
From 495a0cb0eaf5a4197326a423d2b6f247c7fcf67a Mon Sep 17 00:00:00 2001
From: Chen Yang <2082464740@qq.com>
Date: Fri, 3 Apr 2026 11:37:27 +0800
Subject: [PATCH 8/8] fix test
Signed-off-by: Chen Yang <2082464740@qq.com>
---
tests/e2e/online_serving/test_ltx2_expansion.py | 14 +++++---------
1 file changed, 5 insertions(+), 9 deletions(-)
diff --git a/tests/e2e/online_serving/test_ltx2_expansion.py b/tests/e2e/online_serving/test_ltx2_expansion.py
index 7792feb33db..4112b1b4fa5 100644
--- a/tests/e2e/online_serving/test_ltx2_expansion.py
+++ b/tests/e2e/online_serving/test_ltx2_expansion.py
@@ -15,15 +15,15 @@
import pytest
import requests
-# Disable proxy for local test server requests
-NO_PROXY = {"http": None, "https": None}
-
from tests.conftest import (
OmniServer,
OmniServerParams,
)
from tests.utils import hardware_marks
+# Disable proxy for local test server requests
+NO_PROXY = {"http": None, "https": None}
+
MODEL = "Lightricks/LTX-2"
PROMPT = "A cinematic close-up of ocean waves at golden hour."
NEGATIVE_PROMPT = "worst quality, inconsistent motion, blurry, jittery, distorted"
@@ -119,9 +119,7 @@ def test_ltx2(
raise AssertionError(f"Video generation failed: {status_data}")
time.sleep(VIDEO_POLL_INTERVAL_S)
else:
- raise AssertionError(
- f"Timed out waiting for video generation. Last status: {last_status}"
- )
+ raise AssertionError(f"Timed out waiting for video generation. Last status: {last_status}")
# Verify download returns a valid MP4
download_resp = requests.get(
@@ -134,6 +132,4 @@ def test_ltx2(
assert len(download_resp.content) > 32, (
f"Downloaded video payload is unexpectedly small: {len(download_resp.content)} bytes"
)
- assert download_resp.content[4:8] == b"ftyp", (
- "Downloaded payload does not look like an MP4 file."
- )
\ No newline at end of file
+ assert download_resp.content[4:8] == b"ftyp", "Downloaded payload does not look like an MP4 file."