From 5a9c66c8ff31c590f1f2c7e69e4cbdc0d28ea05f Mon Sep 17 00:00:00 2001 From: Lancer Date: Mon, 9 Feb 2026 02:45:10 +0800 Subject: [PATCH 01/11] [Feat] support MaCache Signed-off-by: Lancer --- vllm_omni/diffusion/cache/__init__.py | 7 +- .../diffusion/cache/magcache/__init__.py | 31 ++ vllm_omni/diffusion/cache/magcache/backend.py | 190 ++++++++ vllm_omni/diffusion/cache/magcache/config.py | 103 ++++ .../diffusion/cache/magcache/examples.py | 279 +++++++++++ vllm_omni/diffusion/cache/magcache/hook.py | 451 ++++++++++++++++++ .../diffusion/cache/magcache/strategy.py | 440 +++++++++++++++++ vllm_omni/diffusion/cache/selector.py | 22 +- vllm_omni/diffusion/data.py | 18 +- vllm_omni/diffusion/hooks/base.py | 90 +++- vllm_omni/entrypoints/cli/serve.py | 7 +- 11 files changed, 1618 insertions(+), 20 deletions(-) create mode 100644 vllm_omni/diffusion/cache/magcache/__init__.py create mode 100644 vllm_omni/diffusion/cache/magcache/backend.py create mode 100644 vllm_omni/diffusion/cache/magcache/config.py create mode 100644 vllm_omni/diffusion/cache/magcache/examples.py create mode 100644 vllm_omni/diffusion/cache/magcache/hook.py create mode 100644 vllm_omni/diffusion/cache/magcache/strategy.py diff --git a/vllm_omni/diffusion/cache/__init__.py b/vllm_omni/diffusion/cache/__init__.py index a5968f612a4..dc544ea73e9 100644 --- a/vllm_omni/diffusion/cache/__init__.py +++ b/vllm_omni/diffusion/cache/__init__.py @@ -5,23 +5,24 @@ This module provides a unified cache backend system for different caching strategies: - TeaCache: Timestep Embedding Aware Cache for adaptive transformer caching +- MagCache: Magnitude-based Cache for adaptive transformer caching - cache-dit: DBCache, SCM, and TaylorSeer caching strategies Cache backends are instantiated directly via their constructors and configured via OmniDiffusionConfig. """ from vllm_omni.diffusion.cache.base import CacheBackend +from vllm_omni.diffusion.cache.selector import get_cache_backend from vllm_omni.diffusion.cache.teacache import ( CacheContext, TeaCacheConfig, apply_teacache_hook, ) -from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend __all__ = [ "CacheBackend", - "TeaCacheConfig", "CacheContext", - "TeaCacheBackend", + "get_cache_backend", + "TeaCacheConfig", "apply_teacache_hook", ] diff --git a/vllm_omni/diffusion/cache/magcache/__init__.py b/vllm_omni/diffusion/cache/magcache/__init__.py new file mode 100644 index 00000000000..f5ab42812d7 --- /dev/null +++ b/vllm_omni/diffusion/cache/magcache/__init__.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.diffusion.cache.magcache.backend import CUSTOM_MAG_CACHE_ENABLERS +from vllm_omni.diffusion.cache.magcache.config import FLUX_MAG_RATIOS, MagCacheConfig +from vllm_omni.diffusion.cache.magcache.hook import ( + MagCacheBlockHook, + MagCacheHeadHook, + MagCacheState, + apply_mag_cache_hook, +) +from vllm_omni.diffusion.cache.magcache.strategy import ( + MagCacheStrategy, + MagCacheStrategyRegistry, + MagCacheContext, + FluxMagCacheStrategy, +) + +__all__ = [ + "CUSTOM_MAG_CACHE_ENABLERS", + "FLUX_MAG_RATIOS", + "MagCacheBlockHook", + "MagCacheConfig", + "MagCacheContext", + "MagCacheHeadHook", + "MagCacheState", + "MagCacheStrategy", + "MagCacheStrategyRegistry", + "FluxMagCacheStrategy", + "apply_mag_cache_hook", +] diff --git a/vllm_omni/diffusion/cache/magcache/backend.py b/vllm_omni/diffusion/cache/magcache/backend.py new file mode 100644 index 00000000000..e711d90ea50 --- /dev/null +++ b/vllm_omni/diffusion/cache/magcache/backend.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +MagCache backend implementation. + +This module provides the MagCache backend that implements the CacheBackend +interface using the hooks-based MagCache system. +""" + +from typing import Any + +import torch +from vllm.logger import init_logger + +from vllm_omni.diffusion.cache.base import CacheBackend +from vllm_omni.diffusion.cache.magcache.config import MagCacheConfig +from vllm_omni.diffusion.cache.magcache.hook import ( + MagCacheState, + apply_mag_cache_hook, +) +from vllm_omni.diffusion.hooks.base import StateManager + +logger = init_logger(__name__) + +CUSTOM_MAG_CACHE_ENABLERS = {} + + +def _register_pipeline_magcache( + pipeline: Any, + magcache_config: MagCacheConfig, +) -> None: + """Apply MagCache hooks to transformer using pre-built MagCacheConfig. + + Args: + pipeline: Diffusion pipeline instance. + magcache_config: Pre-configured MagCacheConfig with all parameters set. + """ + transformer = pipeline.transformer + apply_mag_cache_hook(transformer, magcache_config) + + +class MagCacheBackend(CacheBackend): + """ + MagCache implementation using hooks. + + MagCache (Magnitude-based Cache) is an adaptive caching technique that + speeds up diffusion inference by reusing transformer block computations + based on accumulated magnitude error between timesteps. + + The backend applies MagCache hooks to the transformer which intercept the + forward pass and implement the caching logic transparently. + + Example: + >>> from vllm_omni.diffusion.data import DiffusionCacheConfig + >>> from vllm_omni.diffusion.cache.magcache import MagCacheConfig, FLUX_MAG_RATIOS + >>> cache_config = DiffusionCacheConfig( + ... mag_ratios=FLUX_MAG_RATIOS, + ... num_inference_steps=28, + ... threshold=0.06, + ... max_skip_steps=3, + ... retention_ratio=0.2, + ... ) + >>> backend = MagCacheBackend(cache_config) + >>> backend.enable(pipeline) + >>> backend.refresh(pipeline, num_inference_steps=50) + """ + + def enable(self, pipeline: Any) -> None: + """Enable MagCache on transformer using hooks. + + This creates a MagCacheConfig from the backend's DiffusionCacheConfig + and applies the MagCache hook to the transformer. + + Args: + pipeline: Diffusion pipeline instance. Extracts transformer and transformer_type: + - transformer: pipeline.transformer + - transformer_type: pipeline.transformer.__class__.__name__ + """ + from vllm_omni.diffusion.cache.magcache.strategy import ( + MagCacheStrategyRegistry, + ) + + pipeline_type = pipeline.__class__.__name__ + transformer = pipeline.transformer + transformer_type = transformer.__class__.__name__ + + num_inference_steps = self.config.num_inference_steps + if num_inference_steps is None: + num_inference_steps = 28 + + mag_ratios = self.config.mag_ratios + if mag_ratios is None: + strategy = MagCacheStrategyRegistry.get_if_exists(transformer_type) + if strategy is not None: + mag_ratios = strategy.mag_ratios + logger.info( + f"MagCache: Using default mag_ratios from strategy '{transformer_type}'" + ) + + if mag_ratios is None and not self.config.calibrate: + raise ValueError( + f"mag_ratios must be provided for MagCache. " + f"For {transformer_type}, you need to provide mag_ratios or run in calibrate mode." + ) + + magcache_config = MagCacheConfig( + transformer_type=transformer_type, + threshold=self.config.threshold, + max_skip_steps=self.config.max_skip_steps, + retention_ratio=self.config.retention_ratio, + num_inference_steps=num_inference_steps, + calibrate=self.config.calibrate, + mag_ratios=mag_ratios if not self.config.calibrate else None, + ) + + self._registered = False + self._magcache_config = magcache_config + self._transformer_id = id(transformer) + + if pipeline_type in CUSTOM_MAG_CACHE_ENABLERS: + logger.info(f"Using custom MagCache enabler for model: {pipeline_type}") + CUSTOM_MAG_CACHE_ENABLERS[pipeline_type](pipeline, magcache_config) + else: + _register_pipeline_magcache(pipeline, magcache_config) + + self._registered = True + self.enabled = True + + def refresh(self, pipeline: Any, num_inference_steps: int) -> None: + """Refresh MagCache state for new generation. + + Clears all cached residuals and resets counters/accumulators. + Should be called before each generation to ensure clean state. + + Args: + pipeline: Diffusion pipeline instance. Extracts transformer via pipeline.transformer. + num_inference_steps: Number of inference steps for the current generation. + May be used for cache context updates. + """ + from diffusers.hooks._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS + + transformer = pipeline.transformer + transformer_type = transformer.__class__.__name__ + current_transformer_id = id(transformer) + + needs_re_register = False + + if self._registered and hasattr(self, '_transformer_id'): + if current_transformer_id != self._transformer_id: + logger.warning( + f"Transformer was replaced (id changed from {self._transformer_id} " + f"to {current_transformer_id}), re-registering hooks" + ) + needs_re_register = True + + if not self._registered or needs_re_register: + self.enable(pipeline) + return + + state_manager = StateManager(MagCacheState, (), {}) + + blocks_with_hooks = [] + + for name, submodule in transformer.named_children(): + if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): + continue + for index, block in enumerate(submodule): + registry = getattr(block, "_hook_registry", None) + if registry is not None and len(registry._hooks) > 0: + blocks_with_hooks.append((f"{name}.{index}", block, registry)) + + if not blocks_with_hooks: + logger.warning("No hooks found on transformer blocks, re-registering") + _register_pipeline_magcache(pipeline, self._magcache_config) + self._transformer_id = current_transformer_id + else: + for name, block, registry in blocks_with_hooks: + if hasattr(block, "do_true_cfg"): + delattr(block, "do_true_cfg") + + state_manager.reset() + + def is_enabled(self) -> bool: + """Check if MagCache is enabled. + + Returns: + True if enabled, False otherwise. + """ + return self.enabled diff --git a/vllm_omni/diffusion/cache/magcache/config.py b/vllm_omni/diffusion/cache/magcache/config.py new file mode 100644 index 00000000000..227da641a2e --- /dev/null +++ b/vllm_omni/diffusion/cache/magcache/config.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Optional, Union + +import torch + + +@dataclass +class MagCacheConfig: + """ + Configuration for MagCache applied to transformer models. + + MagCache (Magnitude-based Cache) is an adaptive caching technique that speeds up + diffusion model inference by reusing transformer block computations based on + magnitude ratios between consecutive timesteps. + + Reference: https://github.com/Zehong-Ma/MagCache + + Args: + threshold: The threshold for the accumulated error. If the accumulated error + is below this threshold, the block computation is skipped. A higher threshold + allows for more aggressive skipping (faster) but may degrade quality. + Default: 0.06 + max_skip_steps: The maximum number of consecutive steps that can be skipped (K). + Default: 3 + retention_ratio: The fraction of initial steps during which skipping is disabled + to ensure stability. For example, if num_inference_steps is 28 and + retention_ratio is 0.2, the first 6 steps will never be skipped. + Default: 0.2 + num_inference_steps: The number of inference steps used in the pipeline. + This is required to interpolate mag_ratios correctly. + Default: 28 + mag_ratios: The pre-computed magnitude ratios for the model. These are + checkpoint-dependent. If not provided, you must set calibrate=True to + calculate them for your specific model. For Flux models, you can use + FLUX_MAG_RATIOS. + Default: None + calibrate: If True, enables calibration mode. In this mode, no blocks are skipped. + Instead, the hook calculates the magnitude ratios for the current run and logs + them at the end. Use this to obtain mag_ratios for new models or schedulers. + Default: False + transformer_type: Transformer class name for logging and identification. + Auto-detected from pipeline.transformer.__class__.__name__ in backend. + Default: "FluxTransformer2DModel" + """ + + threshold: float = 0.06 + max_skip_steps: int = 3 + retention_ratio: float = 0.2 + num_inference_steps: int = 28 + mag_ratios: Optional[Union[torch.Tensor, list[float]]] = None + calibrate: bool = False + transformer_type: str = "FluxTransformer2DModel" + + def __post_init__(self) -> None: + """Validate and set default coefficients.""" + if self.threshold <= 0: + raise ValueError(f"threshold must be positive, got {self.threshold}") + + if self.max_skip_steps <= 0: + raise ValueError(f"max_skip_steps must be positive, got {self.max_skip_steps}") + + if not 0 < self.retention_ratio < 1: + raise ValueError(f"retention_ratio must be in (0, 1), got {self.retention_ratio}") + + if self.num_inference_steps is None: + raise ValueError( + "num_inference_steps must be provided for MagCache. " + "This is required to determine retention steps and interpolate mag_ratios. " + "For Flux models, use num_inference_steps=28." + ) + + if self.num_inference_steps <= 0: + raise ValueError(f"num_inference_steps must be positive, got {self.num_inference_steps}") + + if not self.calibrate and self.mag_ratios is None: + raise ValueError( + "mag_ratios must be provided for MagCache inference because these ratios " + "are model-dependent. To get them for your model:\n" + "1. Initialize MagCacheConfig(calibrate=True, ...)\n" + "2. Run inference on your model once.\n" + "3. Copy the printed ratios array and pass it to mag_ratios in the config.\n" + "For Flux models, you can import FLUX_MAG_RATIOS from vllm_omni.diffusion.cache.magcache.strategy." + ) + + if not self.calibrate and self.mag_ratios is not None: + if not torch.is_tensor(self.mag_ratios): + self.mag_ratios = torch.tensor(self.mag_ratios) + + +FLUX_MAG_RATIOS = None + + +def get_flux_mag_ratios() -> torch.Tensor: + """Get FLUX_MAG_RATIOS from FluxMagCacheStrategy, importing only when needed.""" + global FLUX_MAG_RATIOS + if FLUX_MAG_RATIOS is None: + from vllm_omni.diffusion.cache.magcache.strategy import FluxMagCacheStrategy + + FLUX_MAG_RATIOS = FluxMagCacheStrategy.FLUX_MAG_RATIOS + return FLUX_MAG_RATIOS diff --git a/vllm_omni/diffusion/cache/magcache/examples.py b/vllm_omni/diffusion/cache/magcache/examples.py new file mode 100644 index 00000000000..84eb05869db --- /dev/null +++ b/vllm_omni/diffusion/cache/magcache/examples.py @@ -0,0 +1,279 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +MagCache Integration Template for New Models. + +This module provides a complete template for integrating new diffusion models +into MagCache. Copy this file and modify according to your model's architecture. + +Integration Steps: + 1. Analyze model architecture (block structure, I/O format) + 2. Create Strategy class (inherit from MagCacheStrategy) + 3. Implement required methods + 4. Register the strategy + 5. Test and calibrate mag_ratios +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable + +import torch +from diffusers.hooks._helpers import TransformerBlockRegistry, TransformerBlockMetadata + +from vllm_omni.diffusion.cache.magcache.strategy import ( + MagCacheStrategy, + MagCacheStrategyRegistry, +) + + +def register_transformer_block( + model_class, + return_hidden_states_index: int = 1, + return_encoder_hidden_states_index: int = 0, +) -> None: + """Register a transformer block class with the TransformerBlockRegistry.""" + try: + TransformerBlockRegistry.get(model_class) + except ValueError: + TransformerBlockRegistry.register( + model_class=model_class, + metadata=TransformerBlockMetadata( + return_hidden_states_index=return_hidden_states_index, + return_encoder_hidden_states_index=return_encoder_hidden_states_index, + ), + ) + + +# ============================================================================= +# EXAMPLE: SD3 (Stable Diffusion 3) Integration +# ============================================================================= + +class SD3MagCacheStrategy(MagCacheStrategy): + """ + MagCache strategy for SD3 (Stable Diffusion 3). + + SD3 Architecture Analysis: + - Single stream: hidden_states only (no encoder_hidden_states separation) + - Block structure: transformer_blocks (nn.ModuleList) + - Output: tuple of (hidden_states,) + - Residual: output - input (simple subtraction) + + Integration Steps: + 1. Identify block name: "transformer_blocks" + 2. Determine I/O indices: return_hidden_states_index=0 + 3. Define residual: output - input + """ + + @property + def transformer_type(self) -> str: + return "SD3Transformer2DModel" + + @property + def mag_ratios(self) -> torch.Tensor: + """Return default mag_ratios for SD3 model.""" + return self.SD3_MAG_RATIOS + + SD3_MAG_RATIOS = torch.tensor( + [ + 1.0, + 1.15, + 1.10, + 1.05, + 1.02, + 1.00, + 0.98, + 0.95, + 0.92, + 0.90, + 0.88, + 0.85, + 0.82, + 0.80, + 0.78, + 0.75, + 0.72, + 0.70, + 0.68, + 0.65, + 0.62, + 0.60, + 0.58, + 0.55, + 0.52, + 0.50, + 0.48, + 0.45, + ] + ) + + @staticmethod + def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor: + """Interpolate mag_ratios to target length using nearest neighbor.""" + src_length = len(src_array) + if target_length == 1: + return src_array[-1:] + + scale = (src_length - 1) / (target_length - 1) + grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32) + mapped_indices = torch.round(grid * scale).long() + + return src_array[mapped_indices] + + @staticmethod + def register_blocks() -> None: + """Register SD3 transformer blocks. + + SD3 uses a simple transformer block structure. + """ + try: + from diffusers.models.transformers.sd_transformer_2d import ( + SD3TransformerBlock, + ) + + register_transformer_block( + SD3TransformerBlock, + return_hidden_states_index=0, + ) + except ImportError: + pass + + def create_context( + self, + module: torch.nn.Module, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None, + timestep: torch.Tensor, + guidance: torch.Tensor | None, + **kwargs, + ) -> MagCacheContext: + """Create context for SD3 model.""" + temb = module.time_embed(timestep) + + def run_transformer_blocks(): + for block in module.transformer_blocks: + hidden_states = block(hidden_states, temb=temb) + return hidden_states + + def run_single_transformer_blocks(h): + return h + + def postprocess(h: torch.Tensor, e: torch.Tensor) -> Any: + return h + + return MagCacheContext( + hidden_states=hidden_states, + encoder_hidden_states=None, + temb=temb, + head_block_input=None, + run_transformer_blocks=run_transformer_blocks, + run_single_transformer_blocks=run_single_transformer_blocks, + postprocess=postprocess, + ) + + def get_head_block_input(self, context: MagCacheContext) -> torch.Tensor: + """Get input to the first transformer block.""" + return context.hidden_states + + def compute_residual( + self, + output: torch.Tensor, + head_input: torch.Tensor, + context: MagCacheContext | None, + ) -> torch.Tensor: + """Compute residual for SD3: output - input.""" + return output - head_input + + def apply_residual( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + ) -> torch.Tensor: + """Apply residual for SD3: hidden_states + residual.""" + return hidden_states + residual + + +# Register SD3 strategy +MagCacheStrategyRegistry.register(SD3MagCacheStrategy()) + + +# ============================================================================= +# TEMPLATE: Copy and modify for your model +# ============================================================================= + +# class YourModelMagCacheStrategy(MagCacheStrategy): +# """ +# MagCache strategy for YourModel. +# +# Model Architecture Analysis: +# - [Describe the architecture] +# - Block name: [e.g., "transformer_blocks", "layers", "blocks"] +# - I/O format: [Describe input/output format] +# - Residual: [Describe how to compute residual] +# """ +# +# @property +# def transformer_type(self) -> str: +# return "YourTransformerModel" +# +# YOUR_MAG_RATIOS = torch.tensor([...]) # Replace with your model's ratios +# +# @staticmethod +# def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor: +# """Interpolate mag_ratios to target length.""" +# src_length = len(src_array) +# if target_length == 1: +# return src_array[-1:] +# +# scale = (src_length - 1) / (target_length - 1) +# grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32) +# mapped_indices = torch.round(grid * scale).long() +# +# return src_array[mapped_indices] +# +# @staticmethod +# def register_blocks() -> None: +# """Register your model's transformer blocks.""" +# try: +# from your_model import YourTransformerBlock +# +# register_transformer_block( +# YourTransformerBlock, +# return_hidden_states_index=0, # Adjust based on your model's output +# ) +# except ImportError: +# pass +# +# def compute_residual( +# self, +# output: torch.Tensor, +# head_input: torch.Tensor, +# context: MagCacheContext | None, +# ) -> torch.Tensor: +# """Compute residual for your model. +# +# Common patterns: +# - Simple: output - head_input +# - Complex: output - head_input (with shape adjustments) +# """ +# return output - head_input +# +# def apply_residual( +# self, +# hidden_states: torch.Tensor, +# residual: torch.Tensor, +# ) -> torch.Tensor: +# """Apply residual for your model. +# +# Common patterns: +# - Simple: hidden_states + residual +# - Complex: hidden_states + residual (with shape adjustments) +# """ +# return hidden_states + residual +# +# +# # Register your strategy +# MagCacheStrategyRegistry.register(YourModelMagCacheStrategy()) diff --git a/vllm_omni/diffusion/cache/magcache/hook.py b/vllm_omni/diffusion/cache/magcache/hook.py new file mode 100644 index 00000000000..5ba4ebf7ac5 --- /dev/null +++ b/vllm_omni/diffusion/cache/magcache/hook.py @@ -0,0 +1,451 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Hook-based MagCache implementation for vLLM-Omni. + +This module implements a diffusers-style hook system for MagCache (Magnitude-based Cache), +providing adaptive caching for diffusion model inference. + +MagCache speeds up inference by skipping transformer block computations when the accumulated +magnitude error is below a threshold, reusing cached residuals instead. + +Based on: https://github.com/Zehong-Ma/MagCache +Reference: diffusers/src/diffusers/hooks/mag_cache.py + +Architecture: +- MagCacheStrategy: Model-specific strategy for preprocessing/postprocessing +- MagCacheState: Per-step state tracking residuals and accumulated error +- MagCacheHeadHook: Decides whether to skip based on accumulated error +- MagCacheBlockHook: Computes and stores residuals at tail block +""" + +from __future__ import annotations + +from typing import Any + +import torch +from diffusers.hooks._helpers import TransformerBlockRegistry +from diffusers.hooks._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS +from diffusers.utils.torch_utils import unwrap_module +from vllm_omni.logger import init_logger + +from vllm_omni.diffusion.cache.magcache.config import MagCacheConfig +from vllm_omni.diffusion.cache.magcache.strategy import MagCacheStrategy, MagCacheStrategyRegistry, FluxMagCacheStrategy +from vllm_omni.diffusion.hooks.base import HookRegistry, ModelHook, StateManager + +logger = init_logger(__name__) + +_MAG_CACHE_LEADER_BLOCK_HOOK = "mag_cache_leader_block_hook" +_MAG_CACHE_BLOCK_HOOK = "mag_cache_block_hook" + + +class MagCacheState: + """State management for MagCache hook.""" + + def __init__(self) -> None: + """Initialize empty MagCache state.""" + self.previous_residual: torch.Tensor | None = None + self.head_block_input: torch.Tensor | tuple | None = None + self.should_compute: bool = True + self.accumulated_ratio: float = 1.0 + self.accumulated_err: float = 0.0 + self.accumulated_steps: int = 0 + self.step_index: int = 0 + self.calibration_ratios: list[float] = [] + + def reset(self) -> None: + """Reset all state variables for a new inference run.""" + self.previous_residual = None + self.should_compute = True + self.accumulated_ratio = 1.0 + self.accumulated_err = 0.0 + self.accumulated_steps = 0 + self.step_index = 0 + self.calibration_ratios = [] + + +class MagCacheHeadHook(ModelHook): + """Head block hook for MagCache - decides whether to skip computation.""" + + _HOOK_NAME = "mag_cache_head" + + def __init__(self, state_manager: StateManager, config: MagCacheConfig, strategy: MagCacheStrategy | None = None): + super().__init__() + self.state_manager = state_manager + self.config = config + self._strategy = strategy + self._metadata = None + + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + unwrapped_module = unwrap_module(module) + self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) + self.state_manager.set_context("inference") + return module + + @torch.compiler.disable + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.state_manager._current_context is None: + self.state_manager.set_context("inference") + + if hasattr(self._metadata, 'hidden_states_argument_name'): + arg_name = self._metadata.hidden_states_argument_name + else: + arg_name = "hidden_states" + hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs) + + state: MagCacheState = self.state_manager.get_state() + state.head_block_input = hidden_states + + should_compute = True + + if self.config.calibrate: + should_compute = True + else: + current_step = state.step_index + if current_step >= len(self.config.mag_ratios): + current_scale = 1.0 + else: + current_scale = self.config.mag_ratios[current_step] + + retention_step = int(self.config.retention_ratio * self.config.num_inference_steps + 0.5) + + if current_step >= retention_step: + state.accumulated_ratio *= current_scale + state.accumulated_steps += 1 + state.accumulated_err += abs(1.0 - state.accumulated_ratio) + + if ( + state.previous_residual is not None + and state.accumulated_err <= self.config.threshold + and state.accumulated_steps <= self.config.max_skip_steps + ): + should_compute = False + else: + state.accumulated_ratio = 1.0 + state.accumulated_steps = 0 + state.accumulated_err = 0.0 + + state.should_compute = should_compute + + if not should_compute: + res = state.previous_residual + + if isinstance(res, tuple): + res = tuple(r.to(hidden_states.device) for r in res) + + if self._strategy is not None: + original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + if original_encoder_hidden_states.device != res[1].device: + original_encoder_hidden_states = original_encoder_hidden_states.to(res[1].device) + h_res, e_res = res + output, enc_output = self._strategy.apply_residual_tuple( + hidden_states, original_encoder_hidden_states, res + ) + ret_list = [None] * 2 + ret_list[self._metadata.return_hidden_states_index] = output + ret_list[self._metadata.return_encoder_hidden_states_index] = enc_output + return self._log_cache_hit(state, output, ret_list) + else: + raise RuntimeError( + f"MagCache residual is tuple but no strategy available for {self._metadata.transformer_type}. " + f"Please register a MagCacheStrategy for this model." + ) + elif res.device != hidden_states.device: + res = res.to(hidden_states.device) + + if self._strategy is not None: + output = self._strategy.apply_residual(hidden_states, res) + elif res.shape == hidden_states.shape: + output = hidden_states + res + elif ( + hidden_states.ndim == 3 + and res.ndim == 3 + and hidden_states.shape[0] == res.shape[0] + and hidden_states.shape[2] == res.shape[2] + ): + diff = hidden_states.shape[1] - res.shape[1] + if diff > 0: + output = hidden_states.clone() + output[:, diff:, :] = output[:, diff:, :] + res + + if self._metadata.return_encoder_hidden_states_index is not None: + original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + max_idx = max( + self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index + ) + ret_list = [None] * (max_idx + 1) + ret_list[self._metadata.return_hidden_states_index] = output + ret_list[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states + return self._log_cache_hit(state, output, ret_list) + else: + return self._log_cache_hit(state, output, None) + else: + output = self.fn_ref.original_forward(*args, **kwargs) + return self._log_cache_miss(state, output) + + def _log_cache_hit(self, state: MagCacheState, output, ret): + step = state.step_index + if state.previous_residual is not None: + if isinstance(state.previous_residual, tuple): + residual_shape = tuple(r.shape for r in state.previous_residual) + else: + residual_shape = state.previous_residual.shape + else: + residual_shape = "None" + logger.debug( + f"[MagCache][HEAD] STEP={step}: CACHE_HIT (err={state.accumulated_err:.6f}, " + f"steps_skipped={state.accumulated_steps}, residual_shape={residual_shape}" + ) + return ret if ret is not None else output + + def _log_cache_miss(self, state: MagCacheState, output): + step = state.step_index + residual_norm = 0.0 + if state.previous_residual is not None: + if isinstance(state.previous_residual, tuple): + residual_norm = sum(float(torch.norm(r).item()) for r in state.previous_residual) + else: + residual_norm = float(torch.norm(state.previous_residual).item()) + logger.debug( + f"[MagCache][HEAD] STEP={step}: CACHE_MISS (err={state.accumulated_err:.6f}, " + f"acc_ratio={state.accumulated_ratio:.6f}, residual_norm={residual_norm:.6f}, threshold={self.config.threshold}, max_skip={self.config.max_skip_steps})" + ) + return output + + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + self.state_manager.reset() + return module + + +class MagCacheBlockHook(ModelHook): + """Block hook for MagCache - computes residuals at tail block.""" + + _HOOK_NAME = "mag_cache_block" + + def __init__( + self, + state_manager: StateManager, + is_tail: bool = False, + config: MagCacheConfig | None = None, + strategy: MagCacheStrategy | None = None, + ): + super().__init__() + self.state_manager = state_manager + self.is_tail = is_tail + self.config = config + self._strategy = strategy + self._metadata = None + + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + unwrapped_module = unwrap_module(module) + self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) + return module + + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + self.state_manager.reset() + return module + + @torch.compiler.disable + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.state_manager._current_context is None: + self.state_manager.set_context("inference") + state: MagCacheState = self.state_manager.get_state() + + if not state.should_compute: + if hasattr(self._metadata, 'hidden_states_argument_name'): + arg_name = self._metadata.hidden_states_argument_name + else: + arg_name = "hidden_states" + hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs) + + if self.is_tail: + self._advance_step(state) + + if self._metadata.return_encoder_hidden_states_index is not None: + encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + max_idx = max( + self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index + ) + ret_list = [None] * (max_idx + 1) + ret_list[self._metadata.return_hidden_states_index] = hidden_states + ret_list[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states + return tuple(ret_list) + + return hidden_states + + output = self.fn_ref.original_forward(*args, **kwargs) + + if self.is_tail: + if isinstance(output, tuple): + out_hidden = output[self._metadata.return_hidden_states_index] + else: + out_hidden = output + + in_hidden = state.head_block_input + + if in_hidden is None: + return output + + if self._strategy is not None: + residual = self._strategy.compute_residual(output, in_hidden, None) + elif out_hidden.shape == in_hidden.shape: + residual = out_hidden - in_hidden + elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]: + diff = in_hidden.shape[1] - out_hidden.shape[1] + if diff == 0: + residual = out_hidden - in_hidden + else: + residual = out_hidden - in_hidden + else: + residual = out_hidden + + if self.config.calibrate: + self._perform_calibration_step(state, residual) + + state.previous_residual = residual + self._advance_step(state) + + self._log_residual_computed(state, residual) + + return output + + def _log_residual_computed(self, state: MagCacheState, residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor]) -> None: + step = state.step_index + if residual is None: + residual_norm = 0.0 + residual_shape = "None" + elif isinstance(residual, tuple): + residual_norm = sum(float(torch.norm(r).item()) for r in residual) + residual_shape = tuple(r.shape for r in residual) + else: + residual_norm = float(torch.norm(residual).item()) + residual_shape = residual.shape + logger.debug( + f"[MagCache][TAIL] STEP={step}: RESIDUAL_COMPUTED (norm={residual_norm:.6f}, " + f"shape={residual_shape})" + ) + + def _perform_calibration_step(self, state: MagCacheState, current_residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor]) -> None: + def _get_norm(residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + if isinstance(residual, tuple): + return sum(torch.linalg.norm(r.float(), dim=-1) for r in residual) + return torch.linalg.norm(residual.float(), dim=-1) + + if state.previous_residual is None: + ratio = 1.0 + else: + curr_norm = _get_norm(current_residual) + prev_norm = _get_norm(state.previous_residual) + ratio = (curr_norm / (prev_norm + 1e-8)).mean().item() + + state.calibration_ratios.append(ratio) + + def _advance_step(self, state: MagCacheState) -> None: + state.step_index += 1 + if state.step_index >= self.config.num_inference_steps: + if self.config.calibrate: + logger.info( + f"MagCache calibration complete. mag_ratios={state.calibration_ratios}" + ) + logger.info( + "Copy these values to DiffusionCacheConfig(mag_ratios=...) for production use" + ) + + state.step_index = 0 + state.accumulated_ratio = 1.0 + state.accumulated_steps = 0 + state.accumulated_err = 0.0 + state.previous_residual = None + state.calibration_ratios = [] + + +def apply_mag_cache_hook(module: torch.nn.Module, config: MagCacheConfig) -> None: + """Apply MagCache optimization to a transformer module. + + Args: + module: Transformer model to optimize (e.g., FluxTransformer2DModel) + config: MagCacheConfig specifying caching parameters + """ + HookRegistry.check_if_exists_or_initialize(module) + + transformer_type = config.transformer_type + strategy = MagCacheStrategyRegistry.get_if_exists(transformer_type) + if strategy is None: + logger.warning( + f"MagCache: No strategy found for '{transformer_type}'. " + f"Using default behavior. Available strategies: {list(MagCacheStrategyRegistry._registry.keys())}" + ) + else: + logger.info(f"MagCache: Using strategy '{transformer_type}' for optimization") + if hasattr(strategy, 'register_blocks'): + strategy.register_blocks() + + state_manager = StateManager(MagCacheState, (), {}) + remaining_blocks = [] + + for name, submodule in module.named_children(): + if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): + continue + for index, block in enumerate(submodule): + remaining_blocks.append((f"{name}.{index}", block)) + + if not remaining_blocks: + logger.warning("MagCache: No transformer blocks found to apply hooks.") + return + + if len(remaining_blocks) == 1: + name, block = remaining_blocks[0] + logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'") + _apply_mag_cache_block_hook(block, state_manager, config, is_tail=True, strategy=strategy) + _apply_mag_cache_head_hook(block, state_manager, config, strategy) + return + + head_block_name, head_block = remaining_blocks.pop(0) + tail_block_name, tail_block = remaining_blocks.pop(-1) + + logger.info(f"MagCache: Applying Head Hook to {head_block_name}") + _apply_mag_cache_head_hook(head_block, state_manager, config, strategy) + + for name, block in remaining_blocks: + _apply_mag_cache_block_hook(block, state_manager, config, strategy=strategy) + + logger.info(f"MagCache: Applying Tail Hook to {tail_block_name}") + _apply_mag_cache_block_hook(tail_block, state_manager, config, is_tail=True, strategy=strategy) + + +def _apply_mag_cache_head_hook( + block: torch.nn.Module, + state_manager: StateManager, + config: MagCacheConfig, + strategy: MagCacheStrategy | None = None, +) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + + if registry.get_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) is not None: + registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) + + hook = MagCacheHeadHook(state_manager, config, strategy) + registry.register_hook(hook, _MAG_CACHE_LEADER_BLOCK_HOOK) + + +def _apply_mag_cache_block_hook( + block: torch.nn.Module, + state_manager: StateManager, + config: MagCacheConfig, + is_tail: bool = False, + strategy: MagCacheStrategy | None = None, +) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + + if registry.get_hook(_MAG_CACHE_BLOCK_HOOK) is not None: + registry.remove_hook(_MAG_CACHE_BLOCK_HOOK) + + hook = MagCacheBlockHook(state_manager, is_tail, config, strategy) + registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK) diff --git a/vllm_omni/diffusion/cache/magcache/strategy.py b/vllm_omni/diffusion/cache/magcache/strategy.py new file mode 100644 index 00000000000..135f52b8d77 --- /dev/null +++ b/vllm_omni/diffusion/cache/magcache/strategy.py @@ -0,0 +1,440 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +MagCache strategy definitions for different model architectures. + +This module provides model-specific strategies for MagCache, allowing easy +extension to new models by implementing the MagCacheStrategy interface. + +Architecture: +- MagCacheStrategy: Abstract base class defining the strategy interface +- FluxMagCacheStrategy: Strategy for Flux (dual-stream) models +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable + +import torch +from diffusers.hooks._helpers import TransformerBlockRegistry, TransformerBlockMetadata + + +def register_transformer_block( + model_class, + return_hidden_states_index: int = 1, + return_encoder_hidden_states_index: int = 0, +) -> None: + """Register a transformer block class with the TransformerBlockRegistry. + + Args: + model_class: The transformer block class to register. + return_hidden_states_index: Index of hidden_states in the forward output tuple. + return_encoder_hidden_states_index: Index of encoder_hidden_states in the output. + """ + try: + TransformerBlockRegistry.get(model_class) + except ValueError: + TransformerBlockRegistry.register( + model_class=model_class, + metadata=TransformerBlockMetadata( + return_hidden_states_index=return_hidden_states_index, + return_encoder_hidden_states_index=return_encoder_hidden_states_index, + ), + ) + + +@dataclass +class MagCacheContext: + """ + Context object containing model-specific information for MagCache. + + Attributes: + hidden_states: Current hidden states before transformer blocks. + encoder_hidden_states: Optional encoder states (None for single-stream). + temb: Timestep embedding tensor. + head_block_input: Input to the first transformer block (for residual calculation). + run_transformer_blocks: Callable to run transformer blocks. + run_single_transformer_blocks: Callable to run single transformer blocks. + postprocess: Callable to produce final output from block outputs. + """ + + hidden_states: torch.Tensor + encoder_hidden_states: torch.Tensor | None + temb: torch.Tensor + head_block_input: torch.Tensor | None + run_transformer_blocks: Callable[[], tuple[torch.Tensor, torch.Tensor]] + run_single_transformer_blocks: Callable[[], torch.Tensor] + postprocess: Callable[[torch.Tensor, torch.Tensor], Any] + + +class MagCacheStrategy(ABC): + """ + Abstract base class for MagCache strategies. + + Each model architecture requires a specific strategy to handle: + - Preprocessing of inputs (embeddings, positional encodings) + - Running transformer blocks + - Postprocessing (normalization, projection) + - Computing residuals for caching + + Implement this class to add support for new model architectures. + """ + + @property + @abstractmethod + def transformer_type(self) -> str: + """Returns the transformer class name this strategy supports.""" + pass + + @property + @abstractmethod + def mag_ratios(self) -> torch.Tensor: + """Return the default mag_ratios tensor for this model. + + This tensor defines caching ratios for each transformer block. + Values should be calibrated for the specific model architecture. + + Returns: + 1D tensor of mag_ratios (one per transformer block). + """ + pass + + @abstractmethod + def create_context( + self, + module: torch.nn.Module, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None, + timestep: torch.Tensor, + guidance: torch.Tensor | None, + **kwargs, + ) -> MagCacheContext: + """ + Create a MagCacheContext from model inputs. + + Args: + module: The transformer module. + hidden_states: Input latents. + encoder_hidden_states: Text encoder outputs (None for single-stream). + timestep: Denoising timestep. + guidance: Guidance scale tensor (optional). + **kwargs: Additional model-specific arguments. + + Returns: + MagCacheContext with all information needed for caching. + """ + pass + + @abstractmethod + def get_head_block_input(self, context: MagCacheContext) -> torch.Tensor: + """ + Get the input to the first transformer block. + + Args: + context: MagCacheContext from create_context. + + Returns: + Tensor representing the input to the first block. + """ + pass + + @abstractmethod + def compute_residual( + self, + output: torch.Tensor, + head_input: torch.Tensor, + context: MagCacheContext, + ) -> torch.Tensor: + """ + Compute residual between output and head input. + + Args: + output: Output from transformer blocks. + head_input: Input to the first block. + context: MagCacheContext. + + Returns: + Residual tensor for caching. + """ + pass + + @abstractmethod + def apply_residual(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + """ + Apply cached residual to hidden states. + + Args: + hidden_states: Current hidden states. + residual: Cached residual to apply. + + Returns: + Hidden states with residual added. + """ + pass + + def apply_residual_tuple( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + residual: tuple[torch.Tensor, torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply cached residual tuple to both hidden_states and encoder_hidden_states. + + Default implementation: add residuals separately. + Override this method for models with specific residual application logic. + + Args: + hidden_states: Current hidden states. + encoder_hidden_states: Current encoder hidden states. + residual: Tuple of (hidden_states_residual, encoder_hidden_states_residual). + + Returns: + Tuple of (hidden_states, encoder_hidden_states) with residuals applied. + """ + h_res, e_res = residual + return hidden_states + h_res, encoder_hidden_states + e_res + + +class FluxMagCacheStrategy(MagCacheStrategy): + """ + MagCache strategy for Flux (dual-stream) models. + + Flux architecture: + - transformer blocks (dual-stream): image tokens and text tokens + processed independently with separate weights + - single transformer blocks (single-stream): concatenated sequence + (image + text tokens) shares the same group of weights + - Final norm_out and proj_out layers + + This strategy provides: + - mag_ratios: Pre-computed magnitude ratios for Flux (28 steps) + - nearest_interp(): Interpolate mag_ratios to match num_inference_steps + """ + + @property + def transformer_type(self) -> str: + return "FluxTransformer2DModel" + + @property + def mag_ratios(self) -> torch.Tensor: + """Return default mag_ratios for Flux model.""" + return self.FLUX_MAG_RATIOS + + FLUX_MAG_RATIOS = torch.tensor( + [1.0] + + [ + 1.21094, + 1.11719, + 1.07812, + 1.0625, + 1.03906, + 1.03125, + 1.03906, + 1.02344, + 1.03125, + 1.02344, + 0.98047, + 1.01562, + 1.00781, + 1.0, + 1.00781, + 1.0, + 1.00781, + 1.0, + 1.0, + 0.99609, + 0.99609, + 0.98047, + 0.98828, + 0.96484, + 0.95703, + 0.93359, + 0.89062, + ] + ) + + def create_context( + self, + module: torch.nn.Module, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None, + timestep: torch.Tensor, + guidance: torch.Tensor | None, + **kwargs, + ) -> MagCacheContext: + """Create context for Flux model.""" + temb = ( + module.time_text_embed(timestep, guidance) + if guidance is not None + else module.time_text_embed(timestep) + ) + + def run_transformer_blocks(): + h = hidden_states + e = encoder_hidden_states + for block in module.transformer_blocks: + e, h = block( + hidden_states=h, + encoder_hidden_states=e, + temb=temb, + ) + return e, h + + def run_single_transformer_blocks(h): + for block in module.single_transformer_blocks: + h = block( + hidden_states=h, + encoder_hidden_states=torch.zeros(1, 1, h.shape[-1], device=h.device, dtype=h.dtype), + temb=temb, + ) + return h + + def postprocess(e: torch.Tensor, h: torch.Tensor) -> Any: + h = torch.cat([e, h], dim=1) + h = module.norm_out(h, temb) + output = module.proj_out(h) + from diffusers.models.modeling_outputs import Transformer2DModelOutput + + return Transformer2DModelOutput(sample=output) + + return MagCacheContext( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + head_block_input=None, + run_transformer_blocks=run_transformer_blocks, + run_single_transformer_blocks=run_single_transformer_blocks, + postprocess=postprocess, + ) + + def get_head_block_input(self, context: MagCacheContext) -> torch.Tensor: + """Get input to the first transformer block.""" + return context.hidden_states + + def compute_residual( + self, + output: torch.Tensor, + head_input: torch.Tensor, + context: MagCacheContext, + ) -> torch.Tensor: + """Compute residual for Flux single transformer blocks. + + For single transformer blocks, the output is concatenated (encoder + decoder). + We need to extract encoder residual from the combined output. + """ + if context is not None: + encoder_hidden_states = context.encoder_hidden_states + if encoder_hidden_states is not None: + encoder_len = encoder_hidden_states.shape[1] + if isinstance(output, tuple): + out_e = output[0] + out_h = output[1] + else: + out_e = output[:, :encoder_len, :] + out_h = output[:, encoder_len:, :] + + e_res = out_e - encoder_hidden_states + h_res = out_h - head_input + return (e_res, h_res) + + return output + + def apply_residual( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + ) -> torch.Tensor: + """Apply residual by adding to hidden states.""" + return hidden_states + residual + + def apply_residual_tuple( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + residual: tuple[torch.Tensor, torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """Apply residual tuple to both hidden_states and encoder_hidden_states. + + Flux architecture: + - encoder: 512 tokens + - decoder: 4096 tokens + + The residual tuple (e_res, h_res) comes from compute_residual: + - e_res: encoder residual (512 tokens) + - h_res: decoder residual (4096 tokens) + + We apply residuals separately to encoder_hidden_states and hidden_states. + """ + e_res, h_res = residual + + output = hidden_states + h_res + enc_output = encoder_hidden_states + e_res + + return output, enc_output + + @staticmethod + def register_blocks() -> None: + """Register vLLM-Omni Flux transformer blocks with TransformerBlockRegistry. + + Blocks: + - FluxTransformerBlock: dual-stream block + - FluxSingleTransformerBlock: single-stream block + """ + try: + from vllm_omni.diffusion.models.flux.flux_transformer import ( + FluxTransformerBlock, + FluxSingleTransformerBlock, + ) + + register_transformer_block(FluxTransformerBlock) + register_transformer_block(FluxSingleTransformerBlock) + except ImportError: + pass + + @staticmethod + def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor: + """Interpolate mag_ratios to target length using nearest neighbor.""" + src_length = len(src_array) + if target_length == 1: + return src_array[-1:] + + scale = (src_length - 1) / (target_length - 1) + grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32) + mapped_indices = torch.round(grid * scale).long() + + return src_array[mapped_indices] + + +class MagCacheStrategyRegistry: + """Registry for MagCache strategies by transformer type.""" + + _registry: dict[str, MagCacheStrategy] = {} + + @classmethod + def register(cls, strategy: MagCacheStrategy) -> None: + """Register a strategy.""" + cls._registry[strategy.transformer_type] = strategy + + @classmethod + def get(cls, transformer_type: str) -> MagCacheStrategy: + """Get strategy for given transformer type.""" + if transformer_type not in cls._registry: + available = list(cls._registry.keys()) + raise ValueError( + f"Unknown model type: '{transformer_type}'. " + f"Available types: {available}" + ) + return cls._registry[transformer_type] + + @classmethod + def get_if_exists(cls, transformer_type: str) -> MagCacheStrategy | None: + """Get strategy if exists, None otherwise.""" + return cls._registry.get(transformer_type) + + +# Register default strategies +MagCacheStrategyRegistry.register(FluxMagCacheStrategy()) diff --git a/vllm_omni/diffusion/cache/selector.py b/vllm_omni/diffusion/cache/selector.py index 7c09bf66475..857bb429507 100644 --- a/vllm_omni/diffusion/cache/selector.py +++ b/vllm_omni/diffusion/cache/selector.py @@ -1,8 +1,6 @@ from typing import Any from vllm_omni.diffusion.cache.base import CacheBackend -from vllm_omni.diffusion.cache.cache_dit_backend import CacheDiTBackend -from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend from vllm_omni.diffusion.data import DiffusionCacheConfig @@ -12,14 +10,15 @@ def get_cache_backend(cache_backend: str | None, cache_config: Any) -> CacheBack This is a selector function that routes to the appropriate backend implementation. - cache_dit: Uses CacheDiTBackend with enable()/refresh() interface - tea_cache: Uses TeaCacheBackend with enable()/refresh() interface + - mag_cache: Uses MagCacheBackend with enable()/refresh() interface Args: - cache_backend: Cache backend name ("cache_dit", "tea_cache", or None). + cache_backend: Cache backend name ("cache_dit", "tea_cache", "mag_cache", or None). cache_config: Cache configuration (dict or DiffusionCacheConfig instance). Returns: - Cache backend instance (CacheDiTBackend or TeaCacheBackend) if cache_backend is set, - None otherwise. + Cache backend instance (CacheDiTBackend, TeaCacheBackend, or MagCacheBackend) + if cache_backend is set, None otherwise. Raises: ValueError: If cache_backend is unsupported. @@ -31,8 +30,19 @@ def get_cache_backend(cache_backend: str | None, cache_config: Any) -> CacheBack cache_config = DiffusionCacheConfig.from_dict(cache_config) if cache_backend == "cache_dit": + from vllm_omni.diffusion.cache.cache_dit_backend import CacheDiTBackend + return CacheDiTBackend(cache_config) elif cache_backend == "tea_cache": + from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend + return TeaCacheBackend(cache_config) + elif cache_backend == "mag_cache": + from vllm_omni.diffusion.cache.magcache.backend import MagCacheBackend + + return MagCacheBackend(cache_config) else: - raise ValueError(f"Unsupported cache backend: {cache_backend}. Supported: 'cache_dit', 'tea_cache'") + raise ValueError( + f"Unsupported cache backend: {cache_backend}. " + f"Supported: 'cache_dit', 'tea_cache', 'mag_cache'" + ) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index f884fb6f177..e9680326acd 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -119,7 +119,7 @@ def __getattr__(self, item: str) -> Any: @dataclass class DiffusionCacheConfig: """ - Configuration for cache adapters (TeaCache, cache-dit, etc.). + Configuration for cache adapters (TeaCache, cache-dit, MagCache, etc.). This dataclass provides a unified interface for cache configuration parameters. It can be initialized from a dictionary and accessed via attributes. @@ -129,6 +129,8 @@ class DiffusionCacheConfig: - cache-dit: Fn_compute_blocks, Bn_compute_blocks, max_warmup_steps, residual_diff_threshold, enable_taylorseer, taylorseer_order, scm_steps_mask_policy, scm_steps_policy + - MagCache: threshold, max_skip_steps, retention_ratio, num_inference_steps, + mag_ratios, calibrate Example: >>> # From dict (user-facing API) - partial config uses defaults for missing keys @@ -138,7 +140,7 @@ class DiffusionCacheConfig: >>> print(config.Fn_compute_blocks) # 8 (default) >>> # Empty dict uses all defaults >>> default_config = DiffusionCacheConfig.from_dict({}) - >>> print(default_config.rel_l1_thresh) # 0.2 (default) + >>> print(config.rel_l1_thresh) # 0.2 (default) """ # TeaCache parameters [tea_cache only] @@ -146,6 +148,18 @@ class DiffusionCacheConfig: rel_l1_thresh: float = 0.2 coefficients: list[float] | None = None # Uses model-specific defaults if None + # MagCache parameters [mag_cache only] + # Default: 0.06 threshold for accumulated magnitude error + threshold: float = 0.06 + # Default: 3 maximum consecutive skip steps + max_skip_steps: int = 3 + # Default: 0.2 retention ratio (initial steps that never skip) + retention_ratio: float = 0.2 + # Default: None magnitude ratios (model-specific, required for inference) + mag_ratios: list[float] | None = None + # Default: False calibration mode (computes mag_ratios on first run) + calibrate: bool = False + # cache-dit parameters [cache-dit only] # Default: 1 forward compute block (optimized for single-transformer models) # Use 1 as default instead of cache-dit's 8, optimized for single-transformer models diff --git a/vllm_omni/diffusion/hooks/base.py b/vllm_omni/diffusion/hooks/base.py index 8b330a19f42..18da45224c9 100644 --- a/vllm_omni/diffusion/hooks/base.py +++ b/vllm_omni/diffusion/hooks/base.py @@ -12,7 +12,11 @@ from dataclasses import dataclass from typing import Any +import torch import torch.nn as nn +from vllm_omni.logger import init_logger + +logger = init_logger(__name__) class BaseState: @@ -25,17 +29,24 @@ def reset(self) -> None: # pragma: no cover - default is no-op class StateManager: """Manage per-context hook state instances.""" - def __init__(self, state_cls: Callable[[], BaseState]): + def __init__(self, state_cls: Callable[[], BaseState], init_args: tuple = (), init_kwargs: dict | None = None): self._state_cls = state_cls + self._init_args = init_args + self._init_kwargs = init_kwargs or {} self._states: dict[str, BaseState] = {} self._context: str = "default" + @property + def _current_context(self) -> str | None: + """Alias for _context for compatibility with diffusers hook code.""" + return self._context if self._context != "default" else None + def set_context(self, name: str) -> None: self._context = name or "default" def get_state(self) -> BaseState: if self._context not in self._states: - self._states[self._context] = self._state_cls() + self._states[self._context] = self._state_cls(*self._init_args, **self._init_kwargs) return self._states[self._context] def reset(self) -> None: @@ -130,7 +141,9 @@ class _WrappedForward: def __call__(self, *args: Any, **kwargs: Any): registry: HookRegistry | None = getattr(self.module, "_hook_registry", None) - if registry is None or not registry._hooks: + if registry is None: + return self.module._original_forward(*args, **kwargs) + if not registry._hooks: return self.module._original_forward(*args, **kwargs) return registry.dispatch(*args, **kwargs) @@ -146,6 +159,15 @@ def __init__(self, module: nn.Module): self.module = module self._hooks: dict[str, ModelHook] = {} + def __getstate__(self): + """Handle pickling - preserve hooks.""" + return {"module": self.module, "_hooks": self._hooks} + + def __setstate__(self, state): + """Handle unpickling - restore hooks.""" + self.module = state["module"] + self._hooks = state["_hooks"] + @classmethod def get_or_create(cls, module: nn.Module) -> HookRegistry: """Get existing registry or create a new one for the module. @@ -161,22 +183,67 @@ def get_or_create(cls, module: nn.Module) -> HookRegistry: registry = cls(module) setattr(module, "_hook_registry", registry) - # Wrap module.forward once so hooks can intercept calls. if not hasattr(module, "_original_forward"): module._original_forward = module.forward # type: ignore[attr-defined] module.forward = _WrappedForward(module) # type: ignore[assignment] - return registry - def register_hook(self, name: str, hook: ModelHook) -> None: + @classmethod + def check_if_exists_or_initialize(cls, module: nn.Module) -> HookRegistry: + """Get existing registry or create a new one for the module. + + This method ensures a HookRegistry exists on the module and returns it. + If a registry doesn't exist, it creates one and attaches it to the module. + This is equivalent to get_or_create() for compatibility with diffusers API. + + Args: + module: The module to get/create a registry for. + + Returns: + The HookRegistry for this module. + """ + return cls.get_or_create(module) + + def register_hook(self, hook: ModelHook, name: str | None = None) -> str | None: """Register a hook with the given name. + This method follows the diffusers API convention where the hook object + comes first, followed by an optional name. If no name is provided, + uses hook._HOOK_NAME. + Args: - name: Unique name for this hook. hook: The hook instance to register. + name: Optional unique name for this hook. If not provided, + uses hook._HOOK_NAME. + + Returns: + The name the hook was registered under, or None if registration failed. """ + if name is None: + name = getattr(hook, "_HOOK_NAME", None) + if name is None: + return None + + if name in self._hooks: + raise ValueError( + f"Hook with name '{name}' already exists. Remove it first or use a different name." + ) + hook.initialize_hook(self.module) + + if hasattr(hook, "fn_ref"): + hook.fn_ref.original_forward = self.module._original_forward + else: + original_forward = self.module._original_forward # type: ignore[attr-defined] + + class _FnRef: + def __init__(self, orig_forward): + self.original_forward = orig_forward + + hook.fn_ref = _FnRef(original_forward) + self._hooks[name] = hook + return name def remove_hook(self, name: str) -> None: """Remove a hook by name. @@ -245,3 +312,12 @@ def reset_hook(self, name: str) -> None: hook = self._hooks.get(name) if hook is not None: hook.reset_state(self.module) + + def reset(self) -> None: + """Reset all hooks and clear the registry. + + This removes all hooks from the registry and resets each hook's state. + """ + for name, hook in list(self._hooks.items()): + hook.reset_state(self.module) + self._hooks.clear() diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index acefc081b0b..d0680a32193 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -167,13 +167,16 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu "--cache-backend", type=str, default="none", - help="Cache backend for diffusion models, options: 'tea_cache', 'cache_dit'", + help="Cache backend for diffusion models, options: 'tea_cache', 'cache_dit', 'mag_cache'", ) omni_config_group.add_argument( "--cache-config", type=str, default=None, - help="JSON string of cache configuration (e.g., '{\"rel_l1_thresh\": 0.2}').", + help="JSON string of cache configuration. " + "TeaCache: '{\"rel_l1_thresh\": 0.2}'. " + "MagCache: '{\"threshold\": 0.06, \"max_skip_steps\": 3, \"mag_ratios\": [1.0, ...]}'. " + "Calibration mode: add '\"calibrate\": true'", ) omni_config_group.add_argument( "--enable-cache-dit-summary", From 39d12e24f2f5698d9d60969ee8ff2d5a009f0d16 Mon Sep 17 00:00:00 2001 From: Lancer Date: Mon, 9 Feb 2026 19:22:27 +0800 Subject: [PATCH 02/11] upd Signed-off-by: Lancer --- .../diffusion/cache/magcache/__init__.py | 9 +- vllm_omni/diffusion/cache/magcache/backend.py | 29 +- vllm_omni/diffusion/cache/magcache/config.py | 43 +-- .../diffusion/cache/magcache/examples.py | 279 --------------- vllm_omni/diffusion/cache/magcache/hook.py | 151 ++++---- vllm_omni/diffusion/cache/magcache/state.py | 44 +++ .../diffusion/cache/magcache/strategy.py | 324 +++++++----------- vllm_omni/diffusion/cache/selector.py | 5 +- vllm_omni/diffusion/hooks/base.py | 6 +- vllm_omni/entrypoints/cli/serve.py | 2 +- 10 files changed, 285 insertions(+), 607 deletions(-) delete mode 100644 vllm_omni/diffusion/cache/magcache/examples.py create mode 100644 vllm_omni/diffusion/cache/magcache/state.py diff --git a/vllm_omni/diffusion/cache/magcache/__init__.py b/vllm_omni/diffusion/cache/magcache/__init__.py index f5ab42812d7..2fb03ec3293 100644 --- a/vllm_omni/diffusion/cache/magcache/__init__.py +++ b/vllm_omni/diffusion/cache/magcache/__init__.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm_omni.diffusion.cache.magcache.backend import CUSTOM_MAG_CACHE_ENABLERS -from vllm_omni.diffusion.cache.magcache.config import FLUX_MAG_RATIOS, MagCacheConfig +from vllm_omni.diffusion.cache.magcache.config import MagCacheConfig from vllm_omni.diffusion.cache.magcache.hook import ( MagCacheBlockHook, MagCacheHeadHook, @@ -10,15 +10,15 @@ apply_mag_cache_hook, ) from vllm_omni.diffusion.cache.magcache.strategy import ( + FluxMagCacheStrategy, + MagCacheContext, MagCacheStrategy, MagCacheStrategyRegistry, - MagCacheContext, - FluxMagCacheStrategy, ) __all__ = [ "CUSTOM_MAG_CACHE_ENABLERS", - "FLUX_MAG_RATIOS", + "FluxMagCacheStrategy", "MagCacheBlockHook", "MagCacheConfig", "MagCacheContext", @@ -26,6 +26,5 @@ "MagCacheState", "MagCacheStrategy", "MagCacheStrategyRegistry", - "FluxMagCacheStrategy", "apply_mag_cache_hook", ] diff --git a/vllm_omni/diffusion/cache/magcache/backend.py b/vllm_omni/diffusion/cache/magcache/backend.py index e711d90ea50..af29e67a33e 100644 --- a/vllm_omni/diffusion/cache/magcache/backend.py +++ b/vllm_omni/diffusion/cache/magcache/backend.py @@ -53,9 +53,10 @@ class MagCacheBackend(CacheBackend): Example: >>> from vllm_omni.diffusion.data import DiffusionCacheConfig - >>> from vllm_omni.diffusion.cache.magcache import MagCacheConfig, FLUX_MAG_RATIOS + >>> from vllm_omni.diffusion.cache.magcache import MagCacheConfig + >>> from vllm_omni.diffusion.cache.magcache.strategy import FluxMagCacheStrategy >>> cache_config = DiffusionCacheConfig( - ... mag_ratios=FLUX_MAG_RATIOS, + ... mag_ratios=FluxMagCacheStrategy.FLUX_MAG_RATIOS, ... num_inference_steps=28, ... threshold=0.06, ... max_skip_steps=3, @@ -93,10 +94,19 @@ def enable(self, pipeline: Any) -> None: if mag_ratios is None: strategy = MagCacheStrategyRegistry.get_if_exists(transformer_type) if strategy is not None: - mag_ratios = strategy.mag_ratios - logger.info( - f"MagCache: Using default mag_ratios from strategy '{transformer_type}'" - ) + original_ratios = strategy.mag_ratios + if len(original_ratios) != num_inference_steps: + if hasattr(strategy, "nearest_interp"): + mag_ratios = strategy.nearest_interp(original_ratios, num_inference_steps) + logger.info( + f"MagCache: Interpolated mag_ratios from {len(original_ratios)} " + f"to {num_inference_steps} steps" + ) + else: + mag_ratios = original_ratios + else: + mag_ratios = original_ratios + logger.info(f"MagCache: Using default mag_ratios from strategy '{transformer_type}'") if mag_ratios is None and not self.config.calibrate: raise ValueError( @@ -138,15 +148,12 @@ def refresh(self, pipeline: Any, num_inference_steps: int) -> None: num_inference_steps: Number of inference steps for the current generation. May be used for cache context updates. """ - from diffusers.hooks._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS - transformer = pipeline.transformer - transformer_type = transformer.__class__.__name__ current_transformer_id = id(transformer) needs_re_register = False - if self._registered and hasattr(self, '_transformer_id'): + if self._registered and hasattr(self, "_transformer_id"): if current_transformer_id != self._transformer_id: logger.warning( f"Transformer was replaced (id changed from {self._transformer_id} " @@ -163,7 +170,7 @@ def refresh(self, pipeline: Any, num_inference_steps: int) -> None: blocks_with_hooks = [] for name, submodule in transformer.named_children(): - if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): + if not isinstance(submodule, torch.nn.ModuleList): continue for index, block in enumerate(submodule): registry = getattr(block, "_hook_registry", None) diff --git a/vllm_omni/diffusion/cache/magcache/config.py b/vllm_omni/diffusion/cache/magcache/config.py index 227da641a2e..229783d3eaa 100644 --- a/vllm_omni/diffusion/cache/magcache/config.py +++ b/vllm_omni/diffusion/cache/magcache/config.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + from dataclasses import dataclass -from typing import Optional, Union import torch @@ -19,30 +20,19 @@ class MagCacheConfig: Reference: https://github.com/Zehong-Ma/MagCache Args: - threshold: The threshold for the accumulated error. If the accumulated error - is below this threshold, the block computation is skipped. A higher threshold - allows for more aggressive skipping (faster) but may degrade quality. + threshold: Accumulated error threshold. Higher = more aggressive skipping (faster, lower quality). Default: 0.06 - max_skip_steps: The maximum number of consecutive steps that can be skipped (K). + max_skip_steps: Max consecutive skip steps (K). Default: 3 - retention_ratio: The fraction of initial steps during which skipping is disabled - to ensure stability. For example, if num_inference_steps is 28 and - retention_ratio is 0.2, the first 6 steps will never be skipped. + retention_ratio: Fraction of initial steps where skipping is disabled (stability). Default: 0.2 - num_inference_steps: The number of inference steps used in the pipeline. - This is required to interpolate mag_ratios correctly. + num_inference_steps: Total inference steps. Required for retention step calculation. Default: 28 - mag_ratios: The pre-computed magnitude ratios for the model. These are - checkpoint-dependent. If not provided, you must set calibrate=True to - calculate them for your specific model. For Flux models, you can use - FLUX_MAG_RATIOS. + mag_ratios: Pre-computed magnitude ratios per step. Calibrate or use strategy defaults. Default: None - calibrate: If True, enables calibration mode. In this mode, no blocks are skipped. - Instead, the hook calculates the magnitude ratios for the current run and logs - them at the end. Use this to obtain mag_ratios for new models or schedulers. + calibrate: If True, runs without skipping and logs norm_ratios for calibration. Default: False - transformer_type: Transformer class name for logging and identification. - Auto-detected from pipeline.transformer.__class__.__name__ in backend. + transformer_type: Transformer class name for logging. Default: "FluxTransformer2DModel" """ @@ -50,7 +40,7 @@ class MagCacheConfig: max_skip_steps: int = 3 retention_ratio: float = 0.2 num_inference_steps: int = 28 - mag_ratios: Optional[Union[torch.Tensor, list[float]]] = None + mag_ratios: torch.Tensor | list[float] | None = None calibrate: bool = False transformer_type: str = "FluxTransformer2DModel" @@ -88,16 +78,3 @@ def __post_init__(self) -> None: if not self.calibrate and self.mag_ratios is not None: if not torch.is_tensor(self.mag_ratios): self.mag_ratios = torch.tensor(self.mag_ratios) - - -FLUX_MAG_RATIOS = None - - -def get_flux_mag_ratios() -> torch.Tensor: - """Get FLUX_MAG_RATIOS from FluxMagCacheStrategy, importing only when needed.""" - global FLUX_MAG_RATIOS - if FLUX_MAG_RATIOS is None: - from vllm_omni.diffusion.cache.magcache.strategy import FluxMagCacheStrategy - - FLUX_MAG_RATIOS = FluxMagCacheStrategy.FLUX_MAG_RATIOS - return FLUX_MAG_RATIOS diff --git a/vllm_omni/diffusion/cache/magcache/examples.py b/vllm_omni/diffusion/cache/magcache/examples.py deleted file mode 100644 index 84eb05869db..00000000000 --- a/vllm_omni/diffusion/cache/magcache/examples.py +++ /dev/null @@ -1,279 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -""" -MagCache Integration Template for New Models. - -This module provides a complete template for integrating new diffusion models -into MagCache. Copy this file and modify according to your model's architecture. - -Integration Steps: - 1. Analyze model architecture (block structure, I/O format) - 2. Create Strategy class (inherit from MagCacheStrategy) - 3. Implement required methods - 4. Register the strategy - 5. Test and calibrate mag_ratios -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any, Callable - -import torch -from diffusers.hooks._helpers import TransformerBlockRegistry, TransformerBlockMetadata - -from vllm_omni.diffusion.cache.magcache.strategy import ( - MagCacheStrategy, - MagCacheStrategyRegistry, -) - - -def register_transformer_block( - model_class, - return_hidden_states_index: int = 1, - return_encoder_hidden_states_index: int = 0, -) -> None: - """Register a transformer block class with the TransformerBlockRegistry.""" - try: - TransformerBlockRegistry.get(model_class) - except ValueError: - TransformerBlockRegistry.register( - model_class=model_class, - metadata=TransformerBlockMetadata( - return_hidden_states_index=return_hidden_states_index, - return_encoder_hidden_states_index=return_encoder_hidden_states_index, - ), - ) - - -# ============================================================================= -# EXAMPLE: SD3 (Stable Diffusion 3) Integration -# ============================================================================= - -class SD3MagCacheStrategy(MagCacheStrategy): - """ - MagCache strategy for SD3 (Stable Diffusion 3). - - SD3 Architecture Analysis: - - Single stream: hidden_states only (no encoder_hidden_states separation) - - Block structure: transformer_blocks (nn.ModuleList) - - Output: tuple of (hidden_states,) - - Residual: output - input (simple subtraction) - - Integration Steps: - 1. Identify block name: "transformer_blocks" - 2. Determine I/O indices: return_hidden_states_index=0 - 3. Define residual: output - input - """ - - @property - def transformer_type(self) -> str: - return "SD3Transformer2DModel" - - @property - def mag_ratios(self) -> torch.Tensor: - """Return default mag_ratios for SD3 model.""" - return self.SD3_MAG_RATIOS - - SD3_MAG_RATIOS = torch.tensor( - [ - 1.0, - 1.15, - 1.10, - 1.05, - 1.02, - 1.00, - 0.98, - 0.95, - 0.92, - 0.90, - 0.88, - 0.85, - 0.82, - 0.80, - 0.78, - 0.75, - 0.72, - 0.70, - 0.68, - 0.65, - 0.62, - 0.60, - 0.58, - 0.55, - 0.52, - 0.50, - 0.48, - 0.45, - ] - ) - - @staticmethod - def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor: - """Interpolate mag_ratios to target length using nearest neighbor.""" - src_length = len(src_array) - if target_length == 1: - return src_array[-1:] - - scale = (src_length - 1) / (target_length - 1) - grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32) - mapped_indices = torch.round(grid * scale).long() - - return src_array[mapped_indices] - - @staticmethod - def register_blocks() -> None: - """Register SD3 transformer blocks. - - SD3 uses a simple transformer block structure. - """ - try: - from diffusers.models.transformers.sd_transformer_2d import ( - SD3TransformerBlock, - ) - - register_transformer_block( - SD3TransformerBlock, - return_hidden_states_index=0, - ) - except ImportError: - pass - - def create_context( - self, - module: torch.nn.Module, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor | None, - timestep: torch.Tensor, - guidance: torch.Tensor | None, - **kwargs, - ) -> MagCacheContext: - """Create context for SD3 model.""" - temb = module.time_embed(timestep) - - def run_transformer_blocks(): - for block in module.transformer_blocks: - hidden_states = block(hidden_states, temb=temb) - return hidden_states - - def run_single_transformer_blocks(h): - return h - - def postprocess(h: torch.Tensor, e: torch.Tensor) -> Any: - return h - - return MagCacheContext( - hidden_states=hidden_states, - encoder_hidden_states=None, - temb=temb, - head_block_input=None, - run_transformer_blocks=run_transformer_blocks, - run_single_transformer_blocks=run_single_transformer_blocks, - postprocess=postprocess, - ) - - def get_head_block_input(self, context: MagCacheContext) -> torch.Tensor: - """Get input to the first transformer block.""" - return context.hidden_states - - def compute_residual( - self, - output: torch.Tensor, - head_input: torch.Tensor, - context: MagCacheContext | None, - ) -> torch.Tensor: - """Compute residual for SD3: output - input.""" - return output - head_input - - def apply_residual( - self, - hidden_states: torch.Tensor, - residual: torch.Tensor, - ) -> torch.Tensor: - """Apply residual for SD3: hidden_states + residual.""" - return hidden_states + residual - - -# Register SD3 strategy -MagCacheStrategyRegistry.register(SD3MagCacheStrategy()) - - -# ============================================================================= -# TEMPLATE: Copy and modify for your model -# ============================================================================= - -# class YourModelMagCacheStrategy(MagCacheStrategy): -# """ -# MagCache strategy for YourModel. -# -# Model Architecture Analysis: -# - [Describe the architecture] -# - Block name: [e.g., "transformer_blocks", "layers", "blocks"] -# - I/O format: [Describe input/output format] -# - Residual: [Describe how to compute residual] -# """ -# -# @property -# def transformer_type(self) -> str: -# return "YourTransformerModel" -# -# YOUR_MAG_RATIOS = torch.tensor([...]) # Replace with your model's ratios -# -# @staticmethod -# def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor: -# """Interpolate mag_ratios to target length.""" -# src_length = len(src_array) -# if target_length == 1: -# return src_array[-1:] -# -# scale = (src_length - 1) / (target_length - 1) -# grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32) -# mapped_indices = torch.round(grid * scale).long() -# -# return src_array[mapped_indices] -# -# @staticmethod -# def register_blocks() -> None: -# """Register your model's transformer blocks.""" -# try: -# from your_model import YourTransformerBlock -# -# register_transformer_block( -# YourTransformerBlock, -# return_hidden_states_index=0, # Adjust based on your model's output -# ) -# except ImportError: -# pass -# -# def compute_residual( -# self, -# output: torch.Tensor, -# head_input: torch.Tensor, -# context: MagCacheContext | None, -# ) -> torch.Tensor: -# """Compute residual for your model. -# -# Common patterns: -# - Simple: output - head_input -# - Complex: output - head_input (with shape adjustments) -# """ -# return output - head_input -# -# def apply_residual( -# self, -# hidden_states: torch.Tensor, -# residual: torch.Tensor, -# ) -> torch.Tensor: -# """Apply residual for your model. -# -# Common patterns: -# - Simple: hidden_states + residual -# - Complex: hidden_states + residual (with shape adjustments) -# """ -# return hidden_states + residual -# -# -# # Register your strategy -# MagCacheStrategyRegistry.register(YourModelMagCacheStrategy()) diff --git a/vllm_omni/diffusion/cache/magcache/hook.py b/vllm_omni/diffusion/cache/magcache/hook.py index 5ba4ebf7ac5..e483b177c9c 100644 --- a/vllm_omni/diffusion/cache/magcache/hook.py +++ b/vllm_omni/diffusion/cache/magcache/hook.py @@ -22,17 +22,16 @@ from __future__ import annotations -from typing import Any - import torch +import torch.nn.functional as F from diffusers.hooks._helpers import TransformerBlockRegistry -from diffusers.hooks._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS from diffusers.utils.torch_utils import unwrap_module -from vllm_omni.logger import init_logger from vllm_omni.diffusion.cache.magcache.config import MagCacheConfig -from vllm_omni.diffusion.cache.magcache.strategy import MagCacheStrategy, MagCacheStrategyRegistry, FluxMagCacheStrategy +from vllm_omni.diffusion.cache.magcache.state import MagCacheState +from vllm_omni.diffusion.cache.magcache.strategy import MagCacheStrategy, MagCacheStrategyRegistry from vllm_omni.diffusion.hooks.base import HookRegistry, ModelHook, StateManager +from vllm_omni.logger import init_logger logger = init_logger(__name__) @@ -40,31 +39,6 @@ _MAG_CACHE_BLOCK_HOOK = "mag_cache_block_hook" -class MagCacheState: - """State management for MagCache hook.""" - - def __init__(self) -> None: - """Initialize empty MagCache state.""" - self.previous_residual: torch.Tensor | None = None - self.head_block_input: torch.Tensor | tuple | None = None - self.should_compute: bool = True - self.accumulated_ratio: float = 1.0 - self.accumulated_err: float = 0.0 - self.accumulated_steps: int = 0 - self.step_index: int = 0 - self.calibration_ratios: list[float] = [] - - def reset(self) -> None: - """Reset all state variables for a new inference run.""" - self.previous_residual = None - self.should_compute = True - self.accumulated_ratio = 1.0 - self.accumulated_err = 0.0 - self.accumulated_steps = 0 - self.step_index = 0 - self.calibration_ratios = [] - - class MagCacheHeadHook(ModelHook): """Head block hook for MagCache - decides whether to skip computation.""" @@ -88,7 +62,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self.state_manager._current_context is None: self.state_manager.set_context("inference") - if hasattr(self._metadata, 'hidden_states_argument_name'): + if hasattr(self._metadata, "hidden_states_argument_name"): arg_name = self._metadata.hidden_states_argument_name else: arg_name = "hidden_states" @@ -97,6 +71,12 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): state: MagCacheState = self.state_manager.get_state() state.head_block_input = hidden_states + if state._is_first_step: + state.accumulated_ratio = 1.0 + state.accumulated_err = 0.0 + state.accumulated_steps = 0 + state._is_first_step = False + should_compute = True if self.config.calibrate: @@ -147,7 +127,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): ret_list = [None] * 2 ret_list[self._metadata.return_hidden_states_index] = output ret_list[self._metadata.return_encoder_hidden_states_index] = enc_output - return self._log_cache_hit(state, output, ret_list) + return self.log_cache_hit(state, output, ret_list) else: raise RuntimeError( f"MagCache residual is tuple but no strategy available for {self._metadata.transformer_type}. " @@ -181,14 +161,14 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): ret_list = [None] * (max_idx + 1) ret_list[self._metadata.return_hidden_states_index] = output ret_list[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states - return self._log_cache_hit(state, output, ret_list) + return self.log_cache_hit(state, output, ret_list) else: - return self._log_cache_hit(state, output, None) + return self.log_cache_hit(state, output, None) else: output = self.fn_ref.original_forward(*args, **kwargs) - return self._log_cache_miss(state, output) + return self.log_cache_miss(state, output) - def _log_cache_hit(self, state: MagCacheState, output, ret): + def log_cache_hit(self, state: MagCacheState, output, ret): step = state.step_index if state.previous_residual is not None: if isinstance(state.previous_residual, tuple): @@ -203,7 +183,7 @@ def _log_cache_hit(self, state: MagCacheState, output, ret): ) return ret if ret is not None else output - def _log_cache_miss(self, state: MagCacheState, output): + def log_cache_miss(self, state: MagCacheState, output): step = state.step_index residual_norm = 0.0 if state.previous_residual is not None: @@ -212,8 +192,10 @@ def _log_cache_miss(self, state: MagCacheState, output): else: residual_norm = float(torch.norm(state.previous_residual).item()) logger.debug( - f"[MagCache][HEAD] STEP={step}: CACHE_MISS (err={state.accumulated_err:.6f}, " - f"acc_ratio={state.accumulated_ratio:.6f}, residual_norm={residual_norm:.6f}, threshold={self.config.threshold}, max_skip={self.config.max_skip_steps})" + f"[MagCache][HEAD] STEP={step}: CACHE_MISS " + f"(err={state.accumulated_err:.6f}, acc_ratio={state.accumulated_ratio:.6f}, " + f"residual_norm={residual_norm:.6f}, threshold={self.config.threshold}, " + f"max_skip={self.config.max_skip_steps})" ) return output @@ -257,14 +239,14 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): state: MagCacheState = self.state_manager.get_state() if not state.should_compute: - if hasattr(self._metadata, 'hidden_states_argument_name'): + if hasattr(self._metadata, "hidden_states_argument_name"): arg_name = self._metadata.hidden_states_argument_name else: arg_name = "hidden_states" hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs) if self.is_tail: - self._advance_step(state) + self.advance_step(state) if self._metadata.return_encoder_hidden_states_index is not None: encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( @@ -294,7 +276,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): return output if self._strategy is not None: - residual = self._strategy.compute_residual(output, in_hidden, None) + residual = self._strategy.compute_residual(output, in_hidden) elif out_hidden.shape == in_hidden.shape: residual = out_hidden - in_hidden elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]: @@ -307,16 +289,20 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): residual = out_hidden if self.config.calibrate: - self._perform_calibration_step(state, residual) + self.perform_calibration(state, residual) state.previous_residual = residual - self._advance_step(state) + self.advance_step(state) - self._log_residual_computed(state, residual) + self.log_residual_computed(state, residual) return output - def _log_residual_computed(self, state: MagCacheState, residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor]) -> None: + def log_residual_computed( + self, + state: MagCacheState, + residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + ) -> None: step = state.step_index if residual is None: residual_norm = 0.0 @@ -328,35 +314,60 @@ def _log_residual_computed(self, state: MagCacheState, residual: torch.Tensor | residual_norm = float(torch.norm(residual).item()) residual_shape = residual.shape logger.debug( - f"[MagCache][TAIL] STEP={step}: RESIDUAL_COMPUTED (norm={residual_norm:.6f}, " - f"shape={residual_shape})" + f"[MagCache][TAIL] STEP={step}: RESIDUAL_COMPUTED (norm={residual_norm:.6f}, shape={residual_shape})" ) - def _perform_calibration_step(self, state: MagCacheState, current_residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor]) -> None: - def _get_norm(residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - if isinstance(residual, tuple): - return sum(torch.linalg.norm(r.float(), dim=-1) for r in residual) - return torch.linalg.norm(residual.float(), dim=-1) - - if state.previous_residual is None: - ratio = 1.0 + def perform_calibration( + self, + state: MagCacheState, + current_residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + ) -> None: + if self._strategy is not None: + ratio, std, cos_dis = self._strategy.compute_calibration_metrics(current_residual, state.previous_residual) else: - curr_norm = _get_norm(current_residual) - prev_norm = _get_norm(state.previous_residual) - ratio = (curr_norm / (prev_norm + 1e-8)).mean().item() + + def _get_norm(residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + if isinstance(residual, tuple): + return sum(torch.linalg.norm(r.float(), dim=-1) for r in residual) + return torch.linalg.norm(residual.float(), dim=-1) + + if state.previous_residual is None: + ratio, std, cos_dis = 1.0, 0.0, 0.0 + else: + curr_norm = _get_norm(current_residual) + prev_norm = _get_norm(state.previous_residual) + ratio = (curr_norm / (prev_norm + 1e-8)).mean().item() + std = (curr_norm / (prev_norm + 1e-8)).std().item() + cos_dis = ( + ( + 1 + - F.cosine_similarity( + current_residual.flatten(0, -2) if current_residual.ndim > 2 else current_residual, + state.previous_residual.flatten(0, -2) + if state.previous_residual.ndim > 2 + else state.previous_residual, + dim=-1, + eps=1e-8, + ) + ) + .mean() + .item() + ) state.calibration_ratios.append(ratio) + state.norm_ratios.append(round(ratio, 5)) + state.norm_stds.append(round(std, 5)) + state.cos_dises.append(round(cos_dis, 5)) - def _advance_step(self, state: MagCacheState) -> None: + def advance_step(self, state: MagCacheState) -> None: state.step_index += 1 if state.step_index >= self.config.num_inference_steps: if self.config.calibrate: - logger.info( - f"MagCache calibration complete. mag_ratios={state.calibration_ratios}" - ) - logger.info( - "Copy these values to DiffusionCacheConfig(mag_ratios=...) for production use" - ) + logger.info("MagCache calibration complete.") + logger.info(f"norm_ratios: {state.norm_ratios}") + logger.info(f"norm_stds: {state.norm_stds}") + logger.info(f"cos_dises: {state.cos_dises}") + logger.info("Copy these values to DiffusionCacheConfig(mag_ratios=...) for production use") state.step_index = 0 state.accumulated_ratio = 1.0 @@ -364,6 +375,10 @@ def _advance_step(self, state: MagCacheState) -> None: state.accumulated_err = 0.0 state.previous_residual = None state.calibration_ratios = [] + state.norm_ratios = [] + state.norm_stds = [] + state.cos_dises = [] + state._is_first_step = True def apply_mag_cache_hook(module: torch.nn.Module, config: MagCacheConfig) -> None: @@ -384,14 +399,14 @@ def apply_mag_cache_hook(module: torch.nn.Module, config: MagCacheConfig) -> Non ) else: logger.info(f"MagCache: Using strategy '{transformer_type}' for optimization") - if hasattr(strategy, 'register_blocks'): + if hasattr(strategy, "register_blocks"): strategy.register_blocks() state_manager = StateManager(MagCacheState, (), {}) remaining_blocks = [] for name, submodule in module.named_children(): - if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): + if not isinstance(submodule, torch.nn.ModuleList): continue for index, block in enumerate(submodule): remaining_blocks.append((f"{name}.{index}", block)) diff --git a/vllm_omni/diffusion/cache/magcache/state.py b/vllm_omni/diffusion/cache/magcache/state.py new file mode 100644 index 00000000000..28060ec832d --- /dev/null +++ b/vllm_omni/diffusion/cache/magcache/state.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +MagCache state management. + +This module contains the MagCacheState class which manages the internal state +for MagCache caching logic, including residuals, accumulated metrics, and step tracking. +""" + +import torch + + +class MagCacheState: + """State management for MagCache caching logic.""" + + def __init__(self) -> None: + """Initialize empty MagCache state.""" + self.previous_residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None = None + self.head_block_input: torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None = None + self.should_compute: bool = True + self.accumulated_ratio: float = 1.0 + self.accumulated_err: float = 0.0 + self.accumulated_steps: int = 0 + self.step_index: int = 0 + self.calibration_ratios: list[float] = [] + self.norm_ratios: list[float] = [] + self.norm_stds: list[float] = [] + self.cos_dises: list[float] = [] + self._is_first_step: bool = True + + def reset(self) -> None: + """Reset all state variables for a new inference run.""" + self.previous_residual = None + self.should_compute = True + self.accumulated_ratio = 1.0 + self.accumulated_err = 0.0 + self.accumulated_steps = 0 + self.step_index = 0 + self.calibration_ratios = [] + self.norm_ratios = [] + self.norm_stds = [] + self.cos_dises = [] + self._is_first_step = True diff --git a/vllm_omni/diffusion/cache/magcache/strategy.py b/vllm_omni/diffusion/cache/magcache/strategy.py index 135f52b8d77..b9140514fe0 100644 --- a/vllm_omni/diffusion/cache/magcache/strategy.py +++ b/vllm_omni/diffusion/cache/magcache/strategy.py @@ -15,11 +15,12 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable +from typing import Any import torch -from diffusers.hooks._helpers import TransformerBlockRegistry, TransformerBlockMetadata +from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry def register_transformer_block( @@ -75,10 +76,9 @@ class MagCacheStrategy(ABC): Abstract base class for MagCache strategies. Each model architecture requires a specific strategy to handle: - - Preprocessing of inputs (embeddings, positional encodings) - - Running transformer blocks - - Postprocessing (normalization, projection) - - Computing residuals for caching + - Residual computation (how to calculate the residual for caching) + - Residual application (how to apply cached residual) + - Model-specific magnitude ratios Implement this class to add support for new model architectures. """ @@ -102,59 +102,17 @@ def mag_ratios(self) -> torch.Tensor: """ pass - @abstractmethod - def create_context( - self, - module: torch.nn.Module, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor | None, - timestep: torch.Tensor, - guidance: torch.Tensor | None, - **kwargs, - ) -> MagCacheContext: - """ - Create a MagCacheContext from model inputs. - - Args: - module: The transformer module. - hidden_states: Input latents. - encoder_hidden_states: Text encoder outputs (None for single-stream). - timestep: Denoising timestep. - guidance: Guidance scale tensor (optional). - **kwargs: Additional model-specific arguments. - - Returns: - MagCacheContext with all information needed for caching. - """ - pass - - @abstractmethod - def get_head_block_input(self, context: MagCacheContext) -> torch.Tensor: - """ - Get the input to the first transformer block. - - Args: - context: MagCacheContext from create_context. - - Returns: - Tensor representing the input to the first block. - """ - pass - @abstractmethod def compute_residual( self, output: torch.Tensor, head_input: torch.Tensor, - context: MagCacheContext, ) -> torch.Tensor: - """ - Compute residual between output and head input. + """Compute residual between block output and input. Args: output: Output from transformer blocks. head_input: Input to the first block. - context: MagCacheContext. Returns: Residual tensor for caching. @@ -162,9 +120,12 @@ def compute_residual( pass @abstractmethod - def apply_residual(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: - """ - Apply cached residual to hidden states. + def apply_residual( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + ) -> torch.Tensor: + """Apply cached residual to hidden states. Args: hidden_states: Current hidden states. @@ -181,8 +142,7 @@ def apply_residual_tuple( encoder_hidden_states: torch.Tensor, residual: tuple[torch.Tensor, torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply cached residual tuple to both hidden_states and encoder_hidden_states. + """Apply cached residual tuple to both hidden_states and encoder_hidden_states. Default implementation: add residuals separately. Override this method for models with specific residual application logic. @@ -198,6 +158,34 @@ def apply_residual_tuple( h_res, e_res = residual return hidden_states + h_res, encoder_hidden_states + e_res + @abstractmethod + def compute_calibration_metrics( + self, + current_residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + previous_residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None, + ) -> tuple[float, float, float]: + """Compute calibration metrics for mag_ratios generation. + + Args: + current_residual: Residual from the current step. + previous_residual: Residual from the previous step (None for first step). + + Returns: + Tuple of (norm_ratio, norm_std, cos_dis): + - norm_ratio: Mean ratio of current to previous residual norms + - norm_std: Standard deviation of the norm ratios + - cos_dis: Mean cosine dissimilarity (1 - cosine_similarity) + """ + pass + + def get_calibration_metrics_names(self) -> tuple[str, str, str]: + """Return the names of calibration metrics for logging. + + Returns: + Tuple of metric names in order: (norm_ratio_name, norm_std_name, cos_dis_name) + """ + return ("norm_ratio", "norm_std", "cos_dis") + class FluxMagCacheStrategy(MagCacheStrategy): """ @@ -212,9 +200,42 @@ class FluxMagCacheStrategy(MagCacheStrategy): This strategy provides: - mag_ratios: Pre-computed magnitude ratios for Flux (28 steps) - - nearest_interp(): Interpolate mag_ratios to match num_inference_steps + - compute_calibration_metrics: FLUX-specific metric computation """ + FLUX_MAG_RATIOS = torch.tensor( + [ + 1.0, + 1.07313, + 1.21035, + 1.04432, + 1.06818, + 1.05547, + 1.0183, + 1.03405, + 1.02574, + 1.03042, + 1.02739, + 1.01955, + 1.01585, + 1.02439, + 1.01154, + 1.01377, + 1.00994, + 1.01444, + 1.00839, + 1.02269, + 1.0007, + 1.00714, + 1.00484, + 1.01381, + 1.00426, + 0.99764, + 1.00778, + 1.00233, + ] + ) + @property def transformer_type(self) -> str: return "FluxTransformer2DModel" @@ -224,124 +245,50 @@ def mag_ratios(self) -> torch.Tensor: """Return default mag_ratios for Flux model.""" return self.FLUX_MAG_RATIOS - FLUX_MAG_RATIOS = torch.tensor( - [1.0] - + [ - 1.21094, - 1.11719, - 1.07812, - 1.0625, - 1.03906, - 1.03125, - 1.03906, - 1.02344, - 1.03125, - 1.02344, - 0.98047, - 1.01562, - 1.00781, - 1.0, - 1.00781, - 1.0, - 1.00781, - 1.0, - 1.0, - 0.99609, - 0.99609, - 0.98047, - 0.98828, - 0.96484, - 0.95703, - 0.93359, - 0.89062, - ] - ) - - def create_context( + def compute_calibration_metrics( self, - module: torch.nn.Module, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor | None, - timestep: torch.Tensor, - guidance: torch.Tensor | None, - **kwargs, - ) -> MagCacheContext: - """Create context for Flux model.""" - temb = ( - module.time_text_embed(timestep, guidance) - if guidance is not None - else module.time_text_embed(timestep) - ) + current_residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + previous_residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None, + ) -> tuple[float, float, float]: + """Compute calibration metrics for FLUX model.""" + import torch.nn.functional as F - def run_transformer_blocks(): - h = hidden_states - e = encoder_hidden_states - for block in module.transformer_blocks: - e, h = block( - hidden_states=h, - encoder_hidden_states=e, - temb=temb, - ) - return e, h - - def run_single_transformer_blocks(h): - for block in module.single_transformer_blocks: - h = block( - hidden_states=h, - encoder_hidden_states=torch.zeros(1, 1, h.shape[-1], device=h.device, dtype=h.dtype), - temb=temb, - ) - return h - - def postprocess(e: torch.Tensor, h: torch.Tensor) -> Any: - h = torch.cat([e, h], dim=1) - h = module.norm_out(h, temb) - output = module.proj_out(h) - from diffusers.models.modeling_outputs import Transformer2DModelOutput - - return Transformer2DModelOutput(sample=output) - - return MagCacheContext( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - head_block_input=None, - run_transformer_blocks=run_transformer_blocks, - run_single_transformer_blocks=run_single_transformer_blocks, - postprocess=postprocess, - ) + if previous_residual is None: + return 1.0, 0.0, 0.0 + + curr_norm = torch.linalg.norm(current_residual.float(), dim=-1) + prev_norm = torch.linalg.norm(previous_residual.float(), dim=-1) + + ratio = (curr_norm / (prev_norm + 1e-8)).mean().item() + std = (curr_norm / (prev_norm + 1e-8)).std().item() + cos_dis = (1 - F.cosine_similarity(current_residual, previous_residual, dim=-1, eps=1e-8)).mean().item() + + return ratio, std, cos_dis - def get_head_block_input(self, context: MagCacheContext) -> torch.Tensor: - """Get input to the first transformer block.""" - return context.hidden_states + @staticmethod + def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor: + """Interpolate mag_ratios to target length using nearest neighbor.""" + src_length = len(src_array) + if target_length == 1: + return src_array[-1:] + + scale = (src_length - 1) / (target_length - 1) + grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32) + mapped_indices = torch.round(grid * scale).long() + + return src_array[mapped_indices] def compute_residual( self, output: torch.Tensor, head_input: torch.Tensor, - context: MagCacheContext, ) -> torch.Tensor: - """Compute residual for Flux single transformer blocks. - - For single transformer blocks, the output is concatenated (encoder + decoder). - We need to extract encoder residual from the combined output. - """ - if context is not None: - encoder_hidden_states = context.encoder_hidden_states - if encoder_hidden_states is not None: - encoder_len = encoder_hidden_states.shape[1] - if isinstance(output, tuple): - out_e = output[0] - out_h = output[1] - else: - out_e = output[:, :encoder_len, :] - out_h = output[:, encoder_len:, :] - - e_res = out_e - encoder_hidden_states - h_res = out_h - head_input - return (e_res, h_res) - - return output + """Compute residual for Flux single transformer blocks.""" + if isinstance(output, tuple): + decoder_output = output[1] if len(output) > 1 else output[0] + else: + decoder_output = output - head_input + return decoder_output def apply_residual( self, @@ -357,37 +304,24 @@ def apply_residual_tuple( encoder_hidden_states: torch.Tensor, residual: tuple[torch.Tensor, torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: - """Apply residual tuple to both hidden_states and encoder_hidden_states. - - Flux architecture: - - encoder: 512 tokens - - decoder: 4096 tokens - - The residual tuple (e_res, h_res) comes from compute_residual: - - e_res: encoder residual (512 tokens) - - h_res: decoder residual (4096 tokens) - - We apply residuals separately to encoder_hidden_states and hidden_states. - """ - e_res, h_res = residual + """Apply residual tuple (for compatibility with hook interface).""" + if isinstance(residual, tuple): + decoder_residual = residual[1] + else: + decoder_residual = residual - output = hidden_states + h_res - enc_output = encoder_hidden_states + e_res + output = hidden_states + decoder_residual + enc_output = encoder_hidden_states return output, enc_output @staticmethod def register_blocks() -> None: - """Register vLLM-Omni Flux transformer blocks with TransformerBlockRegistry. - - Blocks: - - FluxTransformerBlock: dual-stream block - - FluxSingleTransformerBlock: single-stream block - """ + """Register vLLM-Omni Flux transformer blocks with TransformerBlockRegistry.""" try: from vllm_omni.diffusion.models.flux.flux_transformer import ( - FluxTransformerBlock, FluxSingleTransformerBlock, + FluxTransformerBlock, ) register_transformer_block(FluxTransformerBlock) @@ -395,19 +329,6 @@ def register_blocks() -> None: except ImportError: pass - @staticmethod - def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor: - """Interpolate mag_ratios to target length using nearest neighbor.""" - src_length = len(src_array) - if target_length == 1: - return src_array[-1:] - - scale = (src_length - 1) / (target_length - 1) - grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32) - mapped_indices = torch.round(grid * scale).long() - - return src_array[mapped_indices] - class MagCacheStrategyRegistry: """Registry for MagCache strategies by transformer type.""" @@ -424,10 +345,7 @@ def get(cls, transformer_type: str) -> MagCacheStrategy: """Get strategy for given transformer type.""" if transformer_type not in cls._registry: available = list(cls._registry.keys()) - raise ValueError( - f"Unknown model type: '{transformer_type}'. " - f"Available types: {available}" - ) + raise ValueError(f"Unknown model type: '{transformer_type}'. Available types: {available}") return cls._registry[transformer_type] @classmethod diff --git a/vllm_omni/diffusion/cache/selector.py b/vllm_omni/diffusion/cache/selector.py index 857bb429507..98b88720384 100644 --- a/vllm_omni/diffusion/cache/selector.py +++ b/vllm_omni/diffusion/cache/selector.py @@ -17,7 +17,7 @@ def get_cache_backend(cache_backend: str | None, cache_config: Any) -> CacheBack cache_config: Cache configuration (dict or DiffusionCacheConfig instance). Returns: - Cache backend instance (CacheDiTBackend, TeaCacheBackend, or MagCacheBackend) + Cache backend instance (CacheDiTBackend, TeaCacheBackend, or MagCacheBackend) if cache_backend is set, None otherwise. Raises: @@ -43,6 +43,5 @@ def get_cache_backend(cache_backend: str | None, cache_config: Any) -> CacheBack return MagCacheBackend(cache_config) else: raise ValueError( - f"Unsupported cache backend: {cache_backend}. " - f"Supported: 'cache_dit', 'tea_cache', 'mag_cache'" + f"Unsupported cache backend: {cache_backend}. Supported: 'cache_dit', 'tea_cache', 'mag_cache'" ) diff --git a/vllm_omni/diffusion/hooks/base.py b/vllm_omni/diffusion/hooks/base.py index 18da45224c9..25eea172cf3 100644 --- a/vllm_omni/diffusion/hooks/base.py +++ b/vllm_omni/diffusion/hooks/base.py @@ -12,8 +12,8 @@ from dataclasses import dataclass from typing import Any -import torch import torch.nn as nn + from vllm_omni.logger import init_logger logger = init_logger(__name__) @@ -225,9 +225,7 @@ def register_hook(self, hook: ModelHook, name: str | None = None) -> str | None: return None if name in self._hooks: - raise ValueError( - f"Hook with name '{name}' already exists. Remove it first or use a different name." - ) + raise ValueError(f"Hook with name '{name}' already exists. Remove it first or use a different name.") hook.initialize_hook(self.module) diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index d0680a32193..18354af2467 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -175,7 +175,7 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu default=None, help="JSON string of cache configuration. " "TeaCache: '{\"rel_l1_thresh\": 0.2}'. " - "MagCache: '{\"threshold\": 0.06, \"max_skip_steps\": 3, \"mag_ratios\": [1.0, ...]}'. " + 'MagCache: \'{"threshold": 0.06, "max_skip_steps": 3, "mag_ratios": [1.0, ...]}\'. ' "Calibration mode: add '\"calibrate\": true'", ) omni_config_group.add_argument( From 2cc04045977159770762d6e7805531aa5006b238 Mon Sep 17 00:00:00 2001 From: Lancer Date: Mon, 9 Feb 2026 20:48:58 +0800 Subject: [PATCH 03/11] upd Signed-off-by: Lancer --- vllm_omni/diffusion/cache/magcache/backend.py | 9 +++------ vllm_omni/diffusion/cache/magcache/hook.py | 4 ++-- vllm_omni/diffusion/hooks/base.py | 17 ++--------------- 3 files changed, 7 insertions(+), 23 deletions(-) diff --git a/vllm_omni/diffusion/cache/magcache/backend.py b/vllm_omni/diffusion/cache/magcache/backend.py index af29e67a33e..2dcaffafc3b 100644 --- a/vllm_omni/diffusion/cache/magcache/backend.py +++ b/vllm_omni/diffusion/cache/magcache/backend.py @@ -16,10 +16,8 @@ from vllm_omni.diffusion.cache.base import CacheBackend from vllm_omni.diffusion.cache.magcache.config import MagCacheConfig from vllm_omni.diffusion.cache.magcache.hook import ( - MagCacheState, apply_mag_cache_hook, ) -from vllm_omni.diffusion.hooks.base import StateManager logger = init_logger(__name__) @@ -165,8 +163,6 @@ def refresh(self, pipeline: Any, num_inference_steps: int) -> None: self.enable(pipeline) return - state_manager = StateManager(MagCacheState, (), {}) - blocks_with_hooks = [] for name, submodule in transformer.named_children(): @@ -185,8 +181,9 @@ def refresh(self, pipeline: Any, num_inference_steps: int) -> None: for name, block, registry in blocks_with_hooks: if hasattr(block, "do_true_cfg"): delattr(block, "do_true_cfg") - - state_manager.reset() + for hook in registry._hooks: + if hasattr(hook, "reset_state"): + hook.reset_state(block) def is_enabled(self) -> bool: """Check if MagCache is enabled. diff --git a/vllm_omni/diffusion/cache/magcache/hook.py b/vllm_omni/diffusion/cache/magcache/hook.py index e483b177c9c..168813614c0 100644 --- a/vllm_omni/diffusion/cache/magcache/hook.py +++ b/vllm_omni/diffusion/cache/magcache/hook.py @@ -447,7 +447,7 @@ def _apply_mag_cache_head_hook( registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) hook = MagCacheHeadHook(state_manager, config, strategy) - registry.register_hook(hook, _MAG_CACHE_LEADER_BLOCK_HOOK) + registry.register_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, hook) def _apply_mag_cache_block_hook( @@ -463,4 +463,4 @@ def _apply_mag_cache_block_hook( registry.remove_hook(_MAG_CACHE_BLOCK_HOOK) hook = MagCacheBlockHook(state_manager, is_tail, config, strategy) - registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK) + registry.register_hook(_MAG_CACHE_BLOCK_HOOK, hook) diff --git a/vllm_omni/diffusion/hooks/base.py b/vllm_omni/diffusion/hooks/base.py index 25eea172cf3..0ec01dff236 100644 --- a/vllm_omni/diffusion/hooks/base.py +++ b/vllm_omni/diffusion/hooks/base.py @@ -204,26 +204,13 @@ def check_if_exists_or_initialize(cls, module: nn.Module) -> HookRegistry: """ return cls.get_or_create(module) - def register_hook(self, hook: ModelHook, name: str | None = None) -> str | None: + def register_hook(self, name: str, hook: ModelHook) -> None: """Register a hook with the given name. - This method follows the diffusers API convention where the hook object - comes first, followed by an optional name. If no name is provided, - uses hook._HOOK_NAME. - Args: + name: Unique name for this hook. hook: The hook instance to register. - name: Optional unique name for this hook. If not provided, - uses hook._HOOK_NAME. - - Returns: - The name the hook was registered under, or None if registration failed. """ - if name is None: - name = getattr(hook, "_HOOK_NAME", None) - if name is None: - return None - if name in self._hooks: raise ValueError(f"Hook with name '{name}' already exists. Remove it first or use a different name.") From a29533b30c3898a82d243be568192d61a110d198 Mon Sep 17 00:00:00 2001 From: Lancer Date: Tue, 10 Feb 2026 00:15:58 +0800 Subject: [PATCH 04/11] upd Signed-off-by: Lancer --- vllm_omni/diffusion/cache/__init__.py | 4 ++-- vllm_omni/diffusion/cache/magcache/backend.py | 2 +- .../diffusion/cache/magcache/strategy.py | 2 +- vllm_omni/diffusion/cache/selector.py | 9 +++------ vllm_omni/diffusion/hooks/base.py | 14 ------------- .../longcat_image_transformer.py | 20 +++---------------- 6 files changed, 10 insertions(+), 41 deletions(-) diff --git a/vllm_omni/diffusion/cache/__init__.py b/vllm_omni/diffusion/cache/__init__.py index dc544ea73e9..d6403f0fa7f 100644 --- a/vllm_omni/diffusion/cache/__init__.py +++ b/vllm_omni/diffusion/cache/__init__.py @@ -12,17 +12,17 @@ """ from vllm_omni.diffusion.cache.base import CacheBackend -from vllm_omni.diffusion.cache.selector import get_cache_backend from vllm_omni.diffusion.cache.teacache import ( CacheContext, TeaCacheConfig, apply_teacache_hook, ) +from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend __all__ = [ "CacheBackend", "CacheContext", - "get_cache_backend", + "TeaCacheBackend", "TeaCacheConfig", "apply_teacache_hook", ] diff --git a/vllm_omni/diffusion/cache/magcache/backend.py b/vllm_omni/diffusion/cache/magcache/backend.py index 2dcaffafc3b..e389ab5b1ab 100644 --- a/vllm_omni/diffusion/cache/magcache/backend.py +++ b/vllm_omni/diffusion/cache/magcache/backend.py @@ -181,7 +181,7 @@ def refresh(self, pipeline: Any, num_inference_steps: int) -> None: for name, block, registry in blocks_with_hooks: if hasattr(block, "do_true_cfg"): delattr(block, "do_true_cfg") - for hook in registry._hooks: + for hook in registry._hooks.values(): if hasattr(hook, "reset_state"): hook.reset_state(block) diff --git a/vllm_omni/diffusion/cache/magcache/strategy.py b/vllm_omni/diffusion/cache/magcache/strategy.py index b9140514fe0..298e3756a16 100644 --- a/vllm_omni/diffusion/cache/magcache/strategy.py +++ b/vllm_omni/diffusion/cache/magcache/strategy.py @@ -24,7 +24,7 @@ def register_transformer_block( - model_class, + model_class: type, return_hidden_states_index: int = 1, return_encoder_hidden_states_index: int = 0, ) -> None: diff --git a/vllm_omni/diffusion/cache/selector.py b/vllm_omni/diffusion/cache/selector.py index 98b88720384..e6c2e70e318 100644 --- a/vllm_omni/diffusion/cache/selector.py +++ b/vllm_omni/diffusion/cache/selector.py @@ -1,6 +1,9 @@ from typing import Any from vllm_omni.diffusion.cache.base import CacheBackend +from vllm_omni.diffusion.cache.cache_dit_backend import CacheDiTBackend +from vllm_omni.diffusion.cache.magcache.backend import MagCacheBackend +from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend from vllm_omni.diffusion.data import DiffusionCacheConfig @@ -30,16 +33,10 @@ def get_cache_backend(cache_backend: str | None, cache_config: Any) -> CacheBack cache_config = DiffusionCacheConfig.from_dict(cache_config) if cache_backend == "cache_dit": - from vllm_omni.diffusion.cache.cache_dit_backend import CacheDiTBackend - return CacheDiTBackend(cache_config) elif cache_backend == "tea_cache": - from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend - return TeaCacheBackend(cache_config) elif cache_backend == "mag_cache": - from vllm_omni.diffusion.cache.magcache.backend import MagCacheBackend - return MagCacheBackend(cache_config) else: raise ValueError( diff --git a/vllm_omni/diffusion/hooks/base.py b/vllm_omni/diffusion/hooks/base.py index 0ec01dff236..91d212c2b2f 100644 --- a/vllm_omni/diffusion/hooks/base.py +++ b/vllm_omni/diffusion/hooks/base.py @@ -14,10 +14,6 @@ import torch.nn as nn -from vllm_omni.logger import init_logger - -logger = init_logger(__name__) - class BaseState: """Base class for hook state containers.""" @@ -159,15 +155,6 @@ def __init__(self, module: nn.Module): self.module = module self._hooks: dict[str, ModelHook] = {} - def __getstate__(self): - """Handle pickling - preserve hooks.""" - return {"module": self.module, "_hooks": self._hooks} - - def __setstate__(self, state): - """Handle unpickling - restore hooks.""" - self.module = state["module"] - self._hooks = state["_hooks"] - @classmethod def get_or_create(cls, module: nn.Module) -> HookRegistry: """Get existing registry or create a new one for the module. @@ -228,7 +215,6 @@ def __init__(self, orig_forward): hook.fn_ref = _FnRef(original_forward) self._hooks[name] = hook - return name def remove_hook(self, name: str) -> None: """Remove a hook by name. diff --git a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py index 2f6d54aef71..ee6317372ce 100644 --- a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py +++ b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py @@ -209,9 +209,9 @@ def forward( encoder_query = self.norm_added_q(encoder_query) encoder_key = self.norm_added_k(encoder_key) - # Check if SP is enabled from forward context (set by LongCatImageTransformer2DModel) + # Check if SP is enabled forward_ctx = get_forward_context() - sp_size = forward_ctx.sequence_parallel_size + sp_size = self.parallel_config.sequence_parallel_size use_sp_joint_attention = sp_size > 1 and not forward_ctx.split_text_embed_in_sp if use_sp_joint_attention: @@ -249,7 +249,7 @@ def forward( # Check if SP is enabled and we have text_seq_len info forward_ctx = get_forward_context() - sp_size = forward_ctx.sequence_parallel_size + sp_size = self.parallel_config.sequence_parallel_size text_seq_len = kwargs.get("text_seq_len", None) use_sp_single_stream = sp_size > 1 and not forward_ctx.split_text_embed_in_sp and text_seq_len is not None @@ -590,22 +590,8 @@ def forward( if sp_size > 1: sp_world_size = get_sequence_parallel_world_size() sp_rank = get_sequence_parallel_rank() - original_shape = hidden_states.shape hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_rank] - # LongCat uses dual-stream (text + image) with joint attention - # Text embeddings should be replicated across SP ranks for correctness get_forward_context().split_text_embed_in_sp = False - # Debug log (only first forward) - if not hasattr(self, "_sp_forward_logged"): - self._sp_forward_logged = True - logger.info( - f"[LongCat Transformer] SP enabled: sp_size={sp_size}, world_size={sp_world_size}, " - f"rank={sp_rank}, original_shape={original_shape}, chunked_shape={hidden_states.shape}" - ) - else: - if not hasattr(self, "_sp_forward_logged"): - self._sp_forward_logged = True - logger.info(f"[LongCat Transformer] SP disabled: sp_size={sp_size}") hidden_states = self.x_embedder(hidden_states) From 214c0324555ee607bbf781a3e2597517f2598382 Mon Sep 17 00:00:00 2001 From: Lancer Date: Wed, 11 Feb 2026 14:20:55 +0800 Subject: [PATCH 05/11] upd Signed-off-by: Lancer --- tests/diffusion/cache/test_cache_backends.py | 106 +++++++++++ .../diffusion/cache/magcache/__init__.py | 6 +- vllm_omni/diffusion/cache/magcache/backend.py | 85 ++++----- vllm_omni/diffusion/cache/magcache/config.py | 12 +- vllm_omni/diffusion/cache/magcache/hook.py | 57 ++++-- .../diffusion/cache/magcache/strategy.py | 174 ++++++++++-------- 6 files changed, 294 insertions(+), 146 deletions(-) diff --git a/tests/diffusion/cache/test_cache_backends.py b/tests/diffusion/cache/test_cache_backends.py index a9312f4b1ad..872b062257b 100644 --- a/tests/diffusion/cache/test_cache_backends.py +++ b/tests/diffusion/cache/test_cache_backends.py @@ -18,6 +18,7 @@ from vllm_omni.diffusion.cache.cache_dit_backend import ( CacheDiTBackend, ) +from vllm_omni.diffusion.cache.magcache.backend import MagCacheBackend from vllm_omni.diffusion.cache.selector import get_cache_backend from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend from vllm_omni.diffusion.data import DiffusionCacheConfig @@ -221,3 +222,108 @@ def test_get_cache_backend_invalid(self): """Test getting invalid backend raises error.""" with pytest.raises(ValueError, match="Unsupported cache backend"): get_cache_backend("invalid_backend", {}) + + +class TestMagCacheBackend: + """Test MagCacheBackend implementation.""" + + from vllm_omni.diffusion.cache.magcache.backend import MagCacheBackend + + def test_init(self): + """Test initialization.""" + config = DiffusionCacheConfig(threshold=0.1, max_skip_steps=2, calibrate=True) + backend = MagCacheBackend(config) + assert backend.config.threshold == 0.1 + assert backend.config.max_skip_steps == 2 + assert backend.enabled is False + + @patch("vllm_omni.diffusion.cache.magcache.backend.apply_mag_cache_hook") + def test_enable(self, mock_apply_hook): + """Test enabling MagCache on pipeline.""" + mock_pipeline = Mock() + mock_pipeline.__class__.__name__ = "FluxPipeline" + mock_transformer = Mock() + mock_transformer.__class__.__name__ = "FluxTransformer2DModel" + mock_pipeline.transformer = mock_transformer + + mock_ratios = [1.0] * 28 + config = DiffusionCacheConfig( + mag_ratios=mock_ratios, + num_inference_steps=28, + threshold=0.06, + max_skip_steps=3, + retention_ratio=0.2, + ) + backend = MagCacheBackend(config) + backend.enable(mock_pipeline) + + assert backend.enabled is True + mock_apply_hook.assert_called_once() + + call_args = mock_apply_hook.call_args + assert call_args[0][0] == mock_transformer + + @patch("vllm_omni.diffusion.cache.magcache.backend.apply_mag_cache_hook") + def test_enable_with_calibration(self, mock_apply_hook): + """Test enabling MagCache in calibration mode.""" + mock_pipeline = Mock() + mock_pipeline.__class__.__name__ = "FluxPipeline" + mock_transformer = Mock() + mock_transformer.__class__.__name__ = "FluxTransformer2DModel" + mock_pipeline.transformer = mock_transformer + + config = DiffusionCacheConfig( + calibrate=True, + num_inference_steps=28, + threshold=0.06, + max_skip_steps=3, + retention_ratio=0.2, + ) + backend = MagCacheBackend(config) + backend.enable(mock_pipeline) + + assert backend.enabled is True + mock_apply_hook.assert_called_once() + + def test_refresh(self): + """Test refreshing MagCache state calls enable when not registered.""" + mock_pipeline = Mock() + mock_pipeline.__class__.__name__ = "FluxPipeline" + mock_transformer = Mock() + mock_transformer.__class__.__name__ = "FluxTransformer2DModel" + mock_pipeline.transformer = mock_transformer + + mock_transformer.named_children = Mock(return_value=[]) + + mock_ratios = [1.0] * 28 + config = DiffusionCacheConfig( + mag_ratios=mock_ratios, + num_inference_steps=28, + ) + backend = MagCacheBackend(config) + + assert backend._registered is False + backend.refresh(mock_pipeline, num_inference_steps=50) + assert backend._registered is True + + def test_is_enabled(self): + """Test is_enabled method.""" + mock_ratios = [1.0] * 28 + config = DiffusionCacheConfig(mag_ratios=mock_ratios) + backend = MagCacheBackend(config) + assert backend.is_enabled() is False + + def test_get_mag_cache_backend(self): + """Test getting MagCache backend via selector.""" + mock_ratios = [1.0] * 28 + config_dict = { + "mag_ratios": mock_ratios, + "num_inference_steps": 28, + "threshold": 0.06, + "max_skip_steps": 3, + "retention_ratio": 0.2, + } + backend = get_cache_backend("mag_cache", config_dict) + assert backend is not None + assert isinstance(backend, MagCacheBackend) + assert backend.config.threshold == 0.06 diff --git a/vllm_omni/diffusion/cache/magcache/__init__.py b/vllm_omni/diffusion/cache/magcache/__init__.py index 2fb03ec3293..835a2a306d4 100644 --- a/vllm_omni/diffusion/cache/magcache/__init__.py +++ b/vllm_omni/diffusion/cache/magcache/__init__.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm_omni.diffusion.cache.magcache.backend import CUSTOM_MAG_CACHE_ENABLERS from vllm_omni.diffusion.cache.magcache.config import MagCacheConfig from vllm_omni.diffusion.cache.magcache.hook import ( MagCacheBlockHook, @@ -14,10 +13,11 @@ MagCacheContext, MagCacheStrategy, MagCacheStrategyRegistry, + get_strategy, + register_strategy, ) __all__ = [ - "CUSTOM_MAG_CACHE_ENABLERS", "FluxMagCacheStrategy", "MagCacheBlockHook", "MagCacheConfig", @@ -27,4 +27,6 @@ "MagCacheStrategy", "MagCacheStrategyRegistry", "apply_mag_cache_hook", + "get_strategy", + "register_strategy", ] diff --git a/vllm_omni/diffusion/cache/magcache/backend.py b/vllm_omni/diffusion/cache/magcache/backend.py index e389ab5b1ab..bc3149b11c8 100644 --- a/vllm_omni/diffusion/cache/magcache/backend.py +++ b/vllm_omni/diffusion/cache/magcache/backend.py @@ -18,25 +18,13 @@ from vllm_omni.diffusion.cache.magcache.hook import ( apply_mag_cache_hook, ) +from vllm_omni.diffusion.cache.magcache.strategy import ( + get_strategy, +) +from vllm_omni.diffusion.data import DiffusionCacheConfig logger = init_logger(__name__) -CUSTOM_MAG_CACHE_ENABLERS = {} - - -def _register_pipeline_magcache( - pipeline: Any, - magcache_config: MagCacheConfig, -) -> None: - """Apply MagCache hooks to transformer using pre-built MagCacheConfig. - - Args: - pipeline: Diffusion pipeline instance. - magcache_config: Pre-configured MagCacheConfig with all parameters set. - """ - transformer = pipeline.transformer - apply_mag_cache_hook(transformer, magcache_config) - class MagCacheBackend(CacheBackend): """ @@ -65,6 +53,12 @@ class MagCacheBackend(CacheBackend): >>> backend.refresh(pipeline, num_inference_steps=50) """ + def __init__(self, config: DiffusionCacheConfig): + super().__init__(config) + self._registered = False + self._magcache_config: MagCacheConfig | None = None + self._transformer_id: int | None = None + def enable(self, pipeline: Any) -> None: """Enable MagCache on transformer using hooks. @@ -76,35 +70,33 @@ def enable(self, pipeline: Any) -> None: - transformer: pipeline.transformer - transformer_type: pipeline.transformer.__class__.__name__ """ - from vllm_omni.diffusion.cache.magcache.strategy import ( - MagCacheStrategyRegistry, - ) - - pipeline_type = pipeline.__class__.__name__ transformer = pipeline.transformer transformer_type = transformer.__class__.__name__ - num_inference_steps = self.config.num_inference_steps - if num_inference_steps is None: - num_inference_steps = 28 + num_inference_steps = self.config.num_inference_steps or 28 mag_ratios = self.config.mag_ratios - if mag_ratios is None: - strategy = MagCacheStrategyRegistry.get_if_exists(transformer_type) - if strategy is not None: - original_ratios = strategy.mag_ratios + strategy = None + + if mag_ratios is None and not self.config.calibrate: + strategy = get_strategy(transformer_type) + original_ratios = strategy.mag_ratios + + if len(original_ratios) != num_inference_steps and hasattr(strategy, "nearest_interp"): + mag_ratios = strategy.nearest_interp(original_ratios, num_inference_steps) + logger.info( + f"MagCache: Interpolated mag_ratios from {len(original_ratios)} to {num_inference_steps} steps" + ) + else: + mag_ratios = original_ratios if len(original_ratios) != num_inference_steps: - if hasattr(strategy, "nearest_interp"): - mag_ratios = strategy.nearest_interp(original_ratios, num_inference_steps) - logger.info( - f"MagCache: Interpolated mag_ratios from {len(original_ratios)} " - f"to {num_inference_steps} steps" - ) - else: - mag_ratios = original_ratios - else: - mag_ratios = original_ratios - logger.info(f"MagCache: Using default mag_ratios from strategy '{transformer_type}'") + logger.warning( + f"MagCache: mag_ratios length ({len(original_ratios)}) != " + f"num_inference_steps ({num_inference_steps}), " + f"this may cause unexpected behavior" + ) + + logger.info(f"MagCache: Using mag_ratios from {type(strategy).__name__}") if mag_ratios is None and not self.config.calibrate: raise ValueError( @@ -112,7 +104,7 @@ def enable(self, pipeline: Any) -> None: f"For {transformer_type}, you need to provide mag_ratios or run in calibrate mode." ) - magcache_config = MagCacheConfig( + self._magcache_config = MagCacheConfig( transformer_type=transformer_type, threshold=self.config.threshold, max_skip_steps=self.config.max_skip_steps, @@ -121,16 +113,9 @@ def enable(self, pipeline: Any) -> None: calibrate=self.config.calibrate, mag_ratios=mag_ratios if not self.config.calibrate else None, ) - - self._registered = False - self._magcache_config = magcache_config self._transformer_id = id(transformer) - if pipeline_type in CUSTOM_MAG_CACHE_ENABLERS: - logger.info(f"Using custom MagCache enabler for model: {pipeline_type}") - CUSTOM_MAG_CACHE_ENABLERS[pipeline_type](pipeline, magcache_config) - else: - _register_pipeline_magcache(pipeline, magcache_config) + apply_mag_cache_hook(transformer, self._magcache_config, strategy=strategy) self._registered = True self.enabled = True @@ -175,12 +160,10 @@ def refresh(self, pipeline: Any, num_inference_steps: int) -> None: if not blocks_with_hooks: logger.warning("No hooks found on transformer blocks, re-registering") - _register_pipeline_magcache(pipeline, self._magcache_config) + apply_mag_cache_hook(transformer, self._magcache_config) self._transformer_id = current_transformer_id else: for name, block, registry in blocks_with_hooks: - if hasattr(block, "do_true_cfg"): - delattr(block, "do_true_cfg") for hook in registry._hooks.values(): if hasattr(hook, "reset_state"): hook.reset_state(block) diff --git a/vllm_omni/diffusion/cache/magcache/config.py b/vllm_omni/diffusion/cache/magcache/config.py index 229783d3eaa..7d6e6fd88d3 100644 --- a/vllm_omni/diffusion/cache/magcache/config.py +++ b/vllm_omni/diffusion/cache/magcache/config.py @@ -21,11 +21,11 @@ class MagCacheConfig: Args: threshold: Accumulated error threshold. Higher = more aggressive skipping (faster, lower quality). - Default: 0.06 + Default: 0.24 max_skip_steps: Max consecutive skip steps (K). - Default: 3 + Default: 5 retention_ratio: Fraction of initial steps where skipping is disabled (stability). - Default: 0.2 + Default: 0.1 num_inference_steps: Total inference steps. Required for retention step calculation. Default: 28 mag_ratios: Pre-computed magnitude ratios per step. Calibrate or use strategy defaults. @@ -36,9 +36,9 @@ class MagCacheConfig: Default: "FluxTransformer2DModel" """ - threshold: float = 0.06 - max_skip_steps: int = 3 - retention_ratio: float = 0.2 + threshold: float = 0.24 + max_skip_steps: int = 5 + retention_ratio: float = 0.1 num_inference_steps: int = 28 mag_ratios: torch.Tensor | list[float] | None = None calibrate: bool = False diff --git a/vllm_omni/diffusion/cache/magcache/hook.py b/vllm_omni/diffusion/cache/magcache/hook.py index 168813614c0..33926468698 100644 --- a/vllm_omni/diffusion/cache/magcache/hook.py +++ b/vllm_omni/diffusion/cache/magcache/hook.py @@ -24,7 +24,7 @@ import torch import torch.nn.functional as F -from diffusers.hooks._helpers import TransformerBlockRegistry +from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry from diffusers.utils.torch_utils import unwrap_module from vllm_omni.diffusion.cache.magcache.config import MagCacheConfig @@ -53,14 +53,26 @@ def __init__(self, state_manager: StateManager, config: MagCacheConfig, strategy def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: unwrapped_module = unwrap_module(module) - self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) - self.state_manager.set_context("inference") + block_class = unwrapped_module.__class__ + + try: + self._metadata = TransformerBlockRegistry.get(block_class) + except ValueError: + TransformerBlockRegistry.register( + model_class=block_class, + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + self._metadata = TransformerBlockRegistry.get(block_class) + return module @torch.compiler.disable def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self.state_manager._current_context is None: - self.state_manager.set_context("inference") + self.state_manager.set_context("magcache") if hasattr(self._metadata, "hidden_states_argument_name"): arg_name = self._metadata.hidden_states_argument_name @@ -166,7 +178,9 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): return self.log_cache_hit(state, output, None) else: output = self.fn_ref.original_forward(*args, **kwargs) - return self.log_cache_miss(state, output) + result = self.log_cache_miss(state, output) + + return result def log_cache_hit(self, state: MagCacheState, output, ret): step = state.step_index @@ -225,7 +239,20 @@ def __init__( def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: unwrapped_module = unwrap_module(module) - self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) + block_class = unwrapped_module.__class__ + + try: + self._metadata = TransformerBlockRegistry.get(block_class) + except ValueError: + TransformerBlockRegistry.register( + model_class=block_class, + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + self._metadata = TransformerBlockRegistry.get(block_class) + return module def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: @@ -381,26 +408,34 @@ def advance_step(self, state: MagCacheState) -> None: state._is_first_step = True -def apply_mag_cache_hook(module: torch.nn.Module, config: MagCacheConfig) -> None: +def apply_mag_cache_hook( + module: torch.nn.Module, + config: MagCacheConfig, + strategy: MagCacheStrategy | None = None, +) -> None: """Apply MagCache optimization to a transformer module. Args: module: Transformer model to optimize (e.g., FluxTransformer2DModel) config: MagCacheConfig specifying caching parameters + strategy: Optional strategy to use. If None, will be looked up from registry. """ HookRegistry.check_if_exists_or_initialize(module) transformer_type = config.transformer_type - strategy = MagCacheStrategyRegistry.get_if_exists(transformer_type) + if strategy is None: + strategy = MagCacheStrategyRegistry.get_if_exists(transformer_type) + if strategy is None: logger.warning( f"MagCache: No strategy found for '{transformer_type}'. " f"Using default behavior. Available strategies: {list(MagCacheStrategyRegistry._registry.keys())}" ) else: - logger.info(f"MagCache: Using strategy '{transformer_type}' for optimization") - if hasattr(strategy, "register_blocks"): - strategy.register_blocks() + strategy_name = type(strategy).__name__ + logger.info(f"MagCache: Applying {strategy_name} for '{transformer_type}'") + if hasattr(type(strategy), "register_blocks"): + type(strategy).register_blocks() state_manager = StateManager(MagCacheState, (), {}) remaining_blocks = [] diff --git a/vllm_omni/diffusion/cache/magcache/strategy.py b/vllm_omni/diffusion/cache/magcache/strategy.py index 298e3756a16..df780aadcf5 100644 --- a/vllm_omni/diffusion/cache/magcache/strategy.py +++ b/vllm_omni/diffusion/cache/magcache/strategy.py @@ -83,12 +83,6 @@ class MagCacheStrategy(ABC): Implement this class to add support for new model architectures. """ - @property - @abstractmethod - def transformer_type(self) -> str: - """Returns the transformer class name this strategy supports.""" - pass - @property @abstractmethod def mag_ratios(self) -> torch.Tensor: @@ -102,7 +96,6 @@ def mag_ratios(self) -> torch.Tensor: """ pass - @abstractmethod def compute_residual( self, output: torch.Tensor, @@ -110,6 +103,9 @@ def compute_residual( ) -> torch.Tensor: """Compute residual between block output and input. + Default implementation: output - head_input. + Override this method for models with non-standard output formats. + Args: output: Output from transformer blocks. head_input: Input to the first block. @@ -117,9 +113,8 @@ def compute_residual( Returns: Residual tensor for caching. """ - pass + return output - head_input - @abstractmethod def apply_residual( self, hidden_states: torch.Tensor, @@ -127,6 +122,9 @@ def apply_residual( ) -> torch.Tensor: """Apply cached residual to hidden states. + Default implementation: add residual to hidden_states. + This works for most model architectures. + Args: hidden_states: Current hidden states. residual: Cached residual to apply. @@ -134,7 +132,7 @@ def apply_residual( Returns: Hidden states with residual added. """ - pass + return hidden_states + residual def apply_residual_tuple( self, @@ -158,7 +156,6 @@ def apply_residual_tuple( h_res, e_res = residual return hidden_states + h_res, encoder_hidden_states + e_res - @abstractmethod def compute_calibration_metrics( self, current_residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor], @@ -166,6 +163,9 @@ def compute_calibration_metrics( ) -> tuple[float, float, float]: """Compute calibration metrics for mag_ratios generation. + Default implementation computes norm ratios and cosine dissimilarity. + Override this method for models with custom metric computation. + Args: current_residual: Residual from the current step. previous_residual: Residual from the previous step (None for first step). @@ -176,7 +176,19 @@ def compute_calibration_metrics( - norm_std: Standard deviation of the norm ratios - cos_dis: Mean cosine dissimilarity (1 - cosine_similarity) """ - pass + import torch.nn.functional as F + + if previous_residual is None: + return 1.0, 0.0, 0.0 + + curr_norm = torch.linalg.norm(current_residual.float(), dim=-1) + prev_norm = torch.linalg.norm(previous_residual.float(), dim=-1) + + ratio = (curr_norm / (prev_norm + 1e-8)).mean().item() + std = (curr_norm / (prev_norm + 1e-8)).std().item() + cos_dis = (1 - F.cosine_similarity(current_residual, previous_residual, dim=-1, eps=1e-8)).mean().item() + + return ratio, std, cos_dis def get_calibration_metrics_names(self) -> tuple[str, str, str]: """Return the names of calibration metrics for logging. @@ -200,7 +212,8 @@ class FluxMagCacheStrategy(MagCacheStrategy): This strategy provides: - mag_ratios: Pre-computed magnitude ratios for Flux (28 steps) - - compute_calibration_metrics: FLUX-specific metric computation + - compute_residual: Handles tuple output format + - apply_residual_tuple: Handles decoder residual only """ FLUX_MAG_RATIOS = torch.tensor( @@ -236,75 +249,38 @@ class FluxMagCacheStrategy(MagCacheStrategy): ] ) - @property - def transformer_type(self) -> str: - return "FluxTransformer2DModel" - @property def mag_ratios(self) -> torch.Tensor: """Return default mag_ratios for Flux model.""" return self.FLUX_MAG_RATIOS - def compute_calibration_metrics( - self, - current_residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor], - previous_residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None, - ) -> tuple[float, float, float]: - """Compute calibration metrics for FLUX model.""" - import torch.nn.functional as F - - if previous_residual is None: - return 1.0, 0.0, 0.0 - - curr_norm = torch.linalg.norm(current_residual.float(), dim=-1) - prev_norm = torch.linalg.norm(previous_residual.float(), dim=-1) - - ratio = (curr_norm / (prev_norm + 1e-8)).mean().item() - std = (curr_norm / (prev_norm + 1e-8)).std().item() - cos_dis = (1 - F.cosine_similarity(current_residual, previous_residual, dim=-1, eps=1e-8)).mean().item() - - return ratio, std, cos_dis - - @staticmethod - def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor: - """Interpolate mag_ratios to target length using nearest neighbor.""" - src_length = len(src_array) - if target_length == 1: - return src_array[-1:] - - scale = (src_length - 1) / (target_length - 1) - grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32) - mapped_indices = torch.round(grid * scale).long() - - return src_array[mapped_indices] - def compute_residual( self, output: torch.Tensor, head_input: torch.Tensor, ) -> torch.Tensor: - """Compute residual for Flux single transformer blocks.""" + """Compute residual for Flux output format (tuple or single tensor). + + Flux single transformer blocks return a tuple, so we extract + the decoder output (index 1) before computing residual. + """ if isinstance(output, tuple): decoder_output = output[1] if len(output) > 1 else output[0] else: decoder_output = output - head_input return decoder_output - def apply_residual( - self, - hidden_states: torch.Tensor, - residual: torch.Tensor, - ) -> torch.Tensor: - """Apply residual by adding to hidden states.""" - return hidden_states + residual - def apply_residual_tuple( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, residual: tuple[torch.Tensor, torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: - """Apply residual tuple (for compatibility with hook interface).""" + """Apply residual tuple for Flux - only add decoder residual. + + Flux has separate image and text processing, so the residual + is only applied to the decoder (image) branch. + """ if isinstance(residual, tuple): decoder_residual = residual[1] else: @@ -316,18 +292,17 @@ def apply_residual_tuple( return output, enc_output @staticmethod - def register_blocks() -> None: - """Register vLLM-Omni Flux transformer blocks with TransformerBlockRegistry.""" - try: - from vllm_omni.diffusion.models.flux.flux_transformer import ( - FluxSingleTransformerBlock, - FluxTransformerBlock, - ) + def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor: + """Interpolate mag_ratios to target length using nearest neighbor.""" + src_length = len(src_array) + if target_length == 1: + return src_array[-1:] - register_transformer_block(FluxTransformerBlock) - register_transformer_block(FluxSingleTransformerBlock) - except ImportError: - pass + scale = (src_length - 1) / (target_length - 1) + grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32) + mapped_indices = torch.round(grid * scale).long() + + return src_array[mapped_indices] class MagCacheStrategyRegistry: @@ -336,9 +311,14 @@ class MagCacheStrategyRegistry: _registry: dict[str, MagCacheStrategy] = {} @classmethod - def register(cls, strategy: MagCacheStrategy) -> None: - """Register a strategy.""" - cls._registry[strategy.transformer_type] = strategy + def register(cls, name: str, strategy: MagCacheStrategy) -> None: + """Register a strategy with explicit name. + + Args: + name: Transformer model type identifier (e.g., "FluxTransformer2DModel") + strategy: MagCacheStrategy instance + """ + cls._registry[name] = strategy @classmethod def get(cls, transformer_type: str) -> MagCacheStrategy: @@ -354,5 +334,47 @@ def get_if_exists(cls, transformer_type: str) -> MagCacheStrategy | None: return cls._registry.get(transformer_type) -# Register default strategies -MagCacheStrategyRegistry.register(FluxMagCacheStrategy()) +MagCacheStrategyRegistry.register("FluxTransformer2DModel", FluxMagCacheStrategy()) + + +def register_strategy( + transformer_cls_name: str, + strategy: MagCacheStrategy, +) -> None: + """Register a MagCache strategy for a model type. + + This allows extending MagCache support to new models without modifying + the core MagCache code. + + Args: + transformer_cls_name: Transformer model type identifier (class name or type string) + Must match pipeline.transformer.__class__.__name__ + strategy: MagCacheStrategy instance for this model type + + Example: + >>> class MyModelMagCacheStrategy(MagCacheStrategy): + ... @property + ... def mag_ratios(self): + ... return torch.tensor([...]) + >>> register_strategy("MyModelTransformer", MyModelMagCacheStrategy()) + """ + MagCacheStrategyRegistry.register(transformer_cls_name, strategy) + + +def get_strategy(transformer_cls_name: str) -> MagCacheStrategy: + """Get strategy function for given transformer class. + + This function looks up the strategy based on the exact transformer_cls_name string, + which should match the transformer type in the pipeline (i.e., pipeline.transformer.__class__.__name__). + + Args: + transformer_cls_name: Transformer class name (e.g., "FluxTransformer2DModel") + Must exactly match a registered strategy. + + Returns: + MagCacheStrategy instance for the model + + Raises: + ValueError: If model type not found in registry + """ + return MagCacheStrategyRegistry.get(transformer_cls_name) From 1e222bfe16226b849b651991d3f9930c19043c24 Mon Sep 17 00:00:00 2001 From: Lancer Date: Wed, 11 Feb 2026 16:53:04 +0800 Subject: [PATCH 06/11] upd Signed-off-by: Lancer --- .../diffusion/cache/magcache/__init__.py | 4 +- vllm_omni/diffusion/cache/magcache/hook.py | 76 +++++-- .../diffusion/cache/magcache/strategy.py | 212 +++++++++++++----- 3 files changed, 221 insertions(+), 71 deletions(-) diff --git a/vllm_omni/diffusion/cache/magcache/__init__.py b/vllm_omni/diffusion/cache/magcache/__init__.py index 835a2a306d4..08f0d88fa15 100644 --- a/vllm_omni/diffusion/cache/magcache/__init__.py +++ b/vllm_omni/diffusion/cache/magcache/__init__.py @@ -9,8 +9,8 @@ apply_mag_cache_hook, ) from vllm_omni.diffusion.cache.magcache.strategy import ( + Flux2MagCacheStrategy, FluxMagCacheStrategy, - MagCacheContext, MagCacheStrategy, MagCacheStrategyRegistry, get_strategy, @@ -18,10 +18,10 @@ ) __all__ = [ + "Flux2MagCacheStrategy", "FluxMagCacheStrategy", "MagCacheBlockHook", "MagCacheConfig", - "MagCacheContext", "MagCacheHeadHook", "MagCacheState", "MagCacheStrategy", diff --git a/vllm_omni/diffusion/cache/magcache/hook.py b/vllm_omni/diffusion/cache/magcache/hook.py index 33926468698..2a3095d28b7 100644 --- a/vllm_omni/diffusion/cache/magcache/hook.py +++ b/vllm_omni/diffusion/cache/magcache/hook.py @@ -58,13 +58,26 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: try: self._metadata = TransformerBlockRegistry.get(block_class) except ValueError: - TransformerBlockRegistry.register( - model_class=block_class, - metadata=TransformerBlockMetadata( - return_hidden_states_index=1, - return_encoder_hidden_states_index=0, - ), - ) + if self._strategy is not None: + metadata = self._strategy.register_block_metadata(block_class) + if metadata is not None: + TransformerBlockRegistry.register(model_class=block_class, metadata=metadata) + else: + TransformerBlockRegistry.register( + model_class=block_class, + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + else: + TransformerBlockRegistry.register( + model_class=block_class, + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) self._metadata = TransformerBlockRegistry.get(block_class) return module @@ -244,13 +257,26 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: try: self._metadata = TransformerBlockRegistry.get(block_class) except ValueError: - TransformerBlockRegistry.register( - model_class=block_class, - metadata=TransformerBlockMetadata( - return_hidden_states_index=1, - return_encoder_hidden_states_index=0, - ), - ) + if self._strategy is not None: + metadata = self._strategy.register_block_metadata(block_class) + if metadata is not None: + TransformerBlockRegistry.register(model_class=block_class, metadata=metadata) + else: + TransformerBlockRegistry.register( + model_class=block_class, + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + else: + TransformerBlockRegistry.register( + model_class=block_class, + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) self._metadata = TransformerBlockRegistry.get(block_class) return module @@ -266,6 +292,10 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): state: MagCacheState = self.state_manager.get_state() if not state.should_compute: + res = state.previous_residual + if res is None: + res = torch.zeros_like(args[0]) + if hasattr(self._metadata, "hidden_states_argument_name"): arg_name = self._metadata.hidden_states_argument_name else: @@ -279,15 +309,27 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( "encoder_hidden_states", args, kwargs ) + + if self._strategy is not None: + out_hidden, enc_out = self._strategy.apply_residual_tuple(hidden_states, encoder_hidden_states, res) + else: + out_hidden = hidden_states + res + enc_out = encoder_hidden_states + max_idx = max( self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index ) ret_list = [None] * (max_idx + 1) - ret_list[self._metadata.return_hidden_states_index] = hidden_states - ret_list[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states + ret_list[self._metadata.return_hidden_states_index] = out_hidden + ret_list[self._metadata.return_encoder_hidden_states_index] = enc_out return tuple(ret_list) - return hidden_states + if self._strategy is not None: + output = self._strategy.apply_residual(hidden_states, res) + else: + output = hidden_states + res + + return output output = self.fn_ref.original_forward(*args, **kwargs) diff --git a/vllm_omni/diffusion/cache/magcache/strategy.py b/vllm_omni/diffusion/cache/magcache/strategy.py index df780aadcf5..5c771b5a77a 100644 --- a/vllm_omni/diffusion/cache/magcache/strategy.py +++ b/vllm_omni/diffusion/cache/magcache/strategy.py @@ -15,60 +15,9 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any import torch -from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry - - -def register_transformer_block( - model_class: type, - return_hidden_states_index: int = 1, - return_encoder_hidden_states_index: int = 0, -) -> None: - """Register a transformer block class with the TransformerBlockRegistry. - - Args: - model_class: The transformer block class to register. - return_hidden_states_index: Index of hidden_states in the forward output tuple. - return_encoder_hidden_states_index: Index of encoder_hidden_states in the output. - """ - try: - TransformerBlockRegistry.get(model_class) - except ValueError: - TransformerBlockRegistry.register( - model_class=model_class, - metadata=TransformerBlockMetadata( - return_hidden_states_index=return_hidden_states_index, - return_encoder_hidden_states_index=return_encoder_hidden_states_index, - ), - ) - - -@dataclass -class MagCacheContext: - """ - Context object containing model-specific information for MagCache. - - Attributes: - hidden_states: Current hidden states before transformer blocks. - encoder_hidden_states: Optional encoder states (None for single-stream). - temb: Timestep embedding tensor. - head_block_input: Input to the first transformer block (for residual calculation). - run_transformer_blocks: Callable to run transformer blocks. - run_single_transformer_blocks: Callable to run single transformer blocks. - postprocess: Callable to produce final output from block outputs. - """ - - hidden_states: torch.Tensor - encoder_hidden_states: torch.Tensor | None - temb: torch.Tensor - head_block_input: torch.Tensor | None - run_transformer_blocks: Callable[[], tuple[torch.Tensor, torch.Tensor]] - run_single_transformer_blocks: Callable[[], torch.Tensor] - postprocess: Callable[[torch.Tensor, torch.Tensor], Any] +from diffusers.hooks._helpers import TransformerBlockMetadata class MagCacheStrategy(ABC): @@ -96,6 +45,20 @@ def mag_ratios(self) -> torch.Tensor: """ pass + def register_block_metadata(self, block_class: type) -> TransformerBlockMetadata | None: + """Register model-specific transformer block metadata. + + Override this method to provide custom metadata for transformer blocks + that have non-standard output formats (e.g., tuple returns). + + Args: + block_class: The transformer block class to register. + + Returns: + TransformerBlockMetadata if custom registration is needed, None otherwise. + """ + return None + def compute_residual( self, output: torch.Tensor, @@ -305,6 +268,149 @@ def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor: return src_array[mapped_indices] +class Flux2MagCacheStrategy(FluxMagCacheStrategy): + """MagCache strategy for Flux2 model. + + Flux2 shares the same dual-stream architecture as Flux, but may have + different tensor shapes in some transformer blocks, requiring special + handling in residual computation. + """ + + FLUX2_MAG_RATIOS = torch.tensor( + [ + 1.0, + 0.96528, + 1.11559, + 1.0565, + 1.00425, + 1.0805, + 0.98616, + 1.09289, + 1.03196, + 1.06679, + 1.03941, + 1.05375, + 1.03128, + 1.05349, + 1.01983, + 1.05535, + 1.0662, + 1.05748, + 1.00318, + 1.05222, + 1.04556, + 1.0506, + 1.05058, + 1.05219, + 1.02025, + 1.05052, + 1.04143, + 1.0498, + ] + ) + + @property + def mag_ratios(self) -> torch.Tensor: + """Return default mag_ratios for Flux2 model.""" + return self.FLUX2_MAG_RATIOS + + def register_block_metadata(self, block_class: type) -> TransformerBlockMetadata | None: + """Register Flux2-specific block metadata based on block type. + + Flux2 has two block types with different output formats: + - Dual-stream (Flux2TransformerBlock): returns (encoder_hidden_states, hidden_states) + - Single-stream (Flux2SingleTransformerBlock): returns single tensor + """ + class_name = block_class.__name__ + is_single_stream = "Single" in class_name + + if is_single_stream: + return TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ) + else: + return TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ) + + def compute_residual( + self, + output: torch.Tensor, + head_input: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """Compute residual for Flux2 with dual-stream support. + + For dual-stream blocks, computes residual for both encoder and decoder branches. + For single-stream blocks: if shapes match, computes output - input; otherwise returns output. + """ + if isinstance(output, tuple): + enc_output, dec_output = output[0], output[1] + + if isinstance(head_input, tuple): + enc_head, dec_head = head_input[0], head_input[1] + else: + enc_head = head_input[:, : enc_output.shape[1], ...] + dec_head = head_input[:, enc_output.shape[1] :, ...] + + enc_residual = enc_output - enc_head + dec_residual = dec_output - dec_head + + return (enc_residual, dec_residual) + else: + if output.shape == head_input.shape: + return output - head_input + else: + return output + + def apply_residual_tuple( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + residual: tuple[torch.Tensor, torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """Apply residual tuple for Flux2 with shape mismatch handling. + + For Flux2 dual-stream blocks, handles the case where hidden_states + and decoder_residual may have different shapes. + """ + if isinstance(residual, tuple): + decoder_residual = residual[1] + else: + decoder_residual = residual + + if hidden_states.shape == decoder_residual.shape: + output = hidden_states + decoder_residual + else: + output = hidden_states + enc_output = encoder_hidden_states + + return output, enc_output + + def apply_residual( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + ) -> torch.Tensor: + """Apply residual for Flux2. + + For single-stream blocks: if shapes match, adds residual; otherwise returns input. + For dual-stream blocks: applies decoder residual only. + """ + if isinstance(residual, tuple): + dec_residual = residual[1] + if hidden_states.shape == dec_residual.shape: + return hidden_states + dec_residual + else: + return hidden_states + else: + if residual.shape == hidden_states.shape: + return hidden_states + residual + else: + return hidden_states + + class MagCacheStrategyRegistry: """Registry for MagCache strategies by transformer type.""" @@ -336,6 +442,8 @@ def get_if_exists(cls, transformer_type: str) -> MagCacheStrategy | None: MagCacheStrategyRegistry.register("FluxTransformer2DModel", FluxMagCacheStrategy()) +MagCacheStrategyRegistry.register("Flux2Transformer2DModel", Flux2MagCacheStrategy()) + def register_strategy( transformer_cls_name: str, From 8e1285e34630d8eddd8cf4f32e1ee686177f19d3 Mon Sep 17 00:00:00 2001 From: Lancer Date: Thu, 26 Feb 2026 20:47:33 +0800 Subject: [PATCH 07/11] upd Signed-off-by: Lancer --- vllm_omni/diffusion/cache/magcache/backend.py | 22 +++++++++---------- vllm_omni/diffusion/cache/magcache/hook.py | 12 +++++----- vllm_omni/diffusion/cache/magcache/state.py | 1 + .../diffusion/cache/magcache/strategy.py | 4 ++-- vllm_omni/diffusion/data.py | 18 +++++++-------- vllm_omni/diffusion/hooks/base.py | 11 +++++++++- vllm_omni/entrypoints/cli/serve.py | 4 ++-- 7 files changed, 41 insertions(+), 31 deletions(-) diff --git a/vllm_omni/diffusion/cache/magcache/backend.py b/vllm_omni/diffusion/cache/magcache/backend.py index bc3149b11c8..b5604e478a1 100644 --- a/vllm_omni/diffusion/cache/magcache/backend.py +++ b/vllm_omni/diffusion/cache/magcache/backend.py @@ -39,14 +39,14 @@ class MagCacheBackend(CacheBackend): Example: >>> from vllm_omni.diffusion.data import DiffusionCacheConfig - >>> from vllm_omni.diffusion.cache.magcache import MagCacheConfig + >>> from vllm_omni.diffusion.cache.magcache.config import MagCacheConfig >>> from vllm_omni.diffusion.cache.magcache.strategy import FluxMagCacheStrategy >>> cache_config = DiffusionCacheConfig( ... mag_ratios=FluxMagCacheStrategy.FLUX_MAG_RATIOS, ... num_inference_steps=28, - ... threshold=0.06, - ... max_skip_steps=3, - ... retention_ratio=0.2, + ... mag_threshold=0.24, + ... mag_max_skip_steps=5, + ... mag_retention_ratio=0.1, ... ) >>> backend = MagCacheBackend(cache_config) >>> backend.enable(pipeline) @@ -78,7 +78,7 @@ def enable(self, pipeline: Any) -> None: mag_ratios = self.config.mag_ratios strategy = None - if mag_ratios is None and not self.config.calibrate: + if mag_ratios is None and not self.config.mag_calibrate: strategy = get_strategy(transformer_type) original_ratios = strategy.mag_ratios @@ -98,7 +98,7 @@ def enable(self, pipeline: Any) -> None: logger.info(f"MagCache: Using mag_ratios from {type(strategy).__name__}") - if mag_ratios is None and not self.config.calibrate: + if mag_ratios is None and not self.config.mag_calibrate: raise ValueError( f"mag_ratios must be provided for MagCache. " f"For {transformer_type}, you need to provide mag_ratios or run in calibrate mode." @@ -106,12 +106,12 @@ def enable(self, pipeline: Any) -> None: self._magcache_config = MagCacheConfig( transformer_type=transformer_type, - threshold=self.config.threshold, - max_skip_steps=self.config.max_skip_steps, - retention_ratio=self.config.retention_ratio, + threshold=self.config.mag_threshold, + max_skip_steps=self.config.mag_max_skip_steps, + retention_ratio=self.config.mag_retention_ratio, num_inference_steps=num_inference_steps, - calibrate=self.config.calibrate, - mag_ratios=mag_ratios if not self.config.calibrate else None, + calibrate=self.config.mag_calibrate, + mag_ratios=mag_ratios if not self.config.mag_calibrate else None, ) self._transformer_id = id(transformer) diff --git a/vllm_omni/diffusion/cache/magcache/hook.py b/vllm_omni/diffusion/cache/magcache/hook.py index 2a3095d28b7..192e516acc0 100644 --- a/vllm_omni/diffusion/cache/magcache/hook.py +++ b/vllm_omni/diffusion/cache/magcache/hook.py @@ -175,6 +175,10 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): if diff > 0: output = hidden_states.clone() output[:, diff:, :] = output[:, diff:, :] + res + else: + output = hidden_states + res + else: + output = hidden_states + res if self._metadata.return_encoder_hidden_states_index is not None: original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( @@ -288,7 +292,7 @@ def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: @torch.compiler.disable def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self.state_manager._current_context is None: - self.state_manager.set_context("inference") + self.state_manager.set_context("magcache") state: MagCacheState = self.state_manager.get_state() if not state.should_compute: @@ -349,11 +353,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): elif out_hidden.shape == in_hidden.shape: residual = out_hidden - in_hidden elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]: - diff = in_hidden.shape[1] - out_hidden.shape[1] - if diff == 0: - residual = out_hidden - in_hidden - else: - residual = out_hidden - in_hidden + residual = out_hidden - in_hidden else: residual = out_hidden diff --git a/vllm_omni/diffusion/cache/magcache/state.py b/vllm_omni/diffusion/cache/magcache/state.py index 28060ec832d..220f3709a6f 100644 --- a/vllm_omni/diffusion/cache/magcache/state.py +++ b/vllm_omni/diffusion/cache/magcache/state.py @@ -32,6 +32,7 @@ def __init__(self) -> None: def reset(self) -> None: """Reset all state variables for a new inference run.""" self.previous_residual = None + self.head_block_input = None self.should_compute = True self.accumulated_ratio = 1.0 self.accumulated_err = 0.0 diff --git a/vllm_omni/diffusion/cache/magcache/strategy.py b/vllm_omni/diffusion/cache/magcache/strategy.py index 5c771b5a77a..7743ff2c75a 100644 --- a/vllm_omni/diffusion/cache/magcache/strategy.py +++ b/vllm_omni/diffusion/cache/magcache/strategy.py @@ -229,9 +229,9 @@ def compute_residual( """ if isinstance(output, tuple): decoder_output = output[1] if len(output) > 1 else output[0] + return decoder_output - head_input else: - decoder_output = output - head_input - return decoder_output + return output - head_input def apply_residual_tuple( self, diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 0c65c316b3e..8e35350e49b 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -138,8 +138,8 @@ class DiffusionCacheConfig: - cache-dit: Fn_compute_blocks, Bn_compute_blocks, max_warmup_steps, residual_diff_threshold, enable_taylorseer, taylorseer_order, scm_steps_mask_policy, scm_steps_policy - - MagCache: threshold, max_skip_steps, retention_ratio, num_inference_steps, - mag_ratios, calibrate + - MagCache: mag_threshold, mag_max_skip_steps, mag_retention_ratio, + mag_ratios, mag_calibrate Example: >>> # From dict (user-facing API) - partial config uses defaults for missing keys @@ -158,16 +158,16 @@ class DiffusionCacheConfig: coefficients: list[float] | None = None # Uses model-specific defaults if None # MagCache parameters [mag_cache only] - # Default: 0.06 threshold for accumulated magnitude error - threshold: float = 0.06 - # Default: 3 maximum consecutive skip steps - max_skip_steps: int = 3 - # Default: 0.2 retention ratio (initial steps that never skip) - retention_ratio: float = 0.2 + # Default: 0.24 threshold for accumulated magnitude error + mag_threshold: float = 0.24 + # Default: 5 maximum consecutive skip steps (K) + mag_max_skip_steps: int = 5 + # Default: 0.1 fraction of initial steps where skipping is disabled (stability) + mag_retention_ratio: float = 0.1 # Default: None magnitude ratios (model-specific, required for inference) mag_ratios: list[float] | None = None # Default: False calibration mode (computes mag_ratios on first run) - calibrate: bool = False + mag_calibrate: bool = False # cache-dit parameters [cache-dit only] # Default: 1 forward compute block (optimized for single-transformer models) diff --git a/vllm_omni/diffusion/hooks/base.py b/vllm_omni/diffusion/hooks/base.py index 91d212c2b2f..b05af519361 100644 --- a/vllm_omni/diffusion/hooks/base.py +++ b/vllm_omni/diffusion/hooks/base.py @@ -13,6 +13,9 @@ from typing import Any import torch.nn as nn +from vllm.logger import init_logger + +logger = init_logger(__name__) class BaseState: @@ -199,7 +202,8 @@ def register_hook(self, name: str, hook: ModelHook) -> None: hook: The hook instance to register. """ if name in self._hooks: - raise ValueError(f"Hook with name '{name}' already exists. Remove it first or use a different name.") + logger.warning(f"Hook with name '{name}' already exists. Overwriting existing hook.") + self.remove_hook(name) hook.initialize_hook(self.module) @@ -288,7 +292,12 @@ def reset(self) -> None: """Reset all hooks and clear the registry. This removes all hooks from the registry and resets each hook's state. + Also restores module.forward to its original implementation. """ for name, hook in list(self._hooks.items()): hook.reset_state(self.module) self._hooks.clear() + + if hasattr(self.module, "_original_forward"): + self.module.forward = self.module._original_forward # type: ignore[attr-defined] + delattr(self.module, "_original_forward") diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index 064ff68eb68..f4748f818c6 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -223,8 +223,8 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu default=None, help="JSON string of cache configuration. " "TeaCache: '{\"rel_l1_thresh\": 0.2}'. " - 'MagCache: \'{"threshold": 0.06, "max_skip_steps": 3, "mag_ratios": [1.0, ...]}\'. ' - "Calibration mode: add '\"calibrate\": true'", + 'MagCache: \'{"mag_threshold": 0.24, "mag_max_skip_steps": 5, "mag_retention_ratio": 0.1}\'. ' + "Calibration mode: add '\"mag_calibrate\": true'", ) omni_config_group.add_argument( "--enable-cache-dit-summary", From 065f3ea61c539cdf4052a52962305be0226e0793 Mon Sep 17 00:00:00 2001 From: Lancer Date: Mon, 2 Mar 2026 19:45:26 +0800 Subject: [PATCH 08/11] upd Signed-off-by: Lancer --- vllm_omni/diffusion/cache/magcache/backend.py | 2 +- vllm_omni/diffusion/cache/magcache/config.py | 10 +++++----- vllm_omni/diffusion/cache/magcache/hook.py | 6 +++--- vllm_omni/entrypoints/omni.py | 6 ++++++ 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/vllm_omni/diffusion/cache/magcache/backend.py b/vllm_omni/diffusion/cache/magcache/backend.py index b5604e478a1..4ad427d3a2f 100644 --- a/vllm_omni/diffusion/cache/magcache/backend.py +++ b/vllm_omni/diffusion/cache/magcache/backend.py @@ -110,7 +110,7 @@ def enable(self, pipeline: Any) -> None: max_skip_steps=self.config.mag_max_skip_steps, retention_ratio=self.config.mag_retention_ratio, num_inference_steps=num_inference_steps, - calibrate=self.config.mag_calibrate, + mag_calibrate=self.config.mag_calibrate, mag_ratios=mag_ratios if not self.config.mag_calibrate else None, ) self._transformer_id = id(transformer) diff --git a/vllm_omni/diffusion/cache/magcache/config.py b/vllm_omni/diffusion/cache/magcache/config.py index 7d6e6fd88d3..6ccab8fa5a0 100644 --- a/vllm_omni/diffusion/cache/magcache/config.py +++ b/vllm_omni/diffusion/cache/magcache/config.py @@ -30,7 +30,7 @@ class MagCacheConfig: Default: 28 mag_ratios: Pre-computed magnitude ratios per step. Calibrate or use strategy defaults. Default: None - calibrate: If True, runs without skipping and logs norm_ratios for calibration. + mag_calibrate: If True, runs without skipping and logs norm_ratios for calibration. Default: False transformer_type: Transformer class name for logging. Default: "FluxTransformer2DModel" @@ -41,7 +41,7 @@ class MagCacheConfig: retention_ratio: float = 0.1 num_inference_steps: int = 28 mag_ratios: torch.Tensor | list[float] | None = None - calibrate: bool = False + mag_calibrate: bool = False transformer_type: str = "FluxTransformer2DModel" def __post_init__(self) -> None: @@ -65,16 +65,16 @@ def __post_init__(self) -> None: if self.num_inference_steps <= 0: raise ValueError(f"num_inference_steps must be positive, got {self.num_inference_steps}") - if not self.calibrate and self.mag_ratios is None: + if not self.mag_calibrate and self.mag_ratios is None: raise ValueError( "mag_ratios must be provided for MagCache inference because these ratios " "are model-dependent. To get them for your model:\n" - "1. Initialize MagCacheConfig(calibrate=True, ...)\n" + "1. Initialize MagCacheConfig(mag_calibrate=True, ...)\n" "2. Run inference on your model once.\n" "3. Copy the printed ratios array and pass it to mag_ratios in the config.\n" "For Flux models, you can import FLUX_MAG_RATIOS from vllm_omni.diffusion.cache.magcache.strategy." ) - if not self.calibrate and self.mag_ratios is not None: + if not self.mag_calibrate and self.mag_ratios is not None: if not torch.is_tensor(self.mag_ratios): self.mag_ratios = torch.tensor(self.mag_ratios) diff --git a/vllm_omni/diffusion/cache/magcache/hook.py b/vllm_omni/diffusion/cache/magcache/hook.py index 192e516acc0..f2b4fadceba 100644 --- a/vllm_omni/diffusion/cache/magcache/hook.py +++ b/vllm_omni/diffusion/cache/magcache/hook.py @@ -104,7 +104,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): should_compute = True - if self.config.calibrate: + if self.config.mag_calibrate: should_compute = True else: current_step = state.step_index @@ -357,7 +357,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): else: residual = out_hidden - if self.config.calibrate: + if self.config.mag_calibrate: self.perform_calibration(state, residual) state.previous_residual = residual @@ -431,7 +431,7 @@ def _get_norm(residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor]) -> tor def advance_step(self, state: MagCacheState) -> None: state.step_index += 1 if state.step_index >= self.config.num_inference_steps: - if self.config.calibrate: + if self.config.mag_calibrate: logger.info("MagCache calibration complete.") logger.info(f"norm_ratios: {state.norm_ratios}") logger.info(f"norm_stds: {state.norm_stds}") diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index dc18978f967..f6b83aadf2d 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -198,6 +198,12 @@ def _get_default_cache_config(self, cache_backend: str | None) -> dict[str, Any] return { "rel_l1_thresh": 0.2, } + if cache_backend == "mag_cache": + return { + "mag_threshold": 0.24, + "mag_max_skip_steps": 5, + "mag_retention_ratio": 0.1, + } return None def _normalize_cache_config(self, cache_backend: str | None, cache_config: Any | None) -> Any | None: From c23da63dee39f97e90a0a6c893f4cd4dc730261f Mon Sep 17 00:00:00 2001 From: Lancer Date: Tue, 3 Mar 2026 08:19:10 +0800 Subject: [PATCH 09/11] upd Signed-off-by: Lancer --- tests/diffusion/cache/test_cache_backends.py | 68 +++++++++++++---- vllm_omni/diffusion/cache/magcache/hook.py | 79 ++++++++++++++++++-- 2 files changed, 128 insertions(+), 19 deletions(-) diff --git a/tests/diffusion/cache/test_cache_backends.py b/tests/diffusion/cache/test_cache_backends.py index 872b062257b..6e01ae4f220 100644 --- a/tests/diffusion/cache/test_cache_backends.py +++ b/tests/diffusion/cache/test_cache_backends.py @@ -231,10 +231,10 @@ class TestMagCacheBackend: def test_init(self): """Test initialization.""" - config = DiffusionCacheConfig(threshold=0.1, max_skip_steps=2, calibrate=True) + config = DiffusionCacheConfig(mag_threshold=0.1, mag_max_skip_steps=2, mag_calibrate=True) backend = MagCacheBackend(config) - assert backend.config.threshold == 0.1 - assert backend.config.max_skip_steps == 2 + assert backend.config.mag_threshold == 0.1 + assert backend.config.mag_max_skip_steps == 2 assert backend.enabled is False @patch("vllm_omni.diffusion.cache.magcache.backend.apply_mag_cache_hook") @@ -249,10 +249,6 @@ def test_enable(self, mock_apply_hook): mock_ratios = [1.0] * 28 config = DiffusionCacheConfig( mag_ratios=mock_ratios, - num_inference_steps=28, - threshold=0.06, - max_skip_steps=3, - retention_ratio=0.2, ) backend = MagCacheBackend(config) backend.enable(mock_pipeline) @@ -273,11 +269,7 @@ def test_enable_with_calibration(self, mock_apply_hook): mock_pipeline.transformer = mock_transformer config = DiffusionCacheConfig( - calibrate=True, - num_inference_steps=28, - threshold=0.06, - max_skip_steps=3, - retention_ratio=0.2, + mag_calibrate=True, ) backend = MagCacheBackend(config) backend.enable(mock_pipeline) @@ -298,7 +290,6 @@ def test_refresh(self): mock_ratios = [1.0] * 28 config = DiffusionCacheConfig( mag_ratios=mock_ratios, - num_inference_steps=28, ) backend = MagCacheBackend(config) @@ -327,3 +318,54 @@ def test_get_mag_cache_backend(self): assert backend is not None assert isinstance(backend, MagCacheBackend) assert backend.config.threshold == 0.06 + + @patch("vllm_omni.diffusion.cache.magcache.backend.apply_mag_cache_hook") + def test_enable_single_block(self, mock_apply_hook): + """Test enabling MagCache on single transformer block.""" + mock_pipeline = Mock() + mock_pipeline.__class__.__name__ = "FluxPipeline" + + mock_block = Mock() + mock_block.__class__.__name__ = "FluxTransformer2DModel" + mock_blocks = [mock_block] + + mock_transformer = Mock() + mock_transformer.__class__.__name__ = "FluxTransformer2DModel" + mock_transformer.blocks = mock_blocks + mock_pipeline.transformer = mock_transformer + + mock_ratios = [1.0] * 28 + config = DiffusionCacheConfig( + mag_ratios=mock_ratios, + ) + backend = MagCacheBackend(config) + backend.enable(mock_pipeline) + + assert backend.enabled is True + mock_apply_hook.assert_called_once() + + call_args = mock_apply_hook.call_args + assert call_args[0][0] == mock_transformer + + @patch("vllm_omni.diffusion.cache.magcache.backend.apply_mag_cache_hook") + def test_enable_multi_block(self, mock_apply_hook): + """Test enabling MagCache on multiple transformer blocks.""" + mock_pipeline = Mock() + mock_pipeline.__class__.__name__ = "FluxPipeline" + + mock_blocks = [Mock() for _ in range(24)] + + mock_transformer = Mock() + mock_transformer.__class__.__name__ = "FluxTransformer2DModel" + mock_transformer.blocks = mock_blocks + mock_pipeline.transformer = mock_transformer + + mock_ratios = [1.0] * 28 + config = DiffusionCacheConfig( + mag_ratios=mock_ratios, + ) + backend = MagCacheBackend(config) + backend.enable(mock_pipeline) + + assert backend.enabled is True + mock_apply_hook.assert_called_once() diff --git a/vllm_omni/diffusion/cache/magcache/hook.py b/vllm_omni/diffusion/cache/magcache/hook.py index f2b4fadceba..6a440e020ef 100644 --- a/vllm_omni/diffusion/cache/magcache/hook.py +++ b/vllm_omni/diffusion/cache/magcache/hook.py @@ -44,12 +44,19 @@ class MagCacheHeadHook(ModelHook): _HOOK_NAME = "mag_cache_head" - def __init__(self, state_manager: StateManager, config: MagCacheConfig, strategy: MagCacheStrategy | None = None): + def __init__( + self, + state_manager: StateManager, + config: MagCacheConfig, + strategy: MagCacheStrategy | None = None, + is_tail: bool = False, + ): super().__init__() self.state_manager = state_manager self.config = config self._strategy = strategy self._metadata = None + self._is_tail = is_tail def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: unwrapped_module = unwrap_module(module) @@ -145,7 +152,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): ) if original_encoder_hidden_states.device != res[1].device: original_encoder_hidden_states = original_encoder_hidden_states.to(res[1].device) - h_res, e_res = res output, enc_output = self._strategy.apply_residual_tuple( hidden_states, original_encoder_hidden_states, res ) @@ -197,8 +203,69 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): output = self.fn_ref.original_forward(*args, **kwargs) result = self.log_cache_miss(state, output) + if self._is_tail: + if isinstance(output, tuple): + out_hidden = output[self._metadata.return_hidden_states_index] + else: + out_hidden = output + + in_hidden = state.head_block_input + + if in_hidden is not None: + if self._strategy is not None: + residual = self._strategy.compute_residual(out_hidden, in_hidden) + elif out_hidden.shape == in_hidden.shape: + residual = out_hidden - in_hidden + else: + residual = out_hidden + + if self.config.mag_calibrate: + self._perform_calibration_head(state, residual) + + state.previous_residual = residual + self._advance_step_head(state) + return result + def _perform_calibration_head( + self, + state: MagCacheState, + current_residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + ) -> None: + if self._strategy is not None: + ratio, std, cos_dis = self._strategy.compute_calibration_metrics(current_residual, state.previous_residual) + else: + if state.previous_residual is None: + ratio, std, cos_dis = 1.0, 0.0, 0.0 + else: + ratio, std, cos_dis = 1.0, 0.0, 0.0 + + state.calibration_ratios.append(ratio) + state.norm_ratios.append(round(ratio, 5)) + state.norm_stds.append(round(std, 5)) + state.cos_dises.append(round(cos_dis, 5)) + + def _advance_step_head(self, state: MagCacheState) -> None: + state.step_index += 1 + if state.step_index >= self.config.num_inference_steps: + if self.config.mag_calibrate: + logger.info("MagCache calibration complete.") + logger.info(f"norm_ratios: {state.norm_ratios}") + logger.info(f"norm_stds: {state.norm_stds}") + logger.info(f"cos_dises: {state.cos_dises}") + logger.info("Copy these values to DiffusionCacheConfig(mag_ratios=...) for production use") + + state.step_index = 0 + state.accumulated_ratio = 1.0 + state.accumulated_steps = 0 + state.accumulated_err = 0.0 + state.previous_residual = None + state.calibration_ratios = [] + state.norm_ratios = [] + state.norm_stds = [] + state.cos_dises = [] + state._is_first_step = True + def log_cache_hit(self, state: MagCacheState, output, ret): step = state.step_index if state.previous_residual is not None: @@ -494,9 +561,8 @@ def apply_mag_cache_hook( if len(remaining_blocks) == 1: name, block = remaining_blocks[0] - logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'") - _apply_mag_cache_block_hook(block, state_manager, config, is_tail=True, strategy=strategy) - _apply_mag_cache_head_hook(block, state_manager, config, strategy) + logger.info(f"MagCache: Applying Head+Tail Hook to single block '{name}'") + _apply_mag_cache_head_hook(block, state_manager, config, strategy, is_tail=True) return head_block_name, head_block = remaining_blocks.pop(0) @@ -517,13 +583,14 @@ def _apply_mag_cache_head_hook( state_manager: StateManager, config: MagCacheConfig, strategy: MagCacheStrategy | None = None, + is_tail: bool = False, ) -> None: registry = HookRegistry.check_if_exists_or_initialize(block) if registry.get_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) is not None: registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) - hook = MagCacheHeadHook(state_manager, config, strategy) + hook = MagCacheHeadHook(state_manager, config, strategy, is_tail=is_tail) registry.register_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, hook) From 082db90d35d0cf4aa102756dc31c40d8220f0737 Mon Sep 17 00:00:00 2001 From: Lancer Date: Sat, 21 Mar 2026 00:40:52 +0800 Subject: [PATCH 10/11] upd Signed-off-by: Lancer --- vllm_omni/engine/async_omni_engine.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 5562b84ff29..77c1dfc80c8 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -750,6 +750,12 @@ def _get_default_cache_config(cache_backend: str | None) -> dict[str, Any] | Non return { "rel_l1_thresh": 0.2, } + if cache_backend == "mag_cache": + return { + "mag_threshold": 0.24, + "mag_max_skip_steps": 5, + "mag_retention_ratio": 0.1, + } return None @staticmethod From 0bc1099571141a276ec85ad0ed67350f1c0280c0 Mon Sep 17 00:00:00 2001 From: Lancer Date: Thu, 30 Apr 2026 16:12:27 +0800 Subject: [PATCH 11/11] upd Signed-off-by: Lancer --- vllm_omni/diffusion/hooks/base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm_omni/diffusion/hooks/base.py b/vllm_omni/diffusion/hooks/base.py index 5099637fa74..6863a459425 100644 --- a/vllm_omni/diffusion/hooks/base.py +++ b/vllm_omni/diffusion/hooks/base.py @@ -238,9 +238,9 @@ def register_hook(self, name: str, hook: ModelHook) -> None: hook.initialize_hook(self.module) if hasattr(hook, "fn_ref"): - hook.fn_ref.original_forward = self.module._original_forward + hook.fn_ref.original_forward = self.module._omni_original_forward else: - original_forward = self.module._original_forward # type: ignore[attr-defined] + original_forward = self.module._omni_original_forward # type: ignore[attr-defined] class _FnRef: def __init__(self, orig_forward): @@ -343,6 +343,6 @@ def reset(self) -> None: hook.reset_state(self.module) self._hooks.clear() - if hasattr(self.module, "_original_forward"): - self.module.forward = self.module._original_forward # type: ignore[attr-defined] - delattr(self.module, "_original_forward") + if hasattr(self.module, "_omni_original_forward"): + self.module.forward = self.module._omni_original_forward # type: ignore[attr-defined] + delattr(self.module, "_omni_original_forward")