Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
b7ba59d
init
eitanturok Mar 5, 2026
e7f629c
fix from #19964
eitanturok Mar 5, 2026
cfa1e57
Update python/sglang/multimodal_gen/runtime/cache/teacache.py
eitanturok Mar 5, 2026
23c7ee4
dont need params here
eitanturok Mar 5, 2026
a82e797
Merge branch 'sgl-project:main' into teacache-refactor
eitanturok Mar 5, 2026
3ed92f8
Merge branch 'main' into teacache-refactor
eitanturok Mar 7, 2026
57689f5
update
eitanturok Mar 7, 2026
520a101
update
eitanturok Mar 7, 2026
bf765a2
Merge branch 'main' into teacache-refactor
eitanturok Mar 9, 2026
6c70b43
update
eitanturok Mar 9, 2026
eb0e774
Merge branch 'teacache-refactor' of https://github.com/eitanturok/sgl…
eitanturok Mar 9, 2026
61825d5
precommit + start, end skipping steps
eitanturok Mar 10, 2026
47c4b66
update assert
eitanturok Mar 10, 2026
4bafebc
Merge branch 'main' into teacache-refactor
eitanturok Mar 10, 2026
037d347
remove icecream
eitanturok Mar 10, 2026
b16a1ef
better docstring
eitanturok Mar 10, 2026
99904e2
inhereit from diffusioncache
eitanturok Mar 10, 2026
98b0bd5
better docs
eitanturok Mar 10, 2026
f0518b5
track cnt in state
eitanturok Mar 10, 2026
b035ba8
teacache takes in modulated_input
eitanturok Mar 10, 2026
8230b23
better docs, comments
eitanturok Mar 10, 2026
1d1ef72
no enable_cache
eitanturok Mar 10, 2026
728d89a
teacache for hunyanvideo
eitanturok Mar 10, 2026
853607f
better docs
eitanturok Mar 10, 2026
bebdb92
better docs
eitanturok Mar 10, 2026
0d687e9
better docs
eitanturok Mar 10, 2026
c9ae288
Merge branch 'main' into teacache-refactor
eitanturok Mar 11, 2026
ba322aa
Merge branch 'main' into teacache-refactor
eitanturok Mar 11, 2026
3430d0f
Merge branch 'sgl-project:main' into teacache-refactor
eitanturok Mar 12, 2026
1da6ed0
fix start_skip > end_skip
eitanturok Mar 12, 2026
cc64799
Merge branch 'main' into teacache-refactor
eitanturok Mar 12, 2026
069b532
update perf
eitanturok Mar 12, 2026
062c1e8
Merge branch 'main' into teacache-refactor
eitanturok Mar 12, 2026
462bd5e
Merge branch 'main' into teacache-refactor
eitanturok Mar 12, 2026
207d2e7
Merge branch 'sgl-project:main' into teacache-refactor
eitanturok Mar 13, 2026
28e38ed
state.step starts at 0
eitanturok Mar 13, 2026
1c23e44
rm comment
eitanturok Mar 13, 2026
667623c
Merge branch 'sgl-project:main' into teacache-refactor
eitanturok Mar 15, 2026
cb5f2b5
Merge branch 'sgl-project:main' into teacache-refactor
eitanturok Mar 16, 2026
4f1f83d
Merge branch 'main' into teacache-refactor
eitanturok Mar 28, 2026
cdabdab
update
eitanturok Mar 28, 2026
1cdd5f4
Merge branch 'main' into teacache-refactor
eitanturok Apr 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/diffusion/performance/cache/teacache.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ TeaCache is configured via `TeaCacheParams` in the sampling parameters:
from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams

params = TeaCacheParams(
teacache_thresh=0.1, # Threshold for accumulated L1 distance
rel_l1_thresh=0.1, # Threshold for accumulated L1 distance
coefficients=[1.0, 0.0, 0.0], # Polynomial coefficients for L1 rescaling
)
```
Expand All @@ -59,7 +59,7 @@ params = TeaCacheParams(

| Parameter | Type | Description |
|-----------|------|-------------|
| `teacache_thresh` | float | Threshold for accumulated L1 distance. Lower = more caching, faster but potentially lower quality |
| `rel_l1_thresh` | float | Threshold for accumulated L1 distance. Lower = more caching, faster but potentially lower quality |
| `coefficients` | list[float] | Polynomial coefficients for L1 rescaling. Model-specific tuning |

### Model-Specific Configurations
Expand All @@ -73,7 +73,7 @@ TeaCache is built into the following model families:
| Model Family | CFG Cache Separation | Notes |
|--------------|---------------------|-------|
| Wan (wan2.1, wan2.2) | Yes | Full support |
| Hunyuan (HunyuanVideo) | Yes | To be supported |
| Hunyuan (HunyuanVideo) | Yes | Full support |
| Z-Image | Yes | To be supported |
| Flux | No | To be supported |
| Qwen | No | To be supported |
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/multimodal_gen/configs/sample/hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class HunyuanSamplingParams(SamplingParams):

teacache_params: TeaCacheParams = field(
default_factory=lambda: TeaCacheParams(
teacache_thresh=0.15,
rel_l1_thresh=0.15,
# from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4HunyuanVideo/teacache_sample_video.py#L222
coefficients=[
7.33226126e02,
Expand Down
13 changes: 9 additions & 4 deletions python/sglang/multimodal_gen/configs/sample/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,8 @@ class SamplingParams:

# TeaCache parameters
enable_teacache: bool = False
teacache_params: Any = (
None # TeaCacheParams or WanTeaCacheParams, set by model-specific subclass
)
cache_params: Any | None = None
calibrate_cache: bool = False

# Profiling
profile: bool = False
Expand Down Expand Up @@ -615,6 +614,12 @@ def add_argument(*name_or_flags, **kwargs):
"--enable-teacache",
action="store_true",
)
parser.add_argument(
"--calibrate-cache",
action="store_true",
default=SamplingParams.calibrate_cache,
help="Run in calibration mode: collect magnitude ratio statistics instead of skipping steps.",
)

# profiling
add_argument(
Expand Down Expand Up @@ -971,4 +976,4 @@ def n_tokens(self) -> int:

@dataclass
class CacheParams:
cache_type: str = "none"
pass
4 changes: 2 additions & 2 deletions python/sglang/multimodal_gen/configs/sample/teacache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TeaCacheParams(CacheParams):
Attributes:
cache_type: (`str`, defaults to `teacache`):
A string labeling these parameters as belonging to teacache.
teacache_thresh (`float`, defaults to `0.0`):
rel_l1_thresh (`float`, defaults to `0.0`):
Threshold for accumulated relative L1 distance. When below this threshold, the
forward pass is skipped. Recommended values: 0.25 for ~1.5x speedup, 0.4 for ~1.8x,
0.6 for ~2.0x.
Expand Down Expand Up @@ -48,7 +48,7 @@ class TeaCacheParams(CacheParams):
"""

cache_type: str = "teacache"
teacache_thresh: float = 0.0
rel_l1_thresh: float = 0.0
start_skipping: int | float = 5
end_skipping: int | float = -1
coefficients: list[float] = field(default_factory=list)
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/multimodal_gen/configs/sample/wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class WanT2V_1_3B_SamplingParams(SamplingParams):

teacache_params: TeaCacheParams = field(
default_factory=lambda: TeaCacheParams(
teacache_thresh=0.08,
rel_l1_thresh=0.08,
use_ret_steps=True,
coefficients_callback=_wan_1_3b_coefficients,
start_skipping=5,
Expand Down Expand Up @@ -102,7 +102,7 @@ class WanT2V_14B_SamplingParams(SamplingParams):

teacache_params: TeaCacheParams = field(
default_factory=lambda: TeaCacheParams(
teacache_thresh=0.20,
rel_l1_thresh=0.20,
use_ret_steps=False,
coefficients_callback=_wan_14b_coefficients,
start_skipping=1,
Expand All @@ -128,7 +128,7 @@ class WanI2V_14B_480P_SamplingParam(WanT2V_1_3B_SamplingParams):

teacache_params: TeaCacheParams = field(
default_factory=lambda: TeaCacheParams(
teacache_thresh=0.26,
rel_l1_thresh=0.26,
use_ret_steps=True,
coefficients_callback=_wan_14b_coefficients,
start_skipping=5,
Expand Down Expand Up @@ -156,7 +156,7 @@ class WanI2V_14B_720P_SamplingParam(WanT2V_14B_SamplingParams):

teacache_params: TeaCacheParams = field(
default_factory=lambda: TeaCacheParams(
teacache_thresh=0.3,
rel_l1_thresh=0.3,
use_ret_steps=True,
coefficients_callback=_wan_14b_coefficients,
start_skipping=5,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/multimodal_gen/configs/sample/zimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ZImageTurboSamplingParams(SamplingParams):

teacache_params: TeaCacheParams = field(
default_factory=lambda: TeaCacheParams(
teacache_thresh=0.15,
rel_l1_thresh=0.15,
coefficients=[
7.33226126e02,
-4.01131952e02,
Expand Down
12 changes: 9 additions & 3 deletions python/sglang/multimodal_gen/runtime/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,24 @@

"""

from sglang.multimodal_gen.runtime.cache.base import DiffusionCache
from sglang.multimodal_gen.runtime.cache.cache_dit_integration import (
CacheDitConfig,
enable_cache_on_dual_transformer,
enable_cache_on_transformer,
get_scm_mask,
)
from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheContext, TeaCacheMixin
from sglang.multimodal_gen.runtime.cache.teacache import (
TeaCacheState,
TeaCacheStrategy,
)

__all__ = [
# Base
"DiffusionCache",
# TeaCache (always available)
"TeaCacheContext",
"TeaCacheMixin",
"TeaCacheState",
"TeaCacheStrategy",
# cache-dit integration (lazy-loaded, requires cache-dit package)
"CacheDitConfig",
"enable_cache_on_transformer",
Expand Down
66 changes: 66 additions & 0 deletions python/sglang/multimodal_gen/runtime/cache/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod

import torch


class DiffusionCache(ABC):
"""Base class for managing diffusion timestep caching.

Subclasses define specific strategies for deciding when to skip
computation and how to store/retrieve hidden states.
"""

@abstractmethod
def maybe_reset(self, **kwargs) -> None:
"""Resets the internal cache state for a new generation sequence.

Args:
**kwargs: Additional parameters that may be helpful.
"""

@abstractmethod
def should_skip(self, **kwargs) -> bool:
"""Determines if the current timestep computation can be skipped.

Args:
**kwargs: Additional parameters that may be helpful.

Returns:
bool: True if the timestep should be skipped, False otherwise.
"""

@abstractmethod
def write(
self,
hidden_states: torch.Tensor,
original_hidden_states: torch.Tensor,
**kwargs
) -> None:
"""Cache the result of a full forward pass to the cache state.

Args:
hidden_states: Output of the transformer blocks.
original_hidden_states: Input from before the transformer blocks.
**kwargs: Additional parameters that may be helpful.
"""

@abstractmethod
def read(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
"""Computes an approximation of the forward pass using cached data. Reads from the cache.

Args:
hidden_states: The current input/intermediate hidden states.
**kwargs: Additional parameters for the retrieval strategy.

Returns:
torch.Tensor: The approximated output of the forward pass.
"""

def calibrate(self, **kwargs) -> None:
"""Performs a calibration step to learn cache thresholds or values.

Args:
**kwargs: Additional parameters that may be helpful.
"""
pass
Loading
Loading