From ac9c70e8b2f9438b94189b93e441bb1aa75c1f6e Mon Sep 17 00:00:00 2001 From: thomawan Date: Wed, 8 Apr 2026 12:16:57 +0800 Subject: [PATCH 1/2] Use ck layernorm kernel instead of torch implementation --- python/sglang/srt/layers/layernorm.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 0db6675e648f..60eb942e118d 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -65,11 +65,14 @@ gemma_rmsnorm, rmsnorm, ) +_has_aiter_layer_norm = False _has_vllm_rms_norm = False if _use_aiter: + from aiter import layernorm2d_fwd as layer_norm from aiter import rmsnorm2d_fwd as rms_norm from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm + _has_aiter_layer_norm = True # aiter provides the layer_norm functions _has_vllm_rms_norm = True # aiter provides the rms_norm functions elif _is_hip: try: @@ -428,7 +431,18 @@ def forward_hip( self, x: torch.Tensor, ) -> torch.Tensor: - return self.forward_native(x) + if ( + _has_aiter_layer_norm + and x.dtype in (torch.bfloat16, torch.float16) + and x.dtype == self.dtype + ): + orig_shape = x.shape + x = x.reshape(-1, self.hidden_size) + return layer_norm(x, self.weight, self.bias, self.variance_epsilon).view( + orig_shape + ) + else: + return self.forward_native(x) def forward_npu( self, From 59eeff573021cf00bec17d45ae296d6f05b4a397 Mon Sep 17 00:00:00 2001 From: thomawan Date: Wed, 8 Apr 2026 12:21:07 +0800 Subject: [PATCH 2/2] Use bf16 for LayerNorm in the NSA indexer when aiter is enabled --- .../sglang/srt/layers/attention/nsa/nsa_indexer.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 6bfcb3f66852..4ffd13bddbe8 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -16,12 +16,20 @@ from sglang.srt.layers.layernorm import LayerNorm from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.utils import MultiPlatformOp -from sglang.srt.utils import add_prefix, ceil_align, is_cuda, is_hip, is_npu +from sglang.srt.utils import ( + add_prefix, + ceil_align, + get_bool_env_var, + is_cuda, + is_hip, + is_npu, +) global _use_multi_stream _is_cuda = is_cuda() _is_hip = is_hip() _is_npu = is_npu() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _is_fp8_fnuz = is_fp8_fnuz() if _is_cuda: try: @@ -212,7 +220,9 @@ def __init__( params_dtype=torch.bfloat16 if _is_cuda else torch.float32, prefix=add_prefix("weights_proj", prefix), ) - self.k_norm = LayerNorm(self.head_dim, dtype=torch.float32) + self.k_norm = LayerNorm( + self.head_dim, dtype=torch.bfloat16 if _use_aiter else torch.float32 + ) self.rotary_emb = get_rope_wrapper( rope_head_dim, rotary_dim=rope_head_dim,