Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3216bce
Add gdn fusions
hellozhuo-amd Apr 10, 2026
9811501
style: fix ruff F841 and black-format Triton PR files
hellozhuo-amd Apr 10, 2026
b26972f
Update fused_rearrange_sigmoid_gdr.py
hellozhuo-amd Apr 13, 2026
8695885
Update op_tests
hellozhuo-amd Apr 13, 2026
b69cb72
Fix BLACK format problem
hellozhuo-amd Apr 13, 2026
c4db40f
Fix black check failure
hellozhuo-amd Apr 13, 2026
ac48df0
Update test_fused_rearrange_sigmoid_gdr.py
hellozhuo-amd Apr 13, 2026
b9f33dd
Merge branch 'origin/main' into zhuo/qwen3_triton_gdn
hellozhuo-amd Apr 13, 2026
56a2b85
Merge branch 'main' into zhuo/qwen3_triton_gdn
hellozhuo-amd Apr 14, 2026
f214128
Merge branch 'main' into zhuo/qwen3_triton_gdn
hellozhuo-amd Apr 16, 2026
5084462
Merge branch 'main' into zhuo/qwen3_triton_gdn
juuso-oskari Apr 20, 2026
3d084e2
Merge branch 'main' into zhuo/qwen3_triton_gdn
juuso-oskari Apr 21, 2026
b2ab876
Merge branch 'main' into zhuo/qwen3_triton_gdn
juuso-oskari Apr 21, 2026
bdc9a96
Merge branch 'main' into zhuo/qwen3_triton_gdn
hellozhuo-amd Apr 22, 2026
7fbd9ad
Merge branch 'main' into zhuo/qwen3_triton_gdn
hellozhuo-amd Apr 22, 2026
3ffd13c
Replace _fast with _single_token for causal conv1d update kernels for…
hellozhuo-amd Apr 22, 2026
9946258
Fix blck format error
hellozhuo-amd Apr 22, 2026
b8ea372
Merge branch 'main' into zhuo/qwen3_triton_gdn
hellozhuo-amd Apr 22, 2026
0f41d78
Merge branch 'main' into zhuo/qwen3_triton_gdn
hellozhuo-amd Apr 23, 2026
2aa2493
refactor(triton): rename gated RMSNorm+FP8 op to fused_rms_gated_fp8_…
hellozhuo-amd Apr 23, 2026
35035ff
Merge branch 'main' into zhuo/qwen3_triton_gdn
juuso-oskari Apr 24, 2026
711c9e9
Merge branch 'main' into zhuo/qwen3_triton_gdn
hellozhuo-amd Apr 24, 2026
d5e7712
Merge branch 'main' into zhuo/qwen3_triton_gdn
nholmber May 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
507 changes: 507 additions & 0 deletions aiter/ops/triton/_triton_kernels/causal_conv1d_update_single_token.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
This module provides optimized Triton kernels for decode/inference operations.
"""

from .fused_rearrange_sigmoid_gdr import (
fused_rearrange_sigmoid_gated_delta_rule_update_kernel,
)
from .fused_recurrent import _fused_recurrent_gated_delta_rule_fwd_kernel
from .fused_sigmoid_gating_recurrent import fused_sigmoid_gating_delta_rule_update

__all__ = [
"_fused_recurrent_gated_delta_rule_fwd_kernel",
"fused_rearrange_sigmoid_gated_delta_rule_update_kernel",
"fused_sigmoid_gating_delta_rule_update",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

import triton
import triton.language as tl


@triton.heuristics(
{
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
"IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
}
)
@triton.jit(do_not_specialize=["N", "T"])
def fused_rearrange_sigmoid_gated_delta_rule_update_kernel(
A_log,
a,
b,
dt_bias,
beta,
threshold,
qkv,
o,
h0,
ht,
cu_seqlens,
ssm_state_indices,
num_accepted_tokens,
scale,
N: tl.int64, # num of sequences
T: tl.int64, # num of tokens
B: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
stride_qkv_l: tl.constexpr,
stride_qkv_hd: tl.constexpr,
stride_init_state_token: tl.constexpr,
stride_final_state_token: tl.constexpr,
stride_indices_seq: tl.constexpr,
stride_indices_tok: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
IS_KDA: tl.constexpr,
):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)
if IS_VARLEN:
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int64),
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
)
all = T
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
all = B * T

if T == 0:
return

o_k = i_k * BK + tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)

p_q = qkv + bos * stride_qkv_l + ((i_h * K) + o_k) * stride_qkv_hd
p_k = qkv + bos * stride_qkv_l + (H * K + (i_h * K) + o_k) * stride_qkv_hd
p_v = qkv + bos * stride_qkv_l + (2 * H * K + (i_hv * V) + o_v) * stride_qkv_hd

p_A_log = A_log + i_hv
if not IS_KDA:
p_a = a + bos * HV + i_hv
p_dt_bias = dt_bias + i_hv
else:
p_a = a + (bos * HV + i_hv) * K + o_k
p_dt_bias = dt_bias + i_hv * K + o_k

p_b = b + bos * HV + i_hv
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v

mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_v[:, None] & mask_k[None, :]

b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
if IS_CONTINUOUS_BATCHING:
if IS_SPEC_DECODING:
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
else:
i_t = 0
state_idx = tl.load(
ssm_state_indices + i_n * stride_indices_seq + i_t * stride_indices_tok
).to(tl.int64)
if state_idx < 0:
return
p_h0 = h0 + state_idx * stride_init_state_token
Comment thread
hellozhuo-amd marked this conversation as resolved.
else:
p_h0 = h0 + bos * HV * V * K
Comment thread
juuso-oskari marked this conversation as resolved.
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)

for i_t in range(0, T):
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_b = tl.load(p_b).to(tl.float32)

x = tl.load(p_a).to(tl.float32) + tl.load(p_dt_bias).to(tl.float32)
softplus_x = tl.where(
beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
)
b_g = -tl.exp(tl.load(p_A_log).to(tl.float32)) * softplus_x

b_beta = tl.sigmoid(b_b.to(tl.float32))

if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q * tl.rsqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k * tl.rsqrt(tl.sum(b_k * b_k) + 1e-6)
b_q = b_q * scale
if not IS_KDA:
b_h *= tl.exp(b_g)
else:
b_h *= tl.exp(b_g[None, :])
b_v -= tl.sum(b_h * b_k[None, :], 1)
b_v *= b_beta
b_h += b_v[:, None] * b_k[None, :]
b_o = tl.sum(b_h * b_q[None, :], 1)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)

if INPLACE_FINAL_STATE:
final_state_idx = tl.load(
ssm_state_indices + i_n * stride_indices_seq + i_t * stride_indices_tok
).to(tl.int64)
if final_state_idx >= 0:
p_ht = ht + final_state_idx * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
else:
p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)

p_q += stride_qkv_l
p_k += stride_qkv_l
p_v += stride_qkv_l
p_o += HV * V
p_b += HV
p_a += HV
124 changes: 124 additions & 0 deletions aiter/ops/triton/_triton_kernels/quant/fused_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,3 +766,127 @@ def _fused_silu_mul_fp8_per_tensor_static_quant_kernel(
quant_fp8_out.to(out_fp8_ptr.dtype.element_ty),
mask=mask,
)


# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.


@triton.heuristics(
{
"HAS_BIAS": lambda args: args["B"] is not None,
"HAS_Z": lambda args: args["Z"] is not None,
}
)
@triton.jit
def _fused_rms_gated_fp8_group_quant_kernel(
X,
W,
B,
Z,
Y_quant,
Scales,
stride_x_row,
stride_z_row,
stride_y_row,
stride_s_row,
stride_s_g,
M,
N: tl.constexpr,
eps,
RMS_TILE: tl.constexpr,
ROWS_PER_BLOCK: tl.constexpr,
GROUP_SIZE: tl.constexpr,
NUM_GROUPS: tl.constexpr,
BLOCK_G: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_Z: tl.constexpr,
NORM_BEFORE_GATE: tl.constexpr,
FP8_MIN: tl.constexpr,
FP8_MAX: tl.constexpr,
USE_UE8M0: tl.constexpr,
FP8_MIN_SCALING_FACTOR: tl.constexpr,
ACTIVATION: tl.constexpr,
):
row_start = tl.program_id(0) * ROWS_PER_BLOCK
rows = row_start + tl.arange(0, ROWS_PER_BLOCK)
row_mask_1d = rows < M

# --- Full-row RMS: accumulate sum of squares in float32 ---
sumsq = tl.zeros([ROWS_PER_BLOCK], dtype=tl.float32)
off = 0
while off < N:
cols = tl.arange(0, RMS_TILE) + off
col_mask = cols < N
mask = row_mask_1d[:, None] & col_mask[None, :]
row_offsets = rows[:, None] * stride_x_row
col_offsets = cols[None, :]
X_base = X + row_offsets + col_offsets
x = tl.load(X_base, mask=mask, other=0.0).to(tl.float32)
if HAS_Z and not NORM_BEFORE_GATE:
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
if ACTIVATION == "swish" or ACTIVATION == "silu":
x *= z * tl.sigmoid(z)
elif ACTIVATION == "sigmoid":
x *= tl.sigmoid(z)
xbar = tl.where(mask, x, 0.0)
sumsq += tl.sum(xbar * xbar, axis=1)
off += RMS_TILE

var = sumsq / N
rstd = tl.rsqrt(var + eps)

# --- Per-group: normalize (when NORM_BEFORE_GATE), linear, optional gate, FP8 ---
for g in range(NUM_GROUPS):
col0 = g * GROUP_SIZE
cols = tl.arange(0, BLOCK_G) + col0
col_mask = cols < N
mask = row_mask_1d[:, None] & col_mask[None, :]
row_offsets = rows[:, None] * stride_x_row
col_offsets = cols[None, :]
X_base = X + row_offsets + col_offsets
x = tl.load(X_base, mask=mask, other=0.0).to(tl.float32)

if HAS_Z and not NORM_BEFORE_GATE:
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
if ACTIVATION == "swish" or ACTIVATION == "silu":
x *= z * tl.sigmoid(z)
elif ACTIVATION == "sigmoid":
x *= tl.sigmoid(z)

x_hat = x * rstd[:, None]

w_mask = cols < N
w = tl.load(W + cols, mask=w_mask, other=0.0).to(tl.float32)
if HAS_BIAS:
b = tl.load(B + cols, mask=w_mask, other=0.0).to(tl.float32)
y = x_hat * w[None, :] + b[None, :]
else:
y = x_hat * w[None, :]

if HAS_Z and NORM_BEFORE_GATE:
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
if ACTIVATION == "swish" or ACTIVATION == "silu":
y *= z * tl.sigmoid(z)
elif ACTIVATION == "sigmoid":
y *= tl.sigmoid(z)

abs_y = tl.where(mask, tl.abs(y), 0.0)
absmax = tl.max(abs_y, axis=1)
scales_raw = absmax / FP8_MAX
if USE_UE8M0:
scales_raw = tl.exp2(tl.ceil(tl.log2(scales_raw)))
scales = tl.maximum(scales_raw, FP8_MIN_SCALING_FACTOR)

y_scaled = y / scales[:, None]
y_quant = tl.maximum(tl.minimum(y_scaled, FP8_MAX), FP8_MIN)

Y_base = Y_quant + rows[:, None] * stride_y_row + col_offsets
tl.store(Y_base, y_quant.to(Y_quant.dtype.element_ty), mask=mask)

S_ptr = Scales + rows * stride_s_row + g * stride_s_g
tl.store(S_ptr, scales, mask=row_mask_1d)
Loading
Loading