Skip to content

[Bugfix] Fix Incompatible Multihook Integration (TeaCache <-> CPU Offload)#2689

Merged
wtomin merged 4 commits into
vllm-project:mainfrom
alex-jw-brooks:fix_multihook
Apr 12, 2026
Merged

[Bugfix] Fix Incompatible Multihook Integration (TeaCache <-> CPU Offload)#2689
wtomin merged 4 commits into
vllm-project:mainfrom
alex-jw-brooks:fix_multihook

Conversation

@alex-jw-brooks
Copy link
Copy Markdown
Contributor

@alex-jw-brooks alex-jw-brooks commented Apr 10, 2026

Purpose

Fixes the root cause of #1868

The Hook registry currently has the following behavior:

  • If we have one hook, call preprocess -> call new_forward on the call hook -> call post process
  • If we have multiple hooks, sort alphabetically, then call preprocess on all of them -> call the original forward -> call post process in reverse order on all hooks

This means that any hooks that actually override new_forward, e.g., TeaCache, will not call their overridden new_forward when they are run with other hooks like CPU offload.

This PR fixes it by ensuring only one hook that overrides new_forward is active at a time, and ensures that it's the last in the chain. Then the flow becomes call preprocess on all hooks that don't override the new forward, call the new forward on the hook (which currently encapsulates the pre/post process of that hook as well), and then call the chain backwards. If we try to register multiple hooks that have their own overrides of new_forward, it'll also throw a RuntimeError since we currently don't have a way of combining hooks with custom forwards.

Test Plan

The hook registry is already tested indirectly through the SP tests, but does not have direct unit tests, so I added a suite for it as well to verify the sorting / execution behavior on simple hooks.

Test Result

Tests pass, and TeaCache and CPU offload are now compatible.

import os
import gc
import time
import torch
from PIL import Image
from vllm_omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams

# Configuration
MODEL_ID = "black-forest-labs/FLUX.2-klein-4B"

PROMPT = "A cup of coffee sitting on a table."
STEPS = 50
SEEDS = [444, 111, 3919]

TEACACHE_DIR = "cache_results"
NO_CACHE_DIR = "no_cache_results"
os.makedirs(TEACACHE_DIR, exist_ok=True)
os.makedirs(NO_CACHE_DIR, exist_ok=True)


def run_benchmark(use_cache=False):
    print(f"\n{'Testing with tea_cache' if use_cache else 'Testing without tea_cache'}...")
    times = []
    out_dir = TEACACHE_DIR if use_cache else NO_CACHE_DIR
    cache_config = {
        "rel_l1_thresh": 0.2,
    } if use_cache else {}
    cache_backend = "tea_cache" if use_cache else None

    omni = Omni(
        model=MODEL_ID,
        cache_backend=cache_backend,
        cache_config=cache_config,
        dtype="bfloat16",
        enable_cpu_offload=True,
    )

    for seed in SEEDS:
        sampling_params = OmniDiffusionSamplingParams(num_inference_steps=STEPS, seed=seed)
        start = time.time()
        outputs = omni.generate(PROMPT, sampling_params)
        end = time.time()
        run_time = end - start
        times.append(run_time)
        # Save the generated image
        image = outputs[0].images[0]
        filename = f"{out_dir}/seed_{seed}.png"
        print(f"Run time: {run_time} for seed: {seed} [use_cache={use_cache}]")
        image.save(filename)

    avg_time = sum(times) / len(times)
    print(f"Average latency [use_cache={use_cache}]: {avg_time}")
    return avg_time


def create_comparison_grid(seeds, cache_dir, no_cache_dir, output_path="comparison_grid.png"):
    """Create a 2x3 grid comparing teacache vs no cache results.

    Top row: with teacache
    Bottom row: without teacache
    """
    print("\nCreating comparison grid...")

    # Load all images
    cache_images = [Image.open(f"{cache_dir}/seed_{seed}.png") for seed in seeds]
    no_cache_images = [Image.open(f"{no_cache_dir}/seed_{seed}.png") for seed in seeds]

    # Get dimensions from first image
    img_width, img_height = cache_images[0].size

    # Calculate grid dimensions (no padding, just images)
    grid_width = img_width * len(seeds)
    grid_height = img_height * 2

    # Create composite image
    composite = Image.new('RGB', (grid_width, grid_height))

    # Paste top row (with cache)
    for i, img in enumerate(cache_images):
        x_offset = i * img_width
        composite.paste(img, (x_offset, 0))

    # Paste bottom row (no cache)
    for i, img in enumerate(no_cache_images):
        x_offset = i * img_width
        composite.paste(img, (x_offset, img_height))

    # Save composite
    composite.save(output_path)
    print(f"Comparison grid saved to: {output_path}")


if __name__ == "__main__":
    # Run tests
    time_with_cache = run_benchmark(use_cache=True)
    torch.cuda.empty_cache()
    gc.collect()
    time_no_cache = run_benchmark(use_cache=False)

    print(f"\nResults:")
    print(f"Speedup: {time_no_cache / time_with_cache:.2f}x")

    # Create comparison grid
    create_comparison_grid(SEEDS, TEACACHE_DIR, NO_CACHE_DIR)

On main, you should see ~ 1x speedup since teacache doesn't actually run if cpu offload is passed. On this branch you should see the speedup of ~1.3x and the resulting grid (top is with the cache backend set, bottom is with no cache backend set)

comparison_grid

Signed-off-by: Alex Brooks <albrooks@redhat.com>
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean fix. The old dispatch had a latent bug where multi-hook scenarios would skip new_forward overrides entirely. New dispatch correctly separates the two hook types. Tests are thorough.

One observation: hooks with new_forward are expected to call their own pre_forward/post_forward internally (see OverrideAppendHook). This is a subtle contract — if someone writes a new_forward hook that forgets to call its own pre/post, the hook silently does nothing. Worth documenting this expectation on the ModelHook.new_forward docstring.

Signed-off-by: Alex Brooks <albrooks@redhat.com>
@alex-jw-brooks
Copy link
Copy Markdown
Contributor Author

alex-jw-brooks commented Apr 11, 2026

Thanks for the quick review @hsliuustc0106 🙂

One observation: hooks with new_forward are expected to call their own pre_forward/post_forward internally (see OverrideAppendHook). This is a subtle contract — if someone writes a new_forward hook that forgets to call its own pre/post, the hook silently does nothing. Worth documenting this expectation on the ModelHook.new_forward docstring.

I agree, I think this is very unintuitive. IMO it's better to just let the original model forward be the default call for new_forward instead of calling pre/post in it, especially since we don't currently call the default implementation anymore as a special case for one hook.

Currently, there are no hooks that override new_forward that also implement a custom pre or post forward, so I went ahead and moved the pre/post forward out of the default new_forward and changed the flow to explicitly call pre/post forward on the hook with the new_forward override as well, since it is less error prone.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

Good fix. The new register_fails_with_multiple_forward_hooks test and the dispatch tests are solid.

One subtlety: the contract for new_forward changed — it no longer calls pre_forward/post_forward by default. Any future hook that overrides new_forward must explicitly call them. Worth a brief docstring note or a base-class new_forward that raises NotImplementedError to make the contract explicit, since the old default implementation silently handled this.

Signed-off-by: Alex Brooks <albrooks@redhat.com>
@alex-jw-brooks
Copy link
Copy Markdown
Contributor Author

alex-jw-brooks commented Apr 11, 2026

Sounds good to me - changed the default new_forward that's now unused to raise NotImplementedError to be more clear

@hsliuustc0106 hsliuustc0106 added the ready label to trigger buildkite CI label Apr 11, 2026
Copy link
Copy Markdown
Collaborator

@wtomin wtomin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. This fix is nice and clean

@wtomin wtomin merged commit ef230ac into vllm-project:main Apr 12, 2026
7 of 8 checks passed
daixinning pushed a commit to daixinning/vllm-omni that referenced this pull request Apr 13, 2026
…load) (vllm-project#2689)

Signed-off-by: Alex Brooks <albrooks@redhat.com>
Co-authored-by: SYLAR <125541396+lishunyang12@users.noreply.github.com>
lengrongfu pushed a commit to lengrongfu/vllm-omni that referenced this pull request May 1, 2026
…load) (vllm-project#2689)

Signed-off-by: Alex Brooks <albrooks@redhat.com>
Co-authored-by: SYLAR <125541396+lishunyang12@users.noreply.github.com>
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
…load) (vllm-project#2689)

Signed-off-by: Alex Brooks <albrooks@redhat.com>
Co-authored-by: SYLAR <125541396+lishunyang12@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants