Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

import torch
import torch_npu
from sgl_kernel_npu.attention.sinks_attention import (
attention_sinks_prefill_triton,
attention_sinks_triton,
)

from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.hardware_backend.npu.attention.mla_preprocess import (
Expand Down Expand Up @@ -260,9 +264,17 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
// self.page_size
)
if forward_batch.extend_seq_lens is not None:
self.forward_metadata.extend_seq_lens = forward_batch.extend_seq_lens
self.forward_metadata.extend_seq_lens_cpu_int = (
forward_batch.extend_seq_lens.cpu().int()
)
if forward_batch.seq_lens is not None:
self.forward_metadata.seq_lens = forward_batch.seq_lens.int()
else:
self.forward_metadata.seq_lens = forward_batch.seq_lens_cpu.to(
self.device
).int()

self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
if (
not forward_batch.forward_mode.is_draft_extend_v2()
Expand Down Expand Up @@ -576,6 +588,7 @@ def forward_extend(
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
):
if topk_indices is not None:
return self.forward_sparse(
Expand Down Expand Up @@ -617,6 +630,22 @@ def forward_extend(
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)

if sinks is not None:
attn_out = attention_sinks_prefill_triton(
q,
k_cache,
v_cache,
sinks,
self.forward_metadata.extend_seq_lens,
self.forward_metadata.block_tables,
self.forward_metadata.seq_lens,
layer.scaling,
layer.sliding_window_size,
layer.tp_q_head_num,
layer.tp_k_head_num,
)
return attn_out

if self.use_fia:
"""FIA will support multi-bs in the later version of CANN"""
q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
Expand Down Expand Up @@ -1036,6 +1065,7 @@ def forward_decode_graph(
save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
):
if save_kv_cache:
if self.use_mla:
Expand All @@ -1049,6 +1079,24 @@ def forward_decode_graph(
layer, forward_batch.out_cache_loc, k, v
)

if sinks is not None:
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)

attn_out = attention_sinks_triton(
q,
k_cache,
v_cache,
sinks,
self.forward_metadata.block_tables,
self.forward_metadata.seq_lens,
layer.scaling,
layer.sliding_window_size,
layer.tp_q_head_num,
layer.tp_k_head_num,
)
return attn_out

if not self.use_mla:
num_tokens = q.shape[0]
"""PA will support bs<tp in the later version of CANN"""
Expand Down Expand Up @@ -1217,6 +1265,7 @@ def forward_decode(
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
):
if is_mla_preprocess_enabled():
# MLAPO does saving kv_cache
Expand Down Expand Up @@ -1244,6 +1293,7 @@ def forward_decode(
save_kv_cache,
q_rope=q_rope,
k_rope=k_rope,
sinks=sinks,
)

if not self.use_mla:
Expand All @@ -1254,6 +1304,22 @@ def forward_decode(
num_tokens = q.shape[0]
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)

if sinks is not None:
attn_out = attention_sinks_triton(
q,
k_cache,
v_cache,
sinks,
self.forward_metadata.block_tables,
self.forward_metadata.seq_lens,
layer.scaling,
layer.sliding_window_size,
layer.tp_q_head_num,
layer.tp_k_head_num,
)
return attn_out

if self.use_fia:
if self.forward_metadata.seq_lens_cpu_int is None:
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
Expand Down
10 changes: 9 additions & 1 deletion python/sglang/srt/layers/quantization/unquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,8 @@ def forward_npu(
)

expert_tokens = expert_tokens.to(torch.int64)
w13_bias = [layer.w13_weight_bias] if self.with_bias else None
w2_bias = [layer.w2_weight_bias] if self.with_bias else None
if layer.w13_weight.shape[-1] == layer.hidden_size:
w13 = layer.w13_weight.transpose(1, 2)
w2 = layer.w2_weight.transpose(1, 2)
Expand All @@ -500,6 +502,7 @@ def forward_npu(
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w13],
bias=w13_bias,
split_item=2,
group_list_type=0,
group_type=0,
Expand All @@ -508,7 +511,11 @@ def forward_npu(
)[0]

# act_fn:
if self.moe_runner_config.activation == "silu":
if self.moe_runner_config.activation == "npu_swiglu_oai":
from sgl_kernel_npu.activation.swiglu_oai import swiglu_oai

hidden_states = swiglu_oai(layer, hidden_states)
elif self.moe_runner_config.activation == "silu":
hidden_states = torch_npu.npu_swiglu(hidden_states)
else:
from sglang.srt.layers.activation import GeluAndMul
Expand All @@ -519,6 +526,7 @@ def forward_npu(
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
bias=w2_bias,
split_item=2,
group_list_type=0,
group_type=0,
Expand Down
35 changes: 20 additions & 15 deletions python/sglang/srt/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@
enable_fused_set_kv_buffer,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import LazyValue, add_prefix, is_cuda, make_layers
from sglang.srt.utils import LazyValue, add_prefix, is_cuda, is_npu, make_layers

_is_cuda = is_cuda()
_is_npu = is_npu()


if _is_cuda:
Expand Down Expand Up @@ -129,6 +130,7 @@ def __init__(
"use_weight_loader_fused": quant_config_name
!= "mxfp4"
}

self.experts = experts_type(
num_experts=config.num_local_experts
+ get_global_server_args().ep_num_redundant_experts,
Expand Down Expand Up @@ -305,20 +307,20 @@ def forward_prepare(
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

q, k = self.rotary_emb(
positions,
q,
k,
fused_set_kv_buffer_arg=(
create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if enable_fused_set_kv_buffer(forward_batch)
else None
),
)
extra_args = {}
if not _is_npu:
extra_args = {
"fused_set_kv_buffer_arg": (
create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if enable_fused_set_kv_buffer(forward_batch)
else None
),
}
q, k = self.rotary_emb(positions, q, k, **extra_args)
inner_state = q, k, v, forward_batch
return None, forward_batch, inner_state

Expand Down Expand Up @@ -490,6 +492,9 @@ def __init__(
self.vocab_size = config.vocab_size
self.pp_group = get_pp_group()

if is_npu:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be _is_npu, otherwise this condition will always be True and will hinder running on other backends.
@Todobe can you fix it?

config.hidden_act = "npu_swiglu_oai"

if self.pp_group.is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,7 @@ def _handle_model_specific_adjustments(self):
else:
self.attention_backend = "triton"

supported_backends = ["triton", "trtllm_mha", "fa3", "fa4"]
supported_backends = ["triton", "trtllm_mha", "fa3", "fa4", "ascend"]
prefill_attn_backend, decode_attn_backend = self.get_attention_backends()
assert (
prefill_attn_backend in supported_backends
Expand Down
Loading