diff --git a/docs/design/custom_op.md b/docs/design/custom_op.md index 3f4934b15699..13c2915abe8f 100644 --- a/docs/design/custom_op.md +++ b/docs/design/custom_op.md @@ -8,6 +8,15 @@ This document will introduce how CustomOp works in vLLM and how to implement a n `CustomOp` manages two dictionaries of all custom ops (i.e., op classes, indexed by registered name) in its class, for vLLM and OOT plugins respectively. +??? code + + ```python + class CustomOp(nn.Module): + + op_registry: dict[str, type["CustomOp"]] = {} + op_registry_oot: dict[str, type["CustomOp"]] = {} + ``` + We can use `@CustomOp.register("op_name")` to register an op class to the `CustomOp` system. After this, the `op_name` and its class will be added into the `op_registry` dictionary. In addition, We can also register an OOT op by `@CustomOp.register_oot("op_name")`. We will introduce this mechanism in detail later. When a `CustomOp` is called (i.e., call its `forward()` method), if it is enabled (i.e., with `--compilation_config.custom_ops '["+op_name"]'`), it will automatically dispatch the forward method to the appropriate backend according to `current_platform`. Otherwise (i.e., it is disabled), it will only call the `forward_native()` method to use PyTorch-native implementation of this forward method. diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 316caf06b29c..8ee1b1a37ca6 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -11,7 +11,7 @@ get_cached_compilation_config, set_current_vllm_config, ) -from vllm.model_executor.custom_op import CustomOp, op_registry +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import ( GeluAndMul, ReLUSquaredActivation, @@ -98,17 +98,17 @@ def test_enabled_ops( ops_enabled = [bool(x) for x in ops_enabled] assert RMSNorm(1024).enabled() == ops_enabled[0] - assert op_registry["rms_norm"].enabled() == ops_enabled[0] + assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] assert SiluAndMul().enabled() == ops_enabled[1] - assert op_registry["silu_and_mul"].enabled() == ops_enabled[1] + assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] assert GeluAndMul().enabled() == ops_enabled[2] - assert op_registry["gelu_and_mul"].enabled() == ops_enabled[2] + assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] # If registered, subclasses should follow their own name assert Relu3().enabled() == ops_enabled[3] - assert op_registry["relu3"].enabled() == ops_enabled[3] + assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] # Unregistered subclass class SiluAndMul2(SiluAndMul): diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 6fe252fa27ee..81ba544b4813 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -11,86 +11,6 @@ logger = init_logger(__name__) -# Dictionary of all custom ops (classes, indexed by registered name). -# To check if an op with a name is enabled, call .enabled() on the class. -# Examples: -# - MyOp.enabled() -# - op_registry["my_op"].enabled() -op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} -op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} - - -class PluggableLayer(nn.Module): - """ - Base class for pluggable layers. - - A PluggableLayer is a *module-composing* abstraction: it may instantiate other - ``torch.nn.Module`` objects as sub-layers, and its functionality depends on - these sub-layers following a generalized invocation sequence. Also, it is stateful - and may hold parameters or buffers. - - Unlike :class:`CustomOp`, PluggableLayer does NOT provide per-platform - ``forward_*`` dispatch. Instead, it supports out-of-tree (OOT) replacement - of the entire layer class at instantiation time, allowing customized - initialization and submodule composition. - """ - - def __new__(cls, *args, **kwargs): - try: - layer_class_name = cls.__name__ - except AttributeError: - raise TypeError( - f"Cannot instantiate '{cls.__name__}': its 'name' attribute " - f"was not set, possibly because it was not decorated with " - f"@PluggableLayer.register, or it's the PluggableLayer itself." - ) from None - - if layer_class_name not in op_registry_oot: - layer_cls_to_instantiate = cls - else: - layer_cls_to_instantiate = op_registry_oot[layer_class_name] - logger.debug( - "Instantiating pluggable layer: %s using %s", - layer_class_name, - str(layer_cls_to_instantiate), - ) - return super().__new__(layer_cls_to_instantiate) - - # Decorator to register pluggable layers. - @classmethod - def register(cls, name: str): - def decorator(op_cls): - assert name not in op_registry, f"Duplicate op name: {name}" - op_cls.name = name - op_registry[name] = op_cls - return op_cls - - return decorator - - # Decorator to register out-of-tree(oot) pluggable layers. - # For OOT pluggable layers: - # if in-tree layer class is registered with an oot_custom_layer, - # the oot_custom_layer will be used instead. - @classmethod - def register_oot(cls, _decorated_layer_cls=None, name: str | None = None): - def decorator(layer_cls): - reg_name = name if name is not None else cls.__name__ - assert reg_name not in op_registry_oot, f"Duplicate layer name: {reg_name}" - layer_cls.name = reg_name - op_registry_oot[reg_name] = layer_cls - return layer_cls - - if _decorated_layer_cls is None: - # Called with parentheses: @PluggableLayer.register_oot() - # or @PluggableLayer.register_oot(name="...") - return decorator - elif isinstance(_decorated_layer_cls, type): # Check if it's a class - # Called without parentheses: @PluggableLayer.register_oot - return decorator(_decorated_layer_cls) - else: - raise TypeError("Decorator can only be applied to classes.") - - class CustomOp(nn.Module): """ Base class for custom ops. @@ -107,10 +27,10 @@ def __new__(cls, *args, **kwargs): f"@CustomOp.register, or it's the CustomOp base class itself." ) from None - if op_name not in op_registry_oot: + if op_name not in cls.op_registry_oot: op_cls_to_instantiate = cls else: - op_cls_to_instantiate = op_registry_oot[op_name] + op_cls_to_instantiate = cls.op_registry_oot[op_name] logger.debug( "Instantiating custom op: %s using %s", op_name, @@ -230,13 +150,21 @@ def default_on() -> bool: return not count_none > 0 or count_all > 0 + # Dictionary of all custom ops (classes, indexed by registered name). + # To check if an op with a name is enabled, call .enabled() on the class. + # Examples: + # - MyOp.enabled() + # - op_registry["my_op"].enabled() + op_registry: dict[str, type["CustomOp"]] = {} + op_registry_oot: dict[str, type["CustomOp"]] = {} + # Decorator to register custom ops. @classmethod def register(cls, name: str): def decorator(op_cls): - assert name not in op_registry, f"Duplicate op name: {name}" + assert name not in cls.op_registry, f"Duplicate op name: {name}" op_cls.name = name - op_registry[name] = op_cls + cls.op_registry[name] = op_cls return op_cls return decorator @@ -254,9 +182,9 @@ def decorator(op_cls): def register_oot(cls, _decorated_op_cls=None, name: str | None = None): def decorator(op_cls): reg_name = name if name is not None else cls.__name__ - assert reg_name not in op_registry_oot, f"Duplicate op name: {reg_name}" + assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}" op_cls.name = reg_name - op_registry_oot[reg_name] = op_cls + cls.op_registry_oot[reg_name] = op_cls return op_cls if _decorated_op_cls is None: diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 2549f1221f36..65541d2a485a 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -6,7 +6,7 @@ from vllm.attention.layer import MLAAttention from vllm.config import CacheConfig -from vllm.model_executor.custom_op import PluggableLayer +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization import QuantizationConfig @@ -30,13 +30,13 @@ class MLAModules: # --8<-- [start:multi_head_latent_attention] -@PluggableLayer.register("multi_head_latent_attention") -class MultiHeadLatentAttentionWrapper(PluggableLayer): - """Pluggable MLA layer which allows OOT backends to add +@CustomOp.register("multi_head_latent_attention") +class MultiHeadLatentAttentionWrapper(CustomOp): + """MLA layer registered as CustomOp to allow OOT backends to add custom implementations of the outer MLA layer (including rope & o_proj). - Note that currently oot platforms can still use CustomOp.register_oot to - replace MLA layer entirly, although we use PluggableLayer to register - this layer now. + Note that currently MLA ignores the enable/disable mechanism of CustomOp + because there is only one in-tree implementation in forward_native. + TODO: implement this with a new PluggableLayer mechanism. This class takes positions and hidden_states as input. The input tensors can either contain prefill tokens or decode tokens. @@ -110,7 +110,7 @@ def __init__( self.prefix = prefix - def forward( + def forward_native( self, positions: torch.Tensor, hidden_states: torch.Tensor, @@ -174,3 +174,6 @@ def forward( ) return self.o_proj(attn_out)[0] + + def forward_cuda(self, *args, **kwargs): + return self.forward_native(*args, **kwargs)