diff --git a/fla/ops/oja2/OJA.pdf b/fla/ops/oja2/OJA.pdf new file mode 100644 index 0000000000..f2809df718 Binary files /dev/null and b/fla/ops/oja2/OJA.pdf differ diff --git a/fla/ops/oja2/__init__.py b/fla/ops/oja2/__init__.py new file mode 100644 index 0000000000..9a96d4bc78 --- /dev/null +++ b/fla/ops/oja2/__init__.py @@ -0,0 +1,7 @@ +from .chunk import chunk_oja2 +from .fused_recurrent import fused_recurrent_oja2 + +__all__ = [ + "chunk_oja2", + "fused_recurrent_oja2" +] diff --git a/fla/ops/oja2/chunk.py b/fla/ops/oja2/chunk.py new file mode 100644 index 0000000000..0cd860044e --- /dev/null +++ b/fla/ops/oja2/chunk.py @@ -0,0 +1,381 @@ +# # -*- coding: utf-8 -*- +# # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import warnings +from typing import Optional + +import torch + +from fla.modules.l2norm import l2norm_bwd, l2norm_fwd +from fla.ops.utils import chunk_local_cumsum, solve_tril +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + +from fla.ops.oja2.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd +from fla.ops.oja2.chunk_kkt import chunk_scaled_dot_kkt_fwd, chunk_scaled_dot_kkt_bwd_gk +from fla.ops.oja2.chunk_h import ( + chunk_oja2_fwd_h, + chunk_oja2_bwd_dhu, + chunk_oja2_bwd_dvwg_h) +from fla.ops.oja2.chunk_o import ( + chunk_oja2_fwd_o, + chunk_oja2_bwd_dA, + chunk_oja2_bwd_dqk, + chunk_oja2_bwd_dv_o, + ) + + + + +def chunk_oja2_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gv: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + g_cumsum: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None +): + if g_cumsum: + gv = chunk_local_cumsum(gv, chunk_size=64, cu_seqlens=cu_seqlens) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd( + k=v, + gk=gv, + beta=beta, + cu_seqlens=cu_seqlens, + output_dtype=torch.float32 + ) + A = solve_tril( + A=A, + cu_seqlens=cu_seqlens, + output_dtype=k.dtype + ) + # w = Avg, u = Ak + w, u, vg = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + gv=gv, + cu_seqlens=cu_seqlens, + ) + # grid in K + h, k_new, final_state = chunk_oja2_fwd_h( + v=vg, + w=w, + u=u, + gv=gv, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + _, o = chunk_oja2_fwd_o( + q=q, + k=k_new, + v=v, + h=h, + gv=gv, + scale=scale, + cu_seqlens=cu_seqlens, + ) + return gv, o, A, final_state + + +def chunk_oja2_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gv: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + o: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + dgk: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, +): + w, u, vg = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + gv=gv, + cu_seqlens=cu_seqlens, + ) + # w = w.to(torch.float32) + # u = u.to(torch.float32) + # vg = vg.to(torch.float32) + h, k_new, _ = chunk_oja2_fwd_h( + v=vg, + w=w, + u=u, + gv=gv, + initial_state=initial_state, + output_final_state=False, + cu_seqlens=cu_seqlens, + ) + """ + 对于S = g_last * S + Vg @ (U - WS) + O = g_i * (QS + tri(Q @ (U - WS)) (V/g)) + 1. 计算dA = do * g_i * v/g + 2. 计算dA里面的dk_new=dA * q, 顺便收集tri(A), 计算全部dq = do * g_i * S :: 🚩所有dq完毕 + 3. 计算dS, 进一步收集所有S里面的dk_new, 计算递归中的dS以及dk_new中的dS :: 🚩所有dk_new(du), dS, dS0完毕 + 4. 计算o递归里的dv = do * g_i * A(细粒度), 顺便收集dg + 5. 计算S中的dv以及dk_new里的dw :: 🚩所有dw, dv完毕 + @ 至此dq, dk_new, dv, dw, du, dS, dS0完毕,还需要最后解开WY表征 + 6. 先计算W = M * beta * AV以及U = M * beta * K外面的dbeta, dk, dv, dg, 存下来dM + 7. 通过存下来的dM计算内部的dv, dbeta, dg :: 🚩所有dq, dk, dv, dw, du, dS, dS0, dbeta, dg完毕 + """ + # grid = (NV, NT * NC * NC, B * H) + + dAqk = chunk_oja2_bwd_dA( + v=v, + gv=gv, + do=do, + scale=scale, + cu_seqlens=cu_seqlens, + ) + + # (NK, NT, B * H) + Aqk, dq, dk_new = chunk_oja2_bwd_dqk( + q=q, + k=k_new, + h=h, + gv=gv, + dA=dAqk, + do=do, + scale=scale, + cu_seqlens=cu_seqlens, + ) + + # (NK, B*H) + dh, dh0, dk_new = chunk_oja2_bwd_dhu( + q=q, + vg=vg, + w=w, + gv=gv, + h0=initial_state, + dht=dht, + do=do, + dk=dk_new, + scale=scale, + cu_seqlens=cu_seqlens, + states_in_fp32=False, + ) + + # grid = (NV, NT, B * H) + dv, dw, dgv_last = chunk_oja2_bwd_dvwg_h( + k=k_new, + v=v, + gv=gv, + h=h, + dh=dh, + dk=dk_new, + dgk=dgk, + cu_seqlens=cu_seqlens, + ) + + # (NV, NT * NC, B * H) + dv, dgv1 = chunk_oja2_bwd_dv_o( + v=v, + gv=gv, + o=o, + A=Aqk, + dv=dv, + do=do, + cu_seqlens=cu_seqlens, + ) + + # (NT, B * H) + dk, dv1, db, dgv2, dAvv = prepare_wy_repr_bwd( + k=k, + v=v, + beta=beta, + gv=gv, + A=A, + dw=dw, + du=dk_new, + cu_seqlens=cu_seqlens, + ) + + # (NK, NT * NC, B * H) + dv2, dgv3, db2 = chunk_scaled_dot_kkt_bwd_gk( + k=v, + g=gv, + beta=beta, + dA=dAvv, + cu_seqlens=cu_seqlens, + ) + + dv = dv.add_(dv1).add_(dv2) + db = db.add_(db2) + dgv = dgv_last.add_(chunk_local_cumsum(dgv1.add_(dgv2).add_(dgv3), chunk_size=64, reverse=True, cu_seqlens=cu_seqlens)) + return dq, dk, dv, db, dgv, dh0 + + +class ChunkOJA2Function(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gv: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_q_l2norm: bool = False, + use_k_l2norm: bool = False, + ): + q_rstd, k_rstd = None, None + if use_q_l2norm: + q, q_rstd = l2norm_fwd(q) + if use_k_l2norm: + k, k_rstd = l2norm_fwd(k) + + gv, o, A, final_state = chunk_oja2_fwd( + q=q, + k=k, + v=v, + gv=gv, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + ctx.save_for_backward(q, q_rstd, k, k_rstd, v, gv, beta, A, o, initial_state, cu_seqlens) + ctx.scale = scale + ctx.use_q_l2norm = use_q_l2norm + ctx.use_k_l2norm = use_k_l2norm + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward( + ctx, + do: torch.Tensor, + dht: torch.Tensor + ): + q, q_rstd, k, k_rstd, v, gv, beta, A, o, initial_state, cu_seqlens = ctx.saved_tensors + dq, dk, dv, db, dg, dh0 = chunk_oja2_bwd( + q=q, + k=k, + v=v, + gv=gv, + beta=beta, + A=A, + o=o, + scale=ctx.scale, + initial_state=initial_state, + do=do, + dht=dht, + cu_seqlens=cu_seqlens, + ) + # === 遍历检查所有梯度,定位具体是哪个 NaN === + # 将变量名和tensor对应起来 + # grad_tensors = { + # 'dq': dq, 'dk': dk, 'dv': dv, 'db': db, + # 'dg': dg, 'dh0': dh0 + # } + + # for name, t in grad_tensors.items(): + # if t is not None and torch.isnan(t).any(): + # import os + # import torch.distributed as dist + + # # 获取 Rank ID + # # try: + # # rank = dist.get_rank() if dist.is_initialized() else 0 + # # except: + # # rank = 0 + # rank = 0 + + # base_dir = "/mnt/moonfs/hujiaxi-m2/oja_nan_12" + # os.makedirs(base_dir, exist_ok=True) + + # # 保存路径:nan_dump_rank{卡号}.pt + # save_path = os.path.join(base_dir, f"nan_dump_rank{rank}.pt") + + # torch.save({ + # "q": q, + # "k": k, + # "v": v, + # "beta": beta, + # "gv": gv, + # "do": do, + # "cu_seqlens": cu_seqlens, + # "error_source": name # 顺便把出错的变量名也存进文件 + # }, save_path) + + # # 明确报错:指出是哪个变量出的问题 + # raise RuntimeError(f"NaN detected in [{name}] on Rank {rank}! Context saved to: {save_path}") + if ctx.use_q_l2norm: + dq = l2norm_bwd(q, q_rstd, dq) + if ctx.use_k_l2norm: + dk = l2norm_bwd(k, k_rstd, dk) + return dq.to(q), dk.to(k), dv.to(v), dg.to(gv), db.to(beta), None, dh0, None, None, None, None + + +@torch.compiler.disable +def chunk_oja2( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gv: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + use_q_l2norm: bool = False, + use_k_l2norm: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + **kwargs, +): + if 'head_first' in kwargs: + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + if 'use_qk_l2norm_in_kernel' in kwargs and (not use_q_l2norm and not use_k_l2norm): + use_q_l2norm = True + use_k_l2norm = True + + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = ChunkOJA2Function.apply( + q, + k, + v, + gv, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_q_l2norm, + use_k_l2norm + ) + return o, final_state diff --git a/fla/ops/oja2/chunk_h.py b/fla/ops/oja2/chunk_h.py new file mode 100644 index 0000000000..0eb7251888 --- /dev/null +++ b/fla/ops/oja2/chunk_h.py @@ -0,0 +1,821 @@ +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets +from fla.ops.utils.op import exp +from fla.utils import is_nvidia_hopper, use_cuda_graph +from fla.utils import check_shared_mem + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] + + +@triton.heuristics({ + 'USE_GV': lambda args: args['gv'] is not None, + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'SAVE_NEW_KEY': lambda args: args['k_new'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + for BK in [32, 64] + ], + key=['H', 'K', 'V', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_oja2_fwd_kernel_h_blockdim64( + v, + u, + w, + k_new, + gv, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_KEY: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + # (triton.cdiv(K, meta['BK']), N*H) + i_k, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h1 = tl.zeros([BK, 64], dtype=tl.float32) + if V > 64: + b_h2 = tl.zeros([BK, 64], dtype=tl.float32) + if V > 128: + b_h3 = tl.zeros([BK, 64], dtype=tl.float32) + if V > 192: + b_h4 = tl.zeros([BK, 64], dtype=tl.float32) + + # calculate offset + h += ((boh * H + i_h) * K*V).to(tl.int64) + v += ((bos * H + i_h) * V).to(tl.int64) + u += ((bos * H + i_h) * K).to(tl.int64) + w += ((bos * H + i_h) * V).to(tl.int64) + if SAVE_NEW_KEY: + k_new += ((bos * H + i_h) * K).to(tl.int64) + stride_v = H*V + stride_h = H*K*V + stride_k = H*K + if USE_INITIAL_STATE: + h0 = h0 + i_nh * K*V + if STORE_FINAL_STATE: + ht = ht + i_nh * K*V + BV=64 + + # load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (i_k * BK, 0), (BK, 64), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if V > 64: + p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (i_k * BK, 64), (BK, 64), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if V > 128: + p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (i_k * BK, 128), (BK, 64), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if V > 192: + p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (i_k * BK, 192), (BK, 64), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (i_k * BK, 0), (BK, 64), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if V > 64: + p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (i_k * BK, 64), (BK, 64), (1, 0)) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) + if V > 128: + p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (i_k * BK, 128), (BK, 64), (1, 0)) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) + if V > 192: + p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (i_k * BK, 192), (BK, 64), (1, 0)) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) + + p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_k = tl.dot(b_w, tl.trans(b_h1).to(b_w.dtype)) # BT BK + if V > 64: + p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_k += tl.dot(b_w, tl.trans(b_h2).to(b_w.dtype)) + if V > 128: + p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_k += tl.dot(b_w, tl.trans(b_h3).to(b_w.dtype)) + if V > 192: + p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_k += tl.dot(b_w, tl.trans(b_h4).to(b_w.dtype)) + + p_u = tl.make_block_ptr(u, (T, K), (stride_k, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_u, boundary_check=(0, 1)) - b_k + + if SAVE_NEW_KEY: + p_k = tl.make_block_ptr(k_new, (T, K), (stride_k, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_k, b_k.to(p_k.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + + if USE_GV: + o_v1 = tl.arange(0, 64) + b_gk_last1 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v1, mask=(o_v1 < V), other=0.) + b_h1 *= exp(b_gk_last1)[None, :] + if K > 64: + o_v2 = 64 + o_v1 + b_gk_last2 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v2, mask=(o_v2 < V), other=0.) + b_h2 *= exp(b_gk_last2)[None, :] + if K > 128: + o_v3 = 128 + o_v1 + b_gk_last3 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v3, mask=(o_v3 < V), other=0.) + b_h3 *= exp(b_gk_last3)[None, :] + if K > 192: + o_v4 = 192 + o_v1 + b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < K), other=0.) + b_h4 *= exp(b_gk_last4)[None, :] + + b_k = b_k.to(v.dtype.element_ty) # BT BK + + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) # BT BV + b_h1 += tl.dot(tl.trans(b_k), b_v) + if V > 64: + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_h2 += tl.dot(tl.trans(b_k), b_v) + if V > 128: + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_h3 += tl.dot(tl.trans(b_k), b_v) + if V > 192: + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_h4 += tl.dot(tl.trans(b_k), b_v) + # epilogue + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (i_k * BK, 0), (BK, 64), (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if V > 64: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (i_k * BK, 64), (BK, 64), (1, 0)) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if V > 128: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (i_k * BK, 128), (BK, 64), (1, 0)) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if V > 192: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (i_k * BK, 192), (BK, 64), (1, 0)) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_oja2_fwd_h( + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + gv: Optional[torch.Tensor] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_key: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, V, K = *v.shape, u.shape[-1] + BT = chunk_size + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + assert V <= 256, "current kernel does not support head dimension larger than 256." + + h = v.new_empty(B, NT, H, K, V) + final_state = v.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + + k_new = torch.empty_like(u) if save_new_key else None + def grid(meta): return (triton.cdiv(K, meta['BK']), N*H) + chunk_oja2_fwd_kernel_h_blockdim64[grid]( + v=v, + u=u, + w=w, + k_new=k_new, + gv=gv, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT + ) + return h, k_new, final_state + + + + + +@triton.heuristics({ + 'USE_GV': lambda args: args['gv'] is not None, + 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [4, 3, 2] + for BK in [64, 32] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'USE_GV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_oja2_bwd_kernel_dhu_blockdim64( + q, + vg, + w, + gv, + dht, + dh0, + do, + dh, + dk, + dk2, + cu_seqlens, + chunk_offsets, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_k, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh1 = tl.zeros([BK, 64], dtype=tl.float32) + if V > 64: + b_dh2 = tl.zeros([BK, 64], dtype=tl.float32) + if V > 128: + b_dh3 = tl.zeros([BK, 64], dtype=tl.float32) + if V > 192: + b_dh4 = tl.zeros([BK, 64], dtype=tl.float32) + + # calculate offset + q += ((bos * H + i_h) * K).to(tl.int64) + vg += ((bos * H + i_h) * V).to(tl.int64) + w += ((bos * H + i_h) * V).to(tl.int64) + do += ((bos * H + i_h) * V).to(tl.int64) + dk += ((bos * H + i_h) * K).to(tl.int64) + dk2 += ((bos * H + i_h) * K).to(tl.int64) + dh += ((boh * H + i_h) * K*V).to(tl.int64) + if USE_GV: + gv += ((bos * H + i_h) * V).to(tl.int64) + + stride_v = H*V + stride_h = H*K*V + stride_k = H*K + if USE_INITIAL_STATE: + dh0 += i_nh * K*V + if USE_FINAL_STATE_GRADIENT: + dht += i_nh * K*V + + if USE_FINAL_STATE_GRADIENT: + p_dht1 = tl.make_block_ptr(dht, (K, V), (V, 1), (i_k * BK, 0), (BK, 64), (1, 0)) # [BK, BV] + b_dh1 += tl.load(p_dht1, boundary_check=(0, 1)) + if V > 64: + p_dht2 = tl.make_block_ptr(dht, (K, V), (V, 1), (i_k * BK, 64), (BK, 64), (1, 0)) + b_dh2 += tl.load(p_dht2, boundary_check=(0, 1)) + if V > 128: + p_dht3 = tl.make_block_ptr(dht, (K, V), (V, 1), (i_k * BK, 128), (BK, 64), (1, 0)) + b_dh3 += tl.load(p_dht3, boundary_check=(0, 1)) + if V > 192: + p_dht4 = tl.make_block_ptr(dht, (K, V), (V, 1), (i_k * BK, 192), (BK, 64), (1, 0)) + b_dh4 += tl.load(p_dht4, boundary_check=(0, 1)) + + for i_t in range(NT - 1, -1, -1): + p_dh1 = tl.make_block_ptr(dh + i_t*stride_h, (K, V), (V, 1), (i_k * BK, 0), (BK, 64), (1, 0)) + tl.store(p_dh1, b_dh1.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) + if V > 64: + p_dh2 = tl.make_block_ptr(dh + i_t*stride_h, (K, V), (V, 1), (i_k * BK, 64), (BK, 64), (1, 0)) + tl.store(p_dh2, b_dh2.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) + if V > 128: + p_dh3 = tl.make_block_ptr(dh + i_t*stride_h, (K, V), (V, 1), (i_k * BK, 128), (BK, 64), (1, 0)) + tl.store(p_dh3, b_dh3.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) + if V > 192: + p_dh4 = tl.make_block_ptr(dh + i_t*stride_h, (K, V), (V, 1), (i_k * BK, 192), (BK, 64), (1, 0)) + tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + + # Update dk_new, 按K切分 + p_dk = tl.make_block_ptr(dk, (T, K), (stride_k, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) # [BT, BK] + p_dk2 = tl.make_block_ptr(dk2, (T, K), (stride_k, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) # [BT, BK] + + if V > 0: + p_v = tl.make_block_ptr(vg, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) # [BT, BV] + b_dk = tl.dot(b_v, tl.trans(b_dh1).to(b_v.dtype)) # [BT, BV] @ [BV, BK] -> [BT, BK] + + if V > 64: + p_v = tl.make_block_ptr(vg, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dk += tl.dot(b_v, tl.trans(b_dh2).to(b_v.dtype)) + + if V > 128: + p_v = tl.make_block_ptr(vg, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dk += tl.dot(b_v, tl.trans(b_dh3).to(b_v.dtype)) + + if V > 192: + p_v = tl.make_block_ptr(vg, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dk += tl.dot(b_v, tl.trans(b_dh4).to(b_v.dtype)) + + b_dk += tl.load(p_dk, boundary_check=(0, 1)) + + tl.store(p_dk2, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + + # Update dh, 按照K切分,收集所有V维度,q一次就好,wdo要收集所有 + + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + + if V > 0: + p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV] + b_w = tl.load(p_w, boundary_check=(0, 1)) + p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV] + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + if USE_GV: + o_v1 = tl.arange(0, 64) + b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.) + b_dh1 *= exp(b_gv_last1[None, :]) + b_do *= exp(b_gv) + b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV] + + if V > 64: + p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV] + b_w = tl.load(p_w, boundary_check=(0, 1)) + p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV] + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + if USE_GV: + o_v2 = 64 + o_v1 + b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.) + b_dh2 *= exp(b_gv_last2[None, :]) + b_do *= exp(b_gv) + b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) + + if V > 128: + p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV] + b_w = tl.load(p_w, boundary_check=(0, 1)) + p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV] + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + if USE_GV: + o_v3 = 128 + o_v1 + b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.) + b_dh3 *= exp(b_gv_last3[None, :]) + b_do *= exp(b_gv) + b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) + + if V > 192: + p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV] + b_w = tl.load(p_w, boundary_check=(0, 1)) + p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV] + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + if USE_GV: + o_v4 = 192 + o_v1 + b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.) + b_dh4 *= exp(b_gv_last4[None, :]) + b_do *= exp(b_gv) + b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) + + if USE_INITIAL_STATE: + p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (i_k * BK, 0), (BK, 64), (1, 0)) + tl.store(p_dh0, b_dh1.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + if V > 64: + p_dh1 = tl.make_block_ptr(dh0, (K, V), (V, 1), (i_k * BK, 64), (BK, 64), (1, 0)) + tl.store(p_dh1, b_dh2.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) + if V > 128: + p_dh2 = tl.make_block_ptr(dh0, (K, V), (V, 1), (i_k * BK, 128), (BK, 64), (1, 0)) + tl.store(p_dh2, b_dh3.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) + if V > 192: + p_dh3 = tl.make_block_ptr(dh0, (K, V), (V, 1), (i_k * BK, 192), (BK, 64), (1, 0)) + tl.store(p_dh3, b_dh4.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_oja2_bwd_dhu( + q: torch.Tensor, + vg: torch.Tensor, + w: torch.Tensor, + do: torch.Tensor, + dk: torch.Tensor, + gv: Optional[torch.Tensor] = None, + h0: Optional[torch.Tensor] = None, + dht: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + states_in_fp32: bool = False +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *q.shape, do.shape[-1] + # N: the actual number of sequences in the batch with either equal or variable lengths + BT = 64 + assert K <= 256, "current kernel does not support head dimension being larger than 256." + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + + dh = q.new_empty(B, NT, H, K, V, dtype=q.dtype if not states_in_fp32 else torch.float) + dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None + dk2 = torch.empty_like(dk) + + def grid(meta): return (triton.cdiv(K, meta['BK']), N*H) + chunk_oja2_bwd_kernel_dhu_blockdim64[grid]( + q=q, + vg=vg, + w=w, + gv=gv, + dht=dht, + dh0=dh0, + do=do, + dh=dh, + dk=dk, + dk2=dk2, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + return dh, dh0, dk2 + + + + + + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + ], + key=['BT'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_gsa_bwd_k_kernel_dqkvg( + q, + k, + v, + h, + g, + A, + do, + dh, + dq, + dk, + dv, + dg, + dgv, + dA, + cu_seqlens, + chunk_indices, + scale, + T, + B: tl.constexpr, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // NG + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + all = T + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + all = B * T + + o_i = tl.arange(0, BT) + o_t = min(i_t * BT + BT, T) + m_s = o_i[:, None] >= o_i[None, :] + + p_q = tl.make_block_ptr(q + (bos*HQ+i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_A = tl.make_block_ptr(A + ((i_k*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k)) + b_A = tl.where(m_s, b_A, 0.) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + o_v = i_v * BV + tl.arange(0, BV) + p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_g = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gn = g + (bos + o_t - 1) * H*V + i_h * V + o_v + p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + ((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dgv = tl.make_block_ptr(dgv+((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + (i_tg * HQ + i_hq) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + m_v = o_v < V + + # [BV,] + b_gn = tl.load(p_gn, mask=m_v, other=0) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_gv = exp(b_gn[None, :] - b_g) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * exp(b_g) * scale).to(b_do.dtype) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BV] + b_dg = tl.sum(tl.trans(b_h) * b_dh, 0) * exp(b_gn) + + b_dh = b_dh.to(b_k.dtype) + # [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_k.dtype)) + b_dk += tl.dot((b_v * b_gv).to(b_v.dtype), tl.trans(b_dh)) + # [BT, BV] + b_dv = tl.dot(b_k, b_dh) * b_gv + # [BV] + b_dg += tl.sum(b_dv * b_v, 0) + + if i_k == 0: + b_dgv = tl.load(p_dg, boundary_check=(0, 1)) + b_dg[None, :] + else: + b_dgv = tl.zeros([BT, BV], dtype=tl.float32) + b_dg[None, :] + + tl.store(p_dgv, b_dgv.to(p_dgv.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + p_dA = tl.make_block_ptr(dA + (bos*HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dq = tl.make_block_ptr(dq + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + # [BT, BT] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BT, BK] + b_dq += tl.dot(b_dA, b_k) + b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q) + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + + + + + + +@triton.heuristics({ + 'USE_GV': lambda args: args['gv'] is not None, + 'HAVE_GK': lambda args: args['dgk'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_GV'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_oja2_bwd_kernel_dvwg_h( + k, + v, + gv, + h, + dh, + dk, + dw, + dv, + dgv_last, + dgk, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_GV: tl.constexpr, + HAVE_GK: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + gv += (bos * H + i_h) * V + h += (i_tg * H + i_h).to(tl.int64) * K*V + dh += (i_tg * H + i_h).to(tl.int64) * K*V + dk += (bos * H + i_h) * K + dw += (bos * H + i_h) * V + dv += (bos * H + i_h) * V + dgv_last += (bos * H + i_h) * V + + b_dvg = tl.zeros([BT, BV], dtype=tl.float32) + b_dw = tl.zeros([BT, BV], dtype=tl.float32) + b_dgv_last = tl.zeros([BV,], dtype=tl.float32) + + if USE_GV: + o_v = i_v * BV + tl.arange(0, BV) + m_v = o_v < V + p_gn = gv + (min(T, i_t * BT + BT) - 1) * H*V + o_v + p_gv = tl.make_block_ptr(gv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_gn = tl.load(p_gn, mask=m_v, other=0) + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) # BT BK + b_dk = tl.load(p_dk, boundary_check=(0, 1)) # BT BK + b_h = tl.load(p_h, boundary_check=(0, 1)) # BK BV + b_dh = tl.load(p_dh, boundary_check=(0, 1)) # BK BV + + b_dvg += tl.dot(b_k, b_dh.to(b_k.dtype)) # BT BK @ BK BV -> BT BV + b_dw += tl.dot(b_dk.to(b_k.dtype), b_h.to(b_k.dtype)) # BT BK @ BK BV -> BT BV + b_dgv_last += tl.sum((b_h * b_dh) * exp(b_gn), axis=0) + + if USE_GV: + b_dv = b_dvg * exp(b_gn[None, :] - b_gv) + + p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dw = tl.make_block_ptr(dw, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dgv_last = tl.make_block_ptr(dgv_last, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + b_dgv_last += tl.sum(b_dv * b_v, axis=0) + + # 留给GSA2的接口 + if HAVE_GK: + dgk += (bos * H + i_h) * V + p_dgk = tl.make_block_ptr(dgk, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dgk = tl.load(p_dgk, boundary_check=(0, 1)) + b_dgv_last = b_dgk + b_dgv_last[None, :] + else: + b_dgv_last = tl.zeros([BT, BV], dtype=tl.float32) + b_dgv_last[None, :] + + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dgv_last, b_dgv_last.to(p_dgv_last.dtype.element_ty), boundary_check=(0, 1)) + + + + +def chunk_oja2_bwd_dvwg_h( + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + dh: torch.Tensor, + dk: torch.Tensor, + gv: Optional[torch.Tensor] = None, + dgk: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) + BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) + NV = triton.cdiv(V, BV) + dv = torch.empty_like(v, dtype=torch.float) + dw = torch.empty_like(v) + dgv_last = torch.empty_like(gv) + + grid = (NV, NT, B * H) + chunk_oja2_bwd_kernel_dvwg_h[grid]( + k=k, + v=v, + gv=gv, + h=h, + dh=dh, + dw=dw, + dk=dk, + dv=dv, + dgv_last=dgv_last, + dgk=dgk, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return dv, dw, dgv_last \ No newline at end of file diff --git a/fla/ops/oja2/chunk_kkt.py b/fla/ops/oja2/chunk_kkt.py new file mode 100644 index 0000000000..8ff0a21d96 --- /dev/null +++ b/fla/ops/oja2/chunk_kkt.py @@ -0,0 +1,517 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import prepare_chunk_indices +from fla.ops.utils.op import exp + + +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'BT', 'IS_VARLEN'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + g, + beta, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + + p_b = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_b = tl.load(p_b, boundary_check=(0,)) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, tl.trans(b_k)) + + if USE_G: + p_g = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_A *= exp(b_g_diff) + b_A *= b_b[:, None] + + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["BC"] +) +@triton.jit(do_not_specialize=['T']) +def chunk_scaled_dot_kkt_fwd_kernel_intra_sub_inter( + k, + g, + beta, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_i, i_j = i_c // NC, i_c % NC + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + if i_i <= i_j: + return + + k += (bos * H + i_h) * K + g += (bos * H + i_h) * K + A += (bos * H + i_h) * BT + + p_b = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,)) + b_b = tl.load(p_b, boundary_check=(0,)) + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + b_kt = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(g, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + # [BK,] + b_gn = tl.load(g + (i_t * BT + i_i * BC) * H*K + o_k, mask=m_k, other=0) + # [BC, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) * exp(b_g - b_gn[None, :]) + # [BK, BC] + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kt = tl.load(b_kt, boundary_check=(0, 1)) * exp(b_gn[:, None] - b_gk) + # [BC, BC] + b_A += tl.dot(b_k, b_kt) + b_A *= b_b[:, None] + + p_A = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=["BK", "BT"] +) +@triton.jit(do_not_specialize=['T']) +def chunk_scaled_dot_kkt_fwd_kernel_intra_sub_intra( + k, + g, + beta, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + + o_i = tl.arange(0, BC) + o_k = tl.arange(0, BK) + m_k = o_k < K + m_A = (i_t * BT + i_i * BC + o_i) < T + o_A = (bos + i_t * BT + i_i * BC + o_i) * H*BT + i_h * BT + i_i * BC + + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_b = beta + (bos + i_t * BT + i_i * BC + o_i) * H + i_h + + b_k = tl.load(p_k, boundary_check=(0, 1)) * tl.load(p_b, mask=m_A, other=0)[:, None] + b_g = tl.load(p_g, boundary_check=(0, 1)) + + p_kt = k + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_gk = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32) + b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) + b_A = tl.sum(b_k * b_kt[None, :] * exp(b_g - b_gk[None, :]), 1) + b_A = tl.where(o_i > j, b_A, 0.) + + tl.store(A + o_A + j, b_A, mask=m_A) + p_kt += H*K + p_gk += H*K + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8] + ], + key=['BK', 'NC', 'BT'], +) +@triton.jit(do_not_specialize=['B', 'T']) +def chunk_scaled_dot_kkt_bwd_kernel_gk( + k, + g, + beta, + dA, + dk, + dg, + db, + cu_seqlens, + chunk_indices, + B, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_t, i_i = i_c // NC, i_c % NC + + all = B * T + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + if i_t * BT + i_i * BC >= T: + return + + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + k += (bos * H + i_h) * K + g += (bos * H + i_h) * K + beta += bos * H + i_h + + dA += (bos * H + i_h) * BT + dk += (bos * H + i_h) * K + dg += (bos * H + i_h) * K + db += (i_k * all + bos) * H + i_h + + p_g = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_b = tl.make_block_ptr(beta, (T,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,)) + # [BC, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + # [BC] + b_b = tl.load(p_b, boundary_check=(0,)) + if i_i > 0: + p_gn = g + (i_t * BT + i_i * BC) * H*K + o_k + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = b_k * exp(b_gn[None, :] - b_gk) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dkb = tl.dot(b_dA, b_kg) * exp(b_g - b_gn[None, :]) + b_dk += b_dkb + + o_i = tl.arange(0, BC) + m_dA = (i_t * BT + i_i * BC + o_i) < T + o_dA = (i_t * BT + i_i * BC + o_i) * H*BT + i_i * BC + p_kj = k + (i_t * BT + i_i * BC) * H*K + o_k + p_gkj = g + (i_t * BT + i_i * BC) * H*K + o_k + + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC] + b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) + # [BK] + b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) + b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] >= j + # [BC, BK] + b_dkb = tl.where(m_i, b_dA[:, None] * b_kj[None, :] * exp(b_g - b_gkj[None, :]), 0.) + b_dk += b_dkb + + p_kj += H*K + p_gkj += H*K + b_db = tl.sum(b_dk * b_k, 1) + b_dk *= b_b[:, None] + p_db = tl.make_block_ptr(db, (T,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + + tl.debug_barrier() + # [BC, BK] + b_dkt = tl.zeros([BC, BK], dtype=tl.float32) + + NC = min(NC, tl.cdiv(T - i_t * BT, BC)) + if i_i < NC - 1: + p_gn = g + (min(i_t * BT + i_i * BC + BC, T) - 1) * H*K + o_k + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(i_i + 1, NC): + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k*BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA, (BT, T), (1, H*BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + p_b = tl.make_block_ptr(beta, (T,), (H,), (i_t * BT + i_j * BC,), (BC,), (0,)) + + o_j = i_t * BT + i_j * BC + o_i + m_j = o_j < T + # [BC] + b_b = tl.load(p_b, boundary_check=(0,)) + # [BC, BK] + b_kb = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) * b_b[:, None] + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kbg = b_kb * tl.where(m_j[:, None], exp(b_gk - b_gn[None, :]), 0) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + # (SY 09/17) important to not use bf16 here to have a good precision. + b_dkt += tl.dot(b_dA, b_kbg) + b_dkt *= exp(b_gn[None, :] - b_g) + o_dA = (i_t * BT + i_i * BC) * H*BT + i_i * BC + o_i + p_kj = k + (i_t * BT + i_i * BC) * H*K + o_k + p_gkj = g + (i_t * BT + i_i * BC) * H*K + o_k + p_bj = beta + (i_t * BT + i_i * BC) * H + + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC,] + b_dA = tl.load(dA + o_dA + j * H*BT) + # [BK,] + b_kbj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) * tl.load(p_bj) + b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) + b_kbgj = b_kbj[None, :] * exp(b_gkj[None, :] - b_g) + # [BC, BK] + m_i = o_i[:, None] <= j + b_dkt += tl.where(m_i, b_dA[:, None] * b_kbgj, 0.) + + p_kj += H*K + p_gkj += H*K + p_bj += H + b_dg = (b_dk - b_dkt) * b_k + b_dk += b_dkt + + p_dk = tl.make_block_ptr(dk, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_scaled_dot_kkt_fwd( + k: torch.Tensor, + g: Optional[torch.Tensor] = None, + gk: Optional[torch.Tensor] = None, + beta: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32 +) -> torch.Tensor: + r""" + Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + g (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`. + gk (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`. + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + """ + B, T, H, K = k.shape + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + if gk is None: + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) + chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( + k=k, + g=g, + beta=beta, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + ) + return A + + BC = min(16, BT) + NC = triton.cdiv(BT, BC) + BK = max(triton.next_power_of_2(K), 16) + A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype) + grid = (NT, NC * NC, B * H) + chunk_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid]( + k=k, + g=gk, + beta=beta, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + NC=NC, + ) + + grid = (NT, NC, B * H) + chunk_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid]( + k=k, + g=gk, + beta=beta, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + ) + return A + + +def chunk_scaled_dot_kkt_bwd_gk( + k: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + dA: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + B, T, H, K = k.shape + BT = chunk_size + BC = min(16, BT) + BK = min(64, triton.next_power_of_2(K)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NC = triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + + dk = torch.empty_like(k, dtype=torch.float) + dg = torch.empty_like(g, dtype=torch.float) + db = beta.new_empty(NK, *beta.shape, dtype=torch.float) + grid = (NK, NT * NC, B * H) + chunk_scaled_dot_kkt_bwd_kernel_gk[grid]( + k=k, + g=g, + beta=beta, + dA=dA, + dk=dk, + dg=dg, + db=db, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + B=B, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + NC=NC, + ) + db = db.sum(0) + + return dk, dg, db diff --git a/fla/ops/oja2/chunk_o.py b/fla/ops/oja2/chunk_o.py new file mode 100644 index 0000000000..a43fec6d6a --- /dev/null +++ b/fla/ops/oja2/chunk_o.py @@ -0,0 +1,691 @@ +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import prepare_chunk_indices +from fla.ops.utils.op import exp +from fla.utils import check_shared_mem, is_nvidia_hopper +from fla.ops.utils.cumsum import chunk_local_cumsum + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + +exp = tl.exp + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for BV in [32, 64] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_oja2_fwd_inter( + q, + k, + h, + gv, + o, + A, + cu_seqlens, + chunk_indices, + scale, + T, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // NG + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BT] + b_A += tl.dot(b_q, b_k) + p_g = tl.make_block_ptr(gv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + (bos * HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_o = b_o * exp(b_g) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + # [BT, BT] + b_A = tl.where(m_s, b_A, 0.) + if i_v == 0: + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.jit(do_not_specialize=['T']) +def chunk_oja2_fwd_intra( + v, + gv, + o, + A, + cu_seqlens, + chunk_indices, + T, + HQ: tl.constexpr, + H: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr, + NG: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // NG + i_t, i_i = i_c // NC, i_c % NC + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_v = i_v * BV + tl.arange(0, BV) + m_v = o_v < V + + if i_t * BT + i_i * BC >= T: + return + + p_g = tl.make_block_ptr(gv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_gn = gv + (bos + min(i_t * BT + i_i * BC, T)) * H*V + i_h * V + o_v + # [BV,] + b_gn = tl.load(p_gn, mask=m_v, other=0) + # [BC, BV] + b_o = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(0, i_i): + p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0)) + p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + p_gv = tl.make_block_ptr(gv + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_vg = (b_v * exp(b_gn[None, :] - b_gv)).to(b_v.dtype) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A, b_vg) + # [BC, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_o *= exp(b_g - b_gn[None, :]) + + o_i = tl.arange(0, BC) + o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * HQ*BT + i_hq * BT + i_i * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + p_v = v + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v + p_gv = gv + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v + # [BC,] + b_A = tl.load(A + o_A + j, mask=m_A, other=0) + # [BV,] + b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32) + b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32) + # [BC, BV] + b_vg = b_v[None, :] * exp(b_g - b_gv[None, :]) + # avoid 0 * inf = inf + b_o += tl.where(o_i[:, None] >= j, b_A[:, None] * b_vg, 0.) + p_o = tl.make_block_ptr(o + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + b_o += tl.load(p_o, boundary_check=(0, 1)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + + + +def chunk_oja2_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gv: torch.Tensor, + h: torch.Tensor, + scale: float = 1., + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BC = min(16, BT) + BV = min(64, triton.next_power_of_2(V)) + HQ = q.shape[2] + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NC = triton.cdiv(BT, BC) + NG = HQ // H + + o = v.new_empty(B, T, HQ, V) + A = q.new_empty(B, T, HQ, BT) + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * HQ) + chunk_oja2_fwd_inter[grid]( + q, + k, + h, + gv, + o, + A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + NG=NG, + ) + + def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ) + chunk_oja2_fwd_intra[grid]( + v, + gv, + o, + A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + HQ=HQ, + H=H, + V=V, + BT=BT, + BC=BC, + BV=BV, + NC=NC, + NG=NG, + num_warps=4, + num_stages=2 + ) + return A, o + + + + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [2, 4, 8] + ], + key=["BT"] +) +@triton.jit(do_not_specialize=['T']) +def chunk_oja2_bwd_kernel_dA( + v, + gv, + do, + dA, + chunk_indices, + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + all = T + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + all = B * T + + o_v = i_v * BV + tl.arange(0, BV) + m_v = o_v < V + + if i_t * BT + i_i * BC >= T: + return + + # [BC, BC] + b_dA = tl.zeros([BC, BC], dtype=tl.float32) + if i_i > i_j: + p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1)) + p_gv = tl.make_block_ptr(gv + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1)) + p_gn = gv + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v + p_g = tl.make_block_ptr(gv + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + # [BV,] + b_gn = tl.load(p_gn, mask=m_v, other=0.) + # [BC, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * exp(b_g - b_gn[None, :]) * scale).to(b_do.dtype) + # [BV, BC] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_vg = (b_v * exp(b_gn[:, None] - b_gv)).to(b_v.dtype) + # [BC, BC] + b_dA = tl.dot(b_do, b_vg) + elif i_i == i_j: + p_g = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_v = v + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v + p_gv = gv + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v + # [BC, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) * scale + m_v = o_v < V + + o_i = tl.arange(0, BC) + # [BC, BC] + m_dA = o_i[:, None] >= o_i[None, :] + for j in range(0, min(BC, T - i_t * BT - i_j * BC)): + # [BV,] + b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32) + b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32) + # [BC,] + b_dAj = tl.sum(b_do * b_v[None, :] * exp(b_g - b_gv[None, :]), 1) + b_dA = tl.where((o_i == j)[None, :], b_dAj[:, None], b_dA) + + p_v += H*V + p_gv += H*V + b_dA = tl.where(m_dA, b_dA, 0.) + + p_dA = tl.make_block_ptr(dA+((i_v*all+bos)*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t*BT+i_i*BC, i_j*BC), (BC, BC), (1, 0)) + tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1)) + + + + + +def chunk_oja2_bwd_dA( + v: torch.Tensor, + gv: torch.Tensor, + do: torch.Tensor, + scale: float = 1., + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + B, T, H, V = v.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BC = min(16, BT) + BV = min(64, triton.next_power_of_2(V)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NC = triton.cdiv(BT, BC) + NV = triton.cdiv(V, BV) + + dA = v.new_empty(NV, B, T, H, BT) + # 计算dA + grid = (NV, NT * NC * NC, B * H) + chunk_oja2_bwd_kernel_dA[grid]( + v, + gv, + do, + dA, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + B=B, + H=H, + V=V, + BT=BT, + BC=BC, + BV=BV, + NC=NC, + ) + dA = dA.sum(0, dtype=dA.dtype) + + return dA + + + + + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + ], + key=['BT'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_oja2_bwd_kernel_dqk( + q, + k, + h, + gv, + A, + dq, + dk, + dA, + do, + scale, + cu_seqlens, + chunk_indices, + B, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + all = T + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + all = B * T + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + # [B, T, H, BT] + p_q = tl.make_block_ptr(q + (bos*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_A = tl.make_block_ptr(A + ((i_k*all+bos)*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + + b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k)) + b_A = tl.where(m_s, b_A, 0.) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + + # 先计算do对应的dq + for i_v in range(tl.cdiv(V, BV)): + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv = tl.make_block_ptr(gv + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_do = (b_do * exp(b_gv) * scale).to(b_do.dtype) + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + + # 接着计算dA对应的dq, dk + p_dA = tl.make_block_ptr(dA + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dq = tl.make_block_ptr(dq + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + # [BT, BT] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BT, BK] + b_dq += tl.dot(b_dA.to(b_q.dtype), b_k) + b_dk = tl.dot(tl.trans(b_dA).to(b_q.dtype), b_q) + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + + + +def chunk_oja2_bwd_dqk( + q: torch.Tensor, + k: torch.Tensor, + h: torch.Tensor, + gv: torch.Tensor, + dA: torch.Tensor, + do: torch.Tensor, + scale: float = 1., + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + B, T, H, K, V = *q.shape, gv.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NK = triton.cdiv(K, BK) + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + A = dA.new_empty(NK, B, T, H, BT) + # 计算dA + grid = (NK, NT, B * H) + chunk_oja2_bwd_kernel_dqk[grid]( + q, + k, + h, + gv, + A, + dq, + dk, + dA, + do, + scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + B=B, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV + ) + + A = A.sum(0, dtype=A.dtype) + + return A, dq, dk + + + + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.jit(do_not_specialize=['T']) +def chunk_oja2_bwd_kernel_dv_o( + v, + g, + o, + A, + do, + dv, + dv2, + dg, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_t, i_i = i_c // NC, i_c % NC + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_v = i_v * BV + tl.arange(0, BV) + m_v = o_v < V + + if i_t * BT + i_i * BC >= T: + return + + p_gv = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T)-1)*H*V + i_h*V + o_v + # [BV,] + b_gn = tl.load(p_gn, mask=m_v, other=0) + # [BC, BV] + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_dvg = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + p_A = tl.make_block_ptr(A + (bos*H+i_h) * BT, (BT, T), (1, H*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_j*BC, i_v*BV), (BC, BV), (1, 0)) + # [BC, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) * exp(b_g - b_gn[None, :]) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + # [BC, BV] + b_dvg += tl.dot(b_A, b_do.to(b_A.dtype)) + b_dv = b_dvg * exp(b_gn[None, :] - b_gv) + + o_i = tl.arange(0, BC) + o_c = i_i * BC + tl.arange(0, BC) + + p_g = g + (bos + i_t * BT + i_i * BC) * H*V + i_h * V + o_v + p_A = A + (bos + i_t*BT + i_i*BC) * H*BT + i_h * BT + o_c + p_do = do + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC,] + b_A = tl.load(p_A) + # [BV,] + b_g = tl.load(p_g, mask=m_v, other=0) + b_do = tl.load(p_do, mask=m_v, other=0) + # [BC, BV] + m_i = o_i[:, None] <= j + b_dv += tl.where(m_i, exp(b_g[None, :] - b_gv) * b_A[:, None] * b_do[None, :], 0.) + + p_g += H * V + p_A += H * BT + p_do += H * V + p_o = tl.make_block_ptr(o + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_dv2 = tl.make_block_ptr(dv2 + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + + b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) + b_v = tl.load(p_v, boundary_check=(0, 1)).to(tl.float32) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(tl.float32) + b_dv = b_dv + tl.load(p_dv, boundary_check=(0, 1)).to(tl.float32) + b_dg = b_o * b_do - b_v * b_dv + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + + + +def chunk_oja2_bwd_dv_o( + v: torch.Tensor, + gv: torch.Tensor, + o: torch.Tensor, + A: torch.Tensor, + dv: torch.Tensor, + do: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + B, T, H, V = v.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BC = min(16, BT) + BV = min(64, triton.next_power_of_2(V)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NC = triton.cdiv(BT, BC) + + dv2 = torch.empty_like(v, dtype=torch.float) + dgv = torch.empty_like(gv) + # 计算dA + def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * H) + chunk_oja2_bwd_kernel_dv_o[grid]( + v=v, + g=gv, + o=o, + A=A, + do=do, + dv=dv, + dv2=dv2, + dg=dgv, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + V=V, + BT=BT, + BC=BC, + BV=BV, + NC=NC, + num_warps=4, + num_stages=2 + ) + return dv2, dgv + + diff --git a/fla/ops/oja2/fused_recurrent.py b/fla/ops/oja2/fused_recurrent.py new file mode 100644 index 0000000000..4f509a13e0 --- /dev/null +++ b/fla/ops/oja2/fused_recurrent.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp +from fla.utils import input_guard + + +@triton.heuristics({ + 'USE_GV': lambda args: args['gv'] is not None, + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_oja2_fwd_kernel( + q, + k, + v, + gv, + beta, + o, + h0, + ht, + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_GV: tl.constexpr, + USE_Q_L2NORM: tl.constexpr, + USE_K_L2NORM: tl.constexpr, + IS_BETA_HEADWISE: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + o_k = tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if USE_GV: + p_gv = gv + (bos * HV + i_hv) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + bos * HV + i_hv + else: + p_beta = beta + (bos * HV + i_hv) * V + o_v + + p_o = o + (bos * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + if USE_Q_L2NORM: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + if USE_K_L2NORM: + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta).to(tl.float32) + else: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + + # [BK, BV] + if USE_GV: + b_gv = tl.load(p_gv).to(tl.float32) + b_h *= exp(b_gv[None, :]) + + b_k = b_beta * (b_k - tl.sum(b_h * b_v[None, :], 1)) + b_h += b_k[:, None] * b_v + + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + p_q += H*K + p_k += H*K + p_v += HV*V + if USE_GV: + p_gv += HV*V + p_beta += HV * (1 if IS_BETA_HEADWISE else V) + p_o += HV*V + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_oja2_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gv: Optional[torch.Tensor] = None, + beta: Optional[torch.Tensor] = None, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + use_q_l2norm: bool = False, + use_k_l2norm: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + assert V <= 128 + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 256) + NV = triton.cdiv(V, BV) + num_stages = 3 + num_warps = 1 + + o = torch.empty_like(v) + final_state = q.new_empty(N, HV, K, V, dtype=torch.float32) if output_final_state else None + + grid = (NV, N * HV) + fused_recurrent_oja2_fwd_kernel[grid]( + q=q, + k=k, + v=v, + gv=gv, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + scale=scale, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + IS_BETA_HEADWISE=beta.ndim != v.ndim, + USE_Q_L2NORM=use_q_l2norm, + USE_K_L2NORM=use_k_l2norm, + num_warps=num_warps, + num_stages=num_stages, + ) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gv: Optional[torch.Tensor] = None, + beta: Optional[torch.Tensor] = None, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + use_q_l2norm: bool = False, + use_k_l2norm: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + ): + o, final_state = fused_recurrent_oja2_fwd( + q=q, + k=k, + v=v, + gv=gv, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + use_q_l2norm=use_q_l2norm, + use_k_l2norm=use_k_l2norm, + cu_seqlens=cu_seqlens, + ) + + return o, final_state + + @staticmethod + @input_guard + def backward(ctx, do, dht): + raise NotImplementedError( + "Backward pass is not implemented yet and we do not have plans to implement it " + "because we haven't figured out how to compute dg without materializing the full " + "hidden states for all time steps." + ) + + +def fused_recurrent_oja2( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gv: Optional[torch.Tensor] = None, + beta: Optional[torch.Tensor] = None, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + use_q_l2norm: bool = False, + use_k_l2norm: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + + if 'use_qk_l2norm_in_kernel' in kwargs and (not use_q_l2norm and not use_k_l2norm): + use_q_l2norm = True + use_k_l2norm = True + + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + if beta is None: + beta = torch.ones_like(q[..., 0]) + + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + gv, + beta, + scale, + initial_state, + output_final_state, + use_q_l2norm, + use_k_l2norm, + cu_seqlens, + ) + return o, final_state diff --git a/fla/ops/oja2/wy_fast.py b/fla/ops/oja2/wy_fast.py new file mode 100644 index 0000000000..7fdd5d994c --- /dev/null +++ b/fla/ops/oja2/wy_fast.py @@ -0,0 +1,289 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + +from fla.ops.utils import chunk_local_cumsum, prepare_chunk_indices +from fla.ops.utils.op import exp +from fla.utils import check_shared_mem + + +@triton.heuristics({ + 'STORE_VG': lambda args: args['vg'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'], +) +@triton.jit(do_not_specialize=['T']) +def recompute_w_u_fwd_kernel( + k, + v, + vg, + beta, + w, + u, + A, + gv, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + STORE_VG: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_b = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_b = tl.load(p_b, boundary_check=(0,)) + + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_w = tl.make_block_ptr(w + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = b_v * b_b[:, None] + + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_vb *= exp(b_gv) + if STORE_VG: + last_idx = min(i_t * BT + BT, T) - 1 + + o_v = i_v * BV + tl.arange(0, BV) + m_v = o_v < V + b_gn = tl.load(gv + ((bos + last_idx) * H + i_h) * V + o_v, mask=m_v, other=0.) + b_vg = b_v * exp(b_gn - b_gv) + + p_vg = tl.make_block_ptr(vg + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_vg, b_vg.to(p_vg.dtype.element_ty), boundary_check=(0, 1)) + + b_w = tl.dot(b_A, b_vb.to(b_v.dtype)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_b[:, None]).to(b_k.dtype) + b_u = tl.dot(b_A, b_kb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'] +) +@triton.jit(do_not_specialize=['T']) +def prepare_wy_repr_bwd_kernel( + k, + v, + beta, + gv, + A, + dA, + dw, + du, + dk, + dv, + db, + dgv, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_b = tl.make_block_ptr(beta + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_db = tl.make_block_ptr(db + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + + b_b = tl.load(p_b, boundary_check=(0,)) + b_db = tl.zeros([BT], dtype=tl.float32) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_gv_exp = exp(tl.load(p_gv, boundary_check=(0, 1))) + b_vbg = b_v * b_b[:, None] * b_gv_exp + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + + b_dA += tl.dot(b_dw, tl.trans(b_vbg).to(b_dw.dtype)) + b_dvbg = tl.dot(b_A, b_dw) + b_dv = b_dvbg * b_gv_exp * b_b[:, None] + b_db += tl.sum(b_dvbg * b_v * b_gv_exp, 1) + b_dgv = b_dvbg * b_vbg + + p_dgv = tl.make_block_ptr(dgv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dgv, b_dgv.to(p_dgv.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_du = tl.make_block_ptr(du + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_b[:, None]).to(b_k.dtype) # BT BK + b_du = tl.load(p_du, boundary_check=(0, 1)) # BT BK + b_dA += tl.dot(b_du, tl.trans(b_kb)) # BT BT + b_dkb = tl.dot(b_A, b_du) # BT BK + b_dk = b_dkb * b_b[:, None] + b_db += tl.sum(b_dkb * b_k, 1) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_dA = tl.where(m_A, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) + b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + + b_dA = tl.where(m_A, -b_dA, 0) + + # if USE_GV: + p_dA = tl.make_block_ptr(dA + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + gv: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = A.shape[-1] + BK = 64 + BV = 64 + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + w = torch.empty_like(v) + u = torch.empty_like(k) + vg = torch.empty_like(v) if gv is not None else None + recompute_w_u_fwd_kernel[(NT, B*H)]( + k=k, + v=v, + vg=vg, + beta=beta, + w=w, + u=u, + A=A, + gv=gv, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u, vg + + +def prepare_wy_repr_bwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + dw: torch.Tensor, + du: torch.Tensor, + gv: torch.Tensor = None, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = 64 + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) + BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) + + dk = torch.empty_like(k) + dv = torch.empty_like(v, dtype=torch.float) + + dgv = torch.empty_like(gv, dtype=torch.float) + dA = torch.empty_like(A, dtype=torch.float) + db = torch.empty_like(beta, dtype=torch.float) + + prepare_wy_repr_bwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + gv=gv, + A=A, + dA=dA, + dw=dw, + du=du, + dk=dk, + dv=dv, + db=db, + dgv=dgv, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + + return dk, dv, db, dgv, dA