Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,8 @@ def _parse_quant_hf_config(self):
return quant_cfg

def _find_quant_modelslim_config(self):
if self.is_draft_model:
return None
quant_config_file = Path(self.model_path, "quant_model_description.json")
quant_cfg = None
if quant_config_file.is_file():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch_npu
from sgl_kernel_npu.norm.fused_split_qk_norm import fused_split_qk_norm

from sglang.srt.environ import envs
from sglang.srt.hardware_backend.npu.attention.mla_preprocess import (
Expand Down Expand Up @@ -323,46 +324,75 @@ def forward_dsa_prepare_npu(
)
else:
fused_qkv_a_proj_out = m.fused_qkv_a_proj_with_mqa(hidden_states)[0]
q, latent_cache = fused_qkv_a_proj_out.split(
[m.q_lora_rank, m.kv_lora_rank + m.qk_rope_head_dim], dim=-1
)

# overlap qk norm
q = m.q_a_layernorm(q)
if (
_use_ag_after_qlora
and layer_scatter_modes.layer_input_mode == ScatterMode.SCATTERED
and layer_scatter_modes.attn_mode == ScatterMode.TP_ATTN_FULL
):
q = scattered_to_tp_attn_full(q, forward_batch)
latent_cache = scattered_to_tp_attn_full(latent_cache, forward_batch)
q_lora = q.clone() # required for topk_indices

q_event = None
if m.alt_stream is not None:
m.alt_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(m.alt_stream):
if m.rotary_emb.is_neox_style:
q, latent_cache = fused_qkv_a_proj_out.split(
[m.q_lora_rank, m.kv_lora_rank + m.qk_rope_head_dim], dim=-1
)
# overlap qk norm
q = m.q_a_layernorm(q)
if (
_use_ag_after_qlora
and layer_scatter_modes.layer_input_mode == ScatterMode.SCATTERED
and layer_scatter_modes.attn_mode == ScatterMode.TP_ATTN_FULL
):
q = scattered_to_tp_attn_full(q, forward_batch)
latent_cache = scattered_to_tp_attn_full(latent_cache, forward_batch)
q_lora = q.clone() # required for topk_indices

q_event = None
if m.alt_stream is not None:
m.alt_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(m.alt_stream):
q = m.q_b_proj(q_lora)[0].view(-1, m.num_local_heads, m.qk_head_dim)
# record q to ensure memory space will not be released
q.record_stream(m.alt_stream)
q_event = m.alt_stream.record_event()
else:
q = m.q_b_proj(q_lora)[0].view(-1, m.num_local_heads, m.qk_head_dim)
# record q to ensure memory space will not be released
q.record_stream(m.alt_stream)
q_event = m.alt_stream.record_event()

k_nope, k_pe = latent_cache.unsqueeze(1).split(
[m.kv_lora_rank, m.qk_rope_head_dim], dim=-1
)
k_nope = m.kv_a_layernorm(k_nope)
# main stream waits for the completion of the event on the alt stream to ensure data dependency is complete
if q_event is not None:
torch.npu.current_stream().wait_event(q_event)
else:
q = m.q_b_proj(q_lora)[0].view(-1, m.num_local_heads, m.qk_head_dim)
if fused_qkv_a_proj_out.shape[0] < 65535:
q_lora, k_nope, k_pe = fused_split_qk_norm(
fused_qkv_a_proj_out,
m.q_a_layernorm,
m.kv_a_layernorm,
m.q_lora_rank,
m.kv_lora_rank,
m.qk_rope_head_dim,
eps=m.q_a_layernorm.variance_epsilon,
)
else:
q, latent_cache = fused_qkv_a_proj_out.split(
[m.q_lora_rank, m.kv_lora_rank + m.qk_rope_head_dim], dim=-1
)
# overlap qk norm
q = m.q_a_layernorm(q)

k_nope, k_pe = latent_cache.unsqueeze(1).split(
[m.kv_lora_rank, m.qk_rope_head_dim], dim=-1
)
k_nope = m.kv_a_layernorm(k_nope)
# main stream waits for the completion of the event on the alt stream to ensure data dependency is complete
if q_event is not None:
torch.npu.current_stream().wait_event(q_event)
q_lora = q.clone() # required for topk_indices
k_nope, k_pe = latent_cache.unsqueeze(1).split(
[m.kv_lora_rank, m.qk_rope_head_dim], dim=-1
)
k_nope = m.kv_a_layernorm(k_nope)
q = m.q_b_proj(q_lora)[0].view(-1, m.num_local_heads, m.qk_head_dim)

q_nope, q_pe = q.split([m.qk_nope_head_dim, m.qk_rope_head_dim], dim=-1)

q_nope_out = torch.bmm(q_nope.transpose(0, 1), m.w_kc)

q_nope_out = q_nope_out.transpose(0, 1)

if m.layer_id == 0:
m.rotary_emb.sin_cos_cache = m.rotary_emb.cos_sin_cache.index_select(
0, positions
)

q_pe, k_pe = m.rotary_emb(positions, q_pe, k_pe)

if nsa_use_prefill_cp(forward_batch):
Expand Down
146 changes: 99 additions & 47 deletions python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,22 +1254,48 @@ def forward_npu(
and not forward_batch.forward_mode.is_draft_extend()
)

cos_sin = self.rotary_emb.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)

bs = q_lora.shape[0]
if self.alt_stream is not None:
self.alt_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(self.alt_stream):

if self.rotary_emb.is_neox_style:
if not hasattr(forward_batch, "npu_indexer_sin_cos_cache"):
cos_sin = self.rotary_emb.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
forward_batch.npu_indexer_sin_cos_cache = (sin, cos)
else:
sin, cos = forward_batch.npu_indexer_sin_cos_cache

if self.alt_stream is not None:
self.alt_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(self.alt_stream):
q_lora = (
(q_lora, dynamic_scale) if dynamic_scale is not None else q_lora
)
q = self.wq_b(q_lora)[
0
] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128]
wq_b_event = self.alt_stream.record_event()
q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128]
q_pe, q_nope = torch.split(
q,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
) # [bs, 64, 64 + 64]
q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim)
q_pe = torch_npu.npu_rotary_mul(q_pe, cos, sin).view(
bs, self.n_heads, self.rope_head_dim
) # [bs, n, d]
q = torch.cat([q_pe, q_nope], dim=-1)
q.record_stream(self.alt_stream)
q_rope_event = self.alt_stream.record_event()
else:
q_lora = (
(q_lora, dynamic_scale) if dynamic_scale is not None else q_lora
)
q = self.wq_b(q_lora)[
0
] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128]
wq_b_event = self.alt_stream.record_event()
q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128]
q_pe, q_nope = torch.split(
q,
Expand All @@ -1281,9 +1307,52 @@ def forward_npu(
bs, self.n_heads, self.rope_head_dim
) # [bs, n, d]
q = torch.cat([q_pe, q_nope], dim=-1)
q.record_stream(self.alt_stream)
q_rope_event = self.alt_stream.record_event()

if envs.SGLANG_NPU_USE_MULTI_STREAM.get():
indexer_weight_stream = get_indexer_weight_stream()
indexer_weight_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(indexer_weight_stream):
x = x.view(-1, self.hidden_size)
weights = self.weights_proj(x.float())[0].to(torch.bfloat16)
weights.record_stream(indexer_weight_stream)
weights_event = indexer_weight_stream.record_event()
else:
x = x.view(-1, self.hidden_size)
weights = self.weights_proj(x.float())[0].to(torch.bfloat16)

k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128]
k = self.k_norm(k_proj)
if (
_use_ag_after_qlora
and layer_scatter_modes.layer_input_mode == ScatterMode.SCATTERED
and layer_scatter_modes.attn_mode == ScatterMode.TP_ATTN_FULL
):
k = scattered_to_tp_attn_full(k, forward_batch)
k_pe, k_nope = torch.split(
k,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
) # [bs, 64 + 64]

k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim)
k_pe = torch.ops.npu.npu_rotary_mul(k_pe, cos, sin).view(
bs, 1, self.rope_head_dim
) # [bs, 1, d]
k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128]

else:
if envs.SGLANG_NPU_USE_MULTI_STREAM.get():
indexer_weight_stream = get_indexer_weight_stream()
indexer_weight_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(indexer_weight_stream):
x = x.view(-1, self.hidden_size)
weights = self.weights_proj(x.float())[0].to(torch.bfloat16)
weights.record_stream(indexer_weight_stream)
weights_event = indexer_weight_stream.record_event()
else:
x = x.view(-1, self.hidden_size)
weights = self.weights_proj(x.float())[0].to(torch.bfloat16)

q_lora = (q_lora, dynamic_scale) if dynamic_scale is not None else q_lora
q = self.wq_b(q_lora)[0] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128]
q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128]
Expand All @@ -1292,43 +1361,26 @@ def forward_npu(
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
) # [bs, 64, 64 + 64]
q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim)
q_pe = torch_npu.npu_rotary_mul(q_pe, cos, sin).view(
bs, self.n_heads, self.rope_head_dim
) # [bs, n, d]
q = torch.cat([q_pe, q_nope], dim=-1)

if envs.SGLANG_NPU_USE_MULTI_STREAM.get():
indexer_weight_stream = get_indexer_weight_stream()
indexer_weight_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(indexer_weight_stream):
x = x.view(-1, self.hidden_size)
weights = self.weights_proj(x.float())[0].to(torch.bfloat16)
weights.record_stream(indexer_weight_stream)
weights_event = indexer_weight_stream.record_event()
else:
x = x.view(-1, self.hidden_size)
weights = self.weights_proj(x.float())[0].to(torch.bfloat16)
k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128]
k = self.k_norm(k_proj)
k_pe, k_nope = torch.split(
k,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
) # [bs, 64 + 64]

k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128]
k = self.k_norm(k_proj)
if (
_use_ag_after_qlora
and layer_scatter_modes.layer_input_mode == ScatterMode.SCATTERED
and layer_scatter_modes.attn_mode == ScatterMode.TP_ATTN_FULL
):
k = scattered_to_tp_attn_full(k, forward_batch)
k_pe, k_nope = torch.split(
k,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
) # [bs, 64 + 64]

k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim)
k_pe = torch.ops.npu.npu_rotary_mul(k_pe, cos, sin).view(
bs, 1, self.rope_head_dim
) # [bs, 1, d]
k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128]
k_pe = k_pe.unsqueeze(1)

if layer_id == 0:
self.rotary_emb.sin_cos_cache = (
self.rotary_emb.cos_sin_cache.index_select(0, positions)
)

q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
k_pe = k_pe.squeeze(1)
q = torch.cat([q_pe, q_nope], dim=-1)
k = torch.cat([k_pe, k_nope], dim=-1)

if (
is_prefill
Expand Down Expand Up @@ -1394,7 +1446,7 @@ def forward_npu(

past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id)

if self.alt_stream is not None:
if self.rotary_emb.is_neox_style and self.alt_stream is not None:
torch.npu.current_stream().wait_event(q_rope_event)
if envs.SGLANG_NPU_USE_MULTI_STREAM.get():
torch.npu.current_stream().wait_event(weights_event)
Expand Down
30 changes: 27 additions & 3 deletions python/sglang/srt/layers/rotary_embedding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

if _is_npu:
import torch_npu
from sgl_kernel_npu.norm.fused_rope_qk_mqa import fused_rope_qk_mqa


class RotaryEmbedding(MultiPlatformOp):
Expand Down Expand Up @@ -202,9 +203,14 @@ def forward_native(

if offsets is not None:
positions = positions + offsets

positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)

if hasattr(self, "sin_cos_cache"):
cos_sin = self.sin_cos_cache
else:
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)

query_shape = query.shape
Expand Down Expand Up @@ -236,8 +242,26 @@ def forward_npu(
assert (
fused_set_kv_buffer_arg is None
), "fused_set_kv_buffer_arg is not supported for npu implementation"
if query.dtype == torch.bfloat16 and self.cos_sin_cache.dtype == torch.float:
return self.forward_native(positions, query, key, offsets)
if (
query.dtype == torch.bfloat16
and self.cos_sin_cache.dtype == torch.float
or key.ndim == 3
):
if hasattr(self, "sin_cos_cache"):
cos_sin = self.sin_cos_cache
else:
cos_sin = self.cos_sin_cache.index_select(0, positions)

if query.shape[0] * query.shape[1] < 65535:
return fused_rope_qk_mqa(
query,
key,
cos_sin,
self.rotary_dim,
self.is_neox_style,
)
else:
return self.forward_native(positions, query, key, offsets)
if self.is_neox_style:
rotary_mode = "half"
else:
Expand Down
Loading
Loading