Skip to content

[Feat] Support MagCache#1287

Open
RuixiangMa wants to merge 21 commits into
vllm-project:mainfrom
RuixiangMa:supportmagcache
Open

[Feat] Support MagCache#1287
RuixiangMa wants to merge 21 commits into
vllm-project:mainfrom
RuixiangMa:supportmagcache

Conversation

@RuixiangMa
Copy link
Copy Markdown
Collaborator

@RuixiangMa RuixiangMa commented Feb 9, 2026

MagCache Integration

1. Overview

MagCache (Magnitude-based Cache) accelerates diffusion model inference by reusing transformer block computations. It decides whether to skip computation based on the residual magnitude ratio between consecutive timesteps.

Reference: https://github.com/Zehong-Ma/MagCache and https://github.com/huggingface/diffusers

2. Architecture

graph TB
    subgraph "ConfigSection"
        ConfigData[MagCacheConfig]
    end

    subgraph "StateLayer"
        StateData[MagCacheState]
    end

    subgraph "StrategyLayer"
        StrategyDef[MagCacheStrategy]
    end

    subgraph "Hooks"
        Head[MagCacheHeadHook]
        Block[MagCacheBlockHook]
    end

    ConfigData --> Head
    ConfigData --> Block
    StateData --> Head
    StateData --> Block
    StrategyDef --> Head
    StrategyDef --> Block
Loading

Component Responsibilities:

  • MagCacheHeadHook (entry): decides whether to skip computation
  • MagCacheBlockHook (exit): computes/stores residuals
graph LR
    Input[hidden_states] --> Head
    Head -->|compute| Blocks
    Head -->|skip| Tail
    Blocks --> Tail
    Tail --> Output[output]
Loading

3. Usage

3.1 Quick Start

from vllm_omni.diffusion.cache.magcache import MagCacheConfig
from vllm_omni.diffusion.cache.magcache.strategy import FluxMagCacheStrategy
from vllm_omni.diffusion.cache.base import DiffusionCacheConfig

cache_config = DiffusionCacheConfig(
    mag_ratios=FluxMagCacheStrategy.FLUX_MAG_RATIOS,
    num_inference_steps=28,
    mag_threshold=0.24,
    mag_max_skip_steps: int = 5,
    mag_retention_ratio=0.1,
)

3.2 Calibration Mode

Run calibration when using a new model or changing scheduler to get optimal mag_ratios:

calibrate_config = DiffusionCacheConfig(
    calibrate=True,
    num_inference_steps=28,
)
# After inference, output will show:
# [INFO] norm_ratios: [1.0, 1.07313, 1.21035, ...]
# [INFO] norm_stds: [0.0, 0.0348, 0.0156, ...]
# [INFO] cos_dises: [0.0, 0.304, 0.148, ...]
# [INFO] Copy these values to DiffusionCacheConfig(mag_ratios=...) for production use

Calibration output:

  • norm_ratios: use directly as mag_ratios
  • norm_stds: residual fluctuation per step (reference)
  • cos_dises: residual direction change per step (reference)

4. Adapting New Models

4.1 Overview

To add MagCache support for a new model, you need to implement a MagCacheStrategy subclass and register it. The strategy handles model-specific logic for:

  • mag_ratios: Pre-computed magnitude ratios for each transformer block
  • compute_residual: How to calculate residual (override only if model has special output format)
  • apply_residual_tuple: How to apply residual for tuple outputs (override for dual-stream models)

4.2 Minimal Implementation

For models with standard Diffusion output format (output = hidden_states + residual):

from vllm_omni.diffusion.cache.magcache.strategy import (
    MagCacheStrategy,
    register_strategy,
)
import torch

class MyModelMagCacheStrategy(MagCacheStrategy):
    """MagCache strategy for MyModel."""

    @property
    def mag_ratios(self) -> torch.Tensor:
        """Pre-computed mag_ratios for this model (28 values for 28 inference steps)."""
        return torch.tensor([
            1.0, 1.05, 1.08, 1.04, 1.06, 1.03, 1.02, 1.04,
            1.03, 1.02, 1.03, 1.02, 1.01, 1.02, 1.01, 1.02,
            1.01, 1.02, 1.01, 1.03, 1.00, 1.01, 1.01, 1.02,
            1.01, 1.00, 1.01, 1.01,
        ])

# Register the strategy (transformer_cls_name must match pipeline.transformer.__class__.__name__)
register_strategy("MyTransformer2DModel", MyModelMagCacheStrategy())

Note: transformer_cls_name must exactly match pipeline.transformer.__class__.__name__.

4.3 Models with Special Output Format

If your model returns a tuple (e.g., dual-stream architectures like Flux), override compute_residual and apply_residual_tuple:

class MyModelMagCacheStrategy(MagCacheStrategy):
    """MagCache strategy for dual-stream MyModel."""

    def register_block_metadata(self, block_class: type) -> TransformerBlockMetadata | None:
       """"Register model-specific transformer block metadata."""
      return TransformerBlockMetadata(...)

    @property
    def mag_ratios(self) -> torch.Tensor:
        return torch.tensor([...])

    def compute_residual(
        self,
        output: torch.Tensor,
        head_input: torch.Tensor,
    ) -> torch.Tensor:
        """Handle tuple output: (encoder_hidden_states, hidden_states)."""
        if isinstance(output, tuple):
            decoder_output = output[1] if len(output) > 1 else output[0]
            return decoder_output - head_input
        return output - head_input

    def apply_residual_tuple(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        residual: tuple[torch.Tensor, torch.Tensor],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Apply residual only to decoder branch."""
        decoder_residual = residual[1] if isinstance(residual, tuple) else residual
        return hidden_states + decoder_residual, encoder_hidden_states

register_strategy("MyDualStreamTransformer", MyModelMagCacheStrategy())

4.4 Block Registration and Metadata

MagCache uses TransformerBlockRegistry to get metadata about transformer blocks. Each strategy can provide custom metadata for different block types.

Auto-Registration with Strategy

# In MagCacheHeadHook.initialize_hook / MagCacheBlockHook.initialize_hook
try:
    self._metadata = TransformerBlockRegistry.get(block_class)
except ValueError:
    if self._strategy is not None:
        metadata = self._strategy.register_block_metadata(block_class)
        if metadata is not None:
            TransformerBlockRegistry.register(model_class=block_class, metadata=metadata)
        else:
            # Default registration
            TransformerBlockRegistry.register(
                model_class=block_class,
                metadata=TransformerBlockMetadata(
                    return_hidden_states_index=1,
                    return_encoder_hidden_states_index=0,
                ),
            )
    else:
        TransformerBlockRegistry.register(
            model_class=block_class,
            metadata=TransformerBlockMetadata(
                return_hidden_states_index=1,
                return_encoder_hidden_states_index=0,
            ),
        )
    self._metadata = TransformerBlockRegistry.get(block_class)

Base Strategy Method

Override register_block_metadata in your strategy to handle model-specific block types:

def register_block_metadata(self, block_class: type) -> TransformerBlockMetadata | None:
    """Register model-specific block metadata.

    Return TransformerBlockMetadata for custom registration,
    or None to use default indices.

    Args:
        block_class: The transformer block class to register

    Returns:
        TransformerBlockMetadata if custom registration is needed, None otherwise
    """
    return None

Default Indices

Block Type return_hidden_states_index return_encoder_hidden_states_index
Single output tensor 0 None
Tuple output (enc, dec) 1 0

To determine indices: check the return format of block.forward().

5. Core Parameters

This default configuration is consistent with the official MagCache implementation.

Parameter Default Description
mag_threshold 0.24 accumulated error threshold, larger = faster
mag_max_skip_steps 5 max consecutive skip steps
mag_retention_ratio 0.1 first 10% steps never skip
num_inference_steps 28 total inference steps
mag_ratios - magnitude ratio per block

6. Workflow

sequenceDiagram
    participant Head as Head Hook
    participant Blocks as Transformer
    participant Block as Tail Hook

    Note over Head: each step start
    Head->>Head: check accumulated_err
    alt accumulated_err <= threshold
        Head->>Head: skip computation, apply cached residual
        Head-->>Block: return output
    else need compute
        Head->>Blocks: execute transformer
        Blocks-->>Block: return output
        Block->>Block: compute residual = output - input
        Block->>Block: store residual
    end
Loading

7. Hook Responsibilities

MagCacheHeadHook vs MagCacheBlockHook

Aspect MagCacheHeadHook MagCacheBlockHook
Position first block (entry) last block (exit)
Responsibility decides to skip computation computes/stores residuals
Core Logic decides should_compute based on accumulated error skip = return input, compute = compute and store residual
graph LR
    subgraph "Inside Transformer"
        Head[MagCacheHeadHook<br/>entry]
        Blocks[All Transformer Blocks]
        Tail[MagCacheBlockHook<br/>exit]
    end

    Input[hidden_states] --> Head
    Head -->|need compute| Blocks
    Head -->|skip| Tail
    Blocks --> Tail
    Tail --> Output[output]
Loading

Head Hook responsibilities:

  1. Get hidden_states, save as head_block_input
  2. Check accumulated error accumulated_err
  3. If accumulated_err <= threshold: skip computation, return directly
  4. Otherwise: continue executing transformer

Block Hook responsibilities:

  1. If should_compute=False: return hidden_states directly
  2. If should_compute=True: execute block, compute residual = output - head_block_input, save residual

8. File Structure

magcache/
├── __init__.py       # public API
├── backend.py        # MagCacheBackend
├── config.py        # MagCacheConfig
├── hook.py         # HeadHook, BlockHook
├── state.py        # MagCacheState
└── strategy.py     # Strategy base class + Flux implementation

9. Supported Models

MagCache currently supports the following models:

Model Hugging Face ID Status Notes
FLUX.1-dev black-forest-labs/FLUX.1-dev Default implementation with calibrated mag_ratios
FLUX.2-klein black-forest-labs/FLUX.2-klein-4B Shape mismatch handling for single-stream

10. Test

2 * NVIDIA 4096(24G)

vllm serve black-forest-labs/FLUX.1-dev   --omni   --port 8004   --tensor-parallel-size 2   --enable_cpu
_offload   --cache-backend mag_cache   --cache-config '{"mag_threshold": 0.24, "mag_max_skip_steps": 5, "mag_retention_ratio": 0.1}' 
curl -X POST http://localhost:8004/v1/images/generations   -H "Content-Type: application/json"   -d '{
    "prompt": "a dragon laying over the spine of the Green Mountains of Vermont",
    "size": "1024x1024",
    "num_inference_steps": 50,
    "cfg_scale": 4.0,
    "guidance_scale": 4.0,
    "seed": 42
  }' | jq -r '.data[0].b64_json' | base64 -d > dragon.png
Metric NO MagCache MagCache
Image dragon dragon
Time 27.614 s/img 10.439 s/img

Signed-off-by: Lancer <maruixiang6688@gmail.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 55e3509660

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/diffusion/hooks/base.py Outdated
Comment thread vllm_omni/diffusion/cache/magcache/backend.py Outdated
Signed-off-by: Lancer <maruixiang6688@gmail.com>
@princepride
Copy link
Copy Markdown
Collaborator

Thanks for your contribution!😁 May I ask what's the difference with Cache Dit?

@RuixiangMa
Copy link
Copy Markdown
Collaborator Author

RuixiangMa commented Feb 9, 2026

Thanks for your contribution!😁 May I ask what's the difference with Cache Dit?

MagCache decides before computation whether to skip blocks by comparing residual magnitude ratios ( ||r_t|| / ||r_{t-1}|| ). It's faster but less accurate. Cache-DIT decides after computation by comparing actual residual differences ( max(|r_t - r_cached|) ). It's more accurate but requires computing the residual first;

In addition to this, users only need one-time calibration per schedule—run once, copy the magnitude ratios, and reuse them everywhere. Super convenient.

Honestly, the main reason for this PR is just that I found MagCache pretty interesting and the tests show it works pretty well 😄 Plus, seeing that Diffusers already added support made it even more fun to contribute.

@princepride
Copy link
Copy Markdown
Collaborator

Thanks for your clarifying!

Signed-off-by: Lancer <maruixiang6688@gmail.com>
Comment thread vllm_omni/diffusion/cache/magcache/backend.py Outdated
@princepride
Copy link
Copy Markdown
Collaborator

@RuixiangMa Can we create some custom functions like: https://github.com/vllm-project/vllm-omni/blob/26ba1e4de8767aaaefb531d24ee8f195f57814e6/vllm_omni/diffusion/cache/cache_dit_backend.py#L856-L867, because many model structures are not standardized

Comment thread vllm_omni/diffusion/cache/__init__.py
Comment thread vllm_omni/diffusion/cache/selector.py Outdated
Comment thread vllm_omni/diffusion/cache/magcache/strategy.py Outdated
Signed-off-by: Lancer <maruixiang6688@gmail.com>
Copy link
Copy Markdown
Collaborator

@gcanlin gcanlin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the generated picture accuracy has a little regression. Is it expected?

@princepride
Copy link
Copy Markdown
Collaborator

It looks like the generated picture accuracy has a little regression. Is it expected?

Yes, all cache algorithms will regress accuracy

@RuixiangMa
Copy link
Copy Markdown
Collaborator Author

It looks like the generated picture accuracy has a little regression. Is it expected?

yeah,three key parameters control the quality-speed trade-off:

Parameter Quality Speed
threshold higher = lower quality higher = faster
max_skip_steps ower = higher quality lower = slower
retention_ratio higher = higher quality higher = slower

@princepride
Copy link
Copy Markdown
Collaborator

@RuixiangMa Hi, how is the progress, can you also update the docs about supported model list? And you can click the Resolved Conversion button if you think my suggestion is reasonable.

princepride and others added 2 commits February 11, 2026 13:33
Signed-off-by: Lancer <maruixiang6688@gmail.com>
@RuixiangMa
Copy link
Copy Markdown
Collaborator Author

@RuixiangMa Hi, how is the progress, can you also update the docs about supported model list? And you can click the Resolved Conversion button if you think my suggestion is reasonable.

Thanks for the suggestion! I've done some refactoring to clean up the code,made it clearer and easier to plug in other models.

Signed-off-by: Lancer <maruixiang6688@gmail.com>
@RuixiangMa RuixiangMa changed the title Support MagCache [Feat] Support MagCache Feb 13, 2026
Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solid feature addition -- the hook architecture is clean and the strategy pattern makes it easy to extend. A few issues around correctness and dead code worth addressing before merge.

Comment thread vllm_omni/diffusion/cache/magcache/hook.py
Comment thread vllm_omni/diffusion/cache/magcache/hook.py
Comment thread vllm_omni/diffusion/cache/magcache/strategy.py
Comment thread vllm_omni/diffusion/cache/magcache/config.py
Comment thread vllm_omni/diffusion/hooks/base.py
Comment thread vllm_omni/diffusion/hooks/base.py
Comment thread vllm_omni/diffusion/cache/magcache/hook.py
Comment thread vllm_omni/diffusion/cache/magcache/state.py
Comment thread vllm_omni/diffusion/data.py Outdated
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

@vllm-omni-reviewer

1 similar comment
@princepride
Copy link
Copy Markdown
Collaborator

@vllm-omni-reviewer

Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the issues from my last review have been addressed -- thanks for the fixes. Two remaining things:

  1. The single-block path in apply_mag_cache_hook registers two hooks on the same block (MagCacheHeadHook + MagCacheBlockHook). But HookRegistry.dispatch uses the pre_forward/post_forward chain for multi-hook blocks, which bypasses new_forward entirely. The caching logic won't run for single-block transformers.

  2. The dead code branch I flagged earlier is still present (see inline thread).

logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'")
_apply_mag_cache_block_hook(block, state_manager, config, is_tail=True, strategy=strategy)
_apply_mag_cache_head_hook(block, state_manager, config, strategy)
return
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When len(remaining_blocks) == 1, this registers both a head hook and a block hook on the same module. HookRegistry.dispatch with 2+ hooks uses the pre_forward/post_forward chain, not new_forward. Since both hooks only implement new_forward, the caching logic is silently skipped and you get a plain forward pass.

Either make single-block use a combined hook, or update dispatch to handle multi-hook new_forward chaining.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ths, fixed

Comment thread vllm_omni/diffusion/cache/magcache/hook.py Outdated
Signed-off-by: Lancer <maruixiang6688@gmail.com>
Signed-off-by: Lancer <maruixiang6688@gmail.com>
@Gaohan123
Copy link
Copy Markdown
Collaborator

@RuixiangMa Please resolve conflicts

@Gaohan123 Gaohan123 added this to the v0.18.0 milestone Mar 17, 2026
Signed-off-by: Lancer <maruixiang6688@gmail.com>
Signed-off-by: Lancer <maruixiang6688@gmail.com>
Signed-off-by: Lancer <maruixiang6688@gmail.com>
@RuixiangMa
Copy link
Copy Markdown
Collaborator Author

@RuixiangMa Please resolve conflicts

fixed

@Gaohan123 Gaohan123 modified the milestones: v0.18.0, v0.20.0 Apr 14, 2026
@Gaohan123 Gaohan123 added the ready label to trigger buildkite CI label Apr 30, 2026
Signed-off-by: Lancer <maruixiang6688@gmail.com>
@RuixiangMa
Copy link
Copy Markdown
Collaborator Author

CI failure is unrelated to this PR.

@Gaohan123 Gaohan123 modified the milestones: v0.20.0, v0.22.0 May 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants