Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions python/sgl_kernel_npu/sgl_kernel_npu/fla/fused_gdn_gating.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Tuple

import torch
import triton
import triton.language as tl
import triton.runtime.driver as driver


# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
# beta_output = b.sigmoid()
@triton.jit(do_not_specialize=["batch", "seq_len"])
def fused_gdn_gating_kernel(
g,
beta_output,
A_log,
a,
b,
dt_bias,
batch,
seq_len,
NUM_HEADS: tl.constexpr,
beta: tl.constexpr,
threshold: tl.constexpr,
BLK_HEADS: tl.constexpr,
):
core, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)

head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
mask = head_off < NUM_HEADS

blk_A_log = tl.load(A_log + head_off, mask=mask)
blk_bias = tl.load(dt_bias + head_off, mask=mask)

for i_b in tl.range(core, batch, tl.num_programs(0)):
off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off

blk_a = tl.load(a + off, mask=mask)
blk_b = tl.load(b + off, mask=mask)

x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
softplus_x = tl.where(
beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
)

blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))

tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
tl.store(beta_output + off, blk_beta_output.to(b.dtype.element_ty), mask=mask)


def fused_gdn_gating_npu(
A_log: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
dt_bias: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
):
batch, num_heads = a.shape
seq_len = 1

g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device)

device = torch.npu.current_device()
num_cores = driver.active.utils.get_device_properties(device)["num_vectorcore"]

grid = (
triton.cdiv(num_cores, triton.cdiv(num_heads, 8)),
seq_len,
triton.cdiv(num_heads, 8),
)

fused_gdn_gating_kernel[grid](
g,
beta_output,
A_log,
a,
b,
dt_bias,
batch,
seq_len,
num_heads,
beta,
threshold,
8,
multibuffer=True,
num_warps=1,
)
return g, beta_output
40 changes: 40 additions & 0 deletions python/sgl_kernel_npu/sgl_kernel_npu/fla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,43 @@ def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
@tensor_cache
def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1


def fused_qkvzba_split_reshape_cat_torch(
mixed_qkvz: torch.Tensor, # [B, 3072]
mixed_ba: torch.Tensor, # [B, 16]
num_heads_qk: int = 4,
num_heads_v: int = 8,
head_qk: int = 128,
head_v: int = 128,
):
B = mixed_qkvz.shape[0]
v_group_size = num_heads_v // num_heads_qk # = 2

# Step 1: Reshape to [B, num_heads_qk, per_head_dim]
per_head_dim = 2 * head_qk + 2 * v_group_size * head_v # 768
x = mixed_qkvz.view(B, num_heads_qk, per_head_dim)

# Extract components per head
q = x[:, :, :head_qk] # [B, 4, 128]
k = x[:, :, head_qk : 2 * head_qk] # [B, 4, 128]
v_groups = x[:, :, 2 * head_qk : 2 * head_qk + v_group_size * head_v] # [B, 4, 256]
z_groups = x[:, :, 2 * head_qk + v_group_size * head_v :] # [B, 4, 256]

# Reshape V and Z to [B, num_heads_v, head_v]
v = v_groups.reshape(B, num_heads_v, head_v) # [B, 8, 128]
z = z_groups.reshape(B, num_heads_v, head_v) # [B, 8, 128]

# Build mixed_qkv = [Q_flat, K_flat, V_flat]
# Q_flat: concatenate all q heads → [B, 4*128]
q_flat = q.reshape(B, -1)
k_flat = k.reshape(B, -1)
v_flat = v.reshape(B, -1)
mixed_qkv = torch.cat([q_flat, k_flat, v_flat], dim=1) # [B, 2048]

# Process mixed_ba: [B, 16] → view as [B, 4, 4] → split b/a
ba = mixed_ba.view(B, num_heads_qk, 2 * v_group_size) # [B, 4, 4]
b = ba[:, :, :v_group_size].reshape(B, num_heads_v) # [B, 8]
a = ba[:, :, v_group_size:].reshape(B, num_heads_v) # [B, 8]

return mixed_qkv, z, b, a