diff --git a/benchmarks/kernels/benchmark_activation.py b/benchmarks/kernels/benchmark_activation.py index fbe5f744148e..bb66e5d088ef 100644 --- a/benchmarks/kernels/benchmark_activation.py +++ b/benchmarks/kernels/benchmark_activation.py @@ -7,7 +7,7 @@ import torch import vllm.model_executor.layers.activation # noqa F401 -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import op_registry from vllm.triton_utils import triton from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed @@ -33,14 +33,14 @@ def benchmark_activation( torch.set_default_device(device) if func_name == "gelu_and_mul": - layer = CustomOp.op_registry[func_name](approximate="none") + layer = op_registry[func_name](approximate="none") elif func_name == "gelu_and_mul_tanh": - layer = CustomOp.op_registry["gelu_and_mul"](approximate="tanh") + layer = op_registry["gelu_and_mul"](approximate="tanh") elif func_name == "fatrelu_and_mul": threshold = 0.5 - layer = CustomOp.op_registry[func_name](threshold) + layer = op_registry[func_name](threshold) else: - layer = CustomOp.op_registry[func_name]() + layer = op_registry[func_name]() x = torch.randn(num_tokens, dim, dtype=dtype, device=device) compiled_layer = torch.compile(layer.forward_native) diff --git a/docs/design/custom_op.md b/docs/design/custom_op.md index 13c2915abe8f..3f4934b15699 100644 --- a/docs/design/custom_op.md +++ b/docs/design/custom_op.md @@ -8,15 +8,6 @@ This document will introduce how CustomOp works in vLLM and how to implement a n `CustomOp` manages two dictionaries of all custom ops (i.e., op classes, indexed by registered name) in its class, for vLLM and OOT plugins respectively. -??? code - - ```python - class CustomOp(nn.Module): - - op_registry: dict[str, type["CustomOp"]] = {} - op_registry_oot: dict[str, type["CustomOp"]] = {} - ``` - We can use `@CustomOp.register("op_name")` to register an op class to the `CustomOp` system. After this, the `op_name` and its class will be added into the `op_registry` dictionary. In addition, We can also register an OOT op by `@CustomOp.register_oot("op_name")`. We will introduce this mechanism in detail later. When a `CustomOp` is called (i.e., call its `forward()` method), if it is enabled (i.e., with `--compilation_config.custom_ops '["+op_name"]'`), it will automatically dispatch the forward method to the appropriate backend according to `current_platform`. Otherwise (i.e., it is disabled), it will only call the `forward_native()` method to use PyTorch-native implementation of this forward method. diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 7763be0cb5bf..9c6cc4dabb26 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -13,7 +13,7 @@ from torch._prims_common import TensorLikeType from tests.kernels.quant_utils import native_w8a8_block_matmul -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import op_registry from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils.torch_utils import make_tensor_with_pad @@ -883,7 +883,7 @@ def torch_experts( f32 = torch.float32 - act = CustomOp.op_registry[activation] + act = op_registry[activation] for i in range(num_experts): mask = topk_ids == i diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 8ee1b1a37ca6..316caf06b29c 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -11,7 +11,7 @@ get_cached_compilation_config, set_current_vllm_config, ) -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import CustomOp, op_registry from vllm.model_executor.layers.activation import ( GeluAndMul, ReLUSquaredActivation, @@ -98,17 +98,17 @@ def test_enabled_ops( ops_enabled = [bool(x) for x in ops_enabled] assert RMSNorm(1024).enabled() == ops_enabled[0] - assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] + assert op_registry["rms_norm"].enabled() == ops_enabled[0] assert SiluAndMul().enabled() == ops_enabled[1] - assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] + assert op_registry["silu_and_mul"].enabled() == ops_enabled[1] assert GeluAndMul().enabled() == ops_enabled[2] - assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] + assert op_registry["gelu_and_mul"].enabled() == ops_enabled[2] # If registered, subclasses should follow their own name assert Relu3().enabled() == ops_enabled[3] - assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] + assert op_registry["relu3"].enabled() == ops_enabled[3] # Unregistered subclass class SiluAndMul2(SiluAndMul): diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 8ce31ad18133..0608b3755eab 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -1032,13 +1032,13 @@ def custom_op_log_check(self): # check if op name exists in model op_name = op[1:] if op_name not in all_ops_in_model: - from vllm.model_executor.custom_op import CustomOp + from vllm.model_executor.custom_op import op_registry # Does op exist at all or is it just not present in this model? # Note: Only imported op classes appear in the registry. missing_str = ( "doesn't exist (or wasn't imported/registered)" - if op_name not in CustomOp.op_registry + if op_name not in op_registry else "not present in model" ) diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 81ba544b4813..6fe252fa27ee 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -11,6 +11,86 @@ logger = init_logger(__name__) +# Dictionary of all custom ops (classes, indexed by registered name). +# To check if an op with a name is enabled, call .enabled() on the class. +# Examples: +# - MyOp.enabled() +# - op_registry["my_op"].enabled() +op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} +op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} + + +class PluggableLayer(nn.Module): + """ + Base class for pluggable layers. + + A PluggableLayer is a *module-composing* abstraction: it may instantiate other + ``torch.nn.Module`` objects as sub-layers, and its functionality depends on + these sub-layers following a generalized invocation sequence. Also, it is stateful + and may hold parameters or buffers. + + Unlike :class:`CustomOp`, PluggableLayer does NOT provide per-platform + ``forward_*`` dispatch. Instead, it supports out-of-tree (OOT) replacement + of the entire layer class at instantiation time, allowing customized + initialization and submodule composition. + """ + + def __new__(cls, *args, **kwargs): + try: + layer_class_name = cls.__name__ + except AttributeError: + raise TypeError( + f"Cannot instantiate '{cls.__name__}': its 'name' attribute " + f"was not set, possibly because it was not decorated with " + f"@PluggableLayer.register, or it's the PluggableLayer itself." + ) from None + + if layer_class_name not in op_registry_oot: + layer_cls_to_instantiate = cls + else: + layer_cls_to_instantiate = op_registry_oot[layer_class_name] + logger.debug( + "Instantiating pluggable layer: %s using %s", + layer_class_name, + str(layer_cls_to_instantiate), + ) + return super().__new__(layer_cls_to_instantiate) + + # Decorator to register pluggable layers. + @classmethod + def register(cls, name: str): + def decorator(op_cls): + assert name not in op_registry, f"Duplicate op name: {name}" + op_cls.name = name + op_registry[name] = op_cls + return op_cls + + return decorator + + # Decorator to register out-of-tree(oot) pluggable layers. + # For OOT pluggable layers: + # if in-tree layer class is registered with an oot_custom_layer, + # the oot_custom_layer will be used instead. + @classmethod + def register_oot(cls, _decorated_layer_cls=None, name: str | None = None): + def decorator(layer_cls): + reg_name = name if name is not None else cls.__name__ + assert reg_name not in op_registry_oot, f"Duplicate layer name: {reg_name}" + layer_cls.name = reg_name + op_registry_oot[reg_name] = layer_cls + return layer_cls + + if _decorated_layer_cls is None: + # Called with parentheses: @PluggableLayer.register_oot() + # or @PluggableLayer.register_oot(name="...") + return decorator + elif isinstance(_decorated_layer_cls, type): # Check if it's a class + # Called without parentheses: @PluggableLayer.register_oot + return decorator(_decorated_layer_cls) + else: + raise TypeError("Decorator can only be applied to classes.") + + class CustomOp(nn.Module): """ Base class for custom ops. @@ -27,10 +107,10 @@ def __new__(cls, *args, **kwargs): f"@CustomOp.register, or it's the CustomOp base class itself." ) from None - if op_name not in cls.op_registry_oot: + if op_name not in op_registry_oot: op_cls_to_instantiate = cls else: - op_cls_to_instantiate = cls.op_registry_oot[op_name] + op_cls_to_instantiate = op_registry_oot[op_name] logger.debug( "Instantiating custom op: %s using %s", op_name, @@ -150,21 +230,13 @@ def default_on() -> bool: return not count_none > 0 or count_all > 0 - # Dictionary of all custom ops (classes, indexed by registered name). - # To check if an op with a name is enabled, call .enabled() on the class. - # Examples: - # - MyOp.enabled() - # - op_registry["my_op"].enabled() - op_registry: dict[str, type["CustomOp"]] = {} - op_registry_oot: dict[str, type["CustomOp"]] = {} - # Decorator to register custom ops. @classmethod def register(cls, name: str): def decorator(op_cls): - assert name not in cls.op_registry, f"Duplicate op name: {name}" + assert name not in op_registry, f"Duplicate op name: {name}" op_cls.name = name - cls.op_registry[name] = op_cls + op_registry[name] = op_cls return op_cls return decorator @@ -182,9 +254,9 @@ def decorator(op_cls): def register_oot(cls, _decorated_op_cls=None, name: str | None = None): def decorator(op_cls): reg_name = name if name is not None else cls.__name__ - assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}" + assert reg_name not in op_registry_oot, f"Duplicate op name: {reg_name}" op_cls.name = reg_name - cls.op_registry_oot[reg_name] = op_cls + op_registry_oot[reg_name] = op_cls return op_cls if _decorated_op_cls is None: diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 65541d2a485a..2549f1221f36 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -6,7 +6,7 @@ from vllm.attention.layer import MLAAttention from vllm.config import CacheConfig -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.quantization import QuantizationConfig @@ -30,13 +30,13 @@ class MLAModules: # --8<-- [start:multi_head_latent_attention] -@CustomOp.register("multi_head_latent_attention") -class MultiHeadLatentAttentionWrapper(CustomOp): - """MLA layer registered as CustomOp to allow OOT backends to add +@PluggableLayer.register("multi_head_latent_attention") +class MultiHeadLatentAttentionWrapper(PluggableLayer): + """Pluggable MLA layer which allows OOT backends to add custom implementations of the outer MLA layer (including rope & o_proj). - Note that currently MLA ignores the enable/disable mechanism of CustomOp - because there is only one in-tree implementation in forward_native. - TODO: implement this with a new PluggableLayer mechanism. + Note that currently oot platforms can still use CustomOp.register_oot to + replace MLA layer entirly, although we use PluggableLayer to register + this layer now. This class takes positions and hidden_states as input. The input tensors can either contain prefill tokens or decode tokens. @@ -110,7 +110,7 @@ def __init__( self.prefix = prefix - def forward_native( + def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, @@ -174,6 +174,3 @@ def forward_native( ) return self.o_proj(attn_out)[0] - - def forward_cuda(self, *args, **kwargs): - return self.forward_native(*args, **kwargs)