From 1d307a1f923b6fa31b6b2076ad0958c2c365bfda Mon Sep 17 00:00:00 2001 From: Todobe Date: Thu, 27 Nov 2025 16:19:29 +0800 Subject: [PATCH 1/9] adapt gpt-oss-120b --- .../srt/layers/attention/ascend_backend.py | 57 +++ .../attention/triton_ops/sinks_attention.py | 374 ++++++++++++++++++ .../srt/layers/moe/fused_moe_triton/layer.py | 4 +- .../sglang/srt/layers/moe/moe_runner/base.py | 1 + .../sglang/srt/layers/quantization/unquant.py | 10 +- python/sglang/srt/models/gpt_oss.py | 54 ++- python/sglang/srt/server_args.py | 2 +- 7 files changed, 482 insertions(+), 20 deletions(-) create mode 100644 python/sglang/srt/layers/attention/triton_ops/sinks_attention.py diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index 458e98016f5e..155c4b4ff737 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -10,6 +10,10 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess_enabled from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend +from sglang.srt.layers.attention.triton_ops.sinks_attention import( + attention_sinks_prefill_triton, + attention_sinks_triton, +) from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -460,6 +464,7 @@ def forward_extend( q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, topk_indices: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, ): if topk_indices is not None: return self.forward_sparse( @@ -501,6 +506,21 @@ def forward_extend( k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) + if sinks is not None: + attn_out = attention_sinks_prefill_triton( + q, + k_cache, + v_cache, + sinks, + self.forward_metadata.block_tables, + self.forward_metadata.seq_lens_cpu_int, + layer.scaling, + layer.sliding_window_size, + layer.tp_q_head_num, + layer.tp_k_head_num, + ) + return attn_out + if self.use_fia: """FIA will support multi-bs in the later version of CANN""" q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim) @@ -755,6 +775,7 @@ def forward_decode_graph( save_kv_cache: bool = True, q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, ): if save_kv_cache: if self.use_mla: @@ -767,6 +788,24 @@ def forward_decode_graph( forward_batch.token_to_kv_pool.set_kv_buffer( layer, forward_batch.out_cache_loc, k, v ) + + if sinks is not None: + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) + + attn_out = attention_sinks_triton( + q, + k_cache, + v_cache, + sinks, + self.forward_metadata.block_tables, + self.forward_metadata.seq_lens, + layer.scaling, + layer.sliding_window_size, + layer.tp_q_head_num, + layer.tp_k_head_num, + ) + return attn_out if not self.use_mla: num_tokens = q.shape[0] @@ -927,6 +966,7 @@ def forward_decode( q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, topk_indices: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, ): if is_mla_preprocess_enabled(): # MLAPO does saving kv_cache @@ -954,6 +994,7 @@ def forward_decode( save_kv_cache, q_rope=q_rope, k_rope=k_rope, + sinks=sinks, ) if not self.use_mla: @@ -964,6 +1005,22 @@ def forward_decode( num_tokens = q.shape[0] k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) + + if sinks is not None: + attn_out =attention_sinks_triton( + q, + k_cache, + v_cache, + sinks, + self.forward_metadata.block_tables, + self.forward_metadata.seq_lens_cpu_int, + layer.scaling, + layer.sliding_window_size, + layer.tp_q_head_num, + layer.tp_k_head_num, + ) + return attn_out + if self.use_fia: attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( q.view( diff --git a/python/sglang/srt/layers/attention/triton_ops/sinks_attention.py b/python/sglang/srt/layers/attention/triton_ops/sinks_attention.py new file mode 100644 index 000000000000..a280d7658aae --- /dev/null +++ b/python/sglang/srt/layers/attention/triton_ops/sinks_attention.py @@ -0,0 +1,374 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def attention_sinks_kernel( + query, + k_cache, + v_cache, + sinks, + attn_out, + block_tables, + kv_seq_lens, + scale, + sliding_window_size, + q_head_num: tl.constexpr, + k_head_num: tl.constexpr, + D: tl.constexpr, + PAGE_SIZE: tl.constexpr, + MAX_BLOCKS: tl.constexpr, + sync_space, +): + i_s, i_qh = tl.program_id(0), tl.program_id(1) + i_kvh = i_qh // (q_head_num // k_head_num) + + kv_seq_len = tl.load(kv_seq_lens + i_s) + page_num = tl.cdiv(kv_seq_len, PAGE_SIZE) + start_page_num = 0 + start_kv_len = 0 + if sliding_window_size != -1 and kv_seq_len > sliding_window_size: + start_kv_len = (kv_seq_len - sliding_window_size).to(tl.int32) + start_page_num = start_kv_len // PAGE_SIZE + + cur_page_start = i_s * MAX_BLOCKS + offset_page = tl.arange(0, PAGE_SIZE) + offset_d = tl.arange(0, D) + Br: tl.constexpr = 1 + + sink = tl.load(sinks + i_qh) + history_max = tl.zeros([Br], dtype=tl.float32) + sink + l = tl.zeros([Br], dtype=tl.float32) + acc = tl.zeros([Br, D], dtype=tl.float32) + + offset_q = i_qh * D + offset_d + offset_seq = (tl.arange(0, Br) + i_s) * D * q_head_num + q = tl.load(query + offset_seq[:, None] + offset_q[None, :]).to(tl.float32) + + for page_idx in range(start_page_num, page_num): + block_idx = tl.load(block_tables + cur_page_start + page_idx) + mask_page = ((page_idx * PAGE_SIZE + offset_page) < kv_seq_len) & ((page_idx * PAGE_SIZE + offset_page) >= start_kv_len) + + offset_k = ( + block_idx * PAGE_SIZE * k_head_num * D + + offset_page[:, None] * k_head_num * D + + i_kvh * D + + offset_d[None, :] + ) + k = tl.load(k_cache + offset_k, mask=mask_page[:, None]).to(tl.float32) + v = tl.load(v_cache + offset_k, mask=mask_page[:, None]).to(tl.float32) + + k = tl.trans(k, (1, 0)) + qk = tl.dot(q, k) + qk = qk * scale + qk = tl.where(mask_page[None, :], qk, float("-inf")) + + new_e_max = tl.maximum(tl.max(qk, 1), history_max) + re_scale = tl.exp(history_max - new_e_max) + p_exp = tl.exp(qk - new_e_max[:, None]) + + # Online softmax update + l = l * re_scale + tl.sum(p_exp, 1) + acc = acc * re_scale[:, None] + tl.dot(p_exp.to(v.dtype), v) + tl.store(sync_space + tl.arange(0, Br), new_e_max) + history_max = new_e_max + + sink = tl.math.exp(sink - history_max) + l = l + sink + acc = acc / l[:, None] + tl.store(attn_out + offset_seq[:, None] + offset_q[None, :], acc.to(attn_out.type.element_ty)) + + +def attention_sinks_triton( + query, + k_cache, + v_cache, + sinks, + block_tables, + context_lens, + scale, + sliding_window_size, + q_head_num, + k_head_num, +): + S = query.shape[0] + D = query.shape[-1] // q_head_num + PAGE_SIZE = k_cache.shape[1] + v_head_dim = v_cache.shape[-1] + attn_output = torch.zeros( + (S, q_head_num, v_head_dim), + dtype=query.dtype, + device=query.device, + ) + sync_space = torch.empty( + (PAGE_SIZE,), + dtype=torch.float32, + device=query.device, + ) + + if isinstance(context_lens, list): + context_lens = torch.tensor(context_lens, device=query.device) + else: + context_lens = context_lens.to(query.device) + + grid = [S, q_head_num] + attention_sinks_kernel[grid]( + query, + k_cache, + v_cache, + sinks, + attn_output, + block_tables, + context_lens, + scale, + sliding_window_size, + q_head_num, + k_head_num, + D, + PAGE_SIZE, + block_tables.stride(0), + sync_space, + ) + + return attn_output.reshape(-1, q_head_num * v_head_dim) + + +@triton.jit +def attention_sinks_prefill_kernel( + query, + k_cache, + v_cache, + sinks, + attn_out, + block_tables, + kv_seq_lens, + scale, + sliding_window_size, + q_head_num: tl.constexpr, + k_head_num: tl.constexpr, + D: tl.constexpr, + PAGE_SIZE: tl.constexpr, + MAX_BLOCKS: tl.constexpr, + B: tl.constexpr, + BS: tl.constexpr, + sync_space, +): + i_ns, i_qh = tl.program_id(0), tl.program_id(1) + i_kvh = i_qh // (q_head_num // k_head_num) + + for i_bs in range(BS): + i_s = i_ns * BS + i_bs + + i_pos = -1 + kv_seq_len = i_s + + for i in range(B): + tmp_seq_len = tl.load(kv_seq_lens + i) + if kv_seq_len >= tmp_seq_len and i_pos == -1: + kv_seq_len -= tmp_seq_len + elif i_pos == -1: + i_pos = i + + if i_pos != -1: + kv_seq_len += 1 + + page_num = tl.cdiv(kv_seq_len, PAGE_SIZE) + start_page_num = 0 + start_kv_len = 0 + if sliding_window_size != -1 and kv_seq_len > sliding_window_size: + start_kv_len = (kv_seq_len - sliding_window_size).to(tl.int32) + start_page_num = start_kv_len // PAGE_SIZE + + cur_page_start = i_pos * MAX_BLOCKS + offset_page = tl.arange(0, PAGE_SIZE) + offset_d = tl.arange(0, D) + Br: tl.constexpr = 1 + + sink = tl.load(sinks + i_qh) + history_max = tl.zeros([Br], dtype=tl.float32) + sink + l = tl.zeros([Br], dtype=tl.float32) + acc = tl.zeros([Br, D], dtype=tl.float32) + + offset_q = i_qh * D + offset_d + offset_seq = (tl.arange(0, Br) + i_s) * D * q_head_num + q = tl.load(query + offset_seq[:, None] + offset_q[None, :]).to(tl.float32) + + for page_idx in range(start_page_num, page_num): + block_idx = tl.load(block_tables + cur_page_start + page_idx) + mask_page = ((page_idx * PAGE_SIZE + offset_page) < kv_seq_len) & ((page_idx * PAGE_SIZE + offset_page) >= start_kv_len) + + offset_k = ( + block_idx * PAGE_SIZE * k_head_num * D + + offset_page[:, None] * k_head_num * D + + i_kvh * D + + offset_d[None, :] + ) + k = tl.load(k_cache + offset_k, mask=mask_page[:, None]).to(tl.float32) + v = tl.load(v_cache + offset_k, mask=mask_page[:, None]).to(tl.float32) + + k = tl.trans(k, (1, 0)) + qk = tl.dot(q, k) + qk = qk * scale + qk = tl.where(mask_page[None, :], qk, float("-inf")) + + new_e_max = tl.maximum(tl.max(qk, 1), history_max) + re_scale = tl.exp(history_max - new_e_max) + p_exp = tl.exp(qk - new_e_max[:, None]) + + # Online softmax update + l = l * re_scale + tl.sum(p_exp, 1) + acc = acc * re_scale[:, None] + tl.dot(p_exp.to(v.dtype), v) + tl.store(sync_space + tl.arange(0, Br), new_e_max) + history_max = new_e_max + + sink = tl.math.exp(sink - history_max) + l = l + sink + acc = acc / l[:, None] + tl.store(attn_out + offset_seq[:, None] + offset_q[None, :], acc.to(attn_out.type.element_ty)) + + + +def attention_sinks_prefill_triton( + query, + k_cache, + v_cache, + sinks, + block_tables, + context_lens, + scale, + sliding_window_size, + q_head_num, + k_head_num, +): + S = query.shape[0] + kernel_num = 40 + BS = triton.cdiv(S, kernel_num) + NS = triton.cdiv(S, BS) + + D = query.shape[-1] // q_head_num + PAGE_SIZE = k_cache.shape[1] + v_head_dim = v_cache.shape[-1] + attn_output = torch.zeros( + (S, q_head_num, v_head_dim), + dtype=query.dtype, + device=query.device, + ) + sync_space = torch.empty( + (PAGE_SIZE,), + dtype=torch.float32, + device=query.device, + ) + + if isinstance(context_lens, list): + context_lens = torch.tensor(context_lens, device=query.device) + else: + context_lens = context_lens.to(query.device) + B = context_lens.shape[0] + + grid = [NS, q_head_num] + attention_sinks_prefill_kernel[grid]( + query, + k_cache, + v_cache, + sinks, + attn_output, + block_tables, + context_lens, + scale, + sliding_window_size, + q_head_num, + k_head_num, + D, + PAGE_SIZE, + block_tables.stride(0), + B, + BS, + sync_space, + ) + + return attn_output.reshape(-1, q_head_num * v_head_dim) + + + +def native_attn_sinks( + q, + k, + v, + scale, + sinks, +): + seq_len_q, q_head, group, _ = q.shape + qk = torch.einsum("qhgd,khd->hgqk", q, k) + qk = qk * scale + sinks = sinks.reshape(q_head, group, 1, 1).expand(q_head, group, seq_len_q, 1) + qk = torch.cat([qk, sinks], dim=-1) + w = torch.nn.functional.softmax(qk, dim=-1) + w = w[..., :-1] + out = torch.einsum("hgqk,khd->qhgd", w, v) + return out + + +def native_gqa_sinks( + query, + k_cache, + v_cache, + sinks, + block_tables, + context_lens, + scale, + sliding_window_size, + q_head_num, + k_head_num, + is_extend, +): + group_size = q_head_num // k_head_num + q_dim = query.shape[-1] // q_head_num + k_dim = k_cache.shape[-1] + v_dim = v_cache.shape[-1] + page_size = k_cache.shape[1] + out = [] + last = 0 + for i in range(len(context_lens)): + k = [] + v = [] + seq_len = context_lens[i].item() + for page in block_tables[i]: + idx = min(seq_len, page_size) + k.append(k_cache[page][:idx]) + v.append(v_cache[page][:idx]) + if seq_len <= page_size: + break + seq_len -= page_size + + k = torch.cat(k, dim=0) + v = torch.cat(v, dim=0) + + if is_extend: + q = query[last : last + context_lens[i]] + last += context_lens[i] + else: + q = query[last : last + 1] + last += 1 + + o = [] + for idx in range(q.shape[0]): + q_ = q[idx : idx + 1] + k_ = k[: idx + 1] if is_extend else k + v_ = v[: idx + 1] if is_extend else v + if sliding_window_size != -1 and sliding_window_size < k_.shape[0]: + k_ = k_[-sliding_window_size:] + v_ = v_[-sliding_window_size:] + q_ = q_.view(-1, k_head_num, group_size, q_dim) + k_ = k_.view(-1, k_head_num, k_dim) + v_ = v_.view(-1, k_head_num, v_dim) + o_ = native_attn_sinks(q_, k_, v_, scale, sinks) + o.append(o_) + + o = torch.cat(o, dim=0) + out.append(o) + + out = torch.cat(out, dim=0) + return out.reshape(-1, q_head_num * v_dim) + + diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 224ab3c7f1e6..6bd72af10c1d 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -2,7 +2,7 @@ import logging from enum import Enum -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple import torch @@ -161,6 +161,7 @@ def __init__( with_bias=False, routing_method_type: Optional[RoutingMethodType] = None, is_gated: bool = True, + custom_act_fn: Optional[Callable] = None, ): super().__init__() if params_dtype is None: @@ -221,6 +222,7 @@ def __init__( gemm1_alpha=gemm1_alpha, gemm1_clamp_limit=gemm1_clamp_limit, is_gated=is_gated, + custom_act_fn=custom_act_fn, ) self.quant_method: Optional[FusedMoEMethodBase] = None diff --git a/python/sglang/srt/layers/moe/moe_runner/base.py b/python/sglang/srt/layers/moe/moe_runner/base.py index 118206a904d7..4e4552e58835 100644 --- a/python/sglang/srt/layers/moe/moe_runner/base.py +++ b/python/sglang/srt/layers/moe/moe_runner/base.py @@ -37,6 +37,7 @@ class MoeRunnerConfig: # Runner configuration activation: str = "silu" is_gated: bool = True + custom_act_fn: Optional[Callable] = None apply_router_weight_on_input: bool = False inplace: bool = True no_combine: bool = False diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 67c65d5f3664..57fcd7b6aa67 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -423,6 +423,8 @@ def forward_npu( ) expert_tokens = expert_tokens.to(torch.int64) + w13_bias = layer.w13_weight_bias + w2_bias = layer.w2_weight_bias if layer.w13_weight.shape[-1] == layer.hidden_size: w13 = layer.w13_weight.transpose(1, 2) w2 = layer.w2_weight.transpose(1, 2) @@ -431,6 +433,7 @@ def forward_npu( hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w13], + bias=[w13_bias], split_item=2, group_list_type=0, group_type=0, @@ -439,9 +442,9 @@ def forward_npu( )[0] # act_fn: - if self.moe_runner_config.activation == "silu": - hidden_states = torch_npu.npu_swiglu(hidden_states) - else: + if self.moe_runner_config.custom_act_fn is not None: + hidden_states = self.moe_runner_config.custom_act_fn(layer, hidden_states) + elif self.moe_runner_config.activation == "silu": from sglang.srt.layers.activation import GeluAndMul hidden_states = GeluAndMul()(hidden_states) @@ -450,6 +453,7 @@ def forward_npu( hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w2], + bias=[w2_bias], split_item=2, group_list_type=0, group_type=0, diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 9474700c4342..6c2ec59e1cb1 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -70,14 +70,30 @@ enable_fused_set_kv_buffer, ) from sglang.srt.server_args import get_global_server_args -from sglang.srt.utils import LazyValue, add_prefix, is_cuda, make_layers +from sglang.srt.utils import LazyValue, add_prefix, is_cuda, is_npu, make_layers _is_cuda = is_cuda() +_is_npu = is_npu() if _is_cuda: from sgl_kernel import FusedSetKVBufferArg # noqa: F401 +# When using the Ascend backend, the silu activation function (torch_npu.npu_swiglu) +# caused precision issues with the GPT-OSS model; switching to a custom implementation +# of the swiglu activation function resolved the problem. +def _swiglu_oai(layer, hidden_states): + E, N, _ = layer.w13_weight.size() + gate_up = hidden_states.view(-1, N) + alpha = layer.moe_runner_config.gemm1_alpha + limit = layer.moe_runner_config.gemm1_clamp_limit + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate * alpha) + gated_output = (up + 1) * glu + return gated_output + class GptOssConfig(PretrainedConfig): model_type = "gpt_oss" @@ -127,6 +143,13 @@ def __init__( "use_weight_loader_fused": quant_config_name != "mxfp4" } + custom_act_fn = None + if _is_npu: + if self.layer_id == 0: + logger.warning( + "Warning: GPT-OSS use custom activate function on FusedMoE when using ascend backend." + ) + custom_act_fn = _swiglu_oai self.experts = experts_type( num_experts=config.num_local_experts + get_global_server_args().ep_num_redundant_experts, @@ -140,6 +163,7 @@ def __init__( gemm1_clamp_limit=self.gemm1_clamp_limit, with_bias=True, prefix=add_prefix("experts", prefix), + custom_act_fn=custom_act_fn, **extra_kwargs, ) @@ -300,20 +324,20 @@ def forward_prepare( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb( - positions, - q, - k, - fused_set_kv_buffer_arg=( - create_fused_set_kv_buffer_arg( - value=v, - layer=self.attn, - forward_batch=forward_batch, - ) - if enable_fused_set_kv_buffer(forward_batch) - else None - ), - ) + extra_args = {} + if _is_cuda: + extra_args = { + "fused_set_kv_buffer_arg" : ( + create_fused_set_kv_buffer_arg( + value=v, + layer=self.attn, + forward_batch=forward_batch, + ) + if enable_fused_set_kv_buffer(forward_batch) + else None + ), + } + q, k = self.rotary_emb(positions, q, k, **extra_args) inner_state = q, k, v, forward_batch return None, forward_batch, inner_state diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 5b13ac8016ab..ff1f7b6f65a5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1066,7 +1066,7 @@ def _handle_model_specific_adjustments(self): else: self.attention_backend = "triton" - supported_backends = ["triton", "trtllm_mha", "fa3", "fa4"] + supported_backends = ["triton", "trtllm_mha", "fa3", "fa4", "ascend"] prefill_attn_backend, decode_attn_backend = self.get_attention_backends() assert ( prefill_attn_backend in supported_backends From fde8ed114bb56b2720675e7e08c1cdb3837a2afd Mon Sep 17 00:00:00 2001 From: Todobe Date: Mon, 1 Dec 2025 14:20:25 +0800 Subject: [PATCH 2/9] move sinks_attention to sgl-kernel-npu --- .../srt/layers/attention/ascend_backend.py | 12 +- .../attention/triton_ops/sinks_attention.py | 374 ------------------ python/sglang/srt/models/gpt_oss.py | 3 +- 3 files changed, 8 insertions(+), 381 deletions(-) delete mode 100644 python/sglang/srt/layers/attention/triton_ops/sinks_attention.py diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index 155c4b4ff737..59161aa49e45 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -5,15 +5,15 @@ import torch import torch_npu +from sgl_kernel_npu.attention.sinks_attention import ( + attention_sinks_prefill_triton, + attention_sinks_triton, +) from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess_enabled from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend -from sglang.srt.layers.attention.triton_ops.sinks_attention import( - attention_sinks_prefill_triton, - attention_sinks_triton, -) from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -788,7 +788,7 @@ def forward_decode_graph( forward_batch.token_to_kv_pool.set_kv_buffer( layer, forward_batch.out_cache_loc, k, v ) - + if sinks is not None: k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) @@ -1007,7 +1007,7 @@ def forward_decode( v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) if sinks is not None: - attn_out =attention_sinks_triton( + attn_out = attention_sinks_triton( q, k_cache, v_cache, diff --git a/python/sglang/srt/layers/attention/triton_ops/sinks_attention.py b/python/sglang/srt/layers/attention/triton_ops/sinks_attention.py deleted file mode 100644 index a280d7658aae..000000000000 --- a/python/sglang/srt/layers/attention/triton_ops/sinks_attention.py +++ /dev/null @@ -1,374 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def attention_sinks_kernel( - query, - k_cache, - v_cache, - sinks, - attn_out, - block_tables, - kv_seq_lens, - scale, - sliding_window_size, - q_head_num: tl.constexpr, - k_head_num: tl.constexpr, - D: tl.constexpr, - PAGE_SIZE: tl.constexpr, - MAX_BLOCKS: tl.constexpr, - sync_space, -): - i_s, i_qh = tl.program_id(0), tl.program_id(1) - i_kvh = i_qh // (q_head_num // k_head_num) - - kv_seq_len = tl.load(kv_seq_lens + i_s) - page_num = tl.cdiv(kv_seq_len, PAGE_SIZE) - start_page_num = 0 - start_kv_len = 0 - if sliding_window_size != -1 and kv_seq_len > sliding_window_size: - start_kv_len = (kv_seq_len - sliding_window_size).to(tl.int32) - start_page_num = start_kv_len // PAGE_SIZE - - cur_page_start = i_s * MAX_BLOCKS - offset_page = tl.arange(0, PAGE_SIZE) - offset_d = tl.arange(0, D) - Br: tl.constexpr = 1 - - sink = tl.load(sinks + i_qh) - history_max = tl.zeros([Br], dtype=tl.float32) + sink - l = tl.zeros([Br], dtype=tl.float32) - acc = tl.zeros([Br, D], dtype=tl.float32) - - offset_q = i_qh * D + offset_d - offset_seq = (tl.arange(0, Br) + i_s) * D * q_head_num - q = tl.load(query + offset_seq[:, None] + offset_q[None, :]).to(tl.float32) - - for page_idx in range(start_page_num, page_num): - block_idx = tl.load(block_tables + cur_page_start + page_idx) - mask_page = ((page_idx * PAGE_SIZE + offset_page) < kv_seq_len) & ((page_idx * PAGE_SIZE + offset_page) >= start_kv_len) - - offset_k = ( - block_idx * PAGE_SIZE * k_head_num * D - + offset_page[:, None] * k_head_num * D - + i_kvh * D - + offset_d[None, :] - ) - k = tl.load(k_cache + offset_k, mask=mask_page[:, None]).to(tl.float32) - v = tl.load(v_cache + offset_k, mask=mask_page[:, None]).to(tl.float32) - - k = tl.trans(k, (1, 0)) - qk = tl.dot(q, k) - qk = qk * scale - qk = tl.where(mask_page[None, :], qk, float("-inf")) - - new_e_max = tl.maximum(tl.max(qk, 1), history_max) - re_scale = tl.exp(history_max - new_e_max) - p_exp = tl.exp(qk - new_e_max[:, None]) - - # Online softmax update - l = l * re_scale + tl.sum(p_exp, 1) - acc = acc * re_scale[:, None] + tl.dot(p_exp.to(v.dtype), v) - tl.store(sync_space + tl.arange(0, Br), new_e_max) - history_max = new_e_max - - sink = tl.math.exp(sink - history_max) - l = l + sink - acc = acc / l[:, None] - tl.store(attn_out + offset_seq[:, None] + offset_q[None, :], acc.to(attn_out.type.element_ty)) - - -def attention_sinks_triton( - query, - k_cache, - v_cache, - sinks, - block_tables, - context_lens, - scale, - sliding_window_size, - q_head_num, - k_head_num, -): - S = query.shape[0] - D = query.shape[-1] // q_head_num - PAGE_SIZE = k_cache.shape[1] - v_head_dim = v_cache.shape[-1] - attn_output = torch.zeros( - (S, q_head_num, v_head_dim), - dtype=query.dtype, - device=query.device, - ) - sync_space = torch.empty( - (PAGE_SIZE,), - dtype=torch.float32, - device=query.device, - ) - - if isinstance(context_lens, list): - context_lens = torch.tensor(context_lens, device=query.device) - else: - context_lens = context_lens.to(query.device) - - grid = [S, q_head_num] - attention_sinks_kernel[grid]( - query, - k_cache, - v_cache, - sinks, - attn_output, - block_tables, - context_lens, - scale, - sliding_window_size, - q_head_num, - k_head_num, - D, - PAGE_SIZE, - block_tables.stride(0), - sync_space, - ) - - return attn_output.reshape(-1, q_head_num * v_head_dim) - - -@triton.jit -def attention_sinks_prefill_kernel( - query, - k_cache, - v_cache, - sinks, - attn_out, - block_tables, - kv_seq_lens, - scale, - sliding_window_size, - q_head_num: tl.constexpr, - k_head_num: tl.constexpr, - D: tl.constexpr, - PAGE_SIZE: tl.constexpr, - MAX_BLOCKS: tl.constexpr, - B: tl.constexpr, - BS: tl.constexpr, - sync_space, -): - i_ns, i_qh = tl.program_id(0), tl.program_id(1) - i_kvh = i_qh // (q_head_num // k_head_num) - - for i_bs in range(BS): - i_s = i_ns * BS + i_bs - - i_pos = -1 - kv_seq_len = i_s - - for i in range(B): - tmp_seq_len = tl.load(kv_seq_lens + i) - if kv_seq_len >= tmp_seq_len and i_pos == -1: - kv_seq_len -= tmp_seq_len - elif i_pos == -1: - i_pos = i - - if i_pos != -1: - kv_seq_len += 1 - - page_num = tl.cdiv(kv_seq_len, PAGE_SIZE) - start_page_num = 0 - start_kv_len = 0 - if sliding_window_size != -1 and kv_seq_len > sliding_window_size: - start_kv_len = (kv_seq_len - sliding_window_size).to(tl.int32) - start_page_num = start_kv_len // PAGE_SIZE - - cur_page_start = i_pos * MAX_BLOCKS - offset_page = tl.arange(0, PAGE_SIZE) - offset_d = tl.arange(0, D) - Br: tl.constexpr = 1 - - sink = tl.load(sinks + i_qh) - history_max = tl.zeros([Br], dtype=tl.float32) + sink - l = tl.zeros([Br], dtype=tl.float32) - acc = tl.zeros([Br, D], dtype=tl.float32) - - offset_q = i_qh * D + offset_d - offset_seq = (tl.arange(0, Br) + i_s) * D * q_head_num - q = tl.load(query + offset_seq[:, None] + offset_q[None, :]).to(tl.float32) - - for page_idx in range(start_page_num, page_num): - block_idx = tl.load(block_tables + cur_page_start + page_idx) - mask_page = ((page_idx * PAGE_SIZE + offset_page) < kv_seq_len) & ((page_idx * PAGE_SIZE + offset_page) >= start_kv_len) - - offset_k = ( - block_idx * PAGE_SIZE * k_head_num * D - + offset_page[:, None] * k_head_num * D - + i_kvh * D - + offset_d[None, :] - ) - k = tl.load(k_cache + offset_k, mask=mask_page[:, None]).to(tl.float32) - v = tl.load(v_cache + offset_k, mask=mask_page[:, None]).to(tl.float32) - - k = tl.trans(k, (1, 0)) - qk = tl.dot(q, k) - qk = qk * scale - qk = tl.where(mask_page[None, :], qk, float("-inf")) - - new_e_max = tl.maximum(tl.max(qk, 1), history_max) - re_scale = tl.exp(history_max - new_e_max) - p_exp = tl.exp(qk - new_e_max[:, None]) - - # Online softmax update - l = l * re_scale + tl.sum(p_exp, 1) - acc = acc * re_scale[:, None] + tl.dot(p_exp.to(v.dtype), v) - tl.store(sync_space + tl.arange(0, Br), new_e_max) - history_max = new_e_max - - sink = tl.math.exp(sink - history_max) - l = l + sink - acc = acc / l[:, None] - tl.store(attn_out + offset_seq[:, None] + offset_q[None, :], acc.to(attn_out.type.element_ty)) - - - -def attention_sinks_prefill_triton( - query, - k_cache, - v_cache, - sinks, - block_tables, - context_lens, - scale, - sliding_window_size, - q_head_num, - k_head_num, -): - S = query.shape[0] - kernel_num = 40 - BS = triton.cdiv(S, kernel_num) - NS = triton.cdiv(S, BS) - - D = query.shape[-1] // q_head_num - PAGE_SIZE = k_cache.shape[1] - v_head_dim = v_cache.shape[-1] - attn_output = torch.zeros( - (S, q_head_num, v_head_dim), - dtype=query.dtype, - device=query.device, - ) - sync_space = torch.empty( - (PAGE_SIZE,), - dtype=torch.float32, - device=query.device, - ) - - if isinstance(context_lens, list): - context_lens = torch.tensor(context_lens, device=query.device) - else: - context_lens = context_lens.to(query.device) - B = context_lens.shape[0] - - grid = [NS, q_head_num] - attention_sinks_prefill_kernel[grid]( - query, - k_cache, - v_cache, - sinks, - attn_output, - block_tables, - context_lens, - scale, - sliding_window_size, - q_head_num, - k_head_num, - D, - PAGE_SIZE, - block_tables.stride(0), - B, - BS, - sync_space, - ) - - return attn_output.reshape(-1, q_head_num * v_head_dim) - - - -def native_attn_sinks( - q, - k, - v, - scale, - sinks, -): - seq_len_q, q_head, group, _ = q.shape - qk = torch.einsum("qhgd,khd->hgqk", q, k) - qk = qk * scale - sinks = sinks.reshape(q_head, group, 1, 1).expand(q_head, group, seq_len_q, 1) - qk = torch.cat([qk, sinks], dim=-1) - w = torch.nn.functional.softmax(qk, dim=-1) - w = w[..., :-1] - out = torch.einsum("hgqk,khd->qhgd", w, v) - return out - - -def native_gqa_sinks( - query, - k_cache, - v_cache, - sinks, - block_tables, - context_lens, - scale, - sliding_window_size, - q_head_num, - k_head_num, - is_extend, -): - group_size = q_head_num // k_head_num - q_dim = query.shape[-1] // q_head_num - k_dim = k_cache.shape[-1] - v_dim = v_cache.shape[-1] - page_size = k_cache.shape[1] - out = [] - last = 0 - for i in range(len(context_lens)): - k = [] - v = [] - seq_len = context_lens[i].item() - for page in block_tables[i]: - idx = min(seq_len, page_size) - k.append(k_cache[page][:idx]) - v.append(v_cache[page][:idx]) - if seq_len <= page_size: - break - seq_len -= page_size - - k = torch.cat(k, dim=0) - v = torch.cat(v, dim=0) - - if is_extend: - q = query[last : last + context_lens[i]] - last += context_lens[i] - else: - q = query[last : last + 1] - last += 1 - - o = [] - for idx in range(q.shape[0]): - q_ = q[idx : idx + 1] - k_ = k[: idx + 1] if is_extend else k - v_ = v[: idx + 1] if is_extend else v - if sliding_window_size != -1 and sliding_window_size < k_.shape[0]: - k_ = k_[-sliding_window_size:] - v_ = v_[-sliding_window_size:] - q_ = q_.view(-1, k_head_num, group_size, q_dim) - k_ = k_.view(-1, k_head_num, k_dim) - v_ = v_.view(-1, k_head_num, v_dim) - o_ = native_attn_sinks(q_, k_, v_, scale, sinks) - o.append(o_) - - o = torch.cat(o, dim=0) - out.append(o) - - out = torch.cat(out, dim=0) - return out.reshape(-1, q_head_num * v_dim) - - diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 6c2ec59e1cb1..b96b30daa26e 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -79,6 +79,7 @@ if _is_cuda: from sgl_kernel import FusedSetKVBufferArg # noqa: F401 + # When using the Ascend backend, the silu activation function (torch_npu.npu_swiglu) # caused precision issues with the GPT-OSS model; switching to a custom implementation # of the swiglu activation function resolved the problem. @@ -327,7 +328,7 @@ def forward_prepare( extra_args = {} if _is_cuda: extra_args = { - "fused_set_kv_buffer_arg" : ( + "fused_set_kv_buffer_arg": ( create_fused_set_kv_buffer_arg( value=v, layer=self.attn, From 9212884a250c29a662e7c49234a19769e986bb14 Mon Sep 17 00:00:00 2001 From: Todobe Date: Mon, 1 Dec 2025 15:50:24 +0800 Subject: [PATCH 3/9] fix silu --- python/sglang/srt/layers/quantization/unquant.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 57fcd7b6aa67..66d303fc234e 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -445,6 +445,8 @@ def forward_npu( if self.moe_runner_config.custom_act_fn is not None: hidden_states = self.moe_runner_config.custom_act_fn(layer, hidden_states) elif self.moe_runner_config.activation == "silu": + hidden_states = torch_npu.npu_swiglu(hidden_states) + else: from sglang.srt.layers.activation import GeluAndMul hidden_states = GeluAndMul()(hidden_states) From d215df7ded01df13544c4f91cd6b9876cd8ed222 Mon Sep 17 00:00:00 2001 From: Todobe Date: Thu, 4 Dec 2025 10:51:04 +0800 Subject: [PATCH 4/9] move swiglu_oai to sgl-kernel-npu --- python/sglang/srt/models/gpt_oss.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index b96b30daa26e..1ac544e1c128 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -79,21 +79,8 @@ if _is_cuda: from sgl_kernel import FusedSetKVBufferArg # noqa: F401 - -# When using the Ascend backend, the silu activation function (torch_npu.npu_swiglu) -# caused precision issues with the GPT-OSS model; switching to a custom implementation -# of the swiglu activation function resolved the problem. -def _swiglu_oai(layer, hidden_states): - E, N, _ = layer.w13_weight.size() - gate_up = hidden_states.view(-1, N) - alpha = layer.moe_runner_config.gemm1_alpha - limit = layer.moe_runner_config.gemm1_clamp_limit - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=limit) - up = up.clamp(min=-limit, max=limit) - glu = gate * torch.sigmoid(gate * alpha) - gated_output = (up + 1) * glu - return gated_output +if _is_npu: + from sgl_kernel_npu.moe.swiglu_oai import swiglu_oai class GptOssConfig(PretrainedConfig): @@ -150,7 +137,7 @@ def __init__( logger.warning( "Warning: GPT-OSS use custom activate function on FusedMoE when using ascend backend." ) - custom_act_fn = _swiglu_oai + custom_act_fn = swiglu_oai self.experts = experts_type( num_experts=config.num_local_experts + get_global_server_args().ep_num_redundant_experts, From 19a130eb0425664f1e5b63b368f984acc2dd6fa8 Mon Sep 17 00:00:00 2001 From: Todobe Date: Thu, 4 Dec 2025 14:06:05 +0800 Subject: [PATCH 5/9] fix swiglu_oai --- python/sglang/srt/models/gpt_oss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 1ac544e1c128..5f1d2f95e156 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -80,7 +80,7 @@ from sgl_kernel import FusedSetKVBufferArg # noqa: F401 if _is_npu: - from sgl_kernel_npu.moe.swiglu_oai import swiglu_oai + from sgl_kernel_npu.activation.swiglu_oai import swiglu_oai class GptOssConfig(PretrainedConfig): From 17a030986efacec5acef6f3a9bd07661ac260563 Mon Sep 17 00:00:00 2001 From: Todobe Date: Fri, 19 Dec 2025 15:14:22 +0800 Subject: [PATCH 6/9] gptoss adpat prefix cache --- .../npu/attention/ascend_backend.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index e2b20a7043c3..2cc27c1e3a05 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -294,9 +294,17 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): // self.page_size ) if forward_batch.extend_seq_lens is not None: + self.forward_metadata.extend_seq_lens = forward_batch.extend_seq_lens self.forward_metadata.extend_seq_lens_cpu_int = ( forward_batch.extend_seq_lens.cpu().int() ) + if forward_batch.seq_lens is not None: + self.forward_metadata.seq_lens = forward_batch.seq_lens.int() + else: + self.forward_metadata.seq_lens = forward_batch.seq_lens_cpu.to( + self.device + ).int() + self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() if ( not forward_batch.forward_mode.is_draft_extend_v2() @@ -662,8 +670,9 @@ def forward_extend( k_cache, v_cache, sinks, + self.forward_metadata.extend_seq_lens, self.forward_metadata.block_tables, - self.forward_metadata.seq_lens_cpu_int, + self.forward_metadata.seq_lens, layer.scaling, layer.sliding_window_size, layer.tp_q_head_num, @@ -1325,7 +1334,7 @@ def forward_decode( v_cache, sinks, self.forward_metadata.block_tables, - self.forward_metadata.seq_lens_cpu_int, + self.forward_metadata.seq_lens, layer.scaling, layer.sliding_window_size, layer.tp_q_head_num, From 3a8e0a1e1f8803273b98531c9d65614c996a084c Mon Sep 17 00:00:00 2001 From: Todobe Date: Fri, 16 Jan 2026 20:07:48 +0800 Subject: [PATCH 7/9] modify swiglu activation --- .../srt/layers/moe/fused_moe_triton/layer.py | 4 +--- python/sglang/srt/layers/moe/moe_runner/base.py | 1 - .../sglang/srt/layers/quantization/unquant.py | 6 ++++-- python/sglang/srt/models/gpt_oss.py | 17 +++++------------ 4 files changed, 10 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 0b22dc98fe09..73fe410b8c41 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -2,7 +2,7 @@ import logging from enum import Enum -from typing import Callable, List, Optional, Tuple +from typing import List, Optional, Tuple import torch @@ -162,7 +162,6 @@ def __init__( with_bias=False, routing_method_type: Optional[RoutingMethodType] = None, is_gated: bool = True, - custom_act_fn: Optional[Callable] = None, ): super().__init__() if params_dtype is None: @@ -224,7 +223,6 @@ def __init__( gemm1_alpha=gemm1_alpha, gemm1_clamp_limit=gemm1_clamp_limit, is_gated=is_gated, - custom_act_fn=custom_act_fn, ) self.quant_method: Optional[FusedMoEMethodBase] = None diff --git a/python/sglang/srt/layers/moe/moe_runner/base.py b/python/sglang/srt/layers/moe/moe_runner/base.py index 4e4552e58835..118206a904d7 100644 --- a/python/sglang/srt/layers/moe/moe_runner/base.py +++ b/python/sglang/srt/layers/moe/moe_runner/base.py @@ -37,7 +37,6 @@ class MoeRunnerConfig: # Runner configuration activation: str = "silu" is_gated: bool = True - custom_act_fn: Optional[Callable] = None apply_router_weight_on_input: bool = False inplace: bool = True no_combine: bool = False diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 69f6ec7931be..72ac7cda0fdb 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -511,8 +511,10 @@ def forward_npu( )[0] # act_fn: - if self.moe_runner_config.custom_act_fn is not None: - hidden_states = self.moe_runner_config.custom_act_fn(layer, hidden_states) + if self.moe_runner_config.activation == "npu_swiglu_oai": + from sgl_kernel_npu.activation.swiglu_oai import swiglu_oai + + hidden_states = swiglu_oai(layer, hidden_states) elif self.moe_runner_config.activation == "silu": hidden_states = torch_npu.npu_swiglu(hidden_states) else: diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index d4a4f0a27334..817a4d7915d7 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -79,9 +79,6 @@ if _is_cuda: from sgl_kernel import FusedSetKVBufferArg # noqa: F401 -if _is_npu: - from sgl_kernel_npu.activation.swiglu_oai import swiglu_oai - class GptOssConfig(PretrainedConfig): model_type = "gpt_oss" @@ -131,13 +128,7 @@ def __init__( "use_weight_loader_fused": quant_config_name != "mxfp4" } - custom_act_fn = None - if _is_npu: - if self.layer_id == 0: - logger.warning( - "Warning: GPT-OSS use custom activate function on FusedMoE when using ascend backend." - ) - custom_act_fn = swiglu_oai + self.experts = experts_type( num_experts=config.num_local_experts + get_global_server_args().ep_num_redundant_experts, @@ -151,7 +142,6 @@ def __init__( gemm1_clamp_limit=self.gemm1_clamp_limit, with_bias=True, prefix=add_prefix("experts", prefix), - custom_act_fn=custom_act_fn, **extra_kwargs, ) @@ -313,7 +303,7 @@ def forward_prepare( q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) extra_args = {} - if _is_cuda: + if not _is_npu: extra_args = { "fused_set_kv_buffer_arg": ( create_fused_set_kv_buffer_arg( @@ -497,6 +487,9 @@ def __init__( self.vocab_size = config.vocab_size self.pp_group = get_pp_group() + if is_npu: + config.hidden_act = "npu_swiglu_oai" + if self.pp_group.is_first_rank: self.embed_tokens = VocabParallelEmbedding( config.vocab_size, From d46f1f5919c4a11c9e89a44bc50185bd9683a30a Mon Sep 17 00:00:00 2001 From: Todobe Date: Sat, 17 Jan 2026 19:43:01 +0800 Subject: [PATCH 8/9] fix bias --- .codemate/mcp/mcp_settings.json | 5 +++++ python/sglang/srt/layers/quantization/unquant.py | 8 ++++---- 2 files changed, 9 insertions(+), 4 deletions(-) create mode 100644 .codemate/mcp/mcp_settings.json diff --git a/.codemate/mcp/mcp_settings.json b/.codemate/mcp/mcp_settings.json new file mode 100644 index 000000000000..bfd801149da7 --- /dev/null +++ b/.codemate/mcp/mcp_settings.json @@ -0,0 +1,5 @@ +{ + "mcpServers": { + + } +} \ No newline at end of file diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 72ac7cda0fdb..299a7c7ea7b1 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -492,8 +492,8 @@ def forward_npu( ) expert_tokens = expert_tokens.to(torch.int64) - w13_bias = layer.w13_weight_bias - w2_bias = layer.w2_weight_bias + w13_bias = [layer.w13_weight_bias] if self.with_bias else None + w2_bias = [layer.w2_weight_bias] if self.with_bias else None if layer.w13_weight.shape[-1] == layer.hidden_size: w13 = layer.w13_weight.transpose(1, 2) w2 = layer.w2_weight.transpose(1, 2) @@ -502,7 +502,7 @@ def forward_npu( hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w13], - bias=[w13_bias], + bias=w13_bias, split_item=2, group_list_type=0, group_type=0, @@ -526,7 +526,7 @@ def forward_npu( hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w2], - bias=[w2_bias], + bias=w2_bias, split_item=2, group_list_type=0, group_type=0, From 30a7bd2324831f65538f1a52aa6fb0e75870172f Mon Sep 17 00:00:00 2001 From: Todobe Date: Sat, 17 Jan 2026 19:47:45 +0800 Subject: [PATCH 9/9] rm extra file --- .codemate/mcp/mcp_settings.json | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 .codemate/mcp/mcp_settings.json diff --git a/.codemate/mcp/mcp_settings.json b/.codemate/mcp/mcp_settings.json deleted file mode 100644 index bfd801149da7..000000000000 --- a/.codemate/mcp/mcp_settings.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "mcpServers": { - - } -} \ No newline at end of file