Skip to content

[Attn] Remove head_first & rename offsets to cu_seqlens#345

Merged
yzhangcs merged 7 commits intomainfrom
head_first
Apr 10, 2025
Merged

[Attn] Remove head_first & rename offsets to cu_seqlens#345
yzhangcs merged 7 commits intomainfrom
head_first

Conversation

@yzhangcs
Copy link
Copy Markdown
Member

@yzhangcs yzhangcs commented Apr 10, 2025

Summary by CodeRabbit

  • Refactor

    • Enhanced the attention mechanism’s API for variable-length inputs by standardizing parameter naming from offsets to cu_seqlens and indices to chunk_indices.
    • Introduced runtime warnings to alert users about deprecated options and potential input shape mismatches.
  • Tests

    • Updated test parameters and variable names for consistency and clarity in the attention tests, including stylistic changes to string delimiters.
    • Adjusted the logic in test functions to reflect the new parameter names and ensure uniform tensor shapes, removing unnecessary conditionals related to the head_first parameter.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 10, 2025

Caution

Review failed

The pull request is closed.

Walkthrough

This pull request primarily renames parameters and adjusts logic in the attention kernels. In the affected functions, offsets and indices have been renamed to cu_seqlens and chunk_indices respectively. The autotuning logic and warning messages have been updated accordingly, and a minor data type adjustment has been made in the preprocessing function. Additionally, tests have been updated for consistent parameterization and naming.

Changes

File(s) Change Summary
fla/.../attn/parallel.py Updated function signatures in parallel_attn_fwd_kernel, parallel_attn_bwd_kernel_dq, and parallel_attn_bwd_kernel_dkv by renaming offsets to cu_seqlens and indices to chunk_indices. Autotuning conditions now check for cu_seqlens and include IS_VARLEN. Added warnings for deprecated usage of head_first and adjusted the data type for delta in preprocessing.
tests/ops/test_attn.py Revised parameterization syntax in test decorators (double quotes to single quotes) and renamed the variable offsets to cu_seqlens in the variable-length tests. Reformatted assertion messages for consistency.
fla/.../common/chunk_delta_h.py Renamed offsets to cu_seqlens in multiple functions and updated logic for handling variable-length sequences.
fla/.../common/chunk_h.py Replaced offsets with cu_seqlens in function signatures and logic across several functions.
fla/.../common/chunk_h_parallel.py Updated parameter names from offsets to cu_seqlens in multiple functions and heuristic conditions.
fla/.../common/chunk_h_split.py Updated parameter names from offsets to cu_seqlens in multiple functions and heuristic conditions.
fla/.../common/chunk_o.py Renamed offsets to cu_seqlens and indices to chunk_indices in several functions, affecting how variable-length sequences are processed.
fla/.../common/chunk_scaled_dot_kkt.py Updated parameters in chunk_scaled_dot_kkt_fwd_kernel and chunk_scaled_dot_kkt_fwd functions, replacing offsets and indices with cu_seqlens and chunk_indices.
fla/.../common/fused_recurrent.py Renamed offsets to cu_seqlens in multiple functions and updated internal logic accordingly.
fla/.../utils/pooling.py Updated mean_pooling_fwd_kernel and mean_pooling_bwd_kernel to use cu_seqlens and chunk_indices instead of offsets and indices.
fla/.../utils/solve_tril.py Changed offsets and indices to cu_seqlens and chunk_indices in several kernel functions, affecting how sequences are processed.
tests/ops/test_delta.py Modified string delimiters in test decorators and assertions from double quotes to single quotes for consistency.
tests/ops/test_gated_delta.py Updated string delimiters in test decorators and assertions from double quotes to single quotes without changing functionality.
tests/ops/test_hgrn.py Changed variable name from offsets to cu_seqlens in the test_fused_recurrent_varlen function, affecting how input sequences are handled.
fla/.../nsa/compression.py Renamed offsets to cu_seqlens in multiple functions and updated logic for handling variable-length sequences.
fla/.../nsa/naive.py Renamed indices to block_indices in the naive_nsa function and updated internal logic accordingly.
fla/.../nsa/parallel.py Updated multiple functions to replace offsets with cu_seqlens and adjusted logic for handling variable-length sequences.
fla/.../ttt/chunk.py Renamed offsets to cu_seqlens in multiple functions and updated logic for handling variable-length sequences.
fla/.../ttt/fused_chunk.py Renamed offsets to cu_seqlens in multiple functions and updated logic for handling variable-length sequences.

Sequence Diagram(s)

sequenceDiagram
    participant Caller as Caller
    participant Fwd as parallel_attn_fwd
    participant Kernel as parallel_attn_fwd_kernel

    Caller ->> Fwd: Invoke forward attention (q, k, v, g_cumsum, scale, cu_seqlens)
    Fwd ->> Kernel: Call kernel with cu_seqlens & chunk_indices
    Kernel ->> Kernel: Check IS_VARLEN & process variable-length sequences
    Kernel -->> Fwd: Return output tensor
    Fwd -->> Caller: Deliver final result
Loading

Possibly related PRs

  • [DeltaNet] WY repr speedup #279: The changes in the main PR are related to those in the retrieved PR as both involve modifications to the handling of variable-length sequences, specifically through the renaming of parameters from offsets to cu_seqlens in functions that process these sequences. Additionally, both PRs update the logic for checking variable-length conditions and loading data accordingly.
  • [WIP] Delete head_first option for cumsum #342: The changes in the main PR are related to those in the retrieved PR as both involve the removal of the head_first parameter and its associated logic, focusing on how variable-length sequences are handled in the respective functions.
  • [WY representation] Faster lower triangle inverse #289: The changes in the main PR, which involve renaming parameters and updating logic related to handling variable-length sequences in various functions, are related to the changes in the retrieved PR, which also includes modifications to parameter names and logic for handling offsets and indices in the context of variable-length sequences. Both PRs focus on standardizing the handling of sequence lengths across their respective functions.

Poem

I'm a rabbit, hopping through lines of code,
With cu_seqlens and chunk_indices in tow.
The kernels now dance with newfound grace,
A streamlined sequence in every place,
Leaping into a future where changes glow! 🐰💻

✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai plan to trigger planning for file edits and PR creation.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
fla/ops/attn/parallel.py (2)

716-719: Added deprecation warning for head_first parameter.

Good practice to warn users about the planned deprecation of the head_first parameter, allowing them to update their code before the feature is removed.

Consider adding a stacklevel parameter to the warning to help users identify where in their code the warning is coming from:

-        warnings.warn(
-            "head_first is deprecated and will be removed in a future version. "
-            "Please use head_first=False for now instead."
-        )
+        warnings.warn(
+            "head_first is deprecated and will be removed in a future version. "
+            "Please use head_first=False for now instead.",
+            stacklevel=2
+        )
🧰 Tools
🪛 Ruff (0.8.2)

716-716: No explicit stacklevel keyword argument found

(B028)


723-729: Added warning for potential tensor shape mismatch.

Great addition that helps users identify potential misuse of the API when the input tensor dimensions suggest a format mismatch.

Consider adding a stacklevel parameter here as well for better error reporting:

-        warnings.warn(
-            f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
-            "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
-            "when head_first=False was specified. "
-            "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
-        )
+        warnings.warn(
+            f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
+            "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
+            "when head_first=False was specified. "
+            "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
+            stacklevel=2
+        )
🧰 Tools
🪛 Ruff (0.8.2)

724-724: No explicit stacklevel keyword argument found

(B028)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8e05790 and 9a46410.

📒 Files selected for processing (2)
  • fla/ops/attn/parallel.py (22 hunks)
  • tests/ops/test_attn.py (5 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (2)
tests/ops/test_attn.py (2)
fla/ops/common/utils.py (1)
  • prepare_lens (35-36)
fla/ops/utils/testing.py (1)
  • assert_close (17-29)
fla/ops/attn/parallel.py (3)
fla/utils.py (1)
  • check_shared_mem (199-205)
fla/ops/common/utils.py (1)
  • prepare_chunk_indices (56-61)
fla/ops/utils/cumsum.py (1)
  • chunk_global_cumsum (364-382)
🪛 Ruff (0.8.2)
fla/ops/attn/parallel.py

716-716: No explicit stacklevel keyword argument found

(B028)


724-724: No explicit stacklevel keyword argument found

(B028)

⏰ Context from checks skipped due to timeout of 90000ms (2)
  • GitHub Check: test
  • GitHub Check: test
🔇 Additional comments (16)
tests/ops/test_attn.py (5)

34-40: Code style consistency improvement.

The parameterization decorators have been updated to use single quotes consistently for parameter names, which improves style consistency throughout the codebase.


107-111: Improved parameter naming for better semantics.

Renaming offsets to cu_seqlens provides better clarity about the parameter's purpose - representing cumulative sequence lengths in variable-length sequences.


122-126: Consistent parameter naming with implementation.

The parameter names in the flash_attn_varlen_func call have been updated to match the renamed parameters in the implementation file, ensuring consistency across the codebase.


137-138: Aligned parameter naming with implementation.

Updated parameter name from offsets to cu_seqlens in the parallel_attn call, ensuring consistency with the implementation.


144-147: Improved formatting in assertion messages.

Removed unnecessary whitespace in assertion messages, making the code cleaner.

fla/ops/attn/parallel.py (11)

4-4: Added warnings import for deprecation notices.

Good addition to support the new warning messages about deprecated functionality.


19-21: Updated variable length condition check.

The heuristic for checking if variable length sequences are in use has been updated to reference the renamed parameter (cu_seqlens instead of offsets).


28-28: Improved autotune configuration keys.

Added USE_G and IS_VARLEN to the autotune configuration key list, which will help the tuner better differentiate between different use cases.


39-41: Consistent parameter renaming in kernel function signature.

Parameters offsets and indices have been renamed to cu_seqlens and chunk_indices respectively, providing better clarity about their purpose.


60-62: Updated variable loading to use renamed parameters.

The variable loading code has been updated to reference the renamed parameters, maintaining functionality while improving readability.


195-197: Consistent parameter renaming in backward kernel functions.

The parameter renaming has been consistently applied to both forward and backward kernel function signatures, maintaining a unified naming convention throughout the codebase.

Also applies to: 333-335


489-491: Updated variable initialization with renamed parameters.

Function calls and variable initializations have been updated to use the renamed parameters, ensuring consistency throughout the codebase.


526-527: Modified floating-point type in backward preprocessing.

Changed torch.float32 to torch.float which allows the code to adapt to the system's default floating-point precision.


571-573: Simplified empty tensor initialization.

The code now directly initializes empty tensors with the appropriate dtype without additional empty_like and cast operations.


730-734: Improved function parameter handling.

The code now handles scale calculation and input validation more clearly with better spacing and assertions.


679-679: Updated return statement.

Function return statement updated to match the new parameter naming convention.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
fla/ops/gsa/fused_recurrent.py (1)

512-517: Consider removing deprecated head_first parameter

The PR title mentions removing head_first, but it's only deprecated with a warning. Consider removing it completely in this PR or creating a follow-up PR for full removal.

🧰 Tools
🪛 Ruff (0.8.2)

513-513: No explicit stacklevel keyword argument found

(B028)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 703a7a9 and 7b46c6c.

📒 Files selected for processing (23)
  • fla/layers/rwkv6.py (1 hunks)
  • fla/ops/gla/chunk.py (56 hunks)
  • fla/ops/gla/fused_chunk.py (0 hunks)
  • fla/ops/gsa/chunk.py (42 hunks)
  • fla/ops/gsa/fused_recurrent.py (12 hunks)
  • fla/ops/hgrn/fused_recurrent.py (11 hunks)
  • fla/ops/linear_attn/fused_chunk.py (0 hunks)
  • fla/ops/linear_attn/fused_recurrent.py (0 hunks)
  • fla/ops/retention/fused_chunk.py (0 hunks)
  • fla/ops/rwkv6/chunk.py (57 hunks)
  • fla/ops/rwkv6/fused_recurrent.py (21 hunks)
  • fla/ops/simple_gla/chunk.py (11 hunks)
  • fla/ops/simple_gla/parallel.py (17 hunks)
  • tests/ops/test_based.py (1 hunks)
  • tests/ops/test_gla.py (4 hunks)
  • tests/ops/test_gsa.py (8 hunks)
  • tests/ops/test_hgrn.py (6 hunks)
  • tests/ops/test_linear_attn.py (4 hunks)
  • tests/ops/test_retention.py (4 hunks)
  • tests/ops/test_rwkv6.py (7 hunks)
  • tests/ops/test_simple_gla.py (7 hunks)
  • tests/ops/test_solve_tril.py (3 hunks)
  • tests/ops/test_utils.py (8 hunks)
💤 Files with no reviewable changes (4)
  • fla/ops/linear_attn/fused_chunk.py
  • fla/ops/gla/fused_chunk.py
  • fla/ops/retention/fused_chunk.py
  • fla/ops/linear_attn/fused_recurrent.py
✅ Files skipped from review due to trivial changes (9)
  • tests/ops/test_gla.py
  • tests/ops/test_based.py
  • fla/layers/rwkv6.py
  • tests/ops/test_linear_attn.py
  • tests/ops/test_gsa.py
  • tests/ops/test_hgrn.py
  • tests/ops/test_utils.py
  • tests/ops/test_simple_gla.py
  • tests/ops/test_retention.py
🧰 Additional context used
🧬 Code Graph Analysis (6)
fla/ops/simple_gla/chunk.py (1)
fla/ops/utils/cumsum.py (1)
  • chunk_local_cumsum (386-406)
tests/ops/test_rwkv6.py (2)
fla/ops/rwkv6/fused_recurrent.py (1)
  • fused_recurrent_rwkv6 (556-681)
fla/ops/rwkv6/chunk.py (1)
  • chunk_rwkv6 (1188-1308)
fla/ops/hgrn/fused_recurrent.py (8)
fla/ops/gla/fused_chunk.py (1)
  • backward (470-595)
fla/ops/gla/chunk.py (1)
  • backward (1197-1215)
fla/ops/rwkv6/fused_recurrent.py (1)
  • backward (537-553)
fla/ops/rwkv6/chunk.py (1)
  • backward (1167-1184)
fla/ops/delta_rule/chunk.py (1)
  • backward (194-221)
fla/ops/gated_delta_rule/chunk.py (1)
  • backward (205-232)
fla/ops/generalized_delta_rule/dplr/chunk.py (1)
  • backward (135-268)
fla/ops/abc/chunk.py (1)
  • backward (950-1074)
fla/ops/rwkv6/chunk.py (1)
fla/ops/common/utils.py (2)
  • prepare_chunk_indices (59-64)
  • prepare_chunk_offsets (68-72)
fla/ops/simple_gla/parallel.py (2)
fla/ops/common/utils.py (1)
  • prepare_chunk_indices (59-64)
fla/ops/utils/cumsum.py (2)
  • chunk_local_cumsum (386-406)
  • chunk_global_cumsum (364-382)
fla/ops/gla/chunk.py (2)
fla/ops/common/utils.py (1)
  • prepare_chunk_indices (59-64)
fla/ops/utils/cumsum.py (1)
  • chunk_local_cumsum (386-406)
🪛 Ruff (0.8.2)
fla/ops/rwkv6/chunk.py

1272-1272: No explicit stacklevel keyword argument found

(B028)

🔇 Additional comments (60)
fla/ops/hgrn/fused_recurrent.py (7)

17-17: LGTM: Parameter name consistently renamed

The parameter renaming from offsets to cu_seqlens in the heuristic lambda function is consistent with the PR objective.


34-34: LGTM: Parameter renaming in forward kernel

The parameter and its usage have been consistently renamed from offsets to cu_seqlens in the forward kernel, maintaining the same functionality while aligning with standard terminology.

Also applies to: 44-44


78-78: LGTM: Parameter renaming in backward kernel

The parameter and its usage have been consistently renamed from offsets to cu_seqlens in the backward kernel definition and implementation.

Also applies to: 98-98, 108-108


160-160: LGTM: Parameter renaming in fused_recurrent_hgrn_fwd function

The parameter has been consistently renamed from offsets to cu_seqlens in the function signature and its usage in calculating N and passing to the kernel.

Also applies to: 163-163, 175-175


188-188: LGTM: Parameter renaming in fused_recurrent_hgrn_bwd function

The parameter has been consistently renamed from offsets to cu_seqlens in the backward function signature and its usage.

Also applies to: 191-191, 206-206


223-223: LGTM: Parameter renaming in FusedRecurrentHGRNFunction

The parameter has been consistently renamed from offsets to cu_seqlens in the Function class, both in the forward method's signature and in context saving/retrieving for the backward pass.

Also applies to: 230-230, 233-233, 240-240, 248-248


259-308: LGTM: Documentation updated correctly

The docstring has been updated to reference cu_seqlens instead of offsets and provides clear explanation of its purpose for handling variable-length sequences. The usage example is also updated correctly.

fla/ops/gsa/fused_recurrent.py (3)

154-154: Consistent parameter renaming from offsets to cu_seqlens

The parameter has been renamed consistently throughout the file. This improves clarity as cu_seqlens more accurately describes the purpose of this parameter (cumulative sequence lengths), aligning with the naming convention used in the FlashAttention API as mentioned in the docstring.

Also applies to: 161-161, 189-189, 219-219, 252-252, 255-255, 282-282, 299-299, 321-321, 339-339, 364-364, 389-389, 394-394, 404-404, 425-425


467-472: Well-documented parameter

The documentation for cu_seqlens is clear and informative, explaining both its purpose and format. It also maintains consistency with the FlashAttention API.


519-524: Improved warning message

The warning message provides clear guidance about potential format mismatches, which helps users identify and fix issues with their input tensors.

🧰 Tools
🪛 Ruff (0.8.2)

519-519: No explicit stacklevel keyword argument found

(B028)

fla/ops/simple_gla/chunk.py (4)

25-25: Renamed parameter for better semantic clarity

The parameter name change from offsets to cu_seqlens provides better clarity about its purpose - representing cumulative sequence lengths.


28-28: Updated function call argument to match renamed parameter

Properly updated chunk_local_cumsum call to use the newly renamed parameter.


173-176: Modified cumsum operation and removed redundant type conversion

The code now passes cu_seqlens to the chunk_local_cumsum function and removes the .to(g) conversion for dg. This is because the type conversion is now directly handled in the return statement at line 178, which is clearer and more efficient.


178-178: Improved return statement consistency

The return statement maintains proper type conversions for other tensors while directly returning dg without an extra conversion step.

fla/ops/gsa/chunk.py (3)

21-22: Updated heuristic condition for variable-length sequences

The Lambda function now checks for cu_seqlens instead of offsets, maintaining the same functionality while using the more semantically clear parameter name.


557-558: Updated chunk index preparation for consistency

The code now correctly uses prepare_chunk_indices with the renamed cu_seqlens parameter, maintaining consistent naming across the codebase.


793-793: Updated cumsum operation with renamed parameter

The chunk_local_cumsum call now uses the cu_seqlens parameter name, consistent with the parameter naming changes throughout the file.

fla/ops/rwkv6/chunk.py (4)

13-13: Import statement updated to include additional utility

The import statement now includes prepare_chunk_offsets which is used alongside prepare_chunk_indices for handling variable-length sequences.


80-81: Updated chunk preparation logic

The code now uses prepare_chunk_indices with the renamed cu_seqlens parameter, maintaining consistent naming throughout the codebase.


1189-1193: Renamed parameter from g to w

The function parameter has been renamed from g to w for clarity. This matches with the docstring description at line 1208-1209 that describes this parameter as "Forget gates".


1270-1271: Updated parameter references in conditional mapping

The code correctly updates the parameter reference in the map function call to use the new parameter name w instead of g.

fla/ops/simple_gla/parallel.py (7)

33-34: Updated heuristic conditions for better clarity

The heuristic conditions now check for cu_seqlens and properly format the condition for USE_G. This provides better semantic clarity and maintains consistent parameter naming.


52-53: Renamed parameters in kernel function signature

The parameters offsets and indices have been renamed to cu_seqlens and chunk_indices respectively, for better semantic clarity and consistency across the codebase.


512-513: Updated chunk preparation logic

The code now uses prepare_chunk_indices with the renamed cu_seqlens parameter, maintaining consistent naming throughout the codebase.


517-517: Updated cumsum function call with renamed parameter

The chunk_local_cumsum function call now uses the new parameter name cu_seqlens, maintaining consistency with the parameter naming changes.


622-622: Updated method signature for consistency

The forward method signature now uses cu_seqlens instead of offsets, maintaining consistent parameter naming across the codebase.


636-636: Updated context saving with renamed tensor

The saved tensors for backward pass now include cu_seqlens instead of offsets, matching the parameter renaming.


613-613: Updated global cumsum function call

The chunk_global_cumsum function call now correctly uses the renamed cu_seqlens parameter, maintaining consistency with other parameter naming changes.

fla/ops/gla/chunk.py (10)

22-24: Approved: Parameter renaming in heuristic lambda function.

The renaming of offsets to cu_seqlens in the heuristic lambda function is consistent with the PR objective and makes the code more descriptive as it clearly indicates "cumulative sequence lengths."


95-97: Approved: Parameter renaming in heuristic lambda function.

Consistent renaming of offsets to cu_seqlens in another heuristic lambda function.


160-162: Approved: Consistent parameter renaming in remaining heuristic functions.

The renaming has been consistently applied across all IS_VARLEN heuristic lambda functions in the file.

Also applies to: 231-233, 279-281, 357-359, 493-495, 616-618


35-42: Approved: Function signatures updated with new parameter name.

Function signatures have been updated to use cu_seqlens instead of offsets across multiple functions, maintaining consistent naming throughout the codebase.

Also applies to: 108-115, 173-180, 244-249, 291-301, 368-376, 629-644


55-57: Approved: Consistent variable handling in kernel implementations.

The loading of sequence boundaries from cu_seqlens has been correctly implemented in all kernel functions, preserving the original functionality.

Also applies to: 128-129, 196-197, 260-261, 314-316, 523-524, 577-578, 658-659


726-734: Approved: Function parameter and calls updated.

The function signature for chunk_gla_fwd_intra_gk has been updated along with the proper usage of cu_seqlens in the function body. This includes the call to prepare_chunk_indices which now receives the renamed parameter.


827-829: Approved: Consistent parameter renaming in helper functions.

Parameter renaming has been consistently applied to all helper functions in the file.

Also applies to: 862-863, 896-897, 932-933, 977-978


1008-1020: Approved: Renaming in main function signatures and class methods.

The main function signatures and class methods have been properly updated to use the new parameter name.

Also applies to: 1062-1076, 1166-1168, 1199-1200


1024-1025: Approved: Parameter renaming in function calls.

All function calls have been updated to use the new parameter name, ensuring consistent usage throughout the file.

Also applies to: 1036-1037, 1046-1047, 1056-1057, 1080-1081, 1091-1092, 1106-1107, 1117-1119, 1126-1128, 1134-1136, 1148-1150, 1181-1182, 1212-1214


732-733:

Details

✅ Verification successful

Verify sequential logic consistency with new parameter name.

The updated sequential statements may have introduced a dependency that requires evaluation. Ensure that both the conditional statement and the assignment maintain the original intent.


🏁 Script executed:

#!/bin/bash
# Find all instances where prepare_chunk_indices is called
echo "Searching for all calls to prepare_chunk_indices"
rg -p "prepare_chunk_indices\(.*\)" --context 3

# Check if the function is defined in the utils file with cu_seqlens as the first parameter
echo "Checking the signature of prepare_chunk_indices"
rg -p "def prepare_chunk_indices" --context 5

Length of output: 22972


Sequential Logic Verified – No Further Action Required

The analysis confirms that the update to use the new parameter name is consistent. In the sequential branch (when cu_seqlens is None), NT is computed using BT (derived from chunk_size) while for variable-length cases the function is called with chunk_size as expected per its signature in fla/ops/common/utils.py. The logic across the codebase is consistent with this change.

tests/ops/test_solve_tril.py (5)

25-28: Approved: Updated string delimiter style in test parametrizations.

The quotes have been changed from double to single quotes in the pytest parametrize decorators. This is a style change that doesn't affect functionality.


49-50: Approved: Updated string delimiter in assert_close call.

Changed from double to single quotes for consistency.


52-54: Approved: Renamed parameter and updated string delimiters.

The parameter name offsets has been changed to cu_seqlens in the test function parametrize decorator, consistent with the changes in the implementation code.


30-32: Approved: Updated string delimiters in skipif decorators.

Changed from double to single quotes in the test skip conditions for consistency.

Also applies to: 55-58, 60-62


80-81: Approved: Updated string delimiter in final assert_close call.

Changed from double to single quotes for consistency.

tests/ops/test_rwkv6.py (6)

29-34: Approved: Updated string delimiter style in test parametrizations.

Changed from double to single quotes in the pytest parametrize decorators for style consistency.


54-58: Simplified tensor initialization.

The tensor initialization has been simplified to consistently use (B, T, H, D) shape for all tensors, removing conditional logic that previously may have changed the shape based on head_first.


64-82: Approved: Simplified function calls to fused_recurrent_rwkv6.

The function calls have been simplified to directly pass tensor values without conditional reshaping, making the code more readable and maintainable.


92-100: Approved: Simplified function calls to chunk_rwkv6.

Similar simplification has been applied to the chunk_rwkv6 function calls, making them consistent with the fused_recurrent_rwkv6 calls.


138-142: Renamed 'offsets' to 'cu_seqlens'.

The variable name has been updated from 'offsets' to 'cu_seqlens' to align with the consistent naming convention used across the codebase.


160-161: Approved: Updated function parameter in test calls.

Function calls in the variable length sequence tests have been updated to use cu_seqlens instead of offsets.

Also applies to: 170-171, 188-189

fla/ops/rwkv6/fused_recurrent.py (11)

17-20: Approved: Updated heuristic lambda function.

The parameter name in the heuristic lambda function has been renamed from offsets to cu_seqlens for better descriptiveness.


38-39: Approved: Updated kernel function signatures.

All kernel function signatures have been updated to use the new parameter name consistently.

Also applies to: 124-125, 215-216, 307-308


55-56: Approved: Updated internal variable references.

Variable references within the kernel functions have been updated to load data from cu_seqlens instead of offsets.

Also applies to: 140-141, 231-232, 320-321


102-105: Approved: Updated remaining heuristic functions.

All heuristic lambda functions have been updated with the new parameter name.

Also applies to: 192-194, 289-290


362-363: Approved: Updated helper function signatures and calls.

Helper function signatures and their calls have been updated with the new parameter name.

Also applies to: 383-384, 408-409, 429-430, 463-464, 486-487


365-366: Approved: Updated sequential logic.

The variable N is now determined based on cu_seqlens instead of offsets, maintaining the original logic.

Also applies to: 411-412


514-515: Approved: Updated autograd function class.

The FusedRecurrentRWKV6Function class has been updated to use the new parameter name in both forward and backward methods.

Also applies to: 526-527, 550-551


531-532: Approved: Updated context saving.

The saved context variable name has been updated to match the new parameter name.


566-567: Approved: Updated public API function signature.

The public-facing function signature has been updated with the new parameter name.


593-596: Approved: Updated docstring description.

The docstring has been updated to describe cu_seqlens consistently with its usage in the code.


676-678: Approved: Updated function application in public API.

The function application has been updated to use the new parameter name when passing to the autograd function.

@yzhangcs yzhangcs merged commit 2f79d88 into main Apr 10, 2025
2 of 6 checks passed
@yzhangcs yzhangcs deleted the head_first branch April 10, 2025 19:32
zhiyuan1i pushed a commit that referenced this pull request Apr 11, 2025
* [Attn] Remove `head_first` & rename `offsets` to `cu_seqlens`

* Delete 256 headdim tests

* [DeltaNet] Remove `head_first` & rename `offsets` to `cu_seqlens`

* Rename `offsets` to `cu_seqlens` across GLA/GSA/RWKV6

* Fix NSA tests

* Fix TTT checks

* Rename `offsets` to `cu_seqlens` in DPLR and IPLR implementations, updating related logic and heuristics for variable-length sequences.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant