Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/config/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class IrOpPriorityConfig:
rms_norm: list[str] = Field(default_factory=list)
"""Priority list for vllm.ir.ops.rms_norm"""

mixer2_rms_norm_gated: list[str] = Field(default_factory=list)
"""Priority list for vllm.ir.ops.rms_norm_gated"""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: docstring should have vllm.ir.ops.mixer2_rms_norm_gated instead of vllm.ir.ops.rms_norm_gated.
(Or change the op name to rms_norm_gated)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review. I will change it.


def compute_hash(self) -> str:
"""
Produces a hash unique to the pass configuration.
Expand Down
4 changes: 2 additions & 2 deletions vllm/ir/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .layernorm import rms_norm
from .layernorm import mixer2_rms_norm_gated, rms_norm

__all__ = ["rms_norm"]
__all__ = ["rms_norm", "mixer2_rms_norm_gated"]
28 changes: 28 additions & 0 deletions vllm/ir/ops/layernorm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn.functional as F
from torch import Tensor

from ..op import register_op
Expand All @@ -19,3 +20,30 @@ def rms_norm(
if weight is not None:
x = x.to(weight.dtype) * weight
return x.to(orig_dtype)


@register_op
def mixer2_rms_norm_gated(
x: Tensor,
gate: Tensor,
weight: Tensor | None,
epsilon: float,
group_size: int | None = None,
) -> Tensor:
input_dtype = x.dtype
x = x * F.silu(gate.to(torch.float32))
if group_size is None:
# Standard RMSNorm: compute variance over the full hidden dimension
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + epsilon)
else:
# Grouped RMSNorm: compute variance independently within each group
*prefix_dims, hidden_dims = x.shape
x_grouped = x.view(*prefix_dims, hidden_dims // group_size, group_size)
variance = x_grouped.pow(2).mean(dim=-1, keepdim=True)
x_grouped = x_grouped * torch.rsqrt(variance + epsilon)
x = x_grouped.view(*prefix_dims, hidden_dims)

if weight is not None:
x = x.to(weight.dtype) * weight
return x.to(input_dtype)
4 changes: 2 additions & 2 deletions vllm/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Kernel implementations for vLLM."""

from . import aiter_ops, oink_ops, vllm_c, xpu_ops
from . import aiter_ops, oink_ops, triton, vllm_c, xpu_ops

__all__ = ["vllm_c", "aiter_ops", "oink_ops", "xpu_ops"]
__all__ = ["vllm_c", "aiter_ops", "oink_ops", "xpu_ops", "triton"]
5 changes: 5 additions & 0 deletions vllm/kernels/triton/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from . import ops

__all__ = ["ops"]
5 changes: 5 additions & 0 deletions vllm/kernels/triton/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from . import layernorm

__all__ = ["layernorm"]
42 changes: 42 additions & 0 deletions vllm/kernels/triton/ops/layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from torch import Tensor

from vllm import ir
from vllm.platforms import current_platform

current_platform.import_kernels()

CUDA_ALIKE = current_platform.is_cuda_alike()
"""Most kernels in this file are supported on all CUDA-alike platforms."""


mixer2_rms_norm_gated_has_weight = (
lambda x, gate, weight, epsilon, group_size=None: weight is not None
)
"""Triton gated RMSNorm kernel requires a weight tensor."""


@ir.ops.mixer2_rms_norm_gated.register_impl(
"triton", supports_args=mixer2_rms_norm_gated_has_weight, supported=CUDA_ALIKE
)
def mixer2_rms_norm_gated(
x: Tensor,
gate: Tensor,
weight: Tensor | None,
epsilon: float,
group_size: int | None = None,
) -> Tensor:
assert weight is not None
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated

return rms_norm_gated(
x,
weight,
bias=None,
z=gate,
eps=epsilon,
group_size=group_size,
norm_before_gate=False,
)
59 changes: 24 additions & 35 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from torch import nn

import vllm.ir.ops
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import (
divide,
Expand All @@ -30,7 +31,6 @@
causal_conv1d_fn,
causal_conv1d_update,
)
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined_varlen,
Expand Down Expand Up @@ -96,67 +96,56 @@ def forward_native(
# Each rank computes a local sum of squares followed by AllReduce
# 2. tp_size divides n_groups
# Each rank only reduces within its local group(s).
# No collective ops necessary.
# No collective ops necessary (use IR op directly).
# 3. The general case can be pretty complicated so we AllGather
# the input and then redundantly compute the RMSNorm.
input_dtype = x.dtype
x = x * nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x.to(input_dtype)
return (x * nn.functional.silu(gate.to(torch.float32))).to(input_dtype)

if self.n_groups == 1:
if self.tp_size > 1:
# Compute local sum and then reduce to obtain global sum
x = x * nn.functional.silu(gate.to(torch.float32))
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
global_sums = tensor_model_parallel_all_reduce(local_sums)
# Calculate the variance
count = self.tp_size * x.shape[-1]
variance = global_sums / count

x = x * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight * x).to(input_dtype)
else:
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
# No TP collective needed: use IR op
return vllm.ir.ops.mixer2_rms_norm_gated(
x, gate, self.weight, self.variance_epsilon
)
else:
redundant_tp: bool = self.n_groups % self.tp_size != 0
if redundant_tp:
# To handle the general case, redundantly apply the variance
x = x * nn.functional.silu(gate.to(torch.float32))
x = tensor_model_parallel_all_gather(x, -1)

*prefix_dims, hidden_dim = x.shape
group_count = hidden_dim // self.group_size
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
variance = x_grouped.pow(2).mean(-1, keepdim=True)
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
x = x_grouped.view(*prefix_dims, hidden_dim)

if redundant_tp:
*prefix_dims, hidden_dim = x.shape
group_count = hidden_dim // self.group_size
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
variance = x_grouped.pow(2).mean(-1, keepdim=True)
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
x = x_grouped.view(*prefix_dims, hidden_dim)
start = self.per_rank_hidden_size * self.tp_rank
end = start + self.per_rank_hidden_size
x = x[..., start:end]

return self.weight * x.to(input_dtype)
return (self.weight * x).to(input_dtype)
else:
# n_groups % tp_size == 0: local grouped RMSNorm, use IR op
return vllm.ir.ops.mixer2_rms_norm_gated(
x, gate, self.weight, self.variance_epsilon, self.group_size
)

def forward_cuda(
self,
x: torch.Tensor,
gate: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
input_dtype = x.dtype
if not self.use_rms_norm:
# Keep gate in float32 for numerical stability during silu
return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype)

if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
return self.forward_native(x, gate)

return rms_norm_gated(
x,
self.weight.data,
bias=None,
z=gate,
eps=self.variance_epsilon,
norm_before_gate=False,
)
return self.forward_native(x, gate)


def mamba_v2_sharded_weight_loader(
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/layers/mamba/ops/layernorm_gated.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from vllm.triton_utils import tl, triton


@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
@torch.library.wrap_triton
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
Expand Down Expand Up @@ -135,6 +134,8 @@ def _layer_norm_fwd(
group_size,
eps,
BLOCK_N=BLOCK_N,
HAS_BIAS=bias is not None,
HAS_Z=z is not None,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
Expand Down
4 changes: 3 additions & 1 deletion vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,9 @@ def get_default_ir_op_priority(cls, vllm_config: VllmConfig) -> IrOpPriorityConf
if envs.VLLM_USE_OINK_OPS:
rms_norm = ["oink"] + default

return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm)
return IrOpPriorityConfig.with_default(
default, rms_norm=rms_norm, mixer2_rms_norm_gated=["triton", "native"]
)


# NVML utils
Expand Down
4 changes: 3 additions & 1 deletion vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,4 +943,6 @@ def get_default_ir_op_priority(
else:
rms_norm = default

return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm)
return IrOpPriorityConfig.with_default(
default, rms_norm=rms_norm, mixer2_rms_norm_gated=["triton", "native"]
)
Loading