diff --git a/docs/design/custom_op.md b/docs/design/custom_op.md index 13c2915abe8f..3f4934b15699 100644 --- a/docs/design/custom_op.md +++ b/docs/design/custom_op.md @@ -8,15 +8,6 @@ This document will introduce how CustomOp works in vLLM and how to implement a n `CustomOp` manages two dictionaries of all custom ops (i.e., op classes, indexed by registered name) in its class, for vLLM and OOT plugins respectively. -??? code - - ```python - class CustomOp(nn.Module): - - op_registry: dict[str, type["CustomOp"]] = {} - op_registry_oot: dict[str, type["CustomOp"]] = {} - ``` - We can use `@CustomOp.register("op_name")` to register an op class to the `CustomOp` system. After this, the `op_name` and its class will be added into the `op_registry` dictionary. In addition, We can also register an OOT op by `@CustomOp.register_oot("op_name")`. We will introduce this mechanism in detail later. When a `CustomOp` is called (i.e., call its `forward()` method), if it is enabled (i.e., with `--compilation_config.custom_ops '["+op_name"]'`), it will automatically dispatch the forward method to the appropriate backend according to `current_platform`. Otherwise (i.e., it is disabled), it will only call the `forward_native()` method to use PyTorch-native implementation of this forward method. diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 8ee1b1a37ca6..316caf06b29c 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -11,7 +11,7 @@ get_cached_compilation_config, set_current_vllm_config, ) -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import CustomOp, op_registry from vllm.model_executor.layers.activation import ( GeluAndMul, ReLUSquaredActivation, @@ -98,17 +98,17 @@ def test_enabled_ops( ops_enabled = [bool(x) for x in ops_enabled] assert RMSNorm(1024).enabled() == ops_enabled[0] - assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] + assert op_registry["rms_norm"].enabled() == ops_enabled[0] assert SiluAndMul().enabled() == ops_enabled[1] - assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] + assert op_registry["silu_and_mul"].enabled() == ops_enabled[1] assert GeluAndMul().enabled() == ops_enabled[2] - assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] + assert op_registry["gelu_and_mul"].enabled() == ops_enabled[2] # If registered, subclasses should follow their own name assert Relu3().enabled() == ops_enabled[3] - assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] + assert op_registry["relu3"].enabled() == ops_enabled[3] # Unregistered subclass class SiluAndMul2(SiluAndMul): diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 81ba544b4813..6fe252fa27ee 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -11,6 +11,86 @@ logger = init_logger(__name__) +# Dictionary of all custom ops (classes, indexed by registered name). +# To check if an op with a name is enabled, call .enabled() on the class. +# Examples: +# - MyOp.enabled() +# - op_registry["my_op"].enabled() +op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} +op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} + + +class PluggableLayer(nn.Module): + """ + Base class for pluggable layers. + + A PluggableLayer is a *module-composing* abstraction: it may instantiate other + ``torch.nn.Module`` objects as sub-layers, and its functionality depends on + these sub-layers following a generalized invocation sequence. Also, it is stateful + and may hold parameters or buffers. + + Unlike :class:`CustomOp`, PluggableLayer does NOT provide per-platform + ``forward_*`` dispatch. Instead, it supports out-of-tree (OOT) replacement + of the entire layer class at instantiation time, allowing customized + initialization and submodule composition. + """ + + def __new__(cls, *args, **kwargs): + try: + layer_class_name = cls.__name__ + except AttributeError: + raise TypeError( + f"Cannot instantiate '{cls.__name__}': its 'name' attribute " + f"was not set, possibly because it was not decorated with " + f"@PluggableLayer.register, or it's the PluggableLayer itself." + ) from None + + if layer_class_name not in op_registry_oot: + layer_cls_to_instantiate = cls + else: + layer_cls_to_instantiate = op_registry_oot[layer_class_name] + logger.debug( + "Instantiating pluggable layer: %s using %s", + layer_class_name, + str(layer_cls_to_instantiate), + ) + return super().__new__(layer_cls_to_instantiate) + + # Decorator to register pluggable layers. + @classmethod + def register(cls, name: str): + def decorator(op_cls): + assert name not in op_registry, f"Duplicate op name: {name}" + op_cls.name = name + op_registry[name] = op_cls + return op_cls + + return decorator + + # Decorator to register out-of-tree(oot) pluggable layers. + # For OOT pluggable layers: + # if in-tree layer class is registered with an oot_custom_layer, + # the oot_custom_layer will be used instead. + @classmethod + def register_oot(cls, _decorated_layer_cls=None, name: str | None = None): + def decorator(layer_cls): + reg_name = name if name is not None else cls.__name__ + assert reg_name not in op_registry_oot, f"Duplicate layer name: {reg_name}" + layer_cls.name = reg_name + op_registry_oot[reg_name] = layer_cls + return layer_cls + + if _decorated_layer_cls is None: + # Called with parentheses: @PluggableLayer.register_oot() + # or @PluggableLayer.register_oot(name="...") + return decorator + elif isinstance(_decorated_layer_cls, type): # Check if it's a class + # Called without parentheses: @PluggableLayer.register_oot + return decorator(_decorated_layer_cls) + else: + raise TypeError("Decorator can only be applied to classes.") + + class CustomOp(nn.Module): """ Base class for custom ops. @@ -27,10 +107,10 @@ def __new__(cls, *args, **kwargs): f"@CustomOp.register, or it's the CustomOp base class itself." ) from None - if op_name not in cls.op_registry_oot: + if op_name not in op_registry_oot: op_cls_to_instantiate = cls else: - op_cls_to_instantiate = cls.op_registry_oot[op_name] + op_cls_to_instantiate = op_registry_oot[op_name] logger.debug( "Instantiating custom op: %s using %s", op_name, @@ -150,21 +230,13 @@ def default_on() -> bool: return not count_none > 0 or count_all > 0 - # Dictionary of all custom ops (classes, indexed by registered name). - # To check if an op with a name is enabled, call .enabled() on the class. - # Examples: - # - MyOp.enabled() - # - op_registry["my_op"].enabled() - op_registry: dict[str, type["CustomOp"]] = {} - op_registry_oot: dict[str, type["CustomOp"]] = {} - # Decorator to register custom ops. @classmethod def register(cls, name: str): def decorator(op_cls): - assert name not in cls.op_registry, f"Duplicate op name: {name}" + assert name not in op_registry, f"Duplicate op name: {name}" op_cls.name = name - cls.op_registry[name] = op_cls + op_registry[name] = op_cls return op_cls return decorator @@ -182,9 +254,9 @@ def decorator(op_cls): def register_oot(cls, _decorated_op_cls=None, name: str | None = None): def decorator(op_cls): reg_name = name if name is not None else cls.__name__ - assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}" + assert reg_name not in op_registry_oot, f"Duplicate op name: {reg_name}" op_cls.name = reg_name - cls.op_registry_oot[reg_name] = op_cls + op_registry_oot[reg_name] = op_cls return op_cls if _decorated_op_cls is None: diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 65541d2a485a..2549f1221f36 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -6,7 +6,7 @@ from vllm.attention.layer import MLAAttention from vllm.config import CacheConfig -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.quantization import QuantizationConfig @@ -30,13 +30,13 @@ class MLAModules: # --8<-- [start:multi_head_latent_attention] -@CustomOp.register("multi_head_latent_attention") -class MultiHeadLatentAttentionWrapper(CustomOp): - """MLA layer registered as CustomOp to allow OOT backends to add +@PluggableLayer.register("multi_head_latent_attention") +class MultiHeadLatentAttentionWrapper(PluggableLayer): + """Pluggable MLA layer which allows OOT backends to add custom implementations of the outer MLA layer (including rope & o_proj). - Note that currently MLA ignores the enable/disable mechanism of CustomOp - because there is only one in-tree implementation in forward_native. - TODO: implement this with a new PluggableLayer mechanism. + Note that currently oot platforms can still use CustomOp.register_oot to + replace MLA layer entirly, although we use PluggableLayer to register + this layer now. This class takes positions and hidden_states as input. The input tensors can either contain prefill tokens or decode tokens. @@ -110,7 +110,7 @@ def __init__( self.prefix = prefix - def forward_native( + def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, @@ -174,6 +174,3 @@ def forward_native( ) return self.o_proj(attn_out)[0] - - def forward_cuda(self, *args, **kwargs): - return self.forward_native(*args, **kwargs)