[Attn] Remove head_first & rename offsets to cu_seqlens#345
Conversation
|
Caution Review failedThe pull request is closed. WalkthroughThis pull request primarily renames parameters and adjusts logic in the attention kernels. In the affected functions, Changes
Sequence Diagram(s)sequenceDiagram
participant Caller as Caller
participant Fwd as parallel_attn_fwd
participant Kernel as parallel_attn_fwd_kernel
Caller ->> Fwd: Invoke forward attention (q, k, v, g_cumsum, scale, cu_seqlens)
Fwd ->> Kernel: Call kernel with cu_seqlens & chunk_indices
Kernel ->> Kernel: Check IS_VARLEN & process variable-length sequences
Kernel -->> Fwd: Return output tensor
Fwd -->> Caller: Deliver final result
Possibly related PRs
Poem
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
fla/ops/attn/parallel.py (2)
716-719: Added deprecation warning for head_first parameter.Good practice to warn users about the planned deprecation of the
head_firstparameter, allowing them to update their code before the feature is removed.Consider adding a
stacklevelparameter to the warning to help users identify where in their code the warning is coming from:- warnings.warn( - "head_first is deprecated and will be removed in a future version. " - "Please use head_first=False for now instead." - ) + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
716-716: No explicit
stacklevelkeyword argument found(B028)
723-729: Added warning for potential tensor shape mismatch.Great addition that helps users identify potential misuse of the API when the input tensor dimensions suggest a format mismatch.
Consider adding a
stacklevelparameter here as well for better error reporting:- warnings.warn( - f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " - "This may indicate the inputs were passed in head-first format [B, H, T, ...] " - "when head_first=False was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...]." - ) + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
724-724: No explicit
stacklevelkeyword argument found(B028)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
fla/ops/attn/parallel.py(22 hunks)tests/ops/test_attn.py(5 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (2)
tests/ops/test_attn.py (2)
fla/ops/common/utils.py (1)
prepare_lens(35-36)fla/ops/utils/testing.py (1)
assert_close(17-29)
fla/ops/attn/parallel.py (3)
fla/utils.py (1)
check_shared_mem(199-205)fla/ops/common/utils.py (1)
prepare_chunk_indices(56-61)fla/ops/utils/cumsum.py (1)
chunk_global_cumsum(364-382)
🪛 Ruff (0.8.2)
fla/ops/attn/parallel.py
716-716: No explicit stacklevel keyword argument found
(B028)
724-724: No explicit stacklevel keyword argument found
(B028)
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: test
- GitHub Check: test
🔇 Additional comments (16)
tests/ops/test_attn.py (5)
34-40: Code style consistency improvement.The parameterization decorators have been updated to use single quotes consistently for parameter names, which improves style consistency throughout the codebase.
107-111: Improved parameter naming for better semantics.Renaming
offsetstocu_seqlensprovides better clarity about the parameter's purpose - representing cumulative sequence lengths in variable-length sequences.
122-126: Consistent parameter naming with implementation.The parameter names in the flash_attn_varlen_func call have been updated to match the renamed parameters in the implementation file, ensuring consistency across the codebase.
137-138: Aligned parameter naming with implementation.Updated parameter name from
offsetstocu_seqlensin the parallel_attn call, ensuring consistency with the implementation.
144-147: Improved formatting in assertion messages.Removed unnecessary whitespace in assertion messages, making the code cleaner.
fla/ops/attn/parallel.py (11)
4-4: Added warnings import for deprecation notices.Good addition to support the new warning messages about deprecated functionality.
19-21: Updated variable length condition check.The heuristic for checking if variable length sequences are in use has been updated to reference the renamed parameter (
cu_seqlensinstead ofoffsets).
28-28: Improved autotune configuration keys.Added
USE_GandIS_VARLENto the autotune configuration key list, which will help the tuner better differentiate between different use cases.
39-41: Consistent parameter renaming in kernel function signature.Parameters
offsetsandindiceshave been renamed tocu_seqlensandchunk_indicesrespectively, providing better clarity about their purpose.
60-62: Updated variable loading to use renamed parameters.The variable loading code has been updated to reference the renamed parameters, maintaining functionality while improving readability.
195-197: Consistent parameter renaming in backward kernel functions.The parameter renaming has been consistently applied to both forward and backward kernel function signatures, maintaining a unified naming convention throughout the codebase.
Also applies to: 333-335
489-491: Updated variable initialization with renamed parameters.Function calls and variable initializations have been updated to use the renamed parameters, ensuring consistency throughout the codebase.
526-527: Modified floating-point type in backward preprocessing.Changed
torch.float32totorch.floatwhich allows the code to adapt to the system's default floating-point precision.
571-573: Simplified empty tensor initialization.The code now directly initializes empty tensors with the appropriate dtype without additional empty_like and cast operations.
730-734: Improved function parameter handling.The code now handles scale calculation and input validation more clearly with better spacing and assertions.
679-679: Updated return statement.Function return statement updated to match the new parameter naming convention.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
fla/ops/gsa/fused_recurrent.py (1)
512-517: Consider removing deprecatedhead_firstparameterThe PR title mentions removing
head_first, but it's only deprecated with a warning. Consider removing it completely in this PR or creating a follow-up PR for full removal.🧰 Tools
🪛 Ruff (0.8.2)
513-513: No explicit
stacklevelkeyword argument found(B028)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (23)
fla/layers/rwkv6.py(1 hunks)fla/ops/gla/chunk.py(56 hunks)fla/ops/gla/fused_chunk.py(0 hunks)fla/ops/gsa/chunk.py(42 hunks)fla/ops/gsa/fused_recurrent.py(12 hunks)fla/ops/hgrn/fused_recurrent.py(11 hunks)fla/ops/linear_attn/fused_chunk.py(0 hunks)fla/ops/linear_attn/fused_recurrent.py(0 hunks)fla/ops/retention/fused_chunk.py(0 hunks)fla/ops/rwkv6/chunk.py(57 hunks)fla/ops/rwkv6/fused_recurrent.py(21 hunks)fla/ops/simple_gla/chunk.py(11 hunks)fla/ops/simple_gla/parallel.py(17 hunks)tests/ops/test_based.py(1 hunks)tests/ops/test_gla.py(4 hunks)tests/ops/test_gsa.py(8 hunks)tests/ops/test_hgrn.py(6 hunks)tests/ops/test_linear_attn.py(4 hunks)tests/ops/test_retention.py(4 hunks)tests/ops/test_rwkv6.py(7 hunks)tests/ops/test_simple_gla.py(7 hunks)tests/ops/test_solve_tril.py(3 hunks)tests/ops/test_utils.py(8 hunks)
💤 Files with no reviewable changes (4)
- fla/ops/linear_attn/fused_chunk.py
- fla/ops/gla/fused_chunk.py
- fla/ops/retention/fused_chunk.py
- fla/ops/linear_attn/fused_recurrent.py
✅ Files skipped from review due to trivial changes (9)
- tests/ops/test_gla.py
- tests/ops/test_based.py
- fla/layers/rwkv6.py
- tests/ops/test_linear_attn.py
- tests/ops/test_gsa.py
- tests/ops/test_hgrn.py
- tests/ops/test_utils.py
- tests/ops/test_simple_gla.py
- tests/ops/test_retention.py
🧰 Additional context used
🧬 Code Graph Analysis (6)
fla/ops/simple_gla/chunk.py (1)
fla/ops/utils/cumsum.py (1)
chunk_local_cumsum(386-406)
tests/ops/test_rwkv6.py (2)
fla/ops/rwkv6/fused_recurrent.py (1)
fused_recurrent_rwkv6(556-681)fla/ops/rwkv6/chunk.py (1)
chunk_rwkv6(1188-1308)
fla/ops/hgrn/fused_recurrent.py (8)
fla/ops/gla/fused_chunk.py (1)
backward(470-595)fla/ops/gla/chunk.py (1)
backward(1197-1215)fla/ops/rwkv6/fused_recurrent.py (1)
backward(537-553)fla/ops/rwkv6/chunk.py (1)
backward(1167-1184)fla/ops/delta_rule/chunk.py (1)
backward(194-221)fla/ops/gated_delta_rule/chunk.py (1)
backward(205-232)fla/ops/generalized_delta_rule/dplr/chunk.py (1)
backward(135-268)fla/ops/abc/chunk.py (1)
backward(950-1074)
fla/ops/rwkv6/chunk.py (1)
fla/ops/common/utils.py (2)
prepare_chunk_indices(59-64)prepare_chunk_offsets(68-72)
fla/ops/simple_gla/parallel.py (2)
fla/ops/common/utils.py (1)
prepare_chunk_indices(59-64)fla/ops/utils/cumsum.py (2)
chunk_local_cumsum(386-406)chunk_global_cumsum(364-382)
fla/ops/gla/chunk.py (2)
fla/ops/common/utils.py (1)
prepare_chunk_indices(59-64)fla/ops/utils/cumsum.py (1)
chunk_local_cumsum(386-406)
🪛 Ruff (0.8.2)
fla/ops/rwkv6/chunk.py
1272-1272: No explicit stacklevel keyword argument found
(B028)
🔇 Additional comments (60)
fla/ops/hgrn/fused_recurrent.py (7)
17-17: LGTM: Parameter name consistently renamedThe parameter renaming from
offsetstocu_seqlensin the heuristic lambda function is consistent with the PR objective.
34-34: LGTM: Parameter renaming in forward kernelThe parameter and its usage have been consistently renamed from
offsetstocu_seqlensin the forward kernel, maintaining the same functionality while aligning with standard terminology.Also applies to: 44-44
78-78: LGTM: Parameter renaming in backward kernelThe parameter and its usage have been consistently renamed from
offsetstocu_seqlensin the backward kernel definition and implementation.Also applies to: 98-98, 108-108
160-160: LGTM: Parameter renaming in fused_recurrent_hgrn_fwd functionThe parameter has been consistently renamed from
offsetstocu_seqlensin the function signature and its usage in calculatingNand passing to the kernel.Also applies to: 163-163, 175-175
188-188: LGTM: Parameter renaming in fused_recurrent_hgrn_bwd functionThe parameter has been consistently renamed from
offsetstocu_seqlensin the backward function signature and its usage.Also applies to: 191-191, 206-206
223-223: LGTM: Parameter renaming in FusedRecurrentHGRNFunctionThe parameter has been consistently renamed from
offsetstocu_seqlensin the Function class, both in the forward method's signature and in context saving/retrieving for the backward pass.Also applies to: 230-230, 233-233, 240-240, 248-248
259-308: LGTM: Documentation updated correctlyThe docstring has been updated to reference
cu_seqlensinstead ofoffsetsand provides clear explanation of its purpose for handling variable-length sequences. The usage example is also updated correctly.fla/ops/gsa/fused_recurrent.py (3)
154-154: Consistent parameter renaming fromoffsetstocu_seqlensThe parameter has been renamed consistently throughout the file. This improves clarity as
cu_seqlensmore accurately describes the purpose of this parameter (cumulative sequence lengths), aligning with the naming convention used in the FlashAttention API as mentioned in the docstring.Also applies to: 161-161, 189-189, 219-219, 252-252, 255-255, 282-282, 299-299, 321-321, 339-339, 364-364, 389-389, 394-394, 404-404, 425-425
467-472: Well-documented parameterThe documentation for
cu_seqlensis clear and informative, explaining both its purpose and format. It also maintains consistency with the FlashAttention API.
519-524: Improved warning messageThe warning message provides clear guidance about potential format mismatches, which helps users identify and fix issues with their input tensors.
🧰 Tools
🪛 Ruff (0.8.2)
519-519: No explicit
stacklevelkeyword argument found(B028)
fla/ops/simple_gla/chunk.py (4)
25-25: Renamed parameter for better semantic clarityThe parameter name change from
offsetstocu_seqlensprovides better clarity about its purpose - representing cumulative sequence lengths.
28-28: Updated function call argument to match renamed parameterProperly updated
chunk_local_cumsumcall to use the newly renamed parameter.
173-176: Modified cumsum operation and removed redundant type conversionThe code now passes
cu_seqlensto thechunk_local_cumsumfunction and removes the.to(g)conversion fordg. This is because the type conversion is now directly handled in the return statement at line 178, which is clearer and more efficient.
178-178: Improved return statement consistencyThe return statement maintains proper type conversions for other tensors while directly returning
dgwithout an extra conversion step.fla/ops/gsa/chunk.py (3)
21-22: Updated heuristic condition for variable-length sequencesThe Lambda function now checks for
cu_seqlensinstead ofoffsets, maintaining the same functionality while using the more semantically clear parameter name.
557-558: Updated chunk index preparation for consistencyThe code now correctly uses
prepare_chunk_indiceswith the renamedcu_seqlensparameter, maintaining consistent naming across the codebase.
793-793: Updated cumsum operation with renamed parameterThe
chunk_local_cumsumcall now uses thecu_seqlensparameter name, consistent with the parameter naming changes throughout the file.fla/ops/rwkv6/chunk.py (4)
13-13: Import statement updated to include additional utilityThe import statement now includes
prepare_chunk_offsetswhich is used alongsideprepare_chunk_indicesfor handling variable-length sequences.
80-81: Updated chunk preparation logicThe code now uses
prepare_chunk_indiceswith the renamedcu_seqlensparameter, maintaining consistent naming throughout the codebase.
1189-1193: Renamed parameter fromgtowThe function parameter has been renamed from
gtowfor clarity. This matches with the docstring description at line 1208-1209 that describes this parameter as "Forget gates".
1270-1271: Updated parameter references in conditional mappingThe code correctly updates the parameter reference in the
mapfunction call to use the new parameter namewinstead ofg.fla/ops/simple_gla/parallel.py (7)
33-34: Updated heuristic conditions for better clarityThe heuristic conditions now check for
cu_seqlensand properly format the condition forUSE_G. This provides better semantic clarity and maintains consistent parameter naming.
52-53: Renamed parameters in kernel function signatureThe parameters
offsetsandindiceshave been renamed tocu_seqlensandchunk_indicesrespectively, for better semantic clarity and consistency across the codebase.
512-513: Updated chunk preparation logicThe code now uses
prepare_chunk_indiceswith the renamedcu_seqlensparameter, maintaining consistent naming throughout the codebase.
517-517: Updated cumsum function call with renamed parameterThe
chunk_local_cumsumfunction call now uses the new parameter namecu_seqlens, maintaining consistency with the parameter naming changes.
622-622: Updated method signature for consistencyThe
forwardmethod signature now usescu_seqlensinstead ofoffsets, maintaining consistent parameter naming across the codebase.
636-636: Updated context saving with renamed tensorThe saved tensors for backward pass now include
cu_seqlensinstead ofoffsets, matching the parameter renaming.
613-613: Updated global cumsum function callThe
chunk_global_cumsumfunction call now correctly uses the renamedcu_seqlensparameter, maintaining consistency with other parameter naming changes.fla/ops/gla/chunk.py (10)
22-24: Approved: Parameter renaming in heuristic lambda function.The renaming of
offsetstocu_seqlensin the heuristic lambda function is consistent with the PR objective and makes the code more descriptive as it clearly indicates "cumulative sequence lengths."
95-97: Approved: Parameter renaming in heuristic lambda function.Consistent renaming of
offsetstocu_seqlensin another heuristic lambda function.
160-162: Approved: Consistent parameter renaming in remaining heuristic functions.The renaming has been consistently applied across all IS_VARLEN heuristic lambda functions in the file.
Also applies to: 231-233, 279-281, 357-359, 493-495, 616-618
35-42: Approved: Function signatures updated with new parameter name.Function signatures have been updated to use
cu_seqlensinstead ofoffsetsacross multiple functions, maintaining consistent naming throughout the codebase.Also applies to: 108-115, 173-180, 244-249, 291-301, 368-376, 629-644
55-57: Approved: Consistent variable handling in kernel implementations.The loading of sequence boundaries from
cu_seqlenshas been correctly implemented in all kernel functions, preserving the original functionality.Also applies to: 128-129, 196-197, 260-261, 314-316, 523-524, 577-578, 658-659
726-734: Approved: Function parameter and calls updated.The function signature for
chunk_gla_fwd_intra_gkhas been updated along with the proper usage ofcu_seqlensin the function body. This includes the call toprepare_chunk_indiceswhich now receives the renamed parameter.
827-829: Approved: Consistent parameter renaming in helper functions.Parameter renaming has been consistently applied to all helper functions in the file.
Also applies to: 862-863, 896-897, 932-933, 977-978
1008-1020: Approved: Renaming in main function signatures and class methods.The main function signatures and class methods have been properly updated to use the new parameter name.
Also applies to: 1062-1076, 1166-1168, 1199-1200
1024-1025: Approved: Parameter renaming in function calls.All function calls have been updated to use the new parameter name, ensuring consistent usage throughout the file.
Also applies to: 1036-1037, 1046-1047, 1056-1057, 1080-1081, 1091-1092, 1106-1107, 1117-1119, 1126-1128, 1134-1136, 1148-1150, 1181-1182, 1212-1214
732-733:Details
✅ Verification successful
Verify sequential logic consistency with new parameter name.
The updated sequential statements may have introduced a dependency that requires evaluation. Ensure that both the conditional statement and the assignment maintain the original intent.
🏁 Script executed:
#!/bin/bash # Find all instances where prepare_chunk_indices is called echo "Searching for all calls to prepare_chunk_indices" rg -p "prepare_chunk_indices\(.*\)" --context 3 # Check if the function is defined in the utils file with cu_seqlens as the first parameter echo "Checking the signature of prepare_chunk_indices" rg -p "def prepare_chunk_indices" --context 5Length of output: 22972
Sequential Logic Verified – No Further Action Required
The analysis confirms that the update to use the new parameter name is consistent. In the sequential branch (when
cu_seqlensisNone),NTis computed usingBT(derived fromchunk_size) while for variable-length cases the function is called withchunk_sizeas expected per its signature infla/ops/common/utils.py. The logic across the codebase is consistent with this change.tests/ops/test_solve_tril.py (5)
25-28: Approved: Updated string delimiter style in test parametrizations.The quotes have been changed from double to single quotes in the pytest parametrize decorators. This is a style change that doesn't affect functionality.
49-50: Approved: Updated string delimiter in assert_close call.Changed from double to single quotes for consistency.
52-54: Approved: Renamed parameter and updated string delimiters.The parameter name
offsetshas been changed tocu_seqlensin the test function parametrize decorator, consistent with the changes in the implementation code.
30-32: Approved: Updated string delimiters in skipif decorators.Changed from double to single quotes in the test skip conditions for consistency.
Also applies to: 55-58, 60-62
80-81: Approved: Updated string delimiter in final assert_close call.Changed from double to single quotes for consistency.
tests/ops/test_rwkv6.py (6)
29-34: Approved: Updated string delimiter style in test parametrizations.Changed from double to single quotes in the pytest parametrize decorators for style consistency.
54-58: Simplified tensor initialization.The tensor initialization has been simplified to consistently use
(B, T, H, D)shape for all tensors, removing conditional logic that previously may have changed the shape based onhead_first.
64-82: Approved: Simplified function calls to fused_recurrent_rwkv6.The function calls have been simplified to directly pass tensor values without conditional reshaping, making the code more readable and maintainable.
92-100: Approved: Simplified function calls to chunk_rwkv6.Similar simplification has been applied to the chunk_rwkv6 function calls, making them consistent with the fused_recurrent_rwkv6 calls.
138-142: Renamed 'offsets' to 'cu_seqlens'.The variable name has been updated from 'offsets' to 'cu_seqlens' to align with the consistent naming convention used across the codebase.
160-161: Approved: Updated function parameter in test calls.Function calls in the variable length sequence tests have been updated to use
cu_seqlensinstead ofoffsets.Also applies to: 170-171, 188-189
fla/ops/rwkv6/fused_recurrent.py (11)
17-20: Approved: Updated heuristic lambda function.The parameter name in the heuristic lambda function has been renamed from
offsetstocu_seqlensfor better descriptiveness.
38-39: Approved: Updated kernel function signatures.All kernel function signatures have been updated to use the new parameter name consistently.
Also applies to: 124-125, 215-216, 307-308
55-56: Approved: Updated internal variable references.Variable references within the kernel functions have been updated to load data from
cu_seqlensinstead ofoffsets.Also applies to: 140-141, 231-232, 320-321
102-105: Approved: Updated remaining heuristic functions.All heuristic lambda functions have been updated with the new parameter name.
Also applies to: 192-194, 289-290
362-363: Approved: Updated helper function signatures and calls.Helper function signatures and their calls have been updated with the new parameter name.
Also applies to: 383-384, 408-409, 429-430, 463-464, 486-487
365-366: Approved: Updated sequential logic.The variable N is now determined based on
cu_seqlensinstead ofoffsets, maintaining the original logic.Also applies to: 411-412
514-515: Approved: Updated autograd function class.The FusedRecurrentRWKV6Function class has been updated to use the new parameter name in both forward and backward methods.
Also applies to: 526-527, 550-551
531-532: Approved: Updated context saving.The saved context variable name has been updated to match the new parameter name.
566-567: Approved: Updated public API function signature.The public-facing function signature has been updated with the new parameter name.
593-596: Approved: Updated docstring description.The docstring has been updated to describe
cu_seqlensconsistently with its usage in the code.
676-678: Approved: Updated function application in public API.The function application has been updated to use the new parameter name when passing to the autograd function.
…dating related logic and heuristics for variable-length sequences.
* [Attn] Remove `head_first` & rename `offsets` to `cu_seqlens` * Delete 256 headdim tests * [DeltaNet] Remove `head_first` & rename `offsets` to `cu_seqlens` * Rename `offsets` to `cu_seqlens` across GLA/GSA/RWKV6 * Fix NSA tests * Fix TTT checks * Rename `offsets` to `cu_seqlens` in DPLR and IPLR implementations, updating related logic and heuristics for variable-length sequences.
Summary by CodeRabbit
Refactor
offsetstocu_seqlensandindicestochunk_indices.Tests
head_firstparameter.