From dd459291b39088bd086a77ef9067aef27f5880a9 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Fri, 10 Apr 2026 18:24:26 +0000 Subject: [PATCH 1/3] Fix multihook handling Signed-off-by: Alex Brooks --- tests/diffusion/hooks/test_hook_registry.py | 167 ++++++++++++++++++++ vllm_omni/diffusion/hooks/base.py | 73 +++++++-- 2 files changed, 224 insertions(+), 16 deletions(-) create mode 100644 tests/diffusion/hooks/test_hook_registry.py diff --git a/tests/diffusion/hooks/test_hook_registry.py b/tests/diffusion/hooks/test_hook_registry.py new file mode 100644 index 00000000000..c2698bca05f --- /dev/null +++ b/tests/diffusion/hooks/test_hook_registry.py @@ -0,0 +1,167 @@ +""" +Tests for hook registry. + +NOTE: The hook registry is also tested indirectly through a lot of +other tests, e.g., tests/diffusion/distributed/test_sp_plan_hooks.py +""" + +from typing import Any + +import pytest +from torch import nn + +from vllm_omni.diffusion.hooks.base import HookRegistry, ModelHook + +DEFAULT_OUT = "ECHO" +OVERRIDE_OUT = "OVERRIDE" +INPUT_KWARG = "inp" + + +class EchoModule(nn.Module): + """Just echo the input.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + def forward(self, *args, **kwargs): + input_val = kwargs[INPUT_KWARG] + return input_val + DEFAULT_OUT + + +class AppendHook(ModelHook): + """Append an echo value to the input string on pre / post forward.""" + + def __init__(self, echo_val: str): + self.echo_val = echo_val + + def pre_forward(self, module: nn.Module, *args, **kwargs): + input_val = kwargs[INPUT_KWARG] + return (), {INPUT_KWARG: input_val + self.echo_val} + + def post_forward(self, module: nn.Module, output): + return output + self.echo_val + + +class OverrideAppendHook(AppendHook): + """Same as AppendHook, but replace the forward call with a different string.""" + + def new_forward(self, module: nn.Module, *args, **kwargs): + """Call pre_forward, do something instead of fwd, then call post forward.""" + _, new_kwargs = self.pre_forward(module, *args, **kwargs) + fwd_out = new_kwargs[INPUT_KWARG] + OVERRIDE_OUT + return self.post_forward(module, fwd_out) + + +def test_register_no_fwd_override_hooks(): + """Ensure registration is correct with no forward hooks.""" + mod = EchoModule() + registry = HookRegistry.get_or_create(mod) + first_hook = AppendHook("1") + second_hook = AppendHook("2") + sorted_no_fwd_hooks = [first_hook, second_hook] + + # Will add and sort the hook by key + registry.register_hook(name="b", hook=second_hook) + registry.register_hook(name="a", hook=first_hook) + + assert len(registry._hooks) == 2 + assert len(registry._sorted_def_fwd_hooks) == 2 + assert registry._new_fwd_impl_hook is None + # Ensure registering a new hook sorting alphabetically + for actual_hook, expected_hook in zip(registry._sorted_def_fwd_hooks, sorted_no_fwd_hooks): + assert actual_hook is expected_hook + + +def test_register_with_forward_hooks(): + """Ensure registration is correct with a forward hooks.""" + mod = EchoModule() + registry = HookRegistry.get_or_create(mod) + first_hook = AppendHook("1") + second_hook = AppendHook("2") + exec_hook = OverrideAppendHook("3") + sorted_no_fwd_hooks = [first_hook, second_hook] + + # Will add and sort the hook by key + registry.register_hook(name="b", hook=second_hook) + registry.register_hook(name="a", hook=first_hook) + registry.register_hook(name="c", hook=exec_hook) + + assert len(registry._hooks) == 3 + assert len(registry._sorted_def_fwd_hooks) == 2 + assert registry._new_fwd_impl_hook is exec_hook + # Ensure registering a new hook sorting alphabetically + for actual_hook, expected_hook in zip(registry._sorted_def_fwd_hooks, sorted_no_fwd_hooks): + assert actual_hook is expected_hook + + +def test_register_fails_with_multiple_forward_hooks(): + """Ensure registration only allows one hook overriding new_forward""" + mod = EchoModule() + registry = HookRegistry.get_or_create(mod) + + registry.register_hook(name="foo", hook=OverrideAppendHook("1")) + with pytest.raises(RuntimeError): + registry.register_hook(name="bar", hook=OverrideAppendHook("2")) + + +def test_remove_hooks(): + """Ensure removal sorts hooks.""" + mod = EchoModule() + registry = HookRegistry.get_or_create(mod) + + first_hook = AppendHook("1") + second_hook = AppendHook("2") + exec_hook = OverrideAppendHook("3") + + registry.register_hook(name="b", hook=second_hook) + registry.register_hook(name="a", hook=first_hook) + registry.register_hook(name="c", hook=exec_hook) + # Explicitly reorder our hooks to be in the wrong order, since register + # forces them to be sorted too. Ensure that remove the hook will also + # enforce the sorted order. + registry._sorted_def_fwd_hooks = [second_hook, first_hook] + + assert registry._new_fwd_impl_hook is exec_hook + registry.remove_hook("c") + assert registry._new_fwd_impl_hook is None + + sorted_no_fwd_hooks = [first_hook, second_hook] + for actual_hook, expected_hook in zip(registry._sorted_def_fwd_hooks, sorted_no_fwd_hooks): + assert actual_hook is expected_hook + + +def test_dispatch_no_fwd_override_hooks(): + """Ensure dispatch runs hooks in deterministic sorted order.""" + mod = EchoModule() + registry = HookRegistry.get_or_create(mod) + + first_hook = AppendHook("1") + second_hook = AppendHook("2") + + # Register will sort the hooks, so hook 1 will run first + # on preprocess and last in post process + registry.register_hook(name="2", hook=second_hook) + registry.register_hook(name="1", hook=first_hook) + res = registry.dispatch(inp="") + assert isinstance(res, str) + assert res == f"12{DEFAULT_OUT}21" + + +def test_dispatch_with_fwd_hooks(): + """Ensure dispatch runs hooks in deterministic sorted order.""" + mod = EchoModule() + registry = HookRegistry.get_or_create(mod) + + first_hook = AppendHook("1") + second_hook = AppendHook("2") + exec_hook = OverrideAppendHook("3") + + # Register will sort the hooks, so hook 1 will run first on preprocess and last in + # post process. Since the override hook mutates forward, it will run last even + # though the name of the exec_hook is alphabetically before the second hook. + registry.register_hook(name="c", hook=second_hook) + registry.register_hook(name="a", hook=first_hook) + registry.register_hook(name="b", hook=exec_hook) + res = registry.dispatch(inp="") + assert isinstance(res, str) + assert res == f"123{OVERRIDE_OUT}321" diff --git a/vllm_omni/diffusion/hooks/base.py b/vllm_omni/diffusion/hooks/base.py index cda4201ccf3..8a45af17f94 100644 --- a/vllm_omni/diffusion/hooks/base.py +++ b/vllm_omni/diffusion/hooks/base.py @@ -8,6 +8,7 @@ from __future__ import annotations +import functools import inspect from collections.abc import Callable from dataclasses import dataclass @@ -136,6 +137,21 @@ def __call__(self, *args: Any, **kwargs: Any): return registry.dispatch(*args, **kwargs) +def sort_hooks_after_call(func): + """Calls the method on the hook registry, then sorts the hooks. + + This should be added to methods that mutate add or remove hooks. + """ + + @functools.wraps(func) + def wrapper(self: HookRegistry, *args, **kwargs): + res = func(self, *args, **kwargs) + self.update_sorted_hooks() + return res + + return wrapper + + class HookRegistry: """Registry of hooks attached to a module. @@ -146,6 +162,10 @@ class HookRegistry: def __init__(self, module: nn.Module): self.module = module self._hooks: dict[str, ModelHook] = {} + # Sorted hook execution order for hooks that don't override new_forward + self._sorted_def_fwd_hooks: list[ModelHook] = [] + # Hooks overriding new_forward (if any), which includes pre/post process for now + self._new_fwd_impl_hook: ModelHook | None = None @classmethod def get_or_create(cls, module: nn.Module) -> HookRegistry: @@ -173,6 +193,13 @@ def get_or_create(cls, module: nn.Module) -> HookRegistry: return registry + def update_sorted_hooks(self): + """Sort hooks by name, which dictates pre/post process order.""" + self._sorted_def_fwd_hooks = [ + self._hooks[k] for k in sorted(self._hooks) if self._hooks[k] != self._new_fwd_impl_hook + ] + + @sort_hooks_after_call def register_hook(self, name: str, hook: ModelHook) -> None: """Register a hook with the given name. @@ -182,7 +209,14 @@ def register_hook(self, name: str, hook: ModelHook) -> None: """ hook.initialize_hook(self.module) self._hooks[name] = hook - + # We can only have one hook that overrides new_forward, + # since we don't currently have a mechanism for combining them. + if type(hook).new_forward is not ModelHook.new_forward: + if self._new_fwd_impl_hook is not None: + raise RuntimeError("Cannot have multiple hooks that override forward active simultaneously") + self._new_fwd_impl_hook = hook + + @sort_hooks_after_call def remove_hook(self, name: str) -> None: """Remove a hook by name. @@ -190,6 +224,9 @@ def remove_hook(self, name: str) -> None: name: The name of the hook to remove. """ if name in self._hooks: + # clear the forward hook if it's the one to delete + if self._new_fwd_impl_hook is self._hooks[name]: + self._new_fwd_impl_hook = None del self._hooks[name] def get_hook(self, name: str) -> ModelHook | None: @@ -206,8 +243,16 @@ def get_hook(self, name: str) -> ModelHook | None: def dispatch(self, *args: Any, **kwargs: Any) -> Any: """Dispatch a forward call through registered hooks. - Currently supports a single active hook. Multiple hooks are called - in sorted order by name, with each hook's output passed to the next. + Multiple hooks may be used with the caveat that only one hook + may override new_forward. While it is assumed that pre/post process + on hooks are composable, the execution flow is as follows for determinism: + + - Run preprocess on all hooks that don't override new_forward in alphabetical order + + - If a hook overrides new_forward, call new_forward on the hook, which will also call + its pre/post process. Otherwise call the original model forward. + + - Run post process on all hooks that don't override new_forward in reverse order. Args: *args: Positional arguments to forward. @@ -219,24 +264,20 @@ def dispatch(self, *args: Any, **kwargs: Any) -> Any: if not self._hooks: return self.module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined] - # For single hook case, call directly - if len(self._hooks) == 1: - hook = next(iter(self._hooks.values())) - return hook.new_forward(self.module, *args, **kwargs) - - # For multiple hooks, chain them in sorted order - # Each hook can modify args/kwargs via pre_forward - sorted_hooks = sorted(self._hooks.items(), key=lambda x: x[0]) - # Apply all pre_forward hooks - for _, hook in sorted_hooks: + for hook in self._sorted_def_fwd_hooks: args, kwargs = hook.pre_forward(self.module, *args, **kwargs) - # Call original forward - output = self.module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined] + # If we have a hook that overrides new_forward, call it directly; + # this will also call its pre/post process at the moment. + if self._new_fwd_impl_hook is not None: + output = self._new_fwd_impl_hook.new_forward(self.module, *args, **kwargs) + # Otherwise just call the original forward. + else: + output = self.module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined] # Apply all post_forward hooks in reverse order - for _, hook in reversed(sorted_hooks): + for hook in reversed(self._sorted_def_fwd_hooks): output = hook.post_forward(self.module, output) return output From 5fec22db0b3c99b1ba7393d6bcf4a2f159849b24 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sat, 11 Apr 2026 03:50:50 +0000 Subject: [PATCH 2/3] move prev / post out of new_forward default Signed-off-by: Alex Brooks --- tests/diffusion/hooks/test_hook_registry.py | 17 ++++---- vllm_omni/diffusion/hooks/base.py | 44 ++++++++++----------- 2 files changed, 29 insertions(+), 32 deletions(-) diff --git a/tests/diffusion/hooks/test_hook_registry.py b/tests/diffusion/hooks/test_hook_registry.py index c2698bca05f..6c8535cfec4 100644 --- a/tests/diffusion/hooks/test_hook_registry.py +++ b/tests/diffusion/hooks/test_hook_registry.py @@ -46,10 +46,7 @@ class OverrideAppendHook(AppendHook): """Same as AppendHook, but replace the forward call with a different string.""" def new_forward(self, module: nn.Module, *args, **kwargs): - """Call pre_forward, do something instead of fwd, then call post forward.""" - _, new_kwargs = self.pre_forward(module, *args, **kwargs) - fwd_out = new_kwargs[INPUT_KWARG] + OVERRIDE_OUT - return self.post_forward(module, fwd_out) + return kwargs[INPUT_KWARG] + OVERRIDE_OUT def test_register_no_fwd_override_hooks(): @@ -65,10 +62,10 @@ def test_register_no_fwd_override_hooks(): registry.register_hook(name="a", hook=first_hook) assert len(registry._hooks) == 2 - assert len(registry._sorted_def_fwd_hooks) == 2 + assert len(registry._sorted_hooks) == 2 assert registry._new_fwd_impl_hook is None # Ensure registering a new hook sorting alphabetically - for actual_hook, expected_hook in zip(registry._sorted_def_fwd_hooks, sorted_no_fwd_hooks): + for actual_hook, expected_hook in zip(registry._sorted_hooks, sorted_no_fwd_hooks): assert actual_hook is expected_hook @@ -87,10 +84,10 @@ def test_register_with_forward_hooks(): registry.register_hook(name="c", hook=exec_hook) assert len(registry._hooks) == 3 - assert len(registry._sorted_def_fwd_hooks) == 2 + assert len(registry._sorted_hooks) == 3 assert registry._new_fwd_impl_hook is exec_hook # Ensure registering a new hook sorting alphabetically - for actual_hook, expected_hook in zip(registry._sorted_def_fwd_hooks, sorted_no_fwd_hooks): + for actual_hook, expected_hook in zip(registry._sorted_hooks, sorted_no_fwd_hooks): assert actual_hook is expected_hook @@ -119,14 +116,14 @@ def test_remove_hooks(): # Explicitly reorder our hooks to be in the wrong order, since register # forces them to be sorted too. Ensure that remove the hook will also # enforce the sorted order. - registry._sorted_def_fwd_hooks = [second_hook, first_hook] + registry._sorted_hooks = [second_hook, first_hook] assert registry._new_fwd_impl_hook is exec_hook registry.remove_hook("c") assert registry._new_fwd_impl_hook is None sorted_no_fwd_hooks = [first_hook, second_hook] - for actual_hook, expected_hook in zip(registry._sorted_def_fwd_hooks, sorted_no_fwd_hooks): + for actual_hook, expected_hook in zip(registry._sorted_hooks, sorted_no_fwd_hooks): assert actual_hook is expected_hook diff --git a/vllm_omni/diffusion/hooks/base.py b/vllm_omni/diffusion/hooks/base.py index 8a45af17f94..78f3091f819 100644 --- a/vllm_omni/diffusion/hooks/base.py +++ b/vllm_omni/diffusion/hooks/base.py @@ -95,10 +95,10 @@ def post_forward(self, module: nn.Module, output: Any) -> Any: return output def new_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> Any: - """Override the module's forward pass completely. + """Override the module's forward pass. This should be overridden for more complex + cases, e.g., TeaCache. - The default implementation calls pre_forward, then the original forward, - then post_forward. Override this method for more complex behavior. + NOTE: only one hook overriding `new_forward` can be enabled at a time. Args: module: The module being called. @@ -108,9 +108,7 @@ def new_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> Any: Returns: The output of the forward pass. """ - args, kwargs = self.pre_forward(module, *args, **kwargs) - output = module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined] - return self.post_forward(module, output) + return module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined] def reset_state(self, module: nn.Module) -> nn.Module: """Reset any state associated with this hook. @@ -162,9 +160,9 @@ class HookRegistry: def __init__(self, module: nn.Module): self.module = module self._hooks: dict[str, ModelHook] = {} - # Sorted hook execution order for hooks that don't override new_forward - self._sorted_def_fwd_hooks: list[ModelHook] = [] - # Hooks overriding new_forward (if any), which includes pre/post process for now + # Hooks sorted by execution order + self._sorted_hooks: list[ModelHook] = [] + # Hooks overriding new_forward (if any) self._new_fwd_impl_hook: ModelHook | None = None @classmethod @@ -195,9 +193,10 @@ def get_or_create(cls, module: nn.Module) -> HookRegistry: def update_sorted_hooks(self): """Sort hooks by name, which dictates pre/post process order.""" - self._sorted_def_fwd_hooks = [ - self._hooks[k] for k in sorted(self._hooks) if self._hooks[k] != self._new_fwd_impl_hook - ] + sorted_hooks = [self._hooks[k] for k in sorted(self._hooks) if self._hooks[k] != self._new_fwd_impl_hook] + if self._new_fwd_impl_hook is not None: + sorted_hooks.append(self._new_fwd_impl_hook) + self._sorted_hooks = sorted_hooks @sort_hooks_after_call def register_hook(self, name: str, hook: ModelHook) -> None: @@ -247,12 +246,14 @@ def dispatch(self, *args: Any, **kwargs: Any) -> Any: may override new_forward. While it is assumed that pre/post process on hooks are composable, the execution flow is as follows for determinism: - - Run preprocess on all hooks that don't override new_forward in alphabetical order + - Run preprocess on all hooks in their sorted order; hooks are sorted alphabetically, + except for the hook overriding forward (`self._new_fwd_impl_hook`), which is last + if it exists. - - If a hook overrides new_forward, call new_forward on the hook, which will also call - its pre/post process. Otherwise call the original model forward. + - If `self._new_fwd_impl_hook` isn't None, call its forward. Otherwise call the + original model forward. - - Run post process on all hooks that don't override new_forward in reverse order. + - Run post process on all hooks in the reverse sorted order. Args: *args: Positional arguments to forward. @@ -264,20 +265,19 @@ def dispatch(self, *args: Any, **kwargs: Any) -> Any: if not self._hooks: return self.module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined] - # Apply all pre_forward hooks - for hook in self._sorted_def_fwd_hooks: + # Apply all pre_forward hooks; if _new_fwd_impl_hook is set, it's last + for hook in self._sorted_hooks: args, kwargs = hook.pre_forward(self.module, *args, **kwargs) - # If we have a hook that overrides new_forward, call it directly; - # this will also call its pre/post process at the moment. + # If we have a hook that overrides new_forward, call it directly if self._new_fwd_impl_hook is not None: output = self._new_fwd_impl_hook.new_forward(self.module, *args, **kwargs) # Otherwise just call the original forward. else: output = self.module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined] - # Apply all post_forward hooks in reverse order - for hook in reversed(self._sorted_def_fwd_hooks): + # Apply all post_forward hooks in reverse order; if _new_fwd_impl_hook is set, it's first + for hook in reversed(self._sorted_hooks): output = hook.post_forward(self.module, output) return output From 3d543ef01db2f74fe75896500146b84f50d1f72a Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sat, 11 Apr 2026 17:59:34 +0000 Subject: [PATCH 3/3] raise notimplementederror for default new_forward Signed-off-by: Alex Brooks --- vllm_omni/diffusion/hooks/base.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm_omni/diffusion/hooks/base.py b/vllm_omni/diffusion/hooks/base.py index 78f3091f819..517c6615877 100644 --- a/vllm_omni/diffusion/hooks/base.py +++ b/vllm_omni/diffusion/hooks/base.py @@ -96,9 +96,8 @@ def post_forward(self, module: nn.Module, output: Any) -> Any: def new_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> Any: """Override the module's forward pass. This should be overridden for more complex - cases, e.g., TeaCache. - - NOTE: only one hook overriding `new_forward` can be enabled at a time. + cases, e.g., TeaCache. If this method is overridden in a subclass, it will be called + instead of self.module._omni_original_forward when executing the hooks. Args: module: The module being called. @@ -106,9 +105,9 @@ def new_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> Any: **kwargs: Keyword arguments to forward. Returns: - The output of the forward pass. + The output of the replacement for the forward pass. """ - return module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined] + raise NotImplementedError("By default, hooks do not implement new_forward") def reset_state(self, module: nn.Module) -> nn.Module: """Reset any state associated with this hook.