Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions fla/ops/kda/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
from fla.ops.common.chunk_o import chunk_bwd_dv_local
from fla.ops.gla.chunk import chunk_gla_bwd_dA, chunk_gla_fwd_o_gk
from fla.ops.kda.chunk_inter import chunk_kda_bwd_dqkwg
from fla.ops.gla.chunk import chunk_gla_fwd_o_gk
from fla.ops.kda.chunk_bwd import chunk_kda_bwd_dAv, chunk_kda_bwd_dqkwg
from fla.ops.kda.chunk_intra import chunk_kda_bwd_intra, chunk_kda_fwd_intra
from fla.ops.kda.gate import kda_gate_bwd, kda_gate_fwd
from fla.ops.kda.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd
Expand Down Expand Up @@ -103,9 +102,12 @@ def chunk_kda_bwd(
chunk_indices=chunk_indices,
use_exp2=True,
)
dv = chunk_bwd_dv_local(
# dAqk = do @ v.T
# dv = A @ do
dAqk, dv = chunk_kda_bwd_dAv(
q=q,
k=k,
v=v_new,
do=do,
A=Aqk,
scale=scale,
Expand All @@ -128,15 +130,6 @@ def chunk_kda_bwd(
chunk_indices=chunk_indices,
use_exp2=True,
)
# dq dk in fp32
dAqk = chunk_gla_bwd_dA(
v=v_new,
do=do,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
chunk_indices=chunk_indices,
)
dq, dk, dw, dg = chunk_kda_bwd_dqkwg(
q=q,
k=k,
Expand Down
137 changes: 133 additions & 4 deletions fla/ops/kda/chunk_inter.py → fla/ops/kda/chunk_bwd.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,94 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang


import torch
import triton
import triton.language as tl

from fla.ops.utils import prepare_chunk_indices
from fla.ops.utils.op import exp2
from fla.utils import autotune_cache_kwargs, check_shared_mem
from fla.utils import IS_NVIDIA_HOPPER, autotune_cache_kwargs, check_shared_mem

BK_LIST = [32, 64] if check_shared_mem() else [16, 32]
BV_LIST = [64, 128] if check_shared_mem('ampere') else [16, 32]
NUM_WARPS = [2, 4] if IS_NVIDIA_HOPPER else [2, 4, 8]


@triton.heuristics({
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in NUM_WARPS
for num_stages in [2, 3, 4]
],
key=['H', 'K', 'V', 'BT', 'BK', 'BV'],
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
def chunk_bwd_kernel_dAv(
q,
k,
v,
A,
do,
dv,
dA,
cu_seqlens,
chunk_indices,
scale,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T

# offset calculation
q += (bos * H + i_h) * K
k += (bos * H + i_h) * K
v += (bos * H + i_h) * V
do += (bos * H + i_h) * V
dv += (bos * H + i_h) * V
dA += (bos * H + i_h) * BT

p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
b_A = tl.load(p_A, boundary_check=(0, 1))

o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t)
b_A = tl.where(m_A, b_A, 0).to(do.dtype.element_ty)

b_dA = tl.zeros([BT, BT], dtype=tl.float32)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1))
p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# [BV, BT]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, BV]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BT, BT]
b_dA += tl.dot(b_do, b_v)
# [BT, BV]
b_dv = tl.dot(b_A.to(b_do.dtype), b_do)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))

p_dA = tl.make_block_ptr(dA, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
b_dA = tl.where(o_t[:, None] >= o_t, b_dA * scale, 0.)
tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))


@triton.heuristics({
Expand All @@ -28,7 +106,7 @@
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
def chunk_kda_bwd_kernel_inter(
def chunk_kda_bwd_kernel_dqkwg(
q,
k,
v,
Expand Down Expand Up @@ -138,6 +216,57 @@ def chunk_kda_bwd_kernel_inter(
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))


def chunk_kda_bwd_dAv(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
do: torch.Tensor,
A: torch.Tensor | None = None,
scale: float = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
chunk_indices: torch.LongTensor | None = None,
) -> torch.Tensor:
Comment thread
yzhangcs marked this conversation as resolved.
Outdated
B, T, H, K, V = *k.shape, do.shape[-1]
BT = chunk_size
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
# H100 can have larger block size
if check_shared_mem('hopper', k.device.index):
CONST_TILING = 128
elif check_shared_mem:
CONST_TILING = 64
else:
CONST_TILING = 32
Comment on lines +235 to +240

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Bug: Missing function call parentheses makes the else branch unreachable.

Line 237 checks elif check_shared_mem: which evaluates the function object itself (always truthy), not the function's return value. This makes CONST_TILING = 32 unreachable.

🔎 Proposed fix
     if check_shared_mem('hopper', k.device.index):
         CONST_TILING = 128
-    elif check_shared_mem:
+    elif check_shared_mem('ampere', k.device.index):
         CONST_TILING = 64
     else:
         CONST_TILING = 32

Or if you want to check for any shared memory capability without architecture specification:

-    elif check_shared_mem:
+    elif check_shared_mem():
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if check_shared_mem('hopper', k.device.index):
CONST_TILING = 128
elif check_shared_mem:
CONST_TILING = 64
else:
CONST_TILING = 32
if check_shared_mem('hopper', k.device.index):
CONST_TILING = 128
elif check_shared_mem('ampere', k.device.index):
CONST_TILING = 64
else:
CONST_TILING = 32
🤖 Prompt for AI Agents
In fla/ops/kda/chunk_bwd.py around lines 235-240, the elif currently tests the
function object (elif check_shared_mem:) instead of calling it, making the final
else unreachable; change that line to call the function (e.g., elif
check_shared_mem(k.device.index):) so it evaluates the function's boolean result
and allows CONST_TILING = 32 to be reached when appropriate.

BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING)
BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)

dA = v.new_empty(B, T, H, BT, dtype=torch.float)
dv = torch.empty_like(do)
grid = (NT, B * H)
chunk_bwd_kernel_dAv[grid](
q=q,
k=k,
v=v,
A=A,
do=do,
dv=dv,
dA=dA,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
)
return dA, dv


def chunk_kda_bwd_dqkwg(
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -165,7 +294,7 @@ def chunk_kda_bwd_dqkwg(
dw = torch.empty_like(w)
dg = torch.empty_like(g)
def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H)
chunk_kda_bwd_kernel_inter[grid](
chunk_kda_bwd_kernel_dqkwg[grid](
q=q,
k=k,
v=v,
Expand Down
Loading