diff --git a/tests/kernels/core/test_activation.py b/tests/kernels/core/test_activation.py index e7de7731286f..10c98f0d8a34 100644 --- a/tests/kernels/core/test_activation.py +++ b/tests/kernels/core/test_activation.py @@ -6,9 +6,11 @@ import pytest import torch -from tests.kernels.allclose_default import get_default_atol, get_default_rtol -from tests.kernels.utils import opcheck +from tests.kernels.allclose_default import get_default_atol +from tests.kernels.utils import assert_pluggable_layer_calls_ir_op, opcheck +from vllm import ir from vllm.model_executor.layers.activation import ( + GELU, FastGELU, FatreluAndMul, GeluAndMul, @@ -20,6 +22,7 @@ SwigluStepAndMul, swiglustep_and_mul_triton, ) +from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -31,13 +34,18 @@ ] +# TODO: This test validates kernel output correctness, which overlaps with +# the ir.ops routing tests. As custom ops are migrated to PluggableLayer, +# remove cases from here and add them to ACTIVATION_LAYER_CONFIGS. +# Eventually, test_activation_ir_op_routing will be the single source of truth. +@pytest.mark.skipif( + current_platform.is_cpu(), reason="CUDA activation tests require GPU platform" +) @pytest.mark.parametrize( "activation", [ "silu_and_mul", "mul_and_silu", - "gelu", - "gelu_tanh", "fatrelu", "swigluoai_and_mul", "swiglustep_and_mul", @@ -67,12 +75,6 @@ def test_act_and_mul( if activation == "mul_and_silu": layer = MulAndSilu() fn = torch.ops._C.mul_and_silu - elif activation == "gelu": - layer = GeluAndMul(approximate="none") - fn = torch.ops._C.gelu_and_mul - elif activation == "gelu_tanh": - layer = GeluAndMul(approximate="tanh") - fn = torch.ops._C.gelu_tanh_and_mul elif activation == "fatrelu": threshold = random.uniform(0, 1) layer = FatreluAndMul(threshold) @@ -116,39 +118,32 @@ def _get_rtol(output) -> float: opcheck(fn, (out, x)) -@pytest.mark.parametrize( - "activation", - [ - (FastGELU, torch.ops._C.gelu_fast), - (NewGELU, torch.ops._C.gelu_new), - (QuickGELU, torch.ops._C.gelu_quick), - ], -) -@pytest.mark.parametrize("num_tokens", NUM_TOKENS) -@pytest.mark.parametrize("d", D) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) +# Activation layer test configs (currently only GELU variants) +ACTIVATION_LAYER_CONFIGS = [ + ("GELU", GELU, ir.ops.gelu, {}), + ("GeluAndMul_none", GeluAndMul, ir.ops.gelu_and_mul, {"approximate": "none"}), + ("GeluAndMul_tanh", GeluAndMul, ir.ops.gelu_and_mul, {"approximate": "tanh"}), + ("NewGELU", NewGELU, ir.ops.gelu_new, {}), + ("FastGELU", FastGELU, ir.ops.gelu_fast, {}), + ("QuickGELU", QuickGELU, ir.ops.quick_gelu, {}), +] + + +# This test verifies that PluggableLayer instances correctly dispatch to +# their corresponding ir.ops functions. Add new activations to +# ACTIVATION_LAYER_CONFIGS to automatically get test coverage. +# The test only verifies routing correctness, not output values, +# so we don't need multiple parameter combinations. +@pytest.mark.parametrize("name, layer_cls, ir_op, kwargs", ACTIVATION_LAYER_CONFIGS) @torch.inference_mode() -def test_activation( +def test_activation_ir_op_routing( default_vllm_config, - activation: type[torch.nn.Module], - num_tokens: int, - d: int, - dtype: torch.dtype, - seed: int, - device: str, + name: str, + layer_cls: type, + ir_op, + kwargs: dict, ) -> None: - set_random_seed(seed) - torch.set_default_device(device) - x = torch.randn(num_tokens, d, dtype=dtype) - layer = activation[0]() - fn = activation[1] - out = layer(x) - ref_out = layer.forward_native(x) - torch.testing.assert_close( - out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out) - ) - - out = torch.empty_like(x) - opcheck(fn, (out, x)) + """Test PluggableLayer instances call the corresponding ir.ops function.""" + x = torch.randn(83, 512) + layer = layer_cls(**kwargs) + assert_pluggable_layer_calls_ir_op(layer, ir_op, x) diff --git a/tests/kernels/core/test_cpu_activation.py b/tests/kernels/core/test_cpu_activation.py index 40b5f0454683..3f664b80e210 100644 --- a/tests/kernels/core/test_cpu_activation.py +++ b/tests/kernels/core/test_cpu_activation.py @@ -1,25 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# NOTE: This file contains CPU-only activation kernel tests. +# GELU-related tests have been removed because GeluAndMul and other GELU +# activations are now PluggableLayers that call ir.ops. The vllm_c kernel +# implementations for GELU are CUDA-only (marked with CUDA_ALIKE in vllm_c.py). +# As GELU and other activations gain CPU ir.ops implementations, add routing +# tests to test_activation.py's test_activation_ir_op_routing. + import pytest import torch from tests.kernels.allclose_default import get_default_atol, get_default_rtol from tests.kernels.utils import opcheck -from vllm.platforms import CpuArchEnum, current_platform +from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed if not current_platform.is_cpu(): pytest.skip("skipping CPU-only tests", allow_module_level=True) -from vllm.model_executor.layers.activation import ( - GELU, - FastGELU, - GeluAndMul, - NewGELU, - QuickGELU, - SiluAndMul, -) +from vllm.model_executor.layers.activation import SiluAndMul DTYPES = [torch.bfloat16, torch.float32] NUM_TOKENS = [7, 83] @@ -31,8 +31,6 @@ ("activation_cls", "fn"), [ (SiluAndMul, torch.ops._C.silu_and_mul), - (GeluAndMul, torch.ops._C.gelu_and_mul), - (GeluAndMul, torch.ops._C.gelu_tanh_and_mul), ], ) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -40,7 +38,7 @@ @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_cpu_act_and_mul( +def test_cpu_silu_and_mul( default_vllm_config, activation_cls: type[torch.nn.Module], fn: object, @@ -62,50 +60,4 @@ def test_cpu_act_and_mul( output_shape = x.shape[:-1] + (x.shape[-1] // 2,) raw_out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - opcheck(fn, (raw_out, x)) - - -@pytest.mark.parametrize( - ("activation_cls", "fn", "op_args"), - [ - (NewGELU, torch.ops._C.gelu_new, ()), - (FastGELU, torch.ops._C.gelu_fast, ()), - (QuickGELU, torch.ops._C.gelu_quick, ()), - pytest.param( - GELU, - getattr(torch.ops._C, "activation_lut_bf16", None), - ("gelu",), - marks=pytest.mark.skipif( - current_platform.get_cpu_architecture() != CpuArchEnum.ARM, - reason="activation_lut_bf16 is only built on Arm CPU", - ), - ), - ], -) -@pytest.mark.parametrize("num_tokens", NUM_TOKENS) -@pytest.mark.parametrize("d", D) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@torch.inference_mode() -def test_cpu_unary_activation( - default_vllm_config, - activation_cls: type[torch.nn.Module], - fn: object, - op_args: tuple[str, ...], - num_tokens: int, - d: int, - dtype: torch.dtype, - seed: int, -) -> None: - set_random_seed(seed) - x = torch.randn(num_tokens, d, dtype=dtype) - layer = activation_cls() - out = layer(x) - ref_out = layer.forward_native(x) - torch.testing.assert_close( - out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out) - ) - # gelu with activation_lut_bf16 only makes sense for BF16 - if not (activation_cls is GELU and dtype != torch.bfloat16): - raw_out = torch.empty_like(x) - opcheck(fn, (raw_out, x, *op_args)) + opcheck(fn, (raw_out, x)) \ No newline at end of file diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 12ff3830c21f..6591e7e3de41 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1085,3 +1085,36 @@ def group_broadcast(t, shape): output = output + bias return output + + +def assert_pluggable_layer_calls_ir_op( + layer: torch.nn.Module, + ir_op, + x: torch.Tensor, +) -> None: + """Assert that a PluggableLayer calls the corresponding ir.ops function. + + This is a test utility for verifying that PluggableLayer instances + (like GELU, GeluAndMul, etc.) correctly dispatch to their corresponding + ir.ops functions. + + The mock is applied to the underlying ir_op._inner_call which is the + Python entry point for eager execution. + + Args: + layer: PluggableLayer instance to test + ir_op: The corresponding ir.ops function (e.g., ir.ops.gelu) + x: Input tensor + """ + # Disable torch wrapping to directly test ir_op._inner_call + from vllm.ir.op import enable_torch_wrap + + with enable_torch_wrap(False): + with patch.object(ir_op, "_inner_call", wraps=ir_op._inner_call) as mock_op: + _ = layer(x) + mock_op.assert_called_once() + # Verify the call was made with the input tensor + call_args = mock_op.call_args + assert call_args[0][0].shape == x.shape, ( + f"Expected input shape {x.shape}, got {call_args[0][0].shape}" + ) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index fc4f6f6b63f9..59294dd8a4b3 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -13,7 +13,6 @@ ) from vllm.model_executor.custom_op import CustomOp, op_registry from vllm.model_executor.layers.activation import ( - GeluAndMul, ReLUSquaredActivation, SiluAndMul, ) @@ -44,39 +43,39 @@ class Relu3(ReLUSquaredActivation): [ # Default values based on compile level # - All by default (no Inductor compilation) - (None, 0, "eager", [True] * 4, True), - (None, 1, "eager", [True] * 4, True), - (None, 2, "eager", [True] * 4, True), - (None, 3, "eager", [True] * 4, True), + (None, 0, "eager", [True] * 3, True), + (None, 1, "eager", [True] * 3, True), + (None, 2, "eager", [True] * 3, True), + (None, 3, "eager", [True] * 3, True), # - None by default (with Inductor) - (None, 0, "inductor", [True] * 4, True), + (None, 0, "inductor", [True] * 3, True), # - None by default (with Inductor) - (None, 1, "inductor", [False] * 4, False), - (None, 2, "inductor", [False] * 4, False), - (None, 3, "inductor", [False] * 4, False), + (None, 1, "inductor", [False] * 3, False), + (None, 2, "inductor", [False] * 3, False), + (None, 3, "inductor", [False] * 3, False), # Explicitly enabling/disabling # # Default: all # # All but SiluAndMul - ("+rms_norm,-silu_and_mul", 0, "inductor", [1, 0, 1, 1], True), + ("+rms_norm,-silu_and_mul", 0, "inductor", [1, 0, 1], True), # Only ReLU3 - ("none,-rms_norm,+relu3", 1, "eager", [0, 0, 0, 1], False), + ("none,-rms_norm,+relu3", 1, "eager", [0, 0, 1], False), # All but SiluAndMul - ("all,-silu_and_mul", 2, "inductor", [1, 0, 1, 1], True), + ("all,-silu_and_mul", 2, "inductor", [1, 0, 1], True), # All but ReLU3 (even if ReLU2 is on) - ("-relu3,+relu2", 3, "eager", [1, 1, 1, 0], True), + ("-relu3,+relu2", 3, "eager", [1, 1, 0], True), # RMSNorm and SiluAndMul - ("none,-relu3,+rms_norm,+silu_and_mul", 3, "eager", [1, 1, 0, 0], False), + ("none,-relu3,+rms_norm,+silu_and_mul", 3, "eager", [1, 1, 0], False), # All but RMSNorm - ("-rms_norm", 3, "eager", [0, 1, 1, 1], True), + ("-rms_norm", 3, "eager", [0, 1, 1], True), # # Default: none # # Only ReLU3 - ("none,+relu3", 3, "inductor", [0, 0, 0, 1], False), + ("none,+relu3", 3, "inductor", [0, 0, 1], False), # All but RMSNorm - ("all,-rms_norm", 3, "inductor", [0, 1, 1, 1], True), + ("all,-rms_norm", 3, "inductor", [0, 1, 1], True), ], ) def test_enabled_ops( @@ -104,12 +103,9 @@ def test_enabled_ops( assert SiluAndMul().enabled() == ops_enabled[1] assert op_registry["silu_and_mul"].enabled() == ops_enabled[1] - assert GeluAndMul().enabled() == ops_enabled[2] - assert op_registry["gelu_and_mul"].enabled() == ops_enabled[2] - # If registered, subclasses should follow their own name - assert Relu3().enabled() == ops_enabled[3] - assert op_registry["relu3"].enabled() == ops_enabled[3] + assert Relu3().enabled() == ops_enabled[2] + assert op_registry["relu3"].enabled() == ops_enabled[2] # Unregistered subclass class SiluAndMul2(SiluAndMul): diff --git a/vllm/config/kernel.py b/vllm/config/kernel.py index 8d8e37a0549a..bc175edb2009 100644 --- a/vllm/config/kernel.py +++ b/vllm/config/kernel.py @@ -31,6 +31,21 @@ class IrOpPriorityConfig: rms_norm: list[str] = Field(default_factory=list) """Priority list for vllm.ir.ops.rms_norm""" + gelu: list[str] = Field(default_factory=list) + """Priority list for vllm.ir.ops.gelu""" + + gelu_and_mul: list[str] = Field(default_factory=list) + """Priority list for vllm.ir.ops.gelu_and_mul""" + + gelu_new: list[str] = Field(default_factory=list) + """Priority list for vllm.ir.ops.gelu_new""" + + gelu_fast: list[str] = Field(default_factory=list) + """Priority list for vllm.ir.ops.gelu_fast""" + + quick_gelu: list[str] = Field(default_factory=list) + """Priority list for vllm.ir.ops.quick_gelu""" + def compute_hash(self) -> str: """ Produces a hash unique to the pass configuration. diff --git a/vllm/ir/ops/__init__.py b/vllm/ir/ops/__init__.py index 25ad27c8a078..9af8b4d79607 100644 --- a/vllm/ir/ops/__init__.py +++ b/vllm/ir/ops/__init__.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .activation import gelu, gelu_and_mul, gelu_fast, gelu_new, quick_gelu from .layernorm import rms_norm -__all__ = ["rms_norm"] +__all__ = ["rms_norm", "gelu", "gelu_and_mul", "gelu_new", "gelu_fast", "quick_gelu"] diff --git a/vllm/ir/ops/activation.py b/vllm/ir/ops/activation.py new file mode 100644 index 000000000000..8555b22699e6 --- /dev/null +++ b/vllm/ir/ops/activation.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math + +import torch +from torch import Tensor + +from ..op import register_op + +c_gelu_new = math.sqrt(2.0 / math.pi) + + +@register_op +def gelu(x: Tensor, approximate: str = "none") -> Tensor: + """ + Standard GELU activation function. + + Formula: x * 0.5 * (1.0 + erf(x / sqrt(2))) + + Args: + x: Input tensor + approximate: If 'tanh', use tanh approximation for faster computation + """ + from vllm.platforms import CpuArchEnum, current_platform + + # ARM NEON LUT optimization for BF16 + is_arm_bf16 = ( + current_platform.is_cpu() + and current_platform.get_cpu_architecture() == CpuArchEnum.ARM + and x.dtype == torch.bfloat16 + and x.is_contiguous() + and hasattr(torch.ops._C, "activation_lut_bf16") + ) + if is_arm_bf16: + out = torch.empty_like(x) + torch.ops._C.activation_lut_bf16(out, x, "gelu") + return out + + import torch.nn.functional as F + return F.gelu(x, approximate=approximate) + + +@gelu.register_input_generator +def _gelu_input_generator( + num_tokens: int, hidden_size: int, dtype: torch.dtype +) -> tuple: + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + return (x,) + + +@register_op +def gelu_and_mul(x: Tensor, approximate: str = "none") -> Tensor: + """ + GeGLU activation function: GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2. + + This is used in models with gated feed-forward networks. + + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + import torch.nn.functional as F + d = x.shape[-1] // 2 + return F.gelu(x[..., :d], approximate=approximate) * x[..., d:] + + +@gelu_and_mul.register_input_generator +def _gelu_and_mul_input_generator( + num_tokens: int, hidden_size: int, dtype: torch.dtype +) -> tuple: + # hidden_size must be even for gelu_and_mul (it's 2*d) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + return (x,) + + +@register_op +def gelu_new(x: Tensor) -> Tensor: + """ + New GELU activation function. + + Formula: 0.5 * x * (1.0 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + + This is the GELU approximation used in GPT-2 and other transformer models. + """ + return 0.5 * x * (1.0 + torch.tanh(c_gelu_new * (x + 0.044715 * torch.pow(x, 3.0)))) + + +@gelu_new.register_input_generator +def _gelu_new_input_generator( + num_tokens: int, hidden_size: int, dtype: torch.dtype +) -> tuple: + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + return (x,) + + +@register_op +def gelu_fast(x: Tensor) -> Tensor: + """ + Fast GELU activation function. + + Formula: 0.5 * x * (1.0 + tanh(x * 0.7978845608 * (1.0 + 0.044715 * x^2))) + + A computationally efficient approximation of the GELU function. + """ + return 0.5 * x * ( + 1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)) + ) + + +@gelu_fast.register_input_generator +def _gelu_fast_input_generator( + num_tokens: int, hidden_size: int, dtype: torch.dtype +) -> tuple: + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + return (x,) + + +@register_op +def quick_gelu(x: Tensor) -> Tensor: + """ + Quick GELU activation function. + + Formula: x * sigmoid(1.702 * x) + + A fast approximation of GELU used in various transformer models. + Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90 + """ + return x * torch.sigmoid(1.702 * x) + + +@quick_gelu.register_input_generator +def _quick_gelu_input_generator( + num_tokens: int, hidden_size: int, dtype: torch.dtype +) -> tuple: + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + return (x,) diff --git a/vllm/kernels/vllm_c.py b/vllm/kernels/vllm_c.py index 124b02e4e27a..27bdf03a724d 100644 --- a/vllm/kernels/vllm_c.py +++ b/vllm/kernels/vllm_c.py @@ -31,3 +31,63 @@ def rms_norm( output = torch.empty(x.shape, device=x.device, dtype=x.dtype) torch.ops._C.rms_norm(output, x, weight, epsilon) return output + + +@ir.ops.gelu_and_mul.register_impl("vllm_c", supported=CUDA_ALIKE) +def gelu_and_mul(x: Tensor, approximate: str = "none") -> Tensor: + """ + GeGLU activation function: GELU(x[:d]) * x[d:] using vLLM C++ kernel. + + Shapes: + x: (num_tokens, 2 * d) + return: (num_tokens, d) + """ + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + if approximate == "none": + torch.ops._C.gelu_and_mul(out, x) + else: + torch.ops._C.gelu_tanh_and_mul(out, x) + return out + + +#=================== +# GELU Activations +#=================== + + +@ir.ops.gelu_new.register_impl("vllm_c", supported=CUDA_ALIKE) +def gelu_new(x: Tensor) -> Tensor: + """ + New GELU activation function using vLLM C++ kernel. + + Formula: 0.5 * x * (1.0 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + """ + out = torch.empty_like(x) + torch.ops._C.gelu_new(out, x) + return out + + +@ir.ops.gelu_fast.register_impl("vllm_c", supported=CUDA_ALIKE) +def gelu_fast(x: Tensor) -> Tensor: + """ + Fast GELU activation function using vLLM C++ kernel. + + Formula: 0.5 * x * (1.0 + tanh(x * 0.7978845608 * (1.0 + 0.044715 * x^2))) + """ + out = torch.empty_like(x) + torch.ops._C.gelu_fast(out, x) + return out + + +@ir.ops.quick_gelu.register_impl("vllm_c", supported=CUDA_ALIKE) +def quick_gelu(x: Tensor) -> Tensor: + """ + Quick GELU activation function using vLLM C++ kernel. + + Formula: x * sigmoid(1.702 * x) + """ + out = torch.empty_like(x) + torch.ops._C.gelu_quick(out, x) + return out diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index e2b70b771a12..69a6776d9157 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -13,8 +13,9 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) +from vllm import ir from vllm.logger import init_logger -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import CustomOp, PluggableLayer from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import CpuArchEnum, current_platform from vllm.triton_utils import tl, triton @@ -248,36 +249,20 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: # --8<-- [start:gelu] -@CustomOp.register("gelu") -class GELU(CustomOp): +@PluggableLayer.register("gelu") +class GELU(PluggableLayer): # --8<-- [end:gelu] def __init__(self): super().__init__() - if current_platform.get_cpu_architecture() == CpuArchEnum.ARM and hasattr( - torch.ops._C, "activation_lut_bf16" - ): - self.op = torch.ops._C.activation_lut_bf16 - else: - self.op = None - - def forward_native(self, x: torch.Tensor) -> torch.Tensor: - return F.gelu(x, approximate="none") - def forward_cpu(self, x: torch.Tensor) -> torch.Tensor: - if self.op and x.dtype == torch.bfloat16 and x.is_contiguous(): - out = torch.empty_like(x) - self.op(out, x, "gelu") - return out - return self.forward_native(x) - - def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - return self.forward_native(x) + def forward(self, x: torch.Tensor) -> torch.Tensor: + return ir.ops.gelu(x) # --8<-- [start:gelu_and_mul] -@CustomOp.register("gelu_and_mul") -class GeluAndMul(CustomOp): +@PluggableLayer.register("gelu_and_mul") +class GeluAndMul(PluggableLayer): """An activation function for GeGLU. The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2. @@ -294,15 +279,6 @@ def __init__(self, approximate: str = "none"): self.approximate = approximate if approximate not in ("none", "tanh"): raise ValueError(f"Unknown approximate mode: {approximate}") - if ( - current_platform.is_cuda_alike() - or current_platform.is_cpu() - or current_platform.is_xpu() - ): - if approximate == "none": - self.op = torch.ops._C.gelu_and_mul - elif approximate == "tanh": - self.op = torch.ops._C.gelu_tanh_and_mul if current_platform.is_rocm() and approximate == "tanh": logger.warning_once( "[ROCm] PyTorch's native GELU with tanh approximation is unstable " @@ -310,24 +286,8 @@ def __init__(self, approximate: str = "none"): "approximation. The custom kernel implementation is unaffected." ) - def forward_native(self, x: torch.Tensor) -> torch.Tensor: - """PyTorch-native implementation equivalent to forward().""" - # TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile - approximate = self.approximate - if current_platform.is_rocm() and approximate == "tanh": - approximate = "none" - d = x.shape[-1] // 2 - return F.gelu(x[..., :d], approximate=approximate) * x[..., d:] - - def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] // 2 - output_shape = x.shape[:-1] + (d,) - out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - self.op(out, x) - return out - - def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: - return self.forward_cuda(x) + def forward(self, x: torch.Tensor) -> torch.Tensor: + return ir.ops.gelu_and_mul(x, approximate=self.approximate) def extra_repr(self) -> str: return f"approximate={repr(self.approximate)}" @@ -404,86 +364,40 @@ def extra_repr(self) -> str: # --8<-- [start:gelu_new] -@CustomOp.register("gelu_new") -class NewGELU(CustomOp): +@PluggableLayer.register("gelu_new") +class NewGELU(PluggableLayer): # --8<-- [end:gelu_new] def __init__(self): super().__init__() - if ( - current_platform.is_cuda_alike() - or current_platform.is_cpu() - or current_platform.is_xpu() - ): - self.op = torch.ops._C.gelu_new - - def forward_native(self, x: torch.Tensor) -> torch.Tensor: - """PyTorch-native implementation equivalent to forward().""" - c = math.sqrt(2.0 / math.pi) - return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0)))) - def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - out = torch.empty_like(x) - self.op(out, x) - return out - - def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: - return self.forward_cuda(x) + def forward(self, x: torch.Tensor) -> torch.Tensor: + return ir.ops.gelu_new(x) # --8<-- [start:gelu_fast] -@CustomOp.register("gelu_fast") -class FastGELU(CustomOp): +@PluggableLayer.register("gelu_fast") +class FastGELU(PluggableLayer): # --8<-- [end:gelu_fast] def __init__(self): super().__init__() - if ( - current_platform.is_cuda_alike() - or current_platform.is_cpu() - or current_platform.is_xpu() - ): - self.op = torch.ops._C.gelu_fast - - def forward_native(self, x: torch.Tensor) -> torch.Tensor: - """PyTorch-native implementation equivalent to forward().""" - return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) - - def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - out = torch.empty_like(x) - self.op(out, x) - return out - def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: - return self.forward_cuda(x) + def forward(self, x: torch.Tensor) -> torch.Tensor: + return ir.ops.gelu_fast(x) # --8<-- [start:quick_gelu] -@CustomOp.register("quick_gelu") -class QuickGELU(CustomOp): +@PluggableLayer.register("quick_gelu") +class QuickGELU(PluggableLayer): # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90 # --8<-- [end:quick_gelu] def __init__(self): super().__init__() - if ( - current_platform.is_cuda_alike() - or current_platform.is_cpu() - or current_platform.is_xpu() - ): - self.op = torch.ops._C.gelu_quick - def forward_native(self, x: torch.Tensor) -> torch.Tensor: - """PyTorch-native implementation equivalent to forward().""" - return x * torch.sigmoid(1.702 * x) - - def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - out = torch.empty_like(x) - self.op(out, x) - return out - - def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: - return self.forward_cuda(x) + def forward(self, x: torch.Tensor) -> torch.Tensor: + return ir.ops.quick_gelu(x) # --8<-- [start:relu2]