From 2f5602ff2721f53c7110fb002391b44a6caf6bf8 Mon Sep 17 00:00:00 2001 From: Hubert Lu Date: Wed, 29 Apr 2026 00:16:15 +0000 Subject: [PATCH] Fix Aiter RMSNorm layout handling Made-with: Cursor --- python/sglang/srt/layers/layernorm.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 55d880643c9a..b5d2eba98517 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -242,6 +242,15 @@ def forward_aiter( residual: Optional[torch.Tensor] = None, post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # Aiter's RMSNorm kernels expect 2D contiguous inputs. Keep the + # already-safe layout as a zero-copy path, and only normalize strided or + # higher-rank views such as Q/K slices from packed QKV projections. + needs_reshape = x.dim() != 2 and residual is None + if needs_reshape: + original_shape = x.shape + x = x.contiguous().reshape(-1, original_shape[-1]) + elif not x.is_contiguous(): + x = x.contiguous() if residual is not None: residual_out = torch.empty_like(x) output = torch.empty_like(x) @@ -256,7 +265,10 @@ def forward_aiter( self.variance_epsilon, ) return output, residual_out - return rms_norm(x, self.weight.data, self.variance_epsilon) + output = rms_norm(x, self.weight.data, self.variance_epsilon) + if needs_reshape: + output = output.reshape(original_shape) + return output def forward_hip( self,