From 4205b0f29f5d2178cf6e6babffd7da7dd2a24468 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 7 Apr 2025 19:06:40 +0800 Subject: [PATCH 1/4] [API] Update `head_first` parameter default to `False` --- fla/ops/abc/chunk.py | 12 ++-- fla/ops/based/fused_chunk.py | 2 +- fla/ops/based/parallel.py | 2 +- fla/ops/common/chunk_delta_h.py | 8 +-- fla/ops/common/chunk_h.py | 4 +- fla/ops/common/chunk_h_parallel.py | 4 +- fla/ops/common/chunk_h_split.py | 4 +- fla/ops/common/chunk_o.py | 8 +-- fla/ops/common/fused_recurrent.py | 8 +-- fla/ops/delta_rule/chunk.py | 6 +- fla/ops/delta_rule/fused_recurrent.py | 8 +-- fla/ops/delta_rule/parallel.py | 12 ++-- fla/ops/gated_delta_rule/chunk.py | 6 +- fla/ops/gated_delta_rule/wy_fast.py | 2 +- fla/ops/generalized_delta_rule/dplr/chunk.py | 34 ++++++----- .../dplr/chunk_A_bwd.py | 2 +- .../dplr/chunk_A_fwd.py | 2 +- .../dplr/chunk_h_bwd.py | 2 +- .../dplr/chunk_h_fwd.py | 2 +- .../dplr/chunk_o_bwd.py | 6 +- .../dplr/chunk_o_fwd.py | 2 +- .../dplr/fused_recurrent.py | 18 ++++-- .../dplr/wy_fast_fwd.py | 2 +- fla/ops/generalized_delta_rule/iplr/chunk.py | 40 +++++++------ .../iplr/fused_recurrent.py | 30 ++++++---- .../generalized_delta_rule/iplr/wy_fast.py | 2 +- fla/ops/gla/chunk.py | 46 ++++++++------- fla/ops/gla/fused_chunk.py | 2 +- fla/ops/gla/fused_recurrent.py | 32 ++++++----- fla/ops/gsa/chunk.py | 40 +++++++------ fla/ops/gsa/fused_recurrent.py | 56 +++++++++++-------- fla/ops/lightning_attn/chunk.py | 12 ++-- fla/ops/lightning_attn/fused_recurrent.py | 12 ++-- fla/ops/linear_attn/chunk.py | 12 ++-- fla/ops/linear_attn/fused_chunk.py | 12 ++-- fla/ops/linear_attn/fused_recurrent.py | 2 +- fla/ops/nsa/naive.py | 20 ++++--- fla/ops/rebased/parallel.py | 2 +- fla/ops/retention/chunk.py | 12 ++-- fla/ops/retention/fused_chunk.py | 13 ++--- fla/ops/retention/fused_recurrent.py | 9 +-- fla/ops/retention/parallel.py | 12 ++-- fla/ops/rwkv6/chunk.py | 44 ++++++++------- fla/ops/rwkv6/fused_recurrent.py | 36 +++++++----- fla/ops/rwkv7/chunk.py | 13 +++-- fla/ops/rwkv7/fused_recurrent.py | 13 +++-- fla/ops/simple_gla/chunk.py | 54 ++++++++++-------- fla/ops/simple_gla/fused_recurrent.py | 50 ++++++++++------- fla/ops/simple_gla/parallel.py | 18 +++--- fla/ops/titans/naive.py | 2 +- fla/ops/ttt/chunk.py | 39 +++++++------ fla/ops/ttt/fused_chunk.py | 28 ++++++---- fla/ops/ttt/naive.py | 2 +- fla/ops/utils/cumsum.py | 20 ++++--- 54 files changed, 468 insertions(+), 373 deletions(-) diff --git a/fla/ops/abc/chunk.py b/fla/ops/abc/chunk.py index 8538e04800..194e51094c 100644 --- a/fla/ops/abc/chunk.py +++ b/fla/ops/abc/chunk.py @@ -1082,16 +1082,16 @@ def chunk_abc( s: torch.Tensor, initial_state: Optional[Tuple[torch.Tensor]] = None, output_final_state: bool = False, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` s (torch.Tensor): slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]` initial_state (Optional[Tuple[torch.Tensor, torch.Tensor]]): @@ -1100,11 +1100,11 @@ def chunk_abc( Whether to output the final state of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `False`. head_first (Optional[bool]): Whether the inputs are in the head-first format. - Default: `True`. + Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (torch.Tensor): Final state of shape `[B, H, K, M]` and `[B, H, M, V]` if `output_final_state=True` else `None`. """ diff --git a/fla/ops/based/fused_chunk.py b/fla/ops/based/fused_chunk.py index ff5db4fb73..b017af069b 100644 --- a/fla/ops/based/fused_chunk.py +++ b/fla/ops/based/fused_chunk.py @@ -359,7 +359,7 @@ def fused_chunk_based( v: torch.Tensor, scale: Optional[float] = None, use_norm: bool = True, - head_first: bool = True + head_first: bool = False ): assert q.shape[-1] <= 16, 'only support feature dimension up to 16.' if scale is None: diff --git a/fla/ops/based/parallel.py b/fla/ops/based/parallel.py index d4621ea583..573c4844df 100644 --- a/fla/ops/based/parallel.py +++ b/fla/ops/based/parallel.py @@ -395,7 +395,7 @@ def parallel_based( v: torch.Tensor, scale: Optional[float] = None, use_norm: bool = True, - head_first: bool = True + head_first: bool = False ): assert q.shape[-1] <= 128, "only support feature dim up to 128" if scale is None: diff --git a/fla/ops/common/chunk_delta_h.py b/fla/ops/common/chunk_delta_h.py index e7d9d60485..8c4f593005 100644 --- a/fla/ops/common/chunk_delta_h.py +++ b/fla/ops/common/chunk_delta_h.py @@ -424,9 +424,9 @@ def chunk_gated_delta_rule_fwd_h( output_final_state: bool = False, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, chunk_size: int = 64, # SY: remove this argument and force chunk size 64? - save_new_value: bool = True + save_new_value: bool = True, + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: if head_first: B, H, T, K, V = *k.shape, u.shape[-1] @@ -490,8 +490,8 @@ def chunk_gated_delta_rule_bwd_dhu( scale: float, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, - chunk_size: int = 64 # SY: remove this argument and force chunk size 64? + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if head_first: B, H, T, K, V = *q.shape, do.shape[-1] diff --git a/fla/ops/common/chunk_h.py b/fla/ops/common/chunk_h.py index 0aa5a7a93b..8da7b585fa 100644 --- a/fla/ops/common/chunk_h.py +++ b/fla/ops/common/chunk_h.py @@ -302,7 +302,7 @@ def chunk_fwd_h( h0: torch.Tensor, output_final_state: bool, offsets: Optional[torch.Tensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64, split_size: Optional[int] = None, states_in_fp32: bool = False @@ -364,7 +364,7 @@ def chunk_bwd_dh( dht: torch.Tensor, scale: float, offsets: Optional[torch.Tensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64, split_size: Optional[int] = None, states_in_fp32: bool = False diff --git a/fla/ops/common/chunk_h_parallel.py b/fla/ops/common/chunk_h_parallel.py index 51083eda8e..aa951b756c 100644 --- a/fla/ops/common/chunk_h_parallel.py +++ b/fla/ops/common/chunk_h_parallel.py @@ -488,7 +488,7 @@ def chunk_fwd_h( states_in_fp32: bool = False, offsets: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor]: if head_first: @@ -569,7 +569,7 @@ def chunk_bwd_dh( states_in_fp32: bool = False, offsets: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor]: if head_first: diff --git a/fla/ops/common/chunk_h_split.py b/fla/ops/common/chunk_h_split.py index cc017fb6a6..6375328d3c 100644 --- a/fla/ops/common/chunk_h_split.py +++ b/fla/ops/common/chunk_h_split.py @@ -498,7 +498,7 @@ def chunk_fwd_h( offsets: Optional[torch.LongTensor] = None, split_offsets: Optional[torch.LongTensor] = None, split_indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64, split_size: int = 256, states_in_fp32: bool = True @@ -590,7 +590,7 @@ def chunk_bwd_dh( offsets: Optional[torch.Tensor] = None, split_offsets: Optional[torch.Tensor] = None, split_indices: Optional[torch.Tensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64, split_size: int = 256, states_in_fp32: bool = True diff --git a/fla/ops/common/chunk_o.py b/fla/ops/common/chunk_o.py index b1e99d1d28..e0912f8783 100644 --- a/fla/ops/common/chunk_o.py +++ b/fla/ops/common/chunk_o.py @@ -462,7 +462,7 @@ def chunk_fwd_o( scale: Optional[float] = None, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> torch.Tensor: if head_first: @@ -506,7 +506,7 @@ def chunk_bwd_dv( scale: float, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> torch.Tensor: if head_first: @@ -559,7 +559,7 @@ def chunk_bwd_dv_local( scale: float, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> torch.Tensor: if head_first: @@ -615,7 +615,7 @@ def chunk_bwd_dqkwg( indices: Optional[torch.LongTensor] = None, chunk_size: int = 64, scale: float = 1.0, - head_first: bool = True, + head_first: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if head_first: diff --git a/fla/ops/common/fused_recurrent.py b/fla/ops/common/fused_recurrent.py index 263de38d06..4cd0bd6fcf 100644 --- a/fla/ops/common/fused_recurrent.py +++ b/fla/ops/common/fused_recurrent.py @@ -332,7 +332,7 @@ def fused_recurrent_fwd( output_final_state: bool = False, reverse: bool = False, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ): if head_first: B, H, T, K, V = *k.shape, v.shape[-1] @@ -393,7 +393,7 @@ def fused_recurrent_bwd( initial_state: Optional[torch.Tensor] = None, reverse: bool = False, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ): if head_first: B, H, T, K, V = *k.shape, v.shape[-1] @@ -487,7 +487,7 @@ def forward( output_final_state: bool = False, reverse: bool = False, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ): o, ht = fused_recurrent_fwd( q=q, @@ -555,7 +555,7 @@ def fused_recurrent( output_final_state: bool = False, reverse: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ): if scale is None: scale = k.shape[-1] ** -0.5 diff --git a/fla/ops/delta_rule/chunk.py b/fla/ops/delta_rule/chunk.py index 650b63547c..929e23bafc 100644 --- a/fla/ops/delta_rule/chunk.py +++ b/fla/ops/delta_rule/chunk.py @@ -25,7 +25,7 @@ def chunk_delta_rule_fwd( output_final_state: bool, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): T = q.shape[2] if head_first else q.shape[1] @@ -80,7 +80,7 @@ def chunk_delta_rule_bwd( dht: torch.Tensor, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): T = q.shape[2] if head_first else q.shape[1] @@ -181,7 +181,7 @@ def forward( initial_state: torch.Tensor, output_final_state: bool, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, use_qk_l2norm_in_kernel: bool = True ): T = q.shape[2] if head_first else q.shape[1] diff --git a/fla/ops/delta_rule/fused_recurrent.py b/fla/ops/delta_rule/fused_recurrent.py index ef57cc4c2e..559071cda9 100644 --- a/fla/ops/delta_rule/fused_recurrent.py +++ b/fla/ops/delta_rule/fused_recurrent.py @@ -301,7 +301,7 @@ def fused_recurrent_delta_rule_fwd( initial_state: torch.Tensor, output_final_state: bool, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: if head_first: B, H, T, K, V = *k.shape, v.shape[-1] @@ -359,7 +359,7 @@ def fused_recurrent_delta_rule_bwd( scale: float, initial_state: torch.Tensor, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if head_first: B, H, T, K, V = *k.shape, v.shape[-1] @@ -438,7 +438,7 @@ def forward( initial_state: torch.Tensor, output_final_state: bool, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, use_qk_l2norm_in_kernel: bool = False ): q_orig = q @@ -501,7 +501,7 @@ def fused_recurrent_delta_rule( initial_state: torch.Tensor = None, output_final_state: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, use_qk_l2norm_in_kernel: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" diff --git a/fla/ops/delta_rule/parallel.py b/fla/ops/delta_rule/parallel.py index 722f2dec76..765f82a403 100644 --- a/fla/ops/delta_rule/parallel.py +++ b/fla/ops/delta_rule/parallel.py @@ -305,18 +305,18 @@ def parallel_delta_rule( beta: torch.Tensor, scale: float = None, output_attentions: bool = False, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. beta (torch.Tensor): - betas of shape `[B, H, T]` if `head_first=True` else `[B, T, H]`. + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. scale (Optional[int]): Scale factor for attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. @@ -328,7 +328,7 @@ def parallel_delta_rule( Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. attn (torch.Tensor): Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None`. """ diff --git a/fla/ops/gated_delta_rule/chunk.py b/fla/ops/gated_delta_rule/chunk.py index abbb52a56f..36ba416a33 100644 --- a/fla/ops/gated_delta_rule/chunk.py +++ b/fla/ops/gated_delta_rule/chunk.py @@ -26,7 +26,7 @@ def chunk_gated_delta_rule_fwd( output_final_state: bool, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first) @@ -85,7 +85,7 @@ def chunk_gated_delta_rule_bwd( dht: torch.Tensor, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): T = q.shape[2] if head_first else q.shape[1] @@ -193,7 +193,7 @@ def forward( initial_state: torch.Tensor, output_final_state: bool, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, use_qk_l2norm_in_kernel: bool = False ): chunk_size = 64 diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py index f80b2251f3..ad43e6a431 100644 --- a/fla/ops/gated_delta_rule/wy_fast.py +++ b/fla/ops/gated_delta_rule/wy_fast.py @@ -327,7 +327,7 @@ def fwd_prepare_wy_repr( beta: torch.Tensor, offsets: Optional[torch.LongTensor], indices: Optional[torch.LongTensor], - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if head_first: diff --git a/fla/ops/generalized_delta_rule/dplr/chunk.py b/fla/ops/generalized_delta_rule/dplr/chunk.py index eac6af87a2..749b8b0de2 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk.py @@ -31,7 +31,7 @@ def chunk_dplr_fwd( output_final_state: bool, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): T = q.shape[2] if head_first else q.shape[1] @@ -116,7 +116,7 @@ def forward( initial_state: torch.Tensor, output_final_state: bool, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ): chunk_size = 16 @@ -325,17 +325,17 @@ def chunk_dplr_delta_rule( r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. a (torch.Tensor): - activations of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + activations of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. b (torch.Tensor): - betas of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + betas of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. gk (torch.Tensor): - gk of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. decay term in log space! + gk of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. decay term in log space! scale (Optional[int]): Scale factor for the RetNet attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. @@ -354,7 +354,7 @@ def chunk_dplr_delta_rule( Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (torch.Tensor): Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. """ @@ -364,13 +364,19 @@ def chunk_dplr_delta_rule( if cu_seqlens is not None: if q.shape[0] != 1: - raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) if head_first: - raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: - raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " - f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) scale = k.shape[-1] ** -0.5 if scale is None else scale o, final_state = ChunkDPLRDeltaRuleFunction.apply( q, diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py index 8feac35b2f..a26a21531e 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py @@ -369,7 +369,7 @@ def chunk_dplr_bwd_dqk_intra( dgk_last: torch.Tensor, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, scale: float = 1.0, chunk_size: int = 64, ): diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py index 08518c2035..d181bc897b 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py @@ -282,7 +282,7 @@ def chunk_fwd_intra_dplr_fn( chunk_size: int, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, ): if head_first: B, H, T, K = k.shape diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py index ffed7c5c1c..18980e7165 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py @@ -131,7 +131,7 @@ def chunk_dplr_bwd_dhu( dv: torch.Tensor, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if head_first: diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py index b382d5905a..82282e1366 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py @@ -131,7 +131,7 @@ def chunk_dplr_fwd_h( output_final_state: bool = False, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor]: if head_first: diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py index f4a17bcfb2..159ed4c197 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py @@ -309,7 +309,7 @@ def chunk_dplr_bwd_dv( dh: torch.Tensor, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> torch.Tensor: if head_first: @@ -359,7 +359,7 @@ def chunk_dplr_bwd_o( indices: Optional[torch.LongTensor] = None, chunk_size: int = 64, scale: float = 1.0, - head_first: bool = True, + head_first: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if head_first: @@ -420,7 +420,7 @@ def chunk_dplr_bwd_dAu( scale: float, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> torch.Tensor: if head_first: diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py index 981901295b..cd30aedb51 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py @@ -106,7 +106,7 @@ def chunk_dplr_fwd_o( h: torch.Tensor, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> torch.Tensor: if head_first: diff --git a/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py b/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py index 17e7f3483a..5903b8fe6e 100644 --- a/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py +++ b/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py @@ -128,7 +128,7 @@ def fused_recurrent_dplr_delta_rule_fwd( output_final_state: bool = False, reverse: bool = False, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ): if head_first: B, H, T, K, V = *k.shape, v.shape[-1] @@ -264,13 +264,19 @@ def fused_recurrent_dplr_delta_rule( """ if cu_seqlens is not None: if q.shape[0] != 1: - raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) if head_first: - raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: - raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " - f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) if scale is None: scale = q.shape[-1] ** -0.5 else: diff --git a/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py b/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py index 3ef5ac298d..c5e631a3e1 100644 --- a/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py +++ b/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py @@ -238,7 +238,7 @@ def fwd_prepare_wy_repr( A_ab: torch.Tensor, offsets: Optional[torch.LongTensor], indices: Optional[torch.LongTensor], - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if head_first: diff --git a/fla/ops/generalized_delta_rule/iplr/chunk.py b/fla/ops/generalized_delta_rule/iplr/chunk.py index 07f76533b1..ce46d0d7e7 100644 --- a/fla/ops/generalized_delta_rule/iplr/chunk.py +++ b/fla/ops/generalized_delta_rule/iplr/chunk.py @@ -219,7 +219,7 @@ def chunk_generalized_iplr_delta_rule_fwd_o( scale: Optional[float] = None, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> torch.Tensor: if head_first: @@ -269,7 +269,7 @@ def chunk_generalized_iplr_delta_rule_fwd_h( output_final_state: bool = False, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor]: if head_first: @@ -349,7 +349,7 @@ def chunk_generalized_iplr_delta_rule_fwd( output_final_state: bool, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): T = q.shape[2] if head_first else q.shape[1] @@ -410,7 +410,7 @@ def forward( initial_state: torch.Tensor, output_final_state: bool, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ): chunk_size = 64 @@ -464,20 +464,20 @@ def chunk_iplr_delta_rule( initial_state: torch.Tensor = None, output_final_state: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ): r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. a (torch.Tensor): - activations of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + activations of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. b (torch.Tensor): - betas of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + betas of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. scale (Optional[int]): Scale factor for the RetNet attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. @@ -492,11 +492,11 @@ def chunk_iplr_delta_rule( consistent with the FlashAttention API. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. - Default: `True`. + Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (torch.Tensor): Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. """ @@ -505,13 +505,19 @@ def chunk_iplr_delta_rule( if cu_seqlens is not None: if q.shape[0] != 1: - raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) if head_first: - raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: - raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " - f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) scale = k.shape[-1] ** -0.5 if scale is None else scale o, final_state = ChunkGeneralizedIPLRDeltaRuleFunction.apply( q, diff --git a/fla/ops/generalized_delta_rule/iplr/fused_recurrent.py b/fla/ops/generalized_delta_rule/iplr/fused_recurrent.py index 8cea0c212b..f70a7f0771 100644 --- a/fla/ops/generalized_delta_rule/iplr/fused_recurrent.py +++ b/fla/ops/generalized_delta_rule/iplr/fused_recurrent.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright (c) 2024-2025, Songlin Yang, Yu Zhang -from typing import Tuple +from typing import Optional, Tuple import torch import triton @@ -406,7 +406,7 @@ def fused_recurrent_iplr_delta_rule( scale: float = None, initial_state: torch.Tensor = None, output_final_state: bool = False, - offsets: torch.Tensor = None, + cu_seqlens: Optional[torch.Tensor] = None, head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" @@ -430,22 +430,30 @@ def fused_recurrent_iplr_delta_rule( Initial state of shape `[B, H, K, V]`. Default: `None`. output_final_state (Optional[bool]): Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. - offsets (Optional[torch.Tensor]): + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. """ - if offsets is not None: + if cu_seqlens is not None: if q.shape[0] != 1: - raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `offsets`." - f"Please flatten variable-length inputs before processing.") + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `offsets`." + f"Please flatten variable-length inputs before processing." + ) if head_first: - raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") - if initial_state is not None and initial_state.shape[0] != len(offsets) - 1: - raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " - f"i.e., {len(offsets) - 1} rather than {initial_state.shape[0]}.") + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) if scale is None: scale = q.shape[-1] ** -0.5 else: assert scale > 0, "scale must be positive" o, final_state = FusedRecurrentIPLRDeltaRuleFunction.apply( - q, k, v, a, b, scale, initial_state, output_final_state, offsets, head_first) + q, k, v, a, b, scale, initial_state, output_final_state, cu_seqlens, head_first) return o, final_state diff --git a/fla/ops/generalized_delta_rule/iplr/wy_fast.py b/fla/ops/generalized_delta_rule/iplr/wy_fast.py index 9fdfa70915..fcdc7e4418 100644 --- a/fla/ops/generalized_delta_rule/iplr/wy_fast.py +++ b/fla/ops/generalized_delta_rule/iplr/wy_fast.py @@ -253,7 +253,7 @@ def fwd_prepare_wy_repr( k: torch.Tensor, offsets: Optional[torch.LongTensor], indices: Optional[torch.LongTensor], - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if head_first: diff --git a/fla/ops/gla/chunk.py b/fla/ops/gla/chunk.py index c3e40f492d..3648a6192f 100644 --- a/fla/ops/gla/chunk.py +++ b/fla/ops/gla/chunk.py @@ -850,7 +850,7 @@ def chunk_gla_fwd_intra_gk( scale: float, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): if head_first: @@ -953,7 +953,7 @@ def chunk_gla_fwd_o_gk( scale: float, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): if head_first: @@ -991,7 +991,7 @@ def chunk_gla_bwd_dA( scale: float, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): if head_first: @@ -1029,7 +1029,7 @@ def chunk_gla_bwd_dv( dh: torch.Tensor, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): if head_first: @@ -1067,7 +1067,7 @@ def chunk_gla_bwd_dqk_intra( dA: torch.Tensor, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): if head_first: @@ -1118,7 +1118,7 @@ def chunk_gla_bwd_dqkg( scale: float, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): if head_first: @@ -1170,7 +1170,7 @@ def chunk_gla_fwd( output_final_state: bool, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: T = q.shape[2] if head_first else q.shape[1] @@ -1233,7 +1233,7 @@ def chunk_gla_bwd( dht: torch.Tensor, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): T = q.shape[2] if head_first else q.shape[1] @@ -1409,18 +1409,18 @@ def chunk_gla( initial_state: torch.Tensor = None, output_final_state: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. g (torch.Tensor): - Forget gates of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` applied to keys. + Forget gates of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` applied to keys. scale (Optional[int]): Scale factor for the attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. @@ -1435,11 +1435,11 @@ def chunk_gla( consistent with the FlashAttention API. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. - Default: `True`. + Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (torch.Tensor): Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. @@ -1473,13 +1473,19 @@ def chunk_gla( """ if cu_seqlens is not None: if q.shape[0] != 1: - raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) if head_first: - raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: - raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " - f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) if scale is None: scale = q.shape[-1] ** -0.5 o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, cu_seqlens, head_first) diff --git a/fla/ops/gla/fused_chunk.py b/fla/ops/gla/fused_chunk.py index 318be21402..584fab62a0 100644 --- a/fla/ops/gla/fused_chunk.py +++ b/fla/ops/gla/fused_chunk.py @@ -616,7 +616,7 @@ def fused_chunk_gla( scale: int = -1, initial_state: torch.Tensor = None, output_final_state: bool = False, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: if scale == -1: scale = q.shape[-1] ** -0.5 diff --git a/fla/ops/gla/fused_recurrent.py b/fla/ops/gla/fused_recurrent.py index d211541d78..b5e31db124 100644 --- a/fla/ops/gla/fused_recurrent.py +++ b/fla/ops/gla/fused_recurrent.py @@ -19,20 +19,20 @@ def fused_recurrent_gla( output_final_state: bool = False, reverse: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. gk (torch.Tensor): - Forget gates of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` applied to keys. + Forget gates of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` applied to keys. gv (torch.Tensor): - Forget gates of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` applied to values. + Forget gates of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` applied to values. scale (Optional[int]): Scale factor for the attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. @@ -49,11 +49,11 @@ def fused_recurrent_gla( consistent with the FlashAttention API. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. - Default: `True`. + Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (torch.Tensor): Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. @@ -87,13 +87,19 @@ def fused_recurrent_gla( """ if cu_seqlens is not None: if q.shape[0] != 1: - raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) if head_first: - raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: - raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " - f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) if scale is None: scale = k.shape[-1] ** -0.5 o, final_state = fused_recurrent( diff --git a/fla/ops/gsa/chunk.py b/fla/ops/gsa/chunk.py index 950b76cd9c..2f961b9fe8 100644 --- a/fla/ops/gsa/chunk.py +++ b/fla/ops/gsa/chunk.py @@ -616,7 +616,7 @@ def chunk_gsa_fwd_v( output_final_state: bool = False, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: _, A, h, ht, o = chunk_gla_fwd( @@ -646,7 +646,7 @@ def chunk_gsa_fwd_k( scale: float = 1., offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if head_first: @@ -735,7 +735,7 @@ def chunk_gsa_bwd_v( scale: float = 1., offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): dq, dk, dv, dg, dh0 = chunk_gla_bwd( @@ -772,7 +772,7 @@ def chunk_gsa_bwd_k( scale: float = 1., offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): if head_first: @@ -924,7 +924,7 @@ def chunk_gsa_fwd( scale: float = 1., offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: hk0, hv0 = None, None @@ -980,7 +980,7 @@ def chunk_gsa_bwd( dht: Tuple[torch.Tensor, torch.Tensor], offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): hk0, hv0 = None, None @@ -1056,7 +1056,7 @@ def forward( output_final_state: bool, checkpoint_level: int, offsets: Optional[torch.LongTensor], - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: T = q.shape[2] if head_first else q.shape[1] chunk_size = min(64, max(16, triton.next_power_of_2(T))) @@ -1155,12 +1155,12 @@ def chunk_gsa( r""" Args: q (torch.Tensor): - queries of shape `[B, HQ, T, K]` if `head_first=True` else `[B, T, HQ, K]`. + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. GQA is performed if `H` is not equal to `HQ`. v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. s (torch.Tensor): slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`. g (torch.Tensor): @@ -1187,11 +1187,11 @@ def chunk_gsa( consistent with the FlashAttention API. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. - Default: `True`. + Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (Tuple[torch.Tensor]): Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` if `output_final_state=True`. `None` otherwise. @@ -1228,13 +1228,19 @@ def chunk_gsa( """ if cu_seqlens is not None: if q.shape[0] != 1: - raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) if head_first: - raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) if initial_state is not None and initial_state[0].shape[0] != len(cu_seqlens) - 1: - raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " - f"i.e., {len(cu_seqlens) - 1} rather than {initial_state[0].shape[0]}.") + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state[0].shape[0]}." + ) assert checkpoint_level in [0, 1, 2] if g is None: # TODO: this 3 steps took huge amount of time, ought to be optimized diff --git a/fla/ops/gsa/fused_recurrent.py b/fla/ops/gsa/fused_recurrent.py index a4eacfda68..5f436f03e4 100644 --- a/fla/ops/gsa/fused_recurrent.py +++ b/fla/ops/gsa/fused_recurrent.py @@ -93,7 +93,7 @@ def fused_recurrent_gsa_inference( initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_final_state: bool = False, scale: float = 1., - head_first: bool = True + head_first: bool = False ) -> torch.Tensor: if head_first: B, H, T, K, V, M = *k.shape, v.shape[-1], s.shape[-1] @@ -150,7 +150,7 @@ def fused_recurrent_gsa_fwd( scale: float = 1., reverse: bool = False, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: if head_first: B, H, T, K, V, M = *k.shape, v.shape[-1], s.shape[-1] @@ -250,7 +250,7 @@ def fused_recurrent_gsa_bwd( scale: float = 1., reverse: bool = False, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor]: if head_first: B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] @@ -384,7 +384,7 @@ def forward( output_final_state: bool = False, reverse: bool = False, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: T = q.shape[2] if head_first else q.shape[1] if T == 1 and not q.requires_grad: @@ -466,16 +466,16 @@ def fused_recurrent_gsa( output_final_state: Optional[bool] = False, reverse: Optional[bool] = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. s (torch.Tensor): slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`. g (torch.Tensor): @@ -497,11 +497,11 @@ def fused_recurrent_gsa( consistent with the FlashAttention API. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. - Default: `True`. + Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (Tuple[torch.Tensor]): Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]`. @@ -518,32 +518,40 @@ def fused_recurrent_gsa( >>> s = torch.randn(B, T, H, M, device='cuda') >>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda')) >>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda')) - >>> o, (hk, hv) = fused_recurrent_gsa(q, k, v, s, g, - initial_state=h0, - output_final_state=True, - head_first=False) + >>> o, (hk, hv) = fused_recurrent_gsa( + q, k, v, s, g, + initial_state=h0, + output_final_state=True, + ) # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g)) # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) - >>> o_var, (hk_var, hv_var) = fused_recurrent_gsa(q, k, v, s, g, - initial_state=h0, - output_final_state=True, - cu_seqlens=cu_seqlens, - head_first=False) + >>> o_var, (hk_var, hv_var) = fused_recurrent_gsa( + q, k, v, s, g, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens, + ) >>> assert o.allclose(o_var.view(o.shape)) >>> assert hk.allclose(hk_var) >>> assert hv.allclose(hv_var) """ if cu_seqlens is not None: if q.shape[0] != 1: - raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) if head_first: - raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) if initial_state is not None and initial_state[0].shape[0] != len(cu_seqlens) - 1: - raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " - f"i.e., {len(cu_seqlens) - 1} rather than {initial_state[0].shape[0]}.") + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state[0].shape[0]}." + ) if scale is None: scale = k.shape[-1] ** -0.5 if initial_state is None: diff --git a/fla/ops/lightning_attn/chunk.py b/fla/ops/lightning_attn/chunk.py index d56d913acb..a7fbea144d 100644 --- a/fla/ops/lightning_attn/chunk.py +++ b/fla/ops/lightning_attn/chunk.py @@ -19,16 +19,16 @@ def chunk_lightning_attn( initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. layer_idx (int): The index of the current layer. num_layers (int): @@ -47,11 +47,11 @@ def chunk_lightning_attn( consistent with the FlashAttention API. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. - Default: `True`. + Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (torch.Tensor): Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. """ diff --git a/fla/ops/lightning_attn/fused_recurrent.py b/fla/ops/lightning_attn/fused_recurrent.py index 6548188b7b..8db06e33ee 100644 --- a/fla/ops/lightning_attn/fused_recurrent.py +++ b/fla/ops/lightning_attn/fused_recurrent.py @@ -19,16 +19,16 @@ def fused_recurrent_lightning_attn( output_final_state: bool = False, reverse: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. layer_idx (int): The index of the current layer. num_layers (int): @@ -47,11 +47,11 @@ def fused_recurrent_lightning_attn( consistent with the FlashAttention API. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. - Default: `True`. + Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (torch.Tensor): Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. """ diff --git a/fla/ops/linear_attn/chunk.py b/fla/ops/linear_attn/chunk.py index 8283e70792..021aee68a1 100644 --- a/fla/ops/linear_attn/chunk.py +++ b/fla/ops/linear_attn/chunk.py @@ -18,16 +18,16 @@ def chunk_linear_attn( initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, normalize: bool = True, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` scale (Optional[int]): Scale factor for the linear attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. @@ -38,11 +38,11 @@ def chunk_linear_attn( normalize (bool): Whether to normalize the output. Default: `True`. head_first (Optional[bool]): - Whether the inputs are in the head-first format. Default: `True`. + Whether the inputs are in the head-first format. Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` final_state (torch.Tensor): Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None` """ diff --git a/fla/ops/linear_attn/fused_chunk.py b/fla/ops/linear_attn/fused_chunk.py index bfcc1212a5..9e963c3c58 100644 --- a/fla/ops/linear_attn/fused_chunk.py +++ b/fla/ops/linear_attn/fused_chunk.py @@ -278,16 +278,16 @@ def fused_chunk_linear_attn( initial_state: torch.Tensor = None, output_final_state: bool = False, normalize: bool = True, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` scale (Optional[int]): Scale factor for linear attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. @@ -298,11 +298,11 @@ def fused_chunk_linear_attn( normalize (bool): Whether to normalize the output. Default: `True`. head_first (Optional[bool]): - Whether the inputs are in the head-first format. Default: `True`. + Whether the inputs are in the head-first format. Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` final_state (torch.Tensor): Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None` """ diff --git a/fla/ops/linear_attn/fused_recurrent.py b/fla/ops/linear_attn/fused_recurrent.py index b50b8c7bfb..b4569c2ace 100644 --- a/fla/ops/linear_attn/fused_recurrent.py +++ b/fla/ops/linear_attn/fused_recurrent.py @@ -237,7 +237,7 @@ def fused_recurrent_linear_attn( initial_state: torch.Tensor = None, output_final_state: bool = False, normalize: bool = False, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: if scale is None: scale = q.shape[-1] ** -0.5 diff --git a/fla/ops/nsa/naive.py b/fla/ops/nsa/naive.py index 79433b80ef..4949e8aef1 100644 --- a/fla/ops/nsa/naive.py +++ b/fla/ops/nsa/naive.py @@ -14,18 +14,18 @@ def naive_nsa( indices: torch.LongTensor, block_size: int = 64, scale: Optional[float] = None, - head_first: bool = False, - cu_seqlens: Optional[torch.LongTensor] = None + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False ) -> torch.Tensor: r""" Args: q (torch.Tensor): - queries of shape `[B, HQ, T, K]` if `head_first=True` else `[B, T, HQ, K]`. + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. indices (torch.LongTensor): Block indices of shape `[B, T, H, S]` if `head_first=True` else `[B, T, H, S]`. `S` is the number of selected blocks for each query token, which is set to 16 in the paper. @@ -34,21 +34,23 @@ def naive_nsa( scale (Optional[int]): Scale factor for attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. - head_first (Optional[bool]): - Whether the inputs are in the head-first format. Default: `False`. cu_seqlens (torch.LongTensor): Cumulative sequence lengths of shape `[N+1]` used for variable-length training, consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, HQ, T, V]` if `head_first=True` else `[B, T, HQ, V]`. + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: if head_first: - raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) if head_first: q, k, v, indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v, indices)) diff --git a/fla/ops/rebased/parallel.py b/fla/ops/rebased/parallel.py index 54de25ffb8..a1660c92a2 100644 --- a/fla/ops/rebased/parallel.py +++ b/fla/ops/rebased/parallel.py @@ -447,7 +447,7 @@ def parallel_rebased( use_scale: bool = True, use_normalize: bool = True, return_both: bool = False, - head_first: bool = True + head_first: bool = False ): assert q.shape[-1] <= 128, "only support feature dim up to 128" if use_scale: diff --git a/fla/ops/retention/chunk.py b/fla/ops/retention/chunk.py index cca1bd290e..df5fc63ac4 100644 --- a/fla/ops/retention/chunk.py +++ b/fla/ops/retention/chunk.py @@ -17,16 +17,16 @@ def chunk_retention( initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. scale (Optional[int]): Scale factor for the attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. @@ -41,11 +41,11 @@ def chunk_retention( consistent with the FlashAttention API. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. - Default: `True`. + Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (torch.Tensor): Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. diff --git a/fla/ops/retention/fused_chunk.py b/fla/ops/retention/fused_chunk.py index ff068f647d..6089634b1b 100644 --- a/fla/ops/retention/fused_chunk.py +++ b/fla/ops/retention/fused_chunk.py @@ -326,16 +326,16 @@ def fused_chunk_retention( scale: Optional[float] = None, initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` scale (Optional[int]): Scale factor for the attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. @@ -344,12 +344,11 @@ def fused_chunk_retention( output_final_state (Optional[bool]): Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. head_first (Optional[bool]): - Whether the inputs are in the head-first format. - Default: `True`. + Whether the inputs are in the head-first format. Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (torch.Tensor): Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`. """ diff --git a/fla/ops/retention/fused_recurrent.py b/fla/ops/retention/fused_recurrent.py index b84eb83e73..5af37a2b4b 100644 --- a/fla/ops/retention/fused_recurrent.py +++ b/fla/ops/retention/fused_recurrent.py @@ -17,13 +17,10 @@ def fused_recurrent_retention( output_final_state: bool = False, reverse: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: - if head_first: - n_heads = q.shape[1] - else: - n_heads = q.shape[2] - s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log() + H = q.shape[1] if head_first else q.shape[2] + s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(H), dtype=torch.float))).log() if head_first: g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() else: diff --git a/fla/ops/retention/parallel.py b/fla/ops/retention/parallel.py index 8186fc78d4..7319005a4d 100644 --- a/fla/ops/retention/parallel.py +++ b/fla/ops/retention/parallel.py @@ -15,16 +15,16 @@ def parallel_retention( scale: Optional[float] = None, output_attentions: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` scale (Optional[int]): Scale factor for attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. @@ -34,11 +34,11 @@ def parallel_retention( Cumulative sequence lengths of shape `[N+1]` used for variable-length training, consistent with the FlashAttention API. head_first (Optional[bool]): - Whether the inputs are in the head-first format. Default: `True`. + Whether the inputs are in the head-first format. Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. attn (torch.Tensor): Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None` """ diff --git a/fla/ops/rwkv6/chunk.py b/fla/ops/rwkv6/chunk.py index b495dbb21f..296477a00c 100644 --- a/fla/ops/rwkv6/chunk.py +++ b/fla/ops/rwkv6/chunk.py @@ -78,7 +78,7 @@ def chunk_rwkv6_fwd_cumsum( chunk_size: int, offsets: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, - head_first: bool = True + head_first: bool = False ) -> torch.Tensor: if head_first: B, H, T, S = g.shape @@ -852,7 +852,7 @@ def chunk_rwkv6_fwd_intra( scale: float, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): if head_first: @@ -963,7 +963,7 @@ def chunk_rwkv6_bwd_dh( scale: float, offsets: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64, states_in_fp32: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -1021,7 +1021,7 @@ def chunk_rwkv6_bwd_dqk_intra( dA: torch.Tensor, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): if head_first: @@ -1077,7 +1077,7 @@ def chunk_rwkv6_bwd_dqkgu( scale: float, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): if head_first: @@ -1134,7 +1134,7 @@ def chunk_rwkv6_fwd( output_final_state: bool, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: gi, ge = chunk_rwkv6_fwd_cumsum(g, chunk_size=chunk_size, offsets=offsets, indices=indices, head_first=head_first) @@ -1194,7 +1194,7 @@ def chunk_rwkv6_bwd( dht: torch.Tensor, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ): gi, ge = chunk_rwkv6_fwd_cumsum(g, chunk_size=chunk_size, offsets=offsets, indices=indices, head_first=head_first) @@ -1374,18 +1374,18 @@ def chunk_rwkv6( initial_state: torch.Tensor = None, output_final_state: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. g (torch.Tensor): - Forget gates of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` applied to keys. + Forget gates of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` applied to keys. u (torch.Tensor): bonus representations of shape `[H]`. scale (Optional[int]): @@ -1402,11 +1402,11 @@ def chunk_rwkv6( consistent with the FlashAttention API. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. - Default: `True`. + Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (Optional[torch.Tensor]): Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. @@ -1441,13 +1441,19 @@ def chunk_rwkv6( """ if cu_seqlens is not None: if q.shape[0] != 1: - raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) if head_first: - raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: - raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " - f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) if scale is None: scale = q.shape[-1] ** -0.5 o, final_state = ChunkRWKV6Function.apply( diff --git a/fla/ops/rwkv6/fused_recurrent.py b/fla/ops/rwkv6/fused_recurrent.py index 2ff9d2c3cd..e1c96f6e05 100644 --- a/fla/ops/rwkv6/fused_recurrent.py +++ b/fla/ops/rwkv6/fused_recurrent.py @@ -394,7 +394,7 @@ def fused_recurrent_rwkv6_fwd( output_final_state: bool = False, reverse: bool = False, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ): if head_first: B, H, T, K, V = *k.shape, v.shape[-1] @@ -445,7 +445,7 @@ def fused_recurrent_rwkv6_bwd( initial_state: Optional[torch.Tensor] = None, reverse: bool = False, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ): if head_first: B, H, T, K, V = *k.shape, v.shape[-1] @@ -558,7 +558,7 @@ def forward( output_final_state: bool = False, reverse: bool = False, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ): o, ht = fused_recurrent_rwkv6_fwd( q=q, @@ -614,19 +614,19 @@ def fused_recurrent_rwkv6( output_final_state: bool = False, reverse: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: r (torch.Tensor): - reception of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + reception of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. Alias: q, query in linear attention. k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. w (torch.Tensor): - data-dependent decays of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` in log space! Alias: g. + data-dependent decays of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` in log space! Alias: g. u (torch.Tensor): bonus of shape `[H, K]` scale (Optional[int]): @@ -645,11 +645,11 @@ def fused_recurrent_rwkv6( consistent with the FlashAttention API. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. - Default: `True`. + Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (Optional[torch.Tensor]): Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. @@ -684,13 +684,19 @@ def fused_recurrent_rwkv6( """ if cu_seqlens is not None: if r.shape[0] != 1: - raise ValueError(f"The batch size is expected to be 1 rather than {r.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") + raise ValueError( + f"The batch size is expected to be 1 rather than {r.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) if head_first: - raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: - raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " - f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) if scale is None: scale = k.shape[-1] ** -0.5 o, final_state = FusedRecurrentRWKV6Function.apply( diff --git a/fla/ops/rwkv7/chunk.py b/fla/ops/rwkv7/chunk.py index 956c458974..bd79cdc871 100644 --- a/fla/ops/rwkv7/chunk.py +++ b/fla/ops/rwkv7/chunk.py @@ -24,17 +24,17 @@ def chunk_rwkv7( """ Args: r (torch.Tensor): - r of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + r of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. w (torch.Tensor): - log decay of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + log decay of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): - k of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. v (torch.Tensor): - v of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + v of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. a (torch.Tensor): - a of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + a of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. b (torch.Tensor): - b of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + b of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. scale (float): scale of the attention. initial_state (Optional[torch.Tensor]): @@ -48,6 +48,7 @@ def chunk_rwkv7( consistent with the FlashAttention API. head_first (bool): whether to use head first. Recommended to be False to avoid extra transposes. + Default: `False`. """ return chunk_dplr_delta_rule( q=r, diff --git a/fla/ops/rwkv7/fused_recurrent.py b/fla/ops/rwkv7/fused_recurrent.py index 0ce2d15aec..2c1e49999e 100644 --- a/fla/ops/rwkv7/fused_recurrent.py +++ b/fla/ops/rwkv7/fused_recurrent.py @@ -24,17 +24,17 @@ def fused_recurrent_rwkv7( """ Args: r (torch.Tensor): - r of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + r of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. w (torch.Tensor): - log decay of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + log decay of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): - k of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. v (torch.Tensor): - v of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + v of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. a (torch.Tensor): - a of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + a of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. b (torch.Tensor): - b of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + b of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. scale (float): scale of the attention. initial_state (torch.Tensor): @@ -46,6 +46,7 @@ def fused_recurrent_rwkv7( consistent with the FlashAttention API. head_first (bool): whether to use head first. Recommended to be False to avoid extra transposes. + Default: `False`. """ return fused_recurrent_dplr_delta_rule( q=r, diff --git a/fla/ops/simple_gla/chunk.py b/fla/ops/simple_gla/chunk.py index b4c13a54aa..322183f563 100644 --- a/fla/ops/simple_gla/chunk.py +++ b/fla/ops/simple_gla/chunk.py @@ -22,7 +22,7 @@ def chunk_simple_gla_fwd( output_final_state: bool, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first) if g is not None else None @@ -65,7 +65,7 @@ def chunk_simple_gla_bwd( scale: float, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # (SY 09/22) states_in_fp32 seems not affecting the error of dg but for safety, set to True @@ -215,18 +215,18 @@ def chunk_simple_gla( initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. g (torch.Tensor): - Forget gates of shape `[B, H, T]` if `head_first=True` else `[B, T, H]`. + Forget gates of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. Compared to GLA, the gating is head-wise instead of elementwise. scale (Optional[int]): Scale factor for the attention scores. @@ -242,11 +242,11 @@ def chunk_simple_gla( consistent with the FlashAttention API. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. - Default: `True`. + Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (torch.Tensor): Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. @@ -261,31 +261,39 @@ def chunk_simple_gla( >>> k = torch.randn(B, T, H, K, device='cuda') >>> v = torch.randn(B, T, H, V, device='cuda') >>> g = F.logsigmoid(torch.randn(B, T, H, device='cuda')) - >>> o, ht = chunk_simple_gla(q, k, v, g, - initial_state=None, - output_final_state=True, - head_first=False) + >>> o, ht = chunk_simple_gla( + q, k, v, g, + initial_state=None, + output_final_state=True, + ) # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g)) # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) - >>> o_var, ht_var = chunk_simple_gla(q, k, v, g, - initial_state=None, - output_final_state=True, - cu_seqlens=cu_seqlens, - head_first=False) + >>> o_var, ht_var = chunk_simple_gla( + q, k, v, g, + initial_state=None, + output_final_state=True, + cu_seqlens=cu_seqlens, + ) >>> assert o.allclose(o_var.view(o.shape)) >>> assert ht.allclose(ht_var) """ if cu_seqlens is not None: if q.shape[0] != 1: - raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) if head_first: - raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: - raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " - f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) if scale is None: scale = k.shape[-1] ** -0.5 o, final_state = ChunkSimpleGLAFunction.apply( diff --git a/fla/ops/simple_gla/fused_recurrent.py b/fla/ops/simple_gla/fused_recurrent.py index 7012497d85..bbbf9fb16f 100644 --- a/fla/ops/simple_gla/fused_recurrent.py +++ b/fla/ops/simple_gla/fused_recurrent.py @@ -18,18 +18,18 @@ def fused_recurrent_simple_gla( output_final_state: bool = False, reverse: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. g (torch.Tensor): - Forget gates of shape `[B, H, T]` if `head_first=True` else `[B, T, H]`. + Forget gates of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. Compared to GLA, the gating is head-wise instead of elementwise. scale (Optional[int]): Scale factor for the attention scores. @@ -47,11 +47,11 @@ def fused_recurrent_simple_gla( consistent with the FlashAttention API. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. - Default: `True`. + Default: `False`. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (torch.Tensor): Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. @@ -67,31 +67,39 @@ def fused_recurrent_simple_gla( >>> v = torch.randn(B, T, H, V, device='cuda') >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) >>> h0 = torch.randn(B, H, K, V, device='cuda') - >>> o, ht = fused_recurrent_simple_gla(q, k, v, g, - initial_state=h0, - output_final_state=True, - head_first=False) + >>> o, ht = fused_recurrent_simple_gla( + q, k, v, g, + initial_state=h0, + output_final_state=True, + ) # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) - >>> o_var, ht_var = fused_recurrent_simple_gla(q, k, v, g, - initial_state=h0, - output_final_state=True, - cu_seqlens=cu_seqlens, - head_first=False) + >>> o_var, ht_var = fused_recurrent_simple_gla( + q, k, v, g, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens, + ) >>> assert o.allclose(o_var.view(o.shape)) >>> assert ht.allclose(ht_var) """ if cu_seqlens is not None: if q.shape[0] != 1: - raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) if head_first: - raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: - raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " - f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) if scale is None: scale = k.shape[-1] ** -0.5 o, final_state = fused_recurrent( diff --git a/fla/ops/simple_gla/parallel.py b/fla/ops/simple_gla/parallel.py index d0ad1c8c33..48633bd8d2 100644 --- a/fla/ops/simple_gla/parallel.py +++ b/fla/ops/simple_gla/parallel.py @@ -484,7 +484,7 @@ def parallel_simple_gla_fwd( scale: float, output_attentions: bool = False, chunk_size: int = 128, - head_first: bool = True, + head_first: bool = False, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, ): @@ -552,7 +552,7 @@ def parallel_simple_gla_bwd( do: torch.Tensor, scale: float, chunk_size: int = 128, - head_first: bool = True, + head_first: bool = False, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, ): @@ -679,18 +679,18 @@ def parallel_simple_gla( scale: Optional[float] = None, output_attentions: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` g (torch.Tensor): - Forget gates of shape `[B, H, T]` if `head_first=True` else `[B, T, H]`. + Forget gates of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. Compared to GLA, the gating is head-wise instead of elementwise. scale (Optional[int]): Scale factor for attention scores. @@ -698,14 +698,14 @@ def parallel_simple_gla( output_attentions (bool): Whether to output the materialized attention scores of shape [B, H, T, T]. Default: `False`. head_first (Optional[bool]): - Whether the inputs are in the head-first format. Default: `True`. + Whether the inputs are in the head-first format. Default: `False`. cu_seqlens (torch.LongTensor): Cumulative sequence lengths of shape `[N+1]` used for variable-length training, consistent with the FlashAttention API. Returns: o (torch.Tensor): - Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. attn (torch.Tensor): Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None` """ diff --git a/fla/ops/titans/naive.py b/fla/ops/titans/naive.py index 2a1bd4b0a8..e93c827342 100644 --- a/fla/ops/titans/naive.py +++ b/fla/ops/titans/naive.py @@ -313,7 +313,7 @@ def chunk_titans_linear_ref( chunk_size: int = 16, # chunk size initial_state: torch.Tensor = None, output_final_state: bool = False, - head_first: bool = True, + head_first: bool = False, use_chunk: bool = True, ): assert q.dtype == k.dtype == v.dtype diff --git a/fla/ops/ttt/chunk.py b/fla/ops/ttt/chunk.py index 6342364268..34ce2e4a92 100755 --- a/fla/ops/ttt/chunk.py +++ b/fla/ops/ttt/chunk.py @@ -726,7 +726,7 @@ def chunk_ttt_linear_fwd_h( output_final_state: bool = False, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 16, ) -> Tuple[torch.Tensor, torch.Tensor]: if head_first: @@ -798,7 +798,7 @@ def chunk_ttt_linear_fwd_o( scale: Optional[float] = None, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 64 ) -> torch.Tensor: if head_first: @@ -853,7 +853,7 @@ def chunk_ttt_linear_bwd_h( initial_state_bias: Optional[torch.Tensor] = None, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 16, ) -> Tuple[torch.Tensor, torch.Tensor]: if head_first: @@ -923,7 +923,7 @@ def chunk_ttt_linear_bwd_dv_local( scale: float, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 16 ) -> torch.Tensor: if head_first: @@ -979,7 +979,7 @@ def chunk_ttt_linear_bwd_norm( scale: float, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 16 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # torch implementation of `dkh, dw, db, dk, dv` for LN^2 @@ -1075,7 +1075,7 @@ def chunk_ttt_linear_bwd_norm_ref( eps: float, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 16 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # torch implementation of `dkh, dw, db, dk, dv` for LN^2 @@ -1175,7 +1175,7 @@ def chunk_ttt_linear_bwd_dqke( scale: float, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, chunk_size: int = 16, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -1238,7 +1238,7 @@ def chunk_ttt_linear_fwd( output_final_state: bool, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, BT: int = 16 ): h, hb, v_new, final_state, final_state_bias = chunk_ttt_linear_fwd_h( @@ -1289,7 +1289,7 @@ def chunk_ttt_linear_bwd( initial_state_bias: torch.Tensor = None, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ): h, v_new, x, y, rstd = chunk_ttt_linear_bwd_h( k=k, @@ -1463,7 +1463,7 @@ def chunk_ttt_linear( initial_state_bias: torch.Tensor = None, output_final_state: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, ): r""" Args: @@ -1495,7 +1495,8 @@ def chunk_ttt_linear( consistent with the FlashAttention API. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. - Default: `True`. + Default: `False`. + Returns: o (torch.Tensor): Outputs of shape `[B, H, T, V]` @@ -1508,13 +1509,19 @@ def chunk_ttt_linear( eta = torch.full_like(q[:, :, :, :1], eta) if cu_seqlens is not None: if q.shape[0] != 1: - raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) if head_first: - raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: - raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " - f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) if scale is None: scale = k.shape[-1] ** -0.5 else: diff --git a/fla/ops/ttt/fused_chunk.py b/fla/ops/ttt/fused_chunk.py index 08850c170b..d574910a18 100755 --- a/fla/ops/ttt/fused_chunk.py +++ b/fla/ops/ttt/fused_chunk.py @@ -474,7 +474,7 @@ def fused_chunk_ttt_linear_bwd_h( initial_state: torch.Tensor = None, initial_state_bias: torch.Tensor = None, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ): assert offsets is None, "bwd of varlen is not implemented yet." if head_first: @@ -547,7 +547,7 @@ def fused_chunk_ttt_linear_bwd_dh( initial_state: torch.Tensor = None, initial_state_bias: torch.Tensor = None, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ): assert offsets is None, "bwd of varlen is not implemented yet." if head_first: @@ -618,7 +618,7 @@ def fused_chunk_ttt_linear_fwd( initial_state_bias: torch.Tensor, output_final_state: bool, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, BT: int = 16 ): if head_first: @@ -677,7 +677,7 @@ def fused_chunk_ttt_linear_bwd( initial_state: torch.Tensor = None, initial_state_bias: torch.Tensor = None, offsets: Optional[torch.LongTensor] = None, - head_first: bool = True + head_first: bool = False ): assert offsets is None, "bwd of varlen is not implemented yet." dq, h, v2, x, y, rstd = fused_chunk_ttt_linear_bwd_h( @@ -817,7 +817,7 @@ def fused_chunk_ttt_linear( initial_state_bias: torch.Tensor = None, output_final_state: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = True, + head_first: bool = False, ): r""" Args: @@ -849,7 +849,7 @@ def fused_chunk_ttt_linear( consistent with the FlashAttention API. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. - Default: `True`. + Default: `False`. Returns: o (torch.Tensor): @@ -865,13 +865,19 @@ def fused_chunk_ttt_linear( eta = torch.full_like(q[:, :, :, :1], eta) if cu_seqlens is not None: if q.shape[0] != 1: - raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) if head_first: - raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: - raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " - f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) if scale is None: scale = k.shape[-1] ** -0.5 else: diff --git a/fla/ops/ttt/naive.py b/fla/ops/ttt/naive.py index 0ad5dbba89..c756ae6c9c 100755 --- a/fla/ops/ttt/naive.py +++ b/fla/ops/ttt/naive.py @@ -83,7 +83,7 @@ def chunk_ttt_linear_ref( initial_state: torch.Tensor = None, initial_state_bias: torch.Tensor = None, output_final_state: bool = False, - head_first: bool = True, + head_first: bool = False, ): assert q.dtype == k.dtype == v.dtype assert k.shape[-1] == v.shape[-1], "The key and value dimension must be the same." diff --git a/fla/ops/utils/cumsum.py b/fla/ops/utils/cumsum.py index 5a5f3e90d3..8092495025 100644 --- a/fla/ops/utils/cumsum.py +++ b/fla/ops/utils/cumsum.py @@ -229,7 +229,7 @@ def chunk_local_cumsum_scalar( reverse: bool = False, offsets: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, - head_first: bool = True, + head_first: bool = False, output_dtype: Optional[torch.dtype] = torch.float ) -> torch.Tensor: if head_first: @@ -263,7 +263,7 @@ def chunk_local_cumsum_vector( reverse: bool = False, offsets: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, - head_first: bool = True, + head_first: bool = False, output_dtype: Optional[torch.dtype] = torch.float ) -> torch.Tensor: if head_first: @@ -300,7 +300,7 @@ def chunk_global_cumsum_scalar( dtype: Optional[torch.dtype] = None, reverse: bool = False, offsets: Optional[torch.Tensor] = None, - head_first: bool = True, + head_first: bool = False, output_dtype: Optional[torch.dtype] = torch.float ) -> torch.Tensor: dtype = dtype or s.dtype @@ -330,7 +330,7 @@ def chunk_global_cumsum_vector( dtype: Optional[torch.dtype] = None, reverse: bool = False, offsets: Optional[torch.Tensor] = None, - head_first: bool = True, + head_first: bool = False, output_dtype: Optional[torch.dtype] = torch.float ) -> torch.Tensor: dtype = dtype or s.dtype @@ -363,7 +363,7 @@ def chunk_global_cumsum( dtype: Optional[torch.dtype] = None, reverse: bool = False, offsets: Optional[torch.Tensor] = None, - head_first: bool = True, + head_first: bool = False, output_dtype: Optional[torch.dtype] = torch.float ) -> torch.Tensor: if offsets is not None: @@ -385,7 +385,7 @@ def chunk_local_cumsum( reverse: bool = False, offsets: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, - head_first: bool = True, + head_first: bool = False, output_dtype: Optional[torch.dtype] = torch.float ) -> torch.Tensor: if offsets is not None: @@ -395,6 +395,8 @@ def chunk_local_cumsum( elif len(g.shape) == 4: return chunk_local_cumsum_vector(g, chunk_size, reverse, offsets, indices, head_first, output_dtype) else: - raise ValueError(f"Unsupported input shape {g.shape}. " - f"which should be (B, H, T, dim) if `head_first=True` " - f"or (batch_size, num_heads, seq_len) otherwise") + raise ValueError( + f"Unsupported input shape {g.shape}. " + f"which should be (B, H, T, dim) if `head_first=True` " + f"or (batch_size, num_heads, seq_len) otherwise" + ) From a2c9831cf6bd44bcde09b01d892764ed6129c2d0 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 7 Apr 2025 19:14:48 +0800 Subject: [PATCH 2/4] Update docstrings --- fla/ops/forgetting_attn/parallel.py | 2 +- fla/ops/gated_delta_rule/chunk.py | 6 ++---- fla/ops/gated_delta_rule/fused_recurrent.py | 2 +- fla/ops/gla/chunk.py | 20 +++++++++++--------- fla/ops/gla/fused_recurrent.py | 20 +++++++++++--------- fla/ops/gsa/chunk.py | 20 +++++++++++--------- fla/ops/gsa/fused_recurrent.py | 4 ++-- fla/ops/rwkv6/chunk.py | 20 +++++++++++--------- fla/ops/rwkv6/fused_recurrent.py | 20 +++++++++++--------- fla/ops/simple_gla/chunk.py | 4 ++-- fla/ops/simple_gla/fused_recurrent.py | 4 ++-- 11 files changed, 65 insertions(+), 57 deletions(-) diff --git a/fla/ops/forgetting_attn/parallel.py b/fla/ops/forgetting_attn/parallel.py index 88fea7f29b..4fd92891ac 100644 --- a/fla/ops/forgetting_attn/parallel.py +++ b/fla/ops/forgetting_attn/parallel.py @@ -620,7 +620,7 @@ def forward(ctx, q, k, v, g, scale, offsets): # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None - g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=False) + g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices) o, lse = parallel_forgetting_attn_fwd( q=q, k=k, diff --git a/fla/ops/gated_delta_rule/chunk.py b/fla/ops/gated_delta_rule/chunk.py index 36ba416a33..3f92abb545 100644 --- a/fla/ops/gated_delta_rule/chunk.py +++ b/fla/ops/gated_delta_rule/chunk.py @@ -333,8 +333,7 @@ def chunk_gated_delta_rule( >>> o, ht = chunk_gated_delta_rule( q, k, v, g, beta, initial_state=h0, - output_final_state=True, - head_first=False + output_final_state=True ) # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) @@ -344,8 +343,7 @@ def chunk_gated_delta_rule( q, k, v, g, beta, initial_state=h0, output_final_state=True, - cu_seqlens=cu_seqlens, - head_first=False + cu_seqlens=cu_seqlens ) """ assert q.dtype == k.dtype == v.dtype diff --git a/fla/ops/gated_delta_rule/fused_recurrent.py b/fla/ops/gated_delta_rule/fused_recurrent.py index 4c73b8a40b..d6ea2875bc 100644 --- a/fla/ops/gated_delta_rule/fused_recurrent.py +++ b/fla/ops/gated_delta_rule/fused_recurrent.py @@ -266,7 +266,7 @@ def fused_recurrent_gated_delta_rule( >>> o, ht = fused_gated_recurrent_delta_rule( q, k, v, g, beta, initial_state=h0, - output_final_state=True, + output_final_state=True ) # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) diff --git a/fla/ops/gla/chunk.py b/fla/ops/gla/chunk.py index 3648a6192f..df86f81963 100644 --- a/fla/ops/gla/chunk.py +++ b/fla/ops/gla/chunk.py @@ -1455,19 +1455,21 @@ def chunk_gla( >>> v = torch.randn(B, T, H, V, device='cuda') >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) >>> h0 = torch.randn(B, H, K, V, device='cuda') - >>> o, ht = chunk_gla(q, k, v, g, - initial_state=h0, - output_final_state=True, - head_first=False) + >>> o, ht = chunk_gla( + q, k, v, g, + initial_state=h0, + output_final_state=True + ) # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) - >>> o_var, ht_var = chunk_gla(q, k, v, g, - initial_state=h0, - output_final_state=True, - cu_seqlens=cu_seqlens, - head_first=False) + >>> o_var, ht_var = chunk_gla( + q, k, v, g, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) >>> assert o.allclose(o_var.view(o.shape)) >>> assert ht.allclose(ht_var) """ diff --git a/fla/ops/gla/fused_recurrent.py b/fla/ops/gla/fused_recurrent.py index b5e31db124..bb18848f34 100644 --- a/fla/ops/gla/fused_recurrent.py +++ b/fla/ops/gla/fused_recurrent.py @@ -69,19 +69,21 @@ def fused_recurrent_gla( >>> v = torch.randn(B, T, H, V, device='cuda') >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) >>> h0 = torch.randn(B, H, K, V, device='cuda') - >>> o, ht = fused_recurrent_gla(q, k, v, g, - initial_state=h0, - output_final_state=True, - head_first=False) + >>> o, ht = fused_recurrent_gla( + q, k, v, g, + initial_state=h0, + output_final_state=True + ) # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) - >>> o_var, ht_var = fused_recurrent_gla(q, k, v, g, - initial_state=h0, - output_final_state=True, - cu_seqlens=cu_seqlens, - head_first=False) + >>> o_var, ht_var = fused_recurrent_gla( + q, k, v, g, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) >>> assert o.allclose(o_var.view(o.shape)) >>> assert ht.allclose(ht_var) """ diff --git a/fla/ops/gsa/chunk.py b/fla/ops/gsa/chunk.py index 2f961b9fe8..c528f9136d 100644 --- a/fla/ops/gsa/chunk.py +++ b/fla/ops/gsa/chunk.py @@ -1209,19 +1209,21 @@ def chunk_gsa( >>> s = torch.randn(B, T, H, M, device='cuda') >>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda')) >>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda')) - >>> o, (hk, hv) = chunk_gsa(q, k, v, s, g, - initial_state=h0, - output_final_state=True, - head_first=False) + >>> o, (hk, hv) = chunk_gsa( + q, k, v, s, g, + initial_state=h0, + output_final_state=True + ) # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g)) # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) - >>> o_var, (hk_var, hv_var) = chunk_gsa(q, k, v, s, g, - initial_state=h0, - output_final_state=True, - cu_seqlens=cu_seqlens, - head_first=False) + >>> o_var, (hk_var, hv_var) = chunk_gsa( + q, k, v, s, g, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) >>> assert o.allclose(o_var.view(o.shape)) >>> assert hk.allclose(hk_var) >>> assert hv.allclose(hv_var) diff --git a/fla/ops/gsa/fused_recurrent.py b/fla/ops/gsa/fused_recurrent.py index 5f436f03e4..d704ec4732 100644 --- a/fla/ops/gsa/fused_recurrent.py +++ b/fla/ops/gsa/fused_recurrent.py @@ -521,7 +521,7 @@ def fused_recurrent_gsa( >>> o, (hk, hv) = fused_recurrent_gsa( q, k, v, s, g, initial_state=h0, - output_final_state=True, + output_final_state=True ) # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g)) @@ -531,7 +531,7 @@ def fused_recurrent_gsa( q, k, v, s, g, initial_state=h0, output_final_state=True, - cu_seqlens=cu_seqlens, + cu_seqlens=cu_seqlens ) >>> assert o.allclose(o_var.view(o.shape)) >>> assert hk.allclose(hk_var) diff --git a/fla/ops/rwkv6/chunk.py b/fla/ops/rwkv6/chunk.py index 296477a00c..5e790fac83 100644 --- a/fla/ops/rwkv6/chunk.py +++ b/fla/ops/rwkv6/chunk.py @@ -1423,19 +1423,21 @@ def chunk_rwkv6( >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) >>> u = torch.randn(H, K, device='cuda') >>> h0 = torch.randn(B, H, K, V, device='cuda') - >>> o, ht = chunk_rwkv6(q, k, v, g, u, - initial_state=h0, - output_final_state=True, - head_first=False) + >>> o, ht = chunk_rwkv6( + q, k, v, g, u, + initial_state=h0, + output_final_state=True + ) # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) - >>> o_var, ht_var = chunk_rwkv6(q, k, v, g, u, - initial_state=h0, - output_final_state=True, - cu_seqlens=cu_seqlens, - head_first=False) + >>> o_var, ht_var = chunk_rwkv6( + q, k, v, g, u, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) >>> assert o.allclose(o_var.view(o.shape)) >>> assert ht.allclose(ht_var) """ diff --git a/fla/ops/rwkv6/fused_recurrent.py b/fla/ops/rwkv6/fused_recurrent.py index e1c96f6e05..c95ec015ab 100644 --- a/fla/ops/rwkv6/fused_recurrent.py +++ b/fla/ops/rwkv6/fused_recurrent.py @@ -666,19 +666,21 @@ def fused_recurrent_rwkv6( >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) >>> u = torch.randn(H, K, device='cuda') >>> h0 = torch.randn(B, H, K, V, device='cuda') - >>> o, ht = fused_recurrent_rwkv6(q, k, v, g, u, - initial_state=h0, - output_final_state=True, - head_first=False) + >>> o, ht = fused_recurrent_rwkv6( + q, k, v, g, u, + initial_state=h0, + output_final_state=True + ) # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) - >>> o_var, ht_var = fused_recurrent_rwkv6(q, k, v, g, u, - initial_state=h0, - output_final_state=True, - cu_seqlens=cu_seqlens, - head_first=False) + >>> o_var, ht_var = fused_recurrent_rwkv6( + q, k, v, g, u, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) >>> assert o.allclose(o_var.view(o.shape)) >>> assert ht.allclose(ht_var) """ diff --git a/fla/ops/simple_gla/chunk.py b/fla/ops/simple_gla/chunk.py index 322183f563..f6d6973ef1 100644 --- a/fla/ops/simple_gla/chunk.py +++ b/fla/ops/simple_gla/chunk.py @@ -264,7 +264,7 @@ def chunk_simple_gla( >>> o, ht = chunk_simple_gla( q, k, v, g, initial_state=None, - output_final_state=True, + output_final_state=True ) # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g)) @@ -274,7 +274,7 @@ def chunk_simple_gla( q, k, v, g, initial_state=None, output_final_state=True, - cu_seqlens=cu_seqlens, + cu_seqlens=cu_seqlens ) >>> assert o.allclose(o_var.view(o.shape)) >>> assert ht.allclose(ht_var) diff --git a/fla/ops/simple_gla/fused_recurrent.py b/fla/ops/simple_gla/fused_recurrent.py index bbbf9fb16f..21e742fc08 100644 --- a/fla/ops/simple_gla/fused_recurrent.py +++ b/fla/ops/simple_gla/fused_recurrent.py @@ -70,7 +70,7 @@ def fused_recurrent_simple_gla( >>> o, ht = fused_recurrent_simple_gla( q, k, v, g, initial_state=h0, - output_final_state=True, + output_final_state=True ) # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) @@ -80,7 +80,7 @@ def fused_recurrent_simple_gla( q, k, v, g, initial_state=h0, output_final_state=True, - cu_seqlens=cu_seqlens, + cu_seqlens=cu_seqlens ) >>> assert o.allclose(o_var.view(o.shape)) >>> assert ht.allclose(ht_var) From 5a562947fd08a95e9e06ce15d09942158ef427e1 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 7 Apr 2025 19:17:00 +0800 Subject: [PATCH 3/4] Update default arg in cumsum --- fla/ops/utils/cumsum.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/fla/ops/utils/cumsum.py b/fla/ops/utils/cumsum.py index 8092495025..c129883472 100644 --- a/fla/ops/utils/cumsum.py +++ b/fla/ops/utils/cumsum.py @@ -373,9 +373,11 @@ def chunk_global_cumsum( elif len(s.shape) == 4: return chunk_global_cumsum_vector(s, dtype, reverse, offsets, head_first, output_dtype) else: - raise ValueError(f"Unsupported input shape {s.shape}. " - f"which should be [B, H, T]/[B, H, T, D] if `head_first=True` " - f"or [B, T, H]/[B, T, H, D] otherwise") + raise ValueError( + f"Unsupported input shape {s.shape}. " + f"which should be [B, T, H]/[B, T, H, D] if `head_first=False` " + f"or [B, H, T]/[B, H, T, D] otherwise" + ) @input_guard @@ -397,6 +399,6 @@ def chunk_local_cumsum( else: raise ValueError( f"Unsupported input shape {g.shape}. " - f"which should be (B, H, T, dim) if `head_first=True` " - f"or (batch_size, num_heads, seq_len) otherwise" + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise" ) From 1461f5babdc72d9c18fe8c89eb19d7282782705c Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 7 Apr 2025 19:39:30 +0800 Subject: [PATCH 4/4] Rename `USE_OFFSETS` to `IS_VARLEN` --- fla/ops/attn/parallel.py | 18 +++---- fla/ops/common/chunk_delta_h.py | 52 +++++++++++++----- fla/ops/common/chunk_h.py | 12 ++--- fla/ops/common/chunk_h_parallel.py | 24 ++++----- fla/ops/common/chunk_h_split.py | 24 ++++----- fla/ops/common/chunk_o.py | 24 ++++----- fla/ops/common/chunk_scaled_dot_kkt.py | 8 +-- fla/ops/common/fused_recurrent.py | 12 ++--- fla/ops/delta_rule/fused_recurrent.py | 12 ++--- fla/ops/delta_rule/wy_fast.py | 16 +++--- fla/ops/forgetting_attn/parallel.py | 18 +++---- fla/ops/gated_delta_rule/fused_recurrent.py | 6 +-- fla/ops/gated_delta_rule/wy_fast.py | 32 +++++------ .../dplr/chunk_A_bwd.py | 12 ++--- .../dplr/chunk_A_fwd.py | 12 ++--- .../dplr/chunk_h_bwd.py | 6 +-- .../dplr/chunk_h_fwd.py | 6 +-- .../dplr/chunk_o_bwd.py | 18 +++---- .../dplr/chunk_o_fwd.py | 6 +-- .../dplr/fused_recurrent.py | 6 +-- .../dplr/wy_fast_bwd.py | 6 +-- .../dplr/wy_fast_fwd.py | 18 +++---- fla/ops/generalized_delta_rule/iplr/chunk.py | 12 ++--- .../iplr/fused_recurrent.py | 12 ++--- .../generalized_delta_rule/iplr/wy_fast.py | 18 +++---- fla/ops/gla/chunk.py | 54 +++++++++---------- fla/ops/gsa/chunk.py | 30 +++++------ fla/ops/hgrn/fused_recurrent.py | 12 ++--- fla/ops/nsa/compression.py | 18 +++---- fla/ops/nsa/parallel.py | 24 ++++----- fla/ops/rwkv6/chunk.py | 48 ++++++++--------- fla/ops/rwkv6/fused_recurrent.py | 24 ++++----- fla/ops/simple_gla/parallel.py | 15 +++--- fla/ops/ttt/chunk.py | 36 ++++++------- fla/ops/ttt/fused_chunk.py | 6 +-- fla/ops/utils/cumsum.py | 24 ++++----- fla/ops/utils/pooling.py | 12 ++--- fla/ops/utils/solve_tril.py | 23 ++++---- 38 files changed, 370 insertions(+), 346 deletions(-) diff --git a/fla/ops/attn/parallel.py b/fla/ops/attn/parallel.py index d19a2e1b13..88de1f1441 100644 --- a/fla/ops/attn/parallel.py +++ b/fla/ops/attn/parallel.py @@ -14,7 +14,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -45,13 +45,13 @@ def parallel_attn_fwd_kernel( BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_hq = i_bh // HQ, i_bh % HQ i_h = i_hq // G - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -147,7 +147,7 @@ def parallel_attn_bwd_kernel_preprocess( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -180,13 +180,13 @@ def parallel_attn_bwd_kernel_dq( BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_hq = i_bh // HQ, i_bh % HQ i_h = i_hq // G - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -258,7 +258,7 @@ def parallel_attn_bwd_kernel_dq( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -292,13 +292,13 @@ def parallel_attn_bwd_kernel_dkv( BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_hq = i_bh // HQ, i_bh % HQ i_h = i_hq // G - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos diff --git a/fla/ops/common/chunk_delta_h.py b/fla/ops/common/chunk_delta_h.py index 8c4f593005..6702da7365 100644 --- a/fla/ops/common/chunk_delta_h.py +++ b/fla/ops/common/chunk_delta_h.py @@ -16,7 +16,7 @@ @triton.heuristics({ 'USE_G': lambda args: args['g'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, 'SAVE_NEW_VALUE': lambda args: args['v_new'] is not None @@ -53,13 +53,13 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( USE_G: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, SAVE_NEW_VALUE: tl.constexpr, ): i_v, i_nh = tl.program_id(0), tl.program_id(1) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) @@ -207,7 +207,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( 'USE_G': lambda args: args['g'] is not None, 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -243,12 +243,12 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( USE_G: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, USE_FINAL_STATE_GRADIENT: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_nh = tl.program_id(0), tl.program_id(1) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) @@ -451,8 +451,21 @@ def chunk_gated_delta_rule_fwd_h( k_new = torch.empty_like(k).fill_(float('-inf')) w_new = torch.empty_like(w).fill_(float('-inf')) def grid(meta): return (triton.cdiv(K, meta['BK']), N*H, triton.cdiv(T, BT)) - proprocess_qkw[grid](None, k, w, g, None, k_new, w_new, offsets, T, H, K, BT, - USE_OFFSETS=offsets is not None, HEAD_FIRST=head_first) + proprocess_qkw[grid]( + q=None, + k=k, + w=w, + g=g, + q_new=None, + k_new=k_new, + w_new=w_new, + offsets=offsets, + T=T, + H=H, + K=K, + BT=BT, + HEAD_FIRST=head_first + ) v_new = torch.empty_like(u) if save_new_value else None def grid(meta): return (triton.cdiv(V, meta['BV']), N*H) @@ -560,7 +573,7 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), N*H) @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, 'USE_Q': lambda args: args['q'] is not None, }) @triton.autotune( @@ -575,14 +588,27 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), N*H) ) @triton.jit(do_not_specialize=['T']) def proprocess_qkw( - q, k, w, g, q_new, k_new, w_new, offsets, - T, H: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, - USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr, USE_Q: tl.constexpr + q, + k, + w, + g, + q_new, + k_new, + w_new, + offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_Q: tl.constexpr ): i_k, i_nh, i_t = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: diff --git a/fla/ops/common/chunk_h.py b/fla/ops/common/chunk_h.py index 8da7b585fa..50c5a9a952 100644 --- a/fla/ops/common/chunk_h.py +++ b/fla/ops/common/chunk_h.py @@ -17,7 +17,7 @@ @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -54,12 +54,12 @@ def chunk_fwd_kernel_h( USE_GV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) @@ -155,7 +155,7 @@ def chunk_fwd_kernel_h( @triton.heuristics({ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -195,14 +195,14 @@ def chunk_bwd_kernel_dh( USE_GV: tl.constexpr, STORE_INITIAL_STATE_GRADIENT: tl.constexpr, USE_FINAL_STATE_GRADIENT: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_bg = i_nh // NG i_n, i_hq = i_nh // HQ, i_nh % HQ i_h = i_hq // NG - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) diff --git a/fla/ops/common/chunk_h_parallel.py b/fla/ops/common/chunk_h_parallel.py index aa951b756c..87d0eff07f 100644 --- a/fla/ops/common/chunk_h_parallel.py +++ b/fla/ops/common/chunk_h_parallel.py @@ -17,7 +17,7 @@ @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -53,7 +53,7 @@ def chunk_fwd_kernel_h_parallel( USE_GV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) @@ -66,7 +66,7 @@ def chunk_fwd_kernel_h_parallel( # i_tg: (global) chunk index across all sequences i_k, i_v = i_kv // NV, i_kv % NV i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) @@ -157,7 +157,7 @@ def chunk_fwd_kernel_h_parallel( @triton.heuristics({ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -190,12 +190,12 @@ def chunk_fwd_kernel_h_reduction( USE_GK: tl.constexpr, USE_GV: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) @@ -257,7 +257,7 @@ def chunk_fwd_kernel_h_reduction( @triton.heuristics({ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -296,7 +296,7 @@ def chunk_bwd_kernel_dh_parallel( USE_GV: tl.constexpr, STORE_INITIAL_STATE_GRADIENT: tl.constexpr, USE_FINAL_STATE_GRADIENT: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) @@ -305,7 +305,7 @@ def chunk_bwd_kernel_dh_parallel( i_k, i_v = i_kv // NV, i_kv % NV i_b, i_hq, i_bg = i_bh // HQ, i_bh % HQ, i_bh // NG i_h = i_hq // NG - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) @@ -379,7 +379,7 @@ def chunk_bwd_kernel_dh_parallel( @triton.heuristics({ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -414,14 +414,14 @@ def chunk_bwd_kernel_dh_reduction( USE_GK: tl.constexpr, USE_GV: tl.constexpr, STORE_INITIAL_STATE_GRADIENT: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_bg = i_nh // NG i_n, i_hq = i_nh // HQ, i_nh % HQ i_h = i_hq // NG - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) diff --git a/fla/ops/common/chunk_h_split.py b/fla/ops/common/chunk_h_split.py index 6375328d3c..dc40b8fa79 100644 --- a/fla/ops/common/chunk_h_split.py +++ b/fla/ops/common/chunk_h_split.py @@ -13,7 +13,7 @@ @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -51,7 +51,7 @@ def chunk_fwd_kernel_h_split( USE_GV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): # handle one split at a time @@ -60,7 +60,7 @@ def chunk_fwd_kernel_h_split( # i_s: local split index inside a sequence i_k, i_v, i_sh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_ss, i_h = i_sh // H, i_sh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -152,7 +152,7 @@ def chunk_fwd_kernel_h_split( @triton.heuristics({ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -186,12 +186,12 @@ def chunk_fwd_kernel_h_reduction( USE_GK: tl.constexpr, USE_GV: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos NS = tl.cdiv(T, S) @@ -252,7 +252,7 @@ def chunk_fwd_kernel_h_reduction( @triton.heuristics({ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -293,7 +293,7 @@ def chunk_bwd_kernel_dh_split( USE_GV: tl.constexpr, USE_FINAL_STATE_GRADIENT: tl.constexpr, STORE_INITIAL_STATE_GRADIENT: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): # handle one split at a time @@ -302,7 +302,7 @@ def chunk_bwd_kernel_dh_split( # i_s: local split index inside a sequence i_k, i_v, i_sh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_ss, i_hq = i_sh // HQ, i_sh % HQ - if USE_OFFSETS: + if IS_VARLEN: i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -390,7 +390,7 @@ def chunk_bwd_kernel_dh_split( @triton.heuristics({ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -426,13 +426,13 @@ def chunk_bwd_kernel_dh_reduction( USE_GK: tl.constexpr, USE_GV: tl.constexpr, STORE_INITIAL_STATE_GRADIENT: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_hq = i_nh // HQ, i_nh % HQ i_ng, i_h = i_nh // NG, i_hq // NG - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos NS = tl.cdiv(T, S) diff --git a/fla/ops/common/chunk_o.py b/fla/ops/common/chunk_o.py index e0912f8783..06469bf7f7 100644 --- a/fla/ops/common/chunk_o.py +++ b/fla/ops/common/chunk_o.py @@ -16,7 +16,7 @@ @triton.heuristics({ 'USE_G': lambda args: args['g'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -47,13 +47,13 @@ def chunk_fwd_kernel_o( BK: tl.constexpr, BV: tl.constexpr, USE_G: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) @@ -115,7 +115,7 @@ def chunk_fwd_kernel_o( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, 'USE_G': lambda args: args['g'] is not None, 'USE_DW': lambda args: args['dw'] is not None }) @@ -155,14 +155,14 @@ def chunk_bwd_kernel_dqkwg( BV: tl.constexpr, USE_G: tl.constexpr, USE_DW: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_G: dg += i_k * B * H * T - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) @@ -287,7 +287,7 @@ def chunk_bwd_kernel_dqkwg( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, 'USE_G': lambda args: args['g'] is not None, }) @triton.autotune( @@ -317,12 +317,12 @@ def chunk_bwd_kernel_dv( BK: tl.constexpr, BV: tl.constexpr, USE_G: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) @@ -377,7 +377,7 @@ def chunk_bwd_kernel_dv( @triton.heuristics({ 'USE_G': lambda args: args['g'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -405,12 +405,12 @@ def chunk_bwd_kernel_dv_local( BK: tl.constexpr, BV: tl.constexpr, USE_G: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos diff --git a/fla/ops/common/chunk_scaled_dot_kkt.py b/fla/ops/common/chunk_scaled_dot_kkt.py index ff30664dce..c327604479 100644 --- a/fla/ops/common/chunk_scaled_dot_kkt.py +++ b/fla/ops/common/chunk_scaled_dot_kkt.py @@ -11,7 +11,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -20,7 +20,7 @@ for num_warps in [2, 4, 8] for num_stages in [2, 3, 4] ], - key=['H', 'K', 'BT', 'USE_OFFSETS'], + key=['H', 'K', 'BT', 'IS_VARLEN'], ) @triton.jit(do_not_specialize=['T']) def chunk_scaled_dot_kkt_fwd_kernel( @@ -35,11 +35,11 @@ def chunk_scaled_dot_kkt_fwd_kernel( BT: tl.constexpr, BK: tl.constexpr, HEAD_FIRST: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos diff --git a/fla/ops/common/fused_recurrent.py b/fla/ops/common/fused_recurrent.py index 4cd0bd6fcf..047e3b9945 100644 --- a/fla/ops/common/fused_recurrent.py +++ b/fla/ops/common/fused_recurrent.py @@ -15,7 +15,7 @@ @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -50,13 +50,13 @@ def fused_recurrent_fwd_kernel( USE_GV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): # indices i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) all = T T = eos - bos @@ -133,7 +133,7 @@ def fused_recurrent_fwd_kernel( 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -173,12 +173,12 @@ def fused_recurrent_bwd_kernel( USE_INITIAL_STATE: tl.constexpr, STORE_INITIAL_STATE_GRADIENT: tl.constexpr, USE_FINAL_STATE_GRADIENT: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) all = T T = eos - bos diff --git a/fla/ops/delta_rule/fused_recurrent.py b/fla/ops/delta_rule/fused_recurrent.py index 559071cda9..f38790aff6 100644 --- a/fla/ops/delta_rule/fused_recurrent.py +++ b/fla/ops/delta_rule/fused_recurrent.py @@ -15,7 +15,7 @@ @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.jit(do_not_specialize=['T']) def fused_recurrent_delta_rule_fwd_kernel( @@ -39,12 +39,12 @@ def fused_recurrent_delta_rule_fwd_kernel( USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, IS_BETA_HEADWISE: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) all = T T = eos - bos @@ -114,7 +114,7 @@ def fused_recurrent_delta_rule_fwd_kernel( @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.jit(do_not_specialize=['T']) def fused_recurrent_delta_rule_bwd_kernel( @@ -143,12 +143,12 @@ def fused_recurrent_delta_rule_bwd_kernel( IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar USE_INITIAL_STATE: tl.constexpr, # whether to use dh0 USE_FINAL_STATE_GRADIENT: tl.constexpr, # whether to use dht - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) all = T T = eos - bos diff --git a/fla/ops/delta_rule/wy_fast.py b/fla/ops/delta_rule/wy_fast.py index 5a863b9155..b47434dbd2 100644 --- a/fla/ops/delta_rule/wy_fast.py +++ b/fla/ops/delta_rule/wy_fast.py @@ -15,7 +15,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -23,7 +23,7 @@ for num_warps in [2, 4, 8] for num_stages in [2, 3, 4] ], - key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'IS_VARLEN'], ) @triton.jit(do_not_specialize=['T']) def fwd_recompute_w_u_kernel( @@ -43,11 +43,11 @@ def fwd_recompute_w_u_kernel( BK: tl.constexpr, BV: tl.constexpr, HEAD_FIRST: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -89,7 +89,7 @@ def fwd_recompute_w_u_kernel( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -97,7 +97,7 @@ def fwd_recompute_w_u_kernel( for num_warps in NUM_WARPS for num_stages in [2, 3, 4] ], - key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'IS_VARLEN'], ) @triton.jit(do_not_specialize=['T']) def bwd_prepare_wy_repr_kernel( @@ -120,11 +120,11 @@ def bwd_prepare_wy_repr_kernel( BK: tl.constexpr, BV: tl.constexpr, HEAD_FIRST: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos diff --git a/fla/ops/forgetting_attn/parallel.py b/fla/ops/forgetting_attn/parallel.py index 4fd92891ac..26cb303a5c 100644 --- a/fla/ops/forgetting_attn/parallel.py +++ b/fla/ops/forgetting_attn/parallel.py @@ -15,7 +15,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -47,13 +47,13 @@ def parallel_forgetting_attn_fwd_kernel( BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_hq = i_bh // HQ, i_bh % HQ i_h = i_hq // G - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -164,7 +164,7 @@ def parallel_forgetting_attn_bwd_kernel_preprocess( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -199,13 +199,13 @@ def parallel_forgetting_attn_bwd_kernel_dq( BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_hq = i_bh // HQ, i_bh % HQ i_h = i_hq // G - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -296,7 +296,7 @@ def parallel_forgetting_attn_bwd_kernel_dq( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -332,13 +332,13 @@ def parallel_forgetting_attn_bwd_kernel_dkv( BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_hq = i_bh // HQ, i_bh % HQ i_h = i_hq // G - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos diff --git a/fla/ops/gated_delta_rule/fused_recurrent.py b/fla/ops/gated_delta_rule/fused_recurrent.py index d6ea2875bc..1f4931cd04 100644 --- a/fla/ops/gated_delta_rule/fused_recurrent.py +++ b/fla/ops/gated_delta_rule/fused_recurrent.py @@ -15,7 +15,7 @@ @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.jit(do_not_specialize=['T']) def fused_recurrent_gated_delta_rule_fwd_kernel( @@ -40,11 +40,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( STORE_FINAL_STATE: tl.constexpr, # whether to store final state IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, USE_QK_L2NORM_IN_KERNEL: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) all = T T = eos - bos diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py index ad43e6a431..e205faa0be 100644 --- a/fla/ops/gated_delta_rule/wy_fast.py +++ b/fla/ops/gated_delta_rule/wy_fast.py @@ -12,7 +12,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -20,7 +20,7 @@ for num_warps in [2, 4, 8] for num_stages in [2, 3, 4] ], - key=['H', 'K', 'BT', 'BK', 'BC', 'HEAD_FIRST', 'USE_OFFSETS'], + key=['H', 'K', 'BT', 'BK', 'BC', 'HEAD_FIRST', 'IS_VARLEN'], ) @triton.jit(do_not_specialize=['T']) def fwd_prepare_wy_repr_kernel_chunk32( @@ -38,11 +38,11 @@ def fwd_prepare_wy_repr_kernel_chunk32( BK: tl.constexpr, BC: tl.constexpr, HEAD_FIRST: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -100,7 +100,7 @@ def fwd_prepare_wy_repr_kernel_chunk32( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -108,7 +108,7 @@ def fwd_prepare_wy_repr_kernel_chunk32( for num_warps in [2, 4, 8] for num_stages in [2, 3, 4] ], - key=['H', 'K', 'BT', 'BK', 'BC', 'USE_OFFSETS', 'HEAD_FIRST'], + key=['H', 'K', 'BT', 'BK', 'BC', 'IS_VARLEN', 'HEAD_FIRST'], ) @triton.jit(do_not_specialize=['T']) def fwd_prepare_wy_repr_kernel_chunk64( @@ -125,12 +125,12 @@ def fwd_prepare_wy_repr_kernel_chunk64( BT: tl.constexpr, BK: tl.constexpr, BC: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -239,7 +239,7 @@ def fwd_prepare_wy_repr_kernel_chunk64( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -247,7 +247,7 @@ def fwd_prepare_wy_repr_kernel_chunk64( for num_warps in [2, 4, 8] for num_stages in [2, 3, 4] ], - key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'IS_VARLEN'], ) @triton.jit(do_not_specialize=['T']) def fwd_recompute_w_u_kernel( @@ -268,11 +268,11 @@ def fwd_recompute_w_u_kernel( BK: tl.constexpr, BV: tl.constexpr, HEAD_FIRST: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -418,7 +418,7 @@ def fwd_recompute_w_u( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -426,7 +426,7 @@ def fwd_recompute_w_u( for num_warps in [2, 4] for num_stages in [2, 3, 4] ], - key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'] + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'IS_VARLEN'] ) @triton.jit(do_not_specialize=['T']) def bwd_prepare_wy_repr_kernel( @@ -452,11 +452,11 @@ def bwd_prepare_wy_repr_kernel( BK: tl.constexpr, BV: tl.constexpr, HEAD_FIRST: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py index a26a21531e..6619c1f7cd 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py @@ -12,7 +12,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -55,14 +55,14 @@ def chunk_dplr_bwd_kernel_intra( BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, GATHER_SUPPORTED: tl.constexpr ): i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_t, i_i = i_c // NC, i_c % NC - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) else: @@ -289,7 +289,7 @@ def chunk_dplr_bwd_kernel_intra( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -314,12 +314,12 @@ def chunk_dplr_bwd_dgk_kernel( K: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_t, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py index d181bc897b..8615393fa6 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py @@ -12,7 +12,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -46,13 +46,13 @@ def chunk_dplr_fwd_A_kernel_intra_sub_inter( BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_i, i_j = i_c // NC, i_c % NC - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -129,7 +129,7 @@ def chunk_dplr_fwd_A_kernel_intra_sub_inter( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -166,14 +166,14 @@ def chunk_dplr_fwd_A_kernel_intra_sub_intra( BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, GATHER_SUPPORTED: tl.constexpr ): i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_j = i_i - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py index 18980e7165..a34598d1c3 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py @@ -15,7 +15,7 @@ @triton.heuristics({ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -50,12 +50,12 @@ def chunk_dplr_bwd_kernel_dhu( BV: tl.constexpr, USE_FINAL_STATE_GRADIENT: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py index 82282e1366..06c831327a 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py @@ -15,7 +15,7 @@ @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -51,12 +51,12 @@ def chunk_dplr_fwd_kernel_h( NT: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py index 159ed4c197..eb133b6268 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py @@ -14,7 +14,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -42,12 +42,12 @@ def chunk_dplr_bwd_kernel_dAu( V: tl.constexpr, BT: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) else: @@ -100,7 +100,7 @@ def chunk_dplr_bwd_kernel_dAu( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -137,13 +137,13 @@ def chunk_dplr_bwd_o_kernel( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) @@ -228,7 +228,7 @@ def chunk_dplr_bwd_o_kernel( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -257,12 +257,12 @@ def chunk_dplr_bwd_kernel_dv( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py index cd30aedb51..7c2a3e12da 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py @@ -13,7 +13,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -44,13 +44,13 @@ def chunk_dplr_fwd_kernel_o( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) diff --git a/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py b/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py index 5903b8fe6e..76059edae6 100644 --- a/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py +++ b/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py @@ -14,7 +14,7 @@ @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -49,13 +49,13 @@ def fused_recurrent_dplr_delta_rule_fwd_kernel( REVERSE: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) T = eos - bos else: diff --git a/fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py b/fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py index d9ff775184..a5223a47ac 100644 --- a/fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py +++ b/fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py @@ -14,7 +14,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -47,12 +47,12 @@ def bwd_prepare_wy_repr_kernel( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos diff --git a/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py b/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py index c5e631a3e1..46c26ed1e3 100644 --- a/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py +++ b/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py @@ -12,7 +12,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -32,12 +32,12 @@ def fwd_prepare_wy_repr_kernel_chunk32( H: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, # placeholder, do not delete - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -61,7 +61,7 @@ def fwd_prepare_wy_repr_kernel_chunk32( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -82,13 +82,13 @@ def fwd_prepare_wy_repr_kernel_chunk64( H: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, GATHER_SUPPORTED: tl.constexpr = is_gather_supported ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -151,7 +151,7 @@ def fwd_prepare_wy_repr_kernel_chunk64( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -179,12 +179,12 @@ def fwd_wu_kernel( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos diff --git a/fla/ops/generalized_delta_rule/iplr/chunk.py b/fla/ops/generalized_delta_rule/iplr/chunk.py index ce46d0d7e7..ccad332556 100644 --- a/fla/ops/generalized_delta_rule/iplr/chunk.py +++ b/fla/ops/generalized_delta_rule/iplr/chunk.py @@ -17,7 +17,7 @@ @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -51,12 +51,12 @@ def chunk_generalized_iplr_delta_rule_fwd_kernel_h( NT: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) @@ -112,7 +112,7 @@ def chunk_generalized_iplr_delta_rule_fwd_kernel_h( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -144,13 +144,13 @@ def chunk_generalized_iplr_delta_rule_fwd_kernel_o( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) diff --git a/fla/ops/generalized_delta_rule/iplr/fused_recurrent.py b/fla/ops/generalized_delta_rule/iplr/fused_recurrent.py index f70a7f0771..3944515dd7 100644 --- a/fla/ops/generalized_delta_rule/iplr/fused_recurrent.py +++ b/fla/ops/generalized_delta_rule/iplr/fused_recurrent.py @@ -13,7 +13,7 @@ @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -45,14 +45,14 @@ def fused_recurrent_fwd_kernel( BV: tl.constexpr, # BLOCK SIZE along the V dimension USE_INITIAL_STATE: tl.constexpr, # whether to use initial state STORE_FINAL_STATE: tl.constexpr, # whether to store final state - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): # indices i_v, i_nh = tl.program_id(0), tl.program_id(1) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) T = eos - bos else: @@ -115,7 +115,7 @@ def fused_recurrent_fwd_kernel( 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'USE_DHT': lambda args: args['dht'] is not None, 'USE_DH0': lambda args: args['dh0'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -157,7 +157,7 @@ def fused_recurrent_bwd_kernel( USE_INITIAL_STATE: tl.constexpr, # whether to use initial state h0 USE_DH0: tl.constexpr, # whether to use dh0 USE_DHT: tl.constexpr, # whether to use dht - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_nh = tl.program_id(0), tl.program_id(1) @@ -166,7 +166,7 @@ def fused_recurrent_bwd_kernel( db += i_v * B * H * K * T dq += i_v * B * H * K * T da += i_v * B * H * K * T - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) T = eos - bos else: diff --git a/fla/ops/generalized_delta_rule/iplr/wy_fast.py b/fla/ops/generalized_delta_rule/iplr/wy_fast.py index fcdc7e4418..db026e59cf 100644 --- a/fla/ops/generalized_delta_rule/iplr/wy_fast.py +++ b/fla/ops/generalized_delta_rule/iplr/wy_fast.py @@ -14,7 +14,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -36,12 +36,12 @@ def fwd_prepare_wy_repr_kernel_chunk32( BT: tl.constexpr, BK: tl.constexpr, BC: tl.constexpr, # dummy placeholder - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -76,7 +76,7 @@ def fwd_prepare_wy_repr_kernel_chunk32( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -98,12 +98,12 @@ def fwd_prepare_wy_repr_kernel_chunk64( BT: tl.constexpr, BK: tl.constexpr, BC: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -169,7 +169,7 @@ def fwd_prepare_wy_repr_kernel_chunk64( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -195,12 +195,12 @@ def fwd_wu_kernel( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos diff --git a/fla/ops/gla/chunk.py b/fla/ops/gla/chunk.py index df86f81963..473b06d86d 100644 --- a/fla/ops/gla/chunk.py +++ b/fla/ops/gla/chunk.py @@ -18,7 +18,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -45,13 +45,13 @@ def chunk_gla_fwd_A_kernel_intra_sub_inter( BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_i, i_j = i_c // NC, i_c % NC - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -102,7 +102,7 @@ def chunk_gla_fwd_A_kernel_intra_sub_inter( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -128,13 +128,13 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra( BT: tl.constexpr, BC: tl.constexpr, BK: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_j = i_i - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -175,7 +175,7 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -203,14 +203,14 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra_split( BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_tc, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_t, i_i = i_tc // NC, i_tc % NC i_j = i_i - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) all = T @@ -254,7 +254,7 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra_split( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -277,12 +277,12 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra_merge( BT: tl.constexpr, BC: tl.constexpr, NK: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) all = T @@ -309,7 +309,7 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra_merge( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -338,12 +338,12 @@ def chunk_gla_fwd_kernel_o( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) @@ -398,7 +398,7 @@ def chunk_gla_fwd_kernel_o( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -426,13 +426,13 @@ def chunk_gla_bwd_kernel_intra( BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_t, i_i = i_c // NC, i_c % NC - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) else: @@ -574,7 +574,7 @@ def chunk_gla_bwd_kernel_intra( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -598,12 +598,12 @@ def chunk_gla_bwd_kernel_dA( V: tl.constexpr, BT: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) else: @@ -631,7 +631,7 @@ def chunk_gla_bwd_kernel_dA( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -659,12 +659,12 @@ def chunk_gla_bwd_kernel_dv( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) @@ -717,7 +717,7 @@ def chunk_gla_bwd_kernel_dv( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -752,12 +752,12 @@ def chunk_gla_bwd_kernel_inter( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) diff --git a/fla/ops/gsa/chunk.py b/fla/ops/gsa/chunk.py index c528f9136d..838c07ebc7 100644 --- a/fla/ops/gsa/chunk.py +++ b/fla/ops/gsa/chunk.py @@ -16,7 +16,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -48,14 +48,14 @@ def chunk_gsa_fwd_k_kernel_inter( BK: tl.constexpr, BV: tl.constexpr, NG: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_bg = i_bh // NG i_b, i_hq = i_bh // HQ, i_bh % HQ i_h = i_hq // NG - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) @@ -112,7 +112,7 @@ def chunk_gsa_fwd_k_kernel_inter( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.jit(do_not_specialize=['T']) def chunk_gsa_fwd_k_kernel_intra( @@ -131,7 +131,7 @@ def chunk_gsa_fwd_k_kernel_intra( BV: tl.constexpr, NC: tl.constexpr, NG: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) @@ -139,7 +139,7 @@ def chunk_gsa_fwd_k_kernel_intra( i_b, i_hq = i_bh // HQ, i_bh % HQ i_h = i_hq // NG i_t, i_i = i_c // NC, i_c % NC - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -213,7 +213,7 @@ def chunk_gsa_fwd_k_kernel_intra( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -241,7 +241,7 @@ def chunk_gsa_bwd_k_kernel_dA( BV: tl.constexpr, NC: tl.constexpr, NG: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) @@ -249,7 +249,7 @@ def chunk_gsa_bwd_k_kernel_dA( i_b, i_hq = i_bh // HQ, i_bh % HQ i_h = i_hq // NG i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) all = T @@ -330,7 +330,7 @@ def chunk_gsa_bwd_k_kernel_dA( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -369,14 +369,14 @@ def chunk_gsa_bwd_k_kernel_dqkvg( BK: tl.constexpr, BV: tl.constexpr, NG: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_bg = i_bh // NG i_b, i_hq = i_bh // HQ, i_bh % HQ i_h = i_hq // NG - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) @@ -487,7 +487,7 @@ def chunk_gsa_bwd_k_kernel_dqkvg( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.jit(do_not_specialize=['T']) def chunk_gsa_bwd_k_kernel_intra_dvg( @@ -509,7 +509,7 @@ def chunk_gsa_bwd_k_kernel_intra_dvg( BV: tl.constexpr, NC: tl.constexpr, NG: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) @@ -517,7 +517,7 @@ def chunk_gsa_bwd_k_kernel_intra_dvg( i_b, i_hq = i_bh // HQ, i_bh % HQ i_h = i_hq // NG i_t, i_i = i_c // NC, i_c % NC - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos diff --git a/fla/ops/hgrn/fused_recurrent.py b/fla/ops/hgrn/fused_recurrent.py index a6a70f0c7e..d77213c050 100644 --- a/fla/ops/hgrn/fused_recurrent.py +++ b/fla/ops/hgrn/fused_recurrent.py @@ -14,7 +14,7 @@ @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -37,10 +37,10 @@ def fused_recurrent_hgrn_fwd_kernel( BD: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_d, i_n = tl.program_id(0), tl.program_id(1) - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) T = eos - bos else: @@ -75,7 +75,7 @@ def fused_recurrent_hgrn_fwd_kernel( @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -101,10 +101,10 @@ def fused_recurrent_hgrn_bwd_kernel( BD: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, USE_FINAL_STATE_GRADIENT: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_d, i_n = tl.program_id(0), tl.program_id(1) - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) T = eos - bos else: diff --git a/fla/ops/nsa/compression.py b/fla/ops/nsa/compression.py index c6bf3bb97d..14444fbdc1 100644 --- a/fla/ops/nsa/compression.py +++ b/fla/ops/nsa/compression.py @@ -14,7 +14,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -44,12 +44,12 @@ def parallel_nsa_compression_fwd_kernel( BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, ): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -116,7 +116,7 @@ def parallel_nsa_compression_fwd_kernel( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -149,12 +149,12 @@ def parallel_nsa_compression_bwd_kernel_dq( BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -218,7 +218,7 @@ def parallel_nsa_compression_bwd_kernel_dq( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -252,12 +252,12 @@ def parallel_nsa_compression_bwd_kernel_dkv( BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_c = tl.load(chunk_indices + i_c * 2).to(tl.int32), tl.load(chunk_indices + i_c * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos diff --git a/fla/ops/nsa/parallel.py b/fla/ops/nsa/parallel.py index 425b34fffe..98cdee2a17 100644 --- a/fla/ops/nsa/parallel.py +++ b/fla/ops/nsa/parallel.py @@ -28,7 +28,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -56,12 +56,12 @@ def parallel_nsa_kernel_topk( BC: tl.constexpr, BS: tl.constexpr, BK: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -160,7 +160,7 @@ def parallel_nsa_kernel_topk( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), }) @triton.autotune( @@ -192,13 +192,13 @@ def parallel_nsa_fwd_kernel( BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, USE_BLOCK_COUNTS: tl.constexpr ): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -286,7 +286,7 @@ def parallel_nsa_kernel_mask( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor) }) @triton.autotune( @@ -321,13 +321,13 @@ def parallel_nsa_bwd_kernel_dq( BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, USE_BLOCK_COUNTS: tl.constexpr ): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -394,7 +394,7 @@ def parallel_nsa_bwd_kernel_dq( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -428,12 +428,12 @@ def parallel_nsa_bwd_kernel_dkv( BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_v, i_s, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos diff --git a/fla/ops/rwkv6/chunk.py b/fla/ops/rwkv6/chunk.py index 5e790fac83..7047ac2c5e 100644 --- a/fla/ops/rwkv6/chunk.py +++ b/fla/ops/rwkv6/chunk.py @@ -17,7 +17,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -42,11 +42,11 @@ def chunk_rwkv6_fwd_cumsum_kernel( BT: tl.constexpr, BS: tl.constexpr, HEAD_FIRST: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, ): i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -106,7 +106,7 @@ def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H) @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -135,13 +135,13 @@ def chunk_rwkv6_fwd_A_kernel_intra_sub_inter( BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_i, i_j = i_c // NC, i_c % NC - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -194,7 +194,7 @@ def chunk_rwkv6_fwd_A_kernel_intra_sub_inter( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -223,13 +223,13 @@ def chunk_rwkv6_fwd_A_kernel_intra_sub_intra( BT: tl.constexpr, BC: tl.constexpr, BK: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_j = i_i - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -277,7 +277,7 @@ def chunk_rwkv6_fwd_A_kernel_intra_sub_intra( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -308,14 +308,14 @@ def chunk_rwkv6_fwd_A_kernel_intra_sub_intra_split( BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_tc, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_t, i_i = i_tc // NC, i_tc % NC i_j = i_i - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) all = T @@ -366,7 +366,7 @@ def chunk_rwkv6_fwd_A_kernel_intra_sub_intra_split( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -390,12 +390,12 @@ def chunk_rwkv6_fwd_A_kernel_intra_sub_intra_merge( BT: tl.constexpr, BC: tl.constexpr, NK: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) all = T @@ -424,7 +424,7 @@ def chunk_rwkv6_fwd_A_kernel_intra_sub_intra_merge( @triton.heuristics({ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -460,14 +460,14 @@ def chunk_rwkv6_bwd_kernel_dh( NG: tl.constexpr, STORE_INITIAL_STATE_GRADIENT: tl.constexpr, USE_FINAL_STATE_GRADIENT: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_bg = i_nh // NG i_n, i_hq = i_nh // HQ, i_nh % HQ i_h = i_hq // NG - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) @@ -521,7 +521,7 @@ def chunk_rwkv6_bwd_kernel_dh( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -549,13 +549,13 @@ def chunk_rwkv6_bwd_kernel_intra( BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_t, i_i = i_c // NC, i_c % NC - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) else: @@ -696,7 +696,7 @@ def chunk_rwkv6_bwd_kernel_intra( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -736,13 +736,13 @@ def chunk_rwkv6_bwd_kernel_inter( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) diff --git a/fla/ops/rwkv6/fused_recurrent.py b/fla/ops/rwkv6/fused_recurrent.py index c95ec015ab..256d845c90 100644 --- a/fla/ops/rwkv6/fused_recurrent.py +++ b/fla/ops/rwkv6/fused_recurrent.py @@ -14,7 +14,7 @@ @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -45,12 +45,12 @@ def fused_recurrent_rwkv6_fwd_kernel( REVERSE: tl.constexpr, # whether to reverse the recurrence USE_INITIAL_STATE: tl.constexpr, # whether to use initial state STORE_FINAL_STATE: tl.constexpr, # whether to store final state - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) all = T T = eos - bos @@ -107,7 +107,7 @@ def fused_recurrent_rwkv6_fwd_kernel( @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -138,12 +138,12 @@ def fused_recurrent_rwkv6_bwd_kernel_dq( BV: tl.constexpr, REVERSE: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) all = T T = eos - bos @@ -205,7 +205,7 @@ def fused_recurrent_rwkv6_bwd_kernel_dq( @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -238,12 +238,12 @@ def fused_recurrent_rwkv6_bwd_kernel_dkv( BV: tl.constexpr, REVERSE: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) all = T T = eos - bos @@ -312,7 +312,7 @@ def fused_recurrent_rwkv6_bwd_kernel_dkv( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -339,11 +339,11 @@ def fused_recurrent_rwkv6_bwd_kernel_dw( BK: tl.constexpr, REVERSE: tl.constexpr, HEAD_FIRST: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_k, i_nh = tl.program_id(0), tl.program_id(1) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) else: bos, eos = i_n * T, i_n * T + T diff --git a/fla/ops/simple_gla/parallel.py b/fla/ops/simple_gla/parallel.py index 48633bd8d2..468d4088ec 100644 --- a/fla/ops/simple_gla/parallel.py +++ b/fla/ops/simple_gla/parallel.py @@ -18,7 +18,7 @@ @triton.heuristics({ 'NV': lambda args: triton.cdiv(args['V'], args['BV']), 'OUTPUT_ATTENTIONS': lambda args: args['attn'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, 'USE_G': lambda args: args['g'] is not None }) @triton.autotune( @@ -52,16 +52,16 @@ def parallel_simple_gla_fwd_kernel( NV: tl.constexpr, OUTPUT_ATTENTIONS: tl.constexpr, HEAD_FIRST: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, USE_G: tl.constexpr ): - tl.static_assert(not (USE_OFFSETS and HEAD_FIRST), "USE_OFFSETS and HEAD_FIRST cannot be True at the same time") + tl.static_assert(not (IS_VARLEN and HEAD_FIRST), "IS_VARLEN and HEAD_FIRST cannot be True at the same time") i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_k, i_v = i_kv // NV, i_kv % NV i_b, i_h = i_bh // H, i_bh % H o += i_k * B * T * H * V - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -355,7 +355,7 @@ def parallel_simple_gla_bwd_kernel_dkv( @triton.heuristics({ 'NV': lambda args: triton.cdiv(args['V'], args['BV']), - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, 'USE_G': lambda args: args['g'] is not None }) @triton.autotune( @@ -389,11 +389,10 @@ def parallel_simple_gla_bwd_kernel( BK: tl.constexpr, BV: tl.constexpr, NV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, USE_G: tl.constexpr ): - tl.static_assert(not (USE_OFFSETS and HEAD_FIRST), "USE_OFFSETS and HEAD_FIRST cannot be True at the same time") i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_k, i_v = i_kv // NV, i_kv % NV i_b, i_h = i_bh // H, i_bh % H @@ -403,7 +402,7 @@ def parallel_simple_gla_bwd_kernel( if USE_G: dg += i_kv * B * H * T - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos diff --git a/fla/ops/ttt/chunk.py b/fla/ops/ttt/chunk.py index 34ce2e4a92..c5679b4b7f 100755 --- a/fla/ops/ttt/chunk.py +++ b/fla/ops/ttt/chunk.py @@ -17,7 +17,7 @@ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None, 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -54,12 +54,12 @@ def chunk_ttt_linear_fwd_kernel_h( USE_INITIAL_STATE: tl.constexpr, USE_INITIAL_STATE_B: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) @@ -132,7 +132,7 @@ def chunk_ttt_linear_fwd_kernel_h( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -161,13 +161,13 @@ def chunk_ttt_linear_fwd_kernel_o( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) @@ -226,7 +226,7 @@ def chunk_ttt_linear_fwd_kernel_o( @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -262,12 +262,12 @@ def chunk_ttt_linear_bwd_kernel_h( NT: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, USE_INITIAL_STATE_B: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) @@ -340,7 +340,7 @@ def chunk_ttt_linear_bwd_kernel_h( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -366,12 +366,12 @@ def chunk_ttt_linear_bwd_kernel_dv_local( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -415,7 +415,7 @@ def chunk_ttt_linear_bwd_kernel_dv_local( 'USE_FINAL_STATE_GRADIENT_B': lambda args: args['dhbt'] is not None, 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, 'USE_INITIAL_STATE_B': lambda args: args['dhb0'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -463,12 +463,12 @@ def chunk_ttt_linear_bwd_kernel_norm( USE_FINAL_STATE_GRADIENT_B: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, USE_INITIAL_STATE_B: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) @@ -591,7 +591,7 @@ def chunk_ttt_linear_bwd_kernel_norm( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -625,12 +625,12 @@ def chunk_bwd_kernel_dqke( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) diff --git a/fla/ops/ttt/fused_chunk.py b/fla/ops/ttt/fused_chunk.py index d574910a18..0b4e986567 100755 --- a/fla/ops/ttt/fused_chunk.py +++ b/fla/ops/ttt/fused_chunk.py @@ -15,7 +15,7 @@ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None, 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -51,13 +51,13 @@ def fused_chunk_ttt_linear_fwd_kernel( USE_INITIAL_STATE: tl.constexpr, USE_INITIAL_STATE_B: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr ): # indices i_nh = tl.program_id(0) i_n, i_h = i_nh // H, i_nh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) diff --git a/fla/ops/utils/cumsum.py b/fla/ops/utils/cumsum.py index c129883472..1dd19670ad 100644 --- a/fla/ops/utils/cumsum.py +++ b/fla/ops/utils/cumsum.py @@ -13,7 +13,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -32,12 +32,12 @@ def chunk_local_cumsum_scalar_kernel( H: tl.constexpr, BT: tl.constexpr, HEAD_FIRST: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, REVERSE: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -60,7 +60,7 @@ def chunk_local_cumsum_scalar_kernel( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -82,12 +82,12 @@ def chunk_local_cumsum_vector_kernel( BT: tl.constexpr, BS: tl.constexpr, HEAD_FIRST: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, REVERSE: tl.constexpr ): i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -113,7 +113,7 @@ def chunk_local_cumsum_vector_kernel( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -134,12 +134,12 @@ def chunk_global_cumsum_scalar_kernel( H: tl.constexpr, BT: tl.constexpr, HEAD_FIRST: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, REVERSE: tl.constexpr ): i_bh = tl.program_id(0) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32) else: bos, eos = i_b * T, i_b * T + T @@ -167,7 +167,7 @@ def chunk_global_cumsum_scalar_kernel( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'IS_VARLEN': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ @@ -188,12 +188,12 @@ def chunk_global_cumsum_vector_kernel( BT: tl.constexpr, BS: tl.constexpr, HEAD_FIRST: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, REVERSE: tl.constexpr ): i_s, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32) else: bos, eos = i_b * T, i_b * T + T diff --git a/fla/ops/utils/pooling.py b/fla/ops/utils/pooling.py index 0dd9059b4a..374c06ff7f 100644 --- a/fla/ops/utils/pooling.py +++ b/fla/ops/utils/pooling.py @@ -12,7 +12,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -34,11 +34,11 @@ def mean_pooling_fwd_kernel( BT: tl.constexpr, BD: tl.constexpr, NT: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) @@ -59,7 +59,7 @@ def mean_pooling_fwd_kernel( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -81,11 +81,11 @@ def mean_pooling_bwd_kernel( BT: tl.constexpr, BD: tl.constexpr, NT: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) diff --git a/fla/ops/utils/solve_tril.py b/fla/ops/utils/solve_tril.py index d0c2b66833..16324c713c 100644 --- a/fla/ops/utils/solve_tril.py +++ b/fla/ops/utils/solve_tril.py @@ -12,7 +12,7 @@ @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -31,12 +31,12 @@ def solve_tril_16x16_kernel( T, H: tl.constexpr, BT: tl.constexpr, - USE_OFFSETS: tl.constexpr, + IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -71,7 +71,7 @@ def solve_tril_16x16_kernel( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -79,7 +79,7 @@ def solve_tril_16x16_kernel( for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4, 5] ], - key=['H', 'BT', 'HEAD_FIRST', 'USE_OFFSETS'], + key=['H', 'BT', 'HEAD_FIRST', 'IS_VARLEN'], ) @triton.jit(do_not_specialize=['T']) def merge_16x16_to_32x32_inverse_kernel( @@ -92,11 +92,11 @@ def merge_16x16_to_32x32_inverse_kernel( H: tl.constexpr, BT: tl.constexpr, HEAD_FIRST: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -133,7 +133,7 @@ def merge_16x16_to_32x32_inverse_kernel( @triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None + 'IS_VARLEN': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ @@ -141,7 +141,7 @@ def merge_16x16_to_32x32_inverse_kernel( for num_warps in [2, 4, 8] for num_stages in [2, 3, 4, 5] ], - key=['H', 'BT', 'HEAD_FIRST', 'USE_OFFSETS'], + key=['H', 'BT', 'HEAD_FIRST', 'IS_VARLEN'], ) @triton.jit(do_not_specialize=['T']) def merge_16x16_to_64x64_inverse_kernel( @@ -154,11 +154,11 @@ def merge_16x16_to_64x64_inverse_kernel( H: tl.constexpr, BT: tl.constexpr, HEAD_FIRST: tl.constexpr, - USE_OFFSETS: tl.constexpr + IS_VARLEN: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: + if IS_VARLEN: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos @@ -316,6 +316,5 @@ def solve_tril( H=H, BT=BT, HEAD_FIRST=head_first, - USE_OFFSETS=cu_seqlens is not None ) return Ai