Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 27 additions & 28 deletions fla/ops/common/chunk_delta_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,19 +316,19 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
b_dv += tl.dot(b_k, b_dh1.to(b_k.dtype))

if K > 64:
p_k2 = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_k2 = tl.load(p_k2, boundary_check=(0, 1))
b_dv += tl.dot(b_k2, b_dh2.to(b_k.dtype))
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype))

if K > 128:
p_k3 = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_k3 = tl.load(p_k3, boundary_check=(0, 1))
b_dv += tl.dot(b_k3, b_dh3.to(b_k.dtype))
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype))

if K > 192:
p_k4 = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_k4 = tl.load(p_k4, boundary_check=(0, 1))
b_dv += tl.dot(b_k4, b_dh4.to(b_k.dtype))
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype))

tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))

Expand All @@ -342,33 +342,32 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
b_dh1 *= bg_last
b_dh1 += tl.dot(b_q, b_do.to(b_q.dtype))-tl.dot(b_d, b_dv.to(b_d.dtype))
if K > 64:
p_q2 = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
p_d2 = tl.make_block_ptr(d, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
b_q2 = tl.load(p_q2, boundary_check=(0, 1))
b_q2 = (b_q2 * scale).to(b_q.dtype)
b_d2 = tl.load(p_d2, boundary_check=(0, 1))
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
p_d = tl.make_block_ptr(d, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
b_d = tl.load(p_d, boundary_check=(0, 1))
if USE_G:
b_dh2 *= bg_last
b_dh2 += tl.dot(b_q2, b_do.to(b_q.dtype))-tl.dot(b_d2, b_dv.to(b_d.dtype))
b_dh2 += tl.dot(b_q, b_do.to(b_q.dtype))-tl.dot(b_d, b_dv.to(b_d.dtype))
if K > 128:
p_d3 = tl.make_block_ptr(d, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
p_q3 = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
b_q3 = tl.load(p_q3, boundary_check=(0, 1))
b_q3 = (b_q3 * scale).to(b_q.dtype)
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
p_d = tl.make_block_ptr(d, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
b_d = tl.load(p_d, boundary_check=(0, 1))
if USE_G:
b_dh3 *= bg_last
b_d3 = tl.load(p_d3, boundary_check=(0, 1))
b_dh3 += tl.dot(b_q3, b_do.to(b_q.dtype))-tl.dot(b_d3, b_dv.to(b_d.dtype))
b_dh3 += tl.dot(b_q, b_do.to(b_q.dtype))-tl.dot(b_d, b_dv.to(b_d.dtype))
if K > 192:
p_d4 = tl.make_block_ptr(d, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
b_d4 = tl.load(p_d4, boundary_check=(0, 1))
p_q4 = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
b_q4 = tl.load(p_q4, boundary_check=(0, 1))
b_q4 = (b_q4 * scale).to(b_q.dtype)
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
p_d = tl.make_block_ptr(d, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
b_d = tl.load(p_d, boundary_check=(0, 1))
if USE_G:
b_dh4 *= bg_last
b_d4 = tl.load(p_d4, boundary_check=(0, 1))
b_dh4 += tl.dot(b_q4, b_do.to(b_q.dtype))-tl.dot(b_d4, b_dv.to(b_d.dtype))
b_dh4 += tl.dot(b_q, b_do.to(b_q.dtype))-tl.dot(b_d, b_dv.to(b_d.dtype))

if USE_INITIAL_STATE:
p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
Expand Down
26 changes: 23 additions & 3 deletions fla/ops/common/chunk_scaled_dot_kkt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import triton.language as tl

from fla.ops.common.utils import prepare_chunk_indices
from fla.ops.utils.op import safe_exp


@triton.heuristics({
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
'USE_G': lambda args: args['g_cumsum'] is not None
})
@triton.autotune(
configs=[
Expand All @@ -26,7 +28,9 @@
def chunk_scaled_dot_kkt_fwd_kernel(
k,
beta,
g_cumsum,
A,
Ag,
cu_seqlens,
chunk_indices,
T,
Expand All @@ -35,6 +39,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
BT: tl.constexpr,
BK: tl.constexpr,
IS_VARLEN: tl.constexpr,
USE_G: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
Expand All @@ -60,11 +65,20 @@ def chunk_scaled_dot_kkt_fwd_kernel(
p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0))
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))

if USE_G:
p_g = tl.make_block_ptr(g_cumsum + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_g_diff = b_g[:, None] - b_g[None, :]
b_Ag = b_A * safe_exp(b_g_diff)
p_Ag = tl.make_block_ptr(Ag + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0))
tl.store(p_Ag, b_Ag.to(p_Ag.dtype.element_ty), boundary_check=(0, 1))


def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
beta: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor],
g_cumsum: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32
) -> torch.Tensor:
Expand All @@ -76,6 +90,9 @@ def chunk_scaled_dot_kkt_fwd(
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
g_cumsum (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`.
Default: None
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
Expand All @@ -92,15 +109,18 @@ def chunk_scaled_dot_kkt_fwd(
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
Ag = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) if g_cumsum is not None else None
chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
k=k,
beta=beta,
g_cumsum=g_cumsum,
A=A,
Ag=Ag,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
)
return A
return A, Ag
48 changes: 13 additions & 35 deletions fla/ops/delta_rule/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
from typing import Optional

import torch
import triton
from einops import rearrange

from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
from fla.ops.delta_rule.wy_fast import bwd_prepare_wy_repr, fwd_prepare_wy_repr, fwd_recompute_w_u
from fla.ops.delta_rule.wy_fast import prepare_wy_repr_bwd, prepare_wy_repr_fwd, recompute_w_u_fwd
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard


Expand All @@ -24,38 +23,32 @@ def chunk_delta_rule_fwd(
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64
):
T = q.shape[1]
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
# obtain WY representation. u is actually the new v.
w, u, A = fwd_prepare_wy_repr(
w, u, A = prepare_wy_repr_fwd(
k=k,
v=v,
beta=beta,
cu_seqlens=cu_seqlens,
chunk_size=BT
)

h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
k=k,
w=w,
u=u,
g=None,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
chunk_size=BT
cu_seqlens=cu_seqlens
)

o = chunk_fwd_o(
q=q,
k=k,
v=v_new,
h=h,
g=None,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=BT
cu_seqlens=cu_seqlens
)
return o, A, final_state

Expand All @@ -71,17 +64,13 @@ def chunk_delta_rule_bwd(
do: torch.Tensor,
dht: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64
):
T = q.shape[1]
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
w, u = fwd_recompute_w_u(
w, u = recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
cu_seqlens=cu_seqlens,
chunk_size=BT
cu_seqlens=cu_seqlens
)
h, v_new, _ = chunk_gated_delta_rule_fwd_h(
k=k,
Expand All @@ -91,16 +80,14 @@ def chunk_delta_rule_bwd(
initial_state=initial_state,
output_final_state=False,
cu_seqlens=cu_seqlens,
chunk_size=BT
)
dv = chunk_bwd_dv_local(
q=q,
k=k,
do=do,
g=None,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=BT
cu_seqlens=cu_seqlens
)
dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
q=q,
Expand All @@ -112,8 +99,7 @@ def chunk_delta_rule_bwd(
do=do,
dv=dv,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=BT
cu_seqlens=cu_seqlens
)
dq, dk, dw, _ = chunk_bwd_dqkwg(
q=q,
Expand All @@ -126,18 +112,16 @@ def chunk_delta_rule_bwd(
dh=dh,
g=None,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=BT
cu_seqlens=cu_seqlens
)
dk2, dv, db = bwd_prepare_wy_repr(
dk2, dv, db = prepare_wy_repr_bwd(
k=k,
v=v,
beta=beta,
A=A,
dw=dw,
du=dv,
cu_seqlens=cu_seqlens,
chunk_size=BT
cu_seqlens=cu_seqlens
)
dk.add_(dk2)
return dq, dk, dv, db, dh0
Expand All @@ -160,9 +144,6 @@ def forward(
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = True
):
T = q.shape[1]
chunk_size = min(64, max(triton.next_power_of_2(T), 16))

q_orig = q
k_orig = k

Expand All @@ -179,10 +160,8 @@ def forward(
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size
)
ctx.save_for_backward(q_orig, k_orig, v, beta, A, initial_state)
ctx.chunk_size = chunk_size
ctx.scale = scale
ctx.cu_seqlens = cu_seqlens
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
Expand Down Expand Up @@ -212,8 +191,7 @@ def backward(
initial_state=initial_state,
do=do,
dht=dht,
cu_seqlens=ctx.cu_seqlens,
chunk_size=ctx.chunk_size
cu_seqlens=ctx.cu_seqlens
)
if use_qk_l2norm_in_kernel:
dq = l2norm_bwd(q_orig, dq)
Expand Down
Loading
Loading