-
-
Notifications
You must be signed in to change notification settings - Fork 15.9k
[vLLM IR] Port activations (gelu) to IR op #40135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
Alex-ai-future
wants to merge
8
commits into
vllm-project:main
Choose a base branch
from
Alex-ai-future:feature/gelu-on-vllm-ir
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
59ce53a
[vLLM IR] Add GELU activation functions IR support
Alex-ai-future cf81e01
refactor(activation): unify GELU implementations using ir.ops
Alex-ai-future 3a6eb0c
rm
Alex-ai-future f083b6e
test(ir/ops): refactor GELU tests with pytest parameterization
Alex-ai-future 74518c9
move
Alex-ai-future dedd74c
feat(config): add priority lists for GELU variants to kernel config
Alex-ai-future 843d41b
test(compile/ir): remove basic gelu tests and add refactor todo
Alex-ai-future 44c7037
roll back
Alex-ai-future File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Tests for GELU activation function IR ops.""" | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| import vllm.kernels # noqa: F401 | ||
| from vllm import ir | ||
| from vllm.platforms import current_platform | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not current_platform.is_cuda_alike(), | ||
| reason="vllm_c kernels only supported on CUDA-alike platforms", | ||
| ) | ||
| def test_gelu_registration(): | ||
| """Test that GELU ops have correct provider registration.""" | ||
| expected = { | ||
| "native": True, | ||
| "vllm_c": True, | ||
| } | ||
|
|
||
| for op_name in ["gelu_new", "gelu_fast", "quick_gelu"]: | ||
| gelu_op = getattr(ir.ops, op_name) | ||
| actual = { | ||
| provider: impl.supported | ||
| for provider, impl in gelu_op.impls.items() | ||
| } | ||
| assert actual == expected, f"{op_name} has incorrect registration" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) | ||
| @pytest.mark.parametrize("shape", [(1, 8), (4, 16), (17, 64)]) | ||
| class TestGeluOps: | ||
| """Tests for GELU IR ops.""" | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "gelu_op_name", ["gelu_new", "gelu_fast", "quick_gelu"] | ||
| ) | ||
| def test_native_semantics(self, gelu_op_name, dtype, shape): | ||
| """Test that IR op matches native implementation.""" | ||
| gelu_op = getattr(ir.ops, gelu_op_name) | ||
| x = torch.randn(*shape, dtype=dtype) | ||
|
|
||
| out_ir = gelu_op(x) | ||
| out_native = gelu_op.impls["native"].impl_fn(x) | ||
|
|
||
| torch.testing.assert_close(out_ir, out_native) | ||
|
|
||
| @pytest.mark.parametrize("provider", ["vllm_c"]) | ||
| @pytest.mark.parametrize( | ||
| "gelu_op_name", ["gelu_new", "gelu_fast", "quick_gelu"] | ||
| ) | ||
| def test_vllm_c_impl(self, gelu_op_name, provider, dtype, shape): | ||
| """Test vllm_c implementation correctness.""" | ||
| gelu_op = getattr(ir.ops, gelu_op_name) | ||
| impl = gelu_op.impls[provider] | ||
|
|
||
| if not impl.supported: | ||
| pytest.skip(f"{provider} impl not supported on this platform") | ||
|
|
||
| x = torch.randn( | ||
| *shape, dtype=dtype, device=current_platform.device_type | ||
| ) | ||
| out_impl = impl.impl_fn(x) | ||
| out_native = gelu_op.impls["native"].impl_fn(x) | ||
|
|
||
| torch.testing.assert_close(out_impl, out_native) | ||
|
|
||
| # Verify dispatch matches direct call | ||
| with gelu_op.set_priority([provider, "native"]): | ||
| out_dispatch = gelu_op(x) | ||
| torch.testing.assert_close(out_dispatch, out_impl, rtol=0.0, atol=0.0) | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "gelu_op_name", ["gelu_new", "gelu_fast", "quick_gelu"] | ||
| ) | ||
| def test_torch_opcheck(self, gelu_op_name, dtype, shape): | ||
| """Test torch op integration.""" | ||
| gelu_op = getattr(ir.ops, gelu_op_name) | ||
| x = torch.randn(*shape, dtype=dtype) | ||
|
|
||
| with gelu_op.set_priority(["native"]): | ||
| torch.library.opcheck( | ||
| torch.ops.vllm_ir.__getattr__(gelu_op_name), (x,) | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| from .activation import gelu_fast, gelu_new, quick_gelu | ||
| from .layernorm import rms_norm | ||
|
|
||
| __all__ = ["rms_norm"] | ||
| __all__ = ["rms_norm", "gelu_new", "gelu_fast", "quick_gelu"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import math | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
|
|
||
| from ..op import register_op | ||
|
|
||
| c_gelu_new = math.sqrt(2.0 / math.pi) | ||
|
|
||
|
|
||
| @register_op | ||
| def gelu_new(x: Tensor) -> Tensor: | ||
| """ | ||
| New GELU activation function. | ||
|
|
||
| Formula: 0.5 * x * (1.0 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) | ||
|
|
||
| This is the GELU approximation used in GPT-2 and other transformer models. | ||
| """ | ||
| return 0.5 * x * (1.0 + torch.tanh(c_gelu_new * (x + 0.044715 * torch.pow(x, 3.0)))) | ||
|
|
||
|
|
||
| @register_op | ||
| def gelu_fast(x: Tensor) -> Tensor: | ||
| """ | ||
| Fast GELU activation function. | ||
|
|
||
| Formula: 0.5 * x * (1.0 + tanh(x * 0.7978845608 * (1.0 + 0.044715 * x^2))) | ||
|
|
||
| A computationally efficient approximation of the GELU function. | ||
| """ | ||
| return 0.5 * x * ( | ||
| 1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)) | ||
| ) | ||
|
|
||
|
|
||
| @register_op | ||
| def quick_gelu(x: Tensor) -> Tensor: | ||
| """ | ||
| Quick GELU activation function. | ||
|
|
||
| Formula: x * sigmoid(1.702 * x) | ||
|
|
||
| A fast approximation of GELU used in various transformer models. | ||
| Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90 | ||
| """ | ||
| return x * torch.sigmoid(1.702 * x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| get_tensor_model_parallel_rank, | ||
| get_tensor_model_parallel_world_size, | ||
| ) | ||
| from vllm import ir | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.custom_op import CustomOp | ||
| from vllm.model_executor.utils import set_weight_attrs | ||
|
|
@@ -410,22 +411,13 @@ class NewGELU(CustomOp): | |
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| if ( | ||
| current_platform.is_cuda_alike() | ||
| or current_platform.is_cpu() | ||
| or current_platform.is_xpu() | ||
| ): | ||
| self.op = torch.ops._C.gelu_new | ||
|
|
||
| def forward_native(self, x: torch.Tensor) -> torch.Tensor: | ||
| """PyTorch-native implementation equivalent to forward().""" | ||
| c = math.sqrt(2.0 / math.pi) | ||
| return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0)))) | ||
| return ir.ops.gelu_new(x) | ||
|
|
||
| def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: | ||
| out = torch.empty_like(x) | ||
| self.op(out, x) | ||
| return out | ||
| return ir.ops.gelu_new(x) | ||
|
|
||
| def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: | ||
| return self.forward_cuda(x) | ||
|
|
@@ -438,21 +430,13 @@ class FastGELU(CustomOp): | |
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| if ( | ||
| current_platform.is_cuda_alike() | ||
| or current_platform.is_cpu() | ||
| or current_platform.is_xpu() | ||
| ): | ||
| self.op = torch.ops._C.gelu_fast | ||
|
|
||
| def forward_native(self, x: torch.Tensor) -> torch.Tensor: | ||
| """PyTorch-native implementation equivalent to forward().""" | ||
| return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) | ||
| return ir.ops.gelu_fast(x) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: | ||
| out = torch.empty_like(x) | ||
| self.op(out, x) | ||
| return out | ||
| return ir.ops.gelu_fast(x) | ||
|
|
||
| def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: | ||
| return self.forward_cuda(x) | ||
|
|
@@ -466,21 +450,13 @@ class QuickGELU(CustomOp): | |
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| if ( | ||
| current_platform.is_cuda_alike() | ||
| or current_platform.is_cpu() | ||
| or current_platform.is_xpu() | ||
| ): | ||
| self.op = torch.ops._C.gelu_quick | ||
|
|
||
| def forward_native(self, x: torch.Tensor) -> torch.Tensor: | ||
| """PyTorch-native implementation equivalent to forward().""" | ||
| return x * torch.sigmoid(1.702 * x) | ||
| return ir.ops.quick_gelu(x) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: | ||
| out = torch.empty_like(x) | ||
| self.op(out, x) | ||
| return out | ||
| return ir.ops.quick_gelu(x) | ||
|
|
||
| def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: | ||
| return self.forward_cuda(x) | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
forward_nativemethod is intended to be a reference implementation using standard PyTorch operations. By callingir.ops.gelu_new(x), it now uses the IR dispatch mechanism, which may select an optimized kernel (likevllm_c) depending on the environment and priority settings. This makes correctness tests (such as those intests/kernels/core/test_activation.py) tautological, as they compare the optimized output against itself. To maintain the integrity of these tests,forward_nativeshould explicitly call the native implementation.