Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions tests/e2e/singlecard/test_fused_sigmoid_gating_delta_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
from vllm.model_executor.layers.fla.ops import fused_recurrent_gated_delta_rule
from vllm.model_executor.models.qwen3_next import fused_gdn_gating

from vllm_ascend.ops.triton.fla.sigmoid_gating import \
fused_sigmoid_gating_delta_rule_update


def test_triton_fusion_ops():
q = torch.randn(1, 1, 4, 128, dtype=torch.bfloat16).npu()
k = torch.randn(1, 1, 4, 128, dtype=torch.bfloat16).npu()
v = torch.randn(1, 1, 8, 128, dtype=torch.bfloat16).npu()
a = torch.tensor([[
-2.6094, -0.2617, -0.3848, 2.2656, 3.6250, -0.7383, -1.0938, -0.0505
]]).bfloat16().npu()
b = torch.tensor(
[[0.4277, 0.8906, 1.6875, 2.3750, 4.1562, 0.3809, 1.0625,
3.6719]]).bfloat16().npu()
ssm_state = torch.randn(1, 8, 128, 128, dtype=torch.bfloat16).npu()
non_spec_state_indices_tensor = torch.tensor([2]).int().npu()
non_spec_query_start_loc = torch.tensor([0, 1]).int().npu()
a_log = torch.tensor([
-2.6875, -3.2031, -3.3438, -2.7812, -3.0625, -4.0312, -5.3750, 5.7188
]).bfloat16().npu()
dt_bias = torch.tensor(
[-4.7812, -5.0938, -5.5000, 9.4375, 7.6250, -4.3750, -3.0938,
0.9688]).bfloat16().npu()

core_attn_out_non_spec_fused = fused_sigmoid_gating_delta_rule_update(
A_log=a_log.contiguous(),
dt_bias=dt_bias.contiguous(),
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
a=a.contiguous(),
b=b.contiguous(),
initial_state_source=ssm_state,
initial_state_indices=non_spec_state_indices_tensor,
cu_seqlens=non_spec_query_start_loc,
use_qk_l2norm_in_kernel=True,
softplus_beta=1.0,
softplus_threshold=20.0,
)

g, beta = fused_gdn_gating(a_log, a, b, dt_bias)
g_non_spec = g
beta_non_spec = beta
core_attn_out_non_spec_split, last_recurrent_state = (
fused_recurrent_gated_delta_rule(
q=q,
k=k,
v=v,
g=g_non_spec,
beta=beta_non_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=non_spec_query_start_loc,
ssm_state_indices=non_spec_state_indices_tensor,
use_qk_l2norm_in_kernel=True,
))
torch.testing.assert_close(core_attn_out_non_spec_fused,
core_attn_out_non_spec_split,
rtol=1e-02,
atol=1e-02,
equal_nan=True)
226 changes: 226 additions & 0 deletions vllm_ascend/ops/triton/fla/sigmoid_gating.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import os

import torch
from vllm.triton_utils import tl, tldevice, triton

if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
Expand Down Expand Up @@ -169,3 +170,228 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)


@triton.heuristics({
"USE_INITIAL_STATE": lambda args: args["h0_source"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
})
@triton.jit(do_not_specialize=["T"])
def fused_sigmoid_gating_delta_rule_update_kernel(
A_log,
a,
dt_bias,
softplus_beta,
softplus_threshold,
q,
k,
v,
b,
o,
h0_source,
h0_indices,
cu_seqlens,
scale,
T,
B: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
"""
Fused kernel that combines sigmoid gating computation with recurrent delta rule update.
"""
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)

if IS_VARLEN:
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int64),
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
)
all = T
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
all = B * T

o_k = i_k * BK + tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)

p_q = q + (bos * H + i_h) * K + o_k
p_k = k + (bos * H + i_h) * K + o_k
p_v = v + (bos * HV + i_hv) * V + o_v
p_b = b + bos * HV + i_hv
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v

# Gating computation pointers
p_A_log = A_log + i_hv
p_a = a + bos * HV + i_hv
p_dt_bias = dt_bias + i_hv

mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_k[:, None] & mask_v[None, :]

b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
idx = tl.load(h0_indices + i_n)
# if idx >= 0:
tmp0 = tl.where(idx < 0, 0, idx)
p_h0 = (h0_source + tmp0 * HV * K * V + i_hv * K * V +
o_k[:, None] * V + o_v[None, :])
temp1 = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
temp2 = tl.zeros_like(temp1)
value0 = tl.where(idx < 0, temp2, temp1)
b_h += value0 # tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
Comment on lines +245 to +253
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for loading the initial state is a bit convoluted and can be simplified for better readability and efficiency. The current implementation uses tl.where and performs a memory load that is subsequently discarded for negative indices. Using a simple if idx >= 0: check, as is done when storing the final state later in the kernel, would be cleaner and avoid the unnecessary load operation.

Suggested change
idx = tl.load(h0_indices + i_n)
# if idx >= 0:
tmp0 = tl.where(idx < 0, 0, idx)
p_h0 = (
h0_source
+ tmp0 * HV * K * V
+ i_hv * K * V
+ o_k[:, None] * V
+ o_v[None, :]
)
temp1 = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
temp2 = tl.zeros_like(temp1)
value0 = tl.where(idx < 0, temp2, temp1)
b_h += value0 # tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
idx = tl.load(h0_indices + i_n)
if idx >= 0:
p_h0 = (
h0_source
+ idx * HV * K * V
+ i_hv * K * V
+ o_k[:, None] * V
+ o_v[None, :]
)
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)


for i in range(0, T):
# Load inputs
b_q = tl.load(p_q + i * H * K, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k + i * H * K, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v + i * HV * V, mask=mask_v, other=0).to(tl.float32)
b_b = tl.load(p_b + i * HV).to(tl.float32)

# Compute sigmoid gating
# Load gating parameters
b_A_log = tl.load(p_A_log).to(tl.float32)
b_a = tl.load(p_a + i * HV).to(tl.float32)
b_dt_bias = tl.load(p_dt_bias).to(tl.float32)

# Compute g = -exp(A_log) * softplus(a + dt_bias)
x = b_a + b_dt_bias
beta_x = softplus_beta * x
# Apply softplus with numerical stability
softplus_x = tl.where(
beta_x <= softplus_threshold,
(1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)),
x,
)
b_g = -tl.exp(b_A_log) * softplus_x

# Compute beta = sigmoid(b)
b_beta = 1.0 / (1.0 + tl.exp(-b_b))

# Apply L2 normalization if enabled
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)

b_q = b_q * scale

# Apply gating to hidden state: h *= exp(g)
b_h *= tl.exp(b_g)

# Delta rule: v -= sum(h * k, dim=0)
b_v -= tl.sum(b_h * b_k[:, None], 0)

# Apply beta gating: v *= beta
b_v *= b_beta

# Update hidden state: h += k[:, None] * v[None, :]
b_h += b_k[:, None] * b_v[None, :]

# Compute output: o = sum(h * q, dim=0)
b_o = tl.sum(b_h * b_q[:, None], 0)
tl.store(p_o + i * HV * V, b_o.to(p_o.dtype.element_ty), mask=mask_v)

# # Update pointers for next timestep
# p_q += H * K
# p_k += H * K
# p_o += HV * V
# p_v += HV * V
# p_b += HV
# p_a += HV
Comment on lines +305 to +311
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of commented-out code appears to be from a different implementation pattern and is no longer in use. It should be removed to improve code clarity and maintainability.


# Store final state back to h0_source with bounds checking
if USE_INITIAL_STATE:
idx = tl.load(h0_indices + i_n)
if idx >= 0:
p_h0 = (h0_source + idx * HV * K * V + i_hv * K * V +
o_k[:, None] * V + o_v[None, :])
tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)


def fused_sigmoid_gating_delta_rule_update(
A_log: torch.Tensor,
a: torch.Tensor,
dt_bias: torch.Tensor,
softplus_beta: float,
softplus_threshold: float,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
b: torch.Tensor,
initial_state_source: torch.Tensor,
initial_state_indices: torch.Tensor,
scale: float = None,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: torch.Tensor = None,
):
"""
Fused triton implementation of sigmoid gating delta rule update.
This function uses a single fused kernel that combines both sigmoid gating computation
and the recurrent delta rule update for better performance.
"""
B, T, H, K, V = *k.shape, v.shape[-1]
HV = v.shape[2]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 64)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_stages = 3
num_warps = 1

if scale is None:
scale = k.shape[-1]**-0.5
else:
assert scale > 0, "scale must be positive"

o = q.new_empty(NK, *v.shape)
grid = (NK, NV, N * HV)

if not initial_state_indices.is_contiguous():
initial_state_indices = initial_state_indices.contiguous()
if not initial_state_source.is_contiguous():
initial_state_source_contiguous = initial_state_source.contiguous()
if not cu_seqlens.is_contiguous():
cu_seqlens = cu_seqlens.contiguous()

fused_sigmoid_gating_delta_rule_update_kernel[grid](
A_log=A_log,
a=a,
dt_bias=dt_bias,
softplus_beta=softplus_beta,
softplus_threshold=softplus_threshold,
q=q,
k=k,
v=v,
b=b,
o=o,
h0_source=initial_state_source_contiguous,
h0_indices=initial_state_indices,
cu_seqlens=cu_seqlens,
scale=scale,
T=T,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
num_warps=num_warps,
num_stages=num_stages,
)
initial_state_source.copy_(
initial_state_source_contiguous.view_as(initial_state_source))
o = o.squeeze(0)
return o
12 changes: 12 additions & 0 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,15 @@
# Future Plan:
# Remove this patch when vLLM support these operators.
#
# ** 15. File: worker/patch_qwen3_next.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.models.qwen3_next.Qwen3NextGatedDeltaNet._forward_core`
# Why:
# triton ops fused_recurrent_gated_delta_rule and fused_gdn_gating in vLLM perform not good on NPU.
# How:
# add a new fused triton ops in vLLM with ascend implementation.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/30860
# Future Plan:
# Remove this patch when vLLM support these operators.
#
1 change: 1 addition & 0 deletions vllm_ascend/patch/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@
import vllm_ascend.patch.worker.patch_qwen3_next # noqa
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa
import vllm_ascend.patch.worker.patch_rejection_sampler # noqa
import vllm_ascend.patch.worker.patch_qwen3_next # noqa
Loading
Loading