Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 161 additions & 124 deletions tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import triton
import triton.language as tl

from tensorrt_llm._torch.modules.fla.utils import input_guard
from tensorrt_llm._torch.modules.fla.utils import custom_device_ctx


@triton.heuristics({
Expand All @@ -30,6 +30,12 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
cu_seqlens,
scale,
T,
total_nh,
stride_q,
stride_k,
stride_v,
stride_a,
stride_b,
s_h0_0,
h0_dim0,
B: tl.constexpr,
Expand All @@ -46,117 +52,127 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
"""
Fused kernel that combines sigmoid gating computation with recurrent delta rule update.
"""
i_nh, i_v, i_k = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)

if IS_VARLEN:
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int64),
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
)
all = T
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
all = B * T

i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_k = i_k * BK + tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)

p_q = q + (bos * H + i_h) * K + o_k
p_k = k + (bos * H + i_h) * K + o_k
p_v = v + (bos * HV + i_hv) * V + o_v
p_b = b + bos * HV + i_hv
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v

# Gating computation pointers
p_A_log = A_log + i_hv
p_a = a + bos * HV + i_hv
p_dt_bias = dt_bias + i_hv

mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_k[:, None] & mask_v[None, :]
grid_stride_nh = tl.num_programs(2)

b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
idx = tl.load(h0_indices + i_n).to(tl.int64) # prevent int32 overflow
if idx >= 0:
tl.device_assert(idx < h0_dim0,
"idx out of bounds in h0_source load")
p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V + o_k[:, None] * V +
o_v[None, :])
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)

for _ in range(0, T):
# Load inputs
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_b = tl.load(p_b).to(tl.float32)

# Compute sigmoid gating
# Load gating parameters
b_A_log = tl.load(p_A_log).to(tl.float32)
b_a = tl.load(p_a).to(tl.float32)
b_dt_bias = tl.load(p_dt_bias).to(tl.float32)

# Compute g = -exp(A_log) * softplus(a + dt_bias)
x = b_a + b_dt_bias
beta_x = softplus_beta * x
# Apply softplus with numerical stability
softplus_x = tl.where(
beta_x <= softplus_threshold,
(1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)),
x,
)
b_g = -tl.exp(b_A_log) * softplus_x
while i_nh < total_nh:
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)

if IS_VARLEN:
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int64),
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
)
all = T
seq_T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
all = B * T
seq_T = T

# Decode q/k/v/a/b often arrive as views sliced out of larger packed tensors.
# Use the caller-provided token strides so the kernel can consume those views
# directly instead of relying on a packed contiguous layout.
p_q = q + bos * stride_q + i_h * K + o_k
p_k = k + bos * stride_k + i_h * K + o_k
p_v = v + bos * stride_v + i_hv * V + o_v
p_b = b + bos * stride_b + i_hv
# o is allocated in this wrapper and kept contiguous, so the output
# pointer arithmetic can use the packed [NK, B, T, HV, V] layout.
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v

# Compute beta = sigmoid(b)
b_beta = 1.0 / (1.0 + tl.exp(-b_b))
# Gating computation pointers
p_A_log = A_log + i_hv
p_a = a + bos * stride_a + i_hv
p_dt_bias = dt_bias + i_hv

# Apply L2 normalization if enabled
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
idx = tl.load(h0_indices + i_n).to(tl.int64)
if idx >= 0:
tl.device_assert(idx < h0_dim0,
"idx out of bounds in h0_source load")
p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V +
o_k[:, None] * V + o_v[None, :])
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)

b_q = b_q * scale
for _ in range(0, seq_T):
# Load inputs
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_b = tl.load(p_b).to(tl.float32)

# Apply gating to hidden state: h *= exp(g)
b_h *= tl.exp(b_g)
# Compute sigmoid gating
# Load gating parameters
b_A_log = tl.load(p_A_log).to(tl.float32)
b_a = tl.load(p_a).to(tl.float32)
b_dt_bias = tl.load(p_dt_bias).to(tl.float32)

# Delta rule: v -= sum(h * k, dim=0)
b_v -= tl.sum(b_h * b_k[:, None], 0)
# Compute g = -exp(A_log) * softplus(a + dt_bias)
x = b_a + b_dt_bias
beta_x = softplus_beta * x
# Apply softplus with numerical stability
softplus_x = tl.where(
beta_x <= softplus_threshold,
(1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)),
x,
)
b_g = -tl.exp(b_A_log) * softplus_x

# Apply beta gating: v *= beta
b_v *= b_beta
# Compute beta = sigmoid(b)
b_beta = 1.0 / (1.0 + tl.exp(-b_b))

# Update hidden state: h += k[:, None] * v[None, :]
b_h += b_k[:, None] * b_v[None, :]
# Apply L2 normalization if enabled
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)

# Compute output: o = sum(h * q, dim=0)
b_o = tl.sum(b_h * b_q[:, None], 0)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
b_q = b_q * scale

# Update pointers for next timestep
p_q += H * K
p_k += H * K
p_o += HV * V
p_v += HV * V
p_b += HV
p_a += HV
# Apply gating to hidden state: h *= exp(g)
b_h *= tl.exp(b_g)

# Store final state back to h0_source with bounds checking
if USE_INITIAL_STATE:
idx = tl.load(h0_indices + i_n).to(tl.int64)
if idx >= 0:
p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V + o_k[:, None] * V +
o_v[None, :])
tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)
# Delta rule: v -= sum(h * k, dim=0)
b_v -= tl.sum(b_h * b_k[:, None], 0)

# Apply beta gating: v *= beta
b_v *= b_beta

# Update hidden state: h += k[:, None] * v[None, :]
b_h += b_k[:, None] * b_v[None, :]

# Compute output: o = sum(h * q, dim=0)
b_o = tl.sum(b_h * b_q[:, None], 0)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)

# Update pointers for next timestep
p_q += stride_q
p_k += stride_k
p_o += HV * V
p_v += stride_v
p_b += stride_b
p_a += stride_a

# Store final state back to h0_source with bounds checking
if USE_INITIAL_STATE:
idx = tl.load(h0_indices + i_n).to(tl.int64)
if idx >= 0:
tl.device_assert(idx < h0_dim0,
"idx out of bounds in h0_source store")
p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V +
o_k[:, None] * V + o_v[None, :])
tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)

i_nh += grid_stride_nh


@input_guard(exclude_args=["initial_state_source"])
def fused_sigmoid_gating_delta_rule_update(
A_log: torch.Tensor,
a: torch.Tensor,
Expand All @@ -181,6 +197,14 @@ def fused_sigmoid_gating_delta_rule_update(
B, T, H, K, V = *k.shape, v.shape[-1]
HV = v.shape[2]
N = B if cu_seqlens is None else len(cu_seqlens) - 1

# Accept native view layouts from forward_decode rather than forcing packed
# copies through input_guard.
stride_q = q.stride(1)
stride_k = k.stride(1)
stride_v = v.stride(1)
stride_a = a.stride(-2)
stride_b = b.stride(-2)
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, "NK > 1 is not supported yet"
Expand All @@ -193,7 +217,10 @@ def fused_sigmoid_gating_delta_rule_update(
assert scale > 0, "scale must be positive"

o = q.new_empty(NK, *v.shape)
grid = (N * HV, NV, NK)
# (NK, NV, N * HV) is found faster than (N * HV, NV, NK)
# As max of grid.z is 65535, we cap grid.z and let each Triton program
# grid-stride across the remaining N * HV tiles.
grid = (NK, NV, min(N * HV, 65535))

if initial_state_source is not None:
s_h0_0, s_h0_1, s_h0_2, s_h0_3 = initial_state_source.stride()
Expand All @@ -205,34 +232,44 @@ def fused_sigmoid_gating_delta_rule_update(
s_h0_0 = 0
slot_num = 0

fused_sigmoid_gating_delta_rule_update_kernel[grid](
A_log=A_log,
a=a,
dt_bias=dt_bias,
softplus_beta=softplus_beta,
softplus_threshold=softplus_threshold,
q=q,
k=k,
v=v,
b=b,
o=o,
h0_source=initial_state_source,
h0_indices=initial_state_indices,
cu_seqlens=cu_seqlens,
scale=scale,
T=T,
s_h0_0=s_h0_0,
h0_dim0=slot_num,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
num_warps=num_warps,
num_stages=num_stages,
)
# input_guard used to set the active CUDA device and make inputs contiguous.
# We keep only the device-context part here so Triton launches on q's device
# without re-packing the decode views.
with custom_device_ctx(q.device.index):
fused_sigmoid_gating_delta_rule_update_kernel[grid](
A_log=A_log,
a=a,
dt_bias=dt_bias,
softplus_beta=softplus_beta,
softplus_threshold=softplus_threshold,
q=q,
k=k,
v=v,
b=b,
o=o,
h0_source=initial_state_source,
h0_indices=initial_state_indices,
cu_seqlens=cu_seqlens,
scale=scale,
T=T,
total_nh=N * HV,
stride_q=stride_q,
stride_k=stride_k,
stride_v=stride_v,
stride_a=stride_a,
stride_b=stride_b,
s_h0_0=s_h0_0,
h0_dim0=slot_num,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
num_warps=num_warps,
num_stages=num_stages,
)
o = o.squeeze(0)
return o
22 changes: 12 additions & 10 deletions tensorrt_llm/_torch/modules/mamba/gdn_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,10 @@ def __init__(
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.num_k_heads_per_tp = divide(self.num_k_heads, self.attn_tp_size)
self.num_v_heads_per_tp = divide(self.num_v_heads, self.attn_tp_size)
self.key_dim_per_tp = self.head_k_dim * self.num_k_heads_per_tp
self.value_dim_per_tp = self.head_v_dim * self.num_v_heads_per_tp

self.conv_kernel_size = config.linear_conv_kernel_dim
self.layer_idx = layer_idx
Expand Down Expand Up @@ -479,17 +483,15 @@ def forward_decode(
conv_state_indices=cache_indices,
)

# Direct slicing instead of torch.split for better performance
key_size = self.key_dim // self.attn_tp_size
query = mixed_qkv[..., :key_size]
key = mixed_qkv[..., key_size : key_size * 2]
value = mixed_qkv[..., key_size * 2 :]
# Reshape from [l, h*d] to [1, l, h, d]
# Keep q/k/v as views over mixed_qkv so the fused decode kernel can
# consume their native strides without forcing packed copies.
query = mixed_qkv[..., : self.key_dim_per_tp]
key = mixed_qkv[..., self.key_dim_per_tp : self.key_dim_per_tp * 2]
value = mixed_qkv[..., self.key_dim_per_tp * 2 :]
seq_len = query.shape[0]
num_heads = query.shape[1] // self.head_k_dim
query = query.view(1, seq_len, num_heads, self.head_k_dim)
key = key.view(1, seq_len, num_heads, self.head_k_dim)
value = value.view(1, seq_len, value.shape[1] // self.head_v_dim, self.head_v_dim)
query = query.view(1, seq_len, self.num_k_heads_per_tp, self.head_k_dim)
key = key.view(1, seq_len, self.num_k_heads_per_tp, self.head_k_dim)
value = value.view(1, seq_len, self.num_v_heads_per_tp, self.head_v_dim)

core_attn_out = fused_sigmoid_gating_delta_rule_update(
A_log=self.A_log,
Expand Down
Loading