Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,20 @@
from sglang.srt.layers.layernorm import LayerNorm
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.utils import MultiPlatformOp
from sglang.srt.utils import add_prefix, ceil_align, is_cuda, is_hip, is_npu
from sglang.srt.utils import (
add_prefix,
ceil_align,
get_bool_env_var,
is_cuda,
is_hip,
is_npu,
)

global _use_multi_stream
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_fp8_fnuz = is_fp8_fnuz()
if _is_cuda:
try:
Expand Down Expand Up @@ -212,7 +220,9 @@ def __init__(
params_dtype=torch.bfloat16 if _is_cuda else torch.float32,
prefix=add_prefix("weights_proj", prefix),
)
self.k_norm = LayerNorm(self.head_dim, dtype=torch.float32)
self.k_norm = LayerNorm(
self.head_dim, dtype=torch.bfloat16 if _use_aiter else torch.float32
)
self.rotary_emb = get_rope_wrapper(
rope_head_dim,
rotary_dim=rope_head_dim,
Expand Down
16 changes: 15 additions & 1 deletion python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@
gemma_rmsnorm,
rmsnorm,
)
_has_aiter_layer_norm = False
_has_vllm_rms_norm = False
if _use_aiter:
from aiter import layernorm2d_fwd as layer_norm
from aiter import rmsnorm2d_fwd as rms_norm
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm

_has_aiter_layer_norm = True # aiter provides the layer_norm functions
_has_vllm_rms_norm = True # aiter provides the rms_norm functions
elif _is_hip:
try:
Expand Down Expand Up @@ -428,7 +431,18 @@ def forward_hip(
self,
x: torch.Tensor,
) -> torch.Tensor:
return self.forward_native(x)
if (
_has_aiter_layer_norm
and x.dtype in (torch.bfloat16, torch.float16)
and x.dtype == self.dtype
):
orig_shape = x.shape
x = x.reshape(-1, self.hidden_size)
return layer_norm(x, self.weight, self.bias, self.variance_epsilon).view(
orig_shape
)
else:
return self.forward_native(x)

def forward_npu(
self,
Expand Down
Loading