From cf8c816a4bf3315b253a8af4de7130d78ef60113 Mon Sep 17 00:00:00 2001 From: Icey <1790571317@qq.com> Date: Wed, 8 Apr 2026 03:34:24 +0000 Subject: [PATCH 1/4] mixer2_gated_rms_norm Signed-off-by: Icey <1790571317@qq.com> --- vllm/config/kernel.py | 3 + vllm/ir/ops/__init__.py | 4 +- vllm/ir/ops/layernorm.py | 26 ++++++++ vllm/kernels/__init__.py | 4 +- vllm/kernels/triton/__init__.py | 5 ++ vllm/kernels/triton/ops/__init__.py | 5 ++ vllm/kernels/triton/ops/layernorm.py | 42 +++++++++++++ .../layers/mamba/mamba_mixer2.py | 60 ++++++++----------- .../layers/mamba/ops/layernorm_gated.py | 7 ++- vllm/platforms/cuda.py | 4 +- vllm/platforms/rocm.py | 4 +- 11 files changed, 119 insertions(+), 45 deletions(-) create mode 100644 vllm/kernels/triton/__init__.py create mode 100644 vllm/kernels/triton/ops/__init__.py create mode 100644 vllm/kernels/triton/ops/layernorm.py diff --git a/vllm/config/kernel.py b/vllm/config/kernel.py index f3ffbe4e8b19..a69065edd89b 100644 --- a/vllm/config/kernel.py +++ b/vllm/config/kernel.py @@ -31,6 +31,9 @@ class IrOpPriorityConfig: rms_norm: list[str] = Field(default_factory=list) """Priority list for vllm.ir.ops.rms_norm""" + mixer2_rms_norm_gated: list[str] = Field(default_factory=list) + """Priority list for vllm.ir.ops.rms_norm_gated""" + def compute_hash(self) -> str: """ Produces a hash unique to the pass configuration. diff --git a/vllm/ir/ops/__init__.py b/vllm/ir/ops/__init__.py index 25ad27c8a078..99228b25097b 100644 --- a/vllm/ir/ops/__init__.py +++ b/vllm/ir/ops/__init__.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .layernorm import rms_norm +from .layernorm import mixer2_rms_norm_gated, rms_norm -__all__ = ["rms_norm"] +__all__ = ["rms_norm", "mixer2_rms_norm_gated"] diff --git a/vllm/ir/ops/layernorm.py b/vllm/ir/ops/layernorm.py index ac0c38a9e4d6..744990d3871e 100644 --- a/vllm/ir/ops/layernorm.py +++ b/vllm/ir/ops/layernorm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch +import torch.nn.functional as F from torch import Tensor from ..op import register_op @@ -19,3 +20,28 @@ def rms_norm( if weight is not None: x = x.to(weight.dtype) * weight return x.to(orig_dtype) + + +@register_op +def mixer2_rms_norm_gated( + x: Tensor, + gate: Tensor, + weight: Tensor | None, + epsilon: float, + group_size: int | None = None, +) -> Tensor: + input_dtype = x.dtype + x = x * F.silu(gate.to(torch.float32)) + if group_size is None: + # Standard RMSNorm: compute variance over the full hidden dimension + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + epsilon) + else: + # Grouped RMSNorm: compute variance independently within each group + *prefix_dims, hidden_dims = x.shape + x_grouped = x.view(*prefix_dims, hidden_dims // group_size, group_size) + variance = x_grouped.pow(2).mean(dim=-1, keepdim=True) + x_grouped = x_grouped * torch.rsqrt(variance + epsilon) + x = x_grouped.view(*prefix_dims, hidden_dims) + + return (weight * x).to(input_dtype) diff --git a/vllm/kernels/__init__.py b/vllm/kernels/__init__.py index 075bc01f3ba3..5940818247c6 100644 --- a/vllm/kernels/__init__.py +++ b/vllm/kernels/__init__.py @@ -2,6 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Kernel implementations for vLLM.""" -from . import aiter_ops, oink_ops, vllm_c, xpu_ops +from . import aiter_ops, oink_ops, triton, vllm_c, xpu_ops -__all__ = ["vllm_c", "aiter_ops", "oink_ops", "xpu_ops"] +__all__ = ["vllm_c", "aiter_ops", "oink_ops", "xpu_ops", "triton"] diff --git a/vllm/kernels/triton/__init__.py b/vllm/kernels/triton/__init__.py new file mode 100644 index 000000000000..32423fb069ed --- /dev/null +++ b/vllm/kernels/triton/__init__.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from . import ops + +__all__ = ["ops"] diff --git a/vllm/kernels/triton/ops/__init__.py b/vllm/kernels/triton/ops/__init__.py new file mode 100644 index 000000000000..a4042744629b --- /dev/null +++ b/vllm/kernels/triton/ops/__init__.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from . import layernorm + +__all__ = ["layernorm"] diff --git a/vllm/kernels/triton/ops/layernorm.py b/vllm/kernels/triton/ops/layernorm.py new file mode 100644 index 000000000000..52bc5e7737f7 --- /dev/null +++ b/vllm/kernels/triton/ops/layernorm.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from torch import Tensor + +from vllm import ir +from vllm.platforms import current_platform + +current_platform.import_kernels() + +CUDA_ALIKE = current_platform.is_cuda_alike() +"""Most kernels in this file are supported on all CUDA-alike platforms.""" + + +mixer2_rms_norm_gated_has_weight = ( + lambda x, gate, weight, epsilon, group_size=None: weight is not None +) +"""Triton gated RMSNorm kernel requires a weight tensor.""" + + +@ir.ops.mixer2_rms_norm_gated.register_impl( + "triton", supports_args=mixer2_rms_norm_gated_has_weight, supported=CUDA_ALIKE +) +def mixer2_rms_norm_gated( + x: Tensor, + gate: Tensor, + weight: Tensor | None, + epsilon: float, + group_size: int | None = None, +) -> Tensor: + assert weight is not None + from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated + + return rms_norm_gated( + x, + weight, + bias=None, + z=gate, + eps=epsilon, + group_size=group_size, + norm_before_gate=False, + ) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 01ea3fdca574..0e5bf51cad49 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -5,6 +5,7 @@ import torch from torch import nn +import vllm.ir.ops from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed import ( divide, @@ -24,13 +25,11 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator, - is_conv_state_dim_first, ) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update, ) -from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update from vllm.model_executor.layers.mamba.ops.ssd_combined import ( mamba_chunk_scan_combined_varlen, @@ -96,67 +95,56 @@ def forward_native( # Each rank computes a local sum of squares followed by AllReduce # 2. tp_size divides n_groups # Each rank only reduces within its local group(s). - # No collective ops necessary. + # No collective ops necessary (use IR op directly). # 3. The general case can be pretty complicated so we AllGather # the input and then redundantly compute the RMSNorm. input_dtype = x.dtype - x = x * nn.functional.silu(gate.to(torch.float32)) if not self.use_rms_norm: - return x.to(input_dtype) + return (x * nn.functional.silu(gate.to(torch.float32))).to(input_dtype) if self.n_groups == 1: if self.tp_size > 1: # Compute local sum and then reduce to obtain global sum + x = x * nn.functional.silu(gate.to(torch.float32)) local_sums = x.pow(2).sum(dim=-1, keepdim=True) global_sums = tensor_model_parallel_all_reduce(local_sums) - # Calculate the variance count = self.tp_size * x.shape[-1] variance = global_sums / count - + x = x * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * x).to(input_dtype) else: - variance = x.pow(2).mean(-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) + # No TP collective needed: use IR op + return vllm.ir.ops.mixer2_rms_norm_gated( + x, gate, self.weight, self.variance_epsilon + ) else: redundant_tp: bool = self.n_groups % self.tp_size != 0 if redundant_tp: # To handle the general case, redundantly apply the variance + x = x * nn.functional.silu(gate.to(torch.float32)) x = tensor_model_parallel_all_gather(x, -1) - - *prefix_dims, hidden_dim = x.shape - group_count = hidden_dim // self.group_size - x_grouped = x.view(*prefix_dims, group_count, self.group_size) - variance = x_grouped.pow(2).mean(-1, keepdim=True) - x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon) - x = x_grouped.view(*prefix_dims, hidden_dim) - - if redundant_tp: + *prefix_dims, hidden_dim = x.shape + group_count = hidden_dim // self.group_size + x_grouped = x.view(*prefix_dims, group_count, self.group_size) + variance = x_grouped.pow(2).mean(-1, keepdim=True) + x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon) + x = x_grouped.view(*prefix_dims, hidden_dim) start = self.per_rank_hidden_size * self.tp_rank end = start + self.per_rank_hidden_size x = x[..., start:end] - - return self.weight * x.to(input_dtype) + return (self.weight * x).to(input_dtype) + else: + # n_groups % tp_size == 0: local grouped RMSNorm, use IR op + return vllm.ir.ops.mixer2_rms_norm_gated( + x, gate, self.weight, self.variance_epsilon, self.group_size + ) def forward_cuda( self, x: torch.Tensor, gate: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - input_dtype = x.dtype - if not self.use_rms_norm: - # Keep gate in float32 for numerical stability during silu - return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype) - - if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1: - return self.forward_native(x, gate) - - return rms_norm_gated( - x, - self.weight.data, - bias=None, - z=gate, - eps=self.variance_epsilon, - norm_before_gate=False, - ) + return self.forward_native(x, gate) def mamba_v2_sharded_weight_loader( diff --git a/vllm/model_executor/layers/mamba/ops/layernorm_gated.py b/vllm/model_executor/layers/mamba/ops/layernorm_gated.py index 19db051cf801..1924ae4a2b47 100644 --- a/vllm/model_executor/layers/mamba/ops/layernorm_gated.py +++ b/vllm/model_executor/layers/mamba/ops/layernorm_gated.py @@ -4,12 +4,11 @@ # Adapted from https://github.com/state-spaces/mamba/blob/60dadf2e0ee730ac337035d5533de10bc26e4847/mamba_ssm/ops/triton/layernorm_gated.py import torch +from torch.library import wrap_triton from vllm.triton_utils import tl, triton -@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) @triton.jit def _layer_norm_fwd_1pass_kernel( X, # pointer to the input @@ -120,7 +119,7 @@ def _layer_norm_fwd( num_warps = min(max(BLOCK_N // 256, 1), 8) grid = (M, ngroups) with torch.accelerator.device_index(x.device.index): - _layer_norm_fwd_1pass_kernel[grid]( + wrap_triton(_layer_norm_fwd_1pass_kernel)[grid]( x, out, weight, @@ -135,6 +134,8 @@ def _layer_norm_fwd( group_size, eps, BLOCK_N=BLOCK_N, + HAS_BIAS=bias is not None, + HAS_Z=z is not None, NORM_BEFORE_GATE=norm_before_gate, IS_RMS_NORM=is_rms_norm, num_warps=num_warps, diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 10dd1b8697db..e4b6840d1ca7 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -570,7 +570,9 @@ def get_default_ir_op_priority(cls, vllm_config: VllmConfig) -> IrOpPriorityConf if envs.VLLM_USE_OINK_OPS: rms_norm = ["oink"] + default - return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm) + return IrOpPriorityConfig.with_default( + default, rms_norm=rms_norm, mixer2_rms_norm_gated=["triton", "native"] + ) # NVML utils diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 2ba4ef3fe8a2..880987b8a442 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -943,4 +943,6 @@ def get_default_ir_op_priority( else: rms_norm = default - return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm) + return IrOpPriorityConfig.with_default( + default, rms_norm=rms_norm, mixer2_rms_norm_gated=["triton", "native"] + ) From 2c5bcacbc5633d40a36c0cbfe00d40467887bf1a Mon Sep 17 00:00:00 2001 From: Icey <1790571317@qq.com> Date: Wed, 8 Apr 2026 06:37:14 +0000 Subject: [PATCH 2/4] wrap triton Signed-off-by: Icey <1790571317@qq.com> --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 1 + vllm/model_executor/layers/mamba/ops/layernorm_gated.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 0e5bf51cad49..6968c765dfb9 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -25,6 +25,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator, + is_conv_state_dim_first, ) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, diff --git a/vllm/model_executor/layers/mamba/ops/layernorm_gated.py b/vllm/model_executor/layers/mamba/ops/layernorm_gated.py index 1924ae4a2b47..77a3f57c3995 100644 --- a/vllm/model_executor/layers/mamba/ops/layernorm_gated.py +++ b/vllm/model_executor/layers/mamba/ops/layernorm_gated.py @@ -4,11 +4,11 @@ # Adapted from https://github.com/state-spaces/mamba/blob/60dadf2e0ee730ac337035d5533de10bc26e4847/mamba_ssm/ops/triton/layernorm_gated.py import torch -from torch.library import wrap_triton from vllm.triton_utils import tl, triton +@torch.library.wrap_triton @triton.jit def _layer_norm_fwd_1pass_kernel( X, # pointer to the input @@ -119,7 +119,7 @@ def _layer_norm_fwd( num_warps = min(max(BLOCK_N // 256, 1), 8) grid = (M, ngroups) with torch.accelerator.device_index(x.device.index): - wrap_triton(_layer_norm_fwd_1pass_kernel)[grid]( + _layer_norm_fwd_1pass_kernel[grid]( x, out, weight, From bcb8bb70eb307f2e00dff526fcbcc19cf45c4c6f Mon Sep 17 00:00:00 2001 From: Icey <1790571317@qq.com> Date: Wed, 8 Apr 2026 07:27:34 +0000 Subject: [PATCH 3/4] fix native impl Signed-off-by: Icey <1790571317@qq.com> --- vllm/ir/ops/layernorm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/ir/ops/layernorm.py b/vllm/ir/ops/layernorm.py index 744990d3871e..5614078644cc 100644 --- a/vllm/ir/ops/layernorm.py +++ b/vllm/ir/ops/layernorm.py @@ -44,4 +44,6 @@ def mixer2_rms_norm_gated( x_grouped = x_grouped * torch.rsqrt(variance + epsilon) x = x_grouped.view(*prefix_dims, hidden_dims) - return (weight * x).to(input_dtype) + if weight is not None: + x = x.to(weight.dtype) * weight + return x.to(input_dtype) From a5b4f98015e0c48c61261fa5106353e9e07f315c Mon Sep 17 00:00:00 2001 From: Chaojun Zhang Date: Tue, 21 Apr 2026 07:20:18 +0000 Subject: [PATCH 4/4] Support mixer2_rms_norm_gated IR on XPU platform Signed-off-by: Chaojun Zhang --- tests/kernels/ir/test_layernorm.py | 154 +++++++++++++++++++++++++++ vllm/kernels/triton/ops/layernorm.py | 7 +- vllm/platforms/xpu.py | 4 +- 3 files changed, 163 insertions(+), 2 deletions(-) diff --git a/tests/kernels/ir/test_layernorm.py b/tests/kernels/ir/test_layernorm.py index 3d21169098dc..b63a9e70f2cd 100644 --- a/tests/kernels/ir/test_layernorm.py +++ b/tests/kernels/ir/test_layernorm.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch +import torch.nn.functional as F # This registers op implementations import vllm.kernels # noqa: F401 @@ -16,7 +17,40 @@ def rms_norm_inputs(n_tokens: int, hidden_size: int, dtype: torch.dtype): return x, weight +def mixer2_rms_norm_gated_inputs(n_tokens: int, hidden_size: int, dtype: torch.dtype): + x = torch.randn(n_tokens, hidden_size, dtype=dtype) + gate = torch.randn(n_tokens, hidden_size, dtype=dtype) + weight = torch.rand(hidden_size, dtype=dtype) + return x, gate, weight + + rms_norm_native = ir.ops.rms_norm.impls["native"].impl_fn +mixer2_rms_norm_gated_native = ir.ops.mixer2_rms_norm_gated.impls["native"].impl_fn + + +def mixer2_rms_norm_gated_reference( + x: torch.Tensor, + gate: torch.Tensor, + weight: torch.Tensor | None, + epsilon: float, + group_size: int | None = None, +) -> torch.Tensor: + input_dtype = x.dtype + x = x * F.silu(gate.to(torch.float32)) + if group_size is None: + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + epsilon) + else: + *prefix_dims, hidden_size = x.shape + x_grouped = x.reshape(*prefix_dims, hidden_size // group_size, group_size) + variance = x_grouped.pow(2).mean(dim=-1, keepdim=True) + x = (x_grouped * torch.rsqrt(variance + epsilon)).reshape( + *prefix_dims, hidden_size + ) + + if weight is not None: + x = x.to(weight.dtype) * weight + return x.to(input_dtype) @pytest.mark.skipif( @@ -127,3 +161,123 @@ def test_torch_opcheck(self, dtype, n_tokens, hidden_size, epsilon, provider): # When checking the torch op, we have to set priority and use dispatch with ir.ops.rms_norm.set_priority([provider, "native"]): torch.library.opcheck(torch.ops.vllm_ir.rms_norm, args) + + +@pytest.mark.skipif( + not current_platform.is_cuda_alike() and not current_platform.is_xpu(), + reason="Currently only kernels on CUDA, ROCm and XPU", +) +def test_mixer2_rms_norm_gated_registration(): + expected = { + "native": True, + "triton": current_platform.is_cuda_alike() or current_platform.is_xpu(), + } + + actual = { + provider: impl.supported + for provider, impl in ir.ops.mixer2_rms_norm_gated.impls.items() + } + + assert actual == expected + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("n_tokens", [1, 17]) +@pytest.mark.parametrize("hidden_size", [16, 4096]) +@pytest.mark.parametrize("epsilon", [1e-6, 1e-5]) +@pytest.mark.parametrize("group_size", [None, 4]) +@pytest.mark.skipif( + not current_platform.is_cuda_alike() and not current_platform.is_xpu(), + reason="Currently only kernels on CUDA, ROCm and XPU", +) +class TestMixer2RMSNormGated: + @classmethod + def setup_class(cls, **kwargs): + torch.set_default_device(current_platform.device_type) + + def test_native_semantics(self, dtype, n_tokens, hidden_size, epsilon, group_size): + x, gate, weight = mixer2_rms_norm_gated_inputs(n_tokens, hidden_size, dtype) + out = mixer2_rms_norm_gated_native( + x, gate, weight, epsilon=epsilon, group_size=group_size + ) + ref = mixer2_rms_norm_gated_reference( + x, gate, weight, epsilon=epsilon, group_size=group_size + ) + + assert out.shape == x.shape + assert out.dtype == x.dtype + assert out.device == x.device + + torch.testing.assert_close(out, ref, rtol=get_default_rtol(out), atol=1e-3) + + weight1 = torch.ones_like(weight) + out3 = mixer2_rms_norm_gated_native( + x, gate, weight1, epsilon=epsilon, group_size=group_size + ) + out4 = mixer2_rms_norm_gated_native( + x, gate, None, epsilon=epsilon, group_size=group_size + ) + torch.testing.assert_close(out3, out4) + + out_full = mixer2_rms_norm_gated_native( + x, gate, weight, epsilon=epsilon, group_size=None + ) + out_all_grouped = mixer2_rms_norm_gated_native( + x, gate, weight, epsilon=epsilon, group_size=hidden_size + ) + torch.testing.assert_close( + out_all_grouped, out_full, rtol=get_default_rtol(out_full), atol=1e-3 + ) + + @pytest.mark.parametrize("provider", ["triton"]) + def test_impls(self, dtype, n_tokens, hidden_size, epsilon, group_size, provider): + impl = ir.ops.mixer2_rms_norm_gated.impls[provider] + if not impl.supported: + pytest.skip(f"{provider} impl not supported on this platform") + + x, gate, weight = mixer2_rms_norm_gated_inputs(n_tokens, hidden_size, dtype) + args = (x, gate, weight, epsilon, group_size) + + assert impl.supported + assert impl.supports_args(*args) + + out_impl = impl.impl_fn(*args) + out_native = mixer2_rms_norm_gated_native(*args) + + torch.testing.assert_close( + out_impl, out_native, rtol=get_default_rtol(out_impl), atol=1e-3 + ) + + with ir.ops.mixer2_rms_norm_gated.set_priority([provider, "native"]): + out_impl2 = ir.ops.mixer2_rms_norm_gated(*args) + + torch.testing.assert_close(out_impl2, out_impl, rtol=0.0, atol=0.0) + + assert not impl.supports_args(x, gate, None, epsilon, group_size) + assert not impl.supports_args(x, gate, None, epsilon, group_size=group_size) + + out_native_no_weight = mixer2_rms_norm_gated_native( + x, gate, None, epsilon, group_size + ) + out_native_unit_weight = mixer2_rms_norm_gated_native( + x, gate, torch.ones_like(weight), epsilon, group_size + ) + torch.testing.assert_close( + out_native_no_weight, + out_native_unit_weight, + rtol=get_default_rtol(out_native_no_weight), + atol=2e-4, + ) + + @pytest.mark.parametrize("provider", ["triton", "native"]) + def test_torch_opcheck( + self, dtype, n_tokens, hidden_size, epsilon, group_size, provider + ): + if not ir.ops.mixer2_rms_norm_gated.impls[provider].supported: + pytest.skip(f"{provider} impl not supported on this platform") + + x, gate, weight = mixer2_rms_norm_gated_inputs(n_tokens, hidden_size, dtype) + args = (x, gate, weight, epsilon, group_size) + + with ir.ops.mixer2_rms_norm_gated.set_priority([provider, "native"]): + torch.library.opcheck(torch.ops.vllm_ir.mixer2_rms_norm_gated, args) diff --git a/vllm/kernels/triton/ops/layernorm.py b/vllm/kernels/triton/ops/layernorm.py index 52bc5e7737f7..29ef898cda12 100644 --- a/vllm/kernels/triton/ops/layernorm.py +++ b/vllm/kernels/triton/ops/layernorm.py @@ -9,6 +9,9 @@ current_platform.import_kernels() CUDA_ALIKE = current_platform.is_cuda_alike() + +GPGPU_DEVICE = CUDA_ALIKE or current_platform.is_xpu() + """Most kernels in this file are supported on all CUDA-alike platforms.""" @@ -19,7 +22,9 @@ @ir.ops.mixer2_rms_norm_gated.register_impl( - "triton", supports_args=mixer2_rms_norm_gated_has_weight, supported=CUDA_ALIKE + "triton", + supports_args=mixer2_rms_norm_gated_has_weight, + supported=GPGPU_DEVICE, ) def mixer2_rms_norm_gated( x: Tensor, diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index ffc765257ed7..0d3333b14b4c 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -322,7 +322,9 @@ def get_default_ir_op_priority( using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE default = ["native"] if using_inductor else ["xpu_kernels", "native"] - return IrOpPriorityConfig.with_default(default) + return IrOpPriorityConfig.with_default( + default, mixer2_rms_norm_gated=["triton", "native"] + ) @classmethod def device_count(cls) -> int: