diff --git a/docs/diffusion/performance/cache/teacache.md b/docs/diffusion/performance/cache/teacache.md index dd9691c43a4a..186ebc59d996 100644 --- a/docs/diffusion/performance/cache/teacache.md +++ b/docs/diffusion/performance/cache/teacache.md @@ -13,7 +13,19 @@ TeaCache works by: 3. When accumulated distance is below a threshold, reusing the cached residual 4. Supporting CFG (Classifier-Free Guidance) with separate positive/negative caches -## How It Works +## Implementation + +TeaCache is split into three classes: + +- **`TeaCacheParams`** — pure data class holding user-set parameters (`rel_l1_thresh`, `coefficients`, `start_skipping`, `end_skipping`). Set once per request, never mutated during inference. +- **`TeaCacheState`** — dataclass holding runtime state for one CFG branch: `step`, `previous_modulated_input`, `previous_residual`, `accumulated_rel_l1_distance`. +- **`TeaCacheStrategy`** — all the logic. Owns two `TeaCacheState` objects (positive + optional negative CFG branch). Constructed once per generation by `CachableDiT.maybe_init_cache()` with all parameters resolved upfront. + +At each denoising step, the model calls: +1. `cache.step(modulated_input)` — advances the step counter, accumulates the rescaled L1 distance, returns `True` if the forward pass can be skipped +2. `cache.read()` — if skipping, reads the cached residual and applies it to hidden states +3. `cache.write()` — if computing, stores the new residual in the cache +4. `cache.reset_states()` — resets `state` and optionally `state_neg`, discarding any stale tensors ### L1 Distance Tracking @@ -36,9 +48,7 @@ accumulated += poly(coefficients)(rel_l1) ### CFG Support -For models that support CFG cache separation (Wan, Hunyuan, Z-Image), TeaCache maintains separate caches for positive and negative branches: -- `previous_modulated_input` / `previous_residual` for positive branch -- `previous_modulated_input_negative` / `previous_residual_negative` for negative branch +For models that support CFG separation (Wan, Hunyuan, Z-Image), `TeaCacheStrategy` maintains separate `TeaCacheState` objects for the positive and negative branches. For models that don't support CFG separation (Flux, Qwen), TeaCache is automatically disabled when CFG is enabled. @@ -50,7 +60,7 @@ TeaCache is configured via `TeaCacheParams` in the sampling parameters: from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams params = TeaCacheParams( - teacache_thresh=0.1, # Threshold for accumulated L1 distance + rel_l1_thresh=0.1, # Threshold for accumulated L1 distance coefficients=[1.0, 0.0, 0.0], # Polynomial coefficients for L1 rescaling ) ``` @@ -59,7 +69,7 @@ params = TeaCacheParams( | Parameter | Type | Description | |-----------|------|-------------| -| `teacache_thresh` | float | Threshold for accumulated L1 distance. Lower = more caching, faster but potentially lower quality | +| `rel_l1_thresh` | float | Threshold for accumulated L1 distance. Lower = more caching, faster but potentially lower quality | | `coefficients` | list[float] | Polynomial coefficients for L1 rescaling. Model-specific tuning | ### Model-Specific Configurations diff --git a/python/sglang/multimodal_gen/configs/sample/hunyuan.py b/python/sglang/multimodal_gen/configs/sample/hunyuan.py index c60b856630f0..61733d24868d 100644 --- a/python/sglang/multimodal_gen/configs/sample/hunyuan.py +++ b/python/sglang/multimodal_gen/configs/sample/hunyuan.py @@ -38,7 +38,7 @@ class HunyuanSamplingParams(SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.15, + rel_l1_thresh=0.15, # from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4HunyuanVideo/teacache_sample_video.py#L222 coefficients=[ 7.33226126e02, diff --git a/python/sglang/multimodal_gen/configs/sample/sampling_params.py b/python/sglang/multimodal_gen/configs/sample/sampling_params.py index 123bf5234a46..d981659c38ff 100644 --- a/python/sglang/multimodal_gen/configs/sample/sampling_params.py +++ b/python/sglang/multimodal_gen/configs/sample/sampling_params.py @@ -162,9 +162,7 @@ class SamplingParams: # TeaCache parameters enable_teacache: bool = False - teacache_params: Any = ( - None # TeaCacheParams or WanTeaCacheParams, set by model-specific subclass - ) + teacache_params: Any | None = None # Profiling profile: bool = False diff --git a/python/sglang/multimodal_gen/configs/sample/teacache.py b/python/sglang/multimodal_gen/configs/sample/teacache.py index e20df6e67b93..78994c545278 100644 --- a/python/sglang/multimodal_gen/configs/sample/teacache.py +++ b/python/sglang/multimodal_gen/configs/sample/teacache.py @@ -17,7 +17,7 @@ class TeaCacheParams(CacheParams): Attributes: cache_type: (`str`, defaults to `teacache`): A string labeling these parameters as belonging to teacache. - teacache_thresh (`float`, defaults to `0.0`): + rel_l1_thresh (`float`, defaults to `0.0`): Threshold for accumulated relative L1 distance. When below this threshold, the forward pass is skipped. Recommended values: 0.25 for ~1.5x speedup, 0.4 for ~1.8x, 0.6 for ~2.0x. @@ -48,7 +48,7 @@ class TeaCacheParams(CacheParams): """ cache_type: str = "teacache" - teacache_thresh: float = 0.0 + rel_l1_thresh: float = 0.0 start_skipping: int | float = 5 end_skipping: int | float = -1 coefficients: list[float] = field(default_factory=list) @@ -62,9 +62,7 @@ def get_coefficients(self) -> list[float]: return self.coefficients_callback(self) return self.coefficients - def get_skip_boundaries( - self, num_inference_steps: int, do_cfg: bool - ) -> tuple[int, int]: + def get_skip_boundaries(self, num_inference_steps: int) -> tuple[int, int]: def _resolve_boundary(value: int | float) -> int: if isinstance(value, float): return int(num_inference_steps * value) @@ -74,9 +72,4 @@ def _resolve_boundary(value: int | float) -> int: start_skipping = _resolve_boundary(self.start_skipping) end_skipping = _resolve_boundary(self.end_skipping) - - if do_cfg: - start_skipping *= 2 - end_skipping *= 2 - return start_skipping, end_skipping diff --git a/python/sglang/multimodal_gen/configs/sample/wan.py b/python/sglang/multimodal_gen/configs/sample/wan.py index 0f147a9dcede..184fe7034e68 100644 --- a/python/sglang/multimodal_gen/configs/sample/wan.py +++ b/python/sglang/multimodal_gen/configs/sample/wan.py @@ -66,7 +66,7 @@ class WanT2V_1_3B_SamplingParams(SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.08, + rel_l1_thresh=0.08, use_ret_steps=True, coefficients_callback=_wan_1_3b_coefficients, start_skipping=5, @@ -102,7 +102,7 @@ class WanT2V_14B_SamplingParams(SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.20, + rel_l1_thresh=0.20, use_ret_steps=False, coefficients_callback=_wan_14b_coefficients, start_skipping=1, @@ -128,7 +128,7 @@ class WanI2V_14B_480P_SamplingParam(WanT2V_1_3B_SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.26, + rel_l1_thresh=0.26, use_ret_steps=True, coefficients_callback=_wan_14b_coefficients, start_skipping=5, @@ -156,7 +156,7 @@ class WanI2V_14B_720P_SamplingParam(WanT2V_14B_SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.3, + rel_l1_thresh=0.3, use_ret_steps=True, coefficients_callback=_wan_14b_coefficients, start_skipping=5, diff --git a/python/sglang/multimodal_gen/configs/sample/zimage.py b/python/sglang/multimodal_gen/configs/sample/zimage.py index 77a9dabf90de..9c728e182ab3 100644 --- a/python/sglang/multimodal_gen/configs/sample/zimage.py +++ b/python/sglang/multimodal_gen/configs/sample/zimage.py @@ -22,7 +22,7 @@ class ZImageTurboSamplingParams(SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.15, + rel_l1_thresh=0.15, coefficients=[ 7.33226126e02, -4.01131952e02, diff --git a/python/sglang/multimodal_gen/runtime/cache/__init__.py b/python/sglang/multimodal_gen/runtime/cache/__init__.py index 62f0f8457f8f..9e756ab2aef3 100644 --- a/python/sglang/multimodal_gen/runtime/cache/__init__.py +++ b/python/sglang/multimodal_gen/runtime/cache/__init__.py @@ -16,12 +16,12 @@ enable_cache_on_transformer, get_scm_mask, ) -from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheContext, TeaCacheMixin +from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheState, TeaCacheStrategy __all__ = [ # TeaCache (always available) - "TeaCacheContext", - "TeaCacheMixin", + "TeaCacheState", + "TeaCacheStrategy", # cache-dit integration (lazy-loaded, requires cache-dit package) "CacheDitConfig", "enable_cache_on_transformer", diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index 8830f7ec20c4..f90a2bcfcd2a 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -1,316 +1,161 @@ # SPDX-License-Identifier: Apache-2.0 """ -TeaCache: Temporal similarity-based caching for diffusion models. - -TeaCache accelerates diffusion inference by selectively skipping redundant -computation when consecutive diffusion steps are similar enough. This is -achieved by tracking the L1 distance between modulated inputs across timesteps. - -Key concepts: -- Modulated input: The input to transformer blocks after timestep conditioning -- L1 distance: Measures how different consecutive timesteps are -- Threshold: When accumulated L1 distance exceeds threshold, force computation -- CFG support: Separate caches for positive and negative branches +TeaCache accelerates diffusion inference by skipping redundant forward +passes when consecutive denoising steps are sufficiently similar, as measured +by the accumulated relative L1 distance of modulated inputs. References: - TeaCache: Accelerating Diffusion Models with Temporal Similarity https://arxiv.org/abs/2411.14324 """ -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from __future__ import annotations -import numpy as np -import torch +import logging +from dataclasses import dataclass -from sglang.multimodal_gen.configs.models import DiTConfig +logger = logging.getLogger(__name__) -if TYPE_CHECKING: - from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams +import torch @dataclass -class TeaCacheContext: - """Common context extracted for TeaCache skip decision. - - This context is populated from the forward_batch and forward_context - during each denoising step, providing all information needed to make - cache decisions. - - Attributes: - current_timestep: Current denoising timestep index (0-indexed). - num_inference_steps: Total number of inference steps. - do_cfg: Whether classifier-free guidance is enabled. - is_cfg_negative: True if currently processing negative CFG branch. - teacache_thresh: Threshold for accumulated L1 distance. - coefficients: Polynomial coefficients for L1 rescaling. - teacache_params: Full TeaCacheParams for model-specific access. - """ - - current_timestep: int - num_inference_steps: int - do_cfg: bool - is_cfg_negative: bool # For CFG branch selection - teacache_thresh: float - coefficients: list[float] - teacache_params: "TeaCacheParams" # Full params for model-specific access - - -class TeaCacheMixin: - """ - Mixin class providing TeaCache optimization functionality. - - TeaCache accelerates diffusion inference by selectively skipping redundant - computation when consecutive diffusion steps are similar enough. - - This mixin should be inherited by DiT model classes that want to support - TeaCache optimization. It provides: - - State management for tracking L1 distances - - CFG-aware caching (separate caches for positive/negative branches) - - Decision logic for when to compute vs. use cache - - Example usage in a DiT model: - class MyDiT(TeaCacheMixin, BaseDiT): - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - self._init_teacache_state() - - def forward(self, hidden_states, timestep, ...): - ctx = self._get_teacache_context() - if ctx is not None: - # Compute modulated input (model-specific, e.g., after timestep embedding) - modulated_input = self._compute_modulated_input(hidden_states, timestep) - is_boundary = (ctx.current_timestep == 0 or - ctx.current_timestep >= ctx.num_inference_steps - 1) - - should_calc = self._compute_teacache_decision( - modulated_inp=modulated_input, - is_boundary_step=is_boundary, - coefficients=ctx.coefficients, - teacache_thresh=ctx.teacache_thresh, - ) - - if not should_calc: - # Use cached residual (must implement retrieve_cached_states) - return self.retrieve_cached_states(hidden_states) - - # Normal forward pass... - output = self._transformer_forward(hidden_states, timestep, ...) - - # Cache states for next step - if ctx is not None: - self.maybe_cache_states(output, hidden_states) - - return output - - Subclass implementation notes: - - `_compute_modulated_input()`: Model-specific method to compute the input - after timestep conditioning (used for L1 distance calculation) - - `retrieve_cached_states()`: Must be overridden to return cached output - - `maybe_cache_states()`: Override to store states for cache retrieval - - Attributes: - cnt: Counter for tracking steps. - enable_teacache: Whether TeaCache is enabled. - previous_modulated_input: Cached modulated input for positive branch. - previous_residual: Cached residual for positive branch. - accumulated_rel_l1_distance: Accumulated L1 distance for positive branch. - is_cfg_negative: Whether currently processing negative CFG branch. - _supports_cfg_cache: Whether this model supports CFG cache separation. - - CFG-specific attributes (only when _supports_cfg_cache is True): - previous_modulated_input_negative: Cached input for negative branch. - previous_residual_negative: Cached residual for negative branch. - accumulated_rel_l1_distance_negative: L1 distance for negative branch. +class TeaCacheState: + """Tracks step progress, cached tensors, and L1 distances for a single CFG path.""" + + step: int = 0 + previous_modulated_input: torch.Tensor | None = None + previous_residual: torch.Tensor | None = None + accumulated_rel_l1_distance: torch.Tensor | None = None + + +def _rescale_distance_tensor( + coefficients: list[float], x: torch.Tensor +) -> torch.Tensor: + """Polynomial rescaling using tensor operations (torch.compile friendly).""" + x = ( + x.float() + ) # upcast to float32 for numerical stability, especially with higher degree polynomials + result = torch.zeros_like(x) + for i, c in enumerate(coefficients): + result = result + c * x ** (len(coefficients) - 1 - i) + return result + + +def _compute_rel_l1_distance_tensor( + current: torch.Tensor, previous: torch.Tensor +) -> torch.Tensor: + """Compute relative L1 distance as a tensor (torch.compile friendly).""" + current, previous = ( + current.float(), + previous.float(), + ) # upcast to float32 for numerical stability + prev_mean = previous.abs().mean() + curr_diff_mean = (current - previous).abs().mean() + rel_distance = torch.where( + prev_mean > 1e-9, + curr_diff_mean / prev_mean, + torch.where( + current.abs().mean() < 1e-9, + torch.zeros(1, device=current.device, dtype=current.dtype), + torch.full((1,), float("inf"), device=current.device, dtype=current.dtype), + ), + ) + return rel_distance.squeeze() + + +class TeaCacheStrategy: + """Implements TeaCache to skip redundant diffusion forward passes. + + TeaCacheStrategy manages two TeaCacheState objects (positive + optional + negative CFG branch) and stores parameters needed to make the skip decision. """ - # Models that support CFG cache separation (wan/hunyuan/zimage) - # Models not in this set (flux/qwen) auto-disable TeaCache when CFG is enabled - _CFG_SUPPORTED_PREFIXES: set[str] = {"wan", "hunyuan", "zimage"} - config: DiTConfig - - def _init_teacache_state(self) -> None: - """Initialize TeaCache state. Call this in subclass __init__.""" - # Common TeaCache state - self.cnt = 0 - self.enable_teacache = True - # Flag indicating if this model supports CFG cache separation - self._supports_cfg_cache = ( - self.config.prefix.lower() in self._CFG_SUPPORTED_PREFIXES - ) - - # Always initialize positive cache fields (used in all modes) - self.previous_modulated_input: torch.Tensor | None = None - self.previous_residual: torch.Tensor | None = None - self.accumulated_rel_l1_distance: float = 0.0 - - self.is_cfg_negative = False - # CFG-specific fields initialized to None (created when CFG is used) - # These are only used when _supports_cfg_cache is True AND do_cfg is True - if self._supports_cfg_cache: - self.previous_modulated_input_negative: torch.Tensor | None = None - self.previous_residual_negative: torch.Tensor | None = None - self.accumulated_rel_l1_distance_negative: float = 0.0 - - def reset_teacache_state(self) -> None: - """Reset all TeaCache state at the start of each generation task.""" - self.cnt = 0 - - # Primary cache fields (always present) - self.previous_modulated_input = None - self.previous_residual = None - self.accumulated_rel_l1_distance = 0.0 - self.is_cfg_negative = False - self.enable_teacache = True - # CFG negative cache fields (always reset, may be unused) - if self._supports_cfg_cache: - self.previous_modulated_input_negative = None - self.previous_residual_negative = None - self.accumulated_rel_l1_distance_negative = 0.0 - - def _compute_l1_and_decide( + def __init__( self, - modulated_inp: torch.Tensor, + supports_cfg: bool, coefficients: list[float], - teacache_thresh: float, - ) -> tuple[float, bool]: - """ - Compute L1 distance and decide whether to calculate or use cache. - - Args: - modulated_inp: Current timestep's modulated input. - coefficients: Polynomial coefficients for L1 rescaling. - teacache_thresh: Threshold for cache decision. - - Returns: - Tuple of (new_accumulated_distance, should_calc). - """ - prev_modulated_inp = ( - self.previous_modulated_input_negative - if self.is_cfg_negative - else self.previous_modulated_input - ) - - # Defensive check: if previous input is not set, force calculation - if prev_modulated_inp is None: - return 0.0, True - - # Compute relative L1 distance - diff = modulated_inp - prev_modulated_inp - rel_l1 = (diff.abs().mean() / prev_modulated_inp.abs().mean()).cpu().item() - - # Apply polynomial rescaling - rescale_func = np.poly1d(coefficients) - - accumulated_rel_l1_distance = ( - self.accumulated_rel_l1_distance_negative - if self.is_cfg_negative - else self.accumulated_rel_l1_distance - ) - accumulated_rel_l1_distance = accumulated_rel_l1_distance + rescale_func(rel_l1) - - if accumulated_rel_l1_distance >= teacache_thresh: - # Threshold exceeded: force compute and reset accumulator - return 0.0, True - # Cache hit: keep accumulated distance - return accumulated_rel_l1_distance, False - - def _compute_teacache_decision( - self, - modulated_inp: torch.Tensor, - is_boundary_step: bool, - coefficients: list[float], - teacache_thresh: float, - ) -> bool: - """ - Compute cache decision for TeaCache. - - Args: - modulated_inp: Current timestep's modulated input. - is_boundary_step: True for boundary timesteps that always compute. - coefficients: Polynomial coefficients for L1 rescaling. - teacache_thresh: Threshold for cache decision. - - Returns: - True if forward computation is needed, False to use cache. - """ - if not self.enable_teacache: - return True - - if is_boundary_step: - new_accum, should_calc = 0.0, True - else: - new_accum, should_calc = self._compute_l1_and_decide( - modulated_inp=modulated_inp, - coefficients=coefficients, - teacache_thresh=teacache_thresh, + rel_l1_thresh: float, + start_skipping: int, + end_skipping: int, + ) -> None: + """Initialize cache states and all generation parameters.""" + self.supports_cfg = supports_cfg + self.state = TeaCacheState() + self.state_neg = TeaCacheState() if supports_cfg else None + self.coefficients = coefficients + self.rel_l1_thresh = rel_l1_thresh + self.start_skipping = start_skipping + self.end_skipping = end_skipping + if start_skipping >= end_skipping: + logger.warning( + f"TeaCache skip window is invalid (start_skipping={start_skipping} >= " + f"end_skipping={end_skipping}). This can happen during warmup runs with " + "very few steps. Skipping disabled." ) - # Advance baseline and accumulator for the active branch - if not self.is_cfg_negative: - self.previous_modulated_input = modulated_inp.clone() - self.accumulated_rel_l1_distance = new_accum - elif self._supports_cfg_cache: - self.previous_modulated_input_negative = modulated_inp.clone() - self.accumulated_rel_l1_distance_negative = new_accum - - return should_calc + def reset_states(self) -> None: + """Reset cache states, discarding any stale tensors from a previous generation.""" + self.state = TeaCacheState() + self.state_neg = TeaCacheState() if self.supports_cfg else None - def _get_teacache_context(self) -> TeaCacheContext | None: - """ - Check TeaCache preconditions and extract common context. - - Returns: - TeaCacheContext if TeaCache is enabled and properly configured, - None if should skip TeaCache logic entirely. - """ + def _get_state(self) -> TeaCacheState: + """Select the appropriate cache state (positive/negative cfg) based on the forward context.""" from sglang.multimodal_gen.runtime.managers.forward_context import ( get_forward_context, ) - forward_context = get_forward_context() - forward_batch = forward_context.forward_batch - - # Early return checks - if ( - forward_batch is None - or not forward_batch.enable_teacache - or forward_batch.teacache_params is None - ): - return None - - teacache_params = forward_batch.teacache_params - - # Extract common values - current_timestep = forward_context.current_timestep - num_inference_steps = forward_batch.num_inference_steps - do_cfg = forward_batch.do_classifier_free_guidance - is_cfg_negative = forward_batch.is_cfg_negative - - # Reset at first timestep - if current_timestep == 0 and not self.is_cfg_negative: - self.reset_teacache_state() - - return TeaCacheContext( - current_timestep=current_timestep, - num_inference_steps=num_inference_steps, - do_cfg=do_cfg, - is_cfg_negative=is_cfg_negative, - teacache_thresh=teacache_params.teacache_thresh, - coefficients=teacache_params.get_coefficients(), - teacache_params=teacache_params, + forward_batch = get_forward_context().forward_batch + is_cfg_negative = ( + forward_batch.is_cfg_negative if forward_batch is not None else False ) + if is_cfg_negative and self.state_neg is not None: + return self.state_neg + return self.state + + def step(self, modulated_input: torch.Tensor) -> bool: + """Advance state and return whether this forward pass can be skipped.""" + state = self._get_state() + step = state.step + state.step += 1 + + # Do not skip on the first step or if we are outside the skipping window + in_skip_window = self.start_skipping <= step < self.end_skipping + if state.previous_modulated_input is None or not in_skip_window: + state.accumulated_rel_l1_distance = None + state.previous_modulated_input = modulated_input.clone() + return False + + # Compute the relative L1 distance and update the state + rel_l1 = _compute_rel_l1_distance_tensor( + modulated_input, state.previous_modulated_input + ) + rescaled = _rescale_distance_tensor(self.coefficients, rel_l1) + state.accumulated_rel_l1_distance = ( + rescaled + if state.accumulated_rel_l1_distance is None + else state.accumulated_rel_l1_distance + rescaled + ) + state.previous_modulated_input = modulated_input.clone() - def maybe_cache_states( - self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor - ) -> None: - """Cache states for later retrieval. Override in subclass if needed.""" - pass + # Skip if accumulated rel l1 is small + if state.accumulated_rel_l1_distance < self.rel_l1_thresh: + return True - def should_skip_forward_for_cached_states(self, **kwargs: dict[str, Any]) -> bool: - """Check if forward can be skipped using cached states.""" + # Otherwise reset the accumulator and do not skip + state.accumulated_rel_l1_distance = None return False - def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Retrieve cached states. Must be implemented by subclass.""" - raise NotImplementedError("retrieve_cached_states is not implemented") + def write( + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + **kwargs, + ) -> None: + """After the forward pass, cache the residual.""" + state = self._get_state() + state.previous_residual = hidden_states.squeeze(0) - original_hidden_states + + def read(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: + """Before the forward pass, read from the cache and apply it to the current hidden states.""" + return hidden_states + self._get_state().previous_residual diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index 1007e9b41476..5a2f5c5c46d0 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -46,7 +46,6 @@ import torch from torch.distributed.tensor import DTensor, distribute_tensor -from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheMixin from sglang.multimodal_gen.runtime.loader.utils import ( _list_safetensors_files, ) @@ -213,8 +212,8 @@ def update_weights_from_disk( if success and flush_cache: for _, module in modules_to_update: - if isinstance(module, TeaCacheMixin): - module.reset_teacache_state() + if hasattr(module, "cache") and module.cache is not None: + module.cache.reset_states() logger.info(message) return success, message diff --git a/python/sglang/multimodal_gen/runtime/models/dits/base.py b/python/sglang/multimodal_gen/runtime/models/dits/base.py index 9816f5fb03a7..e4efa2bf8bfa 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/base.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/base.py @@ -8,16 +8,13 @@ from torch import nn from sglang.multimodal_gen.configs.models import DiTConfig - -# NOTE: TeaCacheContext and TeaCacheMixin have been moved to -# sglang.multimodal_gen.runtime.cache.teacache -# For backwards compatibility, re-export from the new location -from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheContext # noqa: F401 -from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheMixin +from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheStrategy from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) -# TODO class BaseDiT(nn.Module, ABC): _fsdp_shard_conditions: list = [] _compile_conditions: list = [] @@ -87,11 +84,15 @@ def device(self) -> torch.device: return next(self.parameters()).device -class CachableDiT(TeaCacheMixin, BaseDiT): +_CFG_SUPPORTED_PREFIXES: set[str] = {"wan", "hunyuan", "zimage"} + + +class CachableDiT(BaseDiT): """ - An intermediate base class that adds TeaCache optimization functionality to DiT models. + An intermediate base class that adds timestep-caching support for DiT models + such as TeaCache. - Inherits TeaCacheMixin for cache logic and BaseDiT for core DiT functionality. + Inherits `BaseDiT` for core DiT functionality and stores cache logic in `self.cache`. """ # These are required class attributes that should be overridden by concrete implementations @@ -109,8 +110,61 @@ class CachableDiT(TeaCacheMixin, BaseDiT): ) def __init__(self, config: DiTConfig, **kwargs) -> None: + """Initialize cache state for a DiT model with caching support. + + Args: + config: DiT model configuration. + **kwargs: Passed through to BaseDiT (e.g. hf_config). + + Attributes: + cache: None when uninitialized or when no caching was requested; otherwise an active TeaCacheStrategy. + calibrate_cache: When True, runs every forward pass to gather calibration data. + """ super().__init__(config, **kwargs) - self._init_teacache_state() + self.cache: TeaCacheStrategy | None = None + self.calibrate_cache: bool = False + + def maybe_init_cache(self) -> None: + """Initialize the cache strategy at the start of each new generation + (when timestep == 0 and cfg is positive for cfg-supporting models). + + Since the cache parameters are contained in the sampling parameters which is only + accessible during the first forward pass, we cannot initialize the cache in CachableDiT.__init__. + """ + from sglang.multimodal_gen.runtime.managers.forward_context import ( + get_forward_context, + ) + + forward_batch = get_forward_context().forward_batch + if forward_batch is None: + return + + # caching strategies may handle pos/neg cfg separately + supports_cfg = self.config.prefix.lower() in _CFG_SUPPORTED_PREFIXES + + # initialize cache at the start of each new generation (step index == 0 and cfg is positive for cfg-supporting models) + current_timestep = get_forward_context().current_timestep + if current_timestep == 0 and ( + (supports_cfg and not forward_batch.is_cfg_negative) or not supports_cfg + ): + # select caching strategy + cache_params = getattr( + forward_batch.sampling_params, "teacache_params", None + ) + if forward_batch.enable_teacache and cache_params is not None: + num_steps = int(forward_batch.num_inference_steps) + start_skipping, end_skipping = cache_params.get_skip_boundaries( + num_steps + ) + self.cache = TeaCacheStrategy( + supports_cfg, + cache_params.get_coefficients(), + cache_params.rel_l1_thresh, + start_skipping, + end_skipping, + ) + else: + self.cache = None @classmethod def get_nunchaku_quant_rules(cls) -> dict[str, dict[str, Any]]: diff --git a/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py index 9decdf49e435..b2c793133db3 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py @@ -700,7 +700,7 @@ def should_skip_forward_for_cached_states(self, **kwargs) -> bool: teacache_params, TeaCacheParams ), "teacache_params is not a TeaCacheParams" num_inference_steps = forward_batch.num_inference_steps - teache_thresh = teacache_params.teacache_thresh + teache_thresh = teacache_params.rel_l1_thresh coefficients = teacache_params.coefficients diff --git a/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py b/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py index cbe8489fc9bc..a36e4ba304cc 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py @@ -182,7 +182,7 @@ def __init__( self.num_channels_latents = out_dim self.layer_names = ["blocks"] self.cnt = 0 - self.teacache_thresh = 0 + self.rel_l1_thresh = 0 self.coefficients = [] self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = None diff --git a/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py b/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py index 7e10767f5b53..d097ee8775a4 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py @@ -497,7 +497,7 @@ def __init__( self.num_channels_latents = out_dim self.layer_names = ["blocks"] self.cnt = 0 - self.teacache_thresh = 0 + self.rel_l1_thresh = 0 self.coefficients = [] self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = None diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index eff9835b9342..458e20aefa50 100755 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -944,8 +944,6 @@ def __init__( ) # For type checking - - self.cnt = 0 self.__post_init__() # misc @@ -999,16 +997,19 @@ def forward( guidance=None, **kwargs, ) -> torch.Tensor: - forward_batch = get_forward_context().forward_batch + + # if caching is enabled, we might initialize or reset the cache state + self.maybe_init_cache() + + forward_context = get_forward_context() + forward_batch = forward_context.forward_batch + if forward_batch is not None: sequence_shard_enabled = ( forward_batch.enable_sequence_shard and self.sp_size > 1 ) else: sequence_shard_enabled = False - self.enable_teacache = ( - forward_batch is not None and forward_batch.enable_teacache - ) orig_dtype = hidden_states.dtype if not isinstance(encoder_hidden_states, torch.Tensor): @@ -1140,25 +1141,27 @@ def forward( # 4. Transformer blocks # if caching is enabled, we might be able to skip the forward pass - should_skip_forward = self.should_skip_forward_for_cached_states( - timestep_proj=timestep_proj, temb=temb - ) + should_skip_forward = False + if self.cache: + use_ret_steps = forward_batch.sampling_params.teacache_params.use_ret_steps + modulated_input = timestep_proj if use_ret_steps else temb + should_skip_forward = self.cache.step(modulated_input) if should_skip_forward: - hidden_states = self.retrieve_cached_states(hidden_states) + # compute hidden_states by reading from the cached state + hidden_states = self.cache.read(hidden_states) else: - # if teacache is enabled, we need to cache the original hidden states - if self.enable_teacache: + if self.cache: original_hidden_states = hidden_states.clone() for block in self.blocks: hidden_states = block( hidden_states, encoder_hidden_states, timestep_proj, freqs_cis ) - # if teacache is enabled, we need to cache the original hidden states - if self.enable_teacache: - self.maybe_cache_states(hidden_states, original_hidden_states) - self.cnt += 1 + + if self.cache: + # update the cache with the new hidden states + self.cache.write(hidden_states, original_hidden_states) if sequence_shard_enabled: hidden_states = hidden_states.contiguous() @@ -1196,55 +1199,5 @@ def forward( return output - def maybe_cache_states( - self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor - ) -> None: - """Cache residual with CFG positive/negative separation.""" - residual = hidden_states.squeeze(0) - original_hidden_states - if not self.is_cfg_negative: - self.previous_residual = residual - else: - self.previous_residual_negative = residual - - def should_skip_forward_for_cached_states(self, **kwargs) -> bool: - if not self.enable_teacache: - return False - ctx = self._get_teacache_context() - if ctx is None: - return False - - # Initialize Wan-specific parameters - teacache_params = ctx.teacache_params - use_ret_steps = teacache_params.use_ret_steps - start_skipping, end_skipping = teacache_params.get_skip_boundaries( - ctx.num_inference_steps, ctx.do_cfg - ) - - # Determine boundary step - is_boundary_step = self.cnt < start_skipping or self.cnt >= end_skipping - - timestep_proj = kwargs["timestep_proj"] - temb = kwargs["temb"] - modulated_inp = timestep_proj if use_ret_steps else temb - - self.is_cfg_negative = ctx.is_cfg_negative - - # Use shared helper to compute cache decision - should_calc = self._compute_teacache_decision( - modulated_inp=modulated_inp, - is_boundary_step=is_boundary_step, - coefficients=ctx.coefficients, - teacache_thresh=ctx.teacache_thresh, - ) - - return not should_calc - - def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Retrieve cached residual with CFG positive/negative separation.""" - if not self.is_cfg_negative: - return hidden_states + self.previous_residual - else: - return hidden_states + self.previous_residual_negative - EntryClass = WanTransformer3DModel diff --git a/python/sglang/multimodal_gen/test/server/perf_baselines.json b/python/sglang/multimodal_gen/test/server/perf_baselines.json index dacfe45088ad..5390b4afdba3 100644 --- a/python/sglang/multimodal_gen/test/server/perf_baselines.json +++ b/python/sglang/multimodal_gen/test/server/perf_baselines.json @@ -821,69 +821,70 @@ }, "wan2_1_t2v_1.3b_teacache_enabled": { "stages_ms": { - "TextEncodingStage": 1088.18, - "DenoisingStage": 3926.63, - "InputValidationStage": 0.07, - "LatentPreparationStage": 0.14, - "TimestepPreparationStage": 2.31, - "DecodingStage": 711.47 + "InputValidationStage": 0.04, + "TextEncodingStage": 1218.71, + "LatentPreparationStage": 0.13, + "TimestepPreparationStage": 2.3, + "DenoisingStage": 3968.66, + "DecodingStage": 632.27, + "per_frame_generation": null }, "denoise_step_ms": { - "0": 80.49, - "1": 147.49, - "2": 144.75, - "3": 144.41, - "4": 143.82, - "5": 138.89, - "6": 36.55, - "7": 102.12, - "8": 39.18, - "9": 104.07, - "10": 37.91, - "11": 107.4, - "12": 2.52, - "13": 36.6, - "14": 106.91, - "15": 3.17, - "16": 41.13, - "17": 95.25, - "18": 2.67, - "19": 37.88, - "20": 110.46, - "21": 2.66, - "22": 38.29, - "23": 108.86, - "24": 2.7, - "25": 38.46, - "26": 108.62, - "27": 2.71, - "28": 41.23, - "29": 107.12, - "30": 2.72, - "31": 36.52, - "32": 112.05, - "33": 2.5, - "34": 110.6, - "35": 38.24, - "36": 110.83, - "37": 37.63, - "38": 111.91, - "39": 37.14, - "40": 111.39, - "41": 38.64, - "42": 110.57, - "43": 38.09, - "44": 111.32, - "45": 148.48, - "46": 149.96, - "47": 143.63, - "48": 148.49, - "49": 151.0 - }, - "expected_e2e_ms": 6012.96, - "expected_avg_denoise_ms": 78.45, - "expected_median_denoise_ms": 102.49, - "estimated_full_test_time_s": 126.0 + "0": 84.07, + "1": 153.06, + "2": 149.78, + "3": 149.08, + "4": 149.29, + "5": 66.61, + "6": 110.3, + "7": 52.51, + "8": 114.69, + "9": 44.93, + "10": 109.02, + "11": 43.79, + "12": 2.58, + "13": 109.2, + "14": 43.64, + "15": 2.55, + "16": 107.91, + "17": 43.19, + "18": 2.58, + "19": 107.48, + "20": 42.33, + "21": 2.56, + "22": 107.52, + "23": 48.05, + "24": 2.58, + "25": 117.42, + "26": 40.87, + "27": 2.57, + "28": 109.58, + "29": 45.79, + "30": 2.57, + "31": 109.49, + "32": 44.01, + "33": 2.57, + "34": 109.47, + "35": 43.15, + "36": 107.93, + "37": 43.33, + "38": 109.47, + "39": 44.34, + "40": 110.87, + "41": 45.92, + "42": 116.65, + "43": 38.6, + "44": 108.5, + "45": 153.14, + "46": 152.2, + "47": 152.05, + "48": 151.15, + "49": 153.32 + }, + "expected_e2e_ms": 6182.06, + "expected_avg_denoise_ms": 79.29, + "expected_median_denoise_ms": 95.78, + "estimated_full_test_time_s": 84.6 }, "wan2_1_t2v_1.3b": { "stages_ms": { diff --git a/python/sglang/multimodal_gen/test/unit/test_sampling_params.py b/python/sglang/multimodal_gen/test/unit/test_sampling_params.py index d5b463f8bfe0..f4fca9a7082c 100644 --- a/python/sglang/multimodal_gen/test/unit/test_sampling_params.py +++ b/python/sglang/multimodal_gen/test/unit/test_sampling_params.py @@ -134,24 +134,19 @@ def coefficients_callback(_: TeaCacheParams) -> list[float]: def test_wan_teacache_boundaries_match_legacy_behavior(self): legacy_equivalent_cases = [ - (WanT2V_1_3B_SamplingParams().teacache_params, False, (5, 50)), - (WanT2V_1_3B_SamplingParams().teacache_params, True, (10, 100)), - (WanT2V_14B_SamplingParams().teacache_params, False, (1, 49)), - (WanT2V_14B_SamplingParams().teacache_params, True, (2, 98)), - (WanI2V_14B_480P_SamplingParam().teacache_params, False, (5, 50)), - (WanI2V_14B_480P_SamplingParam().teacache_params, True, (10, 100)), - (WanI2V_14B_720P_SamplingParam().teacache_params, False, (5, 50)), - (WanI2V_14B_720P_SamplingParam().teacache_params, True, (10, 100)), + (WanT2V_1_3B_SamplingParams().teacache_params, (5, 50)), + (WanT2V_14B_SamplingParams().teacache_params, (1, 49)), + (WanI2V_14B_480P_SamplingParam().teacache_params, (5, 50)), + (WanI2V_14B_720P_SamplingParam().teacache_params, (5, 50)), ] - for teacache_params, do_cfg, expected in legacy_equivalent_cases: + for teacache_params, expected in legacy_equivalent_cases: with self.subTest( use_ret_steps=teacache_params.use_ret_steps, - do_cfg=do_cfg, expected=expected, ): self.assertEqual( - teacache_params.get_skip_boundaries(50, do_cfg), + teacache_params.get_skip_boundaries(50), expected, )