Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions tests/diffusion/hooks/test_hook_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""
Tests for hook registry.

NOTE: The hook registry is also tested indirectly through a lot of
other tests, e.g., tests/diffusion/distributed/test_sp_plan_hooks.py
"""

from typing import Any

import pytest
from torch import nn

from vllm_omni.diffusion.hooks.base import HookRegistry, ModelHook

DEFAULT_OUT = "ECHO"
OVERRIDE_OUT = "OVERRIDE"
INPUT_KWARG = "inp"


class EchoModule(nn.Module):
"""Just echo the input."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

def forward(self, *args, **kwargs):
input_val = kwargs[INPUT_KWARG]
return input_val + DEFAULT_OUT


class AppendHook(ModelHook):
"""Append an echo value to the input string on pre / post forward."""

def __init__(self, echo_val: str):
self.echo_val = echo_val

def pre_forward(self, module: nn.Module, *args, **kwargs):
input_val = kwargs[INPUT_KWARG]
return (), {INPUT_KWARG: input_val + self.echo_val}

def post_forward(self, module: nn.Module, output):
return output + self.echo_val


class OverrideAppendHook(AppendHook):
"""Same as AppendHook, but replace the forward call with a different string."""

def new_forward(self, module: nn.Module, *args, **kwargs):
return kwargs[INPUT_KWARG] + OVERRIDE_OUT


def test_register_no_fwd_override_hooks():
"""Ensure registration is correct with no forward hooks."""
mod = EchoModule()
registry = HookRegistry.get_or_create(mod)
first_hook = AppendHook("1")
second_hook = AppendHook("2")
sorted_no_fwd_hooks = [first_hook, second_hook]

# Will add and sort the hook by key
registry.register_hook(name="b", hook=second_hook)
registry.register_hook(name="a", hook=first_hook)

assert len(registry._hooks) == 2
assert len(registry._sorted_hooks) == 2
assert registry._new_fwd_impl_hook is None
# Ensure registering a new hook sorting alphabetically
for actual_hook, expected_hook in zip(registry._sorted_hooks, sorted_no_fwd_hooks):
assert actual_hook is expected_hook


def test_register_with_forward_hooks():
"""Ensure registration is correct with a forward hooks."""
mod = EchoModule()
registry = HookRegistry.get_or_create(mod)
first_hook = AppendHook("1")
second_hook = AppendHook("2")
exec_hook = OverrideAppendHook("3")
sorted_no_fwd_hooks = [first_hook, second_hook]

# Will add and sort the hook by key
registry.register_hook(name="b", hook=second_hook)
registry.register_hook(name="a", hook=first_hook)
registry.register_hook(name="c", hook=exec_hook)

assert len(registry._hooks) == 3
assert len(registry._sorted_hooks) == 3
assert registry._new_fwd_impl_hook is exec_hook
# Ensure registering a new hook sorting alphabetically
for actual_hook, expected_hook in zip(registry._sorted_hooks, sorted_no_fwd_hooks):
assert actual_hook is expected_hook


def test_register_fails_with_multiple_forward_hooks():
"""Ensure registration only allows one hook overriding new_forward"""
mod = EchoModule()
registry = HookRegistry.get_or_create(mod)

registry.register_hook(name="foo", hook=OverrideAppendHook("1"))
with pytest.raises(RuntimeError):
registry.register_hook(name="bar", hook=OverrideAppendHook("2"))


def test_remove_hooks():
"""Ensure removal sorts hooks."""
mod = EchoModule()
registry = HookRegistry.get_or_create(mod)

first_hook = AppendHook("1")
second_hook = AppendHook("2")
exec_hook = OverrideAppendHook("3")

registry.register_hook(name="b", hook=second_hook)
registry.register_hook(name="a", hook=first_hook)
registry.register_hook(name="c", hook=exec_hook)
# Explicitly reorder our hooks to be in the wrong order, since register
# forces them to be sorted too. Ensure that remove the hook will also
# enforce the sorted order.
registry._sorted_hooks = [second_hook, first_hook]

assert registry._new_fwd_impl_hook is exec_hook
registry.remove_hook("c")
assert registry._new_fwd_impl_hook is None

sorted_no_fwd_hooks = [first_hook, second_hook]
for actual_hook, expected_hook in zip(registry._sorted_hooks, sorted_no_fwd_hooks):
assert actual_hook is expected_hook


def test_dispatch_no_fwd_override_hooks():
"""Ensure dispatch runs hooks in deterministic sorted order."""
mod = EchoModule()
registry = HookRegistry.get_or_create(mod)

first_hook = AppendHook("1")
second_hook = AppendHook("2")

# Register will sort the hooks, so hook 1 will run first
# on preprocess and last in post process
registry.register_hook(name="2", hook=second_hook)
registry.register_hook(name="1", hook=first_hook)
res = registry.dispatch(inp="")
assert isinstance(res, str)
assert res == f"12{DEFAULT_OUT}21"


def test_dispatch_with_fwd_hooks():
"""Ensure dispatch runs hooks in deterministic sorted order."""
mod = EchoModule()
registry = HookRegistry.get_or_create(mod)

first_hook = AppendHook("1")
second_hook = AppendHook("2")
exec_hook = OverrideAppendHook("3")

# Register will sort the hooks, so hook 1 will run first on preprocess and last in
# post process. Since the override hook mutates forward, it will run last even
# though the name of the exec_hook is alphabetically before the second hook.
registry.register_hook(name="c", hook=second_hook)
registry.register_hook(name="a", hook=first_hook)
registry.register_hook(name="b", hook=exec_hook)
res = registry.dispatch(inp="")
assert isinstance(res, str)
assert res == f"123{OVERRIDE_OUT}321"
92 changes: 66 additions & 26 deletions vllm_omni/diffusion/hooks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from __future__ import annotations

import functools
import inspect
from collections.abc import Callable
from dataclasses import dataclass
Expand Down Expand Up @@ -94,22 +95,19 @@ def post_forward(self, module: nn.Module, output: Any) -> Any:
return output

def new_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> Any:
"""Override the module's forward pass completely.

The default implementation calls pre_forward, then the original forward,
then post_forward. Override this method for more complex behavior.
"""Override the module's forward pass. This should be overridden for more complex
cases, e.g., TeaCache. If this method is overridden in a subclass, it will be called
instead of self.module._omni_original_forward when executing the hooks.

Args:
module: The module being called.
*args: Positional arguments to forward.
**kwargs: Keyword arguments to forward.

Returns:
The output of the forward pass.
The output of the replacement for the forward pass.
"""
args, kwargs = self.pre_forward(module, *args, **kwargs)
output = module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined]
return self.post_forward(module, output)
raise NotImplementedError("By default, hooks do not implement new_forward")

def reset_state(self, module: nn.Module) -> nn.Module:
"""Reset any state associated with this hook.
Expand All @@ -136,6 +134,21 @@ def __call__(self, *args: Any, **kwargs: Any):
return registry.dispatch(*args, **kwargs)


def sort_hooks_after_call(func):
"""Calls the method on the hook registry, then sorts the hooks.

This should be added to methods that mutate add or remove hooks.
"""

@functools.wraps(func)
def wrapper(self: HookRegistry, *args, **kwargs):
res = func(self, *args, **kwargs)
self.update_sorted_hooks()
return res

return wrapper


class HookRegistry:
"""Registry of hooks attached to a module.

Expand All @@ -146,6 +159,10 @@ class HookRegistry:
def __init__(self, module: nn.Module):
self.module = module
self._hooks: dict[str, ModelHook] = {}
# Hooks sorted by execution order
self._sorted_hooks: list[ModelHook] = []
# Hooks overriding new_forward (if any)
self._new_fwd_impl_hook: ModelHook | None = None

@classmethod
def get_or_create(cls, module: nn.Module) -> HookRegistry:
Expand Down Expand Up @@ -173,6 +190,14 @@ def get_or_create(cls, module: nn.Module) -> HookRegistry:

return registry

def update_sorted_hooks(self):
"""Sort hooks by name, which dictates pre/post process order."""
sorted_hooks = [self._hooks[k] for k in sorted(self._hooks) if self._hooks[k] != self._new_fwd_impl_hook]
if self._new_fwd_impl_hook is not None:
sorted_hooks.append(self._new_fwd_impl_hook)
self._sorted_hooks = sorted_hooks

@sort_hooks_after_call
def register_hook(self, name: str, hook: ModelHook) -> None:
"""Register a hook with the given name.

Expand All @@ -182,14 +207,24 @@ def register_hook(self, name: str, hook: ModelHook) -> None:
"""
hook.initialize_hook(self.module)
self._hooks[name] = hook

# We can only have one hook that overrides new_forward,
# since we don't currently have a mechanism for combining them.
if type(hook).new_forward is not ModelHook.new_forward:
if self._new_fwd_impl_hook is not None:
raise RuntimeError("Cannot have multiple hooks that override forward active simultaneously")
self._new_fwd_impl_hook = hook

@sort_hooks_after_call
def remove_hook(self, name: str) -> None:
"""Remove a hook by name.

Args:
name: The name of the hook to remove.
"""
if name in self._hooks:
# clear the forward hook if it's the one to delete
if self._new_fwd_impl_hook is self._hooks[name]:
self._new_fwd_impl_hook = None
del self._hooks[name]

def get_hook(self, name: str) -> ModelHook | None:
Expand All @@ -206,8 +241,18 @@ def get_hook(self, name: str) -> ModelHook | None:
def dispatch(self, *args: Any, **kwargs: Any) -> Any:
"""Dispatch a forward call through registered hooks.

Currently supports a single active hook. Multiple hooks are called
in sorted order by name, with each hook's output passed to the next.
Multiple hooks may be used with the caveat that only one hook
may override new_forward. While it is assumed that pre/post process
on hooks are composable, the execution flow is as follows for determinism:

- Run preprocess on all hooks in their sorted order; hooks are sorted alphabetically,
except for the hook overriding forward (`self._new_fwd_impl_hook`), which is last
if it exists.

- If `self._new_fwd_impl_hook` isn't None, call its forward. Otherwise call the
original model forward.

- Run post process on all hooks in the reverse sorted order.

Args:
*args: Positional arguments to forward.
Expand All @@ -219,24 +264,19 @@ def dispatch(self, *args: Any, **kwargs: Any) -> Any:
if not self._hooks:
return self.module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined]

# For single hook case, call directly
if len(self._hooks) == 1:
hook = next(iter(self._hooks.values()))
return hook.new_forward(self.module, *args, **kwargs)

# For multiple hooks, chain them in sorted order
# Each hook can modify args/kwargs via pre_forward
sorted_hooks = sorted(self._hooks.items(), key=lambda x: x[0])

# Apply all pre_forward hooks
for _, hook in sorted_hooks:
# Apply all pre_forward hooks; if _new_fwd_impl_hook is set, it's last
for hook in self._sorted_hooks:
args, kwargs = hook.pre_forward(self.module, *args, **kwargs)

# Call original forward
output = self.module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined]
# If we have a hook that overrides new_forward, call it directly
if self._new_fwd_impl_hook is not None:
output = self._new_fwd_impl_hook.new_forward(self.module, *args, **kwargs)
# Otherwise just call the original forward.
else:
output = self.module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined]

# Apply all post_forward hooks in reverse order
for _, hook in reversed(sorted_hooks):
# Apply all post_forward hooks in reverse order; if _new_fwd_impl_hook is set, it's first
for hook in reversed(self._sorted_hooks):
output = hook.post_forward(self.module, output)

return output
Expand Down
Loading