diff --git a/docs/diffusion/performance/cache/teacache.md b/docs/diffusion/performance/cache/teacache.md index dd9691c43a4a..192fe33ad5d6 100644 --- a/docs/diffusion/performance/cache/teacache.md +++ b/docs/diffusion/performance/cache/teacache.md @@ -50,7 +50,7 @@ TeaCache is configured via `TeaCacheParams` in the sampling parameters: from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams params = TeaCacheParams( - teacache_thresh=0.1, # Threshold for accumulated L1 distance + rel_l1_thresh=0.1, # Threshold for accumulated L1 distance coefficients=[1.0, 0.0, 0.0], # Polynomial coefficients for L1 rescaling ) ``` @@ -59,7 +59,7 @@ params = TeaCacheParams( | Parameter | Type | Description | |-----------|------|-------------| -| `teacache_thresh` | float | Threshold for accumulated L1 distance. Lower = more caching, faster but potentially lower quality | +| `rel_l1_thresh` | float | Threshold for accumulated L1 distance. Lower = more caching, faster but potentially lower quality | | `coefficients` | list[float] | Polynomial coefficients for L1 rescaling. Model-specific tuning | ### Model-Specific Configurations @@ -73,7 +73,7 @@ TeaCache is built into the following model families: | Model Family | CFG Cache Separation | Notes | |--------------|---------------------|-------| | Wan (wan2.1, wan2.2) | Yes | Full support | -| Hunyuan (HunyuanVideo) | Yes | To be supported | +| Hunyuan (HunyuanVideo) | Yes | Full support | | Z-Image | Yes | To be supported | | Flux | No | To be supported | | Qwen | No | To be supported | diff --git a/python/sglang/multimodal_gen/configs/sample/hunyuan.py b/python/sglang/multimodal_gen/configs/sample/hunyuan.py index c60b856630f0..61733d24868d 100644 --- a/python/sglang/multimodal_gen/configs/sample/hunyuan.py +++ b/python/sglang/multimodal_gen/configs/sample/hunyuan.py @@ -38,7 +38,7 @@ class HunyuanSamplingParams(SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.15, + rel_l1_thresh=0.15, # from https://github.com/ali-vilab/TeaCache/blob/7c10efc4702c6b619f47805f7abe4a7a08085aa0/TeaCache4HunyuanVideo/teacache_sample_video.py#L222 coefficients=[ 7.33226126e02, diff --git a/python/sglang/multimodal_gen/configs/sample/sampling_params.py b/python/sglang/multimodal_gen/configs/sample/sampling_params.py index 617e2b5e2f7b..713d30cc31c9 100644 --- a/python/sglang/multimodal_gen/configs/sample/sampling_params.py +++ b/python/sglang/multimodal_gen/configs/sample/sampling_params.py @@ -161,9 +161,8 @@ class SamplingParams: # TeaCache parameters enable_teacache: bool = False - teacache_params: Any = ( - None # TeaCacheParams or WanTeaCacheParams, set by model-specific subclass - ) + cache_params: Any | None = None + calibrate_cache: bool = False # Profiling profile: bool = False @@ -615,6 +614,12 @@ def add_argument(*name_or_flags, **kwargs): "--enable-teacache", action="store_true", ) + parser.add_argument( + "--calibrate-cache", + action="store_true", + default=SamplingParams.calibrate_cache, + help="Run in calibration mode: collect magnitude ratio statistics instead of skipping steps.", + ) # profiling add_argument( @@ -971,4 +976,4 @@ def n_tokens(self) -> int: @dataclass class CacheParams: - cache_type: str = "none" + pass diff --git a/python/sglang/multimodal_gen/configs/sample/teacache.py b/python/sglang/multimodal_gen/configs/sample/teacache.py index e20df6e67b93..6ef641eacaa8 100644 --- a/python/sglang/multimodal_gen/configs/sample/teacache.py +++ b/python/sglang/multimodal_gen/configs/sample/teacache.py @@ -17,7 +17,7 @@ class TeaCacheParams(CacheParams): Attributes: cache_type: (`str`, defaults to `teacache`): A string labeling these parameters as belonging to teacache. - teacache_thresh (`float`, defaults to `0.0`): + rel_l1_thresh (`float`, defaults to `0.0`): Threshold for accumulated relative L1 distance. When below this threshold, the forward pass is skipped. Recommended values: 0.25 for ~1.5x speedup, 0.4 for ~1.8x, 0.6 for ~2.0x. @@ -48,7 +48,7 @@ class TeaCacheParams(CacheParams): """ cache_type: str = "teacache" - teacache_thresh: float = 0.0 + rel_l1_thresh: float = 0.0 start_skipping: int | float = 5 end_skipping: int | float = -1 coefficients: list[float] = field(default_factory=list) diff --git a/python/sglang/multimodal_gen/configs/sample/wan.py b/python/sglang/multimodal_gen/configs/sample/wan.py index 0f147a9dcede..184fe7034e68 100644 --- a/python/sglang/multimodal_gen/configs/sample/wan.py +++ b/python/sglang/multimodal_gen/configs/sample/wan.py @@ -66,7 +66,7 @@ class WanT2V_1_3B_SamplingParams(SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.08, + rel_l1_thresh=0.08, use_ret_steps=True, coefficients_callback=_wan_1_3b_coefficients, start_skipping=5, @@ -102,7 +102,7 @@ class WanT2V_14B_SamplingParams(SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.20, + rel_l1_thresh=0.20, use_ret_steps=False, coefficients_callback=_wan_14b_coefficients, start_skipping=1, @@ -128,7 +128,7 @@ class WanI2V_14B_480P_SamplingParam(WanT2V_1_3B_SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.26, + rel_l1_thresh=0.26, use_ret_steps=True, coefficients_callback=_wan_14b_coefficients, start_skipping=5, @@ -156,7 +156,7 @@ class WanI2V_14B_720P_SamplingParam(WanT2V_14B_SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.3, + rel_l1_thresh=0.3, use_ret_steps=True, coefficients_callback=_wan_14b_coefficients, start_skipping=5, diff --git a/python/sglang/multimodal_gen/configs/sample/zimage.py b/python/sglang/multimodal_gen/configs/sample/zimage.py index 77a9dabf90de..9c728e182ab3 100644 --- a/python/sglang/multimodal_gen/configs/sample/zimage.py +++ b/python/sglang/multimodal_gen/configs/sample/zimage.py @@ -22,7 +22,7 @@ class ZImageTurboSamplingParams(SamplingParams): teacache_params: TeaCacheParams = field( default_factory=lambda: TeaCacheParams( - teacache_thresh=0.15, + rel_l1_thresh=0.15, coefficients=[ 7.33226126e02, -4.01131952e02, diff --git a/python/sglang/multimodal_gen/runtime/cache/__init__.py b/python/sglang/multimodal_gen/runtime/cache/__init__.py index 62f0f8457f8f..89f95bc4117d 100644 --- a/python/sglang/multimodal_gen/runtime/cache/__init__.py +++ b/python/sglang/multimodal_gen/runtime/cache/__init__.py @@ -10,18 +10,24 @@ """ +from sglang.multimodal_gen.runtime.cache.base import DiffusionCache from sglang.multimodal_gen.runtime.cache.cache_dit_integration import ( CacheDitConfig, enable_cache_on_dual_transformer, enable_cache_on_transformer, get_scm_mask, ) -from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheContext, TeaCacheMixin +from sglang.multimodal_gen.runtime.cache.teacache import ( + TeaCacheState, + TeaCacheStrategy, +) __all__ = [ + # Base + "DiffusionCache", # TeaCache (always available) - "TeaCacheContext", - "TeaCacheMixin", + "TeaCacheState", + "TeaCacheStrategy", # cache-dit integration (lazy-loaded, requires cache-dit package) "CacheDitConfig", "enable_cache_on_transformer", diff --git a/python/sglang/multimodal_gen/runtime/cache/base.py b/python/sglang/multimodal_gen/runtime/cache/base.py new file mode 100644 index 000000000000..ecdd4d058232 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/cache/base.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod + +import torch + + +class DiffusionCache(ABC): + """Base class for managing diffusion timestep caching. + + Subclasses define specific strategies for deciding when to skip + computation and how to store/retrieve hidden states. + """ + + @abstractmethod + def maybe_reset(self, **kwargs) -> None: + """Resets the internal cache state for a new generation sequence. + + Args: + **kwargs: Additional parameters that may be helpful. + """ + + @abstractmethod + def should_skip(self, **kwargs) -> bool: + """Determines if the current timestep computation can be skipped. + + Args: + **kwargs: Additional parameters that may be helpful. + + Returns: + bool: True if the timestep should be skipped, False otherwise. + """ + + @abstractmethod + def write( + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + **kwargs + ) -> None: + """Cache the result of a full forward pass to the cache state. + + Args: + hidden_states: Output of the transformer blocks. + original_hidden_states: Input from before the transformer blocks. + **kwargs: Additional parameters that may be helpful. + """ + + @abstractmethod + def read(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: + """Computes an approximation of the forward pass using cached data. Reads from the cache. + + Args: + hidden_states: The current input/intermediate hidden states. + **kwargs: Additional parameters for the retrieval strategy. + + Returns: + torch.Tensor: The approximated output of the forward pass. + """ + + def calibrate(self, **kwargs) -> None: + """Performs a calibration step to learn cache thresholds or values. + + Args: + **kwargs: Additional parameters that may be helpful. + """ + pass diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index 8830f7ec20c4..0d7931125160 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -1,316 +1,240 @@ # SPDX-License-Identifier: Apache-2.0 """ -TeaCache: Temporal similarity-based caching for diffusion models. - -TeaCache accelerates diffusion inference by selectively skipping redundant -computation when consecutive diffusion steps are similar enough. This is -achieved by tracking the L1 distance between modulated inputs across timesteps. - -Key concepts: -- Modulated input: The input to transformer blocks after timestep conditioning -- L1 distance: Measures how different consecutive timesteps are -- Threshold: When accumulated L1 distance exceeds threshold, force computation -- CFG support: Separate caches for positive and negative branches +TeaCache accelerates diffusion inference by skipping redundant forward +passes when consecutive denoising steps are sufficiently similar, as measured +by the accumulated relative L1 distance of modulated inputs. References: - TeaCache: Accelerating Diffusion Models with Temporal Similarity https://arxiv.org/abs/2411.14324 """ -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +logger = logging.getLogger(__name__) -import numpy as np import torch -from sglang.multimodal_gen.configs.models import DiTConfig +from sglang.multimodal_gen.runtime.cache import DiffusionCache if TYPE_CHECKING: from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams -@dataclass -class TeaCacheContext: - """Common context extracted for TeaCache skip decision. - - This context is populated from the forward_batch and forward_context - during each denoising step, providing all information needed to make - cache decisions. - - Attributes: - current_timestep: Current denoising timestep index (0-indexed). - num_inference_steps: Total number of inference steps. - do_cfg: Whether classifier-free guidance is enabled. - is_cfg_negative: True if currently processing negative CFG branch. - teacache_thresh: Threshold for accumulated L1 distance. - coefficients: Polynomial coefficients for L1 rescaling. - teacache_params: Full TeaCacheParams for model-specific access. - """ - - current_timestep: int - num_inference_steps: int - do_cfg: bool - is_cfg_negative: bool # For CFG branch selection - teacache_thresh: float - coefficients: list[float] - teacache_params: "TeaCacheParams" # Full params for model-specific access - - -class TeaCacheMixin: - """ - Mixin class providing TeaCache optimization functionality. - - TeaCache accelerates diffusion inference by selectively skipping redundant - computation when consecutive diffusion steps are similar enough. - - This mixin should be inherited by DiT model classes that want to support - TeaCache optimization. It provides: - - State management for tracking L1 distances - - CFG-aware caching (separate caches for positive/negative branches) - - Decision logic for when to compute vs. use cache - - Example usage in a DiT model: - class MyDiT(TeaCacheMixin, BaseDiT): - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - self._init_teacache_state() - - def forward(self, hidden_states, timestep, ...): - ctx = self._get_teacache_context() - if ctx is not None: - # Compute modulated input (model-specific, e.g., after timestep embedding) - modulated_input = self._compute_modulated_input(hidden_states, timestep) - is_boundary = (ctx.current_timestep == 0 or - ctx.current_timestep >= ctx.num_inference_steps - 1) - - should_calc = self._compute_teacache_decision( - modulated_inp=modulated_input, - is_boundary_step=is_boundary, - coefficients=ctx.coefficients, - teacache_thresh=ctx.teacache_thresh, - ) - - if not should_calc: - # Use cached residual (must implement retrieve_cached_states) - return self.retrieve_cached_states(hidden_states) - - # Normal forward pass... - output = self._transformer_forward(hidden_states, timestep, ...) - - # Cache states for next step - if ctx is not None: - self.maybe_cache_states(output, hidden_states) - - return output - - Subclass implementation notes: - - `_compute_modulated_input()`: Model-specific method to compute the input - after timestep conditioning (used for L1 distance calculation) - - `retrieve_cached_states()`: Must be overridden to return cached output - - `maybe_cache_states()`: Override to store states for cache retrieval - - Attributes: - cnt: Counter for tracking steps. - enable_teacache: Whether TeaCache is enabled. - previous_modulated_input: Cached modulated input for positive branch. - previous_residual: Cached residual for positive branch. - accumulated_rel_l1_distance: Accumulated L1 distance for positive branch. - is_cfg_negative: Whether currently processing negative CFG branch. - _supports_cfg_cache: Whether this model supports CFG cache separation. - - CFG-specific attributes (only when _supports_cfg_cache is True): - previous_modulated_input_negative: Cached input for negative branch. - previous_residual_negative: Cached residual for negative branch. - accumulated_rel_l1_distance_negative: L1 distance for negative branch. - """ - - # Models that support CFG cache separation (wan/hunyuan/zimage) - # Models not in this set (flux/qwen) auto-disable TeaCache when CFG is enabled - _CFG_SUPPORTED_PREFIXES: set[str] = {"wan", "hunyuan", "zimage"} - config: DiTConfig - - def _init_teacache_state(self) -> None: - """Initialize TeaCache state. Call this in subclass __init__.""" - # Common TeaCache state - self.cnt = 0 - self.enable_teacache = True - # Flag indicating if this model supports CFG cache separation - self._supports_cfg_cache = ( - self.config.prefix.lower() in self._CFG_SUPPORTED_PREFIXES - ) - - # Always initialize positive cache fields (used in all modes) +def _rescale_distance_tensor( + coefficients: list[float], x: torch.Tensor +) -> torch.Tensor: + """Polynomial rescaling using tensor operations (torch.compile friendly).""" + c = coefficients + return c[0] * x**4 + c[1] * x**3 + c[2] * x**2 + c[3] * x + c[4] + + +def _compute_rel_l1_distance_tensor( + current: torch.Tensor, previous: torch.Tensor +) -> torch.Tensor: + """Compute relative L1 distance as a tensor (torch.compile friendly).""" + prev_mean = previous.abs().mean() + curr_diff_mean = (current - previous).abs().mean() + rel_distance = torch.where( + prev_mean > 1e-9, + curr_diff_mean / prev_mean, + torch.where( + current.abs().mean() < 1e-9, + torch.zeros(1, device=current.device, dtype=current.dtype), + torch.full((1,), float("inf"), device=current.device, dtype=current.dtype), + ), + ) + return rel_distance.squeeze() + + +class TeaCacheState: + """Tracks step progress, cached tensors, and L1 distances for a single CFG path. Updated every timestep.""" + + def __init__(self) -> None: + self.step: int = 0 self.previous_modulated_input: torch.Tensor | None = None self.previous_residual: torch.Tensor | None = None - self.accumulated_rel_l1_distance: float = 0.0 + self.accumulated_rel_l1_distance: torch.Tensor | None = None - self.is_cfg_negative = False - # CFG-specific fields initialized to None (created when CFG is used) - # These are only used when _supports_cfg_cache is True AND do_cfg is True - if self._supports_cfg_cache: - self.previous_modulated_input_negative: torch.Tensor | None = None - self.previous_residual_negative: torch.Tensor | None = None - self.accumulated_rel_l1_distance_negative: float = 0.0 - - def reset_teacache_state(self) -> None: - """Reset all TeaCache state at the start of each generation task.""" - self.cnt = 0 - - # Primary cache fields (always present) + def reset(self) -> None: + """Clear all cached tensors and reset the step counter for a new generation.""" + self.step = 0 self.previous_modulated_input = None self.previous_residual = None - self.accumulated_rel_l1_distance = 0.0 - self.is_cfg_negative = False - self.enable_teacache = True - # CFG negative cache fields (always reset, may be unused) - if self._supports_cfg_cache: - self.previous_modulated_input_negative = None - self.previous_residual_negative = None - self.accumulated_rel_l1_distance_negative = 0.0 - - def _compute_l1_and_decide( - self, - modulated_inp: torch.Tensor, - coefficients: list[float], - teacache_thresh: float, - ) -> tuple[float, bool]: - """ - Compute L1 distance and decide whether to calculate or use cache. + self.accumulated_rel_l1_distance = None - Args: - modulated_inp: Current timestep's modulated input. - coefficients: Polynomial coefficients for L1 rescaling. - teacache_thresh: Threshold for cache decision. + def update( + self, modulated_inp: torch.Tensor | None, previous_residual: torch.Tensor | None + ) -> None: + """Store the current modulated input and its computed residual for possible future reuse.""" + self.previous_modulated_input = modulated_inp + self.previous_residual = previous_residual - Returns: - Tuple of (new_accumulated_distance, should_calc). - """ - prev_modulated_inp = ( - self.previous_modulated_input_negative - if self.is_cfg_negative - else self.previous_modulated_input - ) + def __repr__(self): + return f"TeaCacheState(step={self.step}, accumulated_rel_l1_distance={self.accumulated_rel_l1_distance})" - # Defensive check: if previous input is not set, force calculation - if prev_modulated_inp is None: - return 0.0, True - # Compute relative L1 distance - diff = modulated_inp - prev_modulated_inp - rel_l1 = (diff.abs().mean() / prev_modulated_inp.abs().mean()).cpu().item() +class TeaCacheStrategy(DiffusionCache): + """Implements TeaCache to skip redundant diffusion forward passes. - # Apply polynomial rescaling - rescale_func = np.poly1d(coefficients) + TeaCacheStrategy implements teacache as a `DiffusionCache` object. It + manages two TeaCacheState objects (positive + optional negative CFG branch) + and stores parameters needed to make skippind decision. + """ - accumulated_rel_l1_distance = ( - self.accumulated_rel_l1_distance_negative - if self.is_cfg_negative - else self.accumulated_rel_l1_distance + def __init__(self, supports_cfg: bool) -> None: + """Initialize cache states for positive and optional negative CFG branches.""" + # params updated every forward pass + self.state = TeaCacheState() + self.state_neg = TeaCacheState() if supports_cfg else None + # params updated at the start of each new generation + # set in maybe_reset() + self.cache_params: TeaCacheParams | None = None + self.coefficients: list[float] = [] + self.num_steps: int = 0 + self.start_skipping: int | None = None + self.end_skipping: int | None = None + + def _get_state(self) -> TeaCacheState: + """Select the appropriate cache state (positive/negative cfg) based on the forward context.""" + from sglang.multimodal_gen.runtime.managers.forward_context import ( + get_forward_context, ) - accumulated_rel_l1_distance = accumulated_rel_l1_distance + rescale_func(rel_l1) - - if accumulated_rel_l1_distance >= teacache_thresh: - # Threshold exceeded: force compute and reset accumulator - return 0.0, True - # Cache hit: keep accumulated distance - return accumulated_rel_l1_distance, False - - def _compute_teacache_decision( - self, - modulated_inp: torch.Tensor, - is_boundary_step: bool, - coefficients: list[float], - teacache_thresh: float, - ) -> bool: - """ - Compute cache decision for TeaCache. - - Args: - modulated_inp: Current timestep's modulated input. - is_boundary_step: True for boundary timesteps that always compute. - coefficients: Polynomial coefficients for L1 rescaling. - teacache_thresh: Threshold for cache decision. - - Returns: - True if forward computation is needed, False to use cache. - """ - if not self.enable_teacache: - return True - if is_boundary_step: - new_accum, should_calc = 0.0, True - else: - new_accum, should_calc = self._compute_l1_and_decide( - modulated_inp=modulated_inp, - coefficients=coefficients, - teacache_thresh=teacache_thresh, - ) - - # Advance baseline and accumulator for the active branch - if not self.is_cfg_negative: - self.previous_modulated_input = modulated_inp.clone() - self.accumulated_rel_l1_distance = new_accum - elif self._supports_cfg_cache: - self.previous_modulated_input_negative = modulated_inp.clone() - self.accumulated_rel_l1_distance_negative = new_accum + fb = get_forward_context().forward_batch + is_cfg_negative = fb.is_cfg_negative if fb is not None else False + if is_cfg_negative and self.state_neg is not None: + return self.state_neg + return self.state - return should_calc + def maybe_reset(self, **kwargs) -> None: + """Maybe reset the TeaCacheState by doing three things: - def _get_teacache_context(self) -> TeaCacheContext | None: - """ - Check TeaCache preconditions and extract common context. + 1. Reset TeaCacheState if the previous generation is complete + 2. Initialize parameters if at the start of a new generation. + 3. Increment the state's timestep counter (always) - Returns: - TeaCacheContext if TeaCache is enabled and properly configured, - None if should skip TeaCache logic entirely. + Called on every forward pass before should_skip(). """ from sglang.multimodal_gen.runtime.managers.forward_context import ( get_forward_context, ) - forward_context = get_forward_context() - forward_batch = forward_context.forward_batch - - # Early return checks - if ( - forward_batch is None - or not forward_batch.enable_teacache - or forward_batch.teacache_params is None - ): - return None - - teacache_params = forward_batch.teacache_params - - # Extract common values - current_timestep = forward_context.current_timestep - num_inference_steps = forward_batch.num_inference_steps - do_cfg = forward_batch.do_classifier_free_guidance - is_cfg_negative = forward_batch.is_cfg_negative - - # Reset at first timestep - if current_timestep == 0 and not self.is_cfg_negative: - self.reset_teacache_state() - - return TeaCacheContext( - current_timestep=current_timestep, - num_inference_steps=num_inference_steps, - do_cfg=do_cfg, - is_cfg_negative=is_cfg_negative, - teacache_thresh=teacache_params.teacache_thresh, - coefficients=teacache_params.get_coefficients(), - teacache_params=teacache_params, + state = self._get_state() + + # Reset state if we completed a generation + if state.step == self.num_steps and state.step > 0: + state.reset() + + # Initialize values if at the start of each new generation + if state.step == 0: + + # set the teacache parameters + fb = get_forward_context().forward_batch + assert ( + fb is not None + ), "TeaCacheStrategy required the forward_batch not be None" + self.cache_params = getattr(fb.sampling_params, "teacache_params", None) + + # set the number of inference steps + assert ( + self.cache_params is not None + ), "TeaCacheStrategy requires teacache_params in sampling_params" + self.num_steps = int(fb.num_inference_steps) + + # set the teacache coefficients + if self.cache_params.coefficients_callback: + self.coefficients = self.cache_params.coefficients_callback( + self.cache_params + ) + else: + self.coefficients = self.cache_params.coefficients + + # set the start and end skippable steps + if isinstance(self.cache_params.start_skipping, float): + start_skipping = int(self.num_steps * self.cache_params.start_skipping) + elif self.cache_params.start_skipping < 0: + start_skipping = self.num_steps + self.cache_params.start_skipping + else: + start_skipping = self.cache_params.start_skipping + + if isinstance(self.cache_params.end_skipping, float): + end_skipping = int(self.num_steps * self.cache_params.end_skipping) + elif self.cache_params.end_skipping < 0: + end_skipping = self.num_steps + self.cache_params.end_skipping + else: + end_skipping = self.cache_params.end_skipping + + if start_skipping > end_skipping: + logger.warning( + f"TeaCache skip window is invalid (start_skipping={self.start_skipping} > " + f"end_skipping={self.end_skipping}) for num_inference_steps={self.num_steps}. " + "This can happen during warmup runs with very few steps. TeaCache is disabled." + ) + self.start_skipping = self.end_skipping = None + else: + self.start_skipping, self.end_skipping = start_skipping, end_skipping + + # increment the number of steps always + state.step += 1 + + def should_skip( + self, modulated_input: torch.Tensor | None = None, **kwargs + ) -> bool: + """Decide whether this forward pass can be skipped based on the accumulated L1 distance of the modulated input.""" + state = self._get_state() + assert self.cache_params is not None + + # No valid skip window for this generation + if self.start_skipping is None or self.end_skipping is None: + return False + + # Boundary steps always compute + if state.step < self.start_skipping or state.step >= self.end_skipping: + return False + + # First time computing, no previous input to compare against + if state.accumulated_rel_l1_distance is None: + state.accumulated_rel_l1_distance = torch.zeros( + 1, device=modulated_input.device, dtype=modulated_input.dtype + ) + return False + + # compute the accumulated relative l1 distance + assert state.previous_modulated_input is not None + assert modulated_input is not None + rel_l1 = _compute_rel_l1_distance_tensor( + modulated_input, state.previous_modulated_input ) + rescaled = _rescale_distance_tensor(self.coefficients, rel_l1) + state.accumulated_rel_l1_distance += rescaled - def maybe_cache_states( - self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor - ) -> None: - """Cache states for later retrieval. Override in subclass if needed.""" - pass + # If below threshold, skip the forward pass + if state.accumulated_rel_l1_distance < self.cache_params.rel_l1_thresh: + return True - def should_skip_forward_for_cached_states(self, **kwargs: dict[str, Any]) -> bool: - """Check if forward can be skipped using cached states.""" + # If threshold exceeded, reset accumulated so next window starts fresh + state.accumulated_rel_l1_distance = torch.zeros( + 1, device=modulated_input.device, dtype=modulated_input.dtype + ) return False - def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Retrieve cached states. Must be implemented by subclass.""" - raise NotImplementedError("retrieve_cached_states is not implemented") + def write( + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + modulated_input: torch.Tensor | None = None, + **kwargs, + ) -> None: + """After the forward pass, cache the residual and the current modulated input.""" + assert self.cache_params is not None + residual = hidden_states.squeeze(0) - original_hidden_states + state = self._get_state() + state.update(modulated_input, residual) + + def read(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: + """Before the forward pass, read from the cache and apply it to the current hidden states.""" + return hidden_states + self._get_state().previous_residual diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py index 18d5ed983510..ce7b0117e544 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py @@ -131,6 +131,7 @@ async def generations( true_cfg_scale=request.true_cfg_scale, negative_prompt=request.negative_prompt, enable_teacache=request.enable_teacache, + calibrate_cache=request.calibrate_cache, output_compression=request.output_compression, output_quality=request.output_quality, enable_upscaling=request.enable_upscaling, @@ -211,6 +212,7 @@ async def edits( output_quality: Optional[str] = Form("default"), output_compression: Optional[int] = Form(None), enable_teacache: Optional[bool] = Form(False), + calibrate_cache: Optional[bool] = Form(False), enable_upscaling: Optional[bool] = Form(False), upscaling_model_path: Optional[str] = Form(None), upscaling_scale: Optional[int] = Form(4), @@ -266,6 +268,7 @@ async def edits( true_cfg_scale=true_cfg_scale, num_inference_steps=num_inference_steps, enable_teacache=enable_teacache, + calibrate_cache=calibrate_cache, num_frames=num_frames, output_compression=output_compression, output_quality=output_quality, diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py index b326a295097f..1d320c1a4fe6 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py @@ -48,6 +48,7 @@ class ImageGenerationsRequest(BaseModel): output_quality: Optional[str] = "default" output_compression: Optional[int] = None enable_teacache: Optional[bool] = False + calibrate_cache: Optional[bool] = False # Upscaling enable_upscaling: Optional[bool] = False upscaling_model_path: Optional[str] = None @@ -100,6 +101,7 @@ class VideoGenerationsRequest(BaseModel): ) negative_prompt: Optional[str] = None enable_teacache: Optional[bool] = False + calibrate_cache: Optional[bool] = False # Frame interpolation enable_frame_interpolation: Optional[bool] = False frame_interpolation_exp: Optional[int] = 1 # 1=2×, 2=4× diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py index abccf31bb2d6..eadc2b37d3a0 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py @@ -73,6 +73,7 @@ def _build_video_sampling_params(request_id: str, request: VideoGenerationsReque guidance_scale_2=request.guidance_scale_2, negative_prompt=request.negative_prompt, enable_teacache=request.enable_teacache, + calibrate_cache=request.calibrate_cache, enable_frame_interpolation=request.enable_frame_interpolation, frame_interpolation_exp=request.frame_interpolation_exp, frame_interpolation_scale=request.frame_interpolation_scale, @@ -183,6 +184,7 @@ async def create_video( guidance_scale: Optional[float] = Form(None), num_inference_steps: Optional[int] = Form(None), enable_teacache: Optional[bool] = Form(False), + calibrate_cache: Optional[bool] = Form(False), enable_frame_interpolation: Optional[bool] = Form(False), frame_interpolation_exp: Optional[int] = Form(1), frame_interpolation_scale: Optional[float] = Form(1.0), @@ -261,6 +263,7 @@ async def create_video( negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, enable_teacache=enable_teacache, + calibrate_cache=calibrate_cache, enable_frame_interpolation=enable_frame_interpolation, frame_interpolation_exp=frame_interpolation_exp, frame_interpolation_scale=frame_interpolation_scale, diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index f170809a738e..d926ab152c07 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -46,13 +46,13 @@ import torch from torch.distributed.tensor import DTensor, distribute_tensor -from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheMixin from sglang.multimodal_gen.runtime.loader.utils import ( _list_safetensors_files, ) from sglang.multimodal_gen.runtime.loader.weight_utils import ( safetensors_weights_iterator, ) +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.pipelines.diffusers_pipeline import DiffusersPipeline from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin @@ -213,8 +213,8 @@ def update_weights_from_disk( if success and flush_cache: for _, module in modules_to_update: - if isinstance(module, TeaCacheMixin): - module.reset_teacache_state() + if isinstance(module, CachableDiT) and module.cache: + module.cache.reset() logger.info(message) return success, message diff --git a/python/sglang/multimodal_gen/runtime/models/dits/base.py b/python/sglang/multimodal_gen/runtime/models/dits/base.py index 9816f5fb03a7..71a4dd732541 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/base.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/base.py @@ -8,16 +8,13 @@ from torch import nn from sglang.multimodal_gen.configs.models import DiTConfig - -# NOTE: TeaCacheContext and TeaCacheMixin have been moved to -# sglang.multimodal_gen.runtime.cache.teacache -# For backwards compatibility, re-export from the new location -from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheContext # noqa: F401 -from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheMixin +from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheStrategy from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) -# TODO class BaseDiT(nn.Module, ABC): _fsdp_shard_conditions: list = [] _compile_conditions: list = [] @@ -87,11 +84,15 @@ def device(self) -> torch.device: return next(self.parameters()).device -class CachableDiT(TeaCacheMixin, BaseDiT): +_CFG_SUPPORTED_PREFIXES: set[str] = {"wan", "hunyuan", "zimage"} + + +class CachableDiT(BaseDiT): """ - An intermediate base class that adds TeaCache optimization functionality to DiT models. + An intermediate base class that adds timestep-caching support for DiT models + such as TeaCache. - Inherits TeaCacheMixin for cache logic and BaseDiT for core DiT functionality. + Inherits `BaseDiT` for core DiT functionality and stores cache logic in `self.cache`. """ # These are required class attributes that should be overridden by concrete implementations @@ -109,8 +110,53 @@ class CachableDiT(TeaCacheMixin, BaseDiT): ) def __init__(self, config: DiTConfig, **kwargs) -> None: + """ + Args: + config: DiT model configuration. + **kwargs: Passed through to BaseDiT (e.g. hf_config). + + Attributes: + cache: Active cache strategy, or a sentinel: + - None: uninitialized; init_cache() has not been called yet. + - False: no cache strategy requested. + - DiffusionCache: an active cache strategy (e.g. TeaCacheStrategy). + calibrate_cache: When True, run every forward pass to calibrate + the values needed for caching. + """ super().__init__(config, **kwargs) - self._init_teacache_state() + self.cache: TeaCacheStrategy | bool | None = None + self.calibrate_cache: bool = False + + def init_cache(self) -> None: + """Construct the cache strategy from the current forward_batch context. + + Called lazily on the first forward pass because sampling params + (e.g. `enable_teacache`) are only available then. + """ + from sglang.multimodal_gen.runtime.managers.forward_context import ( + get_forward_context, + ) + + fb = get_forward_context().forward_batch + if fb is None: + return + + # caching strategies may handle pos/neg cfg separately + supports_cfg = self.config.prefix.lower() in _CFG_SUPPORTED_PREFIXES + + # select caching strategy + if fb.enable_teacache: + self.cache = TeaCacheStrategy(supports_cfg) + else: + self.cache = False + + if fb.calibrate_cache: + if self.cache: + self.calibrate_cache = fb.calibrate_cache + else: + logger.warning( + "Calibrate cache is set to True but no cache is defined." + ) @classmethod def get_nunchaku_quant_rules(cls) -> dict[str, dict[str, Any]]: diff --git a/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py index 09a233ec9176..9e41656eec6b 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py @@ -4,12 +4,10 @@ from typing import Any -import numpy as np import torch import torch.nn as nn from sglang.multimodal_gen.configs.models.dits import HunyuanVideoConfig -from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams from sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_world_size from sglang.multimodal_gen.runtime.layers.attention import ( LocalAttention, @@ -36,12 +34,10 @@ TimestepEmbedder, unpatchify, ) -from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.models.utils import modulate from sglang.multimodal_gen.runtime.platforms import ( AttentionBackendEnum, - current_platform, ) from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin @@ -580,9 +576,12 @@ def forward( Returns: Tuple of (output) """ - forward_context = get_forward_context() - forward_batch = forward_context.forward_batch - enable_teacache = forward_batch is not None and forward_batch.enable_teacache + + # if caching is enabled, we might initialize or reset the cache state + if self.cache is None: + self.init_cache() + if self.cache: + self.cache.maybe_reset() if guidance is None: guidance = torch.tensor( @@ -635,14 +634,17 @@ def forward( freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None - should_skip_forward = self.should_skip_forward_for_cached_states( - img=img, vec=vec - ) + # if caching is enabled, we might be able to skip the forward pass + should_skip_forward = False + if self.cache and not self.calibrate_cache: + modulated_input = self._get_modulated_input(img.clone(), vec.clone()) + should_skip_forward = self.cache.should_skip(modulated_input) if should_skip_forward: - img = self.retrieve_cached_states(img) + # compute img using the cached state + img = self.cache.read(img) else: - if enable_teacache: + if self.cache and not self.calibrate_cache: original_img = img.clone() # Process through double stream blocks @@ -666,8 +668,12 @@ def forward( # Extract image features img = x[:, :img_seq_len, ...] - if enable_teacache: - self.maybe_cache_states(img, original_img) + if self.cache and not self.calibrate_cache: + self.cache.write( + img, + original_img, + modulated_input=modulated_input, + ) # Final layer processing img = self.final_layer(img, vec) @@ -676,112 +682,13 @@ def forward( return img - def maybe_cache_states( - self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor - ) -> None: - self.previous_residual = hidden_states - original_hidden_states - - def should_skip_forward_for_cached_states(self, **kwargs) -> bool: - - forward_context = get_forward_context() - forward_batch = forward_context.forward_batch - if forward_batch is None: - return False - current_timestep = forward_context.current_timestep - enable_teacache = forward_batch.enable_teacache - - if not enable_teacache: - return False - raise NotImplementedError("teacache is not supported yet for HunyuanVideo") - - teacache_params = forward_batch.teacache_params - assert teacache_params is not None, "teacache_params is not initialized" - assert isinstance( - teacache_params, TeaCacheParams - ), "teacache_params is not a TeaCacheParams" - num_inference_steps = forward_batch.num_inference_steps - teache_thresh = teacache_params.teacache_thresh - - coefficients = teacache_params.coefficients - - if current_timestep == 0: - self.cnt = 0 - - inp = kwargs["img"].clone() - vec_ = kwargs["vec"].clone() - # convert to DTensor - vec_ = torch.distributed.tensor.DTensor.from_local( - vec_, - torch.distributed.DeviceMesh( - current_platform.device_type, - list(range(get_sp_world_size())), - mesh_dim_names=("dp",), - ), - [torch.distributed.tensor.Replicate()], - ) - - inp = torch.distributed.tensor.DTensor.from_local( - inp, - torch.distributed.DeviceMesh( - current_platform.device_type, - list(range(get_sp_world_size())), - mesh_dim_names=("dp",), - ), - [torch.distributed.tensor.Replicate()], - ) - - # txt_ = kwargs["txt"].clone() - - # inp = img.clone() - # vec_ = vec.clone() - # txt_ = txt.clone() - ( - img_mod1_shift, - img_mod1_scale, - img_mod1_gate, - img_mod2_shift, - img_mod2_scale, - img_mod2_gate, - ) = ( - self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1) + def _get_modulated_input(self, inp, vec) -> bool: + img_mod1_shift, img_mod1_scale = ( + self.double_blocks[0].img_mod(vec).chunk(6, dim=-1)[:2] ) normed_inp = self.double_blocks[0].img_attn_norm.norm(inp) modulated_inp = modulate(normed_inp, shift=img_mod1_shift, scale=img_mod1_scale) - if self.cnt == 0 or self.cnt == num_inference_steps - 1: - should_calc = True - self.accumulated_rel_l1_distance = 0 - else: - coefficients = [ - 7.33226126e02, - -4.01131952e02, - 6.75869174e01, - -3.14987800e00, - 9.61237896e-02, - ] - rescale_func = np.poly1d(coefficients) - assert ( - self.previous_modulated_input is not None - ), "previous_modulated_input is not initialized" - self.accumulated_rel_l1_distance += rescale_func( - ( - (modulated_inp - self.previous_modulated_input).abs().mean() - / self.previous_modulated_input.abs().mean() - ) - .cpu() - .item() - ) - if self.accumulated_rel_l1_distance < teache_thresh: - should_calc = False - else: - should_calc = True - self.accumulated_rel_l1_distance = 0 - self.previous_modulated_input = modulated_inp - self.cnt += 1 - - return not should_calc - - def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor: - return hidden_states + self.previous_residual + return modulated_inp class SingleTokenRefiner(nn.Module): diff --git a/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py b/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py index f3deaf649e9a..d4db6546a499 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/mova_audio_dit.py @@ -182,7 +182,7 @@ def __init__( self.num_channels_latents = out_dim self.layer_names = ["blocks"] self.cnt = 0 - self.teacache_thresh = 0 + self.rel_l1_thresh = 0 self.coefficients = [] self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = None diff --git a/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py b/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py index f6f520690e85..5950e1a0da27 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py @@ -497,7 +497,7 @@ def __init__( self.num_channels_latents = out_dim self.layer_names = ["blocks"] self.cnt = 0 - self.teacache_thresh = 0 + self.rel_l1_thresh = 0 self.coefficients = [] self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = None diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index 4f3080c50fc2..a467d7d41927 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -900,8 +900,6 @@ def __init__( ) # For type checking - - self.cnt = 0 self.__post_init__() # misc @@ -955,16 +953,22 @@ def forward( guidance=None, **kwargs, ) -> torch.Tensor: - forward_batch = get_forward_context().forward_batch + + # if caching is enabled, we might initialize or reset the cache state + if self.cache is None: + self.init_cache() + if self.cache: + self.cache.maybe_reset() + + forward_context = get_forward_context() + forward_batch = forward_context.forward_batch + if forward_batch is not None: sequence_shard_enabled = ( forward_batch.enable_sequence_shard and self.sp_size > 1 ) else: sequence_shard_enabled = False - self.enable_teacache = ( - forward_batch is not None and forward_batch.enable_teacache - ) orig_dtype = hidden_states.dtype if not isinstance(encoder_hidden_states, torch.Tensor): @@ -1096,25 +1100,31 @@ def forward( # 4. Transformer blocks # if caching is enabled, we might be able to skip the forward pass - should_skip_forward = self.should_skip_forward_for_cached_states( - timestep_proj=timestep_proj, temb=temb - ) + should_skip_forward = False + if self.cache and not self.calibrate_cache: + modulated_input = ( + timestep_proj if self.cache.cache_params.use_ret_steps else temb + ) + should_skip_forward = self.cache.should_skip(modulated_input) if should_skip_forward: - hidden_states = self.retrieve_cached_states(hidden_states) + # compute hidden_states using the cached state + hidden_states = self.cache.read(hidden_states) else: - # if teacache is enabled, we need to cache the original hidden states - if self.enable_teacache: + if self.cache and not self.calibrate_cache: original_hidden_states = hidden_states.clone() for block in self.blocks: hidden_states = block( hidden_states, encoder_hidden_states, timestep_proj, freqs_cis ) - # if teacache is enabled, we need to cache the original hidden states - if self.enable_teacache: - self.maybe_cache_states(hidden_states, original_hidden_states) - self.cnt += 1 + + if self.cache and not self.calibrate_cache: + self.cache.write( + hidden_states, + original_hidden_states, + modulated_input=modulated_input, + ) if sequence_shard_enabled: hidden_states = hidden_states.contiguous() @@ -1152,55 +1162,5 @@ def forward( return output - def maybe_cache_states( - self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor - ) -> None: - """Cache residual with CFG positive/negative separation.""" - residual = hidden_states.squeeze(0) - original_hidden_states - if not self.is_cfg_negative: - self.previous_residual = residual - else: - self.previous_residual_negative = residual - - def should_skip_forward_for_cached_states(self, **kwargs) -> bool: - if not self.enable_teacache: - return False - ctx = self._get_teacache_context() - if ctx is None: - return False - - # Initialize Wan-specific parameters - teacache_params = ctx.teacache_params - use_ret_steps = teacache_params.use_ret_steps - start_skipping, end_skipping = teacache_params.get_skip_boundaries( - ctx.num_inference_steps, ctx.do_cfg - ) - - # Determine boundary step - is_boundary_step = self.cnt < start_skipping or self.cnt >= end_skipping - - timestep_proj = kwargs["timestep_proj"] - temb = kwargs["temb"] - modulated_inp = timestep_proj if use_ret_steps else temb - - self.is_cfg_negative = ctx.is_cfg_negative - - # Use shared helper to compute cache decision - should_calc = self._compute_teacache_decision( - modulated_inp=modulated_inp, - is_boundary_step=is_boundary_step, - coefficients=ctx.coefficients, - teacache_thresh=ctx.teacache_thresh, - ) - - return not should_calc - - def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Retrieve cached residual with CFG positive/negative separation.""" - if not self.is_cfg_negative: - return hidden_states + self.previous_residual - else: - return hidden_states + self.previous_residual_negative - EntryClass = WanTransformer3DModel diff --git a/python/sglang/multimodal_gen/test/server/perf_baselines.json b/python/sglang/multimodal_gen/test/server/perf_baselines.json index 99926bd78f4d..f446e1e1d1de 100644 --- a/python/sglang/multimodal_gen/test/server/perf_baselines.json +++ b/python/sglang/multimodal_gen/test/server/perf_baselines.json @@ -807,69 +807,69 @@ }, "wan2_1_t2v_1.3b_teacache_enabled": { "stages_ms": { - "DenoisingStage": 4598.36, - "InputValidationStage": 0.07, - "DecodingStage": 552.92, - "LatentPreparationStage": 0.26, - "per_frame_generation": null, - "TextEncodingStage": 1114.01, - "TimestepPreparationStage": 2.1 + "InputValidationStage": 0.06, + "TextEncodingStage": 1130.34, + "LatentPreparationStage": 0.16, + "TimestepPreparationStage": 2.09, + "DenoisingStage": 5211.97, + "DecodingStage": 544.5, + "per_frame_generation": null }, "denoise_step_ms": { - "0": 94.24, - "1": 172.68, - "2": 169.48, - "3": 169.08, - "4": 168.38, - "5": 167.27, - "6": 62.95, - "7": 119.56, - "8": 53.34, - "9": 121.85, - "10": 47.64, - "11": 125.75, - "12": 3.24, - "13": 48.21, - "14": 125.17, - "15": 3.71, - "16": 48.15, - "17": 124.61, - "18": 3.3, - "19": 47.25, - "20": 129.33, - "21": 3.11, - "22": 48.03, - "23": 127.46, - "24": 3.37, - "25": 45.6, - "26": 127.17, - "27": 3.35, - "28": 49.83, - "29": 125.42, - "30": 3.19, - "31": 42.76, - "32": 131.19, - "33": 2.93, - "34": 130.04, - "35": 44.77, - "36": 131.45, - "37": 44.06, - "38": 131.02, - "39": 43.48, - "40": 130.42, - "41": 45.24, - "42": 129.46, - "43": 44.6, - "44": 130.33, - "45": 173.84, - "46": 175.58, - "47": 168.16, - "48": 173.85, - "49": 177.56 - }, - "expected_e2e_ms": 6497.84, - "expected_avg_denoise_ms": 91.85, - "expected_median_denoise_ms": 120.7 + "0": 96.66, + "1": 174.85, + "2": 169.04, + "3": 169.53, + "4": 169.79, + "5": 165.17, + "6": 118.69, + "7": 55.2, + "8": 122.32, + "9": 54.8, + "10": 116.48, + "11": 55.44, + "12": 116.39, + "13": 55.61, + "14": 121.37, + "15": 51.24, + "16": 117.63, + "17": 56.92, + "18": 117.3, + "19": 59.31, + "20": 116.67, + "21": 56.57, + "22": 115.76, + "23": 57.19, + "24": 116.64, + "25": 55.82, + "26": 121.76, + "27": 52.48, + "28": 119.55, + "29": 56.15, + "30": 118.49, + "31": 57.23, + "32": 116.67, + "33": 57.67, + "34": 116.47, + "35": 56.77, + "36": 125.22, + "37": 47.68, + "38": 121.38, + "39": 52.61, + "40": 121.45, + "41": 52.05, + "42": 118.42, + "43": 56.53, + "44": 118.46, + "45": 169.22, + "46": 172.07, + "47": 177.58, + "48": 147.35, + "49": 171.58 + }, + "expected_e2e_ms": 7196.49, + "expected_avg_denoise_ms": 104.14, + "expected_median_denoise_ms": 116.65 }, "wan2_1_t2v_1.3b": { "stages_ms": { diff --git a/python/sglang/multimodal_gen/test/unit/test_sampling_params.py b/python/sglang/multimodal_gen/test/unit/test_sampling_params.py index ea818b82ed3a..5f686c7c90ca 100644 --- a/python/sglang/multimodal_gen/test/unit/test_sampling_params.py +++ b/python/sglang/multimodal_gen/test/unit/test_sampling_params.py @@ -93,7 +93,7 @@ def coefficients_callback(_: TeaCacheParams) -> list[float]: params = SamplingParams( prompt="callable teacache", - teacache_params=TeaCacheParams( + cache_params=TeaCacheParams( coefficients_callback=coefficients_callback, ), )