Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
d6a1a27
prints for teacache
eitanturok Feb 8, 2026
e007e8d
AI slop, but I unserstand parts of it
eitanturok Feb 8, 2026
dbb9b92
magcache should_skip_forward is sometimes true
eitanturok Feb 9, 2026
0cff136
faster but wrong
eitanturok Feb 9, 2026
93b0f3a
i think it wokrs
eitanturok Feb 9, 2026
f9323bb
update
eitanturok Feb 16, 2026
12247da
Merge branch 'main' into magcache2
eitanturok Feb 18, 2026
161635d
update
eitanturok Feb 18, 2026
9995c51
update
eitanturok Feb 18, 2026
905617a
update
eitanturok Feb 19, 2026
1b0ed0c
magcachemixin2
eitanturok Feb 19, 2026
c4df4f4
init calibration
eitanturok Feb 19, 2026
219273f
teacache works again
eitanturok Feb 19, 2026
2600608
clean a little
eitanturok Feb 19, 2026
4d858b8
clean up
eitanturok Feb 19, 2026
995105f
update
eitanturok Feb 20, 2026
b461250
magcache_param is sample param, not server arg
eitanturok Feb 20, 2026
ac5f832
teacache is sample param, not server arg
eitanturok Feb 20, 2026
cbb46a5
start adding context
eitanturok Feb 20, 2026
2b6ca62
update
eitanturok Feb 20, 2026
e357f33
cleanup
eitanturok Feb 20, 2026
f104d36
update
eitanturok Feb 20, 2026
5031758
update
eitanturok Feb 20, 2026
ee618b9
cache
eitanturok Feb 20, 2026
a606ae2
Merge branch 'main' into magcache2
eitanturok Feb 22, 2026
7ac16de
fix cache directory for mag cache
eitanturok Feb 22, 2026
c749010
enable calibration from cmd line
eitanturok Feb 22, 2026
551a36f
refactor teacache
eitanturok Feb 23, 2026
2ea1fee
Merge branch 'main' into magcache2
eitanturok Mar 1, 2026
748f862
add cache state
eitanturok Mar 1, 2026
63a0692
cleanup init
eitanturok Mar 1, 2026
f220720
cleanup
eitanturok Mar 1, 2026
b84a771
update
eitanturok Mar 2, 2026
109e960
update
eitanturok Mar 2, 2026
9a74833
change api
eitanturok Mar 4, 2026
2a57813
add calibrate cache to cmd args
eitanturok Mar 4, 2026
f89d726
remove cache
eitanturok Mar 4, 2026
cd2572d
update
eitanturok Mar 4, 2026
79ea570
use hardcoded params by default
eitanturok Mar 4, 2026
267adaa
cleanup
eitanturok Mar 4, 2026
35a3434
clenaup context, typcheck
eitanturok Mar 4, 2026
44c64b4
skip_start_step, skip_end_step
eitanturok Mar 4, 2026
0633ba5
Merge branch 'main' into magcache2
eitanturok Mar 4, 2026
e0ab2c1
Merge branch 'main' into magcache2
eitanturok Mar 4, 2026
d03159c
update
eitanturok Mar 4, 2026
5389755
update
eitanturok Mar 4, 2026
b8b3f8f
Merge branch 'sgl-project:main' into magcache2
eitanturok Mar 5, 2026
12bd5c7
Merge branch 'sgl-project:main' into magcache2
eitanturok Mar 5, 2026
07b2d5f
run precommit
eitanturok Mar 6, 2026
9134dd4
run black-jupyter
eitanturok Mar 6, 2026
68e2823
update
eitanturok Mar 6, 2026
e70b788
Merge branch 'main' into magcache2
eitanturok Mar 9, 2026
9cda094
Merge branch 'main' into magcache2
eitanturok Mar 9, 2026
6d4be99
Merge branch 'main' into magcache2
eitanturok Mar 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions python/sglang/multimodal_gen/configs/sample/magcache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass

from sglang.multimodal_gen.configs.sample.sampling_params import CacheParams


@dataclass
class MagCacheParams(CacheParams):
"""
MagCache configuration for magnitude-ratio-based caching.

MagCache accelerates diffusion inference by skipping forward passes when
magnitude ratios of consecutive residuals are predictably similar.

Attributes:
threshold: Accumulated error threshold (default 0.06 from paper).
Lower = higher quality but slower. Higher = faster but lower quality.
max_skip_steps: Maximum consecutive skips allowed (default 3).
Prevents infinite skipping even if error is low.
skip_start_step: Number of denoising steps at the start where skipping is disabled.
skip_end_step: Number of denoising steps at the end where skipping is disabled (0 = active until last step).
"""

cache_type: str = "magcache"
threshold: float = 0.12
max_skip_steps: int = 4
skip_start_step: int = 10
skip_end_step: int = 0
mag_ratios: list[float] | None = None
39 changes: 38 additions & 1 deletion python/sglang/multimodal_gen/configs/sample/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
logger = init_logger(__name__)

if TYPE_CHECKING:
from sglang.multimodal_gen.configs.sample.magcache import MagCacheParams
from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams
from sglang.multimodal_gen.runtime.server_args import ServerArgs


Expand Down Expand Up @@ -154,8 +156,12 @@ class SamplingParams:
cfg_normalization: float | bool = 0.0
boundary_ratio: float | None = None

# TeaCache parameters
# Cache acceleration
enable_teacache: bool = False
teacache_params: "TeaCacheParams | None" = None
enable_magcache: bool = False
magcache_params: "MagCacheParams | None" = None
calibrate_cache: bool = False

# Profiling
profile: bool = False
Expand Down Expand Up @@ -601,6 +607,37 @@ def add_cli_args(parser: Any) -> Any:
"--enable-teacache",
action="store_true",
default=SamplingParams.enable_teacache,
help="Enable TeaCache acceleration for diffusion inference.",
)
parser.add_argument(
"--teacache-params",
type=json.loads,
default=None,
help=(
'TeaCache params as a JSON object, e.g. \'{"teacache_thresh": 0.08, "coefficients": [1.0, 2.0]}\'. '
"Fields map directly to TeaCacheParams dataclass fields."
),
)
parser.add_argument(
"--enable-magcache",
action="store_true",
default=SamplingParams.enable_magcache,
help="Enable MagCache acceleration for diffusion inference.",
)
parser.add_argument(
"--magcache-params",
type=json.loads,
default=None,
help=(
'MagCache params as a JSON object, e.g. \'{"threshold": 0.12, "max_skip_steps": 4}\'. '
"Fields map directly to MagCacheParams dataclass fields."
),
)
parser.add_argument(
"--calibrate-cache",
action="store_true",
default=SamplingParams.calibrate_cache,
help="Run in calibration mode: collect magnitude ratio statistics instead of skipping steps.",
)

# profiling
Expand Down
42 changes: 25 additions & 17 deletions python/sglang/multimodal_gen/configs/sample/teacache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,42 @@
class TeaCacheParams(CacheParams):
cache_type: str = "teacache"
teacache_thresh: float = 0.0
skip_start_step: int = 5
skip_end_step: int = 0
coefficients: list[float] = field(default_factory=list)


@dataclass
class WanTeaCacheParams(CacheParams):
# Unfortunately, TeaCache is very different for Wan than other models
# Default threshold and coefficients are for Wan T2V 1.3B (use_ret_steps=True).
# For other Wan variants, override these values via --teacache-params.
cache_type: str = "teacache"
teacache_thresh: float = 0.0
teacache_thresh: float = 0.08
skip_start_step: int = 5
skip_end_step: int = 0
use_ret_steps: bool = True
ret_steps_coeffs: list[float] = field(default_factory=list)
non_ret_steps_coeffs: list[float] = field(default_factory=list)
ret_steps_coeffs: list[float] = field(
default_factory=lambda: [
-5.21862437e04,
9.23041404e03,
-5.28275948e02,
1.36987616e01,
-4.99875664e-02,
]
)
non_ret_steps_coeffs: list[float] = field(
default_factory=lambda: [
2.39676752e03,
-1.31110545e03,
2.01331979e02,
-8.29855975e00,
1.37887774e-01,
]
)

@property
def coefficients(self) -> list[float]:
if self.use_ret_steps:
return self.ret_steps_coeffs
else:
return self.non_ret_steps_coeffs

@property
def ret_steps(self) -> int:
if self.use_ret_steps:
return 5 * 2
else:
return 1 * 2

def get_cutoff_steps(self, num_inference_steps: int) -> int:
if self.use_ret_steps:
return num_inference_steps * 2
else:
return num_inference_steps * 2 - 2
116 changes: 116 additions & 0 deletions python/sglang/multimodal_gen/configs/sample/wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,115 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field

from sglang.multimodal_gen.configs.sample.magcache import MagCacheParams
from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams
from sglang.multimodal_gen.configs.sample.teacache import WanTeaCacheParams

# Magnitude ratio arrays from the reference implementation:
# https://github.com/Zehong-Ma/MagCache/blob/df81cb181776c2c61477c08e1d21f87fda1cd938/MagCache4Wan2.1/magcache_generate.py
T2V_13B_MAG_RATIOS = [
1.0,
1.0,
1.0124,
1.02213,
1.00166,
1.0041,
0.99791,
1.00061,
0.99682,
0.99762,
0.99634,
0.99685,
0.99567,
0.99586,
0.99416,
0.99422,
0.99578,
0.99575,
0.9957,
0.99563,
0.99511,
0.99506,
0.99535,
0.99531,
0.99552,
0.99549,
0.99541,
0.99539,
0.9954,
0.99536,
0.99489,
0.99485,
0.99518,
0.99514,
0.99484,
0.99478,
0.99481,
0.99479,
0.99415,
0.99413,
0.99419,
0.99416,
0.99396,
0.99393,
0.99388,
0.99386,
0.99349,
0.99349,
0.99309,
0.99304,
0.9927,
0.9927,
0.99228,
0.99226,
0.99171,
0.9917,
0.99137,
0.99135,
0.99068,
0.99063,
0.99005,
0.99003,
0.98944,
0.98942,
0.98849,
0.98849,
0.98758,
0.98757,
0.98644,
0.98643,
0.98504,
0.98503,
0.9836,
0.98359,
0.98202,
0.98201,
0.97977,
0.97978,
0.97717,
0.97718,
0.9741,
0.97411,
0.97003,
0.97002,
0.96538,
0.96541,
0.9593,
0.95933,
0.95086,
0.95089,
0.94013,
0.94019,
0.92402,
0.92414,
0.90241,
0.9026,
0.86821,
0.86868,
0.81838,
0.81939,
]


@dataclass
class WanT2V_1_3B_SamplingParams(SamplingParams):
Expand Down Expand Up @@ -50,6 +156,16 @@ class WanT2V_1_3B_SamplingParams(SamplingParams):
)
)

magcache_params: MagCacheParams = field(
default_factory=lambda: MagCacheParams(
threshold=0.12,
max_skip_steps=4,
skip_start_step=10,
skip_end_step=0,
mag_ratios=T2V_13B_MAG_RATIOS,
)
)


@dataclass
class WanT2V_14B_SamplingParams(SamplingParams):
Expand Down
24 changes: 21 additions & 3 deletions python/sglang/multimodal_gen/runtime/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,40 @@
diffusion transformer (DiT) inference:

- TeaCache: Temporal similarity-based caching for diffusion models
- MagCache: Magnitude-ratio-based caching for diffusion models
- cache-dit integration: Block-level caching with DBCache and TaylorSeer

"""

from sglang.multimodal_gen.runtime.cache.base import DiffusionCache
from sglang.multimodal_gen.runtime.cache.cache_dit_integration import (
CacheDitConfig,
enable_cache_on_dual_transformer,
enable_cache_on_transformer,
get_scm_mask,
)
from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheContext, TeaCacheMixin
from sglang.multimodal_gen.runtime.cache.magcache import (
MagCacheContext,
MagCacheState,
MagCacheStrategy,
)
from sglang.multimodal_gen.runtime.cache.teacache import (
TeaCacheContext,
TeaCacheState,
TeaCacheStrategy,
)

__all__ = [
# TeaCache (always available)
# Base
"DiffusionCache",
# TeaCache
"TeaCacheContext",
"TeaCacheMixin",
"TeaCacheState",
"TeaCacheStrategy",
# MagCache
"MagCacheContext",
"MagCacheState",
"MagCacheStrategy",
# cache-dit integration (lazy-loaded, requires cache-dit package)
"CacheDitConfig",
"enable_cache_on_transformer",
Expand Down
Loading
Loading