-
-
Notifications
You must be signed in to change notification settings - Fork 14.7k
Use aiter triton fused_add_rmsnorm_pad for gpt-oss #30976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ProExpertProg
merged 20 commits into
vllm-project:main
from
ROCm:fused_aiter_triton_rmsnorm_pad
Jan 28, 2026
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
1bdfd52
squash into one (now working) commit
Rohan138 1ba632a
fix comment
Rohan138 da95aa1
fix comment
Rohan138 7ef50bc
move aiter triton gemm above skinny splitKrc gemm; add hidden_size pa…
Rohan138 6adba53
Add optional fusion pass
Rohan138 8ba717a
fix inputs to rmsnormpad pattern
Rohan138 8ba2d7f
Enable fuse_norm_padding
Rohan138 9bc90ce
Use AITER RMSNorm env var instead
Rohan138 77e7fa1
fix pre-commit
Rohan138 fe14e16
fix test; rename fusion
Rohan138 aa9709d
Make pass conditional on AITER Triton GEMM
Rohan138 521ca35
Move rocm_unquantized_gemm into the fusion; drop F.pad
Rohan138 a6b6cfa
Merge branch 'main' into fused_aiter_triton_rmsnorm_pad
Rohan138 2cdce68
drop num_local experts, just use a dummy shape
Rohan138 8fe47e0
Add unit test
Rohan138 1228560
fix lint
Rohan138 de0be00
add layers to fuse_act_padding test
Rohan138 88a805c
move import into test to fix CI
Rohan138 ea638bc
fix lint
Rohan138 71254a6
Merge branch 'main' into fused_aiter_triton_rmsnorm_pad
Rohan138 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| import vllm.config | ||
| from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops | ||
| from vllm.compilation.noop_elimination import NoOpEliminationPass | ||
| from vllm.compilation.post_cleanup import PostCleanupPass | ||
| from vllm.config import ( | ||
| CompilationConfig, | ||
| CompilationMode, | ||
| ModelConfig, | ||
| PassConfig, | ||
| VllmConfig, | ||
| ) | ||
| from vllm.model_executor.layers.layernorm import RMSNorm | ||
| from vllm.model_executor.layers.utils import rocm_unquantized_gemm | ||
|
|
||
| from .backend import TestBackend | ||
|
|
||
|
|
||
| class TestModel(torch.nn.Module): | ||
| def __init__( | ||
| self, | ||
| num_layers: int, | ||
| hidden_size: int, | ||
| num_local_experts: int, | ||
| x_pad_to_multiple: int, | ||
| ): | ||
| super().__init__() | ||
| self.num_layers = num_layers | ||
| self.hidden_size = hidden_size | ||
| self.x_pad_to_multiple = x_pad_to_multiple | ||
| self.pad_dim = x_pad_to_multiple - (hidden_size % x_pad_to_multiple) | ||
|
|
||
| self.norm = [RMSNorm(hidden_size, eps=1e-5) for _ in range(num_layers)] | ||
| self.router = [ | ||
| torch.nn.Linear(hidden_size, num_local_experts) for _ in range(4) | ||
| ] | ||
|
|
||
| def forward(self, x): | ||
| # avoid having graph input be an arg to a pattern directly | ||
| x = resid = torch.relu(x) | ||
| all_router_logits = [] | ||
| for layer in range(self.num_layers): | ||
| x = x[:, : self.hidden_size] | ||
| x, resid = self.norm[layer](x, resid) | ||
| router_logits = rocm_unquantized_gemm( | ||
| self, x, self.router[layer].weight, self.router[layer].bias | ||
| ) | ||
| x = torch.nn.functional.pad( | ||
| x, (0, self.pad_dim), mode="constant", value=0.0 | ||
| ) | ||
| all_router_logits.append(router_logits) | ||
|
|
||
| return x, resid, *all_router_logits | ||
|
|
||
| def ops_in_model_before(self): | ||
| return [ | ||
| rocm_aiter_ops.get_rmsnorm_fused_add_op(), | ||
| torch.ops.aten.constant_pad_nd, | ||
| ] | ||
|
|
||
| def ops_in_model_after(self): | ||
| return [rocm_aiter_ops.get_triton_add_rmsnorm_pad_op()] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", [torch.bfloat16]) | ||
| @pytest.mark.parametrize("num_layers", [3]) | ||
| @pytest.mark.parametrize("hidden_size", [2880]) | ||
| @pytest.mark.parametrize("num_local_experts", [128]) | ||
| @pytest.mark.parametrize("x_pad_to_multiple", [256]) | ||
| @pytest.mark.skipif( | ||
| not is_aiter_found_and_supported(), | ||
| reason="Only test on ROCm with AITER installed and supported", | ||
| ) | ||
| def test_fuse_act_padding( | ||
| dtype: torch.dtype, | ||
| num_layers: int, | ||
| hidden_size: int, | ||
| num_local_experts: int, | ||
| x_pad_to_multiple: int, | ||
| monkeypatch: pytest.MonkeyPatch, | ||
| ): | ||
| vllm_config = VllmConfig( | ||
| model_config=ModelConfig(dtype=dtype), | ||
| compilation_config=CompilationConfig( | ||
| mode=CompilationMode.VLLM_COMPILE, | ||
| custom_ops=["+rms_norm"], | ||
| pass_config=PassConfig(fuse_act_padding=True, eliminate_noops=True), | ||
| ), | ||
| ) | ||
|
|
||
| with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m: | ||
| from vllm.compilation.rocm_aiter_fusion import ( | ||
| RocmAiterTritonAddRMSNormPadFusionPass, | ||
| ) | ||
|
|
||
| torch.set_default_device("cuda") | ||
| torch.set_default_dtype(dtype) | ||
| torch.manual_seed(1) | ||
|
|
||
| m.setenv("VLLM_ROCM_USE_AITER", "1") | ||
| rocm_aiter_ops.refresh_env_variables() | ||
|
|
||
| fusion_pass = RocmAiterTritonAddRMSNormPadFusionPass(vllm_config) | ||
| passes = [ | ||
| NoOpEliminationPass(vllm_config), | ||
| fusion_pass, | ||
| PostCleanupPass(vllm_config), | ||
| ] | ||
| backend = TestBackend(*passes) | ||
| model = TestModel(num_layers, hidden_size, num_local_experts, x_pad_to_multiple) | ||
|
|
||
| x = torch.rand(1, hidden_size) | ||
| torch._dynamo.mark_dynamic(x, 0) | ||
|
|
||
| outputs_unfused = model(x) | ||
|
|
||
| model_fused = torch.compile(model, backend=backend) | ||
| outputs_fused = model_fused(x) | ||
|
|
||
| torch.testing.assert_close(outputs_unfused, outputs_fused) | ||
|
|
||
| assert fusion_pass.matched_count == num_layers | ||
|
|
||
| backend.check_before_ops(model.ops_in_model_before()) | ||
| backend.check_after_ops(model.ops_in_model_after()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.