diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index a6d9c0402e7a..e4960bdb42d6 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -451,9 +451,6 @@ def __init__( super().__init__() self.weight = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps - # Re-dispatch - if _is_hip: - self._forward_method = self.forward_native def _forward_impl( self, @@ -499,6 +496,47 @@ def forward_cuda( ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: return self._forward_impl(x, residual, post_residual_addition) + def forward_hip( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + post_residual_addition: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if not _has_vllm_rms_norm: + return self.forward_native(x, residual, post_residual_addition) + + w = self.weight.data + 1.0 + if _use_aiter: + # aiter API: rms_norm(input, weight, eps) -> output + # fused_add_rms_norm(output, input, residual, residual_out, weight, eps) + if residual is not None: + output = torch.empty_like(x) + residual_out = torch.empty_like(x) + if post_residual_addition is not None: + residual = residual + post_residual_addition + fused_add_rms_norm( + output, x, residual, residual_out, w, self.variance_epsilon + ) + return output, residual_out + return rms_norm(x, w, self.variance_epsilon) + else: + # vllm API: rms_norm(out, input, weight, eps) -> None (in-place) + # fused_add_rms_norm(out, input, residual_out, residual, weight, eps) + if not x.is_contiguous(): + x = x.contiguous() + if residual is not None: + out = torch.empty_like(x) + residual_out = torch.empty_like(x) + if post_residual_addition is not None: + residual = residual + post_residual_addition + fused_add_rms_norm( + out, x, residual_out, residual, w, self.variance_epsilon + ) + return out, residual_out + out = torch.empty_like(x) + rms_norm(out, x, w, self.variance_epsilon) + return out + def forward_cpu( self, x: torch.Tensor,