From b604eb45700425e29a5ced98de02a79eadc47f78 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Sat, 28 Mar 2026 21:50:29 +0000 Subject: [PATCH 01/34] calibrate cache --- .../multimodal_gen/configs/sample/sampling_params.py | 11 ++++++++--- .../multimodal_gen/test/unit/test_sampling_params.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/sglang/multimodal_gen/configs/sample/sampling_params.py b/python/sglang/multimodal_gen/configs/sample/sampling_params.py index f4d353874b44..df5d32220887 100644 --- a/python/sglang/multimodal_gen/configs/sample/sampling_params.py +++ b/python/sglang/multimodal_gen/configs/sample/sampling_params.py @@ -161,9 +161,8 @@ class SamplingParams: # TeaCache parameters enable_teacache: bool = False - teacache_params: Any = ( - None # TeaCacheParams or WanTeaCacheParams, set by model-specific subclass - ) + cache_params: Any | None = None + calibrate_cache: bool = False # Profiling profile: bool = False @@ -615,6 +614,12 @@ def add_argument(*name_or_flags, **kwargs): "--enable-teacache", action="store_true", ) + add_argument( + "--calibrate-cache", + action="store_true", + default=SamplingParams.calibrate_cache, + help="Compute the values needed for caching.", + ) # profiling add_argument( diff --git a/python/sglang/multimodal_gen/test/unit/test_sampling_params.py b/python/sglang/multimodal_gen/test/unit/test_sampling_params.py index 4ac66ec175e1..18787dc13661 100644 --- a/python/sglang/multimodal_gen/test/unit/test_sampling_params.py +++ b/python/sglang/multimodal_gen/test/unit/test_sampling_params.py @@ -84,7 +84,7 @@ def coefficients_callback(_: TeaCacheParams) -> list[float]: params = SamplingParams( prompt="callable teacache", - teacache_params=TeaCacheParams( + cache_params=TeaCacheParams( coefficients_callback=coefficients_callback, ), ) From d814bc145ec2c6b4259e90fe8cb90ad404942832 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Sat, 28 Mar 2026 21:54:03 +0000 Subject: [PATCH 02/34] undo --- .../multimodal_gen/configs/sample/sampling_params.py | 9 +-------- .../multimodal_gen/test/unit/test_sampling_params.py | 2 +- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/python/sglang/multimodal_gen/configs/sample/sampling_params.py b/python/sglang/multimodal_gen/configs/sample/sampling_params.py index df5d32220887..3cce48ae0bb2 100644 --- a/python/sglang/multimodal_gen/configs/sample/sampling_params.py +++ b/python/sglang/multimodal_gen/configs/sample/sampling_params.py @@ -161,8 +161,7 @@ class SamplingParams: # TeaCache parameters enable_teacache: bool = False - cache_params: Any | None = None - calibrate_cache: bool = False + tecache_params: Any | None = None # Profiling profile: bool = False @@ -614,12 +613,6 @@ def add_argument(*name_or_flags, **kwargs): "--enable-teacache", action="store_true", ) - add_argument( - "--calibrate-cache", - action="store_true", - default=SamplingParams.calibrate_cache, - help="Compute the values needed for caching.", - ) # profiling add_argument( diff --git a/python/sglang/multimodal_gen/test/unit/test_sampling_params.py b/python/sglang/multimodal_gen/test/unit/test_sampling_params.py index 18787dc13661..4ac66ec175e1 100644 --- a/python/sglang/multimodal_gen/test/unit/test_sampling_params.py +++ b/python/sglang/multimodal_gen/test/unit/test_sampling_params.py @@ -84,7 +84,7 @@ def coefficients_callback(_: TeaCacheParams) -> list[float]: params = SamplingParams( prompt="callable teacache", - cache_params=TeaCacheParams( + teacache_params=TeaCacheParams( coefficients_callback=coefficients_callback, ), ) From 5c975dabf115cd9d3e0cdd1a90fee46b335669d1 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Sat, 28 Mar 2026 22:03:25 +0000 Subject: [PATCH 03/34] refactor teacache --- .../multimodal_gen/runtime/cache/__init__.py | 6 +- .../multimodal_gen/runtime/cache/teacache.py | 471 ++++++++---------- .../runtime/loader/weights_updater.py | 4 +- .../runtime/models/dits/base.py | 60 ++- .../runtime/models/dits/wanvideo.py | 93 +--- 5 files changed, 277 insertions(+), 357 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/cache/__init__.py b/python/sglang/multimodal_gen/runtime/cache/__init__.py index 62f0f8457f8f..9e756ab2aef3 100644 --- a/python/sglang/multimodal_gen/runtime/cache/__init__.py +++ b/python/sglang/multimodal_gen/runtime/cache/__init__.py @@ -16,12 +16,12 @@ enable_cache_on_transformer, get_scm_mask, ) -from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheContext, TeaCacheMixin +from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheState, TeaCacheStrategy __all__ = [ # TeaCache (always available) - "TeaCacheContext", - "TeaCacheMixin", + "TeaCacheState", + "TeaCacheStrategy", # cache-dit integration (lazy-loaded, requires cache-dit package) "CacheDitConfig", "enable_cache_on_transformer", diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index 8830f7ec20c4..144576444d84 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -1,316 +1,237 @@ # SPDX-License-Identifier: Apache-2.0 """ -TeaCache: Temporal similarity-based caching for diffusion models. - -TeaCache accelerates diffusion inference by selectively skipping redundant -computation when consecutive diffusion steps are similar enough. This is -achieved by tracking the L1 distance between modulated inputs across timesteps. - -Key concepts: -- Modulated input: The input to transformer blocks after timestep conditioning -- L1 distance: Measures how different consecutive timesteps are -- Threshold: When accumulated L1 distance exceeds threshold, force computation -- CFG support: Separate caches for positive and negative branches +TeaCache accelerates diffusion inference by skipping redundant forward +passes when consecutive denoising steps are sufficiently similar, as measured +by the accumulated relative L1 distance of modulated inputs. References: - TeaCache: Accelerating Diffusion Models with Temporal Similarity https://arxiv.org/abs/2411.14324 """ -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from __future__ import annotations -import numpy as np -import torch +import logging +from typing import TYPE_CHECKING -from sglang.multimodal_gen.configs.models import DiTConfig +logger = logging.getLogger(__name__) + +import torch if TYPE_CHECKING: from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams -@dataclass -class TeaCacheContext: - """Common context extracted for TeaCache skip decision. - - This context is populated from the forward_batch and forward_context - during each denoising step, providing all information needed to make - cache decisions. - - Attributes: - current_timestep: Current denoising timestep index (0-indexed). - num_inference_steps: Total number of inference steps. - do_cfg: Whether classifier-free guidance is enabled. - is_cfg_negative: True if currently processing negative CFG branch. - teacache_thresh: Threshold for accumulated L1 distance. - coefficients: Polynomial coefficients for L1 rescaling. - teacache_params: Full TeaCacheParams for model-specific access. - """ - - current_timestep: int - num_inference_steps: int - do_cfg: bool - is_cfg_negative: bool # For CFG branch selection - teacache_thresh: float - coefficients: list[float] - teacache_params: "TeaCacheParams" # Full params for model-specific access - - -class TeaCacheMixin: - """ - Mixin class providing TeaCache optimization functionality. - - TeaCache accelerates diffusion inference by selectively skipping redundant - computation when consecutive diffusion steps are similar enough. - - This mixin should be inherited by DiT model classes that want to support - TeaCache optimization. It provides: - - State management for tracking L1 distances - - CFG-aware caching (separate caches for positive/negative branches) - - Decision logic for when to compute vs. use cache - - Example usage in a DiT model: - class MyDiT(TeaCacheMixin, BaseDiT): - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - self._init_teacache_state() - - def forward(self, hidden_states, timestep, ...): - ctx = self._get_teacache_context() - if ctx is not None: - # Compute modulated input (model-specific, e.g., after timestep embedding) - modulated_input = self._compute_modulated_input(hidden_states, timestep) - is_boundary = (ctx.current_timestep == 0 or - ctx.current_timestep >= ctx.num_inference_steps - 1) - - should_calc = self._compute_teacache_decision( - modulated_inp=modulated_input, - is_boundary_step=is_boundary, - coefficients=ctx.coefficients, - teacache_thresh=ctx.teacache_thresh, - ) - - if not should_calc: - # Use cached residual (must implement retrieve_cached_states) - return self.retrieve_cached_states(hidden_states) - - # Normal forward pass... - output = self._transformer_forward(hidden_states, timestep, ...) - - # Cache states for next step - if ctx is not None: - self.maybe_cache_states(output, hidden_states) - - return output - - Subclass implementation notes: - - `_compute_modulated_input()`: Model-specific method to compute the input - after timestep conditioning (used for L1 distance calculation) - - `retrieve_cached_states()`: Must be overridden to return cached output - - `maybe_cache_states()`: Override to store states for cache retrieval - - Attributes: - cnt: Counter for tracking steps. - enable_teacache: Whether TeaCache is enabled. - previous_modulated_input: Cached modulated input for positive branch. - previous_residual: Cached residual for positive branch. - accumulated_rel_l1_distance: Accumulated L1 distance for positive branch. - is_cfg_negative: Whether currently processing negative CFG branch. - _supports_cfg_cache: Whether this model supports CFG cache separation. - - CFG-specific attributes (only when _supports_cfg_cache is True): - previous_modulated_input_negative: Cached input for negative branch. - previous_residual_negative: Cached residual for negative branch. - accumulated_rel_l1_distance_negative: L1 distance for negative branch. - """ - - # Models that support CFG cache separation (wan/hunyuan/zimage) - # Models not in this set (flux/qwen) auto-disable TeaCache when CFG is enabled - _CFG_SUPPORTED_PREFIXES: set[str] = {"wan", "hunyuan", "zimage"} - config: DiTConfig - - def _init_teacache_state(self) -> None: - """Initialize TeaCache state. Call this in subclass __init__.""" - # Common TeaCache state - self.cnt = 0 - self.enable_teacache = True - # Flag indicating if this model supports CFG cache separation - self._supports_cfg_cache = ( - self.config.prefix.lower() in self._CFG_SUPPORTED_PREFIXES - ) - - # Always initialize positive cache fields (used in all modes) +def _rescale_distance_tensor( + coefficients: list[float], x: torch.Tensor +) -> torch.Tensor: + """Polynomial rescaling using tensor operations (torch.compile friendly).""" + c = coefficients + return c[0] * x**4 + c[1] * x**3 + c[2] * x**2 + c[3] * x + c[4] + + +def _compute_rel_l1_distance_tensor( + current: torch.Tensor, previous: torch.Tensor +) -> torch.Tensor: + """Compute relative L1 distance as a tensor (torch.compile friendly).""" + prev_mean = previous.abs().mean() + curr_diff_mean = (current - previous).abs().mean() + rel_distance = torch.where( + prev_mean > 1e-9, + curr_diff_mean / prev_mean, + torch.where( + current.abs().mean() < 1e-9, + torch.zeros(1, device=current.device, dtype=current.dtype), + torch.full((1,), float("inf"), device=current.device, dtype=current.dtype), + ), + ) + return rel_distance.squeeze() + + +class TeaCacheState: + """Tracks step progress, cached tensors, and L1 distances for a single CFG path. Updated every timestep.""" + + def __init__(self) -> None: + self.step: int = 0 self.previous_modulated_input: torch.Tensor | None = None self.previous_residual: torch.Tensor | None = None - self.accumulated_rel_l1_distance: float = 0.0 - - self.is_cfg_negative = False - # CFG-specific fields initialized to None (created when CFG is used) - # These are only used when _supports_cfg_cache is True AND do_cfg is True - if self._supports_cfg_cache: - self.previous_modulated_input_negative: torch.Tensor | None = None - self.previous_residual_negative: torch.Tensor | None = None - self.accumulated_rel_l1_distance_negative: float = 0.0 - - def reset_teacache_state(self) -> None: - """Reset all TeaCache state at the start of each generation task.""" - self.cnt = 0 + self.accumulated_rel_l1_distance: torch.Tensor | None = None - # Primary cache fields (always present) + def reset(self) -> None: + """Clear all cached tensors and reset the step counter for a new generation.""" + self.step = 0 self.previous_modulated_input = None self.previous_residual = None - self.accumulated_rel_l1_distance = 0.0 - self.is_cfg_negative = False - self.enable_teacache = True - # CFG negative cache fields (always reset, may be unused) - if self._supports_cfg_cache: - self.previous_modulated_input_negative = None - self.previous_residual_negative = None - self.accumulated_rel_l1_distance_negative = 0.0 - - def _compute_l1_and_decide( - self, - modulated_inp: torch.Tensor, - coefficients: list[float], - teacache_thresh: float, - ) -> tuple[float, bool]: - """ - Compute L1 distance and decide whether to calculate or use cache. + self.accumulated_rel_l1_distance = None - Args: - modulated_inp: Current timestep's modulated input. - coefficients: Polynomial coefficients for L1 rescaling. - teacache_thresh: Threshold for cache decision. + def update( + self, modulated_inp: torch.Tensor | None, previous_residual: torch.Tensor | None + ) -> None: + """Store the current modulated input and its computed residual for possible future reuse.""" + self.previous_modulated_input = modulated_inp + self.previous_residual = previous_residual - Returns: - Tuple of (new_accumulated_distance, should_calc). - """ - prev_modulated_inp = ( - self.previous_modulated_input_negative - if self.is_cfg_negative - else self.previous_modulated_input - ) + def __repr__(self): + return f"TeaCacheState(step={self.step}, accumulated_rel_l1_distance={self.accumulated_rel_l1_distance})" - # Defensive check: if previous input is not set, force calculation - if prev_modulated_inp is None: - return 0.0, True - # Compute relative L1 distance - diff = modulated_inp - prev_modulated_inp - rel_l1 = (diff.abs().mean() / prev_modulated_inp.abs().mean()).cpu().item() +class TeaCacheStrategy: + """Implements TeaCache to skip redundant diffusion forward passes. - # Apply polynomial rescaling - rescale_func = np.poly1d(coefficients) + TeaCacheStrategy manages two TeaCacheState objects (positive + optional + negative CFG branch) and stores parameters needed to make the skip decision. + """ - accumulated_rel_l1_distance = ( - self.accumulated_rel_l1_distance_negative - if self.is_cfg_negative - else self.accumulated_rel_l1_distance + def __init__(self, supports_cfg: bool) -> None: + """Initialize cache states for positive and optional negative CFG branches.""" + # params updated every forward pass + self.state = TeaCacheState() + self.state_neg = TeaCacheState() if supports_cfg else None + # params updated at the start of each new generation + # set in maybe_reset() + self.cache_params: TeaCacheParams | None = None + self.coefficients: list[float] = [] + self.num_steps: int = 0 + self.start_skipping: int | None = None + self.end_skipping: int | None = None + + def _get_state(self) -> TeaCacheState: + """Select the appropriate cache state (positive/negative cfg) based on the forward context.""" + from sglang.multimodal_gen.runtime.managers.forward_context import ( + get_forward_context, ) - accumulated_rel_l1_distance = accumulated_rel_l1_distance + rescale_func(rel_l1) - if accumulated_rel_l1_distance >= teacache_thresh: - # Threshold exceeded: force compute and reset accumulator - return 0.0, True - # Cache hit: keep accumulated distance - return accumulated_rel_l1_distance, False + fb = get_forward_context().forward_batch + is_cfg_negative = fb.is_cfg_negative if fb is not None else False + if is_cfg_negative and self.state_neg is not None: + return self.state_neg + return self.state - def _compute_teacache_decision( - self, - modulated_inp: torch.Tensor, - is_boundary_step: bool, - coefficients: list[float], - teacache_thresh: float, - ) -> bool: - """ - Compute cache decision for TeaCache. + def maybe_reset(self, **kwargs) -> None: + """Maybe reset the TeaCacheState by doing three things: - Args: - modulated_inp: Current timestep's modulated input. - is_boundary_step: True for boundary timesteps that always compute. - coefficients: Polynomial coefficients for L1 rescaling. - teacache_thresh: Threshold for cache decision. + 1. Reset TeaCacheState if the previous generation is complete + 2. Initialize parameters if at the start of a new generation. + 3. Increment the state's timestep counter (always) - Returns: - True if forward computation is needed, False to use cache. - """ - if not self.enable_teacache: - return True - - if is_boundary_step: - new_accum, should_calc = 0.0, True - else: - new_accum, should_calc = self._compute_l1_and_decide( - modulated_inp=modulated_inp, - coefficients=coefficients, - teacache_thresh=teacache_thresh, - ) - - # Advance baseline and accumulator for the active branch - if not self.is_cfg_negative: - self.previous_modulated_input = modulated_inp.clone() - self.accumulated_rel_l1_distance = new_accum - elif self._supports_cfg_cache: - self.previous_modulated_input_negative = modulated_inp.clone() - self.accumulated_rel_l1_distance_negative = new_accum - - return should_calc - - def _get_teacache_context(self) -> TeaCacheContext | None: - """ - Check TeaCache preconditions and extract common context. - - Returns: - TeaCacheContext if TeaCache is enabled and properly configured, - None if should skip TeaCache logic entirely. + Called on every forward pass before should_skip(). """ from sglang.multimodal_gen.runtime.managers.forward_context import ( get_forward_context, ) - forward_context = get_forward_context() - forward_batch = forward_context.forward_batch - - # Early return checks - if ( - forward_batch is None - or not forward_batch.enable_teacache - or forward_batch.teacache_params is None - ): - return None - - teacache_params = forward_batch.teacache_params - - # Extract common values - current_timestep = forward_context.current_timestep - num_inference_steps = forward_batch.num_inference_steps - do_cfg = forward_batch.do_classifier_free_guidance - is_cfg_negative = forward_batch.is_cfg_negative - - # Reset at first timestep - if current_timestep == 0 and not self.is_cfg_negative: - self.reset_teacache_state() - - return TeaCacheContext( - current_timestep=current_timestep, - num_inference_steps=num_inference_steps, - do_cfg=do_cfg, - is_cfg_negative=is_cfg_negative, - teacache_thresh=teacache_params.teacache_thresh, - coefficients=teacache_params.get_coefficients(), - teacache_params=teacache_params, + state = self._get_state() + + # Reset state if we completed a generation + if state.step == self.num_steps and state.step > 0: + state.reset() + + # Initialize values if at the start of each new generation + if state.step == 0: + + # set the teacache parameters + fb = get_forward_context().forward_batch + assert ( + fb is not None + ), "TeaCacheStrategy required the forward_batch not be None" + self.cache_params = getattr(fb.sampling_params, "teacache_params", None) + + # set the number of inference steps + assert ( + self.cache_params is not None + ), "TeaCacheStrategy requires teacache_params in sampling_params" + self.num_steps = int(fb.num_inference_steps) + + # set the teacache coefficients + if self.cache_params.coefficients_callback: + self.coefficients = self.cache_params.coefficients_callback( + self.cache_params + ) + else: + self.coefficients = self.cache_params.coefficients + + # set the start and end skippable steps + if isinstance(self.cache_params.start_skipping, float): + start_skipping = int(self.num_steps * self.cache_params.start_skipping) + elif self.cache_params.start_skipping < 0: + start_skipping = self.num_steps + self.cache_params.start_skipping + else: + start_skipping = self.cache_params.start_skipping + + if isinstance(self.cache_params.end_skipping, float): + end_skipping = int(self.num_steps * self.cache_params.end_skipping) + elif self.cache_params.end_skipping < 0: + end_skipping = self.num_steps + self.cache_params.end_skipping + else: + end_skipping = self.cache_params.end_skipping + + if start_skipping > end_skipping: + logger.warning( + f"TeaCache skip window is invalid (start_skipping={self.start_skipping} > " + f"end_skipping={self.end_skipping}) for num_inference_steps={self.num_steps}. " + "This can happen during warmup runs with very few steps. TeaCache is disabled." + ) + self.start_skipping = self.end_skipping = None + else: + self.start_skipping, self.end_skipping = start_skipping, end_skipping + + # increment the number of steps always + state.step += 1 + + def should_skip( + self, modulated_input: torch.Tensor | None = None, **kwargs + ) -> bool: + """Decide whether this forward pass can be skipped based on the accumulated L1 distance of the modulated input.""" + state = self._get_state() + assert self.cache_params is not None + + # No valid skip window for this generation + if self.start_skipping is None or self.end_skipping is None: + return False + + # Boundary steps always compute + if state.step < self.start_skipping or state.step >= self.end_skipping: + return False + + # First time computing, no previous input to compare against + if state.accumulated_rel_l1_distance is None: + state.accumulated_rel_l1_distance = torch.zeros( + 1, device=modulated_input.device, dtype=modulated_input.dtype + ) + return False + + # compute the accumulated relative l1 distance + assert state.previous_modulated_input is not None + assert modulated_input is not None + rel_l1 = _compute_rel_l1_distance_tensor( + modulated_input, state.previous_modulated_input ) + rescaled = _rescale_distance_tensor(self.coefficients, rel_l1) + state.accumulated_rel_l1_distance += rescaled - def maybe_cache_states( - self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor - ) -> None: - """Cache states for later retrieval. Override in subclass if needed.""" - pass + # If below threshold, skip the forward pass + if state.accumulated_rel_l1_distance < self.cache_params.rel_l1_thresh: + return True - def should_skip_forward_for_cached_states(self, **kwargs: dict[str, Any]) -> bool: - """Check if forward can be skipped using cached states.""" + # If threshold exceeded, reset accumulated so next window starts fresh + state.accumulated_rel_l1_distance = torch.zeros( + 1, device=modulated_input.device, dtype=modulated_input.dtype + ) return False - def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Retrieve cached states. Must be implemented by subclass.""" - raise NotImplementedError("retrieve_cached_states is not implemented") + def write( + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + modulated_input: torch.Tensor | None = None, + **kwargs, + ) -> None: + """After the forward pass, cache the residual and the current modulated input.""" + assert self.cache_params is not None + residual = hidden_states.squeeze(0) - original_hidden_states + state = self._get_state() + state.update(modulated_input, residual) + + def read(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: + """Before the forward pass, read from the cache and apply it to the current hidden states.""" + return hidden_states + self._get_state().previous_residual diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index f170809a738e..9ba1d5fc0a02 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -46,7 +46,7 @@ import torch from torch.distributed.tensor import DTensor, distribute_tensor -from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheMixin +from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheStrategy from sglang.multimodal_gen.runtime.loader.utils import ( _list_safetensors_files, ) @@ -213,7 +213,7 @@ def update_weights_from_disk( if success and flush_cache: for _, module in modules_to_update: - if isinstance(module, TeaCacheMixin): + if isinstance(module, TeaCacheStrategy): module.reset_teacache_state() logger.info(message) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/base.py b/python/sglang/multimodal_gen/runtime/models/dits/base.py index 9816f5fb03a7..f9d5692b06de 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/base.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/base.py @@ -8,16 +8,13 @@ from torch import nn from sglang.multimodal_gen.configs.models import DiTConfig - -# NOTE: TeaCacheContext and TeaCacheMixin have been moved to -# sglang.multimodal_gen.runtime.cache.teacache -# For backwards compatibility, re-export from the new location -from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheContext # noqa: F401 -from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheMixin +from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheStrategy from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) -# TODO class BaseDiT(nn.Module, ABC): _fsdp_shard_conditions: list = [] _compile_conditions: list = [] @@ -87,11 +84,15 @@ def device(self) -> torch.device: return next(self.parameters()).device -class CachableDiT(TeaCacheMixin, BaseDiT): +_CFG_SUPPORTED_PREFIXES: set[str] = {"wan", "hunyuan", "zimage"} + + +class CachableDiT(BaseDiT): """ - An intermediate base class that adds TeaCache optimization functionality to DiT models. + An intermediate base class that adds timestep-caching support for DiT models + such as TeaCache. - Inherits TeaCacheMixin for cache logic and BaseDiT for core DiT functionality. + Inherits `BaseDiT` for core DiT functionality and stores cache logic in `self.cache`. """ # These are required class attributes that should be overridden by concrete implementations @@ -109,8 +110,45 @@ class CachableDiT(TeaCacheMixin, BaseDiT): ) def __init__(self, config: DiTConfig, **kwargs) -> None: + """ + Args: + config: DiT model configuration. + **kwargs: Passed through to BaseDiT (e.g. hf_config). + + Attributes: + cache: Active cache strategy, or a sentinel: + - None: uninitialized; init_cache() has not been called yet. + - False: no cache strategy requested. + - DiffusionCache: an active cache strategy (e.g. TeaCacheStrategy). + calibrate_cache: When True, run every forward pass to calibrate + the values needed for caching. + """ super().__init__(config, **kwargs) - self._init_teacache_state() + self.cache: TeaCacheStrategy | bool | None = None + self.calibrate_cache: bool = False + + def init_cache(self) -> None: + """Construct the cache strategy from the current forward_batch context. + + Called lazily on the first forward pass because sampling params + (e.g. `enable_teacache`) are only available then. + """ + from sglang.multimodal_gen.runtime.managers.forward_context import ( + get_forward_context, + ) + + fb = get_forward_context().forward_batch + if fb is None: + return + + # caching strategies may handle pos/neg cfg separately + supports_cfg = self.config.prefix.lower() in _CFG_SUPPORTED_PREFIXES + + # select caching strategy + if fb.enable_teacache: + self.cache = TeaCacheStrategy(supports_cfg) + else: + self.cache = False @classmethod def get_nunchaku_quant_rules(cls) -> dict[str, dict[str, Any]]: diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index b193bf808324..06b4e1761e3d 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -900,8 +900,6 @@ def __init__( ) # For type checking - - self.cnt = 0 self.__post_init__() # misc @@ -955,16 +953,22 @@ def forward( guidance=None, **kwargs, ) -> torch.Tensor: - forward_batch = get_forward_context().forward_batch + + # if caching is enabled, we might initialize or reset the cache state + if self.cache is None: + self.init_cache() + if self.cache: + self.cache.maybe_reset() + + forward_context = get_forward_context() + forward_batch = forward_context.forward_batch + if forward_batch is not None: sequence_shard_enabled = ( forward_batch.enable_sequence_shard and self.sp_size > 1 ) else: sequence_shard_enabled = False - self.enable_teacache = ( - forward_batch is not None and forward_batch.enable_teacache - ) orig_dtype = hidden_states.dtype if not isinstance(encoder_hidden_states, torch.Tensor): @@ -1096,25 +1100,32 @@ def forward( # 4. Transformer blocks # if caching is enabled, we might be able to skip the forward pass - should_skip_forward = self.should_skip_forward_for_cached_states( - timestep_proj=timestep_proj, temb=temb - ) + should_skip_forward = False + if self.cache: + modulated_input = ( + timestep_proj if self.cache.cache_params.use_ret_steps else temb + ) + should_skip_forward = self.cache.should_skip(modulated_input) if should_skip_forward: - hidden_states = self.retrieve_cached_states(hidden_states) + # compute hidden_states using the cached state + hidden_states = self.cache.read(hidden_states) else: - # if teacache is enabled, we need to cache the original hidden states - if self.enable_teacache: + if self.cache: original_hidden_states = hidden_states.clone() for block in self.blocks: hidden_states = block( hidden_states, encoder_hidden_states, timestep_proj, freqs_cis ) - # if teacache is enabled, we need to cache the original hidden states - if self.enable_teacache: - self.maybe_cache_states(hidden_states, original_hidden_states) - self.cnt += 1 + + if self.cache: + # update the cache with the new hidden states + self.cache.write( + hidden_states, + original_hidden_states, + modulated_input=modulated_input, + ) if sequence_shard_enabled: hidden_states = hidden_states.contiguous() @@ -1152,55 +1163,5 @@ def forward( return output - def maybe_cache_states( - self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor - ) -> None: - """Cache residual with CFG positive/negative separation.""" - residual = hidden_states.squeeze(0) - original_hidden_states - if not self.is_cfg_negative: - self.previous_residual = residual - else: - self.previous_residual_negative = residual - - def should_skip_forward_for_cached_states(self, **kwargs) -> bool: - if not self.enable_teacache: - return False - ctx = self._get_teacache_context() - if ctx is None: - return False - - # Initialize Wan-specific parameters - teacache_params = ctx.teacache_params - use_ret_steps = teacache_params.use_ret_steps - start_skipping, end_skipping = teacache_params.get_skip_boundaries( - ctx.num_inference_steps, ctx.do_cfg - ) - - # Determine boundary step - is_boundary_step = self.cnt < start_skipping or self.cnt >= end_skipping - - timestep_proj = kwargs["timestep_proj"] - temb = kwargs["temb"] - modulated_inp = timestep_proj if use_ret_steps else temb - - self.is_cfg_negative = ctx.is_cfg_negative - - # Use shared helper to compute cache decision - should_calc = self._compute_teacache_decision( - modulated_inp=modulated_inp, - is_boundary_step=is_boundary_step, - coefficients=ctx.coefficients, - teacache_thresh=ctx.teacache_thresh, - ) - - return not should_calc - - def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Retrieve cached residual with CFG positive/negative separation.""" - if not self.is_cfg_negative: - return hidden_states + self.previous_residual - else: - return hidden_states + self.previous_residual_negative - EntryClass = WanTransformer3DModel From 6b053aaab14469bb1d95ec19ed433ccba7f4ee08 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Sat, 28 Mar 2026 22:16:06 +0000 Subject: [PATCH 04/34] teacacheparms is stateless --- .../multimodal_gen/configs/sample/teacache.py | 23 ----- .../multimodal_gen/runtime/cache/teacache.py | 88 +++++++++++-------- 2 files changed, 49 insertions(+), 62 deletions(-) diff --git a/python/sglang/multimodal_gen/configs/sample/teacache.py b/python/sglang/multimodal_gen/configs/sample/teacache.py index e20df6e67b93..e861a5383401 100644 --- a/python/sglang/multimodal_gen/configs/sample/teacache.py +++ b/python/sglang/multimodal_gen/configs/sample/teacache.py @@ -57,26 +57,3 @@ class TeaCacheParams(CacheParams): ) use_ret_steps: bool | None = None - def get_coefficients(self) -> list[float]: - if self.coefficients_callback is not None: - return self.coefficients_callback(self) - return self.coefficients - - def get_skip_boundaries( - self, num_inference_steps: int, do_cfg: bool - ) -> tuple[int, int]: - def _resolve_boundary(value: int | float) -> int: - if isinstance(value, float): - return int(num_inference_steps * value) - if value < 0: - return num_inference_steps + value - return value - - start_skipping = _resolve_boundary(self.start_skipping) - end_skipping = _resolve_boundary(self.end_skipping) - - if do_cfg: - start_skipping *= 2 - end_skipping *= 2 - - return start_skipping, end_skipping diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index 144576444d84..8ae6f3560d9d 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -89,7 +89,7 @@ def __init__(self, supports_cfg: bool) -> None: self.state_neg = TeaCacheState() if supports_cfg else None # params updated at the start of each new generation # set in maybe_reset() - self.cache_params: TeaCacheParams | None = None + self.teacache_params: TeaCacheParams | None = None self.coefficients: list[float] = [] self.num_steps: int = 0 self.start_skipping: int | None = None @@ -101,12 +101,39 @@ def _get_state(self) -> TeaCacheState: get_forward_context, ) - fb = get_forward_context().forward_batch - is_cfg_negative = fb.is_cfg_negative if fb is not None else False + forward_batch = get_forward_context().forward_batch + is_cfg_negative = forward_batch.is_cfg_negative if forward_batch is not None else False if is_cfg_negative and self.state_neg is not None: return self.state_neg return self.state + def get_skip_boundaries( + self, start_skipping, end_skipping, num_inference_steps: int, do_cfg: bool + ) -> tuple[int, int]: + def _resolve_boundary(value: int | float) -> int: + if isinstance(value, float): + return int(num_inference_steps * value) + if value < 0: + return num_inference_steps + value + return value + + start_skipping = _resolve_boundary(start_skipping) + end_skipping = _resolve_boundary(end_skipping) + + if do_cfg: + start_skipping *= 2 + end_skipping *= 2 + + if start_skipping > end_skipping: + logger.warning( + f"TeaCache skip window is invalid. Expected start_skipping<=end_skipping but got {start_skipping=}" + f" > {end_skipping=})for {num_inference_steps=}. This can happen during warmup runs with very few" + " steps. TeaCache is disabled." + ) + start_skipping, end_skipping = None, None + + return start_skipping, end_skipping + def maybe_reset(self, **kwargs) -> None: """Maybe reset the TeaCacheState by doing three things: @@ -126,54 +153,37 @@ def maybe_reset(self, **kwargs) -> None: if state.step == self.num_steps and state.step > 0: state.reset() - # Initialize values if at the start of each new generation + # Initialize values at the start of each new generation if state.step == 0: # set the teacache parameters - fb = get_forward_context().forward_batch + forward_batch = get_forward_context().forward_batch assert ( - fb is not None + forward_batch is not None ), "TeaCacheStrategy required the forward_batch not be None" - self.cache_params = getattr(fb.sampling_params, "teacache_params", None) + self.teacache_params = getattr(forward_batch.sampling_params, "teacache_params", None) # set the number of inference steps assert ( - self.cache_params is not None + self.teacache_params is not None ), "TeaCacheStrategy requires teacache_params in sampling_params" - self.num_steps = int(fb.num_inference_steps) + self.num_steps = int(forward_batch.num_inference_steps) # set the teacache coefficients - if self.cache_params.coefficients_callback: - self.coefficients = self.cache_params.coefficients_callback( - self.cache_params + if self.teacache_params.coefficients_callback: + self.coefficients = self.teacache_params.coefficients_callback( + self.teacache_params ) else: - self.coefficients = self.cache_params.coefficients + self.coefficients = self.teacache_params.coefficients # set the start and end skippable steps - if isinstance(self.cache_params.start_skipping, float): - start_skipping = int(self.num_steps * self.cache_params.start_skipping) - elif self.cache_params.start_skipping < 0: - start_skipping = self.num_steps + self.cache_params.start_skipping - else: - start_skipping = self.cache_params.start_skipping - - if isinstance(self.cache_params.end_skipping, float): - end_skipping = int(self.num_steps * self.cache_params.end_skipping) - elif self.cache_params.end_skipping < 0: - end_skipping = self.num_steps + self.cache_params.end_skipping - else: - end_skipping = self.cache_params.end_skipping - - if start_skipping > end_skipping: - logger.warning( - f"TeaCache skip window is invalid (start_skipping={self.start_skipping} > " - f"end_skipping={self.end_skipping}) for num_inference_steps={self.num_steps}. " - "This can happen during warmup runs with very few steps. TeaCache is disabled." - ) - self.start_skipping = self.end_skipping = None - else: - self.start_skipping, self.end_skipping = start_skipping, end_skipping + self.start_skipping, self.end_skipping = self.get_skip_boundaries( + self.teacache_params.start_skipping, + self.teacache_params.end_skipping, + self.num_steps, + do_cfg=forward_batch.do_classifier_free_guidance, + ) # increment the number of steps always state.step += 1 @@ -183,7 +193,7 @@ def should_skip( ) -> bool: """Decide whether this forward pass can be skipped based on the accumulated L1 distance of the modulated input.""" state = self._get_state() - assert self.cache_params is not None + assert self.teacache_params is not None # No valid skip window for this generation if self.start_skipping is None or self.end_skipping is None: @@ -210,7 +220,7 @@ def should_skip( state.accumulated_rel_l1_distance += rescaled # If below threshold, skip the forward pass - if state.accumulated_rel_l1_distance < self.cache_params.rel_l1_thresh: + if state.accumulated_rel_l1_distance < self.teacache_params.rel_l1_thresh: return True # If threshold exceeded, reset accumulated so next window starts fresh @@ -227,7 +237,7 @@ def write( **kwargs, ) -> None: """After the forward pass, cache the residual and the current modulated input.""" - assert self.cache_params is not None + assert self.teacache_params is not None residual = hidden_states.squeeze(0) - original_hidden_states state = self._get_state() state.update(modulated_input, residual) From 9f8f300b704ab9849e79a8303bb0e867ec55c9aa Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Sat, 28 Mar 2026 22:49:15 +0000 Subject: [PATCH 05/34] fix spelling in params --- .../configs/sample/sampling_params.py | 2 +- .../multimodal_gen/runtime/cache/teacache.py | 26 +++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/python/sglang/multimodal_gen/configs/sample/sampling_params.py b/python/sglang/multimodal_gen/configs/sample/sampling_params.py index 3cce48ae0bb2..c87d9d73c1e5 100644 --- a/python/sglang/multimodal_gen/configs/sample/sampling_params.py +++ b/python/sglang/multimodal_gen/configs/sample/sampling_params.py @@ -161,7 +161,7 @@ class SamplingParams: # TeaCache parameters enable_teacache: bool = False - tecache_params: Any | None = None + teacache_params: Any | None = None # Profiling profile: bool = False diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index 8ae6f3560d9d..ac67c049e1a0 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -89,7 +89,7 @@ def __init__(self, supports_cfg: bool) -> None: self.state_neg = TeaCacheState() if supports_cfg else None # params updated at the start of each new generation # set in maybe_reset() - self.teacache_params: TeaCacheParams | None = None + self.cache_params: TeaCacheParams | None = None self.coefficients: list[float] = [] self.num_steps: int = 0 self.start_skipping: int | None = None @@ -161,26 +161,26 @@ def maybe_reset(self, **kwargs) -> None: assert ( forward_batch is not None ), "TeaCacheStrategy required the forward_batch not be None" - self.teacache_params = getattr(forward_batch.sampling_params, "teacache_params", None) + self.cache_params = getattr(forward_batch.sampling_params, "teacache_params", None) # set the number of inference steps assert ( - self.teacache_params is not None - ), "TeaCacheStrategy requires teacache_params in sampling_params" + self.cache_params is not None + ), "TeaCacheStrategy requires cache_params in sampling_params" self.num_steps = int(forward_batch.num_inference_steps) # set the teacache coefficients - if self.teacache_params.coefficients_callback: - self.coefficients = self.teacache_params.coefficients_callback( - self.teacache_params + if self.cache_params.coefficients_callback: + self.coefficients = self.cache_params.coefficients_callback( + self.cache_params ) else: - self.coefficients = self.teacache_params.coefficients + self.coefficients = self.cache_params.coefficients # set the start and end skippable steps self.start_skipping, self.end_skipping = self.get_skip_boundaries( - self.teacache_params.start_skipping, - self.teacache_params.end_skipping, + self.cache_params.start_skipping, + self.cache_params.end_skipping, self.num_steps, do_cfg=forward_batch.do_classifier_free_guidance, ) @@ -193,7 +193,7 @@ def should_skip( ) -> bool: """Decide whether this forward pass can be skipped based on the accumulated L1 distance of the modulated input.""" state = self._get_state() - assert self.teacache_params is not None + assert self.cache_params is not None # No valid skip window for this generation if self.start_skipping is None or self.end_skipping is None: @@ -220,7 +220,7 @@ def should_skip( state.accumulated_rel_l1_distance += rescaled # If below threshold, skip the forward pass - if state.accumulated_rel_l1_distance < self.teacache_params.rel_l1_thresh: + if state.accumulated_rel_l1_distance < self.cache_params.rel_l1_thresh: return True # If threshold exceeded, reset accumulated so next window starts fresh @@ -237,7 +237,7 @@ def write( **kwargs, ) -> None: """After the forward pass, cache the residual and the current modulated input.""" - assert self.teacache_params is not None + assert self.cache_params is not None residual = hidden_states.squeeze(0) - original_hidden_states state = self._get_state() state.update(modulated_input, residual) From a5b765a5f3733dc8d648ed369e345dc6b350f2f2 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Sat, 28 Mar 2026 23:00:26 +0000 Subject: [PATCH 06/34] precommit --- python/sglang/multimodal_gen/configs/sample/teacache.py | 1 - python/sglang/multimodal_gen/runtime/cache/teacache.py | 8 ++++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/multimodal_gen/configs/sample/teacache.py b/python/sglang/multimodal_gen/configs/sample/teacache.py index e861a5383401..43cc1b256129 100644 --- a/python/sglang/multimodal_gen/configs/sample/teacache.py +++ b/python/sglang/multimodal_gen/configs/sample/teacache.py @@ -56,4 +56,3 @@ class TeaCacheParams(CacheParams): default=None, repr=False ) use_ret_steps: bool | None = None - diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index ac67c049e1a0..f3b5edd24dd3 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -102,7 +102,9 @@ def _get_state(self) -> TeaCacheState: ) forward_batch = get_forward_context().forward_batch - is_cfg_negative = forward_batch.is_cfg_negative if forward_batch is not None else False + is_cfg_negative = ( + forward_batch.is_cfg_negative if forward_batch is not None else False + ) if is_cfg_negative and self.state_neg is not None: return self.state_neg return self.state @@ -161,7 +163,9 @@ def maybe_reset(self, **kwargs) -> None: assert ( forward_batch is not None ), "TeaCacheStrategy required the forward_batch not be None" - self.cache_params = getattr(forward_batch.sampling_params, "teacache_params", None) + self.cache_params = getattr( + forward_batch.sampling_params, "teacache_params", None + ) # set the number of inference steps assert ( From d15d4b9e5e67ad79835d1afcf607934dda9a5ff3 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Sat, 28 Mar 2026 23:23:24 +0000 Subject: [PATCH 07/34] make method private --- python/sglang/multimodal_gen/runtime/cache/teacache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index f3b5edd24dd3..7b3d376ca243 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -109,7 +109,7 @@ def _get_state(self) -> TeaCacheState: return self.state_neg return self.state - def get_skip_boundaries( + def _get_skip_boundaries( self, start_skipping, end_skipping, num_inference_steps: int, do_cfg: bool ) -> tuple[int, int]: def _resolve_boundary(value: int | float) -> int: @@ -182,7 +182,7 @@ def maybe_reset(self, **kwargs) -> None: self.coefficients = self.cache_params.coefficients # set the start and end skippable steps - self.start_skipping, self.end_skipping = self.get_skip_boundaries( + self.start_skipping, self.end_skipping = self._get_skip_boundaries( self.cache_params.start_skipping, self.cache_params.end_skipping, self.num_steps, From db02c5db0d001944c43a24379f9eef03f52fb56d Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Mon, 13 Apr 2026 19:18:24 +0300 Subject: [PATCH 08/34] update docs --- docs/diffusion/performance/cache/teacache.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/diffusion/performance/cache/teacache.md b/docs/diffusion/performance/cache/teacache.md index dd9691c43a4a..abeabc6703b8 100644 --- a/docs/diffusion/performance/cache/teacache.md +++ b/docs/diffusion/performance/cache/teacache.md @@ -36,9 +36,7 @@ accumulated += poly(coefficients)(rel_l1) ### CFG Support -For models that support CFG cache separation (Wan, Hunyuan, Z-Image), TeaCache maintains separate caches for positive and negative branches: -- `previous_modulated_input` / `previous_residual` for positive branch -- `previous_modulated_input_negative` / `previous_residual_negative` for negative branch +For models that support CFG cache separation (Wan, Hunyuan, Z-Image), `TeaCacheStrategy` maintains two separate `TeaCacheState` objects for the positive and negative branches. For models that don't support CFG separation (Flux, Qwen), TeaCache is automatically disabled when CFG is enabled. From 8f846785f13eed13591f3bdc38ba7e7013a7eed2 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Mon, 13 Apr 2026 19:25:01 +0300 Subject: [PATCH 09/34] better param name; better docs --- docs/diffusion/performance/cache/teacache.md | 20 +++++++++++++++---- .../multimodal_gen/configs/sample/hunyuan.py | 2 +- .../multimodal_gen/configs/sample/teacache.py | 4 ++-- .../multimodal_gen/configs/sample/wan.py | 8 ++++---- .../multimodal_gen/configs/sample/zimage.py | 2 +- .../runtime/models/dits/hunyuanvideo.py | 2 +- .../runtime/models/dits/mova_audio_dit.py | 2 +- .../runtime/models/dits/mova_video_dit.py | 2 +- 8 files changed, 27 insertions(+), 15 deletions(-) diff --git a/docs/diffusion/performance/cache/teacache.md b/docs/diffusion/performance/cache/teacache.md index abeabc6703b8..f600e68bf974 100644 --- a/docs/diffusion/performance/cache/teacache.md +++ b/docs/diffusion/performance/cache/teacache.md @@ -13,7 +13,19 @@ TeaCache works by: 3. When accumulated distance is below a threshold, reusing the cached residual 4. Supporting CFG (Classifier-Free Guidance) with separate positive/negative caches -## How It Works +## Implementation + +TeaCache is split into three classes: + +- **`TeaCacheParams`** — pure data class holding user-set parameters (`rel_l1_thresh`, `coefficients`, `start_skipping`, `end_skipping`). Set once per request, never mutated during inference. +- **`TeaCacheState`** — runtime state for one CFG branch: `step`, `previous_modulated_input`, `previous_residual`, `accumulated_rel_l1_distance`. Reset at the start of each generation. +- **`TeaCacheStrategy`** — all the logic. Owns two `TeaCacheState` objects (positive + optional negative CFG branch) and reads from `TeaCacheParams` to decide when to skip. + +At each denoising step, `TeaCacheStrategy` calls: +1. `maybe_reset()` — resets state if a generation just finished, initializes params at step 0, increments step counter +2. `should_skip()` — computes accumulated L1 distance and returns whether to skip +3. `read()` — if skipping, reads the cached residual and adds it to hidden states +4. `write()` — if computing, writes the new residual and modulated input to the cache ### L1 Distance Tracking @@ -36,7 +48,7 @@ accumulated += poly(coefficients)(rel_l1) ### CFG Support -For models that support CFG cache separation (Wan, Hunyuan, Z-Image), `TeaCacheStrategy` maintains two separate `TeaCacheState` objects for the positive and negative branches. +For models that support CFG separation (Wan, Hunyuan, Z-Image), `TeaCacheStrategy` maintains separate `TeaCacheState` objects for the positive and negative branches. For models that don't support CFG separation (Flux, Qwen), TeaCache is automatically disabled when CFG is enabled. @@ -48,7 +60,7 @@ TeaCache is configured via `TeaCacheParams` in the sampling parameters: from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams params = TeaCacheParams( - teacache_thresh=0.1, # Threshold for accumulated L1 distance + rel_l1_thresh=0.1, # Threshold for accumulated L1 distance coefficients=[1.0, 0.0, 0.0], # Polynomial coefficients for L1 rescaling ) ``` @@ -57,7 +69,7 @@ params = TeaCacheParams( | Parameter | Type | Description | |-----------|------|-------------| -| `teacache_thresh` | float | Threshold for accumulated L1 distance. Lower = more caching, faster but potentially lower quality | +| `rel_l1_thresh` | float | Threshold for accumulated L1 distance. Lower = more caching, faster but potentially lower quality | | `coefficients` | list[float] | Polynomial coefficients for L1 rescaling. Model-specific tuning | ### Model-Specific Configurations diff --git a/python/sglang/multimodal_gen/configs/sample/hunyuan.py b/python/sglang/multimodal_gen/configs/sample/hunyuan.py index c60b856630f0..61733d24868d 100644 --- a/python/sglang/multimodal_gen/configs/sample/hunyuan.py +++ b/python/sglang/multimodal_gen/configs/sample/hunyuan.py @@ -38,7 +38,7 @@ class HunyuanSamplingParams(SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.15, + rel_l1_thresh=0.15, # from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4HunyuanVideo/teacache_sample_video.py#L222 coefficients=[ 7.33226126e02, diff --git a/python/sglang/multimodal_gen/configs/sample/teacache.py b/python/sglang/multimodal_gen/configs/sample/teacache.py index 43cc1b256129..d2a70507e89e 100644 --- a/python/sglang/multimodal_gen/configs/sample/teacache.py +++ b/python/sglang/multimodal_gen/configs/sample/teacache.py @@ -17,7 +17,7 @@ class TeaCacheParams(CacheParams): Attributes: cache_type: (`str`, defaults to `teacache`): A string labeling these parameters as belonging to teacache. - teacache_thresh (`float`, defaults to `0.0`): + rel_l1_thresh (`float`, defaults to `0.0`): Threshold for accumulated relative L1 distance. When below this threshold, the forward pass is skipped. Recommended values: 0.25 for ~1.5x speedup, 0.4 for ~1.8x, 0.6 for ~2.0x. @@ -48,7 +48,7 @@ class TeaCacheParams(CacheParams): """ cache_type: str = "teacache" - teacache_thresh: float = 0.0 + rel_l1_thresh: float = 0.0 start_skipping: int | float = 5 end_skipping: int | float = -1 coefficients: list[float] = field(default_factory=list) diff --git a/python/sglang/multimodal_gen/configs/sample/wan.py b/python/sglang/multimodal_gen/configs/sample/wan.py index 0f147a9dcede..184fe7034e68 100644 --- a/python/sglang/multimodal_gen/configs/sample/wan.py +++ b/python/sglang/multimodal_gen/configs/sample/wan.py @@ -66,7 +66,7 @@ class WanT2V_1_3B_SamplingParams(SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.08, + rel_l1_thresh=0.08, use_ret_steps=True, coefficients_callback=_wan_1_3b_coefficients, start_skipping=5, @@ -102,7 +102,7 @@ class WanT2V_14B_SamplingParams(SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.20, + rel_l1_thresh=0.20, use_ret_steps=False, coefficients_callback=_wan_14b_coefficients, start_skipping=1, @@ -128,7 +128,7 @@ class WanI2V_14B_480P_SamplingParam(WanT2V_1_3B_SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.26, + rel_l1_thresh=0.26, use_ret_steps=True, coefficients_callback=_wan_14b_coefficients, start_skipping=5, @@ -156,7 +156,7 @@ class WanI2V_14B_720P_SamplingParam(WanT2V_14B_SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.3, + rel_l1_thresh=0.3, use_ret_steps=True, coefficients_callback=_wan_14b_coefficients, start_skipping=5, diff --git a/python/sglang/multimodal_gen/configs/sample/zimage.py b/python/sglang/multimodal_gen/configs/sample/zimage.py index 77a9dabf90de..9c728e182ab3 100644 --- a/python/sglang/multimodal_gen/configs/sample/zimage.py +++ b/python/sglang/multimodal_gen/configs/sample/zimage.py @@ -22,7 +22,7 @@ class ZImageTurboSamplingParams(SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.15, + rel_l1_thresh=0.15, coefficients=[ 7.33226126e02, -4.01131952e02, diff --git a/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py index 09a233ec9176..c36445359ff5 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py @@ -700,7 +700,7 @@ def should_skip_forward_for_cached_states(self, **kwargs) -> bool: teacache_params, TeaCacheParams ), "teacache_params is not a TeaCacheParams" num_inference_steps = forward_batch.num_inference_steps - teache_thresh = teacache_params.teacache_thresh + teache_thresh = teacache_params.rel_l1_thresh coefficients = teacache_params.coefficients diff --git a/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py b/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py index f3deaf649e9a..d4db6546a499 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py @@ -182,7 +182,7 @@ def __init__( self.num_channels_latents = out_dim self.layer_names = ["blocks"] self.cnt = 0 - self.teacache_thresh = 0 + self.rel_l1_thresh = 0 self.coefficients = [] self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = None diff --git a/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py b/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py index f6f520690e85..5950e1a0da27 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py @@ -497,7 +497,7 @@ def __init__( self.num_channels_latents = out_dim self.layer_names = ["blocks"] self.cnt = 0 - self.teacache_thresh = 0 + self.rel_l1_thresh = 0 self.coefficients = [] self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = None From 3c54d34c81004820419f868d3c2e112259406d85 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 14 Apr 2026 13:19:47 +0300 Subject: [PATCH 10/34] move get_coefficents, get_skip_boundaries to TeaCacheParams --- .../multimodal_gen/configs/sample/teacache.py | 35 +++++++++++++ .../multimodal_gen/runtime/cache/teacache.py | 50 +++---------------- 2 files changed, 43 insertions(+), 42 deletions(-) diff --git a/python/sglang/multimodal_gen/configs/sample/teacache.py b/python/sglang/multimodal_gen/configs/sample/teacache.py index d2a70507e89e..fb47dd8eb6d7 100644 --- a/python/sglang/multimodal_gen/configs/sample/teacache.py +++ b/python/sglang/multimodal_gen/configs/sample/teacache.py @@ -3,11 +3,14 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +import logging from dataclasses import dataclass, field from typing import Callable from sglang.multimodal_gen.configs.sample.sampling_params import CacheParams +logger = logging.getLogger(__name__) + @dataclass class TeaCacheParams(CacheParams): @@ -56,3 +59,35 @@ class TeaCacheParams(CacheParams): default=None, repr=False ) use_ret_steps: bool | None = None + + def _get_coefficients(self) -> list[float]: + if self.coefficients_callback is not None: + return self.coefficients_callback(self) + return self.coefficients + + def _get_skip_boundaries( + self, num_inference_steps: int, do_cfg: bool + ) -> tuple[int | None, int | None]: + def _resolve_boundary(value: int | float) -> int: + if isinstance(value, float): + return int(num_inference_steps * value) + if value < 0: + return num_inference_steps + value + return value + + start_skipping = _resolve_boundary(self.start_skipping) + end_skipping = _resolve_boundary(self.end_skipping) + + if do_cfg: + start_skipping *= 2 + end_skipping *= 2 + + if start_skipping > end_skipping: + logger.warning( + f"TeaCache skip window is invalid. Expected start_skipping<=end_skipping but got {start_skipping=}" + f" > {end_skipping=})for {num_inference_steps=}. This can happen during warmup runs with very few" + " steps. TeaCache is disabled." + ) + start_skipping, end_skipping = None, None + + return start_skipping, end_skipping diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index 7b3d376ca243..20c380433f6b 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -109,33 +109,6 @@ def _get_state(self) -> TeaCacheState: return self.state_neg return self.state - def _get_skip_boundaries( - self, start_skipping, end_skipping, num_inference_steps: int, do_cfg: bool - ) -> tuple[int, int]: - def _resolve_boundary(value: int | float) -> int: - if isinstance(value, float): - return int(num_inference_steps * value) - if value < 0: - return num_inference_steps + value - return value - - start_skipping = _resolve_boundary(start_skipping) - end_skipping = _resolve_boundary(end_skipping) - - if do_cfg: - start_skipping *= 2 - end_skipping *= 2 - - if start_skipping > end_skipping: - logger.warning( - f"TeaCache skip window is invalid. Expected start_skipping<=end_skipping but got {start_skipping=}" - f" > {end_skipping=})for {num_inference_steps=}. This can happen during warmup runs with very few" - " steps. TeaCache is disabled." - ) - start_skipping, end_skipping = None, None - - return start_skipping, end_skipping - def maybe_reset(self, **kwargs) -> None: """Maybe reset the TeaCacheState by doing three things: @@ -173,24 +146,17 @@ def maybe_reset(self, **kwargs) -> None: ), "TeaCacheStrategy requires cache_params in sampling_params" self.num_steps = int(forward_batch.num_inference_steps) - # set the teacache coefficients - if self.cache_params.coefficients_callback: - self.coefficients = self.cache_params.coefficients_callback( - self.cache_params + # get teacache coefficients and skip boundaries + self.coefficients = self.cache_params._get_coefficients() + self.start_skipping, self.end_skipping = ( + self.cache_params._get_skip_boundaries( + self.num_steps, + forward_batch.do_classifier_free_guidance, ) - else: - self.coefficients = self.cache_params.coefficients - - # set the start and end skippable steps - self.start_skipping, self.end_skipping = self._get_skip_boundaries( - self.cache_params.start_skipping, - self.cache_params.end_skipping, - self.num_steps, - do_cfg=forward_batch.do_classifier_free_guidance, ) - # increment the number of steps always - state.step += 1 + # always increment the number of steps + state.step += 1 def should_skip( self, modulated_input: torch.Tensor | None = None, **kwargs From 3a23fc79ba59f3e8778fbade0142ed43916b691e Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 14 Apr 2026 13:32:54 +0300 Subject: [PATCH 11/34] update --- python/sglang/multimodal_gen/configs/sample/teacache.py | 6 +----- python/sglang/multimodal_gen/runtime/cache/teacache.py | 5 +---- .../sglang/multimodal_gen/runtime/loader/weights_updater.py | 5 ++--- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/python/sglang/multimodal_gen/configs/sample/teacache.py b/python/sglang/multimodal_gen/configs/sample/teacache.py index fb47dd8eb6d7..aca2fb1408f3 100644 --- a/python/sglang/multimodal_gen/configs/sample/teacache.py +++ b/python/sglang/multimodal_gen/configs/sample/teacache.py @@ -66,7 +66,7 @@ def _get_coefficients(self) -> list[float]: return self.coefficients def _get_skip_boundaries( - self, num_inference_steps: int, do_cfg: bool + self, num_inference_steps: int ) -> tuple[int | None, int | None]: def _resolve_boundary(value: int | float) -> int: if isinstance(value, float): @@ -78,10 +78,6 @@ def _resolve_boundary(value: int | float) -> int: start_skipping = _resolve_boundary(self.start_skipping) end_skipping = _resolve_boundary(self.end_skipping) - if do_cfg: - start_skipping *= 2 - end_skipping *= 2 - if start_skipping > end_skipping: logger.warning( f"TeaCache skip window is invalid. Expected start_skipping<=end_skipping but got {start_skipping=}" diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index 20c380433f6b..f53eca4c477e 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -149,10 +149,7 @@ def maybe_reset(self, **kwargs) -> None: # get teacache coefficients and skip boundaries self.coefficients = self.cache_params._get_coefficients() self.start_skipping, self.end_skipping = ( - self.cache_params._get_skip_boundaries( - self.num_steps, - forward_batch.do_classifier_free_guidance, - ) + self.cache_params._get_skip_boundaries(self.num_steps) ) # always increment the number of steps diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index 9ba1d5fc0a02..d1451c9b80ad 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -46,7 +46,6 @@ import torch from torch.distributed.tensor import DTensor, distribute_tensor -from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheStrategy from sglang.multimodal_gen.runtime.loader.utils import ( _list_safetensors_files, ) @@ -213,8 +212,8 @@ def update_weights_from_disk( if success and flush_cache: for _, module in modules_to_update: - if isinstance(module, TeaCacheStrategy): - module.reset_teacache_state() + if hasattr(module, "cache"): + module.cache.reset() logger.info(message) return success, message From 56a00c189b3dacfa7e29fa7606a9406f85374d71 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 14 Apr 2026 15:42:57 +0300 Subject: [PATCH 12/34] init changes after first generation --- .../runtime/models/dits/base.py | 20 ++++++++++--------- .../runtime/models/dits/wanvideo.py | 5 +++-- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/base.py b/python/sglang/multimodal_gen/runtime/models/dits/base.py index f9d5692b06de..6e4d722668f4 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/base.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/base.py @@ -124,10 +124,10 @@ def __init__(self, config: DiTConfig, **kwargs) -> None: the values needed for caching. """ super().__init__(config, **kwargs) - self.cache: TeaCacheStrategy | bool | None = None + self.cache: TeaCacheStrategy | None = None self.calibrate_cache: bool = False - def init_cache(self) -> None: + def maybe_init_cache(self, timestep: int) -> None: """Construct the cache strategy from the current forward_batch context. Called lazily on the first forward pass because sampling params @@ -137,18 +137,20 @@ def init_cache(self) -> None: get_forward_context, ) - fb = get_forward_context().forward_batch - if fb is None: + forward_batch = get_forward_context().forward_batch + if forward_batch is None: return # caching strategies may handle pos/neg cfg separately supports_cfg = self.config.prefix.lower() in _CFG_SUPPORTED_PREFIXES - # select caching strategy - if fb.enable_teacache: - self.cache = TeaCacheStrategy(supports_cfg) - else: - self.cache = False + # initialize cache at the start of each new generation (timestep == 0 and cfg is positive for cfg-supporting models) + if timestep == 0 and ((supports_cfg and not forward_batch.is_cfg_negative) or not supports_cfg): + # select caching strategy + if forward_batch.enable_teacache: + self.cache = TeaCacheStrategy(supports_cfg) + else: + self.cache = None @classmethod def get_nunchaku_quant_rules(cls) -> dict[str, dict[str, Any]]: diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index 5ed64fdeaf79..b7270cb86172 100755 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -944,6 +944,7 @@ def __init__( ) # For type checking + self.cnt = 0 self.__post_init__() # misc @@ -999,8 +1000,7 @@ def forward( ) -> torch.Tensor: # if caching is enabled, we might initialize or reset the cache state - if self.cache is None: - self.init_cache() + self.maybe_init_cache(timestep) if self.cache: self.cache.maybe_reset() @@ -1170,6 +1170,7 @@ def forward( original_hidden_states, modulated_input=modulated_input, ) + self.cnt += 1 if sequence_shard_enabled: hidden_states = hidden_states.contiguous() From 25a34860db17ea979e6de7590e0e71c9e3ba60c5 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 14 Apr 2026 13:49:06 +0000 Subject: [PATCH 13/34] fix --- .../multimodal_gen/runtime/cache/teacache.py | 86 ++++--------------- .../runtime/loader/weights_updater.py | 2 +- .../runtime/models/dits/base.py | 26 ++++-- .../runtime/models/dits/wanvideo.py | 9 +- 4 files changed, 42 insertions(+), 81 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index f53eca4c477e..ae0591562402 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -12,15 +12,11 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING logger = logging.getLogger(__name__) import torch -if TYPE_CHECKING: - from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams - def _rescale_distance_tensor( coefficients: list[float], x: torch.Tensor @@ -57,13 +53,6 @@ def __init__(self) -> None: self.previous_residual: torch.Tensor | None = None self.accumulated_rel_l1_distance: torch.Tensor | None = None - def reset(self) -> None: - """Clear all cached tensors and reset the step counter for a new generation.""" - self.step = 0 - self.previous_modulated_input = None - self.previous_residual = None - self.accumulated_rel_l1_distance = None - def update( self, modulated_inp: torch.Tensor | None, previous_residual: torch.Tensor | None ) -> None: @@ -82,18 +71,21 @@ class TeaCacheStrategy: negative CFG branch) and stores parameters needed to make the skip decision. """ - def __init__(self, supports_cfg: bool) -> None: - """Initialize cache states for positive and optional negative CFG branches.""" - # params updated every forward pass + def __init__( + self, + supports_cfg: bool, + coefficients: list[float], + rel_l1_thresh: float, + start_skipping: int | None, + end_skipping: int | None, + ) -> None: + """Initialize cache states and all generation parameters.""" self.state = TeaCacheState() self.state_neg = TeaCacheState() if supports_cfg else None - # params updated at the start of each new generation - # set in maybe_reset() - self.cache_params: TeaCacheParams | None = None - self.coefficients: list[float] = [] - self.num_steps: int = 0 - self.start_skipping: int | None = None - self.end_skipping: int | None = None + self.coefficients = coefficients + self.rel_l1_thresh = rel_l1_thresh + self.start_skipping = start_skipping + self.end_skipping = end_skipping def _get_state(self) -> TeaCacheState: """Select the appropriate cache state (positive/negative cfg) based on the forward context.""" @@ -109,58 +101,15 @@ def _get_state(self) -> TeaCacheState: return self.state_neg return self.state - def maybe_reset(self, **kwargs) -> None: - """Maybe reset the TeaCacheState by doing three things: - - 1. Reset TeaCacheState if the previous generation is complete - 2. Initialize parameters if at the start of a new generation. - 3. Increment the state's timestep counter (always) - - Called on every forward pass before should_skip(). - """ - from sglang.multimodal_gen.runtime.managers.forward_context import ( - get_forward_context, - ) - - state = self._get_state() - - # Reset state if we completed a generation - if state.step == self.num_steps and state.step > 0: - state.reset() - - # Initialize values at the start of each new generation - if state.step == 0: - - # set the teacache parameters - forward_batch = get_forward_context().forward_batch - assert ( - forward_batch is not None - ), "TeaCacheStrategy required the forward_batch not be None" - self.cache_params = getattr( - forward_batch.sampling_params, "teacache_params", None - ) - - # set the number of inference steps - assert ( - self.cache_params is not None - ), "TeaCacheStrategy requires cache_params in sampling_params" - self.num_steps = int(forward_batch.num_inference_steps) - - # get teacache coefficients and skip boundaries - self.coefficients = self.cache_params._get_coefficients() - self.start_skipping, self.end_skipping = ( - self.cache_params._get_skip_boundaries(self.num_steps) - ) - - # always increment the number of steps - state.step += 1 + def maybe_reset(self) -> None: + """Increment the step counter.""" + self._get_state().step += 1 def should_skip( self, modulated_input: torch.Tensor | None = None, **kwargs ) -> bool: """Decide whether this forward pass can be skipped based on the accumulated L1 distance of the modulated input.""" state = self._get_state() - assert self.cache_params is not None # No valid skip window for this generation if self.start_skipping is None or self.end_skipping is None: @@ -187,7 +136,7 @@ def should_skip( state.accumulated_rel_l1_distance += rescaled # If below threshold, skip the forward pass - if state.accumulated_rel_l1_distance < self.cache_params.rel_l1_thresh: + if state.accumulated_rel_l1_distance < self.rel_l1_thresh: return True # If threshold exceeded, reset accumulated so next window starts fresh @@ -204,7 +153,6 @@ def write( **kwargs, ) -> None: """After the forward pass, cache the residual and the current modulated input.""" - assert self.cache_params is not None residual = hidden_states.squeeze(0) - original_hidden_states state = self._get_state() state.update(modulated_input, residual) diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index d1451c9b80ad..08cbe1367f57 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -212,7 +212,7 @@ def update_weights_from_disk( if success and flush_cache: for _, module in modules_to_update: - if hasattr(module, "cache"): + if hasattr(module, "cache") and module.cache is not None: module.cache.reset() logger.info(message) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/base.py b/python/sglang/multimodal_gen/runtime/models/dits/base.py index 6e4d722668f4..727108e5fb8c 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/base.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/base.py @@ -127,7 +127,7 @@ def __init__(self, config: DiTConfig, **kwargs) -> None: self.cache: TeaCacheStrategy | None = None self.calibrate_cache: bool = False - def maybe_init_cache(self, timestep: int) -> None: + def maybe_init_cache(self) -> None: """Construct the cache strategy from the current forward_batch context. Called lazily on the first forward pass because sampling params @@ -144,11 +144,27 @@ def maybe_init_cache(self, timestep: int) -> None: # caching strategies may handle pos/neg cfg separately supports_cfg = self.config.prefix.lower() in _CFG_SUPPORTED_PREFIXES - # initialize cache at the start of each new generation (timestep == 0 and cfg is positive for cfg-supporting models) - if timestep == 0 and ((supports_cfg and not forward_batch.is_cfg_negative) or not supports_cfg): + # initialize cache at the start of each new generation (step index == 0 and cfg is positive for cfg-supporting models) + current_timestep = get_forward_context().current_timestep + if current_timestep == 0 and ( + (supports_cfg and not forward_batch.is_cfg_negative) or not supports_cfg + ): # select caching strategy - if forward_batch.enable_teacache: - self.cache = TeaCacheStrategy(supports_cfg) + cache_params = getattr( + forward_batch.sampling_params, "teacache_params", None + ) + if forward_batch.enable_teacache and cache_params is not None: + num_steps = int(forward_batch.num_inference_steps) + start_skipping, end_skipping = cache_params._get_skip_boundaries( + num_steps + ) + self.cache = TeaCacheStrategy( + supports_cfg, + cache_params._get_coefficients(), + cache_params.rel_l1_thresh, + start_skipping, + end_skipping, + ) else: self.cache = None diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index b7270cb86172..350a05f58f1d 100755 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -944,7 +944,6 @@ def __init__( ) # For type checking - self.cnt = 0 self.__post_init__() # misc @@ -1000,7 +999,7 @@ def forward( ) -> torch.Tensor: # if caching is enabled, we might initialize or reset the cache state - self.maybe_init_cache(timestep) + self.maybe_init_cache() if self.cache: self.cache.maybe_reset() @@ -1146,9 +1145,8 @@ def forward( # if caching is enabled, we might be able to skip the forward pass should_skip_forward = False if self.cache: - modulated_input = ( - timestep_proj if self.cache.cache_params.use_ret_steps else temb - ) + use_ret_steps = forward_batch.sampling_params.teacache_params.use_ret_steps + modulated_input = timestep_proj if use_ret_steps else temb should_skip_forward = self.cache.should_skip(modulated_input) if should_skip_forward: @@ -1170,7 +1168,6 @@ def forward( original_hidden_states, modulated_input=modulated_input, ) - self.cnt += 1 if sequence_shard_enabled: hidden_states = hidden_states.contiguous() From 01ab38713dfe120b88a672dde2a64412c33d0fc8 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 14 Apr 2026 13:53:17 +0000 Subject: [PATCH 14/34] remove maybe_reset --- python/sglang/multimodal_gen/runtime/cache/teacache.py | 8 +++----- .../sglang/multimodal_gen/runtime/models/dits/wanvideo.py | 2 -- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index ae0591562402..5d1664caa4cf 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -101,22 +101,20 @@ def _get_state(self) -> TeaCacheState: return self.state_neg return self.state - def maybe_reset(self) -> None: - """Increment the step counter.""" - self._get_state().step += 1 - def should_skip( self, modulated_input: torch.Tensor | None = None, **kwargs ) -> bool: """Decide whether this forward pass can be skipped based on the accumulated L1 distance of the modulated input.""" state = self._get_state() + step = state.step + state.step += 1 # advance before returning, regardless of outcome # No valid skip window for this generation if self.start_skipping is None or self.end_skipping is None: return False # Boundary steps always compute - if state.step < self.start_skipping or state.step >= self.end_skipping: + if step < self.start_skipping or step >= self.end_skipping: return False # First time computing, no previous input to compare against diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index 350a05f58f1d..8cbfe7b1f605 100755 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -1000,8 +1000,6 @@ def forward( # if caching is enabled, we might initialize or reset the cache state self.maybe_init_cache() - if self.cache: - self.cache.maybe_reset() forward_context = get_forward_context() forward_batch = forward_context.forward_batch From 54ff4bf2e6e31b81db833d0a7c6d160ff7545a03 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 14 Apr 2026 13:58:20 +0000 Subject: [PATCH 15/34] state is dataclass, better skip step parsing --- .../multimodal_gen/configs/sample/teacache.py | 19 ++----- .../multimodal_gen/runtime/cache/teacache.py | 51 ++++++++----------- 2 files changed, 26 insertions(+), 44 deletions(-) diff --git a/python/sglang/multimodal_gen/configs/sample/teacache.py b/python/sglang/multimodal_gen/configs/sample/teacache.py index aca2fb1408f3..5424b3f0efc0 100644 --- a/python/sglang/multimodal_gen/configs/sample/teacache.py +++ b/python/sglang/multimodal_gen/configs/sample/teacache.py @@ -65,9 +65,7 @@ def _get_coefficients(self) -> list[float]: return self.coefficients_callback(self) return self.coefficients - def _get_skip_boundaries( - self, num_inference_steps: int - ) -> tuple[int | None, int | None]: + def _get_skip_boundaries(self, num_inference_steps: int) -> tuple[int, int]: def _resolve_boundary(value: int | float) -> int: if isinstance(value, float): return int(num_inference_steps * value) @@ -75,15 +73,6 @@ def _resolve_boundary(value: int | float) -> int: return num_inference_steps + value return value - start_skipping = _resolve_boundary(self.start_skipping) - end_skipping = _resolve_boundary(self.end_skipping) - - if start_skipping > end_skipping: - logger.warning( - f"TeaCache skip window is invalid. Expected start_skipping<=end_skipping but got {start_skipping=}" - f" > {end_skipping=})for {num_inference_steps=}. This can happen during warmup runs with very few" - " steps. TeaCache is disabled." - ) - start_skipping, end_skipping = None, None - - return start_skipping, end_skipping + return _resolve_boundary(self.start_skipping), _resolve_boundary( + self.end_skipping + ) diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index 5d1664caa4cf..7177f1b47eec 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -12,12 +12,23 @@ from __future__ import annotations import logging +from dataclasses import dataclass logger = logging.getLogger(__name__) import torch +@dataclass +class TeaCacheState: + """Tracks step progress, cached tensors, and L1 distances for a single CFG path.""" + + step: int = 0 + previous_modulated_input: torch.Tensor | None = None + previous_residual: torch.Tensor | None = None + accumulated_rel_l1_distance: torch.Tensor | None = None + + def _rescale_distance_tensor( coefficients: list[float], x: torch.Tensor ) -> torch.Tensor: @@ -44,26 +55,6 @@ def _compute_rel_l1_distance_tensor( return rel_distance.squeeze() -class TeaCacheState: - """Tracks step progress, cached tensors, and L1 distances for a single CFG path. Updated every timestep.""" - - def __init__(self) -> None: - self.step: int = 0 - self.previous_modulated_input: torch.Tensor | None = None - self.previous_residual: torch.Tensor | None = None - self.accumulated_rel_l1_distance: torch.Tensor | None = None - - def update( - self, modulated_inp: torch.Tensor | None, previous_residual: torch.Tensor | None - ) -> None: - """Store the current modulated input and its computed residual for possible future reuse.""" - self.previous_modulated_input = modulated_inp - self.previous_residual = previous_residual - - def __repr__(self): - return f"TeaCacheState(step={self.step}, accumulated_rel_l1_distance={self.accumulated_rel_l1_distance})" - - class TeaCacheStrategy: """Implements TeaCache to skip redundant diffusion forward passes. @@ -76,8 +67,8 @@ def __init__( supports_cfg: bool, coefficients: list[float], rel_l1_thresh: float, - start_skipping: int | None, - end_skipping: int | None, + start_skipping: int, + end_skipping: int, ) -> None: """Initialize cache states and all generation parameters.""" self.state = TeaCacheState() @@ -86,6 +77,12 @@ def __init__( self.rel_l1_thresh = rel_l1_thresh self.start_skipping = start_skipping self.end_skipping = end_skipping + if start_skipping >= end_skipping: + logger.warning( + f"TeaCache skip window is invalid (start_skipping={start_skipping} >= " + f"end_skipping={end_skipping}). This can happen during warmup runs with " + "very few steps. Skipping disabled." + ) def _get_state(self) -> TeaCacheState: """Select the appropriate cache state (positive/negative cfg) based on the forward context.""" @@ -109,11 +106,7 @@ def should_skip( step = state.step state.step += 1 # advance before returning, regardless of outcome - # No valid skip window for this generation - if self.start_skipping is None or self.end_skipping is None: - return False - - # Boundary steps always compute + # Boundary steps always compute (also handles invalid window where start >= end) if step < self.start_skipping or step >= self.end_skipping: return False @@ -151,9 +144,9 @@ def write( **kwargs, ) -> None: """After the forward pass, cache the residual and the current modulated input.""" - residual = hidden_states.squeeze(0) - original_hidden_states state = self._get_state() - state.update(modulated_input, residual) + state.previous_residual = hidden_states.squeeze(0) - original_hidden_states + state.previous_modulated_input = modulated_input def read(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: """Before the forward pass, read from the cache and apply it to the current hidden states.""" From cdba9e0ad80230a9792b192c9c977a99cb2ce550 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 14 Apr 2026 14:03:47 +0000 Subject: [PATCH 16/34] update doc --- docs/diffusion/performance/cache/teacache.md | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/docs/diffusion/performance/cache/teacache.md b/docs/diffusion/performance/cache/teacache.md index f600e68bf974..04dd9539d258 100644 --- a/docs/diffusion/performance/cache/teacache.md +++ b/docs/diffusion/performance/cache/teacache.md @@ -18,14 +18,13 @@ TeaCache works by: TeaCache is split into three classes: - **`TeaCacheParams`** — pure data class holding user-set parameters (`rel_l1_thresh`, `coefficients`, `start_skipping`, `end_skipping`). Set once per request, never mutated during inference. -- **`TeaCacheState`** — runtime state for one CFG branch: `step`, `previous_modulated_input`, `previous_residual`, `accumulated_rel_l1_distance`. Reset at the start of each generation. -- **`TeaCacheStrategy`** — all the logic. Owns two `TeaCacheState` objects (positive + optional negative CFG branch) and reads from `TeaCacheParams` to decide when to skip. - -At each denoising step, `TeaCacheStrategy` calls: -1. `maybe_reset()` — resets state if a generation just finished, initializes params at step 0, increments step counter -2. `should_skip()` — computes accumulated L1 distance and returns whether to skip -3. `read()` — if skipping, reads the cached residual and adds it to hidden states -4. `write()` — if computing, writes the new residual and modulated input to the cache +- **`TeaCacheState`** — dataclass holding runtime state for one CFG branch: `step`, `previous_modulated_input`, `previous_residual`, `accumulated_rel_l1_distance`. +- **`TeaCacheStrategy`** — all the logic. Owns two `TeaCacheState` objects (positive + optional negative CFG branch). Constructed once per generation by `CachableDiT.maybe_init_cache()` with all parameters resolved upfront. + +At each denoising step, the model calls: +1. `cache.should_skip(modulated_input)` — advances the step counter, computes accumulated L1 distance, returns whether to skip +2. `cache.read()` — if skipping, reads the residual from the cache and applies it to hidden states +3. `cache.write()` — if computing, stores the new residual and modulated input in the cache ### L1 Distance Tracking From ab61887fd5cfb0859a7db4010ba3d8b0cca28002 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 14 Apr 2026 14:08:58 +0000 Subject: [PATCH 17/34] cleaner --- python/sglang/multimodal_gen/configs/sample/teacache.py | 6 +++--- python/sglang/multimodal_gen/runtime/cache/teacache.py | 2 +- .../sglang/multimodal_gen/runtime/models/dits/wanvideo.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/multimodal_gen/configs/sample/teacache.py b/python/sglang/multimodal_gen/configs/sample/teacache.py index 5424b3f0efc0..d434572ad1b2 100644 --- a/python/sglang/multimodal_gen/configs/sample/teacache.py +++ b/python/sglang/multimodal_gen/configs/sample/teacache.py @@ -73,6 +73,6 @@ def _resolve_boundary(value: int | float) -> int: return num_inference_steps + value return value - return _resolve_boundary(self.start_skipping), _resolve_boundary( - self.end_skipping - ) + start_skipping = _resolve_boundary(self.start_skipping) + end_skipping = _resolve_boundary(self.end_skipping) + return start_skipping, end_skipping diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index 7177f1b47eec..3cfdcc1a3430 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -104,7 +104,7 @@ def should_skip( """Decide whether this forward pass can be skipped based on the accumulated L1 distance of the modulated input.""" state = self._get_state() step = state.step - state.step += 1 # advance before returning, regardless of outcome + state.step += 1 # always advance the step regardless of outcome # Boundary steps always compute (also handles invalid window where start >= end) if step < self.start_skipping or step >= self.end_skipping: diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index 8cbfe7b1f605..e7d8b7544b62 100755 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -1148,7 +1148,7 @@ def forward( should_skip_forward = self.cache.should_skip(modulated_input) if should_skip_forward: - # compute hidden_states using the cached state + # compute hidden_states by reading from the cached state hidden_states = self.cache.read(hidden_states) else: if self.cache: From af18a7acd1285ee2fc373ff654ae3dd3beae024f Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 14 Apr 2026 14:56:45 +0000 Subject: [PATCH 18/34] variable length polynomial --- python/sglang/multimodal_gen/runtime/cache/teacache.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index 3cfdcc1a3430..a42595645a3d 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -33,8 +33,10 @@ def _rescale_distance_tensor( coefficients: list[float], x: torch.Tensor ) -> torch.Tensor: """Polynomial rescaling using tensor operations (torch.compile friendly).""" - c = coefficients - return c[0] * x**4 + c[1] * x**3 + c[2] * x**2 + c[3] * x + c[4] + result = torch.zeros_like(x) + for i, c in enumerate(coefficients): + result = result + c * x ** (len(coefficients) - 1 - i) + return result def _compute_rel_l1_distance_tensor( From c1d4b4c84d2472ffcfb6386b2822240232b4bf66 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 14 Apr 2026 21:16:04 +0300 Subject: [PATCH 19/34] make reset_states and fix weight updater --- docs/diffusion/performance/cache/teacache.md | 1 + python/sglang/multimodal_gen/runtime/cache/teacache.py | 6 ++++++ .../sglang/multimodal_gen/runtime/loader/weights_updater.py | 2 +- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/diffusion/performance/cache/teacache.md b/docs/diffusion/performance/cache/teacache.md index 04dd9539d258..db828be7f6fb 100644 --- a/docs/diffusion/performance/cache/teacache.md +++ b/docs/diffusion/performance/cache/teacache.md @@ -25,6 +25,7 @@ At each denoising step, the model calls: 1. `cache.should_skip(modulated_input)` — advances the step counter, computes accumulated L1 distance, returns whether to skip 2. `cache.read()` — if skipping, reads the residual from the cache and applies it to hidden states 3. `cache.write()` — if computing, stores the new residual and modulated input in the cache +4. `cache.reset_states()` — resets `state` and optionally `state_neg`, discarding any stale tensors ### L1 Distance Tracking diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index a42595645a3d..c493228aad26 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -73,6 +73,7 @@ def __init__( end_skipping: int, ) -> None: """Initialize cache states and all generation parameters.""" + self.supports_cfg = supports_cfg self.state = TeaCacheState() self.state_neg = TeaCacheState() if supports_cfg else None self.coefficients = coefficients @@ -86,6 +87,11 @@ def __init__( "very few steps. Skipping disabled." ) + def reset_states(self) -> None: + """Reset cache states, discarding any stale tensors from a previous generation.""" + self.state = TeaCacheState() + self.state_neg = TeaCacheState() if self.supports_cfg else None + def _get_state(self) -> TeaCacheState: """Select the appropriate cache state (positive/negative cfg) based on the forward context.""" from sglang.multimodal_gen.runtime.managers.forward_context import ( diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index 08cbe1367f57..365c34cdd0b6 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -213,7 +213,7 @@ def update_weights_from_disk( if success and flush_cache: for _, module in modules_to_update: if hasattr(module, "cache") and module.cache is not None: - module.cache.reset() + module.cache.reset_states() logger.info(message) return success, message From 6f160e093039cb93da8ccfb26ccf2e2007a9ae03 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 14 Apr 2026 22:19:46 +0300 Subject: [PATCH 20/34] remove unused logging --- python/sglang/multimodal_gen/configs/sample/teacache.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/sglang/multimodal_gen/configs/sample/teacache.py b/python/sglang/multimodal_gen/configs/sample/teacache.py index d434572ad1b2..8fa248ab0924 100644 --- a/python/sglang/multimodal_gen/configs/sample/teacache.py +++ b/python/sglang/multimodal_gen/configs/sample/teacache.py @@ -3,14 +3,11 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations -import logging from dataclasses import dataclass, field from typing import Callable from sglang.multimodal_gen.configs.sample.sampling_params import CacheParams -logger = logging.getLogger(__name__) - @dataclass class TeaCacheParams(CacheParams): From 3b16cdbb27fe21cab04a7262b8d5bf357b2511a0 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 14 Apr 2026 22:30:52 +0300 Subject: [PATCH 21/34] better docs --- .../multimodal_gen/configs/sample/teacache.py | 4 ++-- .../runtime/models/dits/base.py | 22 +++++++++---------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/python/sglang/multimodal_gen/configs/sample/teacache.py b/python/sglang/multimodal_gen/configs/sample/teacache.py index 8fa248ab0924..78994c545278 100644 --- a/python/sglang/multimodal_gen/configs/sample/teacache.py +++ b/python/sglang/multimodal_gen/configs/sample/teacache.py @@ -57,12 +57,12 @@ class TeaCacheParams(CacheParams): ) use_ret_steps: bool | None = None - def _get_coefficients(self) -> list[float]: + def get_coefficients(self) -> list[float]: if self.coefficients_callback is not None: return self.coefficients_callback(self) return self.coefficients - def _get_skip_boundaries(self, num_inference_steps: int) -> tuple[int, int]: + def get_skip_boundaries(self, num_inference_steps: int) -> tuple[int, int]: def _resolve_boundary(value: int | float) -> int: if isinstance(value, float): return int(num_inference_steps * value) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/base.py b/python/sglang/multimodal_gen/runtime/models/dits/base.py index 727108e5fb8c..e4efa2bf8bfa 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/base.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/base.py @@ -110,28 +110,26 @@ class CachableDiT(BaseDiT): ) def __init__(self, config: DiTConfig, **kwargs) -> None: - """ + """Initialize cache state for a DiT model with caching support. + Args: config: DiT model configuration. **kwargs: Passed through to BaseDiT (e.g. hf_config). Attributes: - cache: Active cache strategy, or a sentinel: - - None: uninitialized; init_cache() has not been called yet. - - False: no cache strategy requested. - - DiffusionCache: an active cache strategy (e.g. TeaCacheStrategy). - calibrate_cache: When True, run every forward pass to calibrate - the values needed for caching. + cache: None when uninitialized or when no caching was requested; otherwise an active TeaCacheStrategy. + calibrate_cache: When True, runs every forward pass to gather calibration data. """ super().__init__(config, **kwargs) self.cache: TeaCacheStrategy | None = None self.calibrate_cache: bool = False def maybe_init_cache(self) -> None: - """Construct the cache strategy from the current forward_batch context. + """Initialize the cache strategy at the start of each new generation + (when timestep == 0 and cfg is positive for cfg-supporting models). - Called lazily on the first forward pass because sampling params - (e.g. `enable_teacache`) are only available then. + Since the cache parameters are contained in the sampling parameters which is only + accessible during the first forward pass, we cannot initialize the cache in CachableDiT.__init__. """ from sglang.multimodal_gen.runtime.managers.forward_context import ( get_forward_context, @@ -155,12 +153,12 @@ def maybe_init_cache(self) -> None: ) if forward_batch.enable_teacache and cache_params is not None: num_steps = int(forward_batch.num_inference_steps) - start_skipping, end_skipping = cache_params._get_skip_boundaries( + start_skipping, end_skipping = cache_params.get_skip_boundaries( num_steps ) self.cache = TeaCacheStrategy( supports_cfg, - cache_params._get_coefficients(), + cache_params.get_coefficients(), cache_params.rel_l1_thresh, start_skipping, end_skipping, From 5f48522839ff3360d6a7383cb095654faa79db47 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Wed, 15 Apr 2026 17:09:33 +0300 Subject: [PATCH 22/34] update --- .../test/unit/test_sampling_params.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/python/sglang/multimodal_gen/test/unit/test_sampling_params.py b/python/sglang/multimodal_gen/test/unit/test_sampling_params.py index 4ccc7fcb8128..094b6d879726 100644 --- a/python/sglang/multimodal_gen/test/unit/test_sampling_params.py +++ b/python/sglang/multimodal_gen/test/unit/test_sampling_params.py @@ -126,24 +126,23 @@ def coefficients_callback(_: TeaCacheParams) -> list[float]: def test_wan_teacache_boundaries_match_legacy_behavior(self): legacy_equivalent_cases = [ - (WanT2V_1_3B_SamplingParams().teacache_params, False, (5, 50)), - (WanT2V_1_3B_SamplingParams().teacache_params, True, (10, 100)), - (WanT2V_14B_SamplingParams().teacache_params, False, (1, 49)), - (WanT2V_14B_SamplingParams().teacache_params, True, (2, 98)), - (WanI2V_14B_480P_SamplingParam().teacache_params, False, (5, 50)), - (WanI2V_14B_480P_SamplingParam().teacache_params, True, (10, 100)), - (WanI2V_14B_720P_SamplingParam().teacache_params, False, (5, 50)), - (WanI2V_14B_720P_SamplingParam().teacache_params, True, (10, 100)), + (WanT2V_1_3B_SamplingParams().teacache_params, (5, 50)), + (WanT2V_1_3B_SamplingParams().teacache_params, (10, 100)), + (WanT2V_14B_SamplingParams().teacache_params, (1, 49)), + (WanT2V_14B_SamplingParams().teacache_params, (2, 98)), + (WanI2V_14B_480P_SamplingParam().teacache_params, (5, 50)), + (WanI2V_14B_480P_SamplingParam().teacache_params, (10, 100)), + (WanI2V_14B_720P_SamplingParam().teacache_params, (5, 50)), + (WanI2V_14B_720P_SamplingParam().teacache_params, (10, 100)), ] - for teacache_params, do_cfg, expected in legacy_equivalent_cases: + for teacache_params, expected in legacy_equivalent_cases: with self.subTest( use_ret_steps=teacache_params.use_ret_steps, - do_cfg=do_cfg, expected=expected, ): self.assertEqual( - teacache_params.get_skip_boundaries(50, do_cfg), + teacache_params.get_skip_boundaries(50), expected, ) From 268344d9caf3110f422f1ea656b7230eebae25bf Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 17 Apr 2026 17:42:24 +0300 Subject: [PATCH 23/34] fix test_wan_teacache_boundaries_match_legacy_behavior: remove stale do_cfg=True cases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The old test used get_skip_boundaries(50, do_cfg=True) which doubled the step count to 100, yielding (10,100) and (2,98). Since do_cfg was removed from the API, those entries are now redundant and wrong — drop them. Co-Authored-By: Claude Sonnet 4.6 --- .../sglang/multimodal_gen/test/unit/test_sampling_params.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/sglang/multimodal_gen/test/unit/test_sampling_params.py b/python/sglang/multimodal_gen/test/unit/test_sampling_params.py index 094b6d879726..69a53deff455 100644 --- a/python/sglang/multimodal_gen/test/unit/test_sampling_params.py +++ b/python/sglang/multimodal_gen/test/unit/test_sampling_params.py @@ -127,13 +127,9 @@ def coefficients_callback(_: TeaCacheParams) -> list[float]: def test_wan_teacache_boundaries_match_legacy_behavior(self): legacy_equivalent_cases = [ (WanT2V_1_3B_SamplingParams().teacache_params, (5, 50)), - (WanT2V_1_3B_SamplingParams().teacache_params, (10, 100)), (WanT2V_14B_SamplingParams().teacache_params, (1, 49)), - (WanT2V_14B_SamplingParams().teacache_params, (2, 98)), (WanI2V_14B_480P_SamplingParam().teacache_params, (5, 50)), - (WanI2V_14B_480P_SamplingParam().teacache_params, (10, 100)), (WanI2V_14B_720P_SamplingParam().teacache_params, (5, 50)), - (WanI2V_14B_720P_SamplingParam().teacache_params, (10, 100)), ] for teacache_params, expected in legacy_equivalent_cases: From 82657c7c53d6bf732da3e89673009e12a8e3cc3e Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 17 Apr 2026 17:49:40 +0300 Subject: [PATCH 24/34] fix TeaCacheStrategy: always advance previous_modulated_input each step MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The old TeaCacheMixin updated previous_modulated_input on every step (skipped or not), so L1 distance was always computed against the immediately preceding step. The refactored TeaCacheStrategy only updated it in write(), which is only called when actually computing — so skipped steps left the reference stale. On the next step the comparison was against a potentially much older input, producing larger distances, hitting the threshold sooner, and forcing more computation than expected. Fix: update state.previous_modulated_input in should_skip() unconditionally, and drop it from write() which only needs to cache the residual. Co-Authored-By: Claude Sonnet 4.6 --- python/sglang/multimodal_gen/runtime/cache/teacache.py | 9 ++++++--- .../multimodal_gen/runtime/models/dits/wanvideo.py | 6 +----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index c493228aad26..6fc0d008848c 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -116,6 +116,7 @@ def should_skip( # Boundary steps always compute (also handles invalid window where start >= end) if step < self.start_skipping or step >= self.end_skipping: + state.previous_modulated_input = modulated_input return False # First time computing, no previous input to compare against @@ -123,6 +124,7 @@ def should_skip( state.accumulated_rel_l1_distance = torch.zeros( 1, device=modulated_input.device, dtype=modulated_input.dtype ) + state.previous_modulated_input = modulated_input return False # compute the accumulated relative l1 distance @@ -134,6 +136,9 @@ def should_skip( rescaled = _rescale_distance_tensor(self.coefficients, rel_l1) state.accumulated_rel_l1_distance += rescaled + # Always advance the reference input (matching legacy per-step update behavior) + state.previous_modulated_input = modulated_input + # If below threshold, skip the forward pass if state.accumulated_rel_l1_distance < self.rel_l1_thresh: return True @@ -148,13 +153,11 @@ def write( self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor, - modulated_input: torch.Tensor | None = None, **kwargs, ) -> None: - """After the forward pass, cache the residual and the current modulated input.""" + """After the forward pass, cache the residual.""" state = self._get_state() state.previous_residual = hidden_states.squeeze(0) - original_hidden_states - state.previous_modulated_input = modulated_input def read(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: """Before the forward pass, read from the cache and apply it to the current hidden states.""" diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index e7d8b7544b62..0b376ab498e6 100755 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -1161,11 +1161,7 @@ def forward( if self.cache: # update the cache with the new hidden states - self.cache.write( - hidden_states, - original_hidden_states, - modulated_input=modulated_input, - ) + self.cache.write(hidden_states, original_hidden_states) if sequence_shard_enabled: hidden_states = hidden_states.contiguous() From cd2e907bac26ef3cba4006ffbbff4b3f9b2f0a37 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 17 Apr 2026 17:57:22 +0300 Subject: [PATCH 25/34] refactor TeaCacheStrategy: split should_skip into advance() + should_skip() advance(modulated_input) - always called, owns all state mutation: step counter, previous_modulated_input, accumulated L1, skippable flag should_skip() - pure read of state.skippable Co-Authored-By: Claude Sonnet 4.6 --- .../multimodal_gen/runtime/cache/teacache.py | 57 ++++++++----------- .../runtime/models/dits/wanvideo.py | 3 +- 2 files changed, 27 insertions(+), 33 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index 6fc0d008848c..f03d1589eafe 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -27,6 +27,7 @@ class TeaCacheState: previous_modulated_input: torch.Tensor | None = None previous_residual: torch.Tensor | None = None accumulated_rel_l1_distance: torch.Tensor | None = None + skippable: bool = False def _rescale_distance_tensor( @@ -106,48 +107,40 @@ def _get_state(self) -> TeaCacheState: return self.state_neg return self.state - def should_skip( - self, modulated_input: torch.Tensor | None = None, **kwargs - ) -> bool: - """Decide whether this forward pass can be skipped based on the accumulated L1 distance of the modulated input.""" + def advance(self, modulated_input: torch.Tensor) -> None: + """Advance state by one step: update the step counter, modulated input reference, + and accumulated L1 distance. Always call once per forward pass before should_skip. + """ state = self._get_state() + prev = state.previous_modulated_input step = state.step - state.step += 1 # always advance the step regardless of outcome - - # Boundary steps always compute (also handles invalid window where start >= end) - if step < self.start_skipping or step >= self.end_skipping: - state.previous_modulated_input = modulated_input - return False + state.step += 1 + state.previous_modulated_input = modulated_input - # First time computing, no previous input to compare against - if state.accumulated_rel_l1_distance is None: + in_window = self.start_skipping <= step < self.end_skipping + if not in_window or prev is None or state.accumulated_rel_l1_distance is None: state.accumulated_rel_l1_distance = torch.zeros( 1, device=modulated_input.device, dtype=modulated_input.dtype ) - state.previous_modulated_input = modulated_input - return False - - # compute the accumulated relative l1 distance - assert state.previous_modulated_input is not None - assert modulated_input is not None - rel_l1 = _compute_rel_l1_distance_tensor( - modulated_input, state.previous_modulated_input - ) - rescaled = _rescale_distance_tensor(self.coefficients, rel_l1) - state.accumulated_rel_l1_distance += rescaled + state.skippable = False + return - # Always advance the reference input (matching legacy per-step update behavior) - state.previous_modulated_input = modulated_input + rel_l1 = _compute_rel_l1_distance_tensor(modulated_input, prev) + state.accumulated_rel_l1_distance += _rescale_distance_tensor( + self.coefficients, rel_l1 + ) - # If below threshold, skip the forward pass if state.accumulated_rel_l1_distance < self.rel_l1_thresh: - return True + state.skippable = True + else: + state.accumulated_rel_l1_distance = torch.zeros( + 1, device=modulated_input.device, dtype=modulated_input.dtype + ) + state.skippable = False - # If threshold exceeded, reset accumulated so next window starts fresh - state.accumulated_rel_l1_distance = torch.zeros( - 1, device=modulated_input.device, dtype=modulated_input.dtype - ) - return False + def should_skip(self) -> bool: + """Return whether this forward pass can be skipped. Call after advance().""" + return self._get_state().skippable def write( self, diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index 0b376ab498e6..86ae99c5c941 100755 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -1145,7 +1145,8 @@ def forward( if self.cache: use_ret_steps = forward_batch.sampling_params.teacache_params.use_ret_steps modulated_input = timestep_proj if use_ret_steps else temb - should_skip_forward = self.cache.should_skip(modulated_input) + self.cache.advance(modulated_input) + should_skip_forward = self.cache.should_skip() if should_skip_forward: # compute hidden_states by reading from the cached state From a272b4a5bb05a3a74be7a317a9ec7e027ad3b9b1 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 17 Apr 2026 18:00:25 +0300 Subject: [PATCH 26/34] simplify TeaCacheStrategy: use None as sentinel, drop skippable flag None means must-compute (out of window, no prev, or threshold exceeded). advance() sets accumulated to None in those cases; should_skip() is a one-liner: accumulated is not None. Co-Authored-By: Claude Sonnet 4.6 --- .../multimodal_gen/runtime/cache/teacache.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index f03d1589eafe..b1e38dae462b 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -27,7 +27,6 @@ class TeaCacheState: previous_modulated_input: torch.Tensor | None = None previous_residual: torch.Tensor | None = None accumulated_rel_l1_distance: torch.Tensor | None = None - skippable: bool = False def _rescale_distance_tensor( @@ -108,8 +107,11 @@ def _get_state(self) -> TeaCacheState: return self.state def advance(self, modulated_input: torch.Tensor) -> None: - """Advance state by one step: update the step counter, modulated input reference, + """Advance state by one step: update step counter, modulated input reference, and accumulated L1 distance. Always call once per forward pass before should_skip. + + accumulated_rel_l1_distance is None iff this step must be computed (out of window, + no previous reference, or threshold just exceeded). should_skip reads this directly. """ state = self._get_state() prev = state.previous_modulated_input @@ -118,29 +120,26 @@ def advance(self, modulated_input: torch.Tensor) -> None: state.previous_modulated_input = modulated_input in_window = self.start_skipping <= step < self.end_skipping - if not in_window or prev is None or state.accumulated_rel_l1_distance is None: + if not in_window or prev is None: + state.accumulated_rel_l1_distance = None + return + + if state.accumulated_rel_l1_distance is None: state.accumulated_rel_l1_distance = torch.zeros( 1, device=modulated_input.device, dtype=modulated_input.dtype ) - state.skippable = False - return rel_l1 = _compute_rel_l1_distance_tensor(modulated_input, prev) state.accumulated_rel_l1_distance += _rescale_distance_tensor( self.coefficients, rel_l1 ) - if state.accumulated_rel_l1_distance < self.rel_l1_thresh: - state.skippable = True - else: - state.accumulated_rel_l1_distance = torch.zeros( - 1, device=modulated_input.device, dtype=modulated_input.dtype - ) - state.skippable = False + if state.accumulated_rel_l1_distance >= self.rel_l1_thresh: + state.accumulated_rel_l1_distance = None # reset: force compute next step def should_skip(self) -> bool: """Return whether this forward pass can be skipped. Call after advance().""" - return self._get_state().skippable + return self._get_state().accumulated_rel_l1_distance is not None def write( self, From d2dc7936302b8268a9e00ccec4bf11b51300939a Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 17 Apr 2026 18:03:26 +0300 Subject: [PATCH 27/34] redesign TeaCacheStrategy: should_skip decides, advance updates should_skip(modulated_input): pure decision, reads state only advance(modulated_input, skipped): updates step/prev/accumulated based on outcome Co-Authored-By: Claude Sonnet 4.6 --- .../multimodal_gen/runtime/cache/teacache.py | 54 ++++++++++--------- .../runtime/models/dits/wanvideo.py | 4 +- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index b1e38dae462b..b819cba43f4d 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -106,40 +106,44 @@ def _get_state(self) -> TeaCacheState: return self.state_neg return self.state - def advance(self, modulated_input: torch.Tensor) -> None: - """Advance state by one step: update step counter, modulated input reference, - and accumulated L1 distance. Always call once per forward pass before should_skip. + def should_skip(self, modulated_input: torch.Tensor) -> bool: + """Decide whether this step can be skipped. Does not mutate state.""" + state = self._get_state() + step = state.step + if step < self.start_skipping or step >= self.end_skipping: + return False + if ( + state.previous_modulated_input is None + or state.accumulated_rel_l1_distance is None + ): + return False + rel_l1 = _compute_rel_l1_distance_tensor( + modulated_input, state.previous_modulated_input + ) + rescaled = _rescale_distance_tensor(self.coefficients, rel_l1) + return (state.accumulated_rel_l1_distance + rescaled) < self.rel_l1_thresh - accumulated_rel_l1_distance is None iff this step must be computed (out of window, - no previous reference, or threshold just exceeded). should_skip reads this directly. - """ + def advance(self, modulated_input: torch.Tensor, skipped: bool) -> None: + """Update state after a step. Always call once per forward pass after should_skip.""" state = self._get_state() - prev = state.previous_modulated_input step = state.step state.step += 1 - state.previous_modulated_input = modulated_input - in_window = self.start_skipping <= step < self.end_skipping - if not in_window or prev is None: - state.accumulated_rel_l1_distance = None - return - if state.accumulated_rel_l1_distance is None: + if not in_window: + state.accumulated_rel_l1_distance = None + elif skipped: + rel_l1 = _compute_rel_l1_distance_tensor( + modulated_input, state.previous_modulated_input + ) + state.accumulated_rel_l1_distance += _rescale_distance_tensor( + self.coefficients, rel_l1 + ) + else: state.accumulated_rel_l1_distance = torch.zeros( 1, device=modulated_input.device, dtype=modulated_input.dtype ) - - rel_l1 = _compute_rel_l1_distance_tensor(modulated_input, prev) - state.accumulated_rel_l1_distance += _rescale_distance_tensor( - self.coefficients, rel_l1 - ) - - if state.accumulated_rel_l1_distance >= self.rel_l1_thresh: - state.accumulated_rel_l1_distance = None # reset: force compute next step - - def should_skip(self) -> bool: - """Return whether this forward pass can be skipped. Call after advance().""" - return self._get_state().accumulated_rel_l1_distance is not None + state.previous_modulated_input = modulated_input def write( self, diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index 86ae99c5c941..81f1c0c0b384 100755 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -1145,8 +1145,8 @@ def forward( if self.cache: use_ret_steps = forward_batch.sampling_params.teacache_params.use_ret_steps modulated_input = timestep_proj if use_ret_steps else temb - self.cache.advance(modulated_input) - should_skip_forward = self.cache.should_skip() + should_skip_forward = self.cache.should_skip(modulated_input) + self.cache.advance(modulated_input, should_skip_forward) if should_skip_forward: # compute hidden_states by reading from the cached state From 20402c858e3700356c0d82b519837c06370633dc Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 17 Apr 2026 18:05:57 +0300 Subject: [PATCH 28/34] simplify: always reset accumulated to zeros, drop None sentinel None was only meaningful as the uninitialized state. advance() now always resets to zeros when not skipping (regardless of window). should_skip() only needs to guard on previous_modulated_input being None. Initialize accumulated_rel_l1_distance to 0.0 so no None check needed in should_skip. Co-Authored-By: Claude Sonnet 4.6 --- .../multimodal_gen/runtime/cache/teacache.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index b819cba43f4d..d02e20ef0137 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -26,7 +26,7 @@ class TeaCacheState: step: int = 0 previous_modulated_input: torch.Tensor | None = None previous_residual: torch.Tensor | None = None - accumulated_rel_l1_distance: torch.Tensor | None = None + accumulated_rel_l1_distance: torch.Tensor | float = 0.0 def _rescale_distance_tensor( @@ -109,13 +109,9 @@ def _get_state(self) -> TeaCacheState: def should_skip(self, modulated_input: torch.Tensor) -> bool: """Decide whether this step can be skipped. Does not mutate state.""" state = self._get_state() - step = state.step - if step < self.start_skipping or step >= self.end_skipping: + if state.step < self.start_skipping or state.step >= self.end_skipping: return False - if ( - state.previous_modulated_input is None - or state.accumulated_rel_l1_distance is None - ): + if state.previous_modulated_input is None: return False rel_l1 = _compute_rel_l1_distance_tensor( modulated_input, state.previous_modulated_input @@ -128,11 +124,7 @@ def advance(self, modulated_input: torch.Tensor, skipped: bool) -> None: state = self._get_state() step = state.step state.step += 1 - in_window = self.start_skipping <= step < self.end_skipping - - if not in_window: - state.accumulated_rel_l1_distance = None - elif skipped: + if skipped: rel_l1 = _compute_rel_l1_distance_tensor( modulated_input, state.previous_modulated_input ) From 9ceff5927a4391066a1191d284e87753652771e0 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 17 Apr 2026 18:17:33 +0300 Subject: [PATCH 29/34] [diffusion] Merge TeaCache should_skip + advance into single step() Eliminates the footgun of calling two functions in the right order, avoids computing L1 twice, and keeps all state transitions in one place. Co-Authored-By: Claude Sonnet 4.6 --- .../multimodal_gen/runtime/cache/teacache.py | 43 ++++++++++--------- .../runtime/models/dits/wanvideo.py | 3 +- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index d02e20ef0137..0d28cdca3564 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -106,37 +106,38 @@ def _get_state(self) -> TeaCacheState: return self.state_neg return self.state - def should_skip(self, modulated_input: torch.Tensor) -> bool: - """Decide whether this step can be skipped. Does not mutate state.""" + def step(self, modulated_input: torch.Tensor) -> bool: + """Advance state and return whether this forward pass can be skipped.""" state = self._get_state() - if state.step < self.start_skipping or state.step >= self.end_skipping: + step = state.step + state.step += 1 + + if step < self.start_skipping or step >= self.end_skipping: + state.previous_modulated_input = modulated_input return False + if state.previous_modulated_input is None: + state.accumulated_rel_l1_distance = torch.zeros( + 1, device=modulated_input.device, dtype=modulated_input.dtype + ) + state.previous_modulated_input = modulated_input return False + rel_l1 = _compute_rel_l1_distance_tensor( modulated_input, state.previous_modulated_input ) rescaled = _rescale_distance_tensor(self.coefficients, rel_l1) - return (state.accumulated_rel_l1_distance + rescaled) < self.rel_l1_thresh - - def advance(self, modulated_input: torch.Tensor, skipped: bool) -> None: - """Update state after a step. Always call once per forward pass after should_skip.""" - state = self._get_state() - step = state.step - state.step += 1 - if skipped: - rel_l1 = _compute_rel_l1_distance_tensor( - modulated_input, state.previous_modulated_input - ) - state.accumulated_rel_l1_distance += _rescale_distance_tensor( - self.coefficients, rel_l1 - ) - else: - state.accumulated_rel_l1_distance = torch.zeros( - 1, device=modulated_input.device, dtype=modulated_input.dtype - ) + state.accumulated_rel_l1_distance += rescaled state.previous_modulated_input = modulated_input + if state.accumulated_rel_l1_distance < self.rel_l1_thresh: + return True + + state.accumulated_rel_l1_distance = torch.zeros( + 1, device=modulated_input.device, dtype=modulated_input.dtype + ) + return False + def write( self, hidden_states: torch.Tensor, diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index 81f1c0c0b384..319a22ca20c8 100755 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -1145,8 +1145,7 @@ def forward( if self.cache: use_ret_steps = forward_batch.sampling_params.teacache_params.use_ret_steps modulated_input = timestep_proj if use_ret_steps else temb - should_skip_forward = self.cache.should_skip(modulated_input) - self.cache.advance(modulated_input, should_skip_forward) + should_skip_forward = self.cache.step(modulated_input) if should_skip_forward: # compute hidden_states by reading from the cached state From 9a7de9cbea876a982b17d4171046a0763c046160 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 17 Apr 2026 18:19:55 +0300 Subject: [PATCH 30/34] [diffusion] Add inline comments to TeaCache step() and update docs Co-Authored-By: Claude Sonnet 4.6 --- docs/diffusion/performance/cache/teacache.md | 6 +++--- python/sglang/multimodal_gen/runtime/cache/teacache.py | 6 ++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/diffusion/performance/cache/teacache.md b/docs/diffusion/performance/cache/teacache.md index db828be7f6fb..186ebc59d996 100644 --- a/docs/diffusion/performance/cache/teacache.md +++ b/docs/diffusion/performance/cache/teacache.md @@ -22,9 +22,9 @@ TeaCache is split into three classes: - **`TeaCacheStrategy`** — all the logic. Owns two `TeaCacheState` objects (positive + optional negative CFG branch). Constructed once per generation by `CachableDiT.maybe_init_cache()` with all parameters resolved upfront. At each denoising step, the model calls: -1. `cache.should_skip(modulated_input)` — advances the step counter, computes accumulated L1 distance, returns whether to skip -2. `cache.read()` — if skipping, reads the residual from the cache and applies it to hidden states -3. `cache.write()` — if computing, stores the new residual and modulated input in the cache +1. `cache.step(modulated_input)` — advances the step counter, accumulates the rescaled L1 distance, returns `True` if the forward pass can be skipped +2. `cache.read()` — if skipping, reads the cached residual and applies it to hidden states +3. `cache.write()` — if computing, stores the new residual in the cache 4. `cache.reset_states()` — resets `state` and optionally `state_neg`, discarding any stale tensors ### L1 Distance Tracking diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index 0d28cdca3564..fdf57d97f252 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -112,10 +112,13 @@ def step(self, modulated_input: torch.Tensor) -> bool: step = state.step state.step += 1 + # Outside the skipping window — always compute, just track the reference. if step < self.start_skipping or step >= self.end_skipping: state.previous_modulated_input = modulated_input return False + # First step inside the window — no previous reference yet, so initialize + # the accumulator and record the reference; must compute this step. if state.previous_modulated_input is None: state.accumulated_rel_l1_distance = torch.zeros( 1, device=modulated_input.device, dtype=modulated_input.dtype @@ -123,6 +126,7 @@ def step(self, modulated_input: torch.Tensor) -> bool: state.previous_modulated_input = modulated_input return False + # Accumulate how much the modulated input has drifted since the last compute step. rel_l1 = _compute_rel_l1_distance_tensor( modulated_input, state.previous_modulated_input ) @@ -130,9 +134,11 @@ def step(self, modulated_input: torch.Tensor) -> bool: state.accumulated_rel_l1_distance += rescaled state.previous_modulated_input = modulated_input + # Drift is still small — skip the forward pass. if state.accumulated_rel_l1_distance < self.rel_l1_thresh: return True + # Drift exceeded threshold — compute this step and reset the accumulator. state.accumulated_rel_l1_distance = torch.zeros( 1, device=modulated_input.device, dtype=modulated_input.dtype ) From 4b26c092052975666147a1bcb4cc6d99189566cb Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Mon, 20 Apr 2026 11:50:37 +0300 Subject: [PATCH 31/34] Fix TeaCache step(): use accumulated_rel_l1_distance=None as window sentinel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit previous_modulated_input is set during outside-window steps, so using it as the "first step in window" sentinel caused the first in-window step to compute L1 distance against a pre-window input and potentially skip — producing different outputs than the original should_skip/advance pair. Fix: mirror the original advance() behavior — set accumulated_rel_l1_distance=None when outside the window, and check for None (not previous_modulated_input is None) to detect the first in-window step. Also update TeaCacheState type annotation. Co-Authored-By: Claude Sonnet 4.6 --- .../sglang/multimodal_gen/runtime/cache/teacache.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index fdf57d97f252..6365875cc83b 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -26,7 +26,7 @@ class TeaCacheState: step: int = 0 previous_modulated_input: torch.Tensor | None = None previous_residual: torch.Tensor | None = None - accumulated_rel_l1_distance: torch.Tensor | float = 0.0 + accumulated_rel_l1_distance: torch.Tensor | None = None def _rescale_distance_tensor( @@ -112,14 +112,17 @@ def step(self, modulated_input: torch.Tensor) -> bool: step = state.step state.step += 1 - # Outside the skipping window — always compute, just track the reference. + # Outside the skipping window — always compute, track reference, and mark + # accumulator as uninitialized so the first in-window step doesn't skip. if step < self.start_skipping or step >= self.end_skipping: state.previous_modulated_input = modulated_input + state.accumulated_rel_l1_distance = None return False - # First step inside the window — no previous reference yet, so initialize - # the accumulator and record the reference; must compute this step. - if state.previous_modulated_input is None: + # First step inside the window — accumulator not yet initialized; must compute. + # (previous_modulated_input may already be set from outside-window steps, but the + # accumulator is None, which is the correct sentinel for "no in-window reference yet".) + if state.accumulated_rel_l1_distance is None: state.accumulated_rel_l1_distance = torch.zeros( 1, device=modulated_input.device, dtype=modulated_input.dtype ) From 4d7d64f613e489aa4325f8c693b7ec7c4a63f2ae Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 24 Apr 2026 09:45:17 +0000 Subject: [PATCH 32/34] fix --- .../multimodal_gen/runtime/cache/teacache.py | 36 ++++++++----------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index 6365875cc83b..c1c6aec843ae 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -112,39 +112,31 @@ def step(self, modulated_input: torch.Tensor) -> bool: step = state.step state.step += 1 - # Outside the skipping window — always compute, track reference, and mark - # accumulator as uninitialized so the first in-window step doesn't skip. - if step < self.start_skipping or step >= self.end_skipping: - state.previous_modulated_input = modulated_input + # Do not skip on the first step or if we are outside the skipping window + in_skip_window = self.start_skipping <= step < self.end_skipping + if state.previous_modulated_input is None or not in_skip_window: state.accumulated_rel_l1_distance = None + state.previous_modulated_input = modulated_input.clone() return False - # First step inside the window — accumulator not yet initialized; must compute. - # (previous_modulated_input may already be set from outside-window steps, but the - # accumulator is None, which is the correct sentinel for "no in-window reference yet".) - if state.accumulated_rel_l1_distance is None: - state.accumulated_rel_l1_distance = torch.zeros( - 1, device=modulated_input.device, dtype=modulated_input.dtype - ) - state.previous_modulated_input = modulated_input - return False - - # Accumulate how much the modulated input has drifted since the last compute step. + # Compute the relative L1 distance and update the state rel_l1 = _compute_rel_l1_distance_tensor( modulated_input, state.previous_modulated_input ) rescaled = _rescale_distance_tensor(self.coefficients, rel_l1) - state.accumulated_rel_l1_distance += rescaled - state.previous_modulated_input = modulated_input + state.accumulated_rel_l1_distance = ( + rescaled + if state.accumulated_rel_l1_distance is None + else state.accumulated_rel_l1_distance + rescaled + ) + state.previous_modulated_input = modulated_input.clone() - # Drift is still small — skip the forward pass. + # Skip if accumulated rel l1 is small if state.accumulated_rel_l1_distance < self.rel_l1_thresh: return True - # Drift exceeded threshold — compute this step and reset the accumulator. - state.accumulated_rel_l1_distance = torch.zeros( - 1, device=modulated_input.device, dtype=modulated_input.dtype - ) + # Otherwise reset the accumulator and do not skip + state.accumulated_rel_l1_distance = None return False def write( From d51aab3f40e196808133b9e8eac6b7d6645d3243 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 24 Apr 2026 12:23:55 +0000 Subject: [PATCH 33/34] numerical precision --- python/sglang/multimodal_gen/runtime/cache/teacache.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index c1c6aec843ae..f90a2bcfcd2a 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -33,6 +33,9 @@ def _rescale_distance_tensor( coefficients: list[float], x: torch.Tensor ) -> torch.Tensor: """Polynomial rescaling using tensor operations (torch.compile friendly).""" + x = ( + x.float() + ) # upcast to float32 for numerical stability, especially with higher degree polynomials result = torch.zeros_like(x) for i, c in enumerate(coefficients): result = result + c * x ** (len(coefficients) - 1 - i) @@ -43,6 +46,10 @@ def _compute_rel_l1_distance_tensor( current: torch.Tensor, previous: torch.Tensor ) -> torch.Tensor: """Compute relative L1 distance as a tensor (torch.compile friendly).""" + current, previous = ( + current.float(), + previous.float(), + ) # upcast to float32 for numerical stability prev_mean = previous.abs().mean() curr_diff_mean = (current - previous).abs().mean() rel_distance = torch.where( From adf6d013987cc498b8ea21264391967b78870a87 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Mon, 4 May 2026 19:16:24 +0000 Subject: [PATCH 34/34] update perf --- .../test/server/perf_baselines.json | 123 +++++++++--------- 1 file changed, 62 insertions(+), 61 deletions(-) diff --git a/python/sglang/multimodal_gen/test/server/perf_baselines.json b/python/sglang/multimodal_gen/test/server/perf_baselines.json index dacfe45088ad..5390b4afdba3 100644 --- a/python/sglang/multimodal_gen/test/server/perf_baselines.json +++ b/python/sglang/multimodal_gen/test/server/perf_baselines.json @@ -821,69 +821,70 @@ }, "wan2_1_t2v_1.3b_teacache_enabled": { "stages_ms": { - "TextEncodingStage": 1088.18, - "DenoisingStage": 3926.63, - "InputValidationStage": 0.07, - "LatentPreparationStage": 0.14, - "TimestepPreparationStage": 2.31, - "DecodingStage": 711.47 + "InputValidationStage": 0.04, + "TextEncodingStage": 1218.71, + "LatentPreparationStage": 0.13, + "TimestepPreparationStage": 2.3, + "DenoisingStage": 3968.66, + "DecodingStage": 632.27, + "per_frame_generation": null }, "denoise_step_ms": { - "0": 80.49, - "1": 147.49, - "2": 144.75, - "3": 144.41, - "4": 143.82, - "5": 138.89, - "6": 36.55, - "7": 102.12, - "8": 39.18, - "9": 104.07, - "10": 37.91, - "11": 107.4, - "12": 2.52, - "13": 36.6, - "14": 106.91, - "15": 3.17, - "16": 41.13, - "17": 95.25, - "18": 2.67, - "19": 37.88, - "20": 110.46, - "21": 2.66, - "22": 38.29, - "23": 108.86, - "24": 2.7, - "25": 38.46, - "26": 108.62, - "27": 2.71, - "28": 41.23, - "29": 107.12, - "30": 2.72, - "31": 36.52, - "32": 112.05, - "33": 2.5, - "34": 110.6, - "35": 38.24, - "36": 110.83, - "37": 37.63, - "38": 111.91, - "39": 37.14, - "40": 111.39, - "41": 38.64, - "42": 110.57, - "43": 38.09, - "44": 111.32, - "45": 148.48, - "46": 149.96, - "47": 143.63, - "48": 148.49, - "49": 151.0 - }, - "expected_e2e_ms": 6012.96, - "expected_avg_denoise_ms": 78.45, - "expected_median_denoise_ms": 102.49, - "estimated_full_test_time_s": 126.0 + "0": 84.07, + "1": 153.06, + "2": 149.78, + "3": 149.08, + "4": 149.29, + "5": 66.61, + "6": 110.3, + "7": 52.51, + "8": 114.69, + "9": 44.93, + "10": 109.02, + "11": 43.79, + "12": 2.58, + "13": 109.2, + "14": 43.64, + "15": 2.55, + "16": 107.91, + "17": 43.19, + "18": 2.58, + "19": 107.48, + "20": 42.33, + "21": 2.56, + "22": 107.52, + "23": 48.05, + "24": 2.58, + "25": 117.42, + "26": 40.87, + "27": 2.57, + "28": 109.58, + "29": 45.79, + "30": 2.57, + "31": 109.49, + "32": 44.01, + "33": 2.57, + "34": 109.47, + "35": 43.15, + "36": 107.93, + "37": 43.33, + "38": 109.47, + "39": 44.34, + "40": 110.87, + "41": 45.92, + "42": 116.65, + "43": 38.6, + "44": 108.5, + "45": 153.14, + "46": 152.2, + "47": 152.05, + "48": 151.15, + "49": 153.32 + }, + "expected_e2e_ms": 6182.06, + "expected_avg_denoise_ms": 79.29, + "expected_median_denoise_ms": 95.78, + "estimated_full_test_time_s": 84.6 }, "wan2_1_t2v_1.3b": { "stages_ms": {