Skip to content
21 changes: 18 additions & 3 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
import torch.nn as nn

from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda
from sglang.srt.utils import is_cuda, is_hip

logger = logging.getLogger(__name__)

_is_cuda = is_cuda()
_is_hip = is_hip()

if _is_cuda:
from sgl_kernel import (
Expand All @@ -32,8 +35,20 @@
rmsnorm,
)

if _is_hip:

logger = logging.getLogger(__name__)
from aiter.ops.rmsnorm import rms_norm, rmsnorm2d_fwd_with_add

rmsnorm = rms_norm

def fused_add_rmsnorm(
x: torch.Tensor,
residual: torch.Tensor,
w: torch.Tensor,
eps: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
rmsnorm2d_fwd_with_add(x, x, residual, residual, w, eps)
return x, residual


class RMSNorm(CustomOp):
Expand Down Expand Up @@ -139,7 +154,7 @@ def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"


if not _is_cuda:
if not (_is_cuda or _is_hip):
logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
)
Expand Down
Loading