diff --git a/python/sglang/multimodal_gen/configs/sample/hunyuan.py b/python/sglang/multimodal_gen/configs/sample/hunyuan.py index ae69dbd62ccd..c60b856630f0 100644 --- a/python/sglang/multimodal_gen/configs/sample/hunyuan.py +++ b/python/sglang/multimodal_gen/configs/sample/hunyuan.py @@ -39,6 +39,7 @@ class HunyuanSamplingParams(SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( teacache_thresh=0.15, + # from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4HunyuanVideo/teacache_sample_video.py#L222 coefficients=[ 7.33226126e02, -4.01131952e02, diff --git a/python/sglang/multimodal_gen/configs/sample/sampling_params.py b/python/sglang/multimodal_gen/configs/sample/sampling_params.py index 7dcf9bf1dc88..f4d353874b44 100644 --- a/python/sglang/multimodal_gen/configs/sample/sampling_params.py +++ b/python/sglang/multimodal_gen/configs/sample/sampling_params.py @@ -29,11 +29,16 @@ def _json_safe(obj: Any): """ Recursively convert objects to JSON-serializable forms. - Enums -> their name + - Callables -> stable module-qualified name - Sets/Tuples -> lists - Dicts/Lists -> recursively processed """ if isinstance(obj, Enum): return obj.name + if callable(obj): + module = getattr(obj, "__module__", None) + qualname = getattr(obj, "__qualname__", getattr(obj, "__name__", repr(obj))) + return f"{module}.{qualname}" if module else qualname if isinstance(obj, dict): return {k: _json_safe(v) for k, v in obj.items()} if isinstance(obj, (list, tuple, set)): diff --git a/python/sglang/multimodal_gen/configs/sample/teacache.py b/python/sglang/multimodal_gen/configs/sample/teacache.py index ada71d0b3618..e20df6e67b93 100644 --- a/python/sglang/multimodal_gen/configs/sample/teacache.py +++ b/python/sglang/multimodal_gen/configs/sample/teacache.py @@ -1,43 +1,82 @@ # Copied and adapted from: https://github.com/hao-ai-lab/FastVideo # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from dataclasses import dataclass, field +from typing import Callable from sglang.multimodal_gen.configs.sample.sampling_params import CacheParams @dataclass class TeaCacheParams(CacheParams): + """ + Parameters for [TeaCache](https://arxiv.org/abs/2411.14324). + + Attributes: + cache_type: (`str`, defaults to `teacache`): + A string labeling these parameters as belonging to teacache. + teacache_thresh (`float`, defaults to `0.0`): + Threshold for accumulated relative L1 distance. When below this threshold, the + forward pass is skipped. Recommended values: 0.25 for ~1.5x speedup, 0.4 for ~1.8x, + 0.6 for ~2.0x. + start_skipping (`int` or `float`, defaults to `5`): + The number of timesteps after which we may skip a forward pass. These early + steps define the global structure and are too critical to not skip. + int: The number of timesteps after which we can skip. If negative, + this is an offset from the end of the schedule. + float (0.0 - 1.0): A percentage of the total steps (e.g., 0.1 + computes the first 10%). + end_skipping (`int` or `float`, defaults to `-1`): + The number of timesteps after which we are no longer able to skip + forward passes. The last steps refine fine textures and details. + int: The number of timesteps after which skipping ends. If negative, + this is an offset from the total number of steps. + float (0.0 - 1.0): A percentage of the total steps (e.g., 0.1 + computes the first 10%). + coefficients (`List[float]`, defaults to `[]`): + Polynomial coefficients for rescaling the raw relative L1 distance, + evaluated as `c[0]*x**4 + c[1]*x**3 + c[2]*x**2 + c[3]*x + c[4]`. + coefficients_callback (`Callable[[TeaCacheParams], List[float]]`, *optional*): + A function that receives this `TeaCacheParams` instance and returns + the polynomial coefficients to use. When set, it takes precedence over + the `coefficients` field, allowing dynamic coefficient selection based + on any property of the params (e.g., `use_ret_steps` for Wan models). + use_ret_steps: (`bool`, `None`, defaults to `None`): + Used exclusively for wanvideo models to select different modulated inputs. + """ + cache_type: str = "teacache" teacache_thresh: float = 0.0 + start_skipping: int | float = 5 + end_skipping: int | float = -1 coefficients: list[float] = field(default_factory=list) + coefficients_callback: Callable[[TeaCacheParams], list[float]] | None = field( + default=None, repr=False + ) + use_ret_steps: bool | None = None + def get_coefficients(self) -> list[float]: + if self.coefficients_callback is not None: + return self.coefficients_callback(self) + return self.coefficients -@dataclass -class WanTeaCacheParams(CacheParams): - # Unfortunately, TeaCache is very different for Wan than other models - cache_type: str = "teacache" - teacache_thresh: float = 0.0 - use_ret_steps: bool = True - ret_steps_coeffs: list[float] = field(default_factory=list) - non_ret_steps_coeffs: list[float] = field(default_factory=list) - - @property - def coefficients(self) -> list[float]: - if self.use_ret_steps: - return self.ret_steps_coeffs - else: - return self.non_ret_steps_coeffs - - @property - def ret_steps(self) -> int: - if self.use_ret_steps: - return 5 * 2 - else: - return 1 * 2 - - def get_cutoff_steps(self, num_inference_steps: int) -> int: - if self.use_ret_steps: - return num_inference_steps * 2 - else: - return num_inference_steps * 2 - 2 + def get_skip_boundaries( + self, num_inference_steps: int, do_cfg: bool + ) -> tuple[int, int]: + def _resolve_boundary(value: int | float) -> int: + if isinstance(value, float): + return int(num_inference_steps * value) + if value < 0: + return num_inference_steps + value + return value + + start_skipping = _resolve_boundary(self.start_skipping) + end_skipping = _resolve_boundary(self.end_skipping) + + if do_cfg: + start_skipping *= 2 + end_skipping *= 2 + + return start_skipping, end_skipping diff --git a/python/sglang/multimodal_gen/configs/sample/wan.py b/python/sglang/multimodal_gen/configs/sample/wan.py index a5faf50214f0..0f147a9dcede 100644 --- a/python/sglang/multimodal_gen/configs/sample/wan.py +++ b/python/sglang/multimodal_gen/configs/sample/wan.py @@ -4,7 +4,41 @@ from dataclasses import dataclass, field from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams -from sglang.multimodal_gen.configs.sample.teacache import WanTeaCacheParams +from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams + + +def _wan_1_3b_coefficients(p: TeaCacheParams) -> list[float]: + if p.use_ret_steps: + # from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4Wan2.1/teacache_generate.py#L883 + return [ + -5.21862437e04, + 9.23041404e03, + -5.28275948e02, + 1.36987616e01, + -4.99875664e-02, + ] + # from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4Wan2.1/teacache_generate.py#L890 + return [ + 2.39676752e03, + -1.31110545e03, + 2.01331979e02, + -8.29855975e00, + 1.37887774e-01, + ] + + +def _wan_14b_coefficients(p: TeaCacheParams) -> list[float]: + if p.use_ret_steps: + # from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4Wan2.1/teacache_generate.py#L885 + return [ + -3.03318725e05, + 4.90537029e04, + -2.65530556e03, + 5.87365115e01, + -3.15583525e-01, + ] + # from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4Wan2.1/teacache_generate.py#L892 + return [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] @dataclass @@ -30,23 +64,13 @@ class WanT2V_1_3B_SamplingParams(SamplingParams): ] ) - teacache_params: WanTeaCacheParams = field( - default_factory=lambda: WanTeaCacheParams( + teacache_params: TeaCacheParams = field( + default_factory=lambda: TeaCacheParams( teacache_thresh=0.08, - ret_steps_coeffs=[ - -5.21862437e04, - 9.23041404e03, - -5.28275948e02, - 1.36987616e01, - -4.99875664e-02, - ], - non_ret_steps_coeffs=[ - 2.39676752e03, - -1.31110545e03, - 2.01331979e02, - -8.29855975e00, - 1.37887774e-01, - ], + use_ret_steps=True, + coefficients_callback=_wan_1_3b_coefficients, + start_skipping=5, + end_skipping=1.0, ) ) @@ -76,24 +100,13 @@ class WanT2V_14B_SamplingParams(SamplingParams): ] ) - teacache_params: WanTeaCacheParams = field( - default_factory=lambda: WanTeaCacheParams( + teacache_params: TeaCacheParams = field( + default_factory=lambda: TeaCacheParams( teacache_thresh=0.20, use_ret_steps=False, - ret_steps_coeffs=[ - -3.03318725e05, - 4.90537029e04, - -2.65530556e03, - 5.87365115e01, - -3.15583525e-01, - ], - non_ret_steps_coeffs=[ - -5784.54975374, - 5449.50911966, - -1811.16591783, - 256.27178429, - -13.02252404, - ], + coefficients_callback=_wan_14b_coefficients, + start_skipping=1, + end_skipping=-1, ) ) @@ -113,23 +126,13 @@ class WanI2V_14B_480P_SamplingParam(WanT2V_1_3B_SamplingParams): ] ) - teacache_params: WanTeaCacheParams = field( - default_factory=lambda: WanTeaCacheParams( + teacache_params: TeaCacheParams = field( + default_factory=lambda: TeaCacheParams( teacache_thresh=0.26, - ret_steps_coeffs=[ - -3.03318725e05, - 4.90537029e04, - -2.65530556e03, - 5.87365115e01, - -3.15583525e-01, - ], - non_ret_steps_coeffs=[ - -5784.54975374, - 5449.50911966, - -1811.16591783, - 256.27178429, - -13.02252404, - ], + use_ret_steps=True, + coefficients_callback=_wan_14b_coefficients, + start_skipping=5, + end_skipping=1.0, ) ) @@ -151,23 +154,13 @@ class WanI2V_14B_720P_SamplingParam(WanT2V_14B_SamplingParams): ] ) - teacache_params: WanTeaCacheParams = field( - default_factory=lambda: WanTeaCacheParams( + teacache_params: TeaCacheParams = field( + default_factory=lambda: TeaCacheParams( teacache_thresh=0.3, - ret_steps_coeffs=[ - -3.03318725e05, - 4.90537029e04, - -2.65530556e03, - 5.87365115e01, - -3.15583525e-01, - ], - non_ret_steps_coeffs=[ - -5784.54975374, - 5449.50911966, - -1811.16591783, - 256.27178429, - -13.02252404, - ], + use_ret_steps=True, + coefficients_callback=_wan_14b_coefficients, + start_skipping=5, + end_skipping=1.0, ) ) diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index 5cdafd08bc04..8830f7ec20c4 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -297,7 +297,7 @@ def _get_teacache_context(self) -> TeaCacheContext | None: do_cfg=do_cfg, is_cfg_negative=is_cfg_negative, teacache_thresh=teacache_params.teacache_thresh, - coefficients=teacache_params.coefficients, + coefficients=teacache_params.get_coefficients(), teacache_params=teacache_params, ) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index 4a2798a4a934..b193bf808324 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -10,7 +10,6 @@ import torch.nn as nn from sglang.multimodal_gen.configs.models.dits import WanVideoConfig -from sglang.multimodal_gen.configs.sample.wan import WanTeaCacheParams from sglang.multimodal_gen.runtime.distributed import ( divide, get_sp_group, @@ -1170,22 +1169,15 @@ def should_skip_forward_for_cached_states(self, **kwargs) -> bool: if ctx is None: return False - # Wan uses WanTeaCacheParams with additional fields - teacache_params = ctx.teacache_params - assert isinstance( - teacache_params, WanTeaCacheParams - ), "teacache_params is not a WanTeaCacheParams" - # Initialize Wan-specific parameters + teacache_params = ctx.teacache_params use_ret_steps = teacache_params.use_ret_steps - cutoff_steps = teacache_params.get_cutoff_steps(ctx.num_inference_steps) - ret_steps = teacache_params.ret_steps + start_skipping, end_skipping = teacache_params.get_skip_boundaries( + ctx.num_inference_steps, ctx.do_cfg + ) - # Adjust ret_steps and cutoff_steps for non-CFG mode - # (WanTeaCacheParams uses *2 factor assuming CFG) - if not ctx.do_cfg: - ret_steps = ret_steps // 2 - cutoff_steps = cutoff_steps // 2 + # Determine boundary step + is_boundary_step = self.cnt < start_skipping or self.cnt >= end_skipping timestep_proj = kwargs["timestep_proj"] temb = kwargs["temb"] @@ -1193,9 +1185,6 @@ def should_skip_forward_for_cached_states(self, **kwargs) -> bool: self.is_cfg_negative = ctx.is_cfg_negative - # Wan uses ret_steps/cutoff_steps for boundary detection - is_boundary_step = self.cnt < ret_steps or self.cnt >= cutoff_steps - # Use shared helper to compute cache decision should_calc = self._compute_teacache_decision( modulated_inp=modulated_inp, diff --git a/python/sglang/multimodal_gen/test/unit/test_sampling_params.py b/python/sglang/multimodal_gen/test/unit/test_sampling_params.py index 722d4d12ab26..4ac66ec175e1 100644 --- a/python/sglang/multimodal_gen/test/unit/test_sampling_params.py +++ b/python/sglang/multimodal_gen/test/unit/test_sampling_params.py @@ -7,7 +7,17 @@ ) from sglang.multimodal_gen.configs.sample.flux import FluxSamplingParams from sglang.multimodal_gen.configs.sample.qwenimage import QwenImageSamplingParams -from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams +from sglang.multimodal_gen.configs.sample.sampling_params import ( + SamplingParams, + _json_safe, +) +from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams +from sglang.multimodal_gen.configs.sample.wan import ( + WanI2V_14B_480P_SamplingParam, + WanI2V_14B_720P_SamplingParam, + WanT2V_1_3B_SamplingParams, + WanT2V_14B_SamplingParams, +) class TestSamplingParamsValidate(unittest.TestCase): @@ -68,6 +78,59 @@ def test_diffusers_generic_calls_base_post_init(self): with self.assertRaises(AssertionError): DiffusersGenericSamplingParams(num_frames=0) + def test_output_file_name_supports_callable_teacache_params(self): + def coefficients_callback(_: TeaCacheParams) -> list[float]: + return [1.0, 2.0, 3.0, 4.0, 5.0] + + params = SamplingParams( + prompt="callable teacache", + teacache_params=TeaCacheParams( + coefficients_callback=coefficients_callback, + ), + ) + + params._set_output_file_name() + + self.assertTrue(params.output_file_name.endswith(".mp4")) + self.assertIn( + "test_sampling_params.TestSamplingParamsSubclass.test_output_file_name_supports_callable_teacache_params", + _json_safe(coefficients_callback), + ) + + def test_teacache_callback_takes_precedence_over_static_coefficients(self): + def coefficients_callback(_: TeaCacheParams) -> list[float]: + return [9.0, 8.0, 7.0, 6.0, 5.0] + + params = TeaCacheParams( + coefficients=[1.0, 2.0, 3.0, 4.0, 5.0], + coefficients_callback=coefficients_callback, + ) + + self.assertEqual(params.get_coefficients(), [9.0, 8.0, 7.0, 6.0, 5.0]) + + def test_wan_teacache_boundaries_match_legacy_behavior(self): + legacy_equivalent_cases = [ + (WanT2V_1_3B_SamplingParams().teacache_params, False, (5, 50)), + (WanT2V_1_3B_SamplingParams().teacache_params, True, (10, 100)), + (WanT2V_14B_SamplingParams().teacache_params, False, (1, 49)), + (WanT2V_14B_SamplingParams().teacache_params, True, (2, 98)), + (WanI2V_14B_480P_SamplingParam().teacache_params, False, (5, 50)), + (WanI2V_14B_480P_SamplingParam().teacache_params, True, (10, 100)), + (WanI2V_14B_720P_SamplingParam().teacache_params, False, (5, 50)), + (WanI2V_14B_720P_SamplingParam().teacache_params, True, (10, 100)), + ] + + for teacache_params, do_cfg, expected in legacy_equivalent_cases: + with self.subTest( + use_ret_steps=teacache_params.use_ret_steps, + do_cfg=do_cfg, + expected=expected, + ): + self.assertEqual( + teacache_params.get_skip_boundaries(50, do_cfg), + expected, + ) + class TestSamplingParamsCliArgs(unittest.TestCase): def _parse_cli_kwargs(self, argv: list[str]) -> dict: