diff --git a/python/sglang/multimodal_gen/configs/sample/magcache.py b/python/sglang/multimodal_gen/configs/sample/magcache.py new file mode 100644 index 000000000000..8ed098aaa3f5 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/sample/magcache.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +from sglang.multimodal_gen.configs.sample.sampling_params import CacheParams + + +@dataclass +class MagCacheParams(CacheParams): + """ + MagCache configuration for magnitude-ratio-based caching. + + MagCache accelerates diffusion inference by skipping forward passes when + magnitude ratios of consecutive residuals are predictably similar. + + Attributes: + threshold: Accumulated error threshold (default 0.06 from paper). + Lower = higher quality but slower. Higher = faster but lower quality. + max_skip_steps: Maximum consecutive skips allowed (default 3). + Prevents infinite skipping even if error is low. + skip_start_step: Number of denoising steps at the start where skipping is disabled. + skip_end_step: Number of denoising steps at the end where skipping is disabled (0 = active until last step). + """ + + cache_type: str = "magcache" + threshold: float = 0.12 + max_skip_steps: int = 4 + skip_start_step: int = 10 + skip_end_step: int = 0 + mag_ratios: list[float] | None = None diff --git a/python/sglang/multimodal_gen/configs/sample/sampling_params.py b/python/sglang/multimodal_gen/configs/sample/sampling_params.py index 42eba884de3b..05be29e05672 100644 --- a/python/sglang/multimodal_gen/configs/sample/sampling_params.py +++ b/python/sglang/multimodal_gen/configs/sample/sampling_params.py @@ -22,6 +22,8 @@ logger = init_logger(__name__) if TYPE_CHECKING: + from sglang.multimodal_gen.configs.sample.magcache import MagCacheParams + from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams from sglang.multimodal_gen.runtime.server_args import ServerArgs @@ -154,8 +156,12 @@ class SamplingParams: cfg_normalization: float | bool = 0.0 boundary_ratio: float | None = None - # TeaCache parameters + # Cache acceleration enable_teacache: bool = False + teacache_params: "TeaCacheParams | None" = None + enable_magcache: bool = False + magcache_params: "MagCacheParams | None" = None + calibrate_cache: bool = False # Profiling profile: bool = False @@ -601,6 +607,37 @@ def add_cli_args(parser: Any) -> Any: "--enable-teacache", action="store_true", default=SamplingParams.enable_teacache, + help="Enable TeaCache acceleration for diffusion inference.", + ) + parser.add_argument( + "--teacache-params", + type=json.loads, + default=None, + help=( + 'TeaCache params as a JSON object, e.g. \'{"teacache_thresh": 0.08, "coefficients": [1.0, 2.0]}\'. ' + "Fields map directly to TeaCacheParams dataclass fields." + ), + ) + parser.add_argument( + "--enable-magcache", + action="store_true", + default=SamplingParams.enable_magcache, + help="Enable MagCache acceleration for diffusion inference.", + ) + parser.add_argument( + "--magcache-params", + type=json.loads, + default=None, + help=( + 'MagCache params as a JSON object, e.g. \'{"threshold": 0.12, "max_skip_steps": 4}\'. ' + "Fields map directly to MagCacheParams dataclass fields." + ), + ) + parser.add_argument( + "--calibrate-cache", + action="store_true", + default=SamplingParams.calibrate_cache, + help="Run in calibration mode: collect magnitude ratio statistics instead of skipping steps.", ) # profiling diff --git a/python/sglang/multimodal_gen/configs/sample/teacache.py b/python/sglang/multimodal_gen/configs/sample/teacache.py index ada71d0b3618..3d6b3ea99389 100644 --- a/python/sglang/multimodal_gen/configs/sample/teacache.py +++ b/python/sglang/multimodal_gen/configs/sample/teacache.py @@ -10,17 +10,38 @@ class TeaCacheParams(CacheParams): cache_type: str = "teacache" teacache_thresh: float = 0.0 + skip_start_step: int = 5 + skip_end_step: int = 0 coefficients: list[float] = field(default_factory=list) @dataclass class WanTeaCacheParams(CacheParams): - # Unfortunately, TeaCache is very different for Wan than other models + # Default threshold and coefficients are for Wan T2V 1.3B (use_ret_steps=True). + # For other Wan variants, override these values via --teacache-params. cache_type: str = "teacache" - teacache_thresh: float = 0.0 + teacache_thresh: float = 0.08 + skip_start_step: int = 5 + skip_end_step: int = 0 use_ret_steps: bool = True - ret_steps_coeffs: list[float] = field(default_factory=list) - non_ret_steps_coeffs: list[float] = field(default_factory=list) + ret_steps_coeffs: list[float] = field( + default_factory=lambda: [ + -5.21862437e04, + 9.23041404e03, + -5.28275948e02, + 1.36987616e01, + -4.99875664e-02, + ] + ) + non_ret_steps_coeffs: list[float] = field( + default_factory=lambda: [ + 2.39676752e03, + -1.31110545e03, + 2.01331979e02, + -8.29855975e00, + 1.37887774e-01, + ] + ) @property def coefficients(self) -> list[float]: @@ -28,16 +49,3 @@ def coefficients(self) -> list[float]: return self.ret_steps_coeffs else: return self.non_ret_steps_coeffs - - @property - def ret_steps(self) -> int: - if self.use_ret_steps: - return 5 * 2 - else: - return 1 * 2 - - def get_cutoff_steps(self, num_inference_steps: int) -> int: - if self.use_ret_steps: - return num_inference_steps * 2 - else: - return num_inference_steps * 2 - 2 diff --git a/python/sglang/multimodal_gen/configs/sample/wan.py b/python/sglang/multimodal_gen/configs/sample/wan.py index 2c405b2f050b..28656688f620 100644 --- a/python/sglang/multimodal_gen/configs/sample/wan.py +++ b/python/sglang/multimodal_gen/configs/sample/wan.py @@ -3,9 +3,115 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field +from sglang.multimodal_gen.configs.sample.magcache import MagCacheParams from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams from sglang.multimodal_gen.configs.sample.teacache import WanTeaCacheParams +# Magnitude ratio arrays from the reference implementation: +# https://github.com/Zehong-Ma/MagCache/blob/df81cb181776c2c61477c08e1d21f87fda1cd938/MagCache4Wan2.1/magcache_generate.py +T2V_13B_MAG_RATIOS = [ + 1.0, + 1.0, + 1.0124, + 1.02213, + 1.00166, + 1.0041, + 0.99791, + 1.00061, + 0.99682, + 0.99762, + 0.99634, + 0.99685, + 0.99567, + 0.99586, + 0.99416, + 0.99422, + 0.99578, + 0.99575, + 0.9957, + 0.99563, + 0.99511, + 0.99506, + 0.99535, + 0.99531, + 0.99552, + 0.99549, + 0.99541, + 0.99539, + 0.9954, + 0.99536, + 0.99489, + 0.99485, + 0.99518, + 0.99514, + 0.99484, + 0.99478, + 0.99481, + 0.99479, + 0.99415, + 0.99413, + 0.99419, + 0.99416, + 0.99396, + 0.99393, + 0.99388, + 0.99386, + 0.99349, + 0.99349, + 0.99309, + 0.99304, + 0.9927, + 0.9927, + 0.99228, + 0.99226, + 0.99171, + 0.9917, + 0.99137, + 0.99135, + 0.99068, + 0.99063, + 0.99005, + 0.99003, + 0.98944, + 0.98942, + 0.98849, + 0.98849, + 0.98758, + 0.98757, + 0.98644, + 0.98643, + 0.98504, + 0.98503, + 0.9836, + 0.98359, + 0.98202, + 0.98201, + 0.97977, + 0.97978, + 0.97717, + 0.97718, + 0.9741, + 0.97411, + 0.97003, + 0.97002, + 0.96538, + 0.96541, + 0.9593, + 0.95933, + 0.95086, + 0.95089, + 0.94013, + 0.94019, + 0.92402, + 0.92414, + 0.90241, + 0.9026, + 0.86821, + 0.86868, + 0.81838, + 0.81939, +] + @dataclass class WanT2V_1_3B_SamplingParams(SamplingParams): @@ -50,6 +156,16 @@ class WanT2V_1_3B_SamplingParams(SamplingParams): ) ) + magcache_params: MagCacheParams = field( + default_factory=lambda: MagCacheParams( + threshold=0.12, + max_skip_steps=4, + skip_start_step=10, + skip_end_step=0, + mag_ratios=T2V_13B_MAG_RATIOS, + ) + ) + @dataclass class WanT2V_14B_SamplingParams(SamplingParams): diff --git a/python/sglang/multimodal_gen/runtime/cache/__init__.py b/python/sglang/multimodal_gen/runtime/cache/__init__.py index 62f0f8457f8f..3d4c381d29e7 100644 --- a/python/sglang/multimodal_gen/runtime/cache/__init__.py +++ b/python/sglang/multimodal_gen/runtime/cache/__init__.py @@ -6,22 +6,40 @@ diffusion transformer (DiT) inference: - TeaCache: Temporal similarity-based caching for diffusion models +- MagCache: Magnitude-ratio-based caching for diffusion models - cache-dit integration: Block-level caching with DBCache and TaylorSeer """ +from sglang.multimodal_gen.runtime.cache.base import DiffusionCache from sglang.multimodal_gen.runtime.cache.cache_dit_integration import ( CacheDitConfig, enable_cache_on_dual_transformer, enable_cache_on_transformer, get_scm_mask, ) -from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheContext, TeaCacheMixin +from sglang.multimodal_gen.runtime.cache.magcache import ( + MagCacheContext, + MagCacheState, + MagCacheStrategy, +) +from sglang.multimodal_gen.runtime.cache.teacache import ( + TeaCacheContext, + TeaCacheState, + TeaCacheStrategy, +) __all__ = [ - # TeaCache (always available) + # Base + "DiffusionCache", + # TeaCache "TeaCacheContext", - "TeaCacheMixin", + "TeaCacheState", + "TeaCacheStrategy", + # MagCache + "MagCacheContext", + "MagCacheState", + "MagCacheStrategy", # cache-dit integration (lazy-loaded, requires cache-dit package) "CacheDitConfig", "enable_cache_on_transformer", diff --git a/python/sglang/multimodal_gen/runtime/cache/base.py b/python/sglang/multimodal_gen/runtime/cache/base.py new file mode 100644 index 000000000000..cf07a8c36bde --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/cache/base.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Base class for diffusion model cache strategies (TeaCache, MagCache, etc.). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +if TYPE_CHECKING: + from sglang.multimodal_gen.runtime.cache.magcache import MagCacheState + from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheState + + +class DiffusionCache: + """ + Base class for diffusion model caching strategies. + + Each subclass owns its own state (positive + negative CFG branch) and + context extraction logic. CachableDiT holds a single + `self.cache: DiffusionCache | None` and delegates all decisions here. + + Subclasses must implement: reset, get_context, should_skip. + maybe_cache, retrieve, and calibrate have default implementations. + + Subclasses set self.state / self.state_neg to strategy-specific state + objects (positive and negative CFG branches respectively). state_neg is + None when CFG negative-branch caching is disabled. + + Typical forward pass usage in CachableDiT: + + ctx = self.cache.get_context(self.cnt) + if ctx and self.cache.should_skip(ctx, timestep_proj=..., temb=...): + hidden_states = self.cache.retrieve(hidden_states, ctx) + else: + original_hidden_states = hidden_states.clone() + # ... run transformer blocks ... + if calibrate_cache: + self.cache.calibrate(hidden_states, original_hidden_states, ctx) + else: + self.cache.maybe_cache(hidden_states, original_hidden_states, ctx) + """ + + def __init__(self) -> None: + self.state: MagCacheState | TeaCacheState | None = None + self.state_neg: MagCacheState | TeaCacheState | None = None + + def reset(self) -> None: + """Reset all state at the start of a new generation.""" + raise NotImplementedError + + def get_context(self, cnt: int): + """ + Read the global forward_context / forward_batch and return a + strategy-specific context dataclass, or None to bypass caching. + + cnt is the monotonically increasing forward-call index owned by the + model (model.cnt), incremented on every call regardless of whether + the forward pass was skipped. + """ + raise NotImplementedError + + def should_skip(self, ctx, **kwargs) -> bool: + """ + Decide whether to skip the transformer forward pass and reuse the + cached residual. kwargs carries model-specific tensors (e.g. + timestep_proj, temb) needed by some strategies. + """ + raise NotImplementedError + + def maybe_cache( + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + ctx, + ) -> None: + """Store residual after a full forward pass for future reuse.""" + state = ( + self.state_neg + if (ctx.is_cfg_negative and self.state_neg is not None) + else self.state + ) + state.previous_residual = hidden_states.squeeze(0) - original_hidden_states + + def retrieve(self, hidden_states: torch.Tensor, ctx) -> torch.Tensor: + """Reconstruct output from cached residual.""" + state = ( + self.state_neg + if (ctx.is_cfg_negative and self.state_neg is not None) + else self.state + ) + return hidden_states + state.previous_residual + + def calibrate( + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + ctx, + ) -> None: + """Log calibration metrics. No-op by default.""" + pass diff --git a/python/sglang/multimodal_gen/runtime/cache/magcache.py b/python/sglang/multimodal_gen/runtime/cache/magcache.py new file mode 100644 index 000000000000..6a3a4fbbf2d0 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/cache/magcache.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +MagCache: Magnitude-ratio-based caching for diffusion models. + +Skips redundant transformer forward passes when magnitude ratios of +consecutive residuals are predictably similar. + +References: +- MagCache: https://openreview.net/forum?id=KZn7TDOL4J +""" + +import json +import os +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import torch +import torch.nn.functional as F + +from sglang.multimodal_gen.runtime.cache.base import DiffusionCache + +if TYPE_CHECKING: + from sglang.multimodal_gen.configs.sample.magcache import MagCacheParams + + +@dataclass +class MagCacheState: + """Per-CFG-branch state for MagCache.""" + + norm_ratio: float = 1.0 + accumulated_error: float = 0.0 + consecutive_skips: int = 0 + previous_residual: torch.Tensor | None = field(default=None, repr=False) + + def reset(self) -> None: + self.norm_ratio = 1.0 + self.accumulated_error = 0.0 + self.consecutive_skips = 0 + + +@dataclass +class MagCacheContext: + """Per-step snapshot for MagCache decisions. + + cnt is the forward-call index: timestep * 2 + cfg_offset when CFG is on, + so min_cnt/max_cnt boundary checks are scaled accordingly. + """ + + cnt: int + num_inference_steps: int + do_cfg: bool + is_cfg_negative: bool + params: "MagCacheParams" + + +class MagCacheStrategy(DiffusionCache): + """MagCache caching strategy. + + Constructed by CachableDiT.init_cache() once per generation when + magcache is selected. Owns both CFG-branch states. + """ + + def __init__(self, supports_cfg_cache: bool) -> None: + self.calibration_path = None + self.state = MagCacheState() + self.state_neg = MagCacheState() if supports_cfg_cache else None + + def reset(self) -> None: + assert isinstance(self.state, MagCacheState) + self.state.reset() + self.state.previous_residual = None + if self.state_neg is not None: + self.state_neg.reset() + self.state_neg.previous_residual = None + + def get_context(self, cnt: int) -> MagCacheContext | None: + from sglang.multimodal_gen.runtime.managers.forward_context import ( + get_forward_context, + ) + + forward_context = get_forward_context() + fb = forward_context.forward_batch + if fb is None: + return None + + steps = fb.num_inference_steps + do_cfg = fb.do_classifier_free_guidance + is_neg = fb.is_cfg_negative + params = getattr(fb.sampling_params, "magcache_params", None) + assert ( + params is not None + ), "MagCacheStrategy requires magcache_params in sampling_params" + + return MagCacheContext(cnt, steps, do_cfg, is_neg, params) + + def should_skip(self, ctx: MagCacheContext, **kwargs) -> bool: + state = ( + self.state_neg + if (ctx.is_cfg_negative and self.state_neg is not None) + else self.state + ) + assert isinstance(state, MagCacheState) and isinstance(ctx, MagCacheContext) + + # Never skip on boundary steps + min_cnt = ( + ctx.params.skip_start_step * 2 if ctx.do_cfg else ctx.params.skip_start_step + ) + max_cnt = ( + (ctx.num_inference_steps - ctx.params.skip_end_step) * 2 + if ctx.do_cfg + else (ctx.num_inference_steps - ctx.params.skip_end_step) + ) + if ctx.cnt < min_cnt or ctx.cnt >= max_cnt: + state.reset() + return False + + if ctx.params.mag_ratios is None: + return False + + cur_ratio = ctx.params.mag_ratios[ctx.cnt] + state.norm_ratio *= cur_ratio + state.consecutive_skips += 1 + state.accumulated_error += abs(1 - state.norm_ratio) + + if ( + state.accumulated_error < ctx.params.threshold + and state.consecutive_skips <= ctx.params.max_skip_steps + ): + return True + state.reset() + return False + + def calibrate( + self, hidden_states, original_hidden_states, ctx: MagCacheContext + ) -> None: + state = ( + self.state_neg + if (ctx.is_cfg_negative and self.state_neg is not None) + else self.state + ) + assert isinstance(state, MagCacheState) and isinstance(ctx, MagCacheContext) + + prev = state.previous_residual + curr = hidden_states.squeeze(0) - original_hidden_states + + if prev is None: + mag_ratio, mag_std, cos_dis = 1.0, 0.0, 0.0 + else: + norms = curr.norm(dim=-1) / prev.norm(dim=-1) + mag_ratio = norms.mean().item() + mag_std = norms.std().item() + cos_dis = ( + (1 - F.cosine_similarity(curr, prev, dim=-1, eps=1e-8)).mean().item() + ) + + state.previous_residual = curr + + if self.calibration_path is None: + from sglang.multimodal_gen.envs import SGLANG_DIFFUSION_CACHE_ROOT + from sglang.multimodal_gen.runtime.server_args import get_global_server_args + + cache_dir = os.path.join( + SGLANG_DIFFUSION_CACHE_ROOT, "magcache_calibration" + ) + os.makedirs(cache_dir, exist_ok=True) + model_name = get_global_server_args().model_path.replace("/", "--") + self.calibration_path = os.path.join(cache_dir, f"{model_name}.jsonl") + + with open(self.calibration_path, "a") as f: + f.write( + json.dumps( + { + "cnt": ctx.cnt, + "mag_ratio": mag_ratio, + "mag_std": mag_std, + "cos_dis": cos_dis, + "negative": ctx.is_cfg_negative, + } + ) + + "\n" + ) diff --git a/python/sglang/multimodal_gen/runtime/cache/teacache.py b/python/sglang/multimodal_gen/runtime/cache/teacache.py index 5cdafd08bc04..5f8384388c46 100644 --- a/python/sglang/multimodal_gen/runtime/cache/teacache.py +++ b/python/sglang/multimodal_gen/runtime/cache/teacache.py @@ -2,315 +2,139 @@ """ TeaCache: Temporal similarity-based caching for diffusion models. -TeaCache accelerates diffusion inference by selectively skipping redundant -computation when consecutive diffusion steps are similar enough. This is -achieved by tracking the L1 distance between modulated inputs across timesteps. - -Key concepts: -- Modulated input: The input to transformer blocks after timestep conditioning -- L1 distance: Measures how different consecutive timesteps are -- Threshold: When accumulated L1 distance exceeds threshold, force computation -- CFG support: Separate caches for positive and negative branches +Skips redundant transformer forward passes by tracking the accumulated L1 +distance between modulated inputs across consecutive diffusion steps. References: -- TeaCache: Accelerating Diffusion Models with Temporal Similarity - https://arxiv.org/abs/2411.14324 +- TeaCache: https://arxiv.org/abs/2411.14324 """ -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from dataclasses import dataclass, field +from typing import TYPE_CHECKING import numpy as np import torch -from sglang.multimodal_gen.configs.models import DiTConfig +from sglang.multimodal_gen.runtime.cache.base import DiffusionCache if TYPE_CHECKING: - from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams + from sglang.multimodal_gen.configs.sample.teacache import ( + TeaCacheParams, + WanTeaCacheParams, + ) @dataclass -class TeaCacheContext: - """Common context extracted for TeaCache skip decision. - - This context is populated from the forward_batch and forward_context - during each denoising step, providing all information needed to make - cache decisions. - - Attributes: - current_timestep: Current denoising timestep index (0-indexed). - num_inference_steps: Total number of inference steps. - do_cfg: Whether classifier-free guidance is enabled. - is_cfg_negative: True if currently processing negative CFG branch. - teacache_thresh: Threshold for accumulated L1 distance. - coefficients: Polynomial coefficients for L1 rescaling. - teacache_params: Full TeaCacheParams for model-specific access. - """ - - current_timestep: int - num_inference_steps: int - do_cfg: bool - is_cfg_negative: bool # For CFG branch selection - teacache_thresh: float - coefficients: list[float] - teacache_params: "TeaCacheParams" # Full params for model-specific access - - -class TeaCacheMixin: - """ - Mixin class providing TeaCache optimization functionality. - - TeaCache accelerates diffusion inference by selectively skipping redundant - computation when consecutive diffusion steps are similar enough. - - This mixin should be inherited by DiT model classes that want to support - TeaCache optimization. It provides: - - State management for tracking L1 distances - - CFG-aware caching (separate caches for positive/negative branches) - - Decision logic for when to compute vs. use cache +class TeaCacheState: + """Per-CFG-branch state for TeaCache.""" - Example usage in a DiT model: - class MyDiT(TeaCacheMixin, BaseDiT): - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - self._init_teacache_state() - - def forward(self, hidden_states, timestep, ...): - ctx = self._get_teacache_context() - if ctx is not None: - # Compute modulated input (model-specific, e.g., after timestep embedding) - modulated_input = self._compute_modulated_input(hidden_states, timestep) - is_boundary = (ctx.current_timestep == 0 or - ctx.current_timestep >= ctx.num_inference_steps - 1) - - should_calc = self._compute_teacache_decision( - modulated_inp=modulated_input, - is_boundary_step=is_boundary, - coefficients=ctx.coefficients, - teacache_thresh=ctx.teacache_thresh, - ) - - if not should_calc: - # Use cached residual (must implement retrieve_cached_states) - return self.retrieve_cached_states(hidden_states) - - # Normal forward pass... - output = self._transformer_forward(hidden_states, timestep, ...) - - # Cache states for next step - if ctx is not None: - self.maybe_cache_states(output, hidden_states) - - return output - - Subclass implementation notes: - - `_compute_modulated_input()`: Model-specific method to compute the input - after timestep conditioning (used for L1 distance calculation) - - `retrieve_cached_states()`: Must be overridden to return cached output - - `maybe_cache_states()`: Override to store states for cache retrieval - - Attributes: - cnt: Counter for tracking steps. - enable_teacache: Whether TeaCache is enabled. - previous_modulated_input: Cached modulated input for positive branch. - previous_residual: Cached residual for positive branch. - accumulated_rel_l1_distance: Accumulated L1 distance for positive branch. - is_cfg_negative: Whether currently processing negative CFG branch. - _supports_cfg_cache: Whether this model supports CFG cache separation. - - CFG-specific attributes (only when _supports_cfg_cache is True): - previous_modulated_input_negative: Cached input for negative branch. - previous_residual_negative: Cached residual for negative branch. - accumulated_rel_l1_distance_negative: L1 distance for negative branch. - """ - - # Models that support CFG cache separation (wan/hunyuan/zimage) - # Models not in this set (flux/qwen) auto-disable TeaCache when CFG is enabled - _CFG_SUPPORTED_PREFIXES: set[str] = {"wan", "hunyuan", "zimage"} - config: DiTConfig - - def _init_teacache_state(self) -> None: - """Initialize TeaCache state. Call this in subclass __init__.""" - # Common TeaCache state - self.cnt = 0 - self.enable_teacache = True - # Flag indicating if this model supports CFG cache separation - self._supports_cfg_cache = ( - self.config.prefix.lower() in self._CFG_SUPPORTED_PREFIXES - ) + previous_modulated_input: torch.Tensor | None = field(default=None, repr=False) + previous_residual: torch.Tensor | None = field(default=None, repr=False) + accumulated_rel_l1_distance: float = 0.0 - # Always initialize positive cache fields (used in all modes) - self.previous_modulated_input: torch.Tensor | None = None - self.previous_residual: torch.Tensor | None = None - self.accumulated_rel_l1_distance: float = 0.0 - - self.is_cfg_negative = False - # CFG-specific fields initialized to None (created when CFG is used) - # These are only used when _supports_cfg_cache is True AND do_cfg is True - if self._supports_cfg_cache: - self.previous_modulated_input_negative: torch.Tensor | None = None - self.previous_residual_negative: torch.Tensor | None = None - self.accumulated_rel_l1_distance_negative: float = 0.0 - - def reset_teacache_state(self) -> None: - """Reset all TeaCache state at the start of each generation task.""" - self.cnt = 0 - - # Primary cache fields (always present) + def reset(self) -> None: self.previous_modulated_input = None self.previous_residual = None self.accumulated_rel_l1_distance = 0.0 - self.is_cfg_negative = False - self.enable_teacache = True - # CFG negative cache fields (always reset, may be unused) - if self._supports_cfg_cache: - self.previous_modulated_input_negative = None - self.previous_residual_negative = None - self.accumulated_rel_l1_distance_negative = 0.0 - - def _compute_l1_and_decide( - self, - modulated_inp: torch.Tensor, - coefficients: list[float], - teacache_thresh: float, - ) -> tuple[float, bool]: - """ - Compute L1 distance and decide whether to calculate or use cache. - - Args: - modulated_inp: Current timestep's modulated input. - coefficients: Polynomial coefficients for L1 rescaling. - teacache_thresh: Threshold for cache decision. - - Returns: - Tuple of (new_accumulated_distance, should_calc). - """ - prev_modulated_inp = ( - self.previous_modulated_input_negative - if self.is_cfg_negative - else self.previous_modulated_input - ) - # Defensive check: if previous input is not set, force calculation - if prev_modulated_inp is None: - return 0.0, True - # Compute relative L1 distance - diff = modulated_inp - prev_modulated_inp - rel_l1 = (diff.abs().mean() / prev_modulated_inp.abs().mean()).cpu().item() - - # Apply polynomial rescaling - rescale_func = np.poly1d(coefficients) - - accumulated_rel_l1_distance = ( - self.accumulated_rel_l1_distance_negative - if self.is_cfg_negative - else self.accumulated_rel_l1_distance - ) - accumulated_rel_l1_distance = accumulated_rel_l1_distance + rescale_func(rel_l1) - - if accumulated_rel_l1_distance >= teacache_thresh: - # Threshold exceeded: force compute and reset accumulator - return 0.0, True - # Cache hit: keep accumulated distance - return accumulated_rel_l1_distance, False +@dataclass +class TeaCacheContext: + """Per-step snapshot for TeaCache decisions. - def _compute_teacache_decision( - self, - modulated_inp: torch.Tensor, - is_boundary_step: bool, - coefficients: list[float], - teacache_thresh: float, - ) -> bool: - """ - Compute cache decision for TeaCache. + cnt is the forward-call index: timestep * 2 + cfg_offset when CFG is on, + so min_cnt/max_cnt boundary checks are scaled accordingly. + """ - Args: - modulated_inp: Current timestep's modulated input. - is_boundary_step: True for boundary timesteps that always compute. - coefficients: Polynomial coefficients for L1 rescaling. - teacache_thresh: Threshold for cache decision. + cnt: int + num_inference_steps: int + do_cfg: bool + is_cfg_negative: bool + params: "TeaCacheParams|WanTeaCacheParams" - Returns: - True if forward computation is needed, False to use cache. - """ - if not self.enable_teacache: - return True - if is_boundary_step: - new_accum, should_calc = 0.0, True - else: - new_accum, should_calc = self._compute_l1_and_decide( - modulated_inp=modulated_inp, - coefficients=coefficients, - teacache_thresh=teacache_thresh, - ) +class TeaCacheStrategy(DiffusionCache): + """TeaCache caching strategy. - # Advance baseline and accumulator for the active branch - if not self.is_cfg_negative: - self.previous_modulated_input = modulated_inp.clone() - self.accumulated_rel_l1_distance = new_accum - elif self._supports_cfg_cache: - self.previous_modulated_input_negative = modulated_inp.clone() - self.accumulated_rel_l1_distance_negative = new_accum + Constructed by CachableDiT.init_cache() once per generation when + teacache is selected. Owns both CFG-branch states. + """ - return should_calc + def __init__(self, supports_cfg_cache: bool) -> None: + self.state = TeaCacheState() + self.state_neg = TeaCacheState() if supports_cfg_cache else None - def _get_teacache_context(self) -> TeaCacheContext | None: - """ - Check TeaCache preconditions and extract common context. + def reset(self) -> None: + assert isinstance(self.state, TeaCacheState) + self.state.reset() + if self.state_neg is not None: + self.state_neg.reset() - Returns: - TeaCacheContext if TeaCache is enabled and properly configured, - None if should skip TeaCache logic entirely. - """ + def get_context(self, cnt: int) -> TeaCacheContext | None: from sglang.multimodal_gen.runtime.managers.forward_context import ( get_forward_context, ) forward_context = get_forward_context() - forward_batch = forward_context.forward_batch - - # Early return checks - if ( - forward_batch is None - or not forward_batch.enable_teacache - or forward_batch.teacache_params is None - ): + fb = forward_context.forward_batch + if fb is None: return None - teacache_params = forward_batch.teacache_params + steps = fb.num_inference_steps + do_cfg = fb.do_classifier_free_guidance + is_neg = fb.is_cfg_negative + params = getattr(fb.sampling_params, "teacache_params", None) + assert ( + params is not None + ), "TeaCacheStrategy requires teacache_params in sampling_params" + + return TeaCacheContext(cnt, steps, do_cfg, is_neg, params) + + def should_skip(self, ctx: TeaCacheContext, **kwargs) -> bool: + state = ( + self.state_neg + if (ctx.is_cfg_negative and self.state_neg is not None) + else self.state + ) + assert isinstance(state, TeaCacheState) and isinstance(ctx, TeaCacheContext) - # Extract common values - current_timestep = forward_context.current_timestep - num_inference_steps = forward_batch.num_inference_steps - do_cfg = forward_batch.do_classifier_free_guidance - is_cfg_negative = forward_batch.is_cfg_negative + # Cannot skip on boundary steps + min_cnt = ( + ctx.params.skip_start_step * 2 if ctx.do_cfg else ctx.params.skip_start_step + ) + max_cnt = ( + (ctx.num_inference_steps - ctx.params.skip_end_step) * 2 + if ctx.do_cfg + else (ctx.num_inference_steps - ctx.params.skip_end_step) + ) + if ctx.cnt < min_cnt or ctx.cnt >= max_cnt: + state.reset() + return False - # Reset at first timestep - if current_timestep == 0 and not self.is_cfg_negative: - self.reset_teacache_state() + modulated_inp = ( + kwargs["timestep_proj"] if ctx.params.use_ret_steps else kwargs["temb"] + ) - return TeaCacheContext( - current_timestep=current_timestep, - num_inference_steps=num_inference_steps, - do_cfg=do_cfg, - is_cfg_negative=is_cfg_negative, - teacache_thresh=teacache_params.teacache_thresh, - coefficients=teacache_params.coefficients, - teacache_params=teacache_params, + # Cannot skip when have no previous input + if state.previous_modulated_input is None: + state.previous_modulated_input = modulated_inp.clone() + return False + + # Accumulate relative L1 distance + diff = modulated_inp - state.previous_modulated_input + rel_l1 = ( + (diff.abs().mean() / state.previous_modulated_input.abs().mean()) + .cpu() + .item() ) + accumulated = state.accumulated_rel_l1_distance + np.poly1d( + ctx.params.coefficients + )(rel_l1) - def maybe_cache_states( - self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor - ) -> None: - """Cache states for later retrieval. Override in subclass if needed.""" - pass + state.accumulated_rel_l1_distance = accumulated + state.previous_modulated_input = modulated_inp.clone() - def should_skip_forward_for_cached_states(self, **kwargs: dict[str, Any]) -> bool: - """Check if forward can be skipped using cached states.""" + if accumulated < ctx.params.teacache_thresh: + return True + state.reset() return False - - def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Retrieve cached states. Must be implemented by subclass.""" - raise NotImplementedError("retrieve_cached_states is not implemented") diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py b/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py index 2f71ad1d1fd5..bb69248cbaee 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py @@ -216,7 +216,6 @@ def generate( sampling_params=sampling_params, ) requests.append(req) - results: list[GenerationResult] = [] total_start_time = time.perf_counter() diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py index 1162084515f5..d1f28c618aa5 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py @@ -131,6 +131,7 @@ async def generations( true_cfg_scale=request.true_cfg_scale, negative_prompt=request.negative_prompt, enable_teacache=request.enable_teacache, + enable_magcache=request.enable_magcache, output_compression=request.output_compression, output_quality=request.output_quality, enable_upscaling=request.enable_upscaling, @@ -210,6 +211,8 @@ async def edits( output_quality: Optional[str] = Form("default"), output_compression: Optional[int] = Form(None), enable_teacache: Optional[bool] = Form(False), + enable_magcache: Optional[bool] = Form(False), + calibrate_cache: Optional[bool] = Form(False), enable_upscaling: Optional[bool] = Form(False), upscaling_model_path: Optional[str] = Form(None), upscaling_scale: Optional[int] = Form(4), @@ -265,6 +268,7 @@ async def edits( true_cfg_scale=true_cfg_scale, num_inference_steps=num_inference_steps, enable_teacache=enable_teacache, + enable_magcache=enable_magcache, num_frames=num_frames, output_compression=output_compression, output_quality=output_quality, diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py index 425a17f7b3ba..35451d00e7e7 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py @@ -48,6 +48,8 @@ class ImageGenerationsRequest(BaseModel): output_quality: Optional[str] = "default" output_compression: Optional[int] = None enable_teacache: Optional[bool] = False + enable_magcache: Optional[bool] = False + calibrate_cache: Optional[bool] = False # Upscaling enable_upscaling: Optional[bool] = False upscaling_model_path: Optional[str] = None @@ -96,6 +98,8 @@ class VideoGenerationsRequest(BaseModel): ) negative_prompt: Optional[str] = None enable_teacache: Optional[bool] = False + enable_magcache: Optional[bool] = False + calibrate_cache: Optional[bool] = False # Frame interpolation enable_frame_interpolation: Optional[bool] = False frame_interpolation_exp: Optional[int] = 1 # 1=2×, 2=4× diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py index d85ce7d40b51..69867b5d2d80 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py @@ -71,6 +71,8 @@ def _build_video_sampling_params(request_id: str, request: VideoGenerationsReque guidance_scale_2=request.guidance_scale_2, negative_prompt=request.negative_prompt, enable_teacache=request.enable_teacache, + enable_magcache=request.enable_magcache, + calibrate_cache=request.calibrate_cache, enable_frame_interpolation=request.enable_frame_interpolation, frame_interpolation_exp=request.frame_interpolation_exp, frame_interpolation_scale=request.frame_interpolation_scale, @@ -180,6 +182,8 @@ async def create_video( guidance_scale: Optional[float] = Form(None), num_inference_steps: Optional[int] = Form(None), enable_teacache: Optional[bool] = Form(False), + enable_magcache: Optional[bool] = Form(False), + calibrate_cache: Optional[bool] = Form(False), enable_frame_interpolation: Optional[bool] = Form(False), frame_interpolation_exp: Optional[int] = Form(1), frame_interpolation_scale: Optional[float] = Form(1.0), @@ -258,6 +262,8 @@ async def create_video( negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, enable_teacache=enable_teacache, + enable_magcache=enable_magcache, + calibrate_cache=calibrate_cache, enable_frame_interpolation=enable_frame_interpolation, frame_interpolation_exp=frame_interpolation_exp, frame_interpolation_scale=frame_interpolation_scale, diff --git a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py index f170809a738e..e0f59f43f6ac 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weights_updater.py +++ b/python/sglang/multimodal_gen/runtime/loader/weights_updater.py @@ -46,13 +46,13 @@ import torch from torch.distributed.tensor import DTensor, distribute_tensor -from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheMixin from sglang.multimodal_gen.runtime.loader.utils import ( _list_safetensors_files, ) from sglang.multimodal_gen.runtime.loader.weight_utils import ( safetensors_weights_iterator, ) +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.pipelines.diffusers_pipeline import DiffusersPipeline from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin @@ -213,8 +213,8 @@ def update_weights_from_disk( if success and flush_cache: for _, module in modules_to_update: - if isinstance(module, TeaCacheMixin): - module.reset_teacache_state() + if isinstance(module, CachableDiT) and module.cache is not None: + module.cache.reset() logger.info(message) return success, message diff --git a/python/sglang/multimodal_gen/runtime/models/dits/base.py b/python/sglang/multimodal_gen/runtime/models/dits/base.py index 1048a5196630..dc3a79ebbb6c 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/base.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/base.py @@ -8,16 +8,10 @@ from torch import nn from sglang.multimodal_gen.configs.models import DiTConfig - -# NOTE: TeaCacheContext and TeaCacheMixin have been moved to -# sglang.multimodal_gen.runtime.cache.teacache -# For backwards compatibility, re-export from the new location -from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheContext # noqa: F401 -from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheMixin +from sglang.multimodal_gen.runtime.cache.base import DiffusionCache from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum -# TODO class BaseDiT(nn.Module, ABC): _fsdp_shard_conditions: list = [] _compile_conditions: list = [] @@ -83,30 +77,83 @@ def device(self) -> torch.device: return next(self.parameters()).device -class CachableDiT(TeaCacheMixin, BaseDiT): +_CFG_SUPPORTED_PREFIXES: set[str] = {"wan", "hunyuan", "zimage"} + + +class CachableDiT(BaseDiT): """ - An intermediate base class that adds TeaCache optimization functionality to DiT models. + BaseDiT subclass that adds pluggable step-caching support (TeaCache / MagCache). - Inherits TeaCacheMixin for cache logic and BaseDiT for core DiT functionality. + self.cache is None until init_cache() is called on the first forward pass, + at which point the correct DiffusionCache strategy is constructed from the + runtime forward_batch context. """ - # These are required class attributes that should be overridden by concrete implementations _fsdp_shard_conditions = [] param_names_mapping = {} reverse_param_names_mapping = {} lora_param_names_mapping: dict = {} - # Ensure these instance attributes are properly defined in subclasses hidden_size: int num_attention_heads: int num_channels_latents: int - # always supports torch_sdpa _supported_attention_backends: set[AttentionBackendEnum] = ( DiTConfig()._supported_attention_backends ) def __init__(self, config: DiTConfig, **kwargs) -> None: super().__init__(config, **kwargs) - self._init_teacache_state() + self.cache: DiffusionCache | None = None + self.calibrate_cache: bool = False + self.cnt: int = 0 + + def init_cache(self) -> None: + """Construct the cache strategy from the current forward_batch context. + + Called lazily on the first forward pass because sampling params + (teacache_params, magcache_params, num_steps) are only available then. + """ + from sglang.multimodal_gen.runtime.cache.magcache import MagCacheStrategy + from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheStrategy + from sglang.multimodal_gen.runtime.managers.forward_context import ( + get_forward_context, + ) + + fb = get_forward_context().forward_batch + if fb is None: + return + + if fb.enable_teacache and fb.enable_magcache: + raise ValueError("TeaCache and MagCache cannot both be enabled") + + self.calibrate_cache = fb.calibrate_cache + supports_cfg = self.config.prefix.lower() in _CFG_SUPPORTED_PREFIXES + + if fb.enable_teacache: + self.cache = TeaCacheStrategy(supports_cfg_cache=supports_cfg) + elif fb.enable_magcache: + self.cache = MagCacheStrategy(supports_cfg_cache=supports_cfg) + else: + self.cache = None + + # todo: only used in hunyuanvideo.py; refactor and remove this method + def maybe_cache_states( + self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor, ctx + ) -> None: + if self.cache is not None: + self.cache.maybe_cache(hidden_states, original_hidden_states, ctx) + + # todo: only used in hunyuanvideo.py; refactor and remove this method + def retrieve_cached_states(self, hidden_states: torch.Tensor, ctx) -> torch.Tensor: + return self.cache.retrieve(hidden_states, ctx) + + # todo: only used in hunyuanvideo.py; refactor and remove this method + def should_skip_forward_for_cached_states(self, **kwargs) -> bool: + if self.cache is None or self.calibrate_cache: + return False + ctx = self.cache.get_context(self.cnt) + if ctx is None: + return False + return self.cache.should_skip(ctx, **kwargs) @classmethod def get_nunchaku_quant_rules(cls) -> dict[str, dict[str, Any]]: diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py index 4a2798a4a934..4e417d02dd29 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -10,7 +10,6 @@ import torch.nn as nn from sglang.multimodal_gen.configs.models.dits import WanVideoConfig -from sglang.multimodal_gen.configs.sample.wan import WanTeaCacheParams from sglang.multimodal_gen.runtime.distributed import ( divide, get_sp_group, @@ -901,7 +900,6 @@ def __init__( ) # For type checking - self.cnt = 0 self.__post_init__() @@ -956,16 +954,26 @@ def forward( guidance=None, **kwargs, ) -> torch.Tensor: - forward_batch = get_forward_context().forward_batch + + forward_context = get_forward_context() + forward_batch = forward_context.forward_batch + current_timestep = forward_context.current_timestep + is_cfg_negative = ( + forward_batch.is_cfg_negative if forward_batch is not None else False + ) + + if self.cache is None: + self.init_cache() + if self.cache is not None and current_timestep == 0 and not is_cfg_negative: + self.cache.reset() + self.cnt = 0 + if forward_batch is not None: sequence_shard_enabled = ( forward_batch.enable_sequence_shard and self.sp_size > 1 ) else: sequence_shard_enabled = False - self.enable_teacache = ( - forward_batch is not None and forward_batch.enable_teacache - ) orig_dtype = hidden_states.dtype if not isinstance(encoder_hidden_states, torch.Tensor): @@ -1097,24 +1105,29 @@ def forward( # 4. Transformer blocks # if caching is enabled, we might be able to skip the forward pass - should_skip_forward = self.should_skip_forward_for_cached_states( - timestep_proj=timestep_proj, temb=temb + ctx = self.cache.get_context(self.cnt) if self.cache is not None else None + should_skip_forward = ( + ctx is not None + and not self.calibrate_cache + and self.cache.should_skip(ctx, timestep_proj=timestep_proj, temb=temb) ) if should_skip_forward: - hidden_states = self.retrieve_cached_states(hidden_states) + hidden_states = self.cache.retrieve(hidden_states, ctx) else: - # if teacache is enabled, we need to cache the original hidden states - if self.enable_teacache: + if self.cache is not None: original_hidden_states = hidden_states.clone() for block in self.blocks: hidden_states = block( hidden_states, encoder_hidden_states, timestep_proj, freqs_cis ) - # if teacache is enabled, we need to cache the original hidden states - if self.enable_teacache: - self.maybe_cache_states(hidden_states, original_hidden_states) + + if self.cache is not None: + if self.calibrate_cache: + self.cache.calibrate(hidden_states, original_hidden_states, ctx) + else: + self.cache.maybe_cache(hidden_states, original_hidden_states, ctx) self.cnt += 1 if sequence_shard_enabled: @@ -1153,65 +1166,5 @@ def forward( return output - def maybe_cache_states( - self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor - ) -> None: - """Cache residual with CFG positive/negative separation.""" - residual = hidden_states.squeeze(0) - original_hidden_states - if not self.is_cfg_negative: - self.previous_residual = residual - else: - self.previous_residual_negative = residual - - def should_skip_forward_for_cached_states(self, **kwargs) -> bool: - if not self.enable_teacache: - return False - ctx = self._get_teacache_context() - if ctx is None: - return False - - # Wan uses WanTeaCacheParams with additional fields - teacache_params = ctx.teacache_params - assert isinstance( - teacache_params, WanTeaCacheParams - ), "teacache_params is not a WanTeaCacheParams" - - # Initialize Wan-specific parameters - use_ret_steps = teacache_params.use_ret_steps - cutoff_steps = teacache_params.get_cutoff_steps(ctx.num_inference_steps) - ret_steps = teacache_params.ret_steps - - # Adjust ret_steps and cutoff_steps for non-CFG mode - # (WanTeaCacheParams uses *2 factor assuming CFG) - if not ctx.do_cfg: - ret_steps = ret_steps // 2 - cutoff_steps = cutoff_steps // 2 - - timestep_proj = kwargs["timestep_proj"] - temb = kwargs["temb"] - modulated_inp = timestep_proj if use_ret_steps else temb - - self.is_cfg_negative = ctx.is_cfg_negative - - # Wan uses ret_steps/cutoff_steps for boundary detection - is_boundary_step = self.cnt < ret_steps or self.cnt >= cutoff_steps - - # Use shared helper to compute cache decision - should_calc = self._compute_teacache_decision( - modulated_inp=modulated_inp, - is_boundary_step=is_boundary_step, - coefficients=ctx.coefficients, - teacache_thresh=ctx.teacache_thresh, - ) - - return not should_calc - - def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Retrieve cached residual with CFG positive/negative separation.""" - if not self.is_cfg_negative: - return hidden_states + self.previous_residual - else: - return hidden_states + self.previous_residual_negative - EntryClass = WanTransformer3DModel