Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -625,11 +625,7 @@ def forward_extend(
key = key.view(1, actual_seq_len, num_heads, head_k_dim)
value = value.view(1, actual_seq_len, num_value_heads, head_v_dim)

beta = b.sigmoid()
g = fused_gdn_gating(A_log, a, dt_bias)

g = g.unsqueeze(0)
beta = beta.unsqueeze(0)
g, beta = fused_gdn_gating(A_log, a, b, dt_bias)

if is_target_verify:
core_attn_out = fused_recurrent_gated_delta_rule_update(
Expand Down
15 changes: 11 additions & 4 deletions python/sglang/srt/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,10 @@ def fused_qkvzba_split_reshape_cat(
@triton.jit
def fused_gdn_gating_kernel(
g,
beta_output,
A_log,
a,
b,
dt_bias,
seq_len,
NUM_HEADS: tl.constexpr,
Expand All @@ -209,30 +211,35 @@ def fused_gdn_gating_kernel(
mask = head_off < NUM_HEADS
blk_A_log = tl.load(A_log + head_off, mask=mask)
blk_a = tl.load(a + off, mask=mask)
blk_b = tl.load(b + off, mask=mask)
blk_bias = tl.load(dt_bias + head_off, mask=mask)
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
softplus_x = tl.where(
beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
)
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
tl.store(beta_output + off, blk_beta_output.to(b.dtype.element_ty), mask=mask)


def fused_gdn_gating(
A_log: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
dt_bias: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
batch, num_heads = a.shape
seq_len = 1
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
g = torch.empty_like(a, dtype=torch.float32)
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device)
fused_gdn_gating_kernel[grid](
g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1
g, beta_output, A_log, a, b, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1
)
return g
return g, beta_output


class Qwen3GatedDeltaNet(nn.Module):
Expand Down
Loading