diff --git a/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py b/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py new file mode 100644 index 00000000000..6e92208ec13 --- /dev/null +++ b/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py @@ -0,0 +1,69 @@ +from typing import Tuple + +import torch +import triton +import triton.language as tl + + +# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) +# beta_output = b.sigmoid() +@triton.jit +def fused_gdn_gating_kernel( + g, + beta_output, + A_log, + a, + b, + dt_bias, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off, mask=mask) + blk_b = tl.load(b + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where( + beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x + ) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) + tl.store(beta_output + off, blk_beta_output.to(b.dtype.element_ty), mask=mask) + + +def fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + batch, num_heads = a.shape + seq_len = 1 + grid = (batch, seq_len, triton.cdiv(num_heads, 8)) + g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) + beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device) + fused_gdn_gating_kernel[grid]( + g, + beta_output, + A_log, + a, + b, + dt_bias, + seq_len, + num_heads, + beta, + threshold, + 8, + num_warps=1, + ) + return g, beta_output diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 016b803ffbb..99785403346 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -5,6 +5,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule +from sglang.srt.layers.attention.fla.fused_gdn_gating import fused_gdn_gating from sglang.srt.layers.attention.fla.fused_recurrent import ( fused_recurrent_gated_delta_rule_update, ) @@ -30,7 +31,6 @@ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner -from sglang.srt.models.qwen3_next import fused_gdn_gating from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.utils import is_cuda, is_npu @@ -697,11 +697,7 @@ def forward_extend( key = key.view(1, actual_seq_len, num_heads, head_k_dim) value = value.view(1, actual_seq_len, num_value_heads, head_v_dim) - beta = b.sigmoid() - g = fused_gdn_gating(A_log, a, dt_bias) - - g = g.unsqueeze(0) - beta = beta.unsqueeze(0) + g, beta = fused_gdn_gating(A_log, a, b, dt_bias) if is_target_verify: core_attn_out = fused_recurrent_gated_delta_rule_update( diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index d817076fbfa..111272272b0 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -190,51 +190,6 @@ def fused_qkvzba_split_reshape_cat( return mixed_qkv, z, b, a -# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) -@triton.jit -def fused_gdn_gating_kernel( - g, - A_log, - a, - dt_bias, - seq_len, - NUM_HEADS: tl.constexpr, - beta: tl.constexpr, - threshold: tl.constexpr, - BLK_HEADS: tl.constexpr, -): - i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) - head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) - off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off - mask = head_off < NUM_HEADS - blk_A_log = tl.load(A_log + head_off, mask=mask) - blk_a = tl.load(a + off, mask=mask) - blk_bias = tl.load(dt_bias + head_off, mask=mask) - x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) - softplus_x = tl.where( - beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x - ) - blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x - tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) - - -def fused_gdn_gating( - A_log: torch.Tensor, - a: torch.Tensor, - dt_bias: torch.Tensor, - beta: float = 1.0, - threshold: float = 20.0, -) -> torch.Tensor: - batch, num_heads = a.shape - seq_len = 1 - grid = (batch, seq_len, triton.cdiv(num_heads, 8)) - g = torch.empty_like(a, dtype=torch.float32) - fused_gdn_gating_kernel[grid]( - g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1 - ) - return g - - class Qwen3GatedDeltaNet(nn.Module): def __init__( self,