diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index a92469a98fce..b582ac5e06c6 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -284,6 +284,15 @@ def forward_aiter( residual: Optional[torch.Tensor] = None, post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # Aiter's RMSNorm kernels expect 2D contiguous inputs. Keep the + # already-safe layout as a zero-copy path, and only normalize strided or + # higher-rank views such as Q/K slices from packed QKV projections. + needs_reshape = x.dim() != 2 and residual is None + if needs_reshape: + original_shape = x.shape + x = x.contiguous().reshape(-1, original_shape[-1]) + elif not x.is_contiguous(): + x = x.contiguous() if residual is not None: residual_out = torch.empty_like(x) output = torch.empty_like(x) @@ -298,7 +307,10 @@ def forward_aiter( self.variance_epsilon, ) return output, residual_out - return rms_norm(x, self.weight.data, self.variance_epsilon) + output = rms_norm(x, self.weight.data, self.variance_epsilon) + if needs_reshape: + output = output.reshape(original_shape) + return output def forward_hip( self,