Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
c63e428
mv how to tutorial to developer guide
wtomin Feb 3, 2026
87c6ecb
auto doc
wtomin Feb 3, 2026
25976d1
manual tutorial
wtomin Feb 3, 2026
5c6ee0e
auto updates
wtomin Feb 3, 2026
6df7ee5
auto updates for cfg paralle and sp
wtomin Feb 3, 2026
0fe38d1
update doc
wtomin Feb 3, 2026
6ff5515
updates
wtomin Feb 3, 2026
abb1384
udpate main doc
wtomin Feb 4, 2026
e31e638
tensor parallel doc
wtomin Feb 4, 2026
3c09d52
update cache_dit md
wtomin Feb 4, 2026
139c1f5
update teacache
wtomin Feb 4, 2026
987b5d5
update sp doc
wtomin Feb 4, 2026
89e530c
update cfg parallel
wtomin Feb 4, 2026
ce48837
update two docs
wtomin Feb 4, 2026
39a47e2
structure
wtomin Feb 4, 2026
34d03d4
update main doc
wtomin Feb 4, 2026
2894f15
correct path
wtomin Feb 4, 2026
7f614c3
updates
wtomin Feb 4, 2026
33bcf11
Update docs/contributing/features/tensor_parallel.md
wtomin Feb 5, 2026
d21b7c8
Update docs/contributing/features/sequence_parallel.md
wtomin Feb 5, 2026
b023111
Update docs/contributing/model/adding_diffusion_model.md
wtomin Feb 5, 2026
bc7bfa2
Update docs/contributing/model/adding_diffusion_model.md
wtomin Feb 5, 2026
a5ba519
Update docs/contributing/features/cache_dit.md
wtomin Feb 5, 2026
36ff9e6
Update docs/contributing/model/adding_diffusion_model.md
wtomin Feb 5, 2026
b2acae6
Update docs/contributing/model/adding_diffusion_model.md
wtomin Feb 5, 2026
657124f
Update docs/contributing/features/cfg_parallel.md
wtomin Feb 5, 2026
e09521d
Update docs/contributing/features/cfg_parallel.md
wtomin Feb 5, 2026
2b7bb18
Update docs/contributing/model/adding_diffusion_model.md
wtomin Feb 5, 2026
0fdadb0
update numbering
wtomin Feb 5, 2026
bfae334
Apply suggestion from @Copilot
wtomin Feb 5, 2026
1742f00
documentation updates
wtomin Feb 5, 2026
aa24d80
update tp
wtomin Feb 5, 2026
422f4f4
impr trouble shooting
wtomin Feb 5, 2026
af5f2e6
sp_refine
wtomin Feb 5, 2026
925fb4f
impr sp steps
wtomin Feb 5, 2026
2e43083
Apply suggestion from @dongbo910220
wtomin Feb 5, 2026
eaf0f53
update cache_dit doc
wtomin Feb 5, 2026
cae0d5c
update tea_cache doc
wtomin Feb 5, 2026
dfc52d9
update cfg-parallel
wtomin Feb 5, 2026
e1b13e5
update all
wtomin Feb 5, 2026
da68983
compress code snippet
wtomin Feb 6, 2026
cc1aae0
add parameter in doc
wtomin Feb 6, 2026
3b273df
pre-commit error
wtomin Feb 6, 2026
014514f
Update docs/.nav.yml
wtomin Feb 9, 2026
87ef30d
Attention recursion fix
wtomin Feb 9, 2026
24d1fe1
example request in dictionary
wtomin Feb 9, 2026
fb5faa7
shorten titles
wtomin Feb 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/.nav.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ nav:
- contributing/model/README.md
- contributing/model/adding_omni_model.md
- contributing/model/adding_diffusion_model.md
- Advanced Features:
- contributing/features/cfg_parallel.md
- contributing/features/sequence_parallel.md
- contributing/features/tensor_parallel.md
- contributing/features/cache_dit.md
- contributing/features/teacache.md
- CI: contributing/ci
- Design Documents:
- design/index.md
Expand Down
286 changes: 286 additions & 0 deletions docs/contributing/features/cache_dit.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
# Support Cache-DiT

This section describes how to add cache-dit acceleration to a new diffusion pipeline. We use the Qwen-Image pipeline and LongCat-Image pipeline as reference implementations.

---

## Table of Contents

- [Overview](#overview)
- [Standard Models: Automatic Support](#standard-models-automatic-support)
- [Custom Architectures: Writing Custom Implementation](#custom-architectures-writing-custom-implementation)
- [Testing](#testing)
- [Troubleshooting](#troubleshooting)
- [Reference Implementations](#reference-implementations)
- [Summary](#summary)

---

## Overview

### What is Cache-DiT?

Cache-DiT is an acceleration library for Diffusion Transformers (DiT) that caches intermediate computation results across denoising steps. The core insight is that adjacent denoising steps often produce similar intermediate features, so we can skip redundant computations by reusing cached results.

The library supports three main caching strategies:

- **DBCache:** Dynamic block-level caching that selectively computes or caches transformer blocks based on residual differences
- **TaylorSeer:** Calibration-based prediction that estimates block outputs using Taylor expansion
- **SCM (Step Computation Masking):** Dynamic step skipping based on configurable policies

### Architecture

vLLM-omni integrates cache-dit through the `CacheDiTBackend` class, which provides a unified interface for managing cache-dit acceleration on diffusion models.

| Method/Class | Purpose | Behavior |
|--------------|---------|----------|
| [`CacheDiTBackend`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/cache/#vllm_omni.diffusion.cache.CacheBackend) | Unified backend interface | Automatically handles enabler selection and cache refresh |
| [`enable_cache_for_dit()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/cache/cache_dit_backend/#vllm_omni.diffusion.cache.cache_dit_backend.enable_cache_for_dit) | Apply caching to transformer | Configures DBCache on transformer blocks |

**Key APIs from Cache-DiT:**

[Cache-DiT API Reference](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/)

| API | Description |
|-----|-------------|
| `BlockAdapter` | Core abstraction for applying cache-dit to transformers. Specifies transformer module(s), block list(s), and forward signature pattern(s). |
| `ForwardPattern` | Defines block forward signature patterns: `Pattern_0`, `Pattern_1`, `Pattern_2` |
| `ParamsModifier` | Per-transformer or per-block-list cache configuration customization |
| `DBCacheConfig` | Configuration for DBCache parameters (warmup steps, cached steps, thresholds) |
| `refresh_context()` | Update cache context | Called when `num_inference_steps` changes |

---

## Standard Models: Automatic Support

Most DiT models follow this pattern:
- Single transformer with one `ModuleList` of blocks
- Standard forward signature
- Compatible with cache-dit's automatic detection

**Examples:** Qwen-Image, Z-Image

For standard single-transformer models, **no code changes are needed**. The `CacheDiTBackend` automatically uses `enable_cache_for_dit()`:

```python
from vllm_omni import Omni

# Works automatically for standard models
omni = Omni(
model="Qwen/Qwen-Image", # Standard single-transformer model
cache_backend="cache_dit",
cache_config={
"Fn_compute_blocks": 1,
"Bn_compute_blocks": 0,
"max_warmup_steps": 4,
}
)
```

**What happens automatically:**

```python
def enable_cache_for_dit(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Default enabler for standard single-transformer DiT models."""

# Build cache configuration
db_cache_config = DBCacheConfig(
num_inference_steps=None, # Will be set during first inference
Fn_compute_blocks=cache_config.Fn_compute_blocks,
Bn_compute_blocks=cache_config.Bn_compute_blocks,
max_warmup_steps=cache_config.max_warmup_steps,
max_cached_steps=cache_config.max_cached_steps,
max_continuous_cached_steps=cache_config.max_continuous_cached_steps,
residual_diff_threshold=cache_config.residual_diff_threshold,
)

# Enable cache-dit on transformer
cache_dit.enable_cache(
pipeline.transformer,
cache_config=db_cache_config,
)

# Return refresh function for dynamic num_inference_steps updates
def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True):
cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)

return refresh_cache_context
```

---

## Custom Architectures: Writing Custom Implementation

Some models require custom handling:

- **Dual-transformer:** Models with separate high-noise and low-noise transformers (e.g., Wan2.2)
- **Multi-block-list:** Models with multiple block lists in one transformer (e.g., LongCatImage with `transformer_blocks` + `single_transformer_blocks`)
- **Special forward patterns:** Models with non-standard block execution patterns

### Example 1: Dual-Transformer Model (Wan2.2)

Wan2.2 uses two transformers: one for high-noise steps and one for low-noise steps.

**Key difference:** Use `BlockAdapter` to wrap multiple transformers with separate configurations.

```python
# Standard: cache_dit.enable_cache(pipeline.transformer, ...)
# Custom: Use BlockAdapter to handle multiple transformers
cache_dit.enable_cache(
BlockAdapter(
transformer=[pipeline.transformer, pipeline.transformer_2], # Multiple transformers
blocks=[pipeline.transformer.blocks, pipeline.transformer_2.blocks],
forward_pattern=[ForwardPattern.Pattern_2, ForwardPattern.Pattern_2],
params_modifiers=[
ParamsModifier(...), # Config for high-noise transformer
ParamsModifier(...), # Config for low-noise transformer (different params)
],
),
cache_config=db_cache_config,
)
```

**Key difference:** `refresh_context` must be called on each transformer separately.

```python
# Standard: cache_dit.refresh_context(pipeline.transformer, num_inference_steps=N)
# Custom: Refresh each transformer with its own step count
def refresh_cache_context(pipeline, num_inference_steps, verbose=True):
high_steps, low_steps = _split_inference_steps(num_inference_steps)
cache_dit.refresh_context(pipeline.transformer, num_inference_steps=high_steps, ...)
cache_dit.refresh_context(pipeline.transformer_2, num_inference_steps=low_steps, ...)
```

### Example 2: Multi-Block-List Model (LongCatImage)

LongCatImage has a single transformer with two block lists: `transformer_blocks` and `single_transformer_blocks`.

**Key difference:** Use `BlockAdapter` to specify multiple block lists within one transformer.

```python
# Standard: cache_dit.enable_cache(pipeline.transformer, ...)
# - Automatically detects single block list
# Custom: Use BlockAdapter to specify multiple block lists
cache_dit.enable_cache(
BlockAdapter(
transformer=pipeline.transformer, # Single transformer
blocks=[
pipeline.transformer.transformer_blocks, # Block list 1
pipeline.transformer.single_transformer_blocks, # Block list 2
],
forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1],
params_modifiers=[modifier],
),
cache_config=db_cache_config,
)
```

> **Note:** For single transformer with multiple block lists, `refresh_context` works the same as standard models.

### Registering Custom Implementations

After writing your custom enabler, register it in `CUSTOM_DIT_ENABLERS` in `vllm_omni/diffusion/cache/cache_dit_backend.py`:

```python
CUSTOM_DIT_ENABLERS = {
"Wan22Pipeline": enable_cache_for_wan22,
"LongCatImagePipeline": enable_cache_for_longcat_image,
"YourCustomPipeline": enable_cache_for_your_model, # Add here
}
```

---

## Testing

After adding cache-dit support, test with:

```python
from vllm_omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams

# Test your custom model
omni = Omni(
model="your-model-name",
cache_backend="cache_dit",
cache_config={
"Fn_compute_blocks": 1,
"Bn_compute_blocks": 0,
"max_warmup_steps": 4,
"residual_diff_threshold": 0.24,
}
)

images = omni.generate(
"a beautiful landscape",
OmniDiffusionSamplingParams(num_inference_steps=50),
)
```

**Verify:**

1. Cache is applied (check logs for "Cache-dit enabled successfully on xxx")
2. Performance improvement (should be around 1.5x-2x faster)
3. Image quality (compare with `cache_backend=None`)

---

## Troubleshooting

### Issue: Cache not applied

**Symptoms:** No speedup observed, no cache-related log messages.

**Causes & Solutions:**

- **Enabler not registered:**

**Problem:** Pipeline name not in `CUSTOM_DIT_ENABLERS` registry.

**Solution:** Verify `pipeline.__class__.__name__` matches the registry key and add your enabler to `CUSTOM_DIT_ENABLERS`.

### Issue: Quality degradation

**Symptoms:** Generated images have artifacts or lower quality compared to non-cached inference.

**Causes & Solutions:**

- **Cache parameters too aggressive:**

**Solution:**
```python
cache_config={
"residual_diff_threshold": 0.12, # Lower from 0.24 (try 0.12-0.18)
"max_warmup_steps": 6, # Increase from 4 (try 6-8)
"max_continuous_cached_steps": 2, # Reduce if higher
}
```

Check the [user guide for cache_dit](../../user_guide/diffusion/cache_dit_acceleration.md) for more adjustable parameters.

---

## Reference Implementations

Complete examples in the codebase:

| Model | Path | Pattern | Notes |
|-------|------|---------|-------|
| **Standard DiT** | [`cache_dit_backend.py::enable_cache_for_dit`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/cache/cache_dit_backend/#vllm_omni.diffusion.cache.cache_dit_backend.enable_cache_for_dit) | Default enabler | Single transformer, automatic |
| **Wan2.2** | [`cache_dit_backend.py::enable_cache_for_wan22`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/cache/cache_dit_backend/#vllm_omni.diffusion.cache.cache_dit_backend.enable_cache_for_wan22) | Dual-transformer | Separate high/low noise transformers |
| **LongCat** | [`cache_dit_backend.py::enable_cache_for_longcat_image`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/cache/cache_dit_backend/#vllm_omni.diffusion.cache.cache_dit_backend.enable_cache_for_longcat_image) | Multi-block-list | Two block lists in one transformer |
| **BAGEL** | [`cache_dit_backend.py::enable_cache_for_bagel`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/cache/cache_dit_backend/#vllm_omni.diffusion.cache.cache_dit_backend.enable_cache_for_bagel) | Omni model | Complex architecture |

---

## Summary

Adding cache-dit support:

1. ✅ **Check model type** - Standard models work automatically, custom architectures need enablers
2. ✅ **Write enabler** (if needed) - Use `BlockAdapter` for complex architectures
3. ✅ **Register enabler** (if needed) - Add to `CUSTOM_DIT_ENABLERS` dictionary
4. ✅ **Return refresh function** (if needed) - Handle `num_inference_steps` changes
5. ✅ **Test** - Verify with `cache_backend="cache_dit"`

For most models, the default enabler is sufficient. Only write custom enablers for complex architectures!
Loading