Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions tests/diffusion/cache/test_cache_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
CUSTOM_DIT_ENABLERS,
CacheDiTBackend,
)
from vllm_omni.diffusion.cache.magcache.backend import MagCacheBackend
from vllm_omni.diffusion.cache.selector import get_cache_backend
from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend
from vllm_omni.diffusion.data import DiffusionCacheConfig
Expand Down Expand Up @@ -313,3 +314,150 @@ def test_get_cache_backend_invalid(self):
"""Test getting invalid backend raises error."""
with pytest.raises(ValueError, match="Unsupported cache backend"):
get_cache_backend("invalid_backend", {})


class TestMagCacheBackend:
"""Test MagCacheBackend implementation."""

from vllm_omni.diffusion.cache.magcache.backend import MagCacheBackend

def test_init(self):
"""Test initialization."""
config = DiffusionCacheConfig(mag_threshold=0.1, mag_max_skip_steps=2, mag_calibrate=True)
backend = MagCacheBackend(config)
assert backend.config.mag_threshold == 0.1
assert backend.config.mag_max_skip_steps == 2
assert backend.enabled is False

@patch("vllm_omni.diffusion.cache.magcache.backend.apply_mag_cache_hook")
def test_enable(self, mock_apply_hook):
"""Test enabling MagCache on pipeline."""
mock_pipeline = Mock()
mock_pipeline.__class__.__name__ = "FluxPipeline"
mock_transformer = Mock()
mock_transformer.__class__.__name__ = "FluxTransformer2DModel"
mock_pipeline.transformer = mock_transformer

mock_ratios = [1.0] * 28
config = DiffusionCacheConfig(
mag_ratios=mock_ratios,
)
backend = MagCacheBackend(config)
backend.enable(mock_pipeline)

assert backend.enabled is True
mock_apply_hook.assert_called_once()

call_args = mock_apply_hook.call_args
assert call_args[0][0] == mock_transformer

@patch("vllm_omni.diffusion.cache.magcache.backend.apply_mag_cache_hook")
def test_enable_with_calibration(self, mock_apply_hook):
"""Test enabling MagCache in calibration mode."""
mock_pipeline = Mock()
mock_pipeline.__class__.__name__ = "FluxPipeline"
mock_transformer = Mock()
mock_transformer.__class__.__name__ = "FluxTransformer2DModel"
mock_pipeline.transformer = mock_transformer

config = DiffusionCacheConfig(
mag_calibrate=True,
)
backend = MagCacheBackend(config)
backend.enable(mock_pipeline)

assert backend.enabled is True
mock_apply_hook.assert_called_once()

def test_refresh(self):
"""Test refreshing MagCache state calls enable when not registered."""
mock_pipeline = Mock()
mock_pipeline.__class__.__name__ = "FluxPipeline"
mock_transformer = Mock()
mock_transformer.__class__.__name__ = "FluxTransformer2DModel"
mock_pipeline.transformer = mock_transformer

mock_transformer.named_children = Mock(return_value=[])

mock_ratios = [1.0] * 28
config = DiffusionCacheConfig(
mag_ratios=mock_ratios,
)
backend = MagCacheBackend(config)

assert backend._registered is False
backend.refresh(mock_pipeline, num_inference_steps=50)
assert backend._registered is True

def test_is_enabled(self):
"""Test is_enabled method."""
mock_ratios = [1.0] * 28
config = DiffusionCacheConfig(mag_ratios=mock_ratios)
backend = MagCacheBackend(config)
assert backend.is_enabled() is False

def test_get_mag_cache_backend(self):
"""Test getting MagCache backend via selector."""
mock_ratios = [1.0] * 28
config_dict = {
"mag_ratios": mock_ratios,
"num_inference_steps": 28,
"threshold": 0.06,
"max_skip_steps": 3,
"retention_ratio": 0.2,
}
backend = get_cache_backend("mag_cache", config_dict)
assert backend is not None
assert isinstance(backend, MagCacheBackend)
assert backend.config.threshold == 0.06

@patch("vllm_omni.diffusion.cache.magcache.backend.apply_mag_cache_hook")
def test_enable_single_block(self, mock_apply_hook):
"""Test enabling MagCache on single transformer block."""
mock_pipeline = Mock()
mock_pipeline.__class__.__name__ = "FluxPipeline"

mock_block = Mock()
mock_block.__class__.__name__ = "FluxTransformer2DModel"
mock_blocks = [mock_block]

mock_transformer = Mock()
mock_transformer.__class__.__name__ = "FluxTransformer2DModel"
mock_transformer.blocks = mock_blocks
mock_pipeline.transformer = mock_transformer

mock_ratios = [1.0] * 28
config = DiffusionCacheConfig(
mag_ratios=mock_ratios,
)
backend = MagCacheBackend(config)
backend.enable(mock_pipeline)

assert backend.enabled is True
mock_apply_hook.assert_called_once()

call_args = mock_apply_hook.call_args
assert call_args[0][0] == mock_transformer

@patch("vllm_omni.diffusion.cache.magcache.backend.apply_mag_cache_hook")
def test_enable_multi_block(self, mock_apply_hook):
"""Test enabling MagCache on multiple transformer blocks."""
mock_pipeline = Mock()
mock_pipeline.__class__.__name__ = "FluxPipeline"

mock_blocks = [Mock() for _ in range(24)]

mock_transformer = Mock()
mock_transformer.__class__.__name__ = "FluxTransformer2DModel"
mock_transformer.blocks = mock_blocks
mock_pipeline.transformer = mock_transformer

mock_ratios = [1.0] * 28
config = DiffusionCacheConfig(
mag_ratios=mock_ratios,
)
backend = MagCacheBackend(config)
backend.enable(mock_pipeline)

assert backend.enabled is True
mock_apply_hook.assert_called_once()
3 changes: 2 additions & 1 deletion vllm_omni/diffusion/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

This module provides a unified cache backend system for different caching strategies:
- TeaCache: Timestep Embedding Aware Cache for adaptive transformer caching
- MagCache: Magnitude-based Cache for adaptive transformer caching
- cache-dit: DBCache, SCM, and TaylorSeer caching strategies

Cache backends are instantiated directly via their constructors and configured via OmniDiffusionConfig.
Expand All @@ -20,8 +21,8 @@

__all__ = [
"CacheBackend",
"TeaCacheConfig",
"CacheContext",
"TeaCacheBackend",
"TeaCacheConfig",
Comment thread
RuixiangMa marked this conversation as resolved.
"apply_teacache_hook",
]
32 changes: 32 additions & 0 deletions vllm_omni/diffusion/cache/magcache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm_omni.diffusion.cache.magcache.config import MagCacheConfig
from vllm_omni.diffusion.cache.magcache.hook import (
MagCacheBlockHook,
MagCacheHeadHook,
MagCacheState,
apply_mag_cache_hook,
)
from vllm_omni.diffusion.cache.magcache.strategy import (
Flux2MagCacheStrategy,
FluxMagCacheStrategy,
MagCacheStrategy,
MagCacheStrategyRegistry,
get_strategy,
register_strategy,
)

__all__ = [
"Flux2MagCacheStrategy",
"FluxMagCacheStrategy",
"MagCacheBlockHook",
"MagCacheConfig",
"MagCacheHeadHook",
"MagCacheState",
"MagCacheStrategy",
"MagCacheStrategyRegistry",
"apply_mag_cache_hook",
"get_strategy",
"register_strategy",
]
177 changes: 177 additions & 0 deletions vllm_omni/diffusion/cache/magcache/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""
MagCache backend implementation.

This module provides the MagCache backend that implements the CacheBackend
interface using the hooks-based MagCache system.
"""

from typing import Any

import torch
from vllm.logger import init_logger

from vllm_omni.diffusion.cache.base import CacheBackend
from vllm_omni.diffusion.cache.magcache.config import MagCacheConfig
from vllm_omni.diffusion.cache.magcache.hook import (
apply_mag_cache_hook,
)
from vllm_omni.diffusion.cache.magcache.strategy import (
get_strategy,
)
from vllm_omni.diffusion.data import DiffusionCacheConfig

logger = init_logger(__name__)


class MagCacheBackend(CacheBackend):
"""
MagCache implementation using hooks.

MagCache (Magnitude-based Cache) is an adaptive caching technique that
speeds up diffusion inference by reusing transformer block computations
based on accumulated magnitude error between timesteps.

The backend applies MagCache hooks to the transformer which intercept the
forward pass and implement the caching logic transparently.

Example:
>>> from vllm_omni.diffusion.data import DiffusionCacheConfig
>>> from vllm_omni.diffusion.cache.magcache.config import MagCacheConfig
>>> from vllm_omni.diffusion.cache.magcache.strategy import FluxMagCacheStrategy
>>> cache_config = DiffusionCacheConfig(
... mag_ratios=FluxMagCacheStrategy.FLUX_MAG_RATIOS,
... num_inference_steps=28,
... mag_threshold=0.24,
... mag_max_skip_steps=5,
... mag_retention_ratio=0.1,
... )
>>> backend = MagCacheBackend(cache_config)
>>> backend.enable(pipeline)
>>> backend.refresh(pipeline, num_inference_steps=50)
"""

def __init__(self, config: DiffusionCacheConfig):
super().__init__(config)
self._registered = False
self._magcache_config: MagCacheConfig | None = None
self._transformer_id: int | None = None

def enable(self, pipeline: Any) -> None:
"""Enable MagCache on transformer using hooks.

This creates a MagCacheConfig from the backend's DiffusionCacheConfig
and applies the MagCache hook to the transformer.

Args:
pipeline: Diffusion pipeline instance. Extracts transformer and transformer_type:
- transformer: pipeline.transformer
- transformer_type: pipeline.transformer.__class__.__name__
"""
transformer = pipeline.transformer
transformer_type = transformer.__class__.__name__

num_inference_steps = self.config.num_inference_steps or 28

mag_ratios = self.config.mag_ratios
strategy = None

if mag_ratios is None and not self.config.mag_calibrate:
strategy = get_strategy(transformer_type)
original_ratios = strategy.mag_ratios

if len(original_ratios) != num_inference_steps and hasattr(strategy, "nearest_interp"):
mag_ratios = strategy.nearest_interp(original_ratios, num_inference_steps)
logger.info(
f"MagCache: Interpolated mag_ratios from {len(original_ratios)} to {num_inference_steps} steps"
)
else:
mag_ratios = original_ratios
if len(original_ratios) != num_inference_steps:
logger.warning(
f"MagCache: mag_ratios length ({len(original_ratios)}) != "
f"num_inference_steps ({num_inference_steps}), "
f"this may cause unexpected behavior"
)

logger.info(f"MagCache: Using mag_ratios from {type(strategy).__name__}")

if mag_ratios is None and not self.config.mag_calibrate:
raise ValueError(
f"mag_ratios must be provided for MagCache. "
f"For {transformer_type}, you need to provide mag_ratios or run in calibrate mode."
)

self._magcache_config = MagCacheConfig(
transformer_type=transformer_type,
threshold=self.config.mag_threshold,
max_skip_steps=self.config.mag_max_skip_steps,
retention_ratio=self.config.mag_retention_ratio,
num_inference_steps=num_inference_steps,
mag_calibrate=self.config.mag_calibrate,
mag_ratios=mag_ratios if not self.config.mag_calibrate else None,
)
self._transformer_id = id(transformer)

apply_mag_cache_hook(transformer, self._magcache_config, strategy=strategy)

self._registered = True
self.enabled = True

def refresh(self, pipeline: Any, num_inference_steps: int) -> None:
"""Refresh MagCache state for new generation.

Clears all cached residuals and resets counters/accumulators.
Should be called before each generation to ensure clean state.

Args:
pipeline: Diffusion pipeline instance. Extracts transformer via pipeline.transformer.
num_inference_steps: Number of inference steps for the current generation.
May be used for cache context updates.
"""
transformer = pipeline.transformer
current_transformer_id = id(transformer)

needs_re_register = False

if self._registered and hasattr(self, "_transformer_id"):
if current_transformer_id != self._transformer_id:
logger.warning(
f"Transformer was replaced (id changed from {self._transformer_id} "
f"to {current_transformer_id}), re-registering hooks"
)
needs_re_register = True

if not self._registered or needs_re_register:
self.enable(pipeline)
return

blocks_with_hooks = []

for name, submodule in transformer.named_children():
if not isinstance(submodule, torch.nn.ModuleList):
continue
for index, block in enumerate(submodule):
registry = getattr(block, "_hook_registry", None)
if registry is not None and len(registry._hooks) > 0:
blocks_with_hooks.append((f"{name}.{index}", block, registry))

if not blocks_with_hooks:
logger.warning("No hooks found on transformer blocks, re-registering")
apply_mag_cache_hook(transformer, self._magcache_config)
self._transformer_id = current_transformer_id
else:
for name, block, registry in blocks_with_hooks:
for hook in registry._hooks.values():
if hasattr(hook, "reset_state"):
hook.reset_state(block)

def is_enabled(self) -> bool:
"""Check if MagCache is enabled.

Returns:
True if enabled, False otherwise.
"""
return self.enabled
Loading
Loading