diff --git a/python/sgl_kernel_npu/sgl_kernel_npu/fla/fused_gdn_gating.py b/python/sgl_kernel_npu/sgl_kernel_npu/fla/fused_gdn_gating.py new file mode 100644 index 000000000..2afd7bb88 --- /dev/null +++ b/python/sgl_kernel_npu/sgl_kernel_npu/fla/fused_gdn_gating.py @@ -0,0 +1,91 @@ +from typing import Tuple + +import torch +import triton +import triton.language as tl +import triton.runtime.driver as driver + + +# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) +# beta_output = b.sigmoid() +@triton.jit(do_not_specialize=["batch", "seq_len"]) +def fused_gdn_gating_kernel( + g, + beta_output, + A_log, + a, + b, + dt_bias, + batch, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + core, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + mask = head_off < NUM_HEADS + + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + + for i_b in tl.range(core, batch, tl.num_programs(0)): + off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off + + blk_a = tl.load(a + off, mask=mask) + blk_b = tl.load(b + off, mask=mask) + + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where( + beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x + ) + + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) + + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + tl.store(beta_output + off, blk_beta_output.to(b.dtype.element_ty), mask=mask) + + +def fused_gdn_gating_npu( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +): + batch, num_heads = a.shape + seq_len = 1 + + g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) + beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device) + + device = torch.npu.current_device() + num_cores = driver.active.utils.get_device_properties(device)["num_vectorcore"] + + grid = ( + triton.cdiv(num_cores, triton.cdiv(num_heads, 8)), + seq_len, + triton.cdiv(num_heads, 8), + ) + + fused_gdn_gating_kernel[grid]( + g, + beta_output, + A_log, + a, + b, + dt_bias, + batch, + seq_len, + num_heads, + beta, + threshold, + 8, + multibuffer=True, + num_warps=1, + ) + return g, beta_output diff --git a/python/sgl_kernel_npu/sgl_kernel_npu/fla/utils.py b/python/sgl_kernel_npu/sgl_kernel_npu/fla/utils.py index f4c4ae21f..4a3f67be5 100644 --- a/python/sgl_kernel_npu/sgl_kernel_npu/fla/utils.py +++ b/python/sgl_kernel_npu/sgl_kernel_npu/fla/utils.py @@ -174,3 +174,43 @@ def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: @tensor_cache def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1 + + +def fused_qkvzba_split_reshape_cat_torch( + mixed_qkvz: torch.Tensor, # [B, 3072] + mixed_ba: torch.Tensor, # [B, 16] + num_heads_qk: int = 4, + num_heads_v: int = 8, + head_qk: int = 128, + head_v: int = 128, +): + B = mixed_qkvz.shape[0] + v_group_size = num_heads_v // num_heads_qk # = 2 + + # Step 1: Reshape to [B, num_heads_qk, per_head_dim] + per_head_dim = 2 * head_qk + 2 * v_group_size * head_v # 768 + x = mixed_qkvz.view(B, num_heads_qk, per_head_dim) + + # Extract components per head + q = x[:, :, :head_qk] # [B, 4, 128] + k = x[:, :, head_qk : 2 * head_qk] # [B, 4, 128] + v_groups = x[:, :, 2 * head_qk : 2 * head_qk + v_group_size * head_v] # [B, 4, 256] + z_groups = x[:, :, 2 * head_qk + v_group_size * head_v :] # [B, 4, 256] + + # Reshape V and Z to [B, num_heads_v, head_v] + v = v_groups.reshape(B, num_heads_v, head_v) # [B, 8, 128] + z = z_groups.reshape(B, num_heads_v, head_v) # [B, 8, 128] + + # Build mixed_qkv = [Q_flat, K_flat, V_flat] + # Q_flat: concatenate all q heads → [B, 4*128] + q_flat = q.reshape(B, -1) + k_flat = k.reshape(B, -1) + v_flat = v.reshape(B, -1) + mixed_qkv = torch.cat([q_flat, k_flat, v_flat], dim=1) # [B, 2048] + + # Process mixed_ba: [B, 16] → view as [B, 4, 4] → split b/a + ba = mixed_ba.view(B, num_heads_qk, 2 * v_group_size) # [B, 4, 4] + b = ba[:, :, :v_group_size].reshape(B, num_heads_v) # [B, 8] + a = ba[:, :, v_group_size:].reshape(B, num_heads_v) # [B, 8] + + return mixed_qkv, z, b, a