Skip to content
36 changes: 33 additions & 3 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
import torch.nn as nn

from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda
from sglang.srt.utils import is_cuda, is_hip

_is_cuda = is_cuda()
_is_hip = is_hip()

if _is_cuda:
from sgl_kernel import (
Expand All @@ -32,6 +33,8 @@
rmsnorm,
)

if _is_hip:
from vllm._custom_ops import fused_add_rms_norm, rms_norm

logger = logging.getLogger(__name__)

Expand All @@ -46,23 +49,47 @@ def __init__(
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, *args, **kwargs):
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
elif _is_hip:
return self.forward_hip(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)

def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:

if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
return out

def forward_hip(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous():
# NOTE: Romove this if aiter kernel supports discontinuous input
x = x.contiguous()
if residual is not None:
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
out = torch.empty_like(x)
rms_norm(out, x, self.weight.data, self.variance_epsilon)
return out

def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous():
x = x.contiguous()
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
Expand Down Expand Up @@ -143,4 +170,7 @@ def extra_repr(self):
logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you fix the comment - sgl-kernel also exists on ROCm, also simply the cascaded ifs to:

if _is_hip:
   <...>
elif not _is_cuda:
   logger.info("Fallback to other kernel libraries.")
   <...>

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

)
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
if _is_hip:
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
else:
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
Loading