Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions tests/compile/passes/ir/test_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,76 @@ def test_lowering_rms_norm(rms_provider, default_vllm_config):

torch.testing.assert_close(output_unlowered, output)
torch.testing.assert_close(output_unlowered, output2)


#===================
# GELU Lowering Tests
#===================

# TODO: Refactor lowering tests into a modular, parameterized framework
# that can test all IR ops uniformly. Current approach has separate tests
# for each op (rms_norm, gelu_*), which leads to code duplication.
# A better design would be:
# 1. Define a generic test_ir_op_lowering(op_name, providers) function
# 2. Use pytest.mark.parametrize to test all ops with their providers
# 3. Keep only special-case tests (e.g., variance_size fallback, mixed ops)
# Example:
# @pytest.mark.parametrize("op_name,providers", [
# ("rms_norm", ["vllm_c", "native"]),
# ("gelu_new", ["vllm_c", "native"]),
# ...
# ])
# def test_ir_op_lowering_basic(op_name, providers): ...


class GeluMixedModel(nn.Module):
"""Model mixing GELU IR ops with RMSNorm."""

def __init__(self, hidden_size=16):
super().__init__()
self.hidden_size = hidden_size
self.weight = torch.ones(hidden_size, dtype=torch.bfloat16)

def forward(self, x):
x1 = ops.gelu_new(x)
x2 = ops.rms_norm(x1, self.weight, 1e-5)
x3 = ops.gelu_fast(x2)
return x3


def test_lowering_gelu_mixed_model(default_vllm_config):
"""Test lowering with mixed GELU and RMSNorm ops."""
torch.set_default_device(current_platform.device_type)

lowering_pass = VllmIRLoweringPass(get_current_vllm_config())
backend = TestBackend(lowering_pass)
backend_unlowered = TestBackend()

model = GeluMixedModel()
x = torch.randn(8, 16, dtype=torch.bfloat16)

# Set priority for all ops
providers_to_test = ["vllm_c"] if current_platform.is_cuda_alike() else ["native"]

with (
ops.gelu_new.set_priority(providers_to_test + ["native"]),
ops.gelu_fast.set_priority(providers_to_test + ["native"]),
ops.rms_norm.set_priority(["vllm_c", "native"]) if current_platform.is_cuda_alike() else ops.rms_norm.set_priority(["native"]),
ir.enable_torch_wrap(True),
):
compiled_model = torch.compile(model, backend=backend, fullgraph=True)
compiled_unlowered_model = torch.compile(
model, backend=backend_unlowered, fullgraph=True
)
output = compiled_model(x)
output_unlowered = compiled_unlowered_model(x)

# Check implementations were selected
assert "gelu_new" in lowering_pass.selected_impls
assert "gelu_fast" in lowering_pass.selected_impls
assert "rms_norm" in lowering_pass.selected_impls

# Verify correctness with relaxed tolerances for bfloat16
torch.testing.assert_close(
output_unlowered, output, rtol=0.1, atol=0.01
)
87 changes: 87 additions & 0 deletions tests/kernels/ir/test_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for GELU activation function IR ops."""

import pytest
import torch

import vllm.kernels # noqa: F401
from vllm import ir
from vllm.platforms import current_platform


@pytest.mark.skipif(
not current_platform.is_cuda_alike(),
reason="vllm_c kernels only supported on CUDA-alike platforms",
)
def test_gelu_registration():
"""Test that GELU ops have correct provider registration."""
expected = {
"native": True,
"vllm_c": True,
}

for op_name in ["gelu_new", "gelu_fast", "quick_gelu"]:
gelu_op = getattr(ir.ops, op_name)
actual = {
provider: impl.supported
for provider, impl in gelu_op.impls.items()
}
assert actual == expected, f"{op_name} has incorrect registration"


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("shape", [(1, 8), (4, 16), (17, 64)])
class TestGeluOps:
"""Tests for GELU IR ops."""

@pytest.mark.parametrize(
"gelu_op_name", ["gelu_new", "gelu_fast", "quick_gelu"]
)
def test_native_semantics(self, gelu_op_name, dtype, shape):
"""Test that IR op matches native implementation."""
gelu_op = getattr(ir.ops, gelu_op_name)
x = torch.randn(*shape, dtype=dtype)

out_ir = gelu_op(x)
out_native = gelu_op.impls["native"].impl_fn(x)

torch.testing.assert_close(out_ir, out_native)

@pytest.mark.parametrize("provider", ["vllm_c"])
@pytest.mark.parametrize(
"gelu_op_name", ["gelu_new", "gelu_fast", "quick_gelu"]
)
def test_vllm_c_impl(self, gelu_op_name, provider, dtype, shape):
"""Test vllm_c implementation correctness."""
gelu_op = getattr(ir.ops, gelu_op_name)
impl = gelu_op.impls[provider]

if not impl.supported:
pytest.skip(f"{provider} impl not supported on this platform")

x = torch.randn(
*shape, dtype=dtype, device=current_platform.device_type
)
out_impl = impl.impl_fn(x)
out_native = gelu_op.impls["native"].impl_fn(x)

torch.testing.assert_close(out_impl, out_native)

# Verify dispatch matches direct call
with gelu_op.set_priority([provider, "native"]):
out_dispatch = gelu_op(x)
torch.testing.assert_close(out_dispatch, out_impl, rtol=0.0, atol=0.0)

@pytest.mark.parametrize(
"gelu_op_name", ["gelu_new", "gelu_fast", "quick_gelu"]
)
def test_torch_opcheck(self, gelu_op_name, dtype, shape):
"""Test torch op integration."""
gelu_op = getattr(ir.ops, gelu_op_name)
x = torch.randn(*shape, dtype=dtype)

with gelu_op.set_priority(["native"]):
torch.library.opcheck(
torch.ops.vllm_ir.__getattr__(gelu_op_name), (x,)
)
9 changes: 9 additions & 0 deletions vllm/config/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ class IrOpPriorityConfig:
rms_norm: list[str] = Field(default_factory=list)
"""Priority list for vllm.ir.ops.rms_norm"""

gelu_new: list[str] = Field(default_factory=list)
"""Priority list for vllm.ir.ops.gelu_new"""

gelu_fast: list[str] = Field(default_factory=list)
"""Priority list for vllm.ir.ops.gelu_fast"""

quick_gelu: list[str] = Field(default_factory=list)
"""Priority list for vllm.ir.ops.quick_gelu"""

def compute_hash(self) -> str:
"""
Produces a hash unique to the pass configuration.
Expand Down
3 changes: 2 additions & 1 deletion vllm/ir/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .activation import gelu_fast, gelu_new, quick_gelu
from .layernorm import rms_norm

__all__ = ["rms_norm"]
__all__ = ["rms_norm", "gelu_new", "gelu_fast", "quick_gelu"]
49 changes: 49 additions & 0 deletions vllm/ir/ops/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math

import torch
from torch import Tensor

from ..op import register_op

c_gelu_new = math.sqrt(2.0 / math.pi)


@register_op
def gelu_new(x: Tensor) -> Tensor:
"""
New GELU activation function.

Formula: 0.5 * x * (1.0 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))

This is the GELU approximation used in GPT-2 and other transformer models.
"""
return 0.5 * x * (1.0 + torch.tanh(c_gelu_new * (x + 0.044715 * torch.pow(x, 3.0))))


@register_op
def gelu_fast(x: Tensor) -> Tensor:
"""
Fast GELU activation function.

Formula: 0.5 * x * (1.0 + tanh(x * 0.7978845608 * (1.0 + 0.044715 * x^2)))

A computationally efficient approximation of the GELU function.
"""
return 0.5 * x * (
1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))
)


@register_op
def quick_gelu(x: Tensor) -> Tensor:
"""
Quick GELU activation function.

Formula: x * sigmoid(1.702 * x)

A fast approximation of GELU used in various transformer models.
Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
"""
return x * torch.sigmoid(1.702 * x)
41 changes: 41 additions & 0 deletions vllm/kernels/vllm_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,44 @@ def rms_norm(
output = torch.empty(x.shape, device=x.device, dtype=x.dtype)
torch.ops._C.rms_norm(output, x, weight, epsilon)
return output


#===================
# GELU Activations
#===================


@ir.ops.gelu_new.register_impl("vllm_c", supported=CUDA_ALIKE)
def gelu_new(x: Tensor) -> Tensor:
"""
New GELU activation function using vLLM C++ kernel.

Formula: 0.5 * x * (1.0 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
"""
out = torch.empty_like(x)
torch.ops._C.gelu_new(out, x)
return out


@ir.ops.gelu_fast.register_impl("vllm_c", supported=CUDA_ALIKE)
def gelu_fast(x: Tensor) -> Tensor:
"""
Fast GELU activation function using vLLM C++ kernel.

Formula: 0.5 * x * (1.0 + tanh(x * 0.7978845608 * (1.0 + 0.044715 * x^2)))
"""
out = torch.empty_like(x)
torch.ops._C.gelu_fast(out, x)
return out


@ir.ops.quick_gelu.register_impl("vllm_c", supported=CUDA_ALIKE)
def quick_gelu(x: Tensor) -> Tensor:
"""
Quick GELU activation function using vLLM C++ kernel.

Formula: x * sigmoid(1.702 * x)
"""
out = torch.empty_like(x)
torch.ops._C.gelu_quick(out, x)
return out
38 changes: 7 additions & 31 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm import ir
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.utils import set_weight_attrs
Expand Down Expand Up @@ -410,22 +411,13 @@ class NewGELU(CustomOp):

def __init__(self):
super().__init__()
if (
current_platform.is_cuda_alike()
or current_platform.is_cpu()
or current_platform.is_xpu()
):
self.op = torch.ops._C.gelu_new

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
return ir.ops.gelu_new(x)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The forward_native method is intended to be a reference implementation using standard PyTorch operations. By calling ir.ops.gelu_new(x), it now uses the IR dispatch mechanism, which may select an optimized kernel (like vllm_c) depending on the environment and priority settings. This makes correctness tests (such as those in tests/kernels/core/test_activation.py) tautological, as they compare the optimized output against itself. To maintain the integrity of these tests, forward_native should explicitly call the native implementation.

Suggested change
return ir.ops.gelu_new(x)
return ir.ops.gelu_new.impls["native"].impl_fn(x)


def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
self.op(out, x)
return out
return ir.ops.gelu_new(x)

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_cuda(x)
Expand All @@ -438,21 +430,13 @@ class FastGELU(CustomOp):

def __init__(self):
super().__init__()
if (
current_platform.is_cuda_alike()
or current_platform.is_cpu()
or current_platform.is_xpu()
):
self.op = torch.ops._C.gelu_fast

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
return ir.ops.gelu_fast(x)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to NewGELU, FastGELU.forward_native should bypass the IR dispatch logic and call the native implementation directly to ensure it remains a valid baseline for correctness verification.

Suggested change
return ir.ops.gelu_fast(x)
return ir.ops.gelu_fast.impls["native"].impl_fn(x)


def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
self.op(out, x)
return out
return ir.ops.gelu_fast(x)

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_cuda(x)
Expand All @@ -466,21 +450,13 @@ class QuickGELU(CustomOp):

def __init__(self):
super().__init__()
if (
current_platform.is_cuda_alike()
or current_platform.is_cpu()
or current_platform.is_xpu()
):
self.op = torch.ops._C.gelu_quick

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return x * torch.sigmoid(1.702 * x)
return ir.ops.quick_gelu(x)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To ensure QuickGELU.forward_native remains a reliable reference for testing, it should explicitly invoke the native implementation of the IR op rather than relying on the default dispatch.

Suggested change
return ir.ops.quick_gelu(x)
return ir.ops.quick_gelu.impls["native"].impl_fn(x)


def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
self.op(out, x)
return out
return ir.ops.quick_gelu(x)

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_cuda(x)
Expand Down
Loading