Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/multimodal_gen/configs/sample/hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class HunyuanSamplingParams(SamplingParams):
teacache_params: TeaCacheParams = field(
default_factory=lambda: TeaCacheParams(
teacache_thresh=0.15,
# from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4HunyuanVideo/teacache_sample_video.py#L222
coefficients=[
7.33226126e02,
-4.01131952e02,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,16 @@ def _json_safe(obj: Any):
"""
Recursively convert objects to JSON-serializable forms.
- Enums -> their name
- Callables -> stable module-qualified name
- Sets/Tuples -> lists
- Dicts/Lists -> recursively processed
"""
if isinstance(obj, Enum):
return obj.name
if callable(obj):
module = getattr(obj, "__module__", None)
qualname = getattr(obj, "__qualname__", getattr(obj, "__name__", repr(obj)))
return f"{module}.{qualname}" if module else qualname
if isinstance(obj, dict):
return {k: _json_safe(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple, set)):
Expand Down
95 changes: 67 additions & 28 deletions python/sglang/multimodal_gen/configs/sample/teacache.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,82 @@
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo

# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Callable

from sglang.multimodal_gen.configs.sample.sampling_params import CacheParams


@dataclass
class TeaCacheParams(CacheParams):
"""
Parameters for [TeaCache](https://arxiv.org/abs/2411.14324).

Attributes:
cache_type: (`str`, defaults to `teacache`):
A string labeling these parameters as belonging to teacache.
teacache_thresh (`float`, defaults to `0.0`):
Threshold for accumulated relative L1 distance. When below this threshold, the
forward pass is skipped. Recommended values: 0.25 for ~1.5x speedup, 0.4 for ~1.8x,
0.6 for ~2.0x.
start_skipping (`int` or `float`, defaults to `5`):
The number of timesteps after which we may skip a forward pass. These early
steps define the global structure and are too critical to not skip.
int: The number of timesteps after which we can skip. If negative,
this is an offset from the end of the schedule.
float (0.0 - 1.0): A percentage of the total steps (e.g., 0.1
computes the first 10%).
end_skipping (`int` or `float`, defaults to `-1`):
The number of timesteps after which we are no longer able to skip
forward passes. The last steps refine fine textures and details.
int: The number of timesteps after which skipping ends. If negative,
this is an offset from the total number of steps.
float (0.0 - 1.0): A percentage of the total steps (e.g., 0.1
computes the first 10%).
coefficients (`List[float]`, defaults to `[]`):
Polynomial coefficients for rescaling the raw relative L1 distance,
evaluated as `c[0]*x**4 + c[1]*x**3 + c[2]*x**2 + c[3]*x + c[4]`.
coefficients_callback (`Callable[[TeaCacheParams], List[float]]`, *optional*):
A function that receives this `TeaCacheParams` instance and returns
the polynomial coefficients to use. When set, it takes precedence over
the `coefficients` field, allowing dynamic coefficient selection based
on any property of the params (e.g., `use_ret_steps` for Wan models).
use_ret_steps: (`bool`, `None`, defaults to `None`):
Used exclusively for wanvideo models to select different modulated inputs.
"""

cache_type: str = "teacache"
teacache_thresh: float = 0.0
start_skipping: int | float = 5
end_skipping: int | float = -1
coefficients: list[float] = field(default_factory=list)
coefficients_callback: Callable[[TeaCacheParams], list[float]] | None = field(
default=None, repr=False
)
use_ret_steps: bool | None = None

def get_coefficients(self) -> list[float]:
if self.coefficients_callback is not None:
return self.coefficients_callback(self)
return self.coefficients

@dataclass
class WanTeaCacheParams(CacheParams):
# Unfortunately, TeaCache is very different for Wan than other models
cache_type: str = "teacache"
teacache_thresh: float = 0.0
use_ret_steps: bool = True
ret_steps_coeffs: list[float] = field(default_factory=list)
non_ret_steps_coeffs: list[float] = field(default_factory=list)

@property
def coefficients(self) -> list[float]:
if self.use_ret_steps:
return self.ret_steps_coeffs
else:
return self.non_ret_steps_coeffs

@property
def ret_steps(self) -> int:
if self.use_ret_steps:
return 5 * 2
else:
return 1 * 2

def get_cutoff_steps(self, num_inference_steps: int) -> int:
if self.use_ret_steps:
return num_inference_steps * 2
else:
return num_inference_steps * 2 - 2
def get_skip_boundaries(
self, num_inference_steps: int, do_cfg: bool
) -> tuple[int, int]:
def _resolve_boundary(value: int | float) -> int:
if isinstance(value, float):
return int(num_inference_steps * value)
if value < 0:
return num_inference_steps + value
return value

start_skipping = _resolve_boundary(self.start_skipping)
end_skipping = _resolve_boundary(self.end_skipping)

if do_cfg:
start_skipping *= 2
end_skipping *= 2

return start_skipping, end_skipping
123 changes: 58 additions & 65 deletions python/sglang/multimodal_gen/configs/sample/wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,41 @@
from dataclasses import dataclass, field

from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams
from sglang.multimodal_gen.configs.sample.teacache import WanTeaCacheParams
from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams


def _wan_1_3b_coefficients(p: TeaCacheParams) -> list[float]:
if p.use_ret_steps:
# from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4Wan2.1/teacache_generate.py#L883
return [
-5.21862437e04,
9.23041404e03,
-5.28275948e02,
1.36987616e01,
-4.99875664e-02,
]
# from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4Wan2.1/teacache_generate.py#L890
return [
2.39676752e03,
-1.31110545e03,
2.01331979e02,
-8.29855975e00,
1.37887774e-01,
]


def _wan_14b_coefficients(p: TeaCacheParams) -> list[float]:
if p.use_ret_steps:
# from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4Wan2.1/teacache_generate.py#L885
return [
-3.03318725e05,
4.90537029e04,
-2.65530556e03,
5.87365115e01,
-3.15583525e-01,
]
# from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4Wan2.1/teacache_generate.py#L892
return [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]


@dataclass
Expand All @@ -30,23 +64,13 @@ class WanT2V_1_3B_SamplingParams(SamplingParams):
]
)

teacache_params: WanTeaCacheParams = field(
default_factory=lambda: WanTeaCacheParams(
teacache_params: TeaCacheParams = field(
default_factory=lambda: TeaCacheParams(
teacache_thresh=0.08,
ret_steps_coeffs=[
-5.21862437e04,
9.23041404e03,
-5.28275948e02,
1.36987616e01,
-4.99875664e-02,
],
non_ret_steps_coeffs=[
2.39676752e03,
-1.31110545e03,
2.01331979e02,
-8.29855975e00,
1.37887774e-01,
],
use_ret_steps=True,
coefficients_callback=_wan_1_3b_coefficients,
start_skipping=5,
end_skipping=1.0,
)
)

Expand Down Expand Up @@ -76,24 +100,13 @@ class WanT2V_14B_SamplingParams(SamplingParams):
]
)

teacache_params: WanTeaCacheParams = field(
default_factory=lambda: WanTeaCacheParams(
teacache_params: TeaCacheParams = field(
default_factory=lambda: TeaCacheParams(
teacache_thresh=0.20,
use_ret_steps=False,
ret_steps_coeffs=[
-3.03318725e05,
4.90537029e04,
-2.65530556e03,
5.87365115e01,
-3.15583525e-01,
],
non_ret_steps_coeffs=[
-5784.54975374,
5449.50911966,
-1811.16591783,
256.27178429,
-13.02252404,
],
coefficients_callback=_wan_14b_coefficients,
start_skipping=1,
end_skipping=-1,
)
)

Expand All @@ -113,23 +126,13 @@ class WanI2V_14B_480P_SamplingParam(WanT2V_1_3B_SamplingParams):
]
)

teacache_params: WanTeaCacheParams = field(
default_factory=lambda: WanTeaCacheParams(
teacache_params: TeaCacheParams = field(
default_factory=lambda: TeaCacheParams(
teacache_thresh=0.26,
ret_steps_coeffs=[
-3.03318725e05,
4.90537029e04,
-2.65530556e03,
5.87365115e01,
-3.15583525e-01,
],
non_ret_steps_coeffs=[
-5784.54975374,
5449.50911966,
-1811.16591783,
256.27178429,
-13.02252404,
],
use_ret_steps=True,
coefficients_callback=_wan_14b_coefficients,
start_skipping=5,
end_skipping=1.0,
)
)

Expand All @@ -151,23 +154,13 @@ class WanI2V_14B_720P_SamplingParam(WanT2V_14B_SamplingParams):
]
)

teacache_params: WanTeaCacheParams = field(
default_factory=lambda: WanTeaCacheParams(
teacache_params: TeaCacheParams = field(
default_factory=lambda: TeaCacheParams(
teacache_thresh=0.3,
ret_steps_coeffs=[
-3.03318725e05,
4.90537029e04,
-2.65530556e03,
5.87365115e01,
-3.15583525e-01,
],
non_ret_steps_coeffs=[
-5784.54975374,
5449.50911966,
-1811.16591783,
256.27178429,
-13.02252404,
],
use_ret_steps=True,
coefficients_callback=_wan_14b_coefficients,
start_skipping=5,
end_skipping=1.0,
)
)

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/multimodal_gen/runtime/cache/teacache.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def _get_teacache_context(self) -> TeaCacheContext | None:
do_cfg=do_cfg,
is_cfg_negative=is_cfg_negative,
teacache_thresh=teacache_params.teacache_thresh,
coefficients=teacache_params.coefficients,
coefficients=teacache_params.get_coefficients(),
teacache_params=teacache_params,
)

Expand Down
23 changes: 6 additions & 17 deletions python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torch.nn as nn

from sglang.multimodal_gen.configs.models.dits import WanVideoConfig
from sglang.multimodal_gen.configs.sample.wan import WanTeaCacheParams
from sglang.multimodal_gen.runtime.distributed import (
divide,
get_sp_group,
Expand Down Expand Up @@ -1170,32 +1169,22 @@ def should_skip_forward_for_cached_states(self, **kwargs) -> bool:
if ctx is None:
return False

# Wan uses WanTeaCacheParams with additional fields
teacache_params = ctx.teacache_params
assert isinstance(
teacache_params, WanTeaCacheParams
), "teacache_params is not a WanTeaCacheParams"

# Initialize Wan-specific parameters
teacache_params = ctx.teacache_params
use_ret_steps = teacache_params.use_ret_steps
cutoff_steps = teacache_params.get_cutoff_steps(ctx.num_inference_steps)
ret_steps = teacache_params.ret_steps
start_skipping, end_skipping = teacache_params.get_skip_boundaries(
ctx.num_inference_steps, ctx.do_cfg
)

# Adjust ret_steps and cutoff_steps for non-CFG mode
# (WanTeaCacheParams uses *2 factor assuming CFG)
if not ctx.do_cfg:
ret_steps = ret_steps // 2
cutoff_steps = cutoff_steps // 2
# Determine boundary step
is_boundary_step = self.cnt < start_skipping or self.cnt >= end_skipping

timestep_proj = kwargs["timestep_proj"]
temb = kwargs["temb"]
modulated_inp = timestep_proj if use_ret_steps else temb

self.is_cfg_negative = ctx.is_cfg_negative

# Wan uses ret_steps/cutoff_steps for boundary detection
is_boundary_step = self.cnt < ret_steps or self.cnt >= cutoff_steps

# Use shared helper to compute cache decision
should_calc = self._compute_teacache_decision(
modulated_inp=modulated_inp,
Expand Down
Loading
Loading