Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions docs/design/custom_op.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,6 @@ This document will introduce how CustomOp works in vLLM and how to implement a n

`CustomOp` manages two dictionaries of all custom ops (i.e., op classes, indexed by registered name) in its class, for vLLM and OOT plugins respectively.

??? code

```python
class CustomOp(nn.Module):

op_registry: dict[str, type["CustomOp"]] = {}
op_registry_oot: dict[str, type["CustomOp"]] = {}
```

We can use `@CustomOp.register("op_name")` to register an op class to the `CustomOp` system. After this, the `op_name` and its class will be added into the `op_registry` dictionary. In addition, We can also register an OOT op by `@CustomOp.register_oot("op_name")`. We will introduce this mechanism in detail later.

When a `CustomOp` is called (i.e., call its `forward()` method), if it is enabled (i.e., with `--compilation_config.custom_ops '["+op_name"]'`), it will automatically dispatch the forward method to the appropriate backend according to `current_platform`. Otherwise (i.e., it is disabled), it will only call the `forward_native()` method to use PyTorch-native implementation of this forward method.
Expand Down
10 changes: 5 additions & 5 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
get_cached_compilation_config,
set_current_vllm_config,
)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.custom_op import CustomOp, op_registry
from vllm.model_executor.layers.activation import (
GeluAndMul,
ReLUSquaredActivation,
Expand Down Expand Up @@ -98,17 +98,17 @@ def test_enabled_ops(
ops_enabled = [bool(x) for x in ops_enabled]

assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
assert op_registry["rms_norm"].enabled() == ops_enabled[0]

assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
assert op_registry["silu_and_mul"].enabled() == ops_enabled[1]

assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
assert op_registry["gelu_and_mul"].enabled() == ops_enabled[2]

# If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
assert op_registry["relu3"].enabled() == ops_enabled[3]

# Unregistered subclass
class SiluAndMul2(SiluAndMul):
Expand Down
100 changes: 86 additions & 14 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,86 @@
logger = init_logger(__name__)


# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}


class PluggableLayer(nn.Module):
"""
Base class for pluggable layers.

A PluggableLayer is a *module-composing* abstraction: it may instantiate other
``torch.nn.Module`` objects as sub-layers, and its functionality depends on
these sub-layers following a generalized invocation sequence. Also, it is stateful
and may hold parameters or buffers.

Unlike :class:`CustomOp`, PluggableLayer does NOT provide per-platform
``forward_*`` dispatch. Instead, it supports out-of-tree (OOT) replacement
of the entire layer class at instantiation time, allowing customized
initialization and submodule composition.
"""

def __new__(cls, *args, **kwargs):
try:
layer_class_name = cls.__name__
except AttributeError:
raise TypeError(
f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
f"was not set, possibly because it was not decorated with "
f"@PluggableLayer.register, or it's the PluggableLayer itself."
) from None

if layer_class_name not in op_registry_oot:
layer_cls_to_instantiate = cls
else:
layer_cls_to_instantiate = op_registry_oot[layer_class_name]
logger.debug(
"Instantiating pluggable layer: %s using %s",
layer_class_name,
str(layer_cls_to_instantiate),
)
return super().__new__(layer_cls_to_instantiate)

# Decorator to register pluggable layers.
@classmethod
def register(cls, name: str):
def decorator(op_cls):
assert name not in op_registry, f"Duplicate op name: {name}"
op_cls.name = name
op_registry[name] = op_cls
return op_cls

return decorator

# Decorator to register out-of-tree(oot) pluggable layers.
# For OOT pluggable layers:
# if in-tree layer class is registered with an oot_custom_layer,
# the oot_custom_layer will be used instead.
@classmethod
def register_oot(cls, _decorated_layer_cls=None, name: str | None = None):
def decorator(layer_cls):
reg_name = name if name is not None else cls.__name__
assert reg_name not in op_registry_oot, f"Duplicate layer name: {reg_name}"
layer_cls.name = reg_name
op_registry_oot[reg_name] = layer_cls
return layer_cls

if _decorated_layer_cls is None:
# Called with parentheses: @PluggableLayer.register_oot()
# or @PluggableLayer.register_oot(name="...")
return decorator
elif isinstance(_decorated_layer_cls, type): # Check if it's a class
# Called without parentheses: @PluggableLayer.register_oot
return decorator(_decorated_layer_cls)
else:
raise TypeError("Decorator can only be applied to classes.")


class CustomOp(nn.Module):
"""
Base class for custom ops.
Expand All @@ -27,10 +107,10 @@ def __new__(cls, *args, **kwargs):
f"@CustomOp.register, or it's the CustomOp base class itself."
) from None

if op_name not in cls.op_registry_oot:
if op_name not in op_registry_oot:
op_cls_to_instantiate = cls
else:
op_cls_to_instantiate = cls.op_registry_oot[op_name]
op_cls_to_instantiate = op_registry_oot[op_name]
logger.debug(
"Instantiating custom op: %s using %s",
op_name,
Expand Down Expand Up @@ -150,21 +230,13 @@ def default_on() -> bool:

return not count_none > 0 or count_all > 0

# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: dict[str, type["CustomOp"]] = {}
op_registry_oot: dict[str, type["CustomOp"]] = {}

# Decorator to register custom ops.
@classmethod
def register(cls, name: str):
def decorator(op_cls):
assert name not in cls.op_registry, f"Duplicate op name: {name}"
assert name not in op_registry, f"Duplicate op name: {name}"
op_cls.name = name
cls.op_registry[name] = op_cls
op_registry[name] = op_cls
return op_cls

return decorator
Expand All @@ -182,9 +254,9 @@ def decorator(op_cls):
def register_oot(cls, _decorated_op_cls=None, name: str | None = None):
def decorator(op_cls):
reg_name = name if name is not None else cls.__name__
assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}"
assert reg_name not in op_registry_oot, f"Duplicate op name: {reg_name}"
op_cls.name = reg_name
cls.op_registry_oot[reg_name] = op_cls
op_registry_oot[reg_name] = op_cls
return op_cls

if _decorated_op_cls is None:
Expand Down
19 changes: 8 additions & 11 deletions vllm/model_executor/layers/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization import QuantizationConfig


Expand All @@ -30,13 +30,13 @@ class MLAModules:


# --8<-- [start:multi_head_latent_attention]
@CustomOp.register("multi_head_latent_attention")
class MultiHeadLatentAttentionWrapper(CustomOp):
"""MLA layer registered as CustomOp to allow OOT backends to add
@PluggableLayer.register("multi_head_latent_attention")
class MultiHeadLatentAttentionWrapper(PluggableLayer):
"""Pluggable MLA layer which allows OOT backends to add
custom implementations of the outer MLA layer (including rope & o_proj).
Note that currently MLA ignores the enable/disable mechanism of CustomOp
because there is only one in-tree implementation in forward_native.
TODO: implement this with a new PluggableLayer mechanism.
Note that currently oot platforms can still use CustomOp.register_oot to
replace MLA layer entirly, although we use PluggableLayer to register
this layer now.

This class takes positions and hidden_states as input.
The input tensors can either contain prefill tokens or decode tokens.
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(

self.prefix = prefix

def forward_native(
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -174,6 +174,3 @@ def forward_native(
)

return self.o_proj(attn_out)[0]

def forward_cuda(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)