From bb6ae129fbc03f8ced2e7f989a3207d69ceb7c34 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Tue, 8 Apr 2025 20:23:37 +0800 Subject: [PATCH 1/3] [Deprecated] `head_first` option removed for gla variants --- fla/ops/abc/chunk.py | 2 +- fla/ops/common/chunk_h.py | 37 ++--- fla/ops/common/chunk_h_parallel.py | 47 ++---- fla/ops/common/chunk_o.py | 45 ++---- fla/ops/common/fused_recurrent.py | 206 +++++++++----------------- fla/ops/gated_delta_rule/chunk.py | 9 +- fla/ops/gla/chunk.py | 155 ++++++++----------- fla/ops/gla/fused_recurrent.py | 22 ++- fla/ops/gla/naive.py | 4 +- fla/ops/gsa/chunk.py | 11 +- fla/ops/gsa/fused_recurrent.py | 2 +- fla/ops/nsa/naive.py | 6 +- fla/ops/retention/chunk.py | 30 ++-- fla/ops/retention/fused_recurrent.py | 27 +++- fla/ops/simple_gla/chunk.py | 66 ++++----- fla/ops/simple_gla/fused_recurrent.py | 22 ++- fla/ops/utils/cumsum.py | 13 +- tests/ops/test_cumsum.py | 8 +- tests/ops/test_gla.py | 51 +++++-- tests/ops/test_retention.py | 6 +- tests/ops/test_simple_gla.py | 19 ++- 21 files changed, 348 insertions(+), 440 deletions(-) diff --git a/fla/ops/abc/chunk.py b/fla/ops/abc/chunk.py index 194e51094c..63a63218df 100644 --- a/fla/ops/abc/chunk.py +++ b/fla/ops/abc/chunk.py @@ -1093,7 +1093,7 @@ def chunk_abc( v (torch.Tensor): values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` s (torch.Tensor): - slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]` + slot representations of shape `[B, T, H, M]` if `head_first=False` else `[B, H, T, M]` initial_state (Optional[Tuple[torch.Tensor, torch.Tensor]]): Initial states of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `None`. output_final_state (Optional[bool]): diff --git a/fla/ops/common/chunk_h.py b/fla/ops/common/chunk_h.py index 50c5a9a952..7d6bbdc4f1 100644 --- a/fla/ops/common/chunk_h.py +++ b/fla/ops/common/chunk_h.py @@ -302,29 +302,22 @@ def chunk_fwd_h( h0: torch.Tensor, output_final_state: bool, offsets: Optional[torch.Tensor] = None, - head_first: bool = False, chunk_size: int = 64, split_size: Optional[int] = None, states_in_fp32: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: - if head_first: - B, H, T, K, V = *k.shape, v.shape[-1] - else: - B, T, H, K, V = *k.shape, v.shape[-1] + B, T, H, K, V = *k.shape, v.shape[-1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T))) assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}" # N: the actual number of sequences in the batch with either equal or variable lengths if offsets is None: - split_offsets, N, NS = None, B, triton.cdiv(T, BS) + N, NS, split_offsets = B, triton.cdiv(T, BS), None else: split_offsets = prepare_chunk_offsets(offsets, BS) - N, NS = len(offsets) - 1, split_offsets[-1] + N, NS = len(offsets) - 1, split_offsets[-1].item() - if head_first: - h = k.new_empty(B, H, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) - else: - h = k.new_empty(B, NS, H, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + h = k.new_empty(B, NS, H, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) chunk_fwd_kernel_h[grid]( @@ -347,7 +340,7 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), USE_G=g is not None, USE_GK=gk is not None, USE_GV=gv is not None, - HEAD_FIRST=head_first + HEAD_FIRST=False ) return h, ht @@ -364,33 +357,25 @@ def chunk_bwd_dh( dht: torch.Tensor, scale: float, offsets: Optional[torch.Tensor] = None, - head_first: bool = False, chunk_size: int = 64, split_size: Optional[int] = None, states_in_fp32: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: - if head_first: - B, H, T, K, V = *k.shape, v.shape[-1] - HQ = q.shape[1] - else: - B, T, H, K, V = *k.shape, v.shape[-1] - HQ = q.shape[2] + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T))) assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}" # N: the actual number of sequences in the batch with either equal or variable lengths # NG: number of groups in GQA if offsets is None: - split_offsets, N, NS = None, B, triton.cdiv(T, BS) + N, NS, split_offsets = B, triton.cdiv(T, BS), None else: split_offsets = prepare_chunk_offsets(offsets, BS) - N, NS = len(offsets) - 1, split_offsets[-1] + N, NS = len(offsets) - 1, split_offsets[-1].item() NG = HQ // H - if head_first: - dh = k.new_empty(B, HQ, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) - else: - dh = k.new_empty(B, NS, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + dh = k.new_empty(B, NS, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) @@ -417,6 +402,6 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), USE_G=g is not None, USE_GK=gk is not None, USE_GV=gv is not None, - HEAD_FIRST=head_first + HEAD_FIRST=False ) return dh, dh0 diff --git a/fla/ops/common/chunk_h_parallel.py b/fla/ops/common/chunk_h_parallel.py index 87d0eff07f..163bb41b94 100644 --- a/fla/ops/common/chunk_h_parallel.py +++ b/fla/ops/common/chunk_h_parallel.py @@ -11,6 +11,7 @@ import triton import triton.language as tl +from fla.ops.common.utils import prepare_chunk_indices, prepare_chunk_offsets from fla.ops.utils.op import exp @@ -487,26 +488,18 @@ def chunk_fwd_h( output_final_state: bool, states_in_fp32: bool = False, offsets: Optional[torch.Tensor] = None, - indices: Optional[torch.Tensor] = None, - head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor]: - if head_first: - B, H, T, K, V = *k.shape, v.shape[-1] - else: - B, T, H, K, V = *k.shape, v.shape[-1] + B, T, H, K, V = *k.shape, v.shape[-1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) # N: the actual number of sequences in the batch with either equal or variable lengths if offsets is None: N, NT, chunk_offsets = B, triton.cdiv(T, BT), None else: - if indices is None: - indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()]) - indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) - N, NT = len(offsets) - 1, len(indices) - chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1) + indices = prepare_chunk_indices(offsets, BT) + N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT) - h = k.new_empty(B, H, NT, K, V, dtype=torch.float) if head_first else k.new_empty(B, NT, H, K, V, dtype=torch.float) + h = k.new_empty(B, NT, H, K, V, dtype=torch.float) ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * H) chunk_fwd_kernel_h_parallel[grid]( @@ -528,7 +521,7 @@ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), USE_G=g is not None, USE_GK=gk is not None, USE_GV=gv is not None, - HEAD_FIRST=head_first + HEAD_FIRST=False ) kvt, ht = ht, (torch.empty_like(ht) if output_final_state else None) def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) @@ -549,7 +542,7 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), USE_G=g is not None, USE_GK=gk is not None, USE_GV=gv is not None, - HEAD_FIRST=head_first + HEAD_FIRST=False ) h = h.to(k.dtype) if not states_in_fp32 else h return h, ht @@ -568,33 +561,21 @@ def chunk_bwd_dh( scale: float, states_in_fp32: bool = False, offsets: Optional[torch.Tensor] = None, - indices: Optional[torch.Tensor] = None, - head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor]: - if head_first: - B, H, T, K, V = *k.shape, v.shape[-1] - HQ = q.shape[1] - else: - B, T, H, K, V = *k.shape, v.shape[-1] - HQ = q.shape[2] + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) # N: the actual number of sequences in the batch with either equal or variable lengths # NG: number of groups in GQA if offsets is None: N, NT, chunk_offsets = B, triton.cdiv(T, BT), None else: - if indices is None: - indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()]) - indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) - N, NT = len(offsets) - 1, len(indices) - chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1) + indices = prepare_chunk_indices(offsets, BT) + N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT) NG = HQ // H - if head_first: - dh = k.new_empty(B, HQ, NT, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) - else: - dh = k.new_empty(B, NT, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + dh = k.new_empty(B, NT, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * HQ) @@ -620,7 +601,7 @@ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), USE_G=g is not None, USE_GK=gk is not None, USE_GV=gv is not None, - HEAD_FIRST=head_first + HEAD_FIRST=False ) doq0, dh0 = dh0, (torch.empty_like(dh0) if dh0 is not None else None) @@ -644,7 +625,7 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), USE_G=g is not None, USE_GK=gk is not None, USE_GV=gv is not None, - HEAD_FIRST=head_first + HEAD_FIRST=False ) dh = dh.to(q.dtype) if not states_in_fp32 else dh return dh, dh0 diff --git a/fla/ops/common/chunk_o.py b/fla/ops/common/chunk_o.py index 06469bf7f7..c9b38497c7 100644 --- a/fla/ops/common/chunk_o.py +++ b/fla/ops/common/chunk_o.py @@ -7,6 +7,7 @@ import triton import triton.language as tl +from fla.ops.common.utils import prepare_chunk_indices from fla.ops.utils.op import exp, safe_exp from fla.utils import check_shared_mem, is_nvidia_hopper @@ -461,18 +462,14 @@ def chunk_fwd_o( g: Optional[torch.Tensor] = None, # cumsum of log decay scale: Optional[float] = None, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ) -> torch.Tensor: - if head_first: - B, H, T, K, V = *q.shape, v.shape[-1] - else: - B, T, H, K, V = *q.shape, v.shape[-1] - if scale is None: - scale = k.shape[-1] ** -0.5 + B, T, H, K, V = *q.shape, v.shape[-1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) + if scale is None: + scale = k.shape[-1] ** -0.5 o = torch.empty_like(v) @@ -492,7 +489,7 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) K=K, V=V, BT=BT, - HEAD_FIRST=head_first + HEAD_FIRST=False ) return o @@ -505,15 +502,11 @@ def chunk_bwd_dv( dh: torch.Tensor, scale: float, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ) -> torch.Tensor: - if head_first: - B, H, T, K, V = *k.shape, do.shape[-1] - else: - B, T, H, K, V = *k.shape, do.shape[-1] + B, T, H, K, V = *k.shape, do.shape[-1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None # H100 can have larger block size if check_shared_mem('hopper', k.device.index): CONST_TILING = 128 @@ -545,7 +538,7 @@ def chunk_bwd_dv( BT=BT, BK=BK, BV=BV, - HEAD_FIRST=head_first + HEAD_FIRST=False ) return dv @@ -558,15 +551,11 @@ def chunk_bwd_dv_local( dh: torch.Tensor, scale: float, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ) -> torch.Tensor: - if head_first: - B, H, T, K, V = *k.shape, do.shape[-1] - else: - B, T, H, K, V = *k.shape, do.shape[-1] + B, T, H, K, V = *k.shape, do.shape[-1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None # H100 can have larger block size if check_shared_mem('hopper', k.device.index): CONST_TILING = 128 @@ -596,7 +585,7 @@ def chunk_bwd_dv_local( BT=BT, BK=BK, BV=BV, - HEAD_FIRST=head_first + HEAD_FIRST=False ) return dv @@ -612,17 +601,13 @@ def chunk_bwd_dqkwg( dv: Optional[torch.Tensor] = None, w: Optional[torch.Tensor] = None, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, chunk_size: int = 64, scale: float = 1.0, - head_first: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if head_first: - B, H, T, K, V = *k.shape, v.shape[-1] - else: - B, T, H, K, V = *k.shape, v.shape[-1] + B, T, H, K, V = *k.shape, v.shape[-1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) CONST_TILING = 64 if check_shared_mem() else 32 @@ -660,7 +645,7 @@ def chunk_bwd_dqkwg( BT=BT, BK=BK, BV=BV, - HEAD_FIRST=head_first + HEAD_FIRST=False ) if dg is not None: diff --git a/fla/ops/common/fused_recurrent.py b/fla/ops/common/fused_recurrent.py index 047e3b9945..491c2cdb70 100644 --- a/fla/ops/common/fused_recurrent.py +++ b/fla/ops/common/fused_recurrent.py @@ -50,8 +50,7 @@ def fused_recurrent_fwd_kernel( USE_GV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, - IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr + IS_VARLEN: tl.constexpr ): # indices i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) @@ -64,28 +63,16 @@ def fused_recurrent_fwd_kernel( bos, eos = i_n * T, i_n * T + T all = B * T - if HEAD_FIRST: - p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) - p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) - p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) - p_o = o + (i_k * B*H + i_nh) * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) - if USE_G: - p_g = g + i_nh * T + ((T-1) if REVERSE else 0) - if USE_GK: - p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) - if USE_GV: - p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) - else: - p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) - p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) - p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) - p_o = o + ((i_k * all + bos) + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) - if USE_G: - p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h - if USE_GK: - p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) - if USE_GV: - p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_o = o + ((i_k * all + bos) + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + if USE_G: + p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h + if USE_GK: + p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) mask_k = (i_k * BK + tl.arange(0, BK)) < K mask_v = (i_v * BV + tl.arange(0, BV)) < V @@ -113,16 +100,16 @@ def fused_recurrent_fwd_kernel( b_o = b_h * b_q[None, :] b_o = tl.sum(b_o, axis=1) tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) - p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K - p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K - p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V - p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + p_q += (-1 if REVERSE else 1) * H*K + p_k += (-1 if REVERSE else 1) * H*K + p_v += (-1 if REVERSE else 1) * H*V + p_o += (-1 if REVERSE else 1) * H*V if USE_GK: - p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_gk += (-1 if REVERSE else 1) * H*K if USE_GV: - p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + p_gv += (-1 if REVERSE else 1) * H*V if USE_G: - p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) + p_g += (-1 if REVERSE else 1) * H if STORE_FINAL_STATE: p_ht = ht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) @@ -174,7 +161,6 @@ def fused_recurrent_bwd_kernel( STORE_INITIAL_STATE_GRADIENT: tl.constexpr, USE_FINAL_STATE_GRADIENT: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) i_n, i_h = i_nh // H, i_nh % H @@ -186,28 +172,16 @@ def fused_recurrent_bwd_kernel( bos, eos = i_n * T, i_n * T + T all = B * T - if HEAD_FIRST: - p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) - p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) - p_do = do + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) - p_dq = dq + (i_v * B*H + i_nh) * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) - if USE_G: - p_g = g + i_nh * T + ((T-1) if REVERSE else 0) - if USE_GK: - p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) - if USE_GV: - p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) - else: - p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) - p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) - p_do = do + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) - p_dq = dq + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) - if USE_G: - p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h - if USE_GK: - p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) - if USE_GV: - p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_do = do + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_dq = dq + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + if USE_G: + p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h + if USE_GK: + p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) mask_k = i_k * BK + tl.arange(0, BK) < K mask_v = i_v * BV + tl.arange(0, BV) < V @@ -236,46 +210,32 @@ def fused_recurrent_bwd_kernel( b_dq = tl.sum(b_dq, axis=1) * scale tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_k) - p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K - p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V - p_do += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V - p_dq += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_k += (-1 if REVERSE else 1) * H*K + p_v += (-1 if REVERSE else 1) * H*V + p_do += (-1 if REVERSE else 1) * H*V + p_dq += (-1 if REVERSE else 1) * H*K if USE_G: - p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) + p_g += (-1 if REVERSE else 1) * H if USE_GK: - p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_gk += (-1 if REVERSE else 1) * H*K if USE_GV: - p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + p_gv += (-1 if REVERSE else 1) * H*V # sync threads tl.debug_barrier() - if HEAD_FIRST: - p_q = q + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) - p_k = k + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) - p_v = v + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) - p_do = do + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) - p_dk = dk + (i_v * B*H + i_nh) * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) - p_dv = dv + (i_k * B*H + i_nh) * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) - if USE_G: - p_g = g + i_nh * T + ((T - 1) if not REVERSE else 0) - if USE_GK: - p_gk = gk + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) - if USE_GV: - p_gv = gv + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) - else: - p_q = q + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) - p_k = k + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) - p_v = v + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) - p_do = do + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) - p_dk = dk + ((i_v * all + bos) + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) - p_dv = dv + ((i_k * all + bos) + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) - if USE_G: - p_g = g + (bos + ((T - 1) if not REVERSE else 0)) * H + i_h - if USE_GK: - p_gk = gk + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) - if USE_GV: - p_gv = gv + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_q = q + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_k = k + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_do = do + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_dk = dk + ((i_v * all + bos) + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_dv = dv + ((i_k * all + bos) + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + if USE_G: + p_g = g + (bos + ((T - 1) if not REVERSE else 0)) * H + i_h + if USE_GK: + p_gk = gk + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) b_dh = tl.zeros([BK, BV], dtype=tl.float32) if USE_FINAL_STATE_GRADIENT: @@ -302,18 +262,18 @@ def fused_recurrent_bwd_kernel( tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k) tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v) - p_q += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K - p_k += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K - p_v += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V - p_do += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V - p_dk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K - p_dv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V + p_q += (1 if REVERSE else -1) * H*K + p_k += (1 if REVERSE else -1) * H*K + p_v += (1 if REVERSE else -1) * H*V + p_do += (1 if REVERSE else -1) * H*V + p_dk += (1 if REVERSE else -1) * H*K + p_dv += (1 if REVERSE else -1) * H*V if USE_G: - p_g += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) + p_g += (1 if REVERSE else -1) * H if USE_GK: - p_gk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K + p_gk += (1 if REVERSE else -1) * H*K if USE_GV: - p_gv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V + p_gv += (1 if REVERSE else -1) * H*V if STORE_INITIAL_STATE_GRADIENT: p_dh0 = dh0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) @@ -331,22 +291,15 @@ def fused_recurrent_fwd( initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, reverse: bool = False, - offsets: Optional[torch.LongTensor] = None, - head_first: bool = False + offsets: Optional[torch.LongTensor] = None ): - if head_first: - B, H, T, K, V = *k.shape, v.shape[-1] - else: - B, T, H, K, V = *k.shape, v.shape[-1] + B, T, H, K, V = *k.shape, v.shape[-1] N = B if offsets is None else len(offsets) - 1 BK, BV = min(K, 64), min(V, 64) NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) h0 = initial_state - if output_final_state: - ht = q.new_empty(N, H, K, V, dtype=torch.float32) - else: - ht = None + ht = q.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None o = q.new_empty(NK, *v.shape, dtype=torch.float32) grid = (NV, NK, N * H) @@ -373,7 +326,6 @@ def fused_recurrent_fwd( USE_GK=gk is not None, USE_GV=gv is not None, REVERSE=reverse, - HEAD_FIRST=head_first ) o = o.sum(0) return o, ht @@ -392,13 +344,9 @@ def fused_recurrent_bwd( scale: Optional[float] = None, initial_state: Optional[torch.Tensor] = None, reverse: bool = False, - offsets: Optional[torch.LongTensor] = None, - head_first: bool = False + offsets: Optional[torch.LongTensor] = None ): - if head_first: - B, H, T, K, V = *k.shape, v.shape[-1] - else: - B, T, H, K, V = *k.shape, v.shape[-1] + B, T, H, K, V = *k.shape, v.shape[-1] N = B if offsets is None else len(offsets) - 1 BK, BV = min(K, 64), min(V, 64) @@ -438,33 +386,17 @@ def fused_recurrent_bwd( USE_GK=gk is not None, USE_GV=gv is not None, REVERSE=reverse, - HEAD_FIRST=head_first ) dq = dq.sum(0) dk = dk.sum(0) dv = dv.sum(0) dg, dgk, dgv = None, None, None if g is not None: - dg = chunk_global_cumsum( - (dq * q.float() - dk * k.float()).sum(-1), - reverse=not reverse, - offsets=offsets, - head_first=head_first - ) + dg = chunk_global_cumsum((dq * q.float() - dk * k.float()).sum(-1), reverse=not reverse, offsets=offsets) if gk is not None: - dgk = chunk_global_cumsum( - dq * q.float() - dk * k.float(), - reverse=not reverse, - offsets=offsets, - head_first=head_first - ) + dgk = chunk_global_cumsum(dq * q.float() - dk * k.float(), reverse=not reverse, offsets=offsets) if gv is not None: - dgv = chunk_global_cumsum( - do.float() * o.float() - dv * v.float(), - reverse=not reverse, - offsets=offsets, - head_first=head_first - ) + dgv = chunk_global_cumsum(do.float() * o.float() - dv * v.float(), reverse=not reverse, offsets=offsets) return dq, dk, dv, dg, dgk, dgv, dh0 @@ -486,8 +418,7 @@ def forward( initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, reverse: bool = False, - offsets: Optional[torch.LongTensor] = None, - head_first: bool = False + offsets: Optional[torch.LongTensor] = None ): o, ht = fused_recurrent_fwd( q=q, @@ -501,13 +432,11 @@ def forward( output_final_state=output_final_state, reverse=reverse, offsets=offsets, - head_first=head_first ) ctx.save_for_backward(q, k, v, g, gk, gv, initial_state, o) ctx.scale = scale ctx.reverse = reverse ctx.offsets = offsets - ctx.head_first = head_first return o.to(q.dtype), ht @staticmethod @@ -538,9 +467,8 @@ def backward(ctx, do, dht): initial_state=initial_state, reverse=ctx.reverse, offsets=ctx.offsets, - head_first=ctx.head_first ) - return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, dgk, dgv, None, dh0, None, None, None, None + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, dgk, dgv, None, dh0, None, None, None def fused_recurrent( @@ -555,7 +483,6 @@ def fused_recurrent( output_final_state: bool = False, reverse: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False ): if scale is None: scale = k.shape[-1] ** -0.5 @@ -571,5 +498,4 @@ def fused_recurrent( output_final_state, reverse, cu_seqlens, - head_first ) diff --git a/fla/ops/gated_delta_rule/chunk.py b/fla/ops/gated_delta_rule/chunk.py index 3f92abb545..680a9f6291 100644 --- a/fla/ops/gated_delta_rule/chunk.py +++ b/fla/ops/gated_delta_rule/chunk.py @@ -348,7 +348,7 @@ def chunk_gated_delta_rule( """ assert q.dtype == k.dtype == v.dtype assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." - assert len(beta.shape) == 3, "beta must be of shape [B, H, T] if head_first=True, or [B, T, H] if head_first=False." + assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." if cu_seqlens is not None: if q.shape[0] != 1: @@ -366,12 +366,9 @@ def chunk_gated_delta_rule( f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." ) if head_first: - q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v)) - beta, g = map(lambda x: rearrange(x, 'b h t -> b t h'), (beta, g)) + q, k, v, beta, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, beta, g)) if scale is None: scale = k.shape[-1] ** -0.5 - else: - assert scale > 0, "Scale must be positive." o, final_state = ChunkGatedDeltaRuleFunction.apply( q, k, @@ -386,5 +383,5 @@ def chunk_gated_delta_rule( use_qk_l2norm_in_kernel ) if head_first: - o = rearrange(o, 'b t h v -> b h t v') + o = rearrange(o, 'b t h ... -> b h t ...') return o, final_state diff --git a/fla/ops/gla/chunk.py b/fla/ops/gla/chunk.py index 473b06d86d..5e69888b7a 100644 --- a/fla/ops/gla/chunk.py +++ b/fla/ops/gla/chunk.py @@ -1,11 +1,13 @@ # -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +import warnings from typing import Optional, Tuple import torch import triton import triton.language as tl +from einops import rearrange from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h from fla.ops.common.utils import prepare_chunk_indices @@ -849,20 +851,16 @@ def chunk_gla_fwd_intra_gk( g: torch.Tensor, scale: float, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ): - if head_first: - B, H, T, K = k.shape - else: - B, T, H, K = k.shape + B, T, H, K = k.shape BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) BC = min(16, BT) NC = triton.cdiv(BT, BC) - A = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float) + A = q.new_empty(B, T, H, BT, dtype=torch.float) grid = (NT, NC * NC, B * H) chunk_gla_fwd_A_kernel_intra_sub_inter[grid]( q, @@ -878,7 +876,7 @@ def chunk_gla_fwd_intra_gk( BT=BT, BC=BC, NC=NC, - HEAD_FIRST=head_first + HEAD_FIRST=False ) grid = (NT, NC, B * H) @@ -899,13 +897,13 @@ def chunk_gla_fwd_intra_gk( BT=BT, BC=BC, BK=BK, - HEAD_FIRST=head_first + HEAD_FIRST=False ) # split then merge else: BK = min(128, triton.next_power_of_2(K)) NK = triton.cdiv(K, BK) - A_intra = q.new_empty(NK, B, *((H, T) if head_first else (T, H)), BC, dtype=torch.float) + A_intra = q.new_empty(NK, B, T, H, BC, dtype=torch.float) grid = (NK, NT * NC, B * H) chunk_gla_fwd_A_kernel_intra_sub_intra_split[grid]( @@ -924,7 +922,7 @@ def chunk_gla_fwd_intra_gk( BC=BC, BK=BK, NC=NC, - HEAD_FIRST=head_first + HEAD_FIRST=False ) grid = (NT, NC, B * H) @@ -939,7 +937,7 @@ def chunk_gla_fwd_intra_gk( BT=BT, BC=BC, NK=NK, - HEAD_FIRST=head_first + HEAD_FIRST=False ) return A @@ -952,15 +950,11 @@ def chunk_gla_fwd_o_gk( h: torch.Tensor, scale: float, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ): - if head_first: - B, H, T, K, V = *q.shape, v.shape[-1] - else: - B, T, H, K, V = *q.shape, v.shape[-1] + B, T, H, K, V = *q.shape, v.shape[-1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) o = torch.empty_like(v) @@ -980,7 +974,7 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) K=K, V=V, BT=BT, - HEAD_FIRST=head_first + HEAD_FIRST=False ) return o @@ -990,19 +984,15 @@ def chunk_gla_bwd_dA( do: torch.Tensor, scale: float, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ): - if head_first: - B, H, T, V = v.shape - else: - B, T, H, V = v.shape + B, T, H, V = v.shape BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) BV = min(64, triton.next_power_of_2(V)) - dA = v.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float) + dA = v.new_empty(B, T, H, BT, dtype=torch.float) grid = (NT, B * H) chunk_gla_bwd_kernel_dA[grid]( v, @@ -1016,7 +1006,7 @@ def chunk_gla_bwd_dA( V=V, BT=BT, BV=BV, - HEAD_FIRST=head_first + HEAD_FIRST=False ) return dA @@ -1028,15 +1018,11 @@ def chunk_gla_bwd_dv( do: torch.Tensor, dh: torch.Tensor, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ): - if head_first: - B, H, T, K, V = *k.shape, do.shape[-1] - else: - B, T, H, K, V = *k.shape, do.shape[-1] + B, T, H, K, V = *k.shape, do.shape[-1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) dv = torch.empty_like(do) @@ -1055,7 +1041,7 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) K=K, V=V, BT=BT, - HEAD_FIRST=head_first + HEAD_FIRST=False ) return dv @@ -1066,18 +1052,14 @@ def chunk_gla_bwd_dqk_intra( g: torch.Tensor, dA: torch.Tensor, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ): - if head_first: - B, H, T, K = q.shape - else: - B, T, H, K = q.shape + B, T, H, K = q.shape BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None + NT = triton.cdiv(T, BT) if offsets is None else len(indices) BC = min(16, BT) BK = min(64, triton.next_power_of_2(K)) - NT = triton.cdiv(T, BT) if offsets is None else len(indices) NC = triton.cdiv(BT, BC) NK = triton.cdiv(K, BK) @@ -1100,7 +1082,7 @@ def chunk_gla_bwd_dqk_intra( BC=BC, BK=BK, NC=NC, - HEAD_FIRST=head_first + HEAD_FIRST=False ) return dq, dk @@ -1117,15 +1099,11 @@ def chunk_gla_bwd_dqkg( dk: torch.Tensor, scale: float, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ): - if head_first: - B, H, T, K, V = *k.shape, v.shape[-1] - else: - B, T, H, K, V = *k.shape, v.shape[-1] + B, T, H, K, V = *k.shape, v.shape[-1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) dg = torch.empty_like(g) @@ -1154,7 +1132,7 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H) K=K, V=V, BT=BT, - HEAD_FIRST=head_first + HEAD_FIRST=False ) return dq2, dk2, dg @@ -1169,14 +1147,10 @@ def chunk_gla_fwd( initial_state: torch.Tensor, output_final_state: bool, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - T = q.shape[2] if head_first else q.shape[1] - BT = min(chunk_size, max(16, triton.next_power_of_2(T))) if g_cumsum is None: - g_cumsum = chunk_local_cumsum(g, BT, offsets=offsets, indices=indices, head_first=head_first) + g_cumsum = chunk_local_cumsum(g, chunk_size, offsets=offsets) h, ht = chunk_fwd_h( k=k, @@ -1188,8 +1162,7 @@ def chunk_gla_fwd( output_final_state=output_final_state, states_in_fp32=False, offsets=offsets, - head_first=head_first, - chunk_size=BT + chunk_size=chunk_size ) # the intra A is kept in fp32 @@ -1200,9 +1173,7 @@ def chunk_gla_fwd( g=g_cumsum, scale=scale, offsets=offsets, - indices=indices, - head_first=head_first, - chunk_size=BT + chunk_size=chunk_size ) o = chunk_gla_fwd_o_gk( q=q, @@ -1212,9 +1183,7 @@ def chunk_gla_fwd( h=h, scale=scale, offsets=offsets, - indices=indices, - head_first=head_first, - chunk_size=BT + chunk_size=chunk_size ) return g_cumsum, A, h, ht, o @@ -1232,14 +1201,12 @@ def chunk_gla_bwd( do: torch.Tensor, dht: torch.Tensor, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ): - T = q.shape[2] if head_first else q.shape[1] + T = q.shape[1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) if g_cumsum is None: - g_cumsum = chunk_local_cumsum(g, BT, offsets=offsets, indices=indices, head_first=head_first) + g_cumsum = chunk_local_cumsum(g, BT, offsets=offsets) if h is None: h, _ = chunk_fwd_h( @@ -1251,7 +1218,6 @@ def chunk_gla_bwd( h0=initial_state, output_final_state=False, offsets=offsets, - head_first=head_first, chunk_size=BT, states_in_fp32=True ) @@ -1267,7 +1233,6 @@ def chunk_gla_bwd( dht=dht, scale=scale, offsets=offsets, - head_first=head_first, chunk_size=BT, states_in_fp32=True ) @@ -1279,8 +1244,6 @@ def chunk_gla_bwd( do=do, dh=dh, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=BT ) @@ -1290,8 +1253,6 @@ def chunk_gla_bwd( do=do, scale=scale, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=BT ) dq, dk = chunk_gla_bwd_dqk_intra( @@ -1300,8 +1261,6 @@ def chunk_gla_bwd( g=g_cumsum, dA=dA, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=BT ) dq, dk, dg = chunk_gla_bwd_dqkg( @@ -1316,8 +1275,6 @@ def chunk_gla_bwd( dk=dk, scale=scale, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=BT ) return dq, dk, dv, dg, dh0 @@ -1336,17 +1293,11 @@ def forward( scale, initial_state, output_final_state, - offsets, - head_first + offsets ): - T = q.shape[2] if head_first else q.shape[1] + T = q.shape[1] chunk_size = min(64, max(16, triton.next_power_of_2(T))) - # 2-d indices denoting the offsets of chunks in each sequence - # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, - # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be - # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] - indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None g_cumsum, A, h, ht, o = chunk_gla_fwd( q=q, k=k, @@ -1357,8 +1308,6 @@ def forward( initial_state=initial_state, output_final_state=output_final_state, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=chunk_size ) # recompute g_cumsum in bwd pass @@ -1370,15 +1319,13 @@ def forward( ctx.chunk_size = chunk_size ctx.scale = scale ctx.offsets = offsets - ctx.indices = indices - ctx.head_first = head_first return o, ht @staticmethod @input_guard def backward(ctx, do, dht): q, k, v, g, g_cumsum, initial_state, A = ctx.saved_tensors - chunk_size, scale, offsets, indices, head_first = ctx.chunk_size, ctx.scale, ctx.offsets, ctx.indices, ctx.head_first + chunk_size, scale, offsets = ctx.chunk_size, ctx.scale, ctx.offsets dq, dk, dv, dg, dh0 = chunk_gla_bwd( q=q, k=k, @@ -1392,11 +1339,9 @@ def backward(ctx, do, dht): do=do, dht=dht, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=chunk_size ) - return dq.to(q), dk.to(k), dv.to(v), dg, None, dh0, None, None, None + return dq.to(q), dk.to(k), dv.to(v), dg, None, dh0, None, None @torch.compiler.disable @@ -1473,6 +1418,19 @@ def chunk_gla( >>> assert o.allclose(o_var.view(o.shape)) >>> assert ht.allclose(ht_var) """ + if head_first: + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) if cu_seqlens is not None: if q.shape[0] != 1: raise ValueError( @@ -1490,5 +1448,16 @@ def chunk_gla( ) if scale is None: scale = q.shape[-1] ** -0.5 - o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, cu_seqlens, head_first) + o, final_state = ChunkGLAFunction.apply( + q, + k, + v, + g, + scale, + initial_state, + output_final_state, + cu_seqlens, + ) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') return o, final_state diff --git a/fla/ops/gla/fused_recurrent.py b/fla/ops/gla/fused_recurrent.py index bb18848f34..c81cd2010d 100644 --- a/fla/ops/gla/fused_recurrent.py +++ b/fla/ops/gla/fused_recurrent.py @@ -1,9 +1,11 @@ # -*- coding: utf-8 -*- # Copyright (c) 2024, Songlin Yang, Yu Zhang +import warnings from typing import Optional, Tuple import torch +from einops import rearrange from fla.ops.common.fused_recurrent import fused_recurrent @@ -87,6 +89,23 @@ def fused_recurrent_gla( >>> assert o.allclose(o_var.view(o.shape)) >>> assert ht.allclose(ht_var) """ + if head_first: + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v)) + if gk is not None: + gk = rearrange(gk, 'b h t ... -> b t h ...') + if gv is not None: + gv = rearrange(gv, 'b h t ... -> b t h ...') + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) if cu_seqlens is not None: if q.shape[0] != 1: raise ValueError( @@ -116,6 +135,7 @@ def fused_recurrent_gla( output_final_state=output_final_state, reverse=reverse, cu_seqlens=cu_seqlens, - head_first=head_first ) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') return o, final_state diff --git a/fla/ops/gla/naive.py b/fla/ops/gla/naive.py index 507a7395c0..8c0b843a55 100644 --- a/fla/ops/gla/naive.py +++ b/fla/ops/gla/naive.py @@ -18,7 +18,7 @@ def naive_recurrent_gla( output_final_state: bool = False ): dtype = q.dtype - q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk)) + q, k, v, gk = map(lambda x: x.transpose(1, 2).float(), (q, k, v, gk)) B, H, T, K, V = *q.shape, v.shape[-1] o = torch.zeros_like(v) scale = K ** -0.5 @@ -38,4 +38,4 @@ def naive_recurrent_gla( if not output_final_state: h = None - return o.to(dtype), h + return o.transpose(1, 2).to(dtype), h diff --git a/fla/ops/gsa/chunk.py b/fla/ops/gsa/chunk.py index 838c07ebc7..b852a1a60f 100644 --- a/fla/ops/gsa/chunk.py +++ b/fla/ops/gsa/chunk.py @@ -6,7 +6,7 @@ import torch import triton import triton.language as tl -from einops import reduce +from einops import rearrange, reduce from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h from fla.ops.gla.chunk import chunk_gla_bwd, chunk_gla_fwd @@ -1150,7 +1150,7 @@ def chunk_gsa( output_final_state: Optional[bool] = False, checkpoint_level: Optional[int] = 2, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: Optional[bool] = True + head_first: Optional[bool] = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: @@ -1162,7 +1162,7 @@ def chunk_gsa( v (torch.Tensor): values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. s (torch.Tensor): - slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`. + slot representations of shape `[B, T, H, M]` if `head_first=False` else `[B, H, T, M]`. g (torch.Tensor): Forget gates of shape `[B, H, T, M]` applied to keys. If not provided, this function is equivalent to vanilla ABC. @@ -1266,7 +1266,8 @@ def chunk_gsa( hv0, output_final_state, checkpoint_level, - cu_seqlens, - head_first + cu_seqlens ) + if head_first: + o = rearrange(o, 'b h t ... -> b t h ...') return o, final_state diff --git a/fla/ops/gsa/fused_recurrent.py b/fla/ops/gsa/fused_recurrent.py index d704ec4732..1dc3e1ec9b 100644 --- a/fla/ops/gsa/fused_recurrent.py +++ b/fla/ops/gsa/fused_recurrent.py @@ -477,7 +477,7 @@ def fused_recurrent_gsa( v (torch.Tensor): values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. s (torch.Tensor): - slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`. + slot representations of shape `[B, T, H, M]` if `head_first=False` else `[B, H, T, M]`. g (torch.Tensor): Forget gates of shape `[B, H, T, M]` applied to keys. scale (Optional[int]): diff --git a/fla/ops/nsa/naive.py b/fla/ops/nsa/naive.py index 4949e8aef1..c365112e3e 100644 --- a/fla/ops/nsa/naive.py +++ b/fla/ops/nsa/naive.py @@ -27,7 +27,7 @@ def naive_nsa( v (torch.Tensor): values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. indices (torch.LongTensor): - Block indices of shape `[B, T, H, S]` if `head_first=True` else `[B, T, H, S]`. + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. `S` is the number of selected blocks for each query token, which is set to 16 in the paper. block_size (int): Selected block size. Default: 64. @@ -52,7 +52,7 @@ def naive_nsa( "Sequences with variable lengths are not supported for head-first mode" ) if head_first: - q, k, v, indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v, indices)) + q, k, v, indices = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, indices)) dtype = q.dtype G = q.shape[2] // k.shape[2] @@ -92,5 +92,5 @@ def naive_nsa( o[0][cu_seqlens[i]+i_q] = torch.einsum('n h, n h v -> h v', attn, v_i) if head_first: - o = rearrange(o, 'b t h d -> b h t d') + o = rearrange(o, 'b t h ... -> b h t ...') return o.to(dtype) diff --git a/fla/ops/retention/chunk.py b/fla/ops/retention/chunk.py index df5fc63ac4..ca619b2bfb 100644 --- a/fla/ops/retention/chunk.py +++ b/fla/ops/retention/chunk.py @@ -1,9 +1,11 @@ # -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +import warnings from typing import Optional, Tuple import torch +from einops import rearrange from fla.ops.simple_gla.chunk import chunk_simple_gla @@ -51,15 +53,21 @@ def chunk_retention( """ if head_first: - n_heads = q.shape[1] - else: - n_heads = q.shape[2] - s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log() - if head_first: - g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() - else: - g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() - return chunk_simple_gla( + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) + s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(q.shape[2]), dtype=torch.float))).log() + g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + o, final_state = chunk_simple_gla( q=q, k=k, v=v, @@ -67,6 +75,8 @@ def chunk_retention( g=g, initial_state=initial_state, output_final_state=output_final_state, - head_first=head_first, cu_seqlens=cu_seqlens ) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') + return o, final_state diff --git a/fla/ops/retention/fused_recurrent.py b/fla/ops/retention/fused_recurrent.py index 5af37a2b4b..459f8db5eb 100644 --- a/fla/ops/retention/fused_recurrent.py +++ b/fla/ops/retention/fused_recurrent.py @@ -1,9 +1,11 @@ # -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +import warnings from typing import Optional, Tuple import torch +from einops import rearrange from fla.ops.simple_gla.fused_recurrent import fused_recurrent_simple_gla @@ -19,13 +21,22 @@ def fused_recurrent_retention( cu_seqlens: Optional[torch.LongTensor] = None, head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: - H = q.shape[1] if head_first else q.shape[2] - s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(H), dtype=torch.float))).log() if head_first: - g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() - else: - g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() - return fused_recurrent_simple_gla( + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) + s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(q.shape[2]), dtype=torch.float))).log() + g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + o, final_state = fused_recurrent_simple_gla( q=q, k=k, v=v, @@ -35,5 +46,7 @@ def fused_recurrent_retention( output_final_state=output_final_state, reverse=reverse, cu_seqlens=cu_seqlens, - head_first=head_first ) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') + return o, final_state diff --git a/fla/ops/simple_gla/chunk.py b/fla/ops/simple_gla/chunk.py index f6d6973ef1..9825fd43b5 100644 --- a/fla/ops/simple_gla/chunk.py +++ b/fla/ops/simple_gla/chunk.py @@ -1,10 +1,12 @@ # -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +import warnings from typing import Optional, Tuple import torch import triton +from einops import rearrange from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv, chunk_fwd_o @@ -21,11 +23,9 @@ def chunk_simple_gla_fwd( initial_state: torch.Tensor, output_final_state: bool, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first) if g is not None else None + g = chunk_local_cumsum(g, chunk_size, offsets=offsets) if g is not None else None h, ht = chunk_fwd_h( k=k, v=v, @@ -36,7 +36,6 @@ def chunk_simple_gla_fwd( output_final_state=output_final_state, states_in_fp32=False, offsets=offsets, - head_first=head_first, chunk_size=chunk_size ) o = chunk_fwd_o( @@ -47,8 +46,6 @@ def chunk_simple_gla_fwd( h=h, scale=scale, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=chunk_size ) return g, o, ht @@ -64,8 +61,6 @@ def chunk_simple_gla_bwd( dht: torch.Tensor, scale: float, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # (SY 09/22) states_in_fp32 seems not affecting the error of dg but for safety, set to True @@ -79,7 +74,6 @@ def chunk_simple_gla_bwd( output_final_state=False, states_in_fp32=True, offsets=offsets, - head_first=head_first, chunk_size=chunk_size ) dh, dh0 = chunk_bwd_dh( @@ -95,7 +89,6 @@ def chunk_simple_gla_bwd( scale=scale, states_in_fp32=True, offsets=offsets, - head_first=head_first, chunk_size=chunk_size ) dq, dk, _, dg = chunk_bwd_dqkwg( @@ -108,8 +101,6 @@ def chunk_simple_gla_bwd( dh=dh, scale=scale, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=chunk_size ) dv = chunk_bwd_dv( @@ -120,8 +111,6 @@ def chunk_simple_gla_bwd( dh=dh, scale=scale, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=chunk_size ) return dq, dk, dv, dg, dh0 @@ -141,21 +130,11 @@ def forward( scale, initial_state, output_final_state, - offsets, - head_first + offsets ): - T = q.shape[2] if head_first else q.shape[1] + T = q.shape[1] chunk_size = min(64, max(16, triton.next_power_of_2(T))) - # 2-d indices denoting the offsets of chunks in each sequence - # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, - # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be - # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] - indices = None - if offsets is not None: - indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()]) - indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) - g, o, ht = chunk_simple_gla_fwd( q=q, k=k, @@ -165,23 +144,19 @@ def forward( initial_state=initial_state, output_final_state=output_final_state, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=chunk_size ) ctx.save_for_backward(q, k, v, g, initial_state) ctx.chunk_size = chunk_size ctx.scale = scale ctx.offsets = offsets - ctx.indices = indices - ctx.head_first = head_first return o.to(q.dtype), ht @staticmethod @input_guard @autocast_custom_bwd def backward(ctx, do, dht): - chunk_size, scale, offsets, indices, head_first = ctx.chunk_size, ctx.scale, ctx.offsets, ctx.indices, ctx.head_first + chunk_size, scale, offsets = ctx.chunk_size, ctx.scale, ctx.offsets q, k, v, g, initial_state = ctx.saved_tensors dq, dk, dv, dg, dh0 = chunk_simple_gla_bwd( q=q, @@ -193,16 +168,11 @@ def backward(ctx, do, dht): dht=dht, scale=scale, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=chunk_size ) - if g is not None: - dg = chunk_local_cumsum(dg, chunk_size, reverse=True, offsets=offsets, - indices=indices, head_first=head_first).to(g.dtype) - else: - dg = None - return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, None, dh0, None, None, None + dg = chunk_local_cumsum(dg, chunk_size, reverse=True, offsets=offsets).to(g.dtype) if g is not None else None + + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, None, dh0, None, None @torch.compiler.disable @@ -279,6 +249,19 @@ def chunk_simple_gla( >>> assert o.allclose(o_var.view(o.shape)) >>> assert ht.allclose(ht_var) """ + if head_first: + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) if cu_seqlens is not None: if q.shape[0] != 1: raise ValueError( @@ -304,7 +287,8 @@ def chunk_simple_gla( scale, initial_state, output_final_state, - cu_seqlens, - head_first + cu_seqlens ) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') return o, final_state diff --git a/fla/ops/simple_gla/fused_recurrent.py b/fla/ops/simple_gla/fused_recurrent.py index 21e742fc08..353762e416 100644 --- a/fla/ops/simple_gla/fused_recurrent.py +++ b/fla/ops/simple_gla/fused_recurrent.py @@ -1,9 +1,11 @@ # -*- coding: utf-8 -*- -# Copyright (c) 2024, Songlin Yang, Yu Zhang +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +import warnings from typing import Optional, Tuple import torch +from einops import rearrange from fla.ops.common.fused_recurrent import fused_recurrent @@ -85,6 +87,19 @@ def fused_recurrent_simple_gla( >>> assert o.allclose(o_var.view(o.shape)) >>> assert ht.allclose(ht_var) """ + if head_first: + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) if cu_seqlens is not None: if q.shape[0] != 1: raise ValueError( @@ -111,7 +126,8 @@ def fused_recurrent_simple_gla( initial_state=initial_state, output_final_state=output_final_state, reverse=reverse, - cu_seqlens=cu_seqlens, - head_first=head_first + cu_seqlens=cu_seqlens ) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') return o, final_state diff --git a/fla/ops/utils/cumsum.py b/fla/ops/utils/cumsum.py index 1dd19670ad..42b292f765 100644 --- a/fla/ops/utils/cumsum.py +++ b/fla/ops/utils/cumsum.py @@ -7,6 +7,7 @@ import triton import triton.language as tl +from fla.ops.common.utils import prepare_chunk_indices from fla.utils import check_shared_mem, input_guard BS_LIST = [32, 64] if check_shared_mem() else [16, 32] @@ -228,7 +229,6 @@ def chunk_local_cumsum_scalar( chunk_size: int, reverse: bool = False, offsets: Optional[torch.Tensor] = None, - indices: Optional[torch.Tensor] = None, head_first: bool = False, output_dtype: Optional[torch.dtype] = torch.float ) -> torch.Tensor: @@ -240,6 +240,7 @@ def chunk_local_cumsum_scalar( B = len(offsets) - 1 assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" BT = chunk_size + indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) grid = (NT, B * H) @@ -262,7 +263,6 @@ def chunk_local_cumsum_vector( chunk_size: int, reverse: bool = False, offsets: Optional[torch.Tensor] = None, - indices: Optional[torch.Tensor] = None, head_first: bool = False, output_dtype: Optional[torch.dtype] = torch.float ) -> torch.Tensor: @@ -271,6 +271,7 @@ def chunk_local_cumsum_vector( else: B, T, H, S = g.shape BT = chunk_size + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" @@ -386,16 +387,16 @@ def chunk_local_cumsum( chunk_size: int, reverse: bool = False, offsets: Optional[torch.Tensor] = None, - indices: Optional[torch.Tensor] = None, head_first: bool = False, - output_dtype: Optional[torch.dtype] = torch.float + output_dtype: Optional[torch.dtype] = torch.float, + **kwargs ) -> torch.Tensor: if offsets is not None: assert g.shape[0] == 1, "Only batch size 1 is supported when offsets are provided" if len(g.shape) == 3: - return chunk_local_cumsum_scalar(g, chunk_size, reverse, offsets, indices, head_first, output_dtype) + return chunk_local_cumsum_scalar(g, chunk_size, reverse, offsets, head_first, output_dtype) elif len(g.shape) == 4: - return chunk_local_cumsum_vector(g, chunk_size, reverse, offsets, indices, head_first, output_dtype) + return chunk_local_cumsum_vector(g, chunk_size, reverse, offsets, head_first, output_dtype) else: raise ValueError( f"Unsupported input shape {g.shape}. " diff --git a/tests/ops/test_cumsum.py b/tests/ops/test_cumsum.py index 58143e65ff..4f8a3e0dc7 100644 --- a/tests/ops/test_cumsum.py +++ b/tests/ops/test_cumsum.py @@ -52,7 +52,7 @@ def cumsum_global_reference(s, reverse=False, head_first=False): @pytest.mark.parametrize("H", test_h_list) @pytest.mark.parametrize("D", test_d_list) @pytest.mark.parametrize("chunk_size", [32, 64]) -@pytest.mark.parametrize("dtype", [torch.float, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float, torch.float16]) @pytest.mark.parametrize("head_first", [True, False]) @pytest.mark.parametrize("reverse", [False, True]) @pytest.mark.skipif( @@ -72,7 +72,7 @@ def test_cumsum_local_vector(B, T, H, D, dtype, head_first, reverse, chunk_size) @pytest.mark.parametrize("B", test_b_list) @pytest.mark.parametrize("T", test_t_list) @pytest.mark.parametrize("H", test_h_list) -@pytest.mark.parametrize("dtype", [torch.float, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float, torch.float16]) @pytest.mark.parametrize("head_first", [True, False]) @pytest.mark.parametrize("reverse", [True, False]) @pytest.mark.parametrize("chunk_size", [32, 64]) @@ -94,7 +94,7 @@ def test_cumsum_local_scalar(B, T, H, dtype, head_first, reverse, chunk_size): @pytest.mark.parametrize("T", test_t_list) @pytest.mark.parametrize("H", test_h_list) @pytest.mark.parametrize("D", test_d_list) -@pytest.mark.parametrize("dtype", [torch.float, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float, torch.float16]) @pytest.mark.parametrize("head_first", [False, True]) @pytest.mark.parametrize("reverse", [True, False]) @pytest.mark.skipif( @@ -118,7 +118,7 @@ def test_cumsum_global_vector(B, T, H, D, dtype, head_first, reverse): @pytest.mark.parametrize("B", test_b_list) @pytest.mark.parametrize("T", test_t_list) @pytest.mark.parametrize("H", test_h_list) -@pytest.mark.parametrize("dtype", [torch.float, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float, torch.float16]) @pytest.mark.parametrize("head_first", [False, True]) @pytest.mark.parametrize("reverse", [True, False]) @pytest.mark.skipif( diff --git a/tests/ops/test_gla.py b/tests/ops/test_gla.py index 4a979ff201..c60f995286 100644 --- a/tests/ops/test_gla.py +++ b/tests/ops/test_gla.py @@ -49,10 +49,10 @@ def test_fused_recurrent( ): torch.manual_seed(42) - q = torch.randn((B, H, T, D), dtype=dtype, device=device).requires_grad_() - k = torch.randn((B, H, T, D), dtype=dtype, device=device).requires_grad_() - v = torch.randn((B, H, T, D), dtype=dtype, device=device).requires_grad_() - g = F.logsigmoid(torch.randn((B, H, T, D), dtype=dtype, device=device)).requires_grad_() + q = torch.randn((B, T, H, D), dtype=dtype, device=device).requires_grad_() + k = torch.randn((B, T, H, D), dtype=dtype, device=device).requires_grad_() + v = torch.randn((B, T, H, D), dtype=dtype, device=device).requires_grad_() + g = F.logsigmoid(torch.randn((B, T, H, D), dtype=dtype, device=device)).requires_grad_() h0 = torch.randn(B, H, D, D, device=device).requires_grad_() do = torch.randn_like(v) @@ -136,7 +136,15 @@ def test_chunk( do = torch.randn_like(v) dht = torch.zeros((B, H, D, D), dtype=dtype, device=device) - tri, tri_ht = chunk_gla(q, k, v, g, initial_state=h0, output_final_state=True, head_first=head_first) + tri, tri_ht = chunk_gla( + q, + k, + v, + g, + initial_state=h0, + output_final_state=True, + head_first=head_first + ) ((tri * do).sum() + (tri_ht * dht).sum()).backward() tri_dq, q.grad = q.grad.clone(), None tri_dk, k.grad = k.grad.clone(), None @@ -144,8 +152,24 @@ def test_chunk( tri_dg, g.grad = g.grad.clone(), None tri_dh0, h0.grad = h0.grad.clone(), None - ref, ref_ht = fused_recurrent_gla(q, k, v, g, initial_state=h0, output_final_state=True, head_first=head_first) - ref, _ = fused_recurrent_gla(q, k, v, g, initial_state=h0, output_final_state=False, head_first=head_first) + ref, ref_ht = fused_recurrent_gla( + q, + k, + v, + g, + initial_state=h0, + output_final_state=True, + head_first=head_first + ) + ref, _ = fused_recurrent_gla( + q, + k, + v, + g, + initial_state=h0, + output_final_state=False, + head_first=head_first + ) (ref * do).sum().backward() ref_dq, q.grad = q.grad.clone(), None ref_dk, k.grad = k.grad.clone(), None @@ -181,9 +205,9 @@ def test_chunk_varlen( torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' # randomly split the sequence into N segments - offsets = torch.cat([ + cu_seqlens = torch.cat([ torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 1)[:N-1]], + torch.arange(16, T)[torch.randperm(T - 16)[:N-1]], torch.tensor([T], dtype=torch.long) ], 0).to(device).sort()[0] # seq-first required for inputs with variable lengths @@ -198,15 +222,13 @@ def test_chunk_varlen( q, k, v, g, initial_state=h0, output_final_state=True, - cu_seqlens=offsets, - head_first=False + cu_seqlens=cu_seqlens ) ref, _ = fused_recurrent_gla( q, k, v, g, initial_state=h0, output_final_state=False, - cu_seqlens=offsets, - head_first=False + cu_seqlens=cu_seqlens ) (ref * do).sum().backward() @@ -223,8 +245,7 @@ def test_chunk_varlen( g, initial_state=h0, output_final_state=True, - cu_seqlens=offsets, - head_first=False + cu_seqlens=cu_seqlens ) ((tri * do).sum()).backward() tri_dq, q.grad = q.grad.clone(), None diff --git a/tests/ops/test_retention.py b/tests/ops/test_retention.py index 8692b50814..597ef2d3ba 100644 --- a/tests/ops/test_retention.py +++ b/tests/ops/test_retention.py @@ -30,7 +30,7 @@ @pytest.mark.parametrize("K", test_d_list) @pytest.mark.parametrize("expand_ratio", [1, 2]) @pytest.mark.parametrize("head_first", [True, False]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.skipif( os.getenv("SKIP_TEST_CHUNK_VARLEN") == "0", reason="Skipping test because TEST_CHUNK_VARLEN is enabled" @@ -84,7 +84,7 @@ def test_chunk( @pytest.mark.parametrize("H", test_h_list) @pytest.mark.parametrize("K", test_d_list) @pytest.mark.parametrize("expand_ratio", [1, 2]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.skipif( os.getenv("SKIP_TEST_CHUNK_VARLEN") == "1", reason="Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set" @@ -104,7 +104,7 @@ def test_chunk_varlen( # randomly split the sequence into N segments offsets = torch.cat([ torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 1)[:N-1]], + torch.arange(16, T)[torch.randperm(T - 16)[:N-1]], torch.tensor([T], dtype=torch.long) ], 0).to(device).sort()[0] # seq-first required for inputs with variable lengths diff --git a/tests/ops/test_simple_gla.py b/tests/ops/test_simple_gla.py index e25b950b8b..73de630f44 100644 --- a/tests/ops/test_simple_gla.py +++ b/tests/ops/test_simple_gla.py @@ -37,7 +37,7 @@ def chunk_simple_gla_ref( g: torch.Tensor, initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, - BT: int = 64, + chunk_size: int = 64, scale: Optional[float] = None, head_first: bool = True ): @@ -50,6 +50,7 @@ def chunk_simple_gla_ref( scale = 1.0 / q.shape[-1] ** 0.5 T = q.shape[-2] + BT = chunk_size pad_len = (BT - (T % BT)) % BT if pad_len > 0: # Pad all tensors @@ -59,19 +60,17 @@ def chunk_simple_gla_ref( g = F.pad(g, (0, pad_len)) q, k, v, g = map(lambda x: x.to(torch.float32), [q, k, v, g]) decay = g - chunk_size = BT - b, h, l, d_k = q.shape + b, h, t, d_k = q.shape d_v = v.shape[-1] q = q * scale - q, k, v, decay = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', - c=chunk_size), [q, k, v, decay.unsqueeze(-1)]) + q, k, v, decay = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, decay.unsqueeze(-1)]) decay = decay.squeeze(-1).cumsum(-1) L_mask = ((decay.unsqueeze(-1) - decay.unsqueeze(-2)).tril().exp().float()).tril() S = k.new_zeros(b, h, d_k, d_v) if initial_state is not None: S = initial_state o = torch.zeros_like(v) - for i in range(0, l // chunk_size): + for i in range(0, t // chunk_size): q_i, k_i, v_i = q[:, :, i], k[:, :, i], v[:, :, i] attn = (q_i @ k_i.transpose(-1, -2) * L_mask[:, :, i]) o_inter = (q_i * decay[:, :, i, :, None].exp()) @ S @@ -116,7 +115,7 @@ def parallel_simple_gla_ref(q, k, v, g, scale=None, head_first=True): @pytest.mark.parametrize("H", test_h_list) @pytest.mark.parametrize("D", test_d_list) @pytest.mark.parametrize("gate_logit_normalizer", test_gate_list) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("head_first", [False, True]) @pytest.mark.parametrize("scale", [1, 0.1]) @pytest.mark.skipif( @@ -188,7 +187,7 @@ def test_chunk( @pytest.mark.parametrize("T", test_t_varlen_list) @pytest.mark.parametrize("H", test_h_list) @pytest.mark.parametrize("D", test_d_list) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.skipif( os.getenv("SKIP_TEST_CHUNK_VARLEN") == "1", reason="Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set" @@ -264,7 +263,7 @@ def test_chunk_varlen( @pytest.mark.parametrize("T", test_t_varlen_list) @pytest.mark.parametrize("H", test_h_list) @pytest.mark.parametrize("D", test_d_list) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.skipif( os.getenv("SKIP_TEST_CHUNK_VARLEN") == "1", reason="Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set" @@ -335,7 +334,7 @@ def test_parallel_varlen( @pytest.mark.parametrize("gate_logit_normalizer", test_gate_list) @pytest.mark.parametrize("head_first", [True, False]) @pytest.mark.parametrize("scale", [0.1]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.skipif( os.getenv("SKIP_TEST_CHUNK_VARLEN") == "0", reason="Skipping test because TEST_CHUNK_VARLEN is enabled" From 112103e46f68fac49578881fa57acc3f2343787f Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Tue, 8 Apr 2025 20:43:13 +0800 Subject: [PATCH 2/3] Remove head_first option --- fla/ops/common/chunk_h.py | 90 ++++------------- fla/ops/common/chunk_h_parallel.py | 150 +++++++--------------------- fla/ops/common/chunk_h_split.py | 120 +++++----------------- fla/ops/common/chunk_o.py | 138 +++++++++++-------------- fla/ops/gla/chunk.py | 155 +++++++++++++++++------------ 5 files changed, 230 insertions(+), 423 deletions(-) diff --git a/fla/ops/common/chunk_h.py b/fla/ops/common/chunk_h.py index 7d6bbdc4f1..a72d5c23db 100644 --- a/fla/ops/common/chunk_h.py +++ b/fla/ops/common/chunk_h.py @@ -55,7 +55,6 @@ def chunk_fwd_kernel_h( USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_h = i_nh // H, i_nh % H @@ -79,18 +78,11 @@ def chunk_fwd_kernel_h( for i_t in range(NT): i_s = i_t // (BS // BT) - if HEAD_FIRST: - p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - o_h = (i_nh * NS + i_s).to(tl.int64) * K*V - p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - else: - p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - - o_h = ((boh + i_s) * H + i_h).to(tl.int64) * K*V - p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + o_h = ((boh + i_s) * H + i_h).to(tl.int64) * K*V + p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) if i_t % (BS // BT) == 0: tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) @@ -102,26 +94,16 @@ def chunk_fwd_kernel_h( # scalar decay if USE_G: - if HEAD_FIRST: - b_g_last = tl.load(g + i_nh * T + last_idx) - p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT) - p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) - else: - b_g_last = tl.load(g + bos * H + last_idx * H + i_h) - p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h b_h *= exp(b_g_last) b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype) # vector decay, h = Diag(gk) @ h if USE_GK: - if HEAD_FIRST: - p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK) - p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) - else: - p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) b_h *= exp(b_gk_last)[:, None] @@ -131,13 +113,8 @@ def chunk_fwd_kernel_h( # vector decay, h = h @ Diag(gv) if USE_GV: - if HEAD_FIRST: - p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV) - p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) - else: - p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) b_h *= exp(b_gv_last)[None, :] @@ -196,10 +173,8 @@ def chunk_bwd_kernel_dh( STORE_INITIAL_STATE_GRADIENT: tl.constexpr, USE_FINAL_STATE_GRADIENT: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_bg = i_nh // NG i_n, i_hq = i_nh // HQ, i_nh % HQ i_h = i_hq // NG if IS_VARLEN: @@ -222,49 +197,31 @@ def chunk_bwd_kernel_dh( for i_t in range(NT - 1, -1, -1): i_s = i_t // (BS // BT) - if HEAD_FIRST: - o_dh = (i_nh * NS + i_s).to(tl.int64) * K*V - p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - else: - o_dh = ((boh + i_s) * H + i_h).to(tl.int64) * K*V - p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + o_dh = ((boh + i_s) * H + i_h).to(tl.int64) * K*V + p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) if i_t % (BS // BT) == 0: tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) last_idx = min(i_t * BT + BT, T) - 1 # [BK, BT] - if HEAD_FIRST: - p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - else: - p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) # [BT, BV] b_do = tl.load(p_do, boundary_check=(0, 1)) if USE_G: - if HEAD_FIRST: - p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT) - p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) - b_g_last = tl.load(g + i_bg * T + last_idx) - else: - p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h - b_g_last = tl.load(g + (bos + last_idx) * H + i_h) + p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h + b_g_last = tl.load(g + (bos + last_idx) * H + i_h) b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype) b_dh *= exp(b_g_last) if USE_GK: - if HEAD_FIRST: - p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK) - p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) - else: - p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) b_gk = tl.load(p_gk, boundary_check=(0, 1)) b_q = (b_q * exp(b_gk)).to(b_q.dtype) @@ -272,13 +229,8 @@ def chunk_bwd_kernel_dh( b_dh *= exp(b_gk_last)[:, None] if USE_GV: - if HEAD_FIRST: - p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV) - p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) - else: - p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) b_gv = tl.load(p_gv, boundary_check=(0, 1)) b_do = (b_do * exp(b_gv)).to(b_do.dtype) @@ -340,7 +292,6 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), USE_G=g is not None, USE_GK=gk is not None, USE_GV=gv is not None, - HEAD_FIRST=False ) return h, ht @@ -402,6 +353,5 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), USE_G=g is not None, USE_GK=gk is not None, USE_GV=gv is not None, - HEAD_FIRST=False ) return dh, dh0 diff --git a/fla/ops/common/chunk_h_parallel.py b/fla/ops/common/chunk_h_parallel.py index 163bb41b94..8bafbeec3b 100644 --- a/fla/ops/common/chunk_h_parallel.py +++ b/fla/ops/common/chunk_h_parallel.py @@ -54,8 +54,7 @@ def chunk_fwd_kernel_h_parallel( USE_GV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, - IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr + IS_VARLEN: tl.constexpr ): i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) @@ -79,14 +78,9 @@ def chunk_fwd_kernel_h_parallel( i_n, i_tg = i_b, i_b * NT + i_t i_nh = i_n * H + i_h - if HEAD_FIRST: - p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - else: - p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) if i_t == 0: if USE_INITIAL_STATE: @@ -104,26 +98,15 @@ def chunk_fwd_kernel_h_parallel( last_idx = min(i_t * BT + BT, T) - 1 # scalar decay if USE_G: - if HEAD_FIRST: - b_g_last = tl.load(g + i_bh * T + last_idx) - p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT) - p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) - else: - b_g_last = tl.load(g + bos * H + last_idx * H + i_h) - p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype) # vector decay, h = Diag(gk) @ h if USE_GK: - if HEAD_FIRST: - p_gk = tl.make_block_ptr(gk + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_gk_last = gk + i_bh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK) - p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) - else: - p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) - + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) b_gk = tl.load(p_gk, boundary_check=(0, 1)) @@ -131,25 +114,16 @@ def chunk_fwd_kernel_h_parallel( # vector decay, h = h @ Diag(gv) if USE_GV: - if HEAD_FIRST: - p_gv = tl.make_block_ptr(gv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_gv_last = gv + i_bh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV) - p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) - else: - p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) - b_gv = tl.load(p_gv, boundary_check=(0, 1)) b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype) b_h = tl.dot(b_k, b_v) if i_t < NT - 1: - if HEAD_FIRST: - p_h = tl.make_block_ptr(h + (i_bh * NT + i_t + 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - else: - p_h = tl.make_block_ptr(h + ((i_tg + 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_h = tl.make_block_ptr(h + ((i_tg + 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) elif STORE_FINAL_STATE: p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) @@ -191,8 +165,7 @@ def chunk_fwd_kernel_h_reduction( USE_GK: tl.constexpr, USE_GV: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, - IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr + IS_VARLEN: tl.constexpr ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_h = i_nh // H, i_nh % H @@ -209,10 +182,7 @@ def chunk_fwd_kernel_h_reduction( # [BK, BV] b_h = tl.zeros([BK, BV], dtype=tl.float32) for i_t in range(NT): - if HEAD_FIRST: - p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - else: - p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) if i_t > 0: tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) @@ -220,30 +190,19 @@ def chunk_fwd_kernel_h_reduction( last_idx = min(i_t * BT + BT, T) - 1 # scalar decay if USE_G: - if HEAD_FIRST: - b_g_last = tl.load(g + i_nh * T + last_idx) - else: - b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) b_h *= exp(b_g_last) # vector decay, h = Diag(gk) @ h if USE_GK: - if HEAD_FIRST: - p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK) - p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) - else: - p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) b_h *= exp(b_gk_last)[:, None] # vector decay, h = h @ Diag(gv) if USE_GV: - if HEAD_FIRST: - p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV) - p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) - else: - p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) b_h *= exp(b_gv_last)[None, :] @@ -297,14 +256,13 @@ def chunk_bwd_kernel_dh_parallel( USE_GV: tl.constexpr, STORE_INITIAL_STATE_GRADIENT: tl.constexpr, USE_FINAL_STATE_GRADIENT: tl.constexpr, - IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr + IS_VARLEN: tl.constexpr ): i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) NV = tl.cdiv(V, BV) i_k, i_v = i_kv // NV, i_kv % NV - i_b, i_hq, i_bg = i_bh // HQ, i_bh % HQ, i_bh // NG + i_b, i_hq = i_bh // HQ, i_bh % HQ i_h = i_hq // NG if IS_VARLEN: i_tg = i_t @@ -318,14 +276,9 @@ def chunk_bwd_kernel_dh_parallel( i_n, i_tg = i_b, i_b * NT + i_t i_nh = i_n * HQ + i_hq - if HEAD_FIRST: - p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - else: - p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) if i_t == NT - 1: if USE_FINAL_STATE_GRADIENT: @@ -342,36 +295,23 @@ def chunk_bwd_kernel_dh_parallel( b_do = tl.load(p_do, boundary_check=(0, 1)) if USE_G: - if HEAD_FIRST: - p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT) - p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) - else: - p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h + p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype) if USE_GK: - if HEAD_FIRST: - p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - else: - p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) b_gk = tl.load(p_gk, boundary_check=(0, 1)) b_q = (b_q * exp(b_gk)).to(b_q.dtype) if USE_GV: - if HEAD_FIRST: - p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - else: - p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_gv = tl.load(p_gv, boundary_check=(0, 1)) b_do = (b_do * exp(b_gv)).to(b_do.dtype) b_dh = tl.dot(b_q, b_do) if i_t > 0: - if HEAD_FIRST: - p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t - 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - else: - p_dh = tl.make_block_ptr(dh + ((i_tg - 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + ((i_tg - 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) elif STORE_INITIAL_STATE_GRADIENT: p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) @@ -415,11 +355,9 @@ def chunk_bwd_kernel_dh_reduction( USE_GK: tl.constexpr, USE_GV: tl.constexpr, STORE_INITIAL_STATE_GRADIENT: tl.constexpr, - IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr + IS_VARLEN: tl.constexpr ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_bg = i_nh // NG i_n, i_hq = i_nh // HQ, i_nh % HQ i_h = i_hq // NG if IS_VARLEN: @@ -435,39 +373,23 @@ def chunk_bwd_kernel_dh_reduction( # [BK, BV] b_dh = tl.zeros([BK, BV], dtype=tl.float32) for i_t in range(NT - 1, -1, -1): - if HEAD_FIRST: - p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - else: - p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) b_dh += tl.load(p_dh, boundary_check=(0, 1)).to(tl.float32) if i_t < NT - 1: tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) last_idx = min(i_t * BT + BT, T) - 1 if USE_G: - if HEAD_FIRST: - b_g_last = tl.load(g + i_bg * T + last_idx) - else: - b_g_last = tl.load(g + (bos + last_idx) * H + i_h) + b_g_last = tl.load(g + (bos + last_idx) * H + i_h) b_dh *= exp(b_g_last) if USE_GK: - if HEAD_FIRST: - p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK) - p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) - else: - p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) - + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) b_dh *= exp(b_gk_last)[:, None] if USE_GV: - if HEAD_FIRST: - p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV) - p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) - else: - p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) - + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) b_dh *= exp(b_gv_last)[None, :] @@ -520,8 +442,7 @@ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), BT=BT, USE_G=g is not None, USE_GK=gk is not None, - USE_GV=gv is not None, - HEAD_FIRST=False + USE_GV=gv is not None ) kvt, ht = ht, (torch.empty_like(ht) if output_final_state else None) def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) @@ -541,8 +462,7 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), BT=BT, USE_G=g is not None, USE_GK=gk is not None, - USE_GV=gv is not None, - HEAD_FIRST=False + USE_GV=gv is not None ) h = h.to(k.dtype) if not states_in_fp32 else h return h, ht @@ -600,8 +520,7 @@ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NG=NG, USE_G=g is not None, USE_GK=gk is not None, - USE_GV=gv is not None, - HEAD_FIRST=False + USE_GV=gv is not None ) doq0, dh0 = dh0, (torch.empty_like(dh0) if dh0 is not None else None) @@ -624,8 +543,7 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), NG=NG, USE_G=g is not None, USE_GK=gk is not None, - USE_GV=gv is not None, - HEAD_FIRST=False + USE_GV=gv is not None ) dh = dh.to(q.dtype) if not states_in_fp32 else dh return dh, dh0 diff --git a/fla/ops/common/chunk_h_split.py b/fla/ops/common/chunk_h_split.py index dc40b8fa79..79270dee93 100644 --- a/fla/ops/common/chunk_h_split.py +++ b/fla/ops/common/chunk_h_split.py @@ -52,7 +52,6 @@ def chunk_fwd_kernel_h_split( USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): # handle one split at a time # i_h: head index @@ -81,12 +80,8 @@ def chunk_fwd_kernel_h_split( p_hr = tl.make_block_ptr(hr + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) tl.store(p_hr, b_h.to(p_hr.dtype.element_ty), boundary_check=(0, 1)) for i_t in range(tl.cdiv(i_s * S, BT), tl.cdiv(min(i_s * S + S, T), BT)): - if HEAD_FIRST: - p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - else: - p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) # [BK, BT] b_k = tl.load(p_k, boundary_check=(0, 1)) # [BT, BV] @@ -95,26 +90,16 @@ def chunk_fwd_kernel_h_split( # scalar decay if USE_G: - if HEAD_FIRST: - b_g_last = tl.load(g + i_nh * T + last_idx) - p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT) - p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) - else: - b_g_last = tl.load(g + bos * H + last_idx * H + i_h) - p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h b_h *= exp(b_g_last) b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype) # vector decay, h = Diag(gk) @ h if USE_GK: - if HEAD_FIRST: - p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK) - p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) - else: - p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) b_h *= exp(b_gk_last)[:, None] @@ -124,13 +109,8 @@ def chunk_fwd_kernel_h_split( # vector decay, h = h @ Diag(gv) if USE_GV: - if HEAD_FIRST: - p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV) - p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) - else: - p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) b_h *= exp(b_gv_last)[None, :] @@ -187,7 +167,6 @@ def chunk_fwd_kernel_h_reduction( USE_GV: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_h = i_nh // H, i_nh % H @@ -213,31 +192,18 @@ def chunk_fwd_kernel_h_reduction( last_idx = min(i_t * BT + BT, T) - 1 # scalar decay if USE_G: - if HEAD_FIRST: - b_g_last = tl.load(g + i_nh * T + last_idx) - else: - b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) b_h *= exp(b_g_last) # vector decay, h = Diag(gk) @ h if USE_GK: - if HEAD_FIRST: - p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK) - p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) - else: - p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) - + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) b_h *= exp(b_gk_last)[:, None] # vector decay, h = h @ Diag(gv) if USE_GV: - if HEAD_FIRST: - p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV) - p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) - else: - p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) - + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) b_h *= exp(b_gv_last)[None, :] @@ -294,7 +260,6 @@ def chunk_bwd_kernel_dh_split( USE_FINAL_STATE_GRADIENT: tl.constexpr, STORE_INITIAL_STATE_GRADIENT: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): # handle one split at a time # i_h: head index @@ -312,7 +277,7 @@ def chunk_bwd_kernel_dh_split( i_n, i_s = i_ss // NS, i_ss % NS bos, eos = i_n * T, i_n * T + T i_nh = i_n * HQ + i_hq - i_ng, i_h = i_nh // NG, i_hq // NG + i_h = i_hq // NG # [BK, BV] b_dh = tl.zeros([BK, BV], dtype=tl.float32) @@ -324,12 +289,8 @@ def chunk_bwd_kernel_dh_split( tl.store(p_dhr, b_dh.to(p_dhr.dtype.element_ty), boundary_check=(0, 1)) for i_t in range(tl.cdiv(min(i_s * S + S, T), BT) - 1, tl.cdiv(i_s * S, BT) - 1, -1): - if HEAD_FIRST: - p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - else: - p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) @@ -338,25 +299,15 @@ def chunk_bwd_kernel_dh_split( last_idx = min(i_t * BT + BT, T) - 1 if USE_G: - if HEAD_FIRST: - p_g = g + i_ng * T + i_t * BT + tl.arange(0, BT) - p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) - b_g_last = tl.load(g + i_ng * T + last_idx) - else: - p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h - b_g_last = tl.load(g + (bos + last_idx) * H + i_h) + p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h + b_g_last = tl.load(g + (bos + last_idx) * H + i_h) b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype) b_dh *= exp(b_g_last) if USE_GK: - if HEAD_FIRST: - p_gk = tl.make_block_ptr(gk + i_ng * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_gk_last = gk + (i_ng * T + last_idx) * K + i_k * BK + tl.arange(0, BK) - p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) - else: - p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) b_gk = tl.load(p_gk, boundary_check=(0, 1)) b_q = (b_q * exp(b_gk)).to(b_q.dtype) @@ -364,13 +315,8 @@ def chunk_bwd_kernel_dh_split( b_dh *= exp(b_gk_last)[:, None] if USE_GV: - if HEAD_FIRST: - p_gv = tl.make_block_ptr(gv + i_ng * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_gv_last = gv + (i_ng * T + last_idx) * V + i_v * BV + tl.arange(0, BV) - p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) - else: - p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) b_gv = tl.load(p_gv, boundary_check=(0, 1)) b_do = (b_do * exp(b_gv)).to(b_do.dtype) @@ -427,11 +373,10 @@ def chunk_bwd_kernel_dh_reduction( USE_GV: tl.constexpr, STORE_INITIAL_STATE_GRADIENT: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_hq = i_nh // HQ, i_nh % HQ - i_ng, i_h = i_nh // NG, i_hq // NG + i_h = i_hq // NG if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -453,29 +398,16 @@ def chunk_bwd_kernel_dh_reduction( last_idx = min(i_t * BT + BT, T) - 1 # scalar decay if USE_G: - if HEAD_FIRST: - b_g_last = tl.load(g + i_ng * T + last_idx) - else: - b_g_last = tl.load(g + (bos + last_idx) * H + i_h) + b_g_last = tl.load(g + (bos + last_idx) * H + i_h) b_dh *= exp(b_g_last) if USE_GK: - if HEAD_FIRST: - p_gk_last = gk + (i_ng * T + last_idx) * K + i_k * BK + tl.arange(0, BK) - p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) - else: - p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) - + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) b_dh *= exp(b_gk_last)[:, None] if USE_GV: - if HEAD_FIRST: - p_gv_last = gv + (i_ng * T + last_idx) * V + i_v * BV + tl.arange(0, BV) - p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) - else: - p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) - + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) b_dh *= exp(b_gv_last)[None, :] @@ -550,7 +482,6 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), USE_G=g is not None, USE_GK=gk is not None, USE_GV=gv is not None, - HEAD_FIRST=head_first ) def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) chunk_fwd_kernel_h_reduction[grid]( @@ -571,7 +502,6 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), USE_G=g is not None, USE_GK=gk is not None, USE_GV=gv is not None, - HEAD_FIRST=head_first ) return hr, ht @@ -648,7 +578,6 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), USE_G=g is not None, USE_GK=gk is not None, USE_GV=gv is not None, - HEAD_FIRST=head_first, ) def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ) @@ -672,6 +601,5 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), USE_G=g is not None, USE_GK=gk is not None, USE_GV=gv is not None, - HEAD_FIRST=head_first ) return dhr, dh0 diff --git a/fla/ops/common/chunk_o.py b/fla/ops/common/chunk_o.py index c9b38497c7..b5ab19b4f5 100644 --- a/fla/ops/common/chunk_o.py +++ b/fla/ops/common/chunk_o.py @@ -49,7 +49,6 @@ def chunk_fwd_kernel_o( BV: tl.constexpr, USE_G: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -65,22 +64,19 @@ def chunk_fwd_kernel_o( i_tg = i_b * NT + i_t bos, eos = i_b * T, i_b * T + T - s_qk = K if HEAD_FIRST else H*K - s_vo = V if HEAD_FIRST else H*V - s_g = 1 if HEAD_FIRST else H # offset calculation - q += (i_bh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K) - k += (i_bh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K) - v += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V) - o += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V) - h += ((i_bh * NT + i_t).to(tl.int64) * K*V) if HEAD_FIRST else ((i_tg * H + i_h).to(tl.int64) * K*V) + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + o += (bos * H + i_h) * V + h += (i_tg * H + i_h).to(tl.int64) * K*V b_o = tl.zeros([BT, BV], dtype=tl.float32) b_A = tl.zeros([BT, BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): - p_q = tl.make_block_ptr(q, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) # [BT, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) @@ -95,8 +91,8 @@ def chunk_fwd_kernel_o( b_A += tl.dot(b_q, b_k) if USE_G: - g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) - p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) b_g = tl.load(p_g, boundary_check=(0,)) b_o = b_o * exp(b_g)[:, None] b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :]) @@ -105,8 +101,8 @@ def chunk_fwd_kernel_o( m_A = o_i[:, None] >= o_i[None, :] b_A = tl.where(m_A, b_A, 0) - p_v = tl.make_block_ptr(v, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_o = tl.make_block_ptr(o, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_v = tl.load(p_v, boundary_check=(0, 1)) # to fix mma -> mma layout conversion @@ -157,7 +153,6 @@ def chunk_bwd_kernel_dqkwg( USE_G: tl.constexpr, USE_DW: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -175,23 +170,20 @@ def chunk_bwd_kernel_dqkwg( bos, eos = i_b * T, i_b * T + T # offset calculation - v += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V - do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V - h += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V - dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V - q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K - k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K - dq += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K - dk += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K - s_qk = K if HEAD_FIRST else H*K - s_vo = V if HEAD_FIRST else H*V - s_g = 1 if HEAD_FIRST else H + v += (bos * H + i_h) * V + do += (bos * H + i_h) * V + h += (i_tg * H + i_h).to(tl.int64) * K*V + dh += (i_tg * H + i_h).to(tl.int64) * K*V + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + dq += (bos * H + i_h) * K + dk += (bos * H + i_h) * K # for delta rule only if USE_DW: - dw += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K - dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V - w += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + dw += (bos * H + i_h) * K + dv += (bos * H + i_h) * V + w += (bos * H + i_h) * K b_dq = tl.zeros([BT, BK], dtype=tl.float32) b_dk = tl.zeros([BT, BK], dtype=tl.float32) @@ -200,8 +192,8 @@ def chunk_bwd_kernel_dqkwg( b_dw = tl.zeros([BT, BK], dtype=tl.float32) if USE_DW else None for i_v in range(tl.cdiv(V, BV)): - p_v = tl.make_block_ptr(v, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) # [BT, BV] @@ -219,36 +211,36 @@ def chunk_bwd_kernel_dqkwg( # [BT, BV] @ [BV, BK] -> [BT, BK] b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) if USE_DW: - p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_dv = tl.load(p_dv, boundary_check=(0, 1)) b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype)) if USE_DW and not USE_G: - p_dw = tl.make_block_ptr(dw, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) tl.debug_barrier() o_i = tl.arange(0, BT) - p_q = tl.make_block_ptr(q, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) - p_dq = tl.make_block_ptr(dq, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dk = tl.make_block_ptr(dk, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) if USE_G: b_dg = tl.zeros([BT,], dtype=tl.float32) - g += i_bh * T if HEAD_FIRST else bos * H + i_h - dg += i_bh * T if HEAD_FIRST else bos * H + i_h - p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) + g += bos * H + i_h + dg += bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) b_g = tl.load(p_g, boundary_check=(0,)) - b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * s_g) + b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * H) b_dg_last *= exp(b_g_last) if USE_DW: - p_w = tl.make_block_ptr(w, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dw = tl.make_block_ptr(dw, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_w = tl.load(p_w, boundary_check=(0, 1)) b_dw = b_dw * exp(b_g)[:, None] tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) @@ -270,7 +262,7 @@ def chunk_bwd_kernel_dqkwg( # [BT, BK] b_dq += tl.dot(b_ds, b_k) b_dk += tl.dot(tl.trans(b_ds), b_q) - p_dg = tl.make_block_ptr(dg, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) + p_dg = tl.make_block_ptr(dg, (T,), (H,), (i_t * BT,), (BT,), (0,)) # (SY 09/21) revcumsum in a separate kernel due to strange triton compiler issue # b_dg = tl.dot(tl.where(o_i[:, None] <= o_i[None, :], 1., 0.), b_dg, allow_tf32=False) + b_dg_last) b_dg = tl.where(o_i < min(BT, T-i_t*BT) - 1, b_dg, b_dg + b_dg_last) @@ -319,7 +311,6 @@ def chunk_bwd_kernel_dv( BV: tl.constexpr, USE_G: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -337,19 +328,16 @@ def chunk_bwd_kernel_dv( b_dv = tl.zeros([BT, BV], dtype=tl.float32) # offset calculation - q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K - k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K - do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V - dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V - s_qk = K if HEAD_FIRST else H*K - s_vo = V if HEAD_FIRST else H*V - s_g = 1 if HEAD_FIRST else H - dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + do += (bos * H + i_h) * V + dv += (bos * H + i_h) * V + dh += (i_tg * H + i_h).to(tl.int64) * K*V b_A = tl.zeros([BT, BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): - p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_q = tl.make_block_ptr(q, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_A += tl.dot(b_k, b_q) @@ -358,10 +346,10 @@ def chunk_bwd_kernel_dv( b_dv += tl.dot(b_k, b_dh.to(b_k.dtype)) if USE_G: - g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) - p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) b_g = tl.load(p_g, boundary_check=(0,)) - b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * s_g) + b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * H) b_dv *= safe_exp(-b_g + b_g_last)[:, None] mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]) @@ -369,8 +357,8 @@ def chunk_bwd_kernel_dv( b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty) else: b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty) - p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_do = tl.load(p_do, boundary_check=(0, 1)) b_dv += tl.dot(b_A.to(b_do.dtype), b_do) tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) @@ -407,7 +395,6 @@ def chunk_bwd_kernel_dv_local( BV: tl.constexpr, USE_G: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H @@ -419,25 +406,22 @@ def chunk_bwd_kernel_dv_local( bos, eos = i_b * T, i_b * T + T # offset calculation - q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K - k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K - do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V - dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V - s_qk = K if HEAD_FIRST else H*K - s_vo = V if HEAD_FIRST else H*V - s_g = 1 if HEAD_FIRST else H + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + do += (bos * H + i_h) * V + dv += (bos * H + i_h) * V b_A = tl.zeros([BT, BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): - p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_q = tl.make_block_ptr(q, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_A += tl.dot(b_k, b_q) if USE_G: - g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) - p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) b_g = tl.load(p_g, boundary_check=(0,)) mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]) @@ -447,8 +431,8 @@ def chunk_bwd_kernel_dv_local( b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty) for i_v in range(tl.cdiv(V, BV)): - p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_do = tl.load(p_do, boundary_check=(0, 1)) b_dv = tl.dot(b_A.to(b_do.dtype), b_do) tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) @@ -489,7 +473,6 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) K=K, V=V, BT=BT, - HEAD_FIRST=False ) return o @@ -538,7 +521,6 @@ def chunk_bwd_dv( BT=BT, BK=BK, BV=BV, - HEAD_FIRST=False ) return dv @@ -585,7 +567,6 @@ def chunk_bwd_dv_local( BT=BT, BK=BK, BV=BV, - HEAD_FIRST=False ) return dv @@ -645,7 +626,6 @@ def chunk_bwd_dqkwg( BT=BT, BK=BK, BV=BV, - HEAD_FIRST=False ) if dg is not None: diff --git a/fla/ops/gla/chunk.py b/fla/ops/gla/chunk.py index 5e69888b7a..473b06d86d 100644 --- a/fla/ops/gla/chunk.py +++ b/fla/ops/gla/chunk.py @@ -1,13 +1,11 @@ # -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -import warnings from typing import Optional, Tuple import torch import triton import triton.language as tl -from einops import rearrange from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h from fla.ops.common.utils import prepare_chunk_indices @@ -851,16 +849,20 @@ def chunk_gla_fwd_intra_gk( g: torch.Tensor, scale: float, offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = False, chunk_size: int = 64 ): - B, T, H, K = k.shape + if head_first: + B, H, T, K = k.shape + else: + B, T, H, K = k.shape BT = min(chunk_size, max(16, triton.next_power_of_2(T))) - indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) BC = min(16, BT) NC = triton.cdiv(BT, BC) - A = q.new_empty(B, T, H, BT, dtype=torch.float) + A = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float) grid = (NT, NC * NC, B * H) chunk_gla_fwd_A_kernel_intra_sub_inter[grid]( q, @@ -876,7 +878,7 @@ def chunk_gla_fwd_intra_gk( BT=BT, BC=BC, NC=NC, - HEAD_FIRST=False + HEAD_FIRST=head_first ) grid = (NT, NC, B * H) @@ -897,13 +899,13 @@ def chunk_gla_fwd_intra_gk( BT=BT, BC=BC, BK=BK, - HEAD_FIRST=False + HEAD_FIRST=head_first ) # split then merge else: BK = min(128, triton.next_power_of_2(K)) NK = triton.cdiv(K, BK) - A_intra = q.new_empty(NK, B, T, H, BC, dtype=torch.float) + A_intra = q.new_empty(NK, B, *((H, T) if head_first else (T, H)), BC, dtype=torch.float) grid = (NK, NT * NC, B * H) chunk_gla_fwd_A_kernel_intra_sub_intra_split[grid]( @@ -922,7 +924,7 @@ def chunk_gla_fwd_intra_gk( BC=BC, BK=BK, NC=NC, - HEAD_FIRST=False + HEAD_FIRST=head_first ) grid = (NT, NC, B * H) @@ -937,7 +939,7 @@ def chunk_gla_fwd_intra_gk( BT=BT, BC=BC, NK=NK, - HEAD_FIRST=False + HEAD_FIRST=head_first ) return A @@ -950,11 +952,15 @@ def chunk_gla_fwd_o_gk( h: torch.Tensor, scale: float, offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = False, chunk_size: int = 64 ): - B, T, H, K, V = *q.shape, v.shape[-1] + if head_first: + B, H, T, K, V = *q.shape, v.shape[-1] + else: + B, T, H, K, V = *q.shape, v.shape[-1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) - indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) o = torch.empty_like(v) @@ -974,7 +980,7 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) K=K, V=V, BT=BT, - HEAD_FIRST=False + HEAD_FIRST=head_first ) return o @@ -984,15 +990,19 @@ def chunk_gla_bwd_dA( do: torch.Tensor, scale: float, offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = False, chunk_size: int = 64 ): - B, T, H, V = v.shape + if head_first: + B, H, T, V = v.shape + else: + B, T, H, V = v.shape BT = min(chunk_size, max(16, triton.next_power_of_2(T))) - indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) BV = min(64, triton.next_power_of_2(V)) - dA = v.new_empty(B, T, H, BT, dtype=torch.float) + dA = v.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float) grid = (NT, B * H) chunk_gla_bwd_kernel_dA[grid]( v, @@ -1006,7 +1016,7 @@ def chunk_gla_bwd_dA( V=V, BT=BT, BV=BV, - HEAD_FIRST=False + HEAD_FIRST=head_first ) return dA @@ -1018,11 +1028,15 @@ def chunk_gla_bwd_dv( do: torch.Tensor, dh: torch.Tensor, offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = False, chunk_size: int = 64 ): - B, T, H, K, V = *k.shape, do.shape[-1] + if head_first: + B, H, T, K, V = *k.shape, do.shape[-1] + else: + B, T, H, K, V = *k.shape, do.shape[-1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) - indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) dv = torch.empty_like(do) @@ -1041,7 +1055,7 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) K=K, V=V, BT=BT, - HEAD_FIRST=False + HEAD_FIRST=head_first ) return dv @@ -1052,14 +1066,18 @@ def chunk_gla_bwd_dqk_intra( g: torch.Tensor, dA: torch.Tensor, offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = False, chunk_size: int = 64 ): - B, T, H, K = q.shape + if head_first: + B, H, T, K = q.shape + else: + B, T, H, K = q.shape BT = min(chunk_size, max(16, triton.next_power_of_2(T))) - indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None - NT = triton.cdiv(T, BT) if offsets is None else len(indices) BC = min(16, BT) BK = min(64, triton.next_power_of_2(K)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) NC = triton.cdiv(BT, BC) NK = triton.cdiv(K, BK) @@ -1082,7 +1100,7 @@ def chunk_gla_bwd_dqk_intra( BC=BC, BK=BK, NC=NC, - HEAD_FIRST=False + HEAD_FIRST=head_first ) return dq, dk @@ -1099,11 +1117,15 @@ def chunk_gla_bwd_dqkg( dk: torch.Tensor, scale: float, offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = False, chunk_size: int = 64 ): - B, T, H, K, V = *k.shape, v.shape[-1] + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) - indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) dg = torch.empty_like(g) @@ -1132,7 +1154,7 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H) K=K, V=V, BT=BT, - HEAD_FIRST=False + HEAD_FIRST=head_first ) return dq2, dk2, dg @@ -1147,10 +1169,14 @@ def chunk_gla_fwd( initial_state: torch.Tensor, output_final_state: bool, offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + T = q.shape[2] if head_first else q.shape[1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) if g_cumsum is None: - g_cumsum = chunk_local_cumsum(g, chunk_size, offsets=offsets) + g_cumsum = chunk_local_cumsum(g, BT, offsets=offsets, indices=indices, head_first=head_first) h, ht = chunk_fwd_h( k=k, @@ -1162,7 +1188,8 @@ def chunk_gla_fwd( output_final_state=output_final_state, states_in_fp32=False, offsets=offsets, - chunk_size=chunk_size + head_first=head_first, + chunk_size=BT ) # the intra A is kept in fp32 @@ -1173,7 +1200,9 @@ def chunk_gla_fwd( g=g_cumsum, scale=scale, offsets=offsets, - chunk_size=chunk_size + indices=indices, + head_first=head_first, + chunk_size=BT ) o = chunk_gla_fwd_o_gk( q=q, @@ -1183,7 +1212,9 @@ def chunk_gla_fwd( h=h, scale=scale, offsets=offsets, - chunk_size=chunk_size + indices=indices, + head_first=head_first, + chunk_size=BT ) return g_cumsum, A, h, ht, o @@ -1201,12 +1232,14 @@ def chunk_gla_bwd( do: torch.Tensor, dht: torch.Tensor, offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = False, chunk_size: int = 64 ): - T = q.shape[1] + T = q.shape[2] if head_first else q.shape[1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) if g_cumsum is None: - g_cumsum = chunk_local_cumsum(g, BT, offsets=offsets) + g_cumsum = chunk_local_cumsum(g, BT, offsets=offsets, indices=indices, head_first=head_first) if h is None: h, _ = chunk_fwd_h( @@ -1218,6 +1251,7 @@ def chunk_gla_bwd( h0=initial_state, output_final_state=False, offsets=offsets, + head_first=head_first, chunk_size=BT, states_in_fp32=True ) @@ -1233,6 +1267,7 @@ def chunk_gla_bwd( dht=dht, scale=scale, offsets=offsets, + head_first=head_first, chunk_size=BT, states_in_fp32=True ) @@ -1244,6 +1279,8 @@ def chunk_gla_bwd( do=do, dh=dh, offsets=offsets, + indices=indices, + head_first=head_first, chunk_size=BT ) @@ -1253,6 +1290,8 @@ def chunk_gla_bwd( do=do, scale=scale, offsets=offsets, + indices=indices, + head_first=head_first, chunk_size=BT ) dq, dk = chunk_gla_bwd_dqk_intra( @@ -1261,6 +1300,8 @@ def chunk_gla_bwd( g=g_cumsum, dA=dA, offsets=offsets, + indices=indices, + head_first=head_first, chunk_size=BT ) dq, dk, dg = chunk_gla_bwd_dqkg( @@ -1275,6 +1316,8 @@ def chunk_gla_bwd( dk=dk, scale=scale, offsets=offsets, + indices=indices, + head_first=head_first, chunk_size=BT ) return dq, dk, dv, dg, dh0 @@ -1293,11 +1336,17 @@ def forward( scale, initial_state, output_final_state, - offsets + offsets, + head_first ): - T = q.shape[1] + T = q.shape[2] if head_first else q.shape[1] chunk_size = min(64, max(16, triton.next_power_of_2(T))) + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None g_cumsum, A, h, ht, o = chunk_gla_fwd( q=q, k=k, @@ -1308,6 +1357,8 @@ def forward( initial_state=initial_state, output_final_state=output_final_state, offsets=offsets, + indices=indices, + head_first=head_first, chunk_size=chunk_size ) # recompute g_cumsum in bwd pass @@ -1319,13 +1370,15 @@ def forward( ctx.chunk_size = chunk_size ctx.scale = scale ctx.offsets = offsets + ctx.indices = indices + ctx.head_first = head_first return o, ht @staticmethod @input_guard def backward(ctx, do, dht): q, k, v, g, g_cumsum, initial_state, A = ctx.saved_tensors - chunk_size, scale, offsets = ctx.chunk_size, ctx.scale, ctx.offsets + chunk_size, scale, offsets, indices, head_first = ctx.chunk_size, ctx.scale, ctx.offsets, ctx.indices, ctx.head_first dq, dk, dv, dg, dh0 = chunk_gla_bwd( q=q, k=k, @@ -1339,9 +1392,11 @@ def backward(ctx, do, dht): do=do, dht=dht, offsets=offsets, + indices=indices, + head_first=head_first, chunk_size=chunk_size ) - return dq.to(q), dk.to(k), dv.to(v), dg, None, dh0, None, None + return dq.to(q), dk.to(k), dv.to(v), dg, None, dh0, None, None, None @torch.compiler.disable @@ -1418,19 +1473,6 @@ def chunk_gla( >>> assert o.allclose(o_var.view(o.shape)) >>> assert ht.allclose(ht_var) """ - if head_first: - warnings.warn( - "head_first is deprecated and will be removed in a future version. " - "Please use head_first=False for now instead." - ) - q, k, v, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g)) - if not head_first and q.shape[1] < q.shape[2]: - warnings.warn( - f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " - "This may indicate the inputs were passed in head-first format [B, H, T, ...] " - "when head_first=False was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...]." - ) if cu_seqlens is not None: if q.shape[0] != 1: raise ValueError( @@ -1448,16 +1490,5 @@ def chunk_gla( ) if scale is None: scale = q.shape[-1] ** -0.5 - o, final_state = ChunkGLAFunction.apply( - q, - k, - v, - g, - scale, - initial_state, - output_final_state, - cu_seqlens, - ) - if head_first: - o = rearrange(o, 'b t h ... -> b h t ...') + o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, cu_seqlens, head_first) return o, final_state From 25ec0e094c60f5d17e6d5f819192567970ebf112 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Tue, 8 Apr 2025 21:34:19 +0800 Subject: [PATCH 3/3] Remove `head_first` in GSA --- fla/ops/gla/chunk.py | 415 ++++++++++----------------------- fla/ops/gsa/chunk.py | 320 ++++++++----------------- fla/ops/gsa/fused_recurrent.py | 72 +++--- tests/ops/test_gsa.py | 4 +- 4 files changed, 247 insertions(+), 564 deletions(-) diff --git a/fla/ops/gla/chunk.py b/fla/ops/gla/chunk.py index 473b06d86d..9412d27341 100644 --- a/fla/ops/gla/chunk.py +++ b/fla/ops/gla/chunk.py @@ -1,11 +1,13 @@ # -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +import warnings from typing import Optional, Tuple import torch import triton import triton.language as tl +from einops import rearrange from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h from fla.ops.common.utils import prepare_chunk_indices @@ -46,7 +48,6 @@ def chunk_gla_fwd_A_kernel_intra_sub_inter( BK: tl.constexpr, NC: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -68,18 +69,11 @@ def chunk_gla_fwd_A_kernel_intra_sub_inter( o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K - if HEAD_FIRST: - p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) - p_gk = tl.make_block_ptr(g + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) - p_gn = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) - else: - p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_g = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) - p_gk = tl.make_block_ptr(g + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) - p_gn = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(g + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k # [BK,] b_gn = tl.load(p_gn, mask=m_k, other=0) @@ -94,10 +88,7 @@ def chunk_gla_fwd_A_kernel_intra_sub_inter( # [BC, BC] using tf32 to improve precision here. b_A += tl.dot(b_qg, b_kg) - if HEAD_FIRST: - p_A = tl.make_block_ptr(A + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) - else: - p_A = tl.make_block_ptr(A + (bos*H + i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_A = tl.make_block_ptr(A + (bos*H + i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) @@ -129,7 +120,6 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra( BC: tl.constexpr, BK: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -148,18 +138,11 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra( o_k = tl.arange(0, BK) m_k = o_k < K m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T - if HEAD_FIRST: - o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC - p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) - p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) - p_k = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) - p_gk = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) - else: - o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC - p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) - p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) - p_k = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k - p_gk = g + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_k = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_gk = g + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k b_q = tl.load(p_q, boundary_check=(0, 1)) b_g = tl.load(p_g, boundary_check=(0, 1)) @@ -170,8 +153,8 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra( b_A = tl.where(o_i >= j, b_A * scale, 0.) tl.store(A + o_A + j, b_A, mask=m_A) - p_k += K if HEAD_FIRST else H*K - p_gk += K if HEAD_FIRST else H*K + p_k += H*K + p_gk += H*K @triton.heuristics({ @@ -204,7 +187,6 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra_split( BK: tl.constexpr, NC: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_k, i_tc, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -227,18 +209,11 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra_split( m_k = o_k < K m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T - if HEAD_FIRST: - o_A = (i_k * B*H + i_bh) * T * BC + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BC - p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_k = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) - p_gk = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) - else: - o_A = (i_k * all + bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BC + i_h * BC - p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_k = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k - p_gk = g + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + o_A = (i_k * all + bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BC + i_h * BC + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_gk = g + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k b_q = tl.load(p_q, boundary_check=(0, 1)) b_g = tl.load(p_g, boundary_check=(0, 1)) @@ -249,8 +224,8 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra_split( b_A += tl.sum(b_q * b_k[None, :] * exp(b_g - b_gk[None, :]), 1) b_A = tl.where(o_i >= j, b_A * scale, 0.) tl.store(A + o_A + j, b_A, mask=m_A) - p_k += K if HEAD_FIRST else H*K - p_gk += K if HEAD_FIRST else H*K + p_k += H*K + p_gk += H*K @triton.heuristics({ @@ -278,7 +253,6 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra_merge( BC: tl.constexpr, NK: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -296,15 +270,9 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra_merge( b_A = tl.zeros([BC, BC], dtype=tl.float32) for i_k in range(0, NK): - if HEAD_FIRST: - p_A = tl.make_block_ptr(A + (i_k*B*H+i_bh)*T*BC, (T, BC), (BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0)) - else: - p_A = tl.make_block_ptr(A + (i_k*all+bos)*H*BC+i_h*BC, (T, BC), (H*BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0)) + p_A = tl.make_block_ptr(A + (i_k*all+bos)*H*BC+i_h*BC, (T, BC), (H*BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0)) b_A += tl.load(p_A, boundary_check=(0, 1)) - if HEAD_FIRST: - p_A2 = tl.make_block_ptr(A2 + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0)) - else: - p_A2 = tl.make_block_ptr(A2 + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A2 + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0)) tl.store(p_A2, b_A.to(A2.dtype.element_ty), boundary_check=(0, 1)) @@ -339,7 +307,6 @@ def chunk_gla_fwd_kernel_o( BK: tl.constexpr, BV: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -358,14 +325,9 @@ def chunk_gla_fwd_kernel_o( b_o = tl.zeros([BT, BV], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): - if HEAD_FIRST: - p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - else: - p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) # [BT, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) @@ -380,14 +342,9 @@ def chunk_gla_fwd_kernel_o( # [BT, BV] if i_k >= 0: b_o += tl.dot(b_qg, b_h.to(b_qg.dtype)) - if HEAD_FIRST: - p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) - else: - p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) # [BT, BV] b_v = tl.load(p_v, boundary_check=(0, 1)) # [BT, BT] @@ -402,10 +359,8 @@ def chunk_gla_fwd_kernel_o( }) @triton.autotune( configs=[ - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8] ], key=['BK', 'NC', 'BT'], ) @@ -427,7 +382,6 @@ def chunk_gla_bwd_kernel_intra( BK: tl.constexpr, NC: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -444,31 +398,19 @@ def chunk_gla_bwd_kernel_intra( o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K - if HEAD_FIRST: - p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - else: - p_g = tl.make_block_ptr(g + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) # [BC, BK] b_g = tl.load(p_g, boundary_check=(0, 1)) b_dq = tl.zeros([BC, BK], dtype=tl.float32) if i_i > 0: - if HEAD_FIRST: - p_gn = g + (i_bh * T + i_t * BT + i_i * BC) * K + o_k - p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK) - else: - p_gn = g + (bos + i_t * BT + i_i * BC) * H*K + i_h*K + o_k + p_gn = g + (bos + i_t * BT + i_i * BC) * H*K + i_h*K + o_k # [BK,] b_gn = tl.load(p_gn, mask=m_k, other=0) for i_j in range(0, i_i): - if HEAD_FIRST: - p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) - p_gk = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) - p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) - else: - p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) - p_gk = tl.make_block_ptr(g+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) - p_dA = tl.make_block_ptr(dA+(bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0)) + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA+(bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0)) # [BC, BK] b_k = tl.load(p_k, boundary_check=(0, 1)) b_gk = tl.load(p_gk, boundary_check=(0, 1)) @@ -481,16 +423,10 @@ def chunk_gla_bwd_kernel_intra( o_i = tl.arange(0, BC) m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T - if HEAD_FIRST: - o_dA = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC - p_kj = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) - p_gkj = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) - p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - else: - o_dA = bos*H*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_i * BC - p_kj = k + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k - p_gkj = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k - p_dq = tl.make_block_ptr(dq + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + o_dA = bos*H*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_i * BC + p_kj = k + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_gkj = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_dq = tl.make_block_ptr(dq + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) for j in range(0, min(BC, T - i_t * BT - i_i * BC)): # [BC,] @@ -503,17 +439,13 @@ def chunk_gla_bwd_kernel_intra( # [BC, BK] # (SY 09/17) important to not use bf16 here to have a good precision. b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * exp(b_g - b_gkj[None, :]), 0.) - p_kj += K if HEAD_FIRST else H*K - p_gkj += K if HEAD_FIRST else H*K + p_kj += H*K + p_gkj += H*K tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) tl.debug_barrier() - if HEAD_FIRST: - p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_gk = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - else: - p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_gk = tl.make_block_ptr(g + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) # [BC, BK] b_k = tl.load(p_k, boundary_check=(0, 1)) @@ -522,23 +454,14 @@ def chunk_gla_bwd_kernel_intra( NC = min(NC, tl.cdiv(T - i_t * BT, BC)) if i_i < NC - 1: - if HEAD_FIRST: - p_gn = g + (i_bh * T + min(i_t * BT + i_i * BC + BC, T) - 1) * K + o_k - p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK) - else: - p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T) - 1) * H*K + i_h * K + o_k + p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T) - 1) * H*K + i_h * K + o_k # [BK,] b_gn = tl.load(p_gn, mask=m_k, other=0) for i_j in range(i_i + 1, NC): - if HEAD_FIRST: - p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t*BT + i_j*BC, i_k*BK), (BC, BK), (1, 0)) - p_gq = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t*BT + i_j*BC, i_k*BK), (BC, BK), (1, 0)) - p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (BT, T), (1, BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) - else: - p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0)) - p_gq = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0)) - p_dA = tl.make_block_ptr(dA + (bos*H+i_h)*BT, (BT, T), (1, H*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0)) + p_gq = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + (bos*H+i_h)*BT, (BT, T), (1, H*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) # [BC, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_gq = tl.load(p_gq, boundary_check=(0, 1)) @@ -549,27 +472,21 @@ def chunk_gla_bwd_kernel_intra( # (SY 09/17) important to not use bf16 here to have a good precision. b_dk += tl.dot(b_dA, b_qg) b_dk *= exp(b_gn[None, :] - b_gk) - if HEAD_FIRST: - o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC) - p_qj = tl.max_contiguous(tl.multiple_of(q + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) - p_gqj = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) - p_dk = tl.make_block_ptr(dk + i_bh*T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - else: - o_dA = bos*H*BT + (i_t * BT + i_i * BC) * H*BT + i_h * BT + i_i * BC + tl.arange(0, BC) - p_qj = q + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k - p_gqj = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k - p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + o_dA = bos*H*BT + (i_t * BT + i_i * BC) * H*BT + i_h * BT + i_i * BC + tl.arange(0, BC) + p_qj = q + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_gqj = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) for j in range(0, min(BC, T - i_t * BT - i_i * BC)): # [BC,] - b_dA = tl.load(dA + o_dA + j * (1 if HEAD_FIRST else H) * BT) + b_dA = tl.load(dA + o_dA + j * H*BT) # [BK,] b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) b_gqj = tl.load(p_gqj, mask=m_k, other=0).to(tl.float32) # [BC, BK] m_i = o_i[:, None] <= j b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * exp(b_gqj[None, :] - b_gk), 0.) - p_qj += K if HEAD_FIRST else H*K - p_gqj += K if HEAD_FIRST else H*K + p_qj += H*K + p_gqj += H*K tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) @@ -599,7 +516,6 @@ def chunk_gla_bwd_kernel_dA( BT: tl.constexpr, BV: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H @@ -612,19 +528,12 @@ def chunk_gla_bwd_kernel_dA( b_dA = tl.zeros([BT, BT], dtype=tl.float32) for i_v in range(tl.cdiv(V, BV)): - if HEAD_FIRST: - p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) - else: - p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) b_v = tl.load(p_v, boundary_check=(0, 1)) b_do = tl.load(p_do, boundary_check=(0, 1)) b_dA += tl.dot(b_do, b_v) - if HEAD_FIRST: - p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) - else: - p_dA = tl.make_block_ptr(dA + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dA = tl.make_block_ptr(dA + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] b_dA = tl.where(m_s, b_dA * scale, 0.) tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) @@ -660,7 +569,6 @@ def chunk_gla_bwd_kernel_dv( BK: tl.constexpr, BV: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -675,14 +583,9 @@ def chunk_gla_bwd_kernel_dv( i_tg = i_b * NT + i_t bos, eos = i_b * T, i_b * T + T - if HEAD_FIRST: - p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) - p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - else: - p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) - p_do = tl.make_block_ptr(do + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_A = tl.load(p_A, boundary_check=(0, 1)) b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0.) @@ -694,16 +597,10 @@ def chunk_gla_bwd_kernel_dv( o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K - if HEAD_FIRST: - p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_gk = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_gn = tl.max_contiguous(tl.multiple_of(g + i_bh * T*K + min(i_t * BT + BT, T) * K - K + o_k, BK), BK) - p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - else: - p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_gk = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_gn = g + (bos + min(i_t * BT + BT, T) - 1)*H*K + i_h * K + o_k - p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = g + (bos + min(i_t * BT + BT, T) - 1)*H*K + i_h * K + o_k + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_gk = tl.load(p_gk, boundary_check=(0, 1)) @@ -753,7 +650,6 @@ def chunk_gla_bwd_kernel_inter( BK: tl.constexpr, BV: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H @@ -770,28 +666,18 @@ def chunk_gla_bwd_kernel_inter( o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K - if HEAD_FIRST: - p_gk = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_gn = tl.max_contiguous(tl.multiple_of(g + i_bh * T*K + (min(T, i_t * BT + BT)-1) * K + o_k, BK), BK) - else: - p_gk = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_gn = g + (bos + min(T, i_t * BT + BT)-1) * H*K + i_h * K + o_k + p_gk = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = g + (bos + min(T, i_t * BT + BT)-1) * H*K + i_h * K + o_k b_gn = tl.load(p_gn, mask=m_k, other=0) b_dq = tl.zeros([BT, BK], dtype=tl.float32) b_dk = tl.zeros([BT, BK], dtype=tl.float32) b_dgk = tl.zeros([BK,], dtype=tl.float32) for i_v in range(tl.cdiv(V, BV)): - if HEAD_FIRST: - p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) - p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) - else: - p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) - p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) # [BT, BV] b_v = tl.load(p_v, boundary_check=(0, 1)) b_do = tl.load(p_do, boundary_check=(0, 1)) @@ -809,16 +695,10 @@ def chunk_gla_bwd_kernel_inter( b_dq = b_dq * exp(b_gk) b_dk = b_dk * exp(b_gn[None, :] - b_gk) - if HEAD_FIRST: - p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - else: - p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dq = tl.make_block_ptr(dq + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_dgk += tl.sum(b_dk * b_k, axis=0) @@ -830,14 +710,9 @@ def chunk_gla_bwd_kernel_inter( # Buggy due to strange triton compiler issue. # m_s = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], 1., 0.) # b_dg = tl.dot(m_s, b_dg, allow_tf32=False) + b_dgk[None, :] - if HEAD_FIRST: - p_dq = tl.make_block_ptr(dq2 + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dk = tl.make_block_ptr(dk2 + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dg = tl.make_block_ptr(dg + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - else: - p_dq = tl.make_block_ptr(dq2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dk = tl.make_block_ptr(dk2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dg = tl.make_block_ptr(dg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) @@ -849,20 +724,17 @@ def chunk_gla_fwd_intra_gk( g: torch.Tensor, scale: float, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ): - if head_first: - B, H, T, K = k.shape - else: - B, T, H, K = k.shape + B, T, H, K = k.shape BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) BC = min(16, BT) NC = triton.cdiv(BT, BC) - A = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float) + A = q.new_empty(B, T, H, BT, dtype=torch.float) grid = (NT, NC * NC, B * H) chunk_gla_fwd_A_kernel_intra_sub_inter[grid]( q, @@ -878,7 +750,6 @@ def chunk_gla_fwd_intra_gk( BT=BT, BC=BC, NC=NC, - HEAD_FIRST=head_first ) grid = (NT, NC, B * H) @@ -899,13 +770,12 @@ def chunk_gla_fwd_intra_gk( BT=BT, BC=BC, BK=BK, - HEAD_FIRST=head_first ) # split then merge else: BK = min(128, triton.next_power_of_2(K)) NK = triton.cdiv(K, BK) - A_intra = q.new_empty(NK, B, *((H, T) if head_first else (T, H)), BC, dtype=torch.float) + A_intra = q.new_empty(NK, B, T, H, BC, dtype=torch.float) grid = (NK, NT * NC, B * H) chunk_gla_fwd_A_kernel_intra_sub_intra_split[grid]( @@ -924,7 +794,6 @@ def chunk_gla_fwd_intra_gk( BC=BC, BK=BK, NC=NC, - HEAD_FIRST=head_first ) grid = (NT, NC, B * H) @@ -939,7 +808,6 @@ def chunk_gla_fwd_intra_gk( BT=BT, BC=BC, NK=NK, - HEAD_FIRST=head_first ) return A @@ -952,15 +820,12 @@ def chunk_gla_fwd_o_gk( h: torch.Tensor, scale: float, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ): - if head_first: - B, H, T, K, V = *q.shape, v.shape[-1] - else: - B, T, H, K, V = *q.shape, v.shape[-1] + B, T, H, K, V = *q.shape, v.shape[-1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) o = torch.empty_like(v) @@ -980,7 +845,6 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) K=K, V=V, BT=BT, - HEAD_FIRST=head_first ) return o @@ -990,19 +854,16 @@ def chunk_gla_bwd_dA( do: torch.Tensor, scale: float, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ): - if head_first: - B, H, T, V = v.shape - else: - B, T, H, V = v.shape + B, T, H, V = v.shape BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) BV = min(64, triton.next_power_of_2(V)) - dA = v.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float) + dA = v.new_empty(B, T, H, BT, dtype=torch.float) grid = (NT, B * H) chunk_gla_bwd_kernel_dA[grid]( v, @@ -1016,7 +877,6 @@ def chunk_gla_bwd_dA( V=V, BT=BT, BV=BV, - HEAD_FIRST=head_first ) return dA @@ -1028,15 +888,12 @@ def chunk_gla_bwd_dv( do: torch.Tensor, dh: torch.Tensor, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ): - if head_first: - B, H, T, K, V = *k.shape, do.shape[-1] - else: - B, T, H, K, V = *k.shape, do.shape[-1] + B, T, H, K, V = *k.shape, do.shape[-1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) dv = torch.empty_like(do) @@ -1055,7 +912,6 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) K=K, V=V, BT=BT, - HEAD_FIRST=head_first ) return dv @@ -1066,17 +922,14 @@ def chunk_gla_bwd_dqk_intra( g: torch.Tensor, dA: torch.Tensor, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ): - if head_first: - B, H, T, K = q.shape - else: - B, T, H, K = q.shape + B, T, H, K = q.shape BT = min(chunk_size, max(16, triton.next_power_of_2(T))) BC = min(16, BT) BK = min(64, triton.next_power_of_2(K)) + + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) NC = triton.cdiv(BT, BC) NK = triton.cdiv(K, BK) @@ -1100,7 +953,6 @@ def chunk_gla_bwd_dqk_intra( BC=BC, BK=BK, NC=NC, - HEAD_FIRST=head_first ) return dq, dk @@ -1117,19 +969,15 @@ def chunk_gla_bwd_dqkg( dk: torch.Tensor, scale: float, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ): - if head_first: - B, H, T, K, V = *k.shape, v.shape[-1] - else: - B, T, H, K, V = *k.shape, v.shape[-1] + B, T, H, K, V = *k.shape, v.shape[-1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) dg = torch.empty_like(g) - # work around triton compiler bugs. dq2 = torch.empty_like(dq) dk2 = torch.empty_like(dk) def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H) @@ -1154,7 +1002,6 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H) K=K, V=V, BT=BT, - HEAD_FIRST=head_first ) return dq2, dk2, dg @@ -1169,14 +1016,12 @@ def chunk_gla_fwd( initial_state: torch.Tensor, output_final_state: bool, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - T = q.shape[2] if head_first else q.shape[1] + T = q.shape[1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) if g_cumsum is None: - g_cumsum = chunk_local_cumsum(g, BT, offsets=offsets, indices=indices, head_first=head_first) + g_cumsum = chunk_local_cumsum(g, BT, offsets=offsets) h, ht = chunk_fwd_h( k=k, @@ -1188,7 +1033,6 @@ def chunk_gla_fwd( output_final_state=output_final_state, states_in_fp32=False, offsets=offsets, - head_first=head_first, chunk_size=BT ) @@ -1200,8 +1044,6 @@ def chunk_gla_fwd( g=g_cumsum, scale=scale, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=BT ) o = chunk_gla_fwd_o_gk( @@ -1212,8 +1054,6 @@ def chunk_gla_fwd( h=h, scale=scale, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=BT ) return g_cumsum, A, h, ht, o @@ -1232,14 +1072,12 @@ def chunk_gla_bwd( do: torch.Tensor, dht: torch.Tensor, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ): - T = q.shape[2] if head_first else q.shape[1] + T = q.shape[1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) if g_cumsum is None: - g_cumsum = chunk_local_cumsum(g, BT, offsets=offsets, indices=indices, head_first=head_first) + g_cumsum = chunk_local_cumsum(g, BT, offsets=offsets) if h is None: h, _ = chunk_fwd_h( @@ -1251,7 +1089,6 @@ def chunk_gla_bwd( h0=initial_state, output_final_state=False, offsets=offsets, - head_first=head_first, chunk_size=BT, states_in_fp32=True ) @@ -1267,7 +1104,6 @@ def chunk_gla_bwd( dht=dht, scale=scale, offsets=offsets, - head_first=head_first, chunk_size=BT, states_in_fp32=True ) @@ -1279,8 +1115,6 @@ def chunk_gla_bwd( do=do, dh=dh, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=BT ) @@ -1290,8 +1124,6 @@ def chunk_gla_bwd( do=do, scale=scale, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=BT ) dq, dk = chunk_gla_bwd_dqk_intra( @@ -1300,8 +1132,6 @@ def chunk_gla_bwd( g=g_cumsum, dA=dA, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=BT ) dq, dk, dg = chunk_gla_bwd_dqkg( @@ -1316,8 +1146,6 @@ def chunk_gla_bwd( dk=dk, scale=scale, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=BT ) return dq, dk, dv, dg, dh0 @@ -1337,16 +1165,10 @@ def forward( initial_state, output_final_state, offsets, - head_first ): - T = q.shape[2] if head_first else q.shape[1] + T = q.shape[1] chunk_size = min(64, max(16, triton.next_power_of_2(T))) - # 2-d indices denoting the offsets of chunks in each sequence - # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, - # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be - # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] - indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None g_cumsum, A, h, ht, o = chunk_gla_fwd( q=q, k=k, @@ -1357,8 +1179,6 @@ def forward( initial_state=initial_state, output_final_state=output_final_state, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=chunk_size ) # recompute g_cumsum in bwd pass @@ -1370,15 +1190,13 @@ def forward( ctx.chunk_size = chunk_size ctx.scale = scale ctx.offsets = offsets - ctx.indices = indices - ctx.head_first = head_first return o, ht @staticmethod @input_guard def backward(ctx, do, dht): q, k, v, g, g_cumsum, initial_state, A = ctx.saved_tensors - chunk_size, scale, offsets, indices, head_first = ctx.chunk_size, ctx.scale, ctx.offsets, ctx.indices, ctx.head_first + chunk_size, scale, offsets = ctx.chunk_size, ctx.scale, ctx.offsets dq, dk, dv, dg, dh0 = chunk_gla_bwd( q=q, k=k, @@ -1392,11 +1210,9 @@ def backward(ctx, do, dht): do=do, dht=dht, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=chunk_size ) - return dq.to(q), dk.to(k), dv.to(v), dg, None, dh0, None, None, None + return dq.to(q), dk.to(k), dv.to(v), dg, None, dh0, None, None @torch.compiler.disable @@ -1473,6 +1289,19 @@ def chunk_gla( >>> assert o.allclose(o_var.view(o.shape)) >>> assert ht.allclose(ht_var) """ + if head_first: + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) if cu_seqlens is not None: if q.shape[0] != 1: raise ValueError( @@ -1490,5 +1319,7 @@ def chunk_gla( ) if scale is None: scale = q.shape[-1] ** -0.5 - o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, cu_seqlens, head_first) + o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, cu_seqlens) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') return o, final_state diff --git a/fla/ops/gsa/chunk.py b/fla/ops/gsa/chunk.py index b852a1a60f..53ca69bbad 100644 --- a/fla/ops/gsa/chunk.py +++ b/fla/ops/gsa/chunk.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +import warnings from typing import Optional, Tuple import torch @@ -9,6 +10,7 @@ from einops import rearrange, reduce from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h +from fla.ops.common.utils import prepare_chunk_indices from fla.ops.gla.chunk import chunk_gla_bwd, chunk_gla_fwd from fla.ops.utils import chunk_local_cumsum, softmax_bwd, softmax_fwd from fla.ops.utils.op import exp, safe_exp @@ -49,10 +51,8 @@ def chunk_gsa_fwd_k_kernel_inter( BV: tl.constexpr, NG: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_bg = i_bh // NG i_b, i_hq = i_bh // HQ, i_bh % HQ i_h = i_hq // NG if IS_VARLEN: @@ -72,14 +72,9 @@ def chunk_gsa_fwd_k_kernel_inter( b_o = tl.zeros([BT, BV], dtype=tl.float32) b_A = tl.zeros([BT, BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): - if HEAD_FIRST: - p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_h = tl.make_block_ptr(h + (i_bg * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - else: - p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) # [BT, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) @@ -92,14 +87,9 @@ def chunk_gsa_fwd_k_kernel_inter( b_o += tl.dot(b_q, b_h) # [BT, BT] b_A += tl.dot(b_q, b_k) - if HEAD_FIRST: - p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) - else: - p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_A = tl.make_block_ptr(A + (bos * HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + (bos * HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) # [BT, BV] b_g = tl.load(p_g, boundary_check=(0, 1)) b_o = b_o * exp(b_g) @@ -132,10 +122,8 @@ def chunk_gsa_fwd_k_kernel_intra( NC: tl.constexpr, NG: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_bg = i_bh // NG i_b, i_hq = i_bh // HQ, i_bh % HQ i_h = i_hq // NG i_t, i_i = i_c // NC, i_c % NC @@ -152,25 +140,16 @@ def chunk_gsa_fwd_k_kernel_intra( if i_t * BT + i_i * BC > T: return - if HEAD_FIRST: - p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) - p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + min(i_t * BT + i_i * BC, T) * V + o_v, BV), BV) - else: - p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) - p_gn = g + (bos + min(i_t * BT + i_i * BC, T)) * H*V + i_h * V + o_v + p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_gn = g + (bos + min(i_t * BT + i_i * BC, T)) * H*V + i_h * V + o_v # [BV,] b_gn = tl.load(p_gn, mask=m_v, other=0) # [BC, BV] b_o = tl.zeros([BC, BV], dtype=tl.float32) for i_j in range(0, i_i): - if HEAD_FIRST: - p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) - p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) - p_gv = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) - else: - p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0)) - p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) - p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0)) + p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) # [BC, BV] b_v = tl.load(p_v, boundary_check=(0, 1)) b_gv = tl.load(p_gv, boundary_check=(0, 1)) @@ -183,18 +162,11 @@ def chunk_gsa_fwd_k_kernel_intra( b_o *= exp(b_g - b_gn[None, :]) o_i = tl.arange(0, BC) - if HEAD_FIRST: - o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC - else: - o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * HQ*BT + i_hq * BT + i_i * BC + o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * HQ*BT + i_hq * BT + i_i * BC m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T for j in range(0, min(BC, T - i_t * BT - i_i * BC)): - if HEAD_FIRST: - p_v = tl.max_contiguous(tl.multiple_of(v + i_bg * T*V + (i_t * BT + i_i * BC + j) * V + o_v, BV), BV) - p_gv = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC + j) * V + o_v, BV), BV) - else: - p_v = v + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v - p_gv = g + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v + p_v = v + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v + p_gv = g + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v # [BC,] b_A = tl.load(A + o_A + j, mask=m_A, other=0) # [BV,] @@ -204,10 +176,7 @@ def chunk_gsa_fwd_k_kernel_intra( b_vg = b_v[None, :] * exp(b_g - b_gv[None, :]) # avoid 0 * inf = inf b_o += tl.where(o_i[:, None] >= j, b_A[:, None] * b_vg, 0.) - if HEAD_FIRST: - p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) - else: - p_o = tl.make_block_ptr(o + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) b_o += tl.load(p_o, boundary_check=(0, 1)) tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) @@ -242,10 +211,8 @@ def chunk_gsa_bwd_k_kernel_dA( NC: tl.constexpr, NG: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_bg = i_bh // NG i_b, i_hq = i_bh // HQ, i_bh % HQ i_h = i_hq // NG i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC @@ -264,26 +231,16 @@ def chunk_gsa_bwd_k_kernel_dA( if i_t * BT + i_i * BC > T: return - if HEAD_FIRST: - p_dA = tl.make_block_ptr(dA+(i_v*B*H+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) - else: - p_dA = tl.make_block_ptr(dA+((i_v*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j*BC), (BC, BC), (1, 0)) + p_dA = tl.make_block_ptr(dA+((i_v*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j*BC), (BC, BC), (1, 0)) # [BC, BC] b_dA = tl.zeros([BC, BC], dtype=tl.float32) if i_i > i_j: - if HEAD_FIRST: - p_v = tl.make_block_ptr(v + i_bg * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1)) - p_gv = tl.make_block_ptr(g + i_bg * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1)) - p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV) - p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) - p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) - else: - p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1)) - p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1)) - p_gn = g + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v - p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) - p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1)) + p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1)) + p_gn = g + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v + p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) # [BV,] b_gn = tl.load(p_gn, mask=m_v, other=0.) # [BC, BV] @@ -297,16 +254,10 @@ def chunk_gsa_bwd_k_kernel_dA( # [BC, BC] b_dA = tl.dot(b_do, b_vg) elif i_i == i_j: - if HEAD_FIRST: - p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) - p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) - p_v = tl.max_contiguous(tl.multiple_of(v + i_bg * T*V + (i_t * BT + i_j * BC) * V + o_v, BV), BV) - p_gv = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_j * BC) * V + o_v, BV), BV) - else: - p_g = tl.make_block_ptr(g + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) - p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) - p_v = v + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v - p_gv = g + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v + p_g = tl.make_block_ptr(g + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_v = v + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v + p_gv = g + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v # [BC, BV] b_g = tl.load(p_g, boundary_check=(0, 1)) b_do = tl.load(p_do, boundary_check=(0, 1)) * scale @@ -323,8 +274,8 @@ def chunk_gsa_bwd_k_kernel_dA( b_dAj = tl.sum(b_do * b_v[None, :] * exp(b_g - b_gv[None, :]), 1) b_dA = tl.where((o_i == j)[None, :], b_dAj[:, None], b_dA) - p_v += (1 if HEAD_FIRST else H) * V - p_gv += (1 if HEAD_FIRST else H) * V + p_v += H*V + p_gv += H*V b_dA = tl.where(m_dA, b_dA, 0.) tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1)) @@ -370,10 +321,8 @@ def chunk_gsa_bwd_k_kernel_dqkvg( BV: tl.constexpr, NG: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_bg = i_bh // NG i_b, i_hq = i_bh // HQ, i_bh % HQ i_h = i_hq // NG if IS_VARLEN: @@ -393,14 +342,9 @@ def chunk_gsa_bwd_k_kernel_dqkvg( o_t = min(i_t * BT + BT, T) m_s = o_i[:, None] >= o_i[None, :] - if HEAD_FIRST: - p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k + i_bg * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_A = tl.make_block_ptr(A + (i_k*B*H+i_bh) * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) - else: - p_q = tl.make_block_ptr(q + (bos*HQ+i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k + (bos*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_A = tl.make_block_ptr(A + ((i_k*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_q = tl.make_block_ptr(q + (bos*HQ+i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_A = tl.make_block_ptr(A + ((i_k*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) # [BT, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) @@ -414,26 +358,15 @@ def chunk_gsa_bwd_k_kernel_dqkvg( b_dk = tl.zeros([BT, BK], dtype=tl.float32) for i_v in range(tl.cdiv(V, BV)): o_v = i_v * BV + tl.arange(0, BV) - if HEAD_FIRST: - p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (o_t - 1) * V + o_v, BV), BV) - p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_dv = tl.make_block_ptr(dv + (i_k*B*H+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_dg = tl.make_block_ptr(dg + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_dgv = tl.make_block_ptr(dgv + (i_k*B*H+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_h = tl.make_block_ptr(h + i_bg * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) - p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - else: - p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_g = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_gn = g + (bos + o_t - 1) * H*V + i_h * V + o_v - p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_dv = tl.make_block_ptr(dv + ((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_dgv = tl.make_block_ptr(dgv+((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) - p_dh = tl.make_block_ptr(dh + (i_tg * HQ + i_hq) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_g = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gn = g + (bos + o_t - 1) * H*V + i_h * V + o_v + p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + ((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dgv = tl.make_block_ptr(dgv+((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + (i_tg * HQ + i_hq) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) m_v = o_v < V # [BV,] @@ -468,14 +401,9 @@ def chunk_gsa_bwd_k_kernel_dqkvg( tl.store(p_dgv, b_dgv.to(p_dgv.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) - if HEAD_FIRST: - p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) - p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - else: - p_dA = tl.make_block_ptr(dA + (bos*HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) - p_dq = tl.make_block_ptr(dq + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dk = tl.make_block_ptr(dk + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + (bos*HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dq = tl.make_block_ptr(dq + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) # [BT, BT] b_dA = tl.load(p_dA, boundary_check=(0, 1)) # [BT, BK] @@ -510,10 +438,8 @@ def chunk_gsa_bwd_k_kernel_intra_dvg( NC: tl.constexpr, NG: tl.constexpr, IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr ): i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_bg = i_bh // NG i_b, i_hq = i_bh // HQ, i_bh % HQ i_h = i_hq // NG i_t, i_i = i_c // NC, i_c % NC @@ -530,26 +456,17 @@ def chunk_gsa_bwd_k_kernel_intra_dvg( if i_t * BT + i_i * BC > T: return - if HEAD_FIRST: - p_gv = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) - p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (min(i_t * BT + i_i * BC + BC, T) - 1) * V + o_v, BV), BV) - else: - p_gv = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) - p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T)-1)*H*V + i_h*V + o_v + p_gv = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T)-1)*H*V + i_h*V + o_v # [BV,] b_gn = tl.load(p_gn, mask=m_v, other=0) # [BC, BV] b_gv = tl.load(p_gv, boundary_check=(0, 1)) b_dv = tl.zeros([BC, BV], dtype=tl.float32) for i_j in range(i_i + 1, NC): - if HEAD_FIRST: - p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) - p_A = tl.make_block_ptr(A + i_bh * T*BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) - p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) - else: - p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) - p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (BT, T), (1, HQ*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) - p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_j*BC, i_v*BV), (BC, BV), (1, 0)) + p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (BT, T), (1, HQ*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_j*BC, i_v*BV), (BC, BV), (1, 0)) # [BC, BV] b_g = tl.load(p_g, boundary_check=(0, 1)) b_do = tl.load(p_do, boundary_check=(0, 1)) * safe_exp(b_g - b_gn[None, :]) @@ -562,15 +479,9 @@ def chunk_gsa_bwd_k_kernel_intra_dvg( o_i = tl.arange(0, BC) o_c = i_i * BC + tl.arange(0, BC) - if HEAD_FIRST: - p_g = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV) - p_A = tl.max_contiguous(tl.multiple_of(A + i_bh * T*BT + (i_t * BT + i_i * BC) * BT + o_c, BC), BC) - p_do = tl.max_contiguous(tl.multiple_of(do + i_bh * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV) - else: - p_g = g + (bos + i_t * BT + i_i * BC) * H*V + i_h * V + o_v - p_A = A + (bos + i_t*BT + i_i*BC) * HQ*BT + i_hq * BT + o_c - p_do = do + (bos + i_t*BT + i_i*BC) * HQ*V + i_hq * V + o_v - + p_g = g + (bos + i_t * BT + i_i * BC) * H*V + i_h * V + o_v + p_A = A + (bos + i_t*BT + i_i*BC) * HQ*BT + i_hq * BT + o_c + p_do = do + (bos + i_t*BT + i_i*BC) * HQ*V + i_hq * V + o_v for j in range(0, min(BC, T - i_t * BT - i_i * BC)): # [BC,] b_A = tl.load(p_A) @@ -581,21 +492,14 @@ def chunk_gsa_bwd_k_kernel_intra_dvg( m_i = o_i[:, None] <= j b_dv += tl.where(m_i, exp(b_g[None, :] - b_gv) * b_A[:, None] * b_do[None, :], 0.) - p_g += (1 if HEAD_FIRST else H) * V - p_A += (1 if HEAD_FIRST else HQ) * BT - p_do += (1 if HEAD_FIRST else HQ) * V - if HEAD_FIRST: - p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) - p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) - p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) - p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) - p_dg = tl.make_block_ptr(dg + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) - else: - p_o = tl.make_block_ptr(o + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) - p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) - p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) - p_dv = tl.make_block_ptr(dv + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) - p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_g += H * V + p_A += HQ * BT + p_do += HQ * V + p_o = tl.make_block_ptr(o + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) b_v = tl.load(p_v, boundary_check=(0, 1)).to(tl.float32) @@ -615,8 +519,6 @@ def chunk_gsa_fwd_v( initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: _, A, h, ht, o = chunk_gla_fwd( @@ -629,8 +531,6 @@ def chunk_gsa_fwd_v( initial_state=initial_state, output_final_state=output_final_state, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=chunk_size ) return A, h, ht, o @@ -645,18 +545,15 @@ def chunk_gsa_fwd_k( output_final_state: bool = False, scale: float = 1., offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - if head_first: - B, H, T, K, V = *k.shape, v.shape[-1] - else: - B, T, H, K, V = *k.shape, v.shape[-1] + B, T, H, K, V = *k.shape, v.shape[-1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) BC = min(16, BT) BV = min(64, triton.next_power_of_2(V)) - HQ = q.shape[1] if head_first else q.shape[2] + HQ = q.shape[2] + + indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) NC = triton.cdiv(BT, BC) NG = HQ // H @@ -670,12 +567,11 @@ def chunk_gsa_fwd_k( h0=h0, output_final_state=output_final_state, offsets=offsets, - head_first=head_first, chunk_size=BT, states_in_fp32=False ) - o = v.new_empty(B, *((HQ, T) if head_first else (T, HQ)), V) - A = q.new_empty(B, *((HQ, T) if head_first else (T, HQ)), BT) + o = v.new_empty(B, T, HQ, V) + A = q.new_empty(B, T, HQ, BT) def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * HQ) chunk_gsa_fwd_k_kernel_inter[grid]( q, @@ -694,7 +590,6 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * HQ) V=V, BT=BT, NG=NG, - HEAD_FIRST=head_first ) def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ) @@ -714,7 +609,6 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ) BV=BV, NC=NC, NG=NG, - HEAD_FIRST=head_first, num_warps=4, num_stages=2 ) @@ -734,8 +628,6 @@ def chunk_gsa_bwd_v( dg: torch.Tensor, scale: float = 1., offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ): dq, dk, dv, dg, dh0 = chunk_gla_bwd( @@ -751,8 +643,6 @@ def chunk_gsa_bwd_v( do=do, dht=dht, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=chunk_size ) return dq, dk, dv, dg, dh0 @@ -771,19 +661,16 @@ def chunk_gsa_bwd_k( dg: torch.Tensor, scale: float = 1., offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ): - if head_first: - B, H, T, K, V = *k.shape, v.shape[-1] - else: - B, T, H, K, V = *k.shape, v.shape[-1] + B, T, H, K, V = *k.shape, v.shape[-1] BT = min(chunk_size, max(16, triton.next_power_of_2(T))) BC = min(16, BT) BK = min(64, triton.next_power_of_2(K)) BV = min(64, triton.next_power_of_2(V)) - HQ = q.shape[1] if head_first else q.shape[2] + HQ = q.shape[2] + + indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None NT = triton.cdiv(T, BT) if offsets is None else len(indices) NC = triton.cdiv(BT, BC) NK = triton.cdiv(K, BK) @@ -800,7 +687,6 @@ def chunk_gsa_bwd_k( h0=h0, output_final_state=False, offsets=offsets, - head_first=head_first, chunk_size=BT, states_in_fp32=False ) @@ -816,11 +702,10 @@ def chunk_gsa_bwd_k( dht=dht, scale=scale, offsets=offsets, - head_first=head_first, chunk_size=BT, states_in_fp32=True ) - dA = q.new_empty(NV, B, *((HQ, T) if head_first else (T, HQ)), BT) + dA = q.new_empty(NV, B, T, HQ, BT) grid = (NV, NT * NC * NC, B * HQ) chunk_gsa_bwd_k_kernel_dA[grid]( v, @@ -840,15 +725,14 @@ def chunk_gsa_bwd_k( BV=BV, NC=NC, NG=NG, - HEAD_FIRST=head_first ) dA = dA.sum(0, dtype=dA.dtype) - A = do.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), BT) + A = do.new_empty(NK, B, T, HQ, BT) dq = torch.empty_like(q) - dk = k.new_empty(B, *((HQ, T) if head_first else (T, HQ)), K) - dv = v.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), V) - dgv = g.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), V, dtype=torch.float) + dk = k.new_empty(B, T, HQ, K) + dv = v.new_empty(NK, B, T, HQ, V) + dgv = g.new_empty(NK, B, T, HQ, V, dtype=torch.float) grid = (NK, NT, B * HQ) chunk_gsa_bwd_k_kernel_dqkvg[grid]( q, @@ -878,7 +762,6 @@ def chunk_gsa_bwd_k( BK=BK, BV=BV, NG=NG, - HEAD_FIRST=head_first ) A = A.sum(0, dtype=A.dtype) dv = dv.sum(0, dtype=dv.dtype) @@ -904,11 +787,10 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ) BV=BV, NC=NC, NG=NG, - HEAD_FIRST=head_first, num_warps=4, num_stages=2 ) - dg = dgv.add_(chunk_local_cumsum(dg, chunk_size=BT, reverse=True, offsets=offsets, indices=indices, head_first=head_first)) + dg = dgv.add_(chunk_local_cumsum(dg, chunk_size=BT, reverse=True, offsets=offsets)) return dq, dk, dv, dg, dh0 @@ -923,8 +805,6 @@ def chunk_gsa_fwd( output_final_state: bool = False, scale: float = 1., offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: hk0, hv0 = None, None @@ -939,8 +819,6 @@ def chunk_gsa_fwd( output_final_state=output_final_state, scale=scale, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=chunk_size ) @@ -957,8 +835,6 @@ def chunk_gsa_fwd( initial_state=hv0, output_final_state=output_final_state, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=chunk_size ) return Ak, hk, hkt, ok, p, Av, hv, hvt, ov @@ -979,8 +855,6 @@ def chunk_gsa_bwd( do: torch.Tensor, dht: Tuple[torch.Tensor, torch.Tensor], offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None, - head_first: bool = False, chunk_size: int = 64 ): hk0, hv0 = None, None @@ -1005,8 +879,6 @@ def chunk_gsa_bwd( dg=None, scale=1., offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=chunk_size ) @@ -1027,8 +899,6 @@ def chunk_gsa_bwd( dg=dg, scale=scale, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=chunk_size ) @@ -1056,20 +926,11 @@ def forward( output_final_state: bool, checkpoint_level: int, offsets: Optional[torch.LongTensor], - head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - T = q.shape[2] if head_first else q.shape[1] + T = q.shape[1] chunk_size = min(64, max(16, triton.next_power_of_2(T))) - # 2-d indices denoting the offsets of chunks in each sequence - # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, - # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be - # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] - indices = None - if offsets is not None: - indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()]) - indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) - g_org, g = g, chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first) + g_org, g = g, chunk_local_cumsum(g, chunk_size, offsets=offsets) Ak, hk, hkt, ok, p, Av, hv, hvt, ov = chunk_gsa_fwd( q=q, k=k, @@ -1080,8 +941,6 @@ def forward( output_final_state=output_final_state, scale=scale, offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=chunk_size ) @@ -1099,8 +958,6 @@ def forward( ctx.checkpoint_level = checkpoint_level ctx.scale = scale ctx.offsets = offsets - ctx.indices = indices - ctx.head_first = head_first ctx.chunk_size = chunk_size return ov, hkt, hvt @@ -1110,12 +967,10 @@ def backward(ctx, dov, dhkt=None, dhvt=None): q, k, v, s, g, ok, p, Av, hk0, hv0, hk, hv = ctx.saved_tensors scale = ctx.scale offsets = ctx.offsets - indices = ctx.indices - head_first = ctx.head_first chunk_size = ctx.chunk_size if ctx.checkpoint_level >= 1: - g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first) + g = chunk_local_cumsum(g, chunk_size, offsets=offsets) dq, dk, dv, ds, dg, dhk0, dhv0 = chunk_gsa_bwd( q=q, k=k, @@ -1131,8 +986,6 @@ def backward(ctx, dov, dhkt=None, dhvt=None): do=dov, dht=(dhkt, dhvt), offsets=offsets, - indices=indices, - head_first=head_first, chunk_size=chunk_size ) return dq, dk, dv, ds, dg, None, dhk0, dhv0, None, None, None, None @@ -1228,6 +1081,19 @@ def chunk_gsa( >>> assert hk.allclose(hk_var) >>> assert hv.allclose(hv_var) """ + if head_first: + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v, s, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, s, g)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) if cu_seqlens is not None: if q.shape[0] != 1: raise ValueError( @@ -1247,7 +1113,7 @@ def chunk_gsa( if g is None: # TODO: this 3 steps took huge amount of time, ought to be optimized z = s.float().logcumsumexp(2) - g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z + g = torch.cat((z[:, :, :1], z[:, :, :-1]), 1) - z s = torch.exp(s - z).to(k.dtype) if scale is None: scale = q.shape[-1] ** -0.5 diff --git a/fla/ops/gsa/fused_recurrent.py b/fla/ops/gsa/fused_recurrent.py index 1dc3e1ec9b..d08fb5043b 100644 --- a/fla/ops/gsa/fused_recurrent.py +++ b/fla/ops/gsa/fused_recurrent.py @@ -1,11 +1,13 @@ # -*- coding: utf-8 -*- # Copyright (c) 2024, Songlin Yang, Yu Zhang +import warnings from typing import Optional, Tuple import torch import triton import triton.language as tl +from einops import rearrange from fla.ops.common.fused_recurrent import fused_recurrent_bwd_kernel, fused_recurrent_fwd_kernel from fla.ops.utils import chunk_global_cumsum @@ -196,8 +198,7 @@ def fused_recurrent_gsa_fwd( USE_G=False, USE_GK=False, USE_GV=True, - REVERSE=reverse, - HEAD_FIRST=head_first + REVERSE=reverse ) ok = ok.sum(0) @@ -228,7 +229,6 @@ def fused_recurrent_gsa_fwd( USE_GK=True, USE_GV=False, REVERSE=reverse, - HEAD_FIRST=head_first ) ov = ov.sum(0) return ok, hkt, qv, ov, hvt @@ -250,25 +250,16 @@ def fused_recurrent_gsa_bwd( scale: float = 1., reverse: bool = False, offsets: Optional[torch.LongTensor] = None, - head_first: bool = False ) -> Tuple[torch.Tensor]: - if head_first: - B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] - else: - B, T, H, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + B, T, H, K, V, M = *q.shape, v.shape[-1], s.shape[-1] N = B if offsets is None else len(offsets) - 1 BK, BV, BM = min(K, 64), min(V, 64), min(M, 64) NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM) - if head_first: - dqv = q.new_empty(NV, B, H, T, M, dtype=torch.float) - dsv = q.new_empty(NV, B, H, T, M, dtype=torch.float) - dv = q.new_empty(NM, B, H, T, V, dtype=torch.float) - else: - dqv = q.new_empty(NV, B, T, H, M, dtype=torch.float) - dsv = q.new_empty(NV, B, T, H, M, dtype=torch.float) - dv = q.new_empty(NM, B, T, H, V, dtype=torch.float) + dqv = q.new_empty(NV, B, T, H, M, dtype=torch.float) + dsv = q.new_empty(NV, B, T, H, M, dtype=torch.float) + dv = q.new_empty(NM, B, T, H, V, dtype=torch.float) dhk0 = torch.empty_like(hk0)if hk0 is not None else None dhv0 = torch.empty_like(hv0)if hv0 is not None else None @@ -301,25 +292,16 @@ def fused_recurrent_gsa_bwd( USE_GK=True, USE_GV=False, REVERSE=reverse, - HEAD_FIRST=head_first ) dqv = dqv.sum(0) dsv = dsv.sum(0) dv = dv.sum(0) - dgk = chunk_global_cumsum(dqv * qv.float() - dsv * s.float(), - reverse=not reverse, - offsets=offsets, - head_first=head_first) + dgk = chunk_global_cumsum(dqv * qv.float() - dsv * s.float(), reverse=not reverse, offsets=offsets) dok = qv * (dqv - (qv * dqv).sum(-1, True)) - if head_first: - dq = q.new_empty(NM, B, H, T, K, dtype=torch.float) - dk = q.new_empty(NM, B, H, T, K, dtype=torch.float) - dsk = q.new_empty(NK, B, H, T, M, dtype=torch.float) - else: - dq = q.new_empty(NM, B, T, H, K, dtype=torch.float) - dk = q.new_empty(NM, B, T, H, K, dtype=torch.float) - dsk = q.new_empty(NK, B, T, H, M, dtype=torch.float) + dq = q.new_empty(NM, B, T, H, K, dtype=torch.float) + dk = q.new_empty(NM, B, T, H, K, dtype=torch.float) + dsk = q.new_empty(NK, B, T, H, M, dtype=torch.float) gk, gv = None, g grid = (NM, NK, N * H) fused_recurrent_bwd_kernel[grid]( @@ -349,16 +331,12 @@ def fused_recurrent_gsa_bwd( USE_GK=False, USE_GV=True, REVERSE=reverse, - HEAD_FIRST=head_first ) dq = dq.sum(0) dk = dk.sum(0) dsk = dsk.sum(0) - dgv = chunk_global_cumsum(dok.float() * ok.float() - dsk * s.float(), - reverse=not reverse, - offsets=offsets, - head_first=head_first) + dgv = chunk_global_cumsum(dok.float() * ok.float() - dsk * s.float(), reverse=not reverse, offsets=offsets) ds = dsk.add_(dsv) dg = dgk.add_(dgv) @@ -384,9 +362,8 @@ def forward( output_final_state: bool = False, reverse: bool = False, offsets: Optional[torch.LongTensor] = None, - head_first: bool = False ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: - T = q.shape[2] if head_first else q.shape[1] + T = q.shape[1] if T == 1 and not q.requires_grad: o, (hkt, hvt) = fused_recurrent_gsa_inference( q=q, @@ -397,7 +374,6 @@ def forward( initial_state=(hk0, hv0), output_final_state=output_final_state, scale=scale, - head_first=head_first ) return o, hkt, hvt ok, hkt, qv, ov, hvt = fused_recurrent_gsa_fwd( @@ -411,13 +387,11 @@ def forward( scale=scale, reverse=reverse, offsets=offsets, - head_first=head_first ) ctx.save_for_backward(q, k, v, s, g, qv, hk0, hv0, ok) ctx.scale = scale ctx.reverse = reverse ctx.offsets = offsets - ctx.head_first = head_first return ov.to(q.dtype), hkt, hvt @staticmethod @@ -428,7 +402,6 @@ def backward(ctx, do, dhkt=None, dhvt=None): scale = ctx.scale reverse = ctx.reverse offsets = ctx.offsets - head_first = ctx.head_first # not supported yet. if dhkt is not None or dhvt is not None: @@ -450,9 +423,8 @@ def backward(ctx, do, dhkt=None, dhvt=None): scale=scale, reverse=reverse, offsets=offsets, - head_first=head_first ) - return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, dhk0, dhv0, None, None, None, None + return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, dhk0, dhv0, None, None, None def fused_recurrent_gsa( @@ -537,6 +509,19 @@ def fused_recurrent_gsa( >>> assert hk.allclose(hk_var) >>> assert hv.allclose(hv_var) """ + if head_first: + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v, s, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, s, g)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) if cu_seqlens is not None: if q.shape[0] != 1: raise ValueError( @@ -567,6 +552,7 @@ def fused_recurrent_gsa( output_final_state, reverse, cu_seqlens, - head_first ) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') return o, final_state diff --git a/tests/ops/test_gsa.py b/tests/ops/test_gsa.py index aa8550e38f..2876ffa4e5 100644 --- a/tests/ops/test_gsa.py +++ b/tests/ops/test_gsa.py @@ -144,7 +144,7 @@ def test_fused_recurrent_varlen( # randomly split the sequence into N segments offsets = torch.cat([ torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 1)[:N-1]], + torch.arange(16, T)[torch.randperm(T - 16)[:N-1]], torch.tensor([T], dtype=torch.long) ], 0).to(device).sort()[0] @@ -326,7 +326,7 @@ def test_chunk_varlen( # randomly split the sequence into N segments offsets = torch.cat([ torch.tensor([0], dtype=torch.long), - torch.arange(16, T)[torch.randperm(T - 1)[:N-1]], + torch.arange(16, T)[torch.randperm(T - 16)[:N-1]], torch.tensor([T], dtype=torch.long) ], 0).to(device).sort()[0] # seq-first required for inputs with variable lengths