Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,15 @@ def forward_aiter(
residual: Optional[torch.Tensor] = None,
post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# Aiter's RMSNorm kernels expect 2D contiguous inputs. Keep the
# already-safe layout as a zero-copy path, and only normalize strided or
# higher-rank views such as Q/K slices from packed QKV projections.
needs_reshape = x.dim() != 2 and residual is None
if needs_reshape:
original_shape = x.shape
x = x.contiguous().reshape(-1, original_shape[-1])
elif not x.is_contiguous():
x = x.contiguous()
if residual is not None:
residual_out = torch.empty_like(x)
output = torch.empty_like(x)
Expand All @@ -298,7 +307,10 @@ def forward_aiter(
self.variance_epsilon,
)
return output, residual_out
return rms_norm(x, self.weight.data, self.variance_epsilon)
output = rms_norm(x, self.weight.data, self.variance_epsilon)
if needs_reshape:
output = output.reshape(original_shape)
return output

def forward_hip(
self,
Expand Down
Loading