diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 3c18cea7046..88524ffae01 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -20,9 +20,12 @@ import torch.nn as nn from sglang.srt.custom_op import CustomOp -from sglang.srt.utils import is_cuda +from sglang.srt.utils import is_cuda, is_hip + +logger = logging.getLogger(__name__) _is_cuda = is_cuda() +_is_hip = is_hip() if _is_cuda: from sgl_kernel import ( @@ -32,8 +35,20 @@ rmsnorm, ) +if _is_hip: -logger = logging.getLogger(__name__) + from aiter.ops.rmsnorm import rms_norm, rmsnorm2d_fwd_with_add + + rmsnorm = rms_norm + + def fused_add_rmsnorm( + x: torch.Tensor, + residual: torch.Tensor, + w: torch.Tensor, + eps: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + rmsnorm2d_fwd_with_add(x, x, residual, residual, w, eps) + return x, residual class RMSNorm(CustomOp): @@ -139,7 +154,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" -if not _is_cuda: +if not (_is_cuda or _is_hip): logger.info( "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." )