Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,6 @@ def __init__(
super().__init__()
self.weight = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
# Re-dispatch
if _is_hip:
self._forward_method = self.forward_native

def _forward_impl(
self,
Expand Down Expand Up @@ -499,6 +496,47 @@ def forward_cuda(
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
return self._forward_impl(x, residual, post_residual_addition)

def forward_hip(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not _has_vllm_rms_norm:
return self.forward_native(x, residual, post_residual_addition)

w = self.weight.data + 1.0
if _use_aiter:
# aiter API: rms_norm(input, weight, eps) -> output
# fused_add_rms_norm(output, input, residual, residual_out, weight, eps)
if residual is not None:
output = torch.empty_like(x)
residual_out = torch.empty_like(x)
if post_residual_addition is not None:
residual = residual + post_residual_addition
fused_add_rms_norm(
output, x, residual, residual_out, w, self.variance_epsilon
)
return output, residual_out
return rms_norm(x, w, self.variance_epsilon)
else:
# vllm API: rms_norm(out, input, weight, eps) -> None (in-place)
# fused_add_rms_norm(out, input, residual_out, residual, weight, eps)
if not x.is_contiguous():
x = x.contiguous()
if residual is not None:
out = torch.empty_like(x)
residual_out = torch.empty_like(x)
if post_residual_addition is not None:
residual = residual + post_residual_addition
fused_add_rms_norm(
out, x, residual_out, residual, w, self.variance_epsilon
)
return out, residual_out
out = torch.empty_like(x)
rms_norm(out, x, w, self.variance_epsilon)
return out

def forward_cpu(
self,
x: torch.Tensor,
Expand Down
Loading