diff --git a/tests/diffusion/cache/test_cache_backends.py b/tests/diffusion/cache/test_cache_backends.py index b68944c6c7a..021473ca5cc 100644 --- a/tests/diffusion/cache/test_cache_backends.py +++ b/tests/diffusion/cache/test_cache_backends.py @@ -19,6 +19,7 @@ CUSTOM_DIT_ENABLERS, CacheDiTBackend, ) +from vllm_omni.diffusion.cache.magcache.backend import MagCacheBackend from vllm_omni.diffusion.cache.selector import get_cache_backend from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend from vllm_omni.diffusion.data import DiffusionCacheConfig @@ -313,3 +314,150 @@ def test_get_cache_backend_invalid(self): """Test getting invalid backend raises error.""" with pytest.raises(ValueError, match="Unsupported cache backend"): get_cache_backend("invalid_backend", {}) + + +class TestMagCacheBackend: + """Test MagCacheBackend implementation.""" + + from vllm_omni.diffusion.cache.magcache.backend import MagCacheBackend + + def test_init(self): + """Test initialization.""" + config = DiffusionCacheConfig(mag_threshold=0.1, mag_max_skip_steps=2, mag_calibrate=True) + backend = MagCacheBackend(config) + assert backend.config.mag_threshold == 0.1 + assert backend.config.mag_max_skip_steps == 2 + assert backend.enabled is False + + @patch("vllm_omni.diffusion.cache.magcache.backend.apply_mag_cache_hook") + def test_enable(self, mock_apply_hook): + """Test enabling MagCache on pipeline.""" + mock_pipeline = Mock() + mock_pipeline.__class__.__name__ = "FluxPipeline" + mock_transformer = Mock() + mock_transformer.__class__.__name__ = "FluxTransformer2DModel" + mock_pipeline.transformer = mock_transformer + + mock_ratios = [1.0] * 28 + config = DiffusionCacheConfig( + mag_ratios=mock_ratios, + ) + backend = MagCacheBackend(config) + backend.enable(mock_pipeline) + + assert backend.enabled is True + mock_apply_hook.assert_called_once() + + call_args = mock_apply_hook.call_args + assert call_args[0][0] == mock_transformer + + @patch("vllm_omni.diffusion.cache.magcache.backend.apply_mag_cache_hook") + def test_enable_with_calibration(self, mock_apply_hook): + """Test enabling MagCache in calibration mode.""" + mock_pipeline = Mock() + mock_pipeline.__class__.__name__ = "FluxPipeline" + mock_transformer = Mock() + mock_transformer.__class__.__name__ = "FluxTransformer2DModel" + mock_pipeline.transformer = mock_transformer + + config = DiffusionCacheConfig( + mag_calibrate=True, + ) + backend = MagCacheBackend(config) + backend.enable(mock_pipeline) + + assert backend.enabled is True + mock_apply_hook.assert_called_once() + + def test_refresh(self): + """Test refreshing MagCache state calls enable when not registered.""" + mock_pipeline = Mock() + mock_pipeline.__class__.__name__ = "FluxPipeline" + mock_transformer = Mock() + mock_transformer.__class__.__name__ = "FluxTransformer2DModel" + mock_pipeline.transformer = mock_transformer + + mock_transformer.named_children = Mock(return_value=[]) + + mock_ratios = [1.0] * 28 + config = DiffusionCacheConfig( + mag_ratios=mock_ratios, + ) + backend = MagCacheBackend(config) + + assert backend._registered is False + backend.refresh(mock_pipeline, num_inference_steps=50) + assert backend._registered is True + + def test_is_enabled(self): + """Test is_enabled method.""" + mock_ratios = [1.0] * 28 + config = DiffusionCacheConfig(mag_ratios=mock_ratios) + backend = MagCacheBackend(config) + assert backend.is_enabled() is False + + def test_get_mag_cache_backend(self): + """Test getting MagCache backend via selector.""" + mock_ratios = [1.0] * 28 + config_dict = { + "mag_ratios": mock_ratios, + "num_inference_steps": 28, + "threshold": 0.06, + "max_skip_steps": 3, + "retention_ratio": 0.2, + } + backend = get_cache_backend("mag_cache", config_dict) + assert backend is not None + assert isinstance(backend, MagCacheBackend) + assert backend.config.threshold == 0.06 + + @patch("vllm_omni.diffusion.cache.magcache.backend.apply_mag_cache_hook") + def test_enable_single_block(self, mock_apply_hook): + """Test enabling MagCache on single transformer block.""" + mock_pipeline = Mock() + mock_pipeline.__class__.__name__ = "FluxPipeline" + + mock_block = Mock() + mock_block.__class__.__name__ = "FluxTransformer2DModel" + mock_blocks = [mock_block] + + mock_transformer = Mock() + mock_transformer.__class__.__name__ = "FluxTransformer2DModel" + mock_transformer.blocks = mock_blocks + mock_pipeline.transformer = mock_transformer + + mock_ratios = [1.0] * 28 + config = DiffusionCacheConfig( + mag_ratios=mock_ratios, + ) + backend = MagCacheBackend(config) + backend.enable(mock_pipeline) + + assert backend.enabled is True + mock_apply_hook.assert_called_once() + + call_args = mock_apply_hook.call_args + assert call_args[0][0] == mock_transformer + + @patch("vllm_omni.diffusion.cache.magcache.backend.apply_mag_cache_hook") + def test_enable_multi_block(self, mock_apply_hook): + """Test enabling MagCache on multiple transformer blocks.""" + mock_pipeline = Mock() + mock_pipeline.__class__.__name__ = "FluxPipeline" + + mock_blocks = [Mock() for _ in range(24)] + + mock_transformer = Mock() + mock_transformer.__class__.__name__ = "FluxTransformer2DModel" + mock_transformer.blocks = mock_blocks + mock_pipeline.transformer = mock_transformer + + mock_ratios = [1.0] * 28 + config = DiffusionCacheConfig( + mag_ratios=mock_ratios, + ) + backend = MagCacheBackend(config) + backend.enable(mock_pipeline) + + assert backend.enabled is True + mock_apply_hook.assert_called_once() diff --git a/vllm_omni/diffusion/cache/__init__.py b/vllm_omni/diffusion/cache/__init__.py index a5968f612a4..d6403f0fa7f 100644 --- a/vllm_omni/diffusion/cache/__init__.py +++ b/vllm_omni/diffusion/cache/__init__.py @@ -5,6 +5,7 @@ This module provides a unified cache backend system for different caching strategies: - TeaCache: Timestep Embedding Aware Cache for adaptive transformer caching +- MagCache: Magnitude-based Cache for adaptive transformer caching - cache-dit: DBCache, SCM, and TaylorSeer caching strategies Cache backends are instantiated directly via their constructors and configured via OmniDiffusionConfig. @@ -20,8 +21,8 @@ __all__ = [ "CacheBackend", - "TeaCacheConfig", "CacheContext", "TeaCacheBackend", + "TeaCacheConfig", "apply_teacache_hook", ] diff --git a/vllm_omni/diffusion/cache/magcache/__init__.py b/vllm_omni/diffusion/cache/magcache/__init__.py new file mode 100644 index 00000000000..08f0d88fa15 --- /dev/null +++ b/vllm_omni/diffusion/cache/magcache/__init__.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.diffusion.cache.magcache.config import MagCacheConfig +from vllm_omni.diffusion.cache.magcache.hook import ( + MagCacheBlockHook, + MagCacheHeadHook, + MagCacheState, + apply_mag_cache_hook, +) +from vllm_omni.diffusion.cache.magcache.strategy import ( + Flux2MagCacheStrategy, + FluxMagCacheStrategy, + MagCacheStrategy, + MagCacheStrategyRegistry, + get_strategy, + register_strategy, +) + +__all__ = [ + "Flux2MagCacheStrategy", + "FluxMagCacheStrategy", + "MagCacheBlockHook", + "MagCacheConfig", + "MagCacheHeadHook", + "MagCacheState", + "MagCacheStrategy", + "MagCacheStrategyRegistry", + "apply_mag_cache_hook", + "get_strategy", + "register_strategy", +] diff --git a/vllm_omni/diffusion/cache/magcache/backend.py b/vllm_omni/diffusion/cache/magcache/backend.py new file mode 100644 index 00000000000..4ad427d3a2f --- /dev/null +++ b/vllm_omni/diffusion/cache/magcache/backend.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +MagCache backend implementation. + +This module provides the MagCache backend that implements the CacheBackend +interface using the hooks-based MagCache system. +""" + +from typing import Any + +import torch +from vllm.logger import init_logger + +from vllm_omni.diffusion.cache.base import CacheBackend +from vllm_omni.diffusion.cache.magcache.config import MagCacheConfig +from vllm_omni.diffusion.cache.magcache.hook import ( + apply_mag_cache_hook, +) +from vllm_omni.diffusion.cache.magcache.strategy import ( + get_strategy, +) +from vllm_omni.diffusion.data import DiffusionCacheConfig + +logger = init_logger(__name__) + + +class MagCacheBackend(CacheBackend): + """ + MagCache implementation using hooks. + + MagCache (Magnitude-based Cache) is an adaptive caching technique that + speeds up diffusion inference by reusing transformer block computations + based on accumulated magnitude error between timesteps. + + The backend applies MagCache hooks to the transformer which intercept the + forward pass and implement the caching logic transparently. + + Example: + >>> from vllm_omni.diffusion.data import DiffusionCacheConfig + >>> from vllm_omni.diffusion.cache.magcache.config import MagCacheConfig + >>> from vllm_omni.diffusion.cache.magcache.strategy import FluxMagCacheStrategy + >>> cache_config = DiffusionCacheConfig( + ... mag_ratios=FluxMagCacheStrategy.FLUX_MAG_RATIOS, + ... num_inference_steps=28, + ... mag_threshold=0.24, + ... mag_max_skip_steps=5, + ... mag_retention_ratio=0.1, + ... ) + >>> backend = MagCacheBackend(cache_config) + >>> backend.enable(pipeline) + >>> backend.refresh(pipeline, num_inference_steps=50) + """ + + def __init__(self, config: DiffusionCacheConfig): + super().__init__(config) + self._registered = False + self._magcache_config: MagCacheConfig | None = None + self._transformer_id: int | None = None + + def enable(self, pipeline: Any) -> None: + """Enable MagCache on transformer using hooks. + + This creates a MagCacheConfig from the backend's DiffusionCacheConfig + and applies the MagCache hook to the transformer. + + Args: + pipeline: Diffusion pipeline instance. Extracts transformer and transformer_type: + - transformer: pipeline.transformer + - transformer_type: pipeline.transformer.__class__.__name__ + """ + transformer = pipeline.transformer + transformer_type = transformer.__class__.__name__ + + num_inference_steps = self.config.num_inference_steps or 28 + + mag_ratios = self.config.mag_ratios + strategy = None + + if mag_ratios is None and not self.config.mag_calibrate: + strategy = get_strategy(transformer_type) + original_ratios = strategy.mag_ratios + + if len(original_ratios) != num_inference_steps and hasattr(strategy, "nearest_interp"): + mag_ratios = strategy.nearest_interp(original_ratios, num_inference_steps) + logger.info( + f"MagCache: Interpolated mag_ratios from {len(original_ratios)} to {num_inference_steps} steps" + ) + else: + mag_ratios = original_ratios + if len(original_ratios) != num_inference_steps: + logger.warning( + f"MagCache: mag_ratios length ({len(original_ratios)}) != " + f"num_inference_steps ({num_inference_steps}), " + f"this may cause unexpected behavior" + ) + + logger.info(f"MagCache: Using mag_ratios from {type(strategy).__name__}") + + if mag_ratios is None and not self.config.mag_calibrate: + raise ValueError( + f"mag_ratios must be provided for MagCache. " + f"For {transformer_type}, you need to provide mag_ratios or run in calibrate mode." + ) + + self._magcache_config = MagCacheConfig( + transformer_type=transformer_type, + threshold=self.config.mag_threshold, + max_skip_steps=self.config.mag_max_skip_steps, + retention_ratio=self.config.mag_retention_ratio, + num_inference_steps=num_inference_steps, + mag_calibrate=self.config.mag_calibrate, + mag_ratios=mag_ratios if not self.config.mag_calibrate else None, + ) + self._transformer_id = id(transformer) + + apply_mag_cache_hook(transformer, self._magcache_config, strategy=strategy) + + self._registered = True + self.enabled = True + + def refresh(self, pipeline: Any, num_inference_steps: int) -> None: + """Refresh MagCache state for new generation. + + Clears all cached residuals and resets counters/accumulators. + Should be called before each generation to ensure clean state. + + Args: + pipeline: Diffusion pipeline instance. Extracts transformer via pipeline.transformer. + num_inference_steps: Number of inference steps for the current generation. + May be used for cache context updates. + """ + transformer = pipeline.transformer + current_transformer_id = id(transformer) + + needs_re_register = False + + if self._registered and hasattr(self, "_transformer_id"): + if current_transformer_id != self._transformer_id: + logger.warning( + f"Transformer was replaced (id changed from {self._transformer_id} " + f"to {current_transformer_id}), re-registering hooks" + ) + needs_re_register = True + + if not self._registered or needs_re_register: + self.enable(pipeline) + return + + blocks_with_hooks = [] + + for name, submodule in transformer.named_children(): + if not isinstance(submodule, torch.nn.ModuleList): + continue + for index, block in enumerate(submodule): + registry = getattr(block, "_hook_registry", None) + if registry is not None and len(registry._hooks) > 0: + blocks_with_hooks.append((f"{name}.{index}", block, registry)) + + if not blocks_with_hooks: + logger.warning("No hooks found on transformer blocks, re-registering") + apply_mag_cache_hook(transformer, self._magcache_config) + self._transformer_id = current_transformer_id + else: + for name, block, registry in blocks_with_hooks: + for hook in registry._hooks.values(): + if hasattr(hook, "reset_state"): + hook.reset_state(block) + + def is_enabled(self) -> bool: + """Check if MagCache is enabled. + + Returns: + True if enabled, False otherwise. + """ + return self.enabled diff --git a/vllm_omni/diffusion/cache/magcache/config.py b/vllm_omni/diffusion/cache/magcache/config.py new file mode 100644 index 00000000000..6ccab8fa5a0 --- /dev/null +++ b/vllm_omni/diffusion/cache/magcache/config.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + + +@dataclass +class MagCacheConfig: + """ + Configuration for MagCache applied to transformer models. + + MagCache (Magnitude-based Cache) is an adaptive caching technique that speeds up + diffusion model inference by reusing transformer block computations based on + magnitude ratios between consecutive timesteps. + + Reference: https://github.com/Zehong-Ma/MagCache + + Args: + threshold: Accumulated error threshold. Higher = more aggressive skipping (faster, lower quality). + Default: 0.24 + max_skip_steps: Max consecutive skip steps (K). + Default: 5 + retention_ratio: Fraction of initial steps where skipping is disabled (stability). + Default: 0.1 + num_inference_steps: Total inference steps. Required for retention step calculation. + Default: 28 + mag_ratios: Pre-computed magnitude ratios per step. Calibrate or use strategy defaults. + Default: None + mag_calibrate: If True, runs without skipping and logs norm_ratios for calibration. + Default: False + transformer_type: Transformer class name for logging. + Default: "FluxTransformer2DModel" + """ + + threshold: float = 0.24 + max_skip_steps: int = 5 + retention_ratio: float = 0.1 + num_inference_steps: int = 28 + mag_ratios: torch.Tensor | list[float] | None = None + mag_calibrate: bool = False + transformer_type: str = "FluxTransformer2DModel" + + def __post_init__(self) -> None: + """Validate and set default coefficients.""" + if self.threshold <= 0: + raise ValueError(f"threshold must be positive, got {self.threshold}") + + if self.max_skip_steps <= 0: + raise ValueError(f"max_skip_steps must be positive, got {self.max_skip_steps}") + + if not 0 < self.retention_ratio < 1: + raise ValueError(f"retention_ratio must be in (0, 1), got {self.retention_ratio}") + + if self.num_inference_steps is None: + raise ValueError( + "num_inference_steps must be provided for MagCache. " + "This is required to determine retention steps and interpolate mag_ratios. " + "For Flux models, use num_inference_steps=28." + ) + + if self.num_inference_steps <= 0: + raise ValueError(f"num_inference_steps must be positive, got {self.num_inference_steps}") + + if not self.mag_calibrate and self.mag_ratios is None: + raise ValueError( + "mag_ratios must be provided for MagCache inference because these ratios " + "are model-dependent. To get them for your model:\n" + "1. Initialize MagCacheConfig(mag_calibrate=True, ...)\n" + "2. Run inference on your model once.\n" + "3. Copy the printed ratios array and pass it to mag_ratios in the config.\n" + "For Flux models, you can import FLUX_MAG_RATIOS from vllm_omni.diffusion.cache.magcache.strategy." + ) + + if not self.mag_calibrate and self.mag_ratios is not None: + if not torch.is_tensor(self.mag_ratios): + self.mag_ratios = torch.tensor(self.mag_ratios) diff --git a/vllm_omni/diffusion/cache/magcache/hook.py b/vllm_omni/diffusion/cache/magcache/hook.py new file mode 100644 index 00000000000..6a440e020ef --- /dev/null +++ b/vllm_omni/diffusion/cache/magcache/hook.py @@ -0,0 +1,610 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Hook-based MagCache implementation for vLLM-Omni. + +This module implements a diffusers-style hook system for MagCache (Magnitude-based Cache), +providing adaptive caching for diffusion model inference. + +MagCache speeds up inference by skipping transformer block computations when the accumulated +magnitude error is below a threshold, reusing cached residuals instead. + +Based on: https://github.com/Zehong-Ma/MagCache +Reference: diffusers/src/diffusers/hooks/mag_cache.py + +Architecture: +- MagCacheStrategy: Model-specific strategy for preprocessing/postprocessing +- MagCacheState: Per-step state tracking residuals and accumulated error +- MagCacheHeadHook: Decides whether to skip based on accumulated error +- MagCacheBlockHook: Computes and stores residuals at tail block +""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry +from diffusers.utils.torch_utils import unwrap_module + +from vllm_omni.diffusion.cache.magcache.config import MagCacheConfig +from vllm_omni.diffusion.cache.magcache.state import MagCacheState +from vllm_omni.diffusion.cache.magcache.strategy import MagCacheStrategy, MagCacheStrategyRegistry +from vllm_omni.diffusion.hooks.base import HookRegistry, ModelHook, StateManager +from vllm_omni.logger import init_logger + +logger = init_logger(__name__) + +_MAG_CACHE_LEADER_BLOCK_HOOK = "mag_cache_leader_block_hook" +_MAG_CACHE_BLOCK_HOOK = "mag_cache_block_hook" + + +class MagCacheHeadHook(ModelHook): + """Head block hook for MagCache - decides whether to skip computation.""" + + _HOOK_NAME = "mag_cache_head" + + def __init__( + self, + state_manager: StateManager, + config: MagCacheConfig, + strategy: MagCacheStrategy | None = None, + is_tail: bool = False, + ): + super().__init__() + self.state_manager = state_manager + self.config = config + self._strategy = strategy + self._metadata = None + self._is_tail = is_tail + + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + unwrapped_module = unwrap_module(module) + block_class = unwrapped_module.__class__ + + try: + self._metadata = TransformerBlockRegistry.get(block_class) + except ValueError: + if self._strategy is not None: + metadata = self._strategy.register_block_metadata(block_class) + if metadata is not None: + TransformerBlockRegistry.register(model_class=block_class, metadata=metadata) + else: + TransformerBlockRegistry.register( + model_class=block_class, + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + else: + TransformerBlockRegistry.register( + model_class=block_class, + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + self._metadata = TransformerBlockRegistry.get(block_class) + + return module + + @torch.compiler.disable + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.state_manager._current_context is None: + self.state_manager.set_context("magcache") + + if hasattr(self._metadata, "hidden_states_argument_name"): + arg_name = self._metadata.hidden_states_argument_name + else: + arg_name = "hidden_states" + hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs) + + state: MagCacheState = self.state_manager.get_state() + state.head_block_input = hidden_states + + if state._is_first_step: + state.accumulated_ratio = 1.0 + state.accumulated_err = 0.0 + state.accumulated_steps = 0 + state._is_first_step = False + + should_compute = True + + if self.config.mag_calibrate: + should_compute = True + else: + current_step = state.step_index + if current_step >= len(self.config.mag_ratios): + current_scale = 1.0 + else: + current_scale = self.config.mag_ratios[current_step] + + retention_step = int(self.config.retention_ratio * self.config.num_inference_steps + 0.5) + + if current_step >= retention_step: + state.accumulated_ratio *= current_scale + state.accumulated_steps += 1 + state.accumulated_err += abs(1.0 - state.accumulated_ratio) + + if ( + state.previous_residual is not None + and state.accumulated_err <= self.config.threshold + and state.accumulated_steps <= self.config.max_skip_steps + ): + should_compute = False + else: + state.accumulated_ratio = 1.0 + state.accumulated_steps = 0 + state.accumulated_err = 0.0 + + state.should_compute = should_compute + + if not should_compute: + res = state.previous_residual + + if isinstance(res, tuple): + res = tuple(r.to(hidden_states.device) for r in res) + + if self._strategy is not None: + original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + if original_encoder_hidden_states.device != res[1].device: + original_encoder_hidden_states = original_encoder_hidden_states.to(res[1].device) + output, enc_output = self._strategy.apply_residual_tuple( + hidden_states, original_encoder_hidden_states, res + ) + ret_list = [None] * 2 + ret_list[self._metadata.return_hidden_states_index] = output + ret_list[self._metadata.return_encoder_hidden_states_index] = enc_output + return self.log_cache_hit(state, output, ret_list) + else: + raise RuntimeError( + f"MagCache residual is tuple but no strategy available for {self._metadata.transformer_type}. " + f"Please register a MagCacheStrategy for this model." + ) + elif res.device != hidden_states.device: + res = res.to(hidden_states.device) + + if self._strategy is not None: + output = self._strategy.apply_residual(hidden_states, res) + elif res.shape == hidden_states.shape: + output = hidden_states + res + elif ( + hidden_states.ndim == 3 + and res.ndim == 3 + and hidden_states.shape[0] == res.shape[0] + and hidden_states.shape[2] == res.shape[2] + ): + diff = hidden_states.shape[1] - res.shape[1] + if diff > 0: + output = hidden_states.clone() + output[:, diff:, :] = output[:, diff:, :] + res + else: + output = hidden_states + res + else: + output = hidden_states + res + + if self._metadata.return_encoder_hidden_states_index is not None: + original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + max_idx = max( + self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index + ) + ret_list = [None] * (max_idx + 1) + ret_list[self._metadata.return_hidden_states_index] = output + ret_list[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states + return self.log_cache_hit(state, output, ret_list) + else: + return self.log_cache_hit(state, output, None) + else: + output = self.fn_ref.original_forward(*args, **kwargs) + result = self.log_cache_miss(state, output) + + if self._is_tail: + if isinstance(output, tuple): + out_hidden = output[self._metadata.return_hidden_states_index] + else: + out_hidden = output + + in_hidden = state.head_block_input + + if in_hidden is not None: + if self._strategy is not None: + residual = self._strategy.compute_residual(out_hidden, in_hidden) + elif out_hidden.shape == in_hidden.shape: + residual = out_hidden - in_hidden + else: + residual = out_hidden + + if self.config.mag_calibrate: + self._perform_calibration_head(state, residual) + + state.previous_residual = residual + self._advance_step_head(state) + + return result + + def _perform_calibration_head( + self, + state: MagCacheState, + current_residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + ) -> None: + if self._strategy is not None: + ratio, std, cos_dis = self._strategy.compute_calibration_metrics(current_residual, state.previous_residual) + else: + if state.previous_residual is None: + ratio, std, cos_dis = 1.0, 0.0, 0.0 + else: + ratio, std, cos_dis = 1.0, 0.0, 0.0 + + state.calibration_ratios.append(ratio) + state.norm_ratios.append(round(ratio, 5)) + state.norm_stds.append(round(std, 5)) + state.cos_dises.append(round(cos_dis, 5)) + + def _advance_step_head(self, state: MagCacheState) -> None: + state.step_index += 1 + if state.step_index >= self.config.num_inference_steps: + if self.config.mag_calibrate: + logger.info("MagCache calibration complete.") + logger.info(f"norm_ratios: {state.norm_ratios}") + logger.info(f"norm_stds: {state.norm_stds}") + logger.info(f"cos_dises: {state.cos_dises}") + logger.info("Copy these values to DiffusionCacheConfig(mag_ratios=...) for production use") + + state.step_index = 0 + state.accumulated_ratio = 1.0 + state.accumulated_steps = 0 + state.accumulated_err = 0.0 + state.previous_residual = None + state.calibration_ratios = [] + state.norm_ratios = [] + state.norm_stds = [] + state.cos_dises = [] + state._is_first_step = True + + def log_cache_hit(self, state: MagCacheState, output, ret): + step = state.step_index + if state.previous_residual is not None: + if isinstance(state.previous_residual, tuple): + residual_shape = tuple(r.shape for r in state.previous_residual) + else: + residual_shape = state.previous_residual.shape + else: + residual_shape = "None" + logger.debug( + f"[MagCache][HEAD] STEP={step}: CACHE_HIT (err={state.accumulated_err:.6f}, " + f"steps_skipped={state.accumulated_steps}, residual_shape={residual_shape}" + ) + return ret if ret is not None else output + + def log_cache_miss(self, state: MagCacheState, output): + step = state.step_index + residual_norm = 0.0 + if state.previous_residual is not None: + if isinstance(state.previous_residual, tuple): + residual_norm = sum(float(torch.norm(r).item()) for r in state.previous_residual) + else: + residual_norm = float(torch.norm(state.previous_residual).item()) + logger.debug( + f"[MagCache][HEAD] STEP={step}: CACHE_MISS " + f"(err={state.accumulated_err:.6f}, acc_ratio={state.accumulated_ratio:.6f}, " + f"residual_norm={residual_norm:.6f}, threshold={self.config.threshold}, " + f"max_skip={self.config.max_skip_steps})" + ) + return output + + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + self.state_manager.reset() + return module + + +class MagCacheBlockHook(ModelHook): + """Block hook for MagCache - computes residuals at tail block.""" + + _HOOK_NAME = "mag_cache_block" + + def __init__( + self, + state_manager: StateManager, + is_tail: bool = False, + config: MagCacheConfig | None = None, + strategy: MagCacheStrategy | None = None, + ): + super().__init__() + self.state_manager = state_manager + self.is_tail = is_tail + self.config = config + self._strategy = strategy + self._metadata = None + + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + unwrapped_module = unwrap_module(module) + block_class = unwrapped_module.__class__ + + try: + self._metadata = TransformerBlockRegistry.get(block_class) + except ValueError: + if self._strategy is not None: + metadata = self._strategy.register_block_metadata(block_class) + if metadata is not None: + TransformerBlockRegistry.register(model_class=block_class, metadata=metadata) + else: + TransformerBlockRegistry.register( + model_class=block_class, + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + else: + TransformerBlockRegistry.register( + model_class=block_class, + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + self._metadata = TransformerBlockRegistry.get(block_class) + + return module + + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + self.state_manager.reset() + return module + + @torch.compiler.disable + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.state_manager._current_context is None: + self.state_manager.set_context("magcache") + state: MagCacheState = self.state_manager.get_state() + + if not state.should_compute: + res = state.previous_residual + if res is None: + res = torch.zeros_like(args[0]) + + if hasattr(self._metadata, "hidden_states_argument_name"): + arg_name = self._metadata.hidden_states_argument_name + else: + arg_name = "hidden_states" + hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs) + + if self.is_tail: + self.advance_step(state) + + if self._metadata.return_encoder_hidden_states_index is not None: + encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + + if self._strategy is not None: + out_hidden, enc_out = self._strategy.apply_residual_tuple(hidden_states, encoder_hidden_states, res) + else: + out_hidden = hidden_states + res + enc_out = encoder_hidden_states + + max_idx = max( + self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index + ) + ret_list = [None] * (max_idx + 1) + ret_list[self._metadata.return_hidden_states_index] = out_hidden + ret_list[self._metadata.return_encoder_hidden_states_index] = enc_out + return tuple(ret_list) + + if self._strategy is not None: + output = self._strategy.apply_residual(hidden_states, res) + else: + output = hidden_states + res + + return output + + output = self.fn_ref.original_forward(*args, **kwargs) + + if self.is_tail: + if isinstance(output, tuple): + out_hidden = output[self._metadata.return_hidden_states_index] + else: + out_hidden = output + + in_hidden = state.head_block_input + + if in_hidden is None: + return output + + if self._strategy is not None: + residual = self._strategy.compute_residual(output, in_hidden) + elif out_hidden.shape == in_hidden.shape: + residual = out_hidden - in_hidden + elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]: + residual = out_hidden - in_hidden + else: + residual = out_hidden + + if self.config.mag_calibrate: + self.perform_calibration(state, residual) + + state.previous_residual = residual + self.advance_step(state) + + self.log_residual_computed(state, residual) + + return output + + def log_residual_computed( + self, + state: MagCacheState, + residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + ) -> None: + step = state.step_index + if residual is None: + residual_norm = 0.0 + residual_shape = "None" + elif isinstance(residual, tuple): + residual_norm = sum(float(torch.norm(r).item()) for r in residual) + residual_shape = tuple(r.shape for r in residual) + else: + residual_norm = float(torch.norm(residual).item()) + residual_shape = residual.shape + logger.debug( + f"[MagCache][TAIL] STEP={step}: RESIDUAL_COMPUTED (norm={residual_norm:.6f}, shape={residual_shape})" + ) + + def perform_calibration( + self, + state: MagCacheState, + current_residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + ) -> None: + if self._strategy is not None: + ratio, std, cos_dis = self._strategy.compute_calibration_metrics(current_residual, state.previous_residual) + else: + + def _get_norm(residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + if isinstance(residual, tuple): + return sum(torch.linalg.norm(r.float(), dim=-1) for r in residual) + return torch.linalg.norm(residual.float(), dim=-1) + + if state.previous_residual is None: + ratio, std, cos_dis = 1.0, 0.0, 0.0 + else: + curr_norm = _get_norm(current_residual) + prev_norm = _get_norm(state.previous_residual) + ratio = (curr_norm / (prev_norm + 1e-8)).mean().item() + std = (curr_norm / (prev_norm + 1e-8)).std().item() + cos_dis = ( + ( + 1 + - F.cosine_similarity( + current_residual.flatten(0, -2) if current_residual.ndim > 2 else current_residual, + state.previous_residual.flatten(0, -2) + if state.previous_residual.ndim > 2 + else state.previous_residual, + dim=-1, + eps=1e-8, + ) + ) + .mean() + .item() + ) + + state.calibration_ratios.append(ratio) + state.norm_ratios.append(round(ratio, 5)) + state.norm_stds.append(round(std, 5)) + state.cos_dises.append(round(cos_dis, 5)) + + def advance_step(self, state: MagCacheState) -> None: + state.step_index += 1 + if state.step_index >= self.config.num_inference_steps: + if self.config.mag_calibrate: + logger.info("MagCache calibration complete.") + logger.info(f"norm_ratios: {state.norm_ratios}") + logger.info(f"norm_stds: {state.norm_stds}") + logger.info(f"cos_dises: {state.cos_dises}") + logger.info("Copy these values to DiffusionCacheConfig(mag_ratios=...) for production use") + + state.step_index = 0 + state.accumulated_ratio = 1.0 + state.accumulated_steps = 0 + state.accumulated_err = 0.0 + state.previous_residual = None + state.calibration_ratios = [] + state.norm_ratios = [] + state.norm_stds = [] + state.cos_dises = [] + state._is_first_step = True + + +def apply_mag_cache_hook( + module: torch.nn.Module, + config: MagCacheConfig, + strategy: MagCacheStrategy | None = None, +) -> None: + """Apply MagCache optimization to a transformer module. + + Args: + module: Transformer model to optimize (e.g., FluxTransformer2DModel) + config: MagCacheConfig specifying caching parameters + strategy: Optional strategy to use. If None, will be looked up from registry. + """ + HookRegistry.check_if_exists_or_initialize(module) + + transformer_type = config.transformer_type + if strategy is None: + strategy = MagCacheStrategyRegistry.get_if_exists(transformer_type) + + if strategy is None: + logger.warning( + f"MagCache: No strategy found for '{transformer_type}'. " + f"Using default behavior. Available strategies: {list(MagCacheStrategyRegistry._registry.keys())}" + ) + else: + strategy_name = type(strategy).__name__ + logger.info(f"MagCache: Applying {strategy_name} for '{transformer_type}'") + if hasattr(type(strategy), "register_blocks"): + type(strategy).register_blocks() + + state_manager = StateManager(MagCacheState, (), {}) + remaining_blocks = [] + + for name, submodule in module.named_children(): + if not isinstance(submodule, torch.nn.ModuleList): + continue + for index, block in enumerate(submodule): + remaining_blocks.append((f"{name}.{index}", block)) + + if not remaining_blocks: + logger.warning("MagCache: No transformer blocks found to apply hooks.") + return + + if len(remaining_blocks) == 1: + name, block = remaining_blocks[0] + logger.info(f"MagCache: Applying Head+Tail Hook to single block '{name}'") + _apply_mag_cache_head_hook(block, state_manager, config, strategy, is_tail=True) + return + + head_block_name, head_block = remaining_blocks.pop(0) + tail_block_name, tail_block = remaining_blocks.pop(-1) + + logger.info(f"MagCache: Applying Head Hook to {head_block_name}") + _apply_mag_cache_head_hook(head_block, state_manager, config, strategy) + + for name, block in remaining_blocks: + _apply_mag_cache_block_hook(block, state_manager, config, strategy=strategy) + + logger.info(f"MagCache: Applying Tail Hook to {tail_block_name}") + _apply_mag_cache_block_hook(tail_block, state_manager, config, is_tail=True, strategy=strategy) + + +def _apply_mag_cache_head_hook( + block: torch.nn.Module, + state_manager: StateManager, + config: MagCacheConfig, + strategy: MagCacheStrategy | None = None, + is_tail: bool = False, +) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + + if registry.get_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) is not None: + registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) + + hook = MagCacheHeadHook(state_manager, config, strategy, is_tail=is_tail) + registry.register_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, hook) + + +def _apply_mag_cache_block_hook( + block: torch.nn.Module, + state_manager: StateManager, + config: MagCacheConfig, + is_tail: bool = False, + strategy: MagCacheStrategy | None = None, +) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + + if registry.get_hook(_MAG_CACHE_BLOCK_HOOK) is not None: + registry.remove_hook(_MAG_CACHE_BLOCK_HOOK) + + hook = MagCacheBlockHook(state_manager, is_tail, config, strategy) + registry.register_hook(_MAG_CACHE_BLOCK_HOOK, hook) diff --git a/vllm_omni/diffusion/cache/magcache/state.py b/vllm_omni/diffusion/cache/magcache/state.py new file mode 100644 index 00000000000..220f3709a6f --- /dev/null +++ b/vllm_omni/diffusion/cache/magcache/state.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +MagCache state management. + +This module contains the MagCacheState class which manages the internal state +for MagCache caching logic, including residuals, accumulated metrics, and step tracking. +""" + +import torch + + +class MagCacheState: + """State management for MagCache caching logic.""" + + def __init__(self) -> None: + """Initialize empty MagCache state.""" + self.previous_residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None = None + self.head_block_input: torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None = None + self.should_compute: bool = True + self.accumulated_ratio: float = 1.0 + self.accumulated_err: float = 0.0 + self.accumulated_steps: int = 0 + self.step_index: int = 0 + self.calibration_ratios: list[float] = [] + self.norm_ratios: list[float] = [] + self.norm_stds: list[float] = [] + self.cos_dises: list[float] = [] + self._is_first_step: bool = True + + def reset(self) -> None: + """Reset all state variables for a new inference run.""" + self.previous_residual = None + self.head_block_input = None + self.should_compute = True + self.accumulated_ratio = 1.0 + self.accumulated_err = 0.0 + self.accumulated_steps = 0 + self.step_index = 0 + self.calibration_ratios = [] + self.norm_ratios = [] + self.norm_stds = [] + self.cos_dises = [] + self._is_first_step = True diff --git a/vllm_omni/diffusion/cache/magcache/strategy.py b/vllm_omni/diffusion/cache/magcache/strategy.py new file mode 100644 index 00000000000..7743ff2c75a --- /dev/null +++ b/vllm_omni/diffusion/cache/magcache/strategy.py @@ -0,0 +1,488 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +MagCache strategy definitions for different model architectures. + +This module provides model-specific strategies for MagCache, allowing easy +extension to new models by implementing the MagCacheStrategy interface. + +Architecture: +- MagCacheStrategy: Abstract base class defining the strategy interface +- FluxMagCacheStrategy: Strategy for Flux (dual-stream) models +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + +import torch +from diffusers.hooks._helpers import TransformerBlockMetadata + + +class MagCacheStrategy(ABC): + """ + Abstract base class for MagCache strategies. + + Each model architecture requires a specific strategy to handle: + - Residual computation (how to calculate the residual for caching) + - Residual application (how to apply cached residual) + - Model-specific magnitude ratios + + Implement this class to add support for new model architectures. + """ + + @property + @abstractmethod + def mag_ratios(self) -> torch.Tensor: + """Return the default mag_ratios tensor for this model. + + This tensor defines caching ratios for each transformer block. + Values should be calibrated for the specific model architecture. + + Returns: + 1D tensor of mag_ratios (one per transformer block). + """ + pass + + def register_block_metadata(self, block_class: type) -> TransformerBlockMetadata | None: + """Register model-specific transformer block metadata. + + Override this method to provide custom metadata for transformer blocks + that have non-standard output formats (e.g., tuple returns). + + Args: + block_class: The transformer block class to register. + + Returns: + TransformerBlockMetadata if custom registration is needed, None otherwise. + """ + return None + + def compute_residual( + self, + output: torch.Tensor, + head_input: torch.Tensor, + ) -> torch.Tensor: + """Compute residual between block output and input. + + Default implementation: output - head_input. + Override this method for models with non-standard output formats. + + Args: + output: Output from transformer blocks. + head_input: Input to the first block. + + Returns: + Residual tensor for caching. + """ + return output - head_input + + def apply_residual( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + ) -> torch.Tensor: + """Apply cached residual to hidden states. + + Default implementation: add residual to hidden_states. + This works for most model architectures. + + Args: + hidden_states: Current hidden states. + residual: Cached residual to apply. + + Returns: + Hidden states with residual added. + """ + return hidden_states + residual + + def apply_residual_tuple( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + residual: tuple[torch.Tensor, torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """Apply cached residual tuple to both hidden_states and encoder_hidden_states. + + Default implementation: add residuals separately. + Override this method for models with specific residual application logic. + + Args: + hidden_states: Current hidden states. + encoder_hidden_states: Current encoder hidden states. + residual: Tuple of (hidden_states_residual, encoder_hidden_states_residual). + + Returns: + Tuple of (hidden_states, encoder_hidden_states) with residuals applied. + """ + h_res, e_res = residual + return hidden_states + h_res, encoder_hidden_states + e_res + + def compute_calibration_metrics( + self, + current_residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + previous_residual: torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None, + ) -> tuple[float, float, float]: + """Compute calibration metrics for mag_ratios generation. + + Default implementation computes norm ratios and cosine dissimilarity. + Override this method for models with custom metric computation. + + Args: + current_residual: Residual from the current step. + previous_residual: Residual from the previous step (None for first step). + + Returns: + Tuple of (norm_ratio, norm_std, cos_dis): + - norm_ratio: Mean ratio of current to previous residual norms + - norm_std: Standard deviation of the norm ratios + - cos_dis: Mean cosine dissimilarity (1 - cosine_similarity) + """ + import torch.nn.functional as F + + if previous_residual is None: + return 1.0, 0.0, 0.0 + + curr_norm = torch.linalg.norm(current_residual.float(), dim=-1) + prev_norm = torch.linalg.norm(previous_residual.float(), dim=-1) + + ratio = (curr_norm / (prev_norm + 1e-8)).mean().item() + std = (curr_norm / (prev_norm + 1e-8)).std().item() + cos_dis = (1 - F.cosine_similarity(current_residual, previous_residual, dim=-1, eps=1e-8)).mean().item() + + return ratio, std, cos_dis + + def get_calibration_metrics_names(self) -> tuple[str, str, str]: + """Return the names of calibration metrics for logging. + + Returns: + Tuple of metric names in order: (norm_ratio_name, norm_std_name, cos_dis_name) + """ + return ("norm_ratio", "norm_std", "cos_dis") + + +class FluxMagCacheStrategy(MagCacheStrategy): + """ + MagCache strategy for Flux (dual-stream) models. + + Flux architecture: + - transformer blocks (dual-stream): image tokens and text tokens + processed independently with separate weights + - single transformer blocks (single-stream): concatenated sequence + (image + text tokens) shares the same group of weights + - Final norm_out and proj_out layers + + This strategy provides: + - mag_ratios: Pre-computed magnitude ratios for Flux (28 steps) + - compute_residual: Handles tuple output format + - apply_residual_tuple: Handles decoder residual only + """ + + FLUX_MAG_RATIOS = torch.tensor( + [ + 1.0, + 1.07313, + 1.21035, + 1.04432, + 1.06818, + 1.05547, + 1.0183, + 1.03405, + 1.02574, + 1.03042, + 1.02739, + 1.01955, + 1.01585, + 1.02439, + 1.01154, + 1.01377, + 1.00994, + 1.01444, + 1.00839, + 1.02269, + 1.0007, + 1.00714, + 1.00484, + 1.01381, + 1.00426, + 0.99764, + 1.00778, + 1.00233, + ] + ) + + @property + def mag_ratios(self) -> torch.Tensor: + """Return default mag_ratios for Flux model.""" + return self.FLUX_MAG_RATIOS + + def compute_residual( + self, + output: torch.Tensor, + head_input: torch.Tensor, + ) -> torch.Tensor: + """Compute residual for Flux output format (tuple or single tensor). + + Flux single transformer blocks return a tuple, so we extract + the decoder output (index 1) before computing residual. + """ + if isinstance(output, tuple): + decoder_output = output[1] if len(output) > 1 else output[0] + return decoder_output - head_input + else: + return output - head_input + + def apply_residual_tuple( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + residual: tuple[torch.Tensor, torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """Apply residual tuple for Flux - only add decoder residual. + + Flux has separate image and text processing, so the residual + is only applied to the decoder (image) branch. + """ + if isinstance(residual, tuple): + decoder_residual = residual[1] + else: + decoder_residual = residual + + output = hidden_states + decoder_residual + enc_output = encoder_hidden_states + + return output, enc_output + + @staticmethod + def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor: + """Interpolate mag_ratios to target length using nearest neighbor.""" + src_length = len(src_array) + if target_length == 1: + return src_array[-1:] + + scale = (src_length - 1) / (target_length - 1) + grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32) + mapped_indices = torch.round(grid * scale).long() + + return src_array[mapped_indices] + + +class Flux2MagCacheStrategy(FluxMagCacheStrategy): + """MagCache strategy for Flux2 model. + + Flux2 shares the same dual-stream architecture as Flux, but may have + different tensor shapes in some transformer blocks, requiring special + handling in residual computation. + """ + + FLUX2_MAG_RATIOS = torch.tensor( + [ + 1.0, + 0.96528, + 1.11559, + 1.0565, + 1.00425, + 1.0805, + 0.98616, + 1.09289, + 1.03196, + 1.06679, + 1.03941, + 1.05375, + 1.03128, + 1.05349, + 1.01983, + 1.05535, + 1.0662, + 1.05748, + 1.00318, + 1.05222, + 1.04556, + 1.0506, + 1.05058, + 1.05219, + 1.02025, + 1.05052, + 1.04143, + 1.0498, + ] + ) + + @property + def mag_ratios(self) -> torch.Tensor: + """Return default mag_ratios for Flux2 model.""" + return self.FLUX2_MAG_RATIOS + + def register_block_metadata(self, block_class: type) -> TransformerBlockMetadata | None: + """Register Flux2-specific block metadata based on block type. + + Flux2 has two block types with different output formats: + - Dual-stream (Flux2TransformerBlock): returns (encoder_hidden_states, hidden_states) + - Single-stream (Flux2SingleTransformerBlock): returns single tensor + """ + class_name = block_class.__name__ + is_single_stream = "Single" in class_name + + if is_single_stream: + return TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ) + else: + return TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ) + + def compute_residual( + self, + output: torch.Tensor, + head_input: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """Compute residual for Flux2 with dual-stream support. + + For dual-stream blocks, computes residual for both encoder and decoder branches. + For single-stream blocks: if shapes match, computes output - input; otherwise returns output. + """ + if isinstance(output, tuple): + enc_output, dec_output = output[0], output[1] + + if isinstance(head_input, tuple): + enc_head, dec_head = head_input[0], head_input[1] + else: + enc_head = head_input[:, : enc_output.shape[1], ...] + dec_head = head_input[:, enc_output.shape[1] :, ...] + + enc_residual = enc_output - enc_head + dec_residual = dec_output - dec_head + + return (enc_residual, dec_residual) + else: + if output.shape == head_input.shape: + return output - head_input + else: + return output + + def apply_residual_tuple( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + residual: tuple[torch.Tensor, torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """Apply residual tuple for Flux2 with shape mismatch handling. + + For Flux2 dual-stream blocks, handles the case where hidden_states + and decoder_residual may have different shapes. + """ + if isinstance(residual, tuple): + decoder_residual = residual[1] + else: + decoder_residual = residual + + if hidden_states.shape == decoder_residual.shape: + output = hidden_states + decoder_residual + else: + output = hidden_states + enc_output = encoder_hidden_states + + return output, enc_output + + def apply_residual( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + ) -> torch.Tensor: + """Apply residual for Flux2. + + For single-stream blocks: if shapes match, adds residual; otherwise returns input. + For dual-stream blocks: applies decoder residual only. + """ + if isinstance(residual, tuple): + dec_residual = residual[1] + if hidden_states.shape == dec_residual.shape: + return hidden_states + dec_residual + else: + return hidden_states + else: + if residual.shape == hidden_states.shape: + return hidden_states + residual + else: + return hidden_states + + +class MagCacheStrategyRegistry: + """Registry for MagCache strategies by transformer type.""" + + _registry: dict[str, MagCacheStrategy] = {} + + @classmethod + def register(cls, name: str, strategy: MagCacheStrategy) -> None: + """Register a strategy with explicit name. + + Args: + name: Transformer model type identifier (e.g., "FluxTransformer2DModel") + strategy: MagCacheStrategy instance + """ + cls._registry[name] = strategy + + @classmethod + def get(cls, transformer_type: str) -> MagCacheStrategy: + """Get strategy for given transformer type.""" + if transformer_type not in cls._registry: + available = list(cls._registry.keys()) + raise ValueError(f"Unknown model type: '{transformer_type}'. Available types: {available}") + return cls._registry[transformer_type] + + @classmethod + def get_if_exists(cls, transformer_type: str) -> MagCacheStrategy | None: + """Get strategy if exists, None otherwise.""" + return cls._registry.get(transformer_type) + + +MagCacheStrategyRegistry.register("FluxTransformer2DModel", FluxMagCacheStrategy()) + +MagCacheStrategyRegistry.register("Flux2Transformer2DModel", Flux2MagCacheStrategy()) + + +def register_strategy( + transformer_cls_name: str, + strategy: MagCacheStrategy, +) -> None: + """Register a MagCache strategy for a model type. + + This allows extending MagCache support to new models without modifying + the core MagCache code. + + Args: + transformer_cls_name: Transformer model type identifier (class name or type string) + Must match pipeline.transformer.__class__.__name__ + strategy: MagCacheStrategy instance for this model type + + Example: + >>> class MyModelMagCacheStrategy(MagCacheStrategy): + ... @property + ... def mag_ratios(self): + ... return torch.tensor([...]) + >>> register_strategy("MyModelTransformer", MyModelMagCacheStrategy()) + """ + MagCacheStrategyRegistry.register(transformer_cls_name, strategy) + + +def get_strategy(transformer_cls_name: str) -> MagCacheStrategy: + """Get strategy function for given transformer class. + + This function looks up the strategy based on the exact transformer_cls_name string, + which should match the transformer type in the pipeline (i.e., pipeline.transformer.__class__.__name__). + + Args: + transformer_cls_name: Transformer class name (e.g., "FluxTransformer2DModel") + Must exactly match a registered strategy. + + Returns: + MagCacheStrategy instance for the model + + Raises: + ValueError: If model type not found in registry + """ + return MagCacheStrategyRegistry.get(transformer_cls_name) diff --git a/vllm_omni/diffusion/cache/selector.py b/vllm_omni/diffusion/cache/selector.py index 7c09bf66475..e6c2e70e318 100644 --- a/vllm_omni/diffusion/cache/selector.py +++ b/vllm_omni/diffusion/cache/selector.py @@ -2,6 +2,7 @@ from vllm_omni.diffusion.cache.base import CacheBackend from vllm_omni.diffusion.cache.cache_dit_backend import CacheDiTBackend +from vllm_omni.diffusion.cache.magcache.backend import MagCacheBackend from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend from vllm_omni.diffusion.data import DiffusionCacheConfig @@ -12,14 +13,15 @@ def get_cache_backend(cache_backend: str | None, cache_config: Any) -> CacheBack This is a selector function that routes to the appropriate backend implementation. - cache_dit: Uses CacheDiTBackend with enable()/refresh() interface - tea_cache: Uses TeaCacheBackend with enable()/refresh() interface + - mag_cache: Uses MagCacheBackend with enable()/refresh() interface Args: - cache_backend: Cache backend name ("cache_dit", "tea_cache", or None). + cache_backend: Cache backend name ("cache_dit", "tea_cache", "mag_cache", or None). cache_config: Cache configuration (dict or DiffusionCacheConfig instance). Returns: - Cache backend instance (CacheDiTBackend or TeaCacheBackend) if cache_backend is set, - None otherwise. + Cache backend instance (CacheDiTBackend, TeaCacheBackend, or MagCacheBackend) + if cache_backend is set, None otherwise. Raises: ValueError: If cache_backend is unsupported. @@ -34,5 +36,9 @@ def get_cache_backend(cache_backend: str | None, cache_config: Any) -> CacheBack return CacheDiTBackend(cache_config) elif cache_backend == "tea_cache": return TeaCacheBackend(cache_config) + elif cache_backend == "mag_cache": + return MagCacheBackend(cache_config) else: - raise ValueError(f"Unsupported cache backend: {cache_backend}. Supported: 'cache_dit', 'tea_cache'") + raise ValueError( + f"Unsupported cache backend: {cache_backend}. Supported: 'cache_dit', 'tea_cache', 'mag_cache'" + ) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 0675c226aee..e4f34df290e 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -281,7 +281,7 @@ def __getattr__(self, item: str) -> Any: @dataclass class DiffusionCacheConfig: """ - Configuration for cache adapters (TeaCache, cache-dit, etc.). + Configuration for cache adapters (TeaCache, cache-dit, MagCache, etc.). This dataclass provides a unified interface for cache configuration parameters. It can be initialized from a dictionary and accessed via attributes. @@ -291,6 +291,8 @@ class DiffusionCacheConfig: - cache-dit: Fn_compute_blocks, Bn_compute_blocks, max_warmup_steps, residual_diff_threshold, enable_taylorseer, taylorseer_order, scm_steps_mask_policy, scm_steps_policy + - MagCache: mag_threshold, mag_max_skip_steps, mag_retention_ratio, + mag_ratios, mag_calibrate Example: >>> # From dict (user-facing API) - partial config uses defaults for missing keys @@ -300,7 +302,7 @@ class DiffusionCacheConfig: >>> print(config.Fn_compute_blocks) # 8 (default) >>> # Empty dict uses all defaults >>> default_config = DiffusionCacheConfig.from_dict({}) - >>> print(default_config.rel_l1_thresh) # 0.2 (default) + >>> print(config.rel_l1_thresh) # 0.2 (default) """ # TeaCache parameters [tea_cache only] @@ -308,6 +310,18 @@ class DiffusionCacheConfig: rel_l1_thresh: float = 0.2 coefficients: list[float] | None = None # Uses model-specific defaults if None + # MagCache parameters [mag_cache only] + # Default: 0.24 threshold for accumulated magnitude error + mag_threshold: float = 0.24 + # Default: 5 maximum consecutive skip steps (K) + mag_max_skip_steps: int = 5 + # Default: 0.1 fraction of initial steps where skipping is disabled (stability) + mag_retention_ratio: float = 0.1 + # Default: None magnitude ratios (model-specific, required for inference) + mag_ratios: list[float] | None = None + # Default: False calibration mode (computes mag_ratios on first run) + mag_calibrate: bool = False + # cache-dit parameters [cache-dit only] # Default: 1 forward compute block (optimized for single-transformer models) # Use 1 as default instead of cache-dit's 8, optimized for single-transformer models diff --git a/vllm_omni/diffusion/hooks/base.py b/vllm_omni/diffusion/hooks/base.py index 517c6615877..6863a459425 100644 --- a/vllm_omni/diffusion/hooks/base.py +++ b/vllm_omni/diffusion/hooks/base.py @@ -15,6 +15,9 @@ from typing import Any import torch.nn as nn +from vllm.logger import init_logger + +logger = init_logger(__name__) class BaseState: @@ -27,17 +30,24 @@ def reset(self) -> None: # pragma: no cover - default is no-op class StateManager: """Manage per-context hook state instances.""" - def __init__(self, state_cls: Callable[[], BaseState]): + def __init__(self, state_cls: Callable[[], BaseState], init_args: tuple = (), init_kwargs: dict | None = None): self._state_cls = state_cls + self._init_args = init_args + self._init_kwargs = init_kwargs or {} self._states: dict[str, BaseState] = {} self._context: str = "default" + @property + def _current_context(self) -> str | None: + """Alias for _context for compatibility with diffusers hook code.""" + return self._context if self._context != "default" else None + def set_context(self, name: str) -> None: self._context = name or "default" def get_state(self) -> BaseState: if self._context not in self._states: - self._states[self._context] = self._state_cls() + self._states[self._context] = self._state_cls(*self._init_args, **self._init_kwargs) return self._states[self._context] def reset(self) -> None: @@ -190,6 +200,22 @@ def get_or_create(cls, module: nn.Module) -> HookRegistry: return registry + @classmethod + def check_if_exists_or_initialize(cls, module: nn.Module) -> HookRegistry: + """Get existing registry or create a new one for the module. + + This method ensures a HookRegistry exists on the module and returns it. + If a registry doesn't exist, it creates one and attaches it to the module. + This is equivalent to get_or_create() for compatibility with diffusers API. + + Args: + module: The module to get/create a registry for. + + Returns: + The HookRegistry for this module. + """ + return cls.get_or_create(module) + def update_sorted_hooks(self): """Sort hooks by name, which dictates pre/post process order.""" sorted_hooks = [self._hooks[k] for k in sorted(self._hooks) if self._hooks[k] != self._new_fwd_impl_hook] @@ -205,7 +231,23 @@ def register_hook(self, name: str, hook: ModelHook) -> None: name: Unique name for this hook. hook: The hook instance to register. """ + if name in self._hooks: + logger.warning(f"Hook with name '{name}' already exists. Overwriting existing hook.") + self.remove_hook(name) + hook.initialize_hook(self.module) + + if hasattr(hook, "fn_ref"): + hook.fn_ref.original_forward = self.module._omni_original_forward + else: + original_forward = self.module._omni_original_forward # type: ignore[attr-defined] + + class _FnRef: + def __init__(self, orig_forward): + self.original_forward = orig_forward + + hook.fn_ref = _FnRef(original_forward) + self._hooks[name] = hook # We can only have one hook that overrides new_forward, # since we don't currently have a mechanism for combining them. @@ -290,3 +332,17 @@ def reset_hook(self, name: str) -> None: hook = self._hooks.get(name) if hook is not None: hook.reset_state(self.module) + + def reset(self) -> None: + """Reset all hooks and clear the registry. + + This removes all hooks from the registry and resets each hook's state. + Also restores module.forward to its original implementation. + """ + for name, hook in list(self._hooks.items()): + hook.reset_state(self.module) + self._hooks.clear() + + if hasattr(self.module, "_omni_original_forward"): + self.module.forward = self.module._omni_original_forward # type: ignore[attr-defined] + delattr(self.module, "_omni_original_forward") diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 4986eae63c9..71b8d5c1fb3 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -1784,6 +1784,12 @@ def _get_default_cache_config(cache_backend: str | None) -> dict[str, Any] | Non return { "rel_l1_thresh": 0.2, } + if cache_backend == "mag_cache": + return { + "mag_threshold": 0.24, + "mag_max_skip_steps": 5, + "mag_retention_ratio": 0.1, + } return None @staticmethod diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index cc6e9a4dabb..a9f3210f0be 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -507,13 +507,16 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu "--cache-backend", type=str, default="none", - help="Cache backend for diffusion models, options: 'tea_cache', 'cache_dit'", + help="Cache backend for diffusion models, options: 'tea_cache', 'cache_dit', 'mag_cache'", ) omni_config_group.add_argument( "--cache-config", type=str, default=None, - help="JSON string of cache configuration (e.g., '{\"rel_l1_thresh\": 0.2}').", + help="JSON string of cache configuration. " + "TeaCache: '{\"rel_l1_thresh\": 0.2}'. " + 'MagCache: \'{"mag_threshold": 0.24, "mag_max_skip_steps": 5, "mag_retention_ratio": 0.1}\'. ' + "Calibration mode: add '\"mag_calibrate\": true'", ) omni_config_group.add_argument( "--enable-cache-dit-summary",